Pytorch torch.unbind() function can remove a tensor dimension and return a tuple. In this tutorial, we will use some examples to show you how to use it.
Syntax
It is defined as:
torch.unbind(input, dim=0)
Here dim is the dimension we plan to remove.
How to use torch.unbind()?
Look at this example:
import torch x = torch.randn(5, 4, 3) print(x) y = torch.unbind(x, dim = 0) print("y len = ", len(y)) print("y type = ", type(y)) print("y[0] shape = ", y[0].shape)
We will remove dim = 0
We will get a 5 tensors with shape [4,3]
y len = 5 y type = <class 'tuple'> y[0] shape = torch.Size([4, 3])
Here we can find y is a python tuple.
When dim = 1
y = torch.unbind(x, dim = 1) print("y len = ", len(y)) print("y type = ", type(y)) print("y[0] shape = ", y[0].shape) print(y[0])
We will get:
y len = 4 y type = <class 'tuple'> y[0] shape = torch.Size([5, 3])
When dim = 2
y = torch.unbind(x, dim = 2) print("y len = ", len(y)) print("y type = ", type(y)) print("y[0] shape = ", y[0].shape) print(y[0])
We will get:
y len = 3 y type = <class 'tuple'> y[0] shape = torch.Size([5, 4])