We often see with torch.no_grad(): in some pytorch script. What does it mean? In this tutorial, we will use an example to explain.
torch.no_grad()
torch.no_grad() is a context manager, in order to undertand python context manager, you can view:
Create Customized Context Manager for Python With Statement: A Completed Guide – Python Tutorial
It will disable all gradient calculation in its context.
For example:
import torch x = torch.randn([3, 4], requires_grad=True) print(x.requires_grad) with torch.no_grad(): y = x * 2 print(x.requires_grad) print(y.requires_grad) z = x * 2 print(z.requires_grad)
Run this code, we will see:
True True False True
In this example, tensor x is gradient supported. However, tensor y is in torch.no_grad() context. We can find:
x requires_grad = True and y requires_grad = False
Meanwhile, tensor z is out of torch.no_grad() context. we can find:
z requires_grad = False
How to use torch.no_grad()?
We usually use torch.no_grad() when we test or evaluate our model. Here is an example code:
model.eval() with torch.no_grad(): # you test / evaluate code