Understand torch.nn.MSELoss(): Compute the Mean Squared Error (Squared L2 Norm) – PyTorch Tutorial

By | July 11, 2022

In pytorch, we can use torch.nn.MSELoss() to compute the mean squared error (MSE). In this tutorial, we will use some examples to show you how to use it.

torch.nn.MSELoss()

It is defined as:

torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean')

It will compute the squared error with mean or sum

how to use torch.nn.MSELoss()

When reduction=”mean”, we will get mean squared error, however, when reduction=”sum”, we will get the sum of squared error.

MSE loss

It is defined as:

understand mse loss

How to use torch.nn.MSELoss() to compute mse loss?

It is easy in pytorch. Here is an example:

import torch
import torch.nn as nn
loss = nn.MSELoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)
output_mean = loss(input, target)
print(output_mean)

Here reduction=”mean”, run this code, we may get:

tensor(1.3199, grad_fn=<MseLossBackward>)

Moreover, if reduction=”sum”, we will get:

loss = nn.MSELoss(reduction="sum")
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)
output_sum = loss(input, target)
print(output_sum)
print(output_sum / 15)

Run this code, we will get:

tensor(22.6790, grad_fn=<MseLossBackward>)
tensor(1.5119, grad_fn=<DivBackward0>)