Understand torch.unbind() with Examples – PyTorch Tutorial

By | March 14, 2023

Pytorch torch.unbind() function can remove a tensor dimension and return a tuple. In this tutorial, we will use some examples to show you how to use it.

Syntax

It is defined as:

torch.unbind(input, dim=0)

Here dim is the dimension we plan to remove.

How to use torch.unbind()?

Look at this example:

import torch

x = torch.randn(5, 4, 3)
print(x)

y = torch.unbind(x, dim = 0)
print("y len = ", len(y))
print("y type = ", type(y))
print("y[0] shape = ", y[0].shape)

We will remove dim = 0

We will get a 5 tensors with shape [4,3]

y len =  5
y type =  <class 'tuple'>
y[0] shape =  torch.Size([4, 3])

Here we can find y is a python tuple.

When dim = 1

y = torch.unbind(x, dim = 1)
print("y len = ", len(y))
print("y type = ", type(y))
print("y[0] shape = ", y[0].shape)
print(y[0])

We will get:

y len =  4
y type =  <class 'tuple'>
y[0] shape =  torch.Size([5, 3])

When dim = 2

y = torch.unbind(x, dim = 2)
print("y len = ", len(y))
print("y type = ", type(y))
print("y[0] shape = ", y[0].shape)
print(y[0])

We will get:

y len =  3
y type =  <class 'tuple'>
y[0] shape =  torch.Size([5, 4])