Understand with torch.no_grad() with Examples – PyTorch Tutorial

By | April 27, 2022

We often see with torch.no_grad(): in some pytorch script. What does it mean? In this tutorial, we will use an example to explain.

torch.no_grad()

torch.no_grad() is a context manager, in order to undertand python context manager, you can view:

Create Customized Context Manager for Python With Statement: A Completed Guide – Python Tutorial

It will disable all gradient calculation in its context.

For example:

import torch

x = torch.randn([3, 4], requires_grad=True)
print(x.requires_grad)
with torch.no_grad():
    y = x * 2
    print(x.requires_grad)
    print(y.requires_grad)
z = x * 2
print(z.requires_grad)

Run this code, we will see:

True
True
False
True

In this example, tensor x is gradient supported. However, tensor y is in torch.no_grad() context. We can find:

x requires_grad = True and y requires_grad = False

Meanwhile, tensor z is out of torch.no_grad() context. we can find:

z requires_grad = False

How to use torch.no_grad()?

We usually use torch.no_grad() when we test or evaluate our model. Here is an example code:

model.eval()
with torch.no_grad():
    # you test / evaluate code

Leave a Reply