# Understand torch.nn.BatchNorm1d() with Examples – PyTorch Tutorial

By | June 23, 2022

In this tutorial, we will use an example to introduce how to use torch.nn.BatchNorm1d() in pytorch. Then, you can use it correctly.

## Batch Normalization

Batch normalization can speed up the model training and improve the accuracy of CNNs.

Dropout vs Batch Normalization – Which is Better for Multilayered Neural Network

In order to understand what is batch normalization and how to compute and update parameters, you can read:

Batch Normalization Vs Layer Normalization: The Difference Explained

How to Update the Mean and Variance of Population and Test Sample in Batch Normalization – Machine Learning Tutorial

Understand Batch Normalization: A Beginner Explain – Machine Learning Tutorial

## How to compute batch normalization in pytorch?

It is easy to implement a batch normalization layer in pytorch, we can use torch.nn.BatchNorm1d(). This function is defined as:

torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)

It will apply Batch Normalization over a 2D or 3D input

2D input: it will be (N,C)

3D input: it will be (N,C, L)

Here

N: it is the batch size, for example N = 32

C: it is the number of features or channels input, here C = num_features

L: it is the sequence length

As to output,

Output: (N,C) or (N,C,L) (same shape as input)

We will use an example to show you how to use it.

import torch
import torch.nn as nn

C = 200
B = 20
m = nn.BatchNorm1d(C, affine=False)
input = torch.randn(B, C)
output = m(input)
print(output.shape)

Here input is a 2D, the shape of it is 20*200. Run this code, we also get a 20*200 output.

torch.Size([20, 200])

Moreover, if input is 3D

import torch
import torch.nn as nn

C = 200
B = 20
L = 50
m = nn.BatchNorm1d(C, affine=False)
input = torch.randn(B, C, L)
output = m(input)
print(output.shape)

We also get a 20*200*50 output.

Notice: we will disable the effect of batch normalization when testing or evaluating a pytorch model. In tensorflow, we will set Training = False.

In pytorch, we can use model.eval() model to disable batch normalization.

