# Understand tf.nn.batch_normalization(): Normalize a Layer – TensorFlow Tutorial

By | May 24, 2021

TensorFlow tf.nn.batch_normalization() function can normalize a layer in batch. In this tutorial, we will use some examples to show you how to use it.

## tf.nn.batch_normalization()

tf.nn.batch_normalization() is defined as:

tf.nn.batch_normalization(
x,
mean,
variance,
offset,
scale,
variance_epsilon,
name=None
)

It can normalize input $$x$$ with $$mean$$ and $$variance$$.

The normalization equation is below in TensorFlow:

$$y_i=\lambda(\frac{x_i-\mu}{\sqrt{\sigma^2+\epsilon}})+\beta$$

Here

$$\mu$$ is mean

$$\sigma^2$$ is variance

$$\beta$$ is offset

$$\lambda$$ is scale

As to batch normalization, it is implemented differently in PyTorch and TensorFlow. Here is the tutorial:

Understand Batch Normalization: A Beginner Explain – Machine Learning Tutorial

## How to use tf.nn.batch_normalization()?

In order to use tf.nn.batch_normalization(), we should compute the mean and variance of input $$x$$. We can use tensorflow tf.nn.moments() to get them.

Here is an example:

import tensorflow as tf

x1 = tf.convert_to_tensor(
[[[18.369314, 2.6570225, 20.402943],
[10.403599, 2.7813416, 20.794857]],
[[19.0327, 2.6398268, 6.3894367],
[3.921237, 10.761424, 2.7887821]],
[[11.466338, 20.210938, 8.242946],
[22.77081, 11.555874, 11.183836]],
[[8.976935, 10.204252, 11.20231],
[-7.356888, 6.2725096, 1.1952505]]])

Here we have created a input $$x1$$, then we will compute its mean and variance.

mean_x, std_x = tf.nn.moments(x1, axes = 2, keep_dims=True)

We should notice $$axes = 2$$, which means we will normaize input $$x1$$. However, it is not batch normalization, it is layer normalization.

v1 = tf.nn.batch_normalization(x1, mean_x, std_x, None, None, 1e-12)
with tf.Session() as sess1:
sess1.run(tf.global_variables_initializer())
print(sess1.run(v1))

Run this code, we will get result:

[[[ 0.574993   -1.4064413   0.8314482 ]
[-0.12501884 -1.1574404   1.2824591 ]]

[[ 1.3801125  -0.95738953 -0.422723  ]
[-0.5402142   1.4019756  -0.86176133]]

[[-0.36398554  1.3654773  -1.0014919 ]
[ 1.4136491  -0.67222667 -0.7414224 ]]

[[-1.2645674   0.08396816  1.1806011 ]
[-1.3146634   1.108713    0.20595042]]]