A Step Guide to Implement Batch Normalization in TensorFlow – TensorFlow Tutorial

By | July 13, 2021

Batch normalization is widely used in neural networks. In this tutorial, we will introduce how to use it in tensorflow.

To understand batch normalization, you can read this tutorial:

Understand Batch Normalization: A Beginner Explain

In order to use batch normalization in neural networks, there are two important tips you must know:

  • You only can train mean and variance of samples in training step
  • You are best to add batch normalization before nonlinear activation function in cnn

How to use batch normalization in tensorflow?

The simplest way is to use tf.layers.batch_normalization() function.

Here is the definition.

tf.layers.batch_normalization(
    inputs, axis=-1, momentum=0.99, epsilon=0.001, center=True, scale=True,
    beta_initializer=tf.zeros_initializer(),
    gamma_initializer=tf.ones_initializer(),
    moving_mean_initializer=tf.zeros_initializer(),
    moving_variance_initializer=tf.ones_initializer(), beta_regularizer=None,
    gamma_regularizer=None, beta_constraint=None, gamma_constraint=None,
    training=False, trainable=True, name=None, reuse=None, renorm=False,
    renorm_clipping=None, renorm_momentum=0.99, fused=None, virtual_batch_size=None,
    adjustment=None
)

As to training parameter, you should set it to be True when training your model.

Here is an example to show you how to use this function.

pooled_outputs = []
        for i, filter_size in enumerate(filter_sizes):
            with tf.name_scope("conv-maxpool-%s" % filter_size):
                # Convolution Layer
                # 3, 200, 1, 200
                filter_shape = [filter_size, embedding_size, 1, num_filters] # 3*200*1*200
                W = tf.Variable(tf.glorot_uniform_initializer()(filter_shape), name="W")
                #
                b = tf.Variable(tf.constant(0., shape=[num_filters]), name="bias") # 200
                conv = tf.nn.conv2d(
                    self.inputs,
                    W,
                    strides=[1, 1, embedding_size, 1],
                    padding="SAME",
                    name="conv")
                # Apply nonlinearity
                output = tf.nn.bias_add(conv, b)
                output = tf.layers.batch_normalization(
                    output,
                    training=is_training)

                h = tf.nn.relu(output, name="relu")
                pooled_outputs.append(h)

        num_filters_total = num_filters * len(filter_sizes)  # 3*200 = 600
        self.h_pool = tf.concat(pooled_outputs, 3)

In this example code, we have added batch normalization before nonlinear activation function (relu) using tf.layers.batch_normalization().

Meanwhile, in order to make batch normalization work, we should train our model as follows:

        global_step = tf.Variable(0, name="global_step", trainable=False)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            optimizer = tf.train.AdamOptimizer(FLAGS.lr)
            grads_and_vars = optimizer.compute_gradients(hm.loss)
            train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)

In training step, we should set training = True. Here is an example:

            feed_dict = {
                hm.input_x: x,
                hm.input_y: y,
                hm.word_in_sen_len: word_in_sen_len,
                hm.sen_in_doc_len: sen_in_doc_len,
                hm.primary_input_x: p,
                hm.input_mask_y: mask_y,
                hm.training: True,
                hm.prob: 0.5
            }
            _, step, loss, accuracy = sess.run(
                [train_op, global_step, hm.loss, hm.accuracy], feed_dict)

In evaluation step, we should set training = False.

            feed_dict = {
                hm.input_x: x,
                hm.input_y: y,
                hm.word_in_sen_len: word_in_sen_len,
                hm.sen_in_doc_len: sen_in_doc_len,
                hm.primary_input_x: p,
                hm.input_mask_y: mask_y,
                hm.training: False,
                hm.prob: 1.0
            }
            _, step, loss, accuracy = sess.run(
                [train_op, global_step, hm.loss, hm.accuracy], feed_dict)

Leave a Reply

Your email address will not be published. Required fields are marked *