Understand Tensor.detach() with Examples in PyTorch – PyTorch Tutorial

By | September 15, 2022

In pytorch, tensor.detach() function will detach a new tensor from the current graph. In this tutorial, we will use some examples to show you how to use it.


It will return a new tensor, which will not never require gradient.

For example, we can stop a tensor gradient with tf.stop_gradient() in tensorflow. However, there does not exist this kind of function in pytorch.

In pytorch, we can use tensor.detach() to stop the gradient of a tensor.

For example:

import torch

x = torch.randn([5,4])
y = x.detach()

Here we can compute the gradient of tensor y.

Meanwhile, we often to use it when we are evaluating a model.

For example:

for line in lines:			
	embedding_11, embedding_12 = embeddings[line.split()[1]]
	embedding_21, embedding_22 = embeddings[line.split()[2]]
	# Compute the scores
	score_1 = torch.mean(torch.matmul(embedding_11, embedding_21.T)) # higher is positive
	score_2 = torch.mean(torch.matmul(embedding_12, embedding_22.T))
	score = (score_1 + score_2) / 2
	score = score.detach().cpu().numpy()

Here score = score.detach().cpu().numpy() will send the score tensor value to cpu. Becuase we have used score.detach(), it means the new score tensor will not be computed gradient.