Understand torch.sort() with Examples – Sort a Tensor – PyTorch Tutorial

By | April 26, 2022

In this tutorial, we will use some examples to show you how to use torch.sort() function in pytorch correctly.

torch.sort()

It is defined as:

torch.sort(input, dim=- 1, descending=False, stable=False, *, out=None)

This function will sort the elements of the input tensor along a given dimension in ascending order by value.

Input parameters

input: a tensor

dim: elements of the input will be sorted along it.

stable: preserve the order of equivalent elements in input, it will make the sorted result more stable.

Output

This function will return (Tensor, LongTensor)

Tensor: the sorted tensor

LongTensor: the index of Tensor elements in input

How to use?

Here we will use some examples to show you how to use this function.

For example:

import torch

x = torch.randn(3, 4)
sorted_tensor, tensor_indices = torch.sort(x)
print("x", x)
print(sorted_tensor)
print(tensor_indices)

Run this code, we will get:

x tensor([[-0.2957, -0.4860, -1.7153, -0.5024],
        [-0.4195, -0.1657, -0.9259, -0.6146],
        [ 1.0523, -1.3726,  0.8340, -2.0896]])
tensor([[-1.7153, -0.5024, -0.4860, -0.2957],
        [-0.9259, -0.6146, -0.4195, -0.1657],
        [-2.0896, -1.3726,  0.8340,  1.0523]])
tensor([[2, 3, 1, 0],
        [2, 3, 0, 1],
        [3, 1, 2, 0]])

You can understand this result as follows:

how to use torch.sort()

Here we will sort elements in x along aixs = -1. Then, we also will save its index in tensor_indices.

If dim =0, how about the result?

For example:

x = torch.randn(3, 4)
sorted_tensor, tensor_indices = torch.sort(x, dim = 0)
print("x", x)
print(sorted_tensor)
print(tensor_indices)

Run this code, we will see:

x tensor([[ 0.8447,  0.0204,  0.1753,  0.7887],
        [-0.2143, -0.0590, -0.4153, -0.4881],
        [-0.0828, -0.1560, -0.6996, -0.3439]])
tensor([[-0.2143, -0.1560, -0.6996, -0.4881],
        [-0.0828, -0.0590, -0.4153, -0.3439],
        [ 0.8447,  0.0204,  0.1753,  0.7887]])
tensor([[1, 2, 2, 1],
        [2, 1, 1, 2],
        [0, 0, 0, 0]])

It is:

understand torch.sort() with examples

How about stable = True?

It is useful when there are some same elements along a dimension.

For example:

import torch

x = torch.tensor([0, 1] * 20)
print("x", x)
sorted_tensor, tensor_indices = torch.sort(x, stable= False)
print(sorted_tensor)
print(tensor_indices)
sorted_tensor_stable, tensor_indices_stable = torch.sort(x, stable= True)
print(sorted_tensor_stable)
print(tensor_indices_stable)

Run this code, we will get:

x tensor([0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
        0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
tensor([ 0, 20,  2, 22,  4, 24,  6, 26,  8, 28, 10, 30, 12, 32, 14, 34, 16, 36,
        18, 38,  1,  3,  5,  7,  9, 11, 13, 15, 17, 19, 25, 31, 23, 33, 27, 35,
        21, 37, 29, 39])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
tensor([ 0,  2,  4,  6,  8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34,
        36, 38,  1,  3,  5,  7,  9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31,
        33, 35, 37, 39])

We can find sorted_tensor and sorted_tensor_stable are the same. However, tensor_indices and tensor_indices_stable are different.

understand torch.sort() stable with examples

We can find when stable = True is much better when there are some same elements in a tensor dimension.