# Understand the Mean and Variance Computed in Batch Normalization – Machine Learning Tutorial

By | December 3, 2020

In this tutorial, we will discuss how the mean and variance computed in batch normalization, which is very usefult to understand batch normalization.

For example:

There is a batch, the shape of it is (64, 200). There are three ways to compute the mean ($$\mu$$) and variance $$\sigma$$ .

## Way 1:

Compute the mean ($$\mu$$) and variance $$\sigma$$ on 12800 data. You will get two scalar.

## Way 2:

Compuate the mean ($$\mu$$) and variance $$\sigma$$ on axis = 1. You will get two 1*64 vectors.

## Way 3:

Compuate the mean ($$\mu$$) and variance $$\sigma$$ on axis = 0. You will get two 1*200 vectors.

If we will normalize the data on axis = 1, which way is used in batch normalization?

The answer is way 3.

We will use an example to show you this answer.

Look at code example below:

import os
import numpy as np
import tensorflow as tf

cap_map = tf.convert_to_tensor(np.array([[-1,3,2], [-3,1,3],[2,-7,4],[5,7, 6]], dtype = float), dtype = tf.float32)

def norm(xs):
fc_mean, fc_var = tf.nn.moments(
xs,
axes = 0,
keep_dims=True
)
epsilon =0.001
xs = tf.nn.batch_normalization(xs, fc_mean, fc_var, 0.0, 1.0, epsilon)
return xs

n1= tf.layers.batch_normalization(cap_map, axis = 1, training=True,  scale=False)
n2 = norm(cap_map)

init = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init)
print (sess.run(n1))
print (sess.run(n2))

cap_map is a 4*3 tensor, which means the batch size = 4.

We will normalize the data on axis = 1.

tf.layers.batch_normalization() will use batch normalization to normalize the data.

n1 is:

[[-0.57731885  0.39222473 -1.1829456 ]
[-1.2371118   0.         -0.5069766 ]
[ 0.4123706  -1.568899    0.16899228]
[ 1.40206     1.1766742   1.5209303 ]]

tf.nn.batch_normalization() function normalize the data based on tensor mean and variance. In this example, we will calculate them on axis = 0.

Run this code, we can find n2 is:

[[-0.57731885  0.39222473 -1.1829456 ]
[-1.2371118   0.         -0.5069766 ]
[ 0.4123706  -1.568899    0.16899228]
[ 1.40206     1.1766742   1.5209303 ]]


n1 = n2

It means the way 3 is implemented in batch normalization.