Understand torch.nn.BatchNorm1d() with Examples – PyTorch Tutorial

By | June 23, 2022

In this tutorial, we will use an example to introduce how to use torch.nn.BatchNorm1d() in pytorch. Then, you can use it correctly.

Batch Normalization

Batch normalization can speed up the model training and improve the accuracy of CNNs.

Dropout vs Batch Normalization – Which is Better for Multilayered Neural Network

In order to understand what is batch normalization and how to compute and update parameters, you can read:

Batch Normalization Vs Layer Normalization: The Difference Explained

How to Update the Mean and Variance of Population and Test Sample in Batch Normalization – Machine Learning Tutorial

Understand Batch Normalization: A Beginner Explain – Machine Learning Tutorial

How to compute batch normalization in pytorch?

It is easy to implement a batch normalization layer in pytorch, we can use torch.nn.BatchNorm1d(). This function is defined as:

torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)

It will apply Batch Normalization over a 2D or 3D input

Understand torch.nn.BatchNorm1d() with Examples - PyTorch Tutorial

2D input: it will be (N,C) 

3D input: it will be (N,C, L)

Here

N: it is the batch size, for example N = 32

C: it is the number of features or channels input, here C = num_features

L: it is the sequence length

As to output,

Output: (N,C) or (N,C,L) (same shape as input)

We will use an example to show you how to use it.

import torch
import torch.nn as nn

C = 200
B = 20
m = nn.BatchNorm1d(C, affine=False)
input = torch.randn(B, C)
output = m(input)
print(output.shape)

Here input is a 2D, the shape of it is 20*200. Run this code, we also get a 20*200 output.

torch.Size([20, 200])

Moreover, if input is 3D

import torch
import torch.nn as nn

C = 200
B = 20
L = 50
m = nn.BatchNorm1d(C, affine=False)
input = torch.randn(B, C, L)
output = m(input)
print(output.shape)

We also get a 20*200*50 output.

Notice: we will disable the effect of batch normalization when testing or evaluating a pytorch model. In tensorflow, we will set Training = False.

In pytorch, we can use model.eval() model to disable batch normalization.

An Introduction to PyTorch model.eval() for Beginners – PyTorch Tutorial