PyTorch torch.split() function can split a tensor into some chunks. In this tutorial, we will create some exampls to show how to use it.
Syntax
torch.split() is defined as:
torch.split(tensor, split_size_or_sections, dim=0)
It will split a tensor to some chunks based on dim.
split_size_or_sections: It can be a list or integer, it determines the size of each chunk.
How to use torch.split()?
Here we will use some examples to show you how to do.
Step 1: create a tensor
>>> a = torch.arange(10).reshape(5, 2) >>> a tensor([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]])
a is a tenosr with shape [5, 2]
split_size_or_sections is an integer
For example: split_size_or_sections = 2
>>> torch.split(a, 2) (tensor([[0, 1], [2, 3]]), tensor([[4, 5], [6, 7]]), tensor([[8, 9]]))
We can find we will get 3 tensors.
However, 5%2 != 0, why do we get 3 tensors?
When split_size_or_sections is an integer, torch.split() will make the size of each chunk is split_size_or_sections as much as possible.
split_size_or_sections is a list
For example: split_size_or_sections = [1, 4]
torch.split(a, [1, 4]) (tensor([[0, 1]]), tensor([[2, 3], [4, 5], [6, 7], [8, 9]]))
We will get 2 tensors, the size of the first tensor is 1, the other is 4
1+4 = 5