Understand torch.split(): Split Tensor into Chunks – PyTorch Tutorial

By | March 24, 2023

PyTorch torch.split() function can split a tensor into some chunks. In this tutorial, we will create some exampls to show how to use it.

Syntax

torch.split() is defined as:

torch.split(tensor, split_size_or_sections, dim=0)

It will split a tensor to some chunks based on dim.

split_size_or_sections: It can be a list or integer, it determines the size of each chunk.

How to use torch.split()?

Here we will use some examples to show you how to do.

Step 1: create a tensor

>>> a = torch.arange(10).reshape(5, 2)
>>> a
tensor([[0, 1],
        [2, 3],
        [4, 5],
        [6, 7],
        [8, 9]])

a is a tenosr with shape [5, 2]

split_size_or_sections is an integer

For example: split_size_or_sections = 2

>>> torch.split(a, 2)
(tensor([[0, 1],
         [2, 3]]),
 tensor([[4, 5],
         [6, 7]]),
 tensor([[8, 9]]))

We can find we will get 3 tensors.

However, 5%2 != 0, why do we get 3 tensors?

When split_size_or_sections is an integer, torch.split() will make the size of each chunk is split_size_or_sections  as much as possible.

split_size_or_sections is a list

For example: split_size_or_sections = [1, 4]

torch.split(a, [1, 4])
(tensor([[0, 1]]),
 tensor([[2, 3],
         [4, 5],
         [6, 7],
         [8, 9]]))

We will get 2 tensors, the size of the first tensor is 1, the other is 4

1+4 = 5