Understand PyTorch torch.max(): Return the Maximum Value of a Tensor – PyTorch Tutorial

By | May 26, 2022

In this tutorial, we will use some examples to show you how to use pytorch torch.max() function, which can make us get the maximum value of a tensor.

torch.max()

It is defined as:

torch.max(input)

It will return the maximum value of all elements in the input tensor.

However, there are some tips you should notice.

torch.max() without dimension

It will get the maximum value from all elements in a tensor.

For example:

import torch

input = torch.tensor([
    [1, 2, 100],
    [4, 3, 3],
    [1, 200, 3],
    [4, 5, 6]
], dtype=torch.float32)
m = torch.max(input)
print(m)

Here we use torch.max() without dimension, it will return the maximum value 200 in input.

tensor(200.)

torch.max() with dimension

It will return a tuple (max, max_indices)

For example:

import torch

input = torch.tensor([
    [1, 2, 100],
    [4, 3, 3],
    [1, 200, 3],
    [4, 5, 6]
], dtype=torch.float32)
max_value, max_indices = torch.max(input, dim = 0)
print(input)
print(max_value)
print(max_indices)

Here we have used torch.max() with dim = 0, we will see:

tensor([[  1.,   2., 100.],
        [  4.,   3.,   3.],
        [  1., 200.,   3.],
        [  4.,   5.,   6.]])
tensor([  4., 200., 100.])
tensor([1, 2, 0])

torch.max() to get maximum value with dim = 0

If we set dim = 1, we will see:

max_value, max_indices = torch.max(input, dim = 1)
print(input)
print(max_value)
print(max_indices)

The result is:

tensor([[  1.,   2., 100.],
        [  4.,   3.,   3.],
        [  1., 200.,   3.],
        [  4.,   5.,   6.]])
tensor([100.,   4., 200.,   6.])
tensor([2, 0, 1, 2])

We also can use argument keepdim= True, for example:

mport torch

input = torch.tensor([[[
    [1, 2, 100],
    [4, 3, 3]],
    [[1, 200, 3],
    [4, 5, 6]
]]], dtype=torch.float32)
print(input.shape)
max_value, max_indices = torch.max(input, dim = 1,keepdim= True)
print(input)
print(max_value)
print(max_indices)

Run this code, we will see:

torch.Size([1, 2, 2, 3])
tensor([[[[  1.,   2., 100.],
          [  4.,   3.,   3.]],

         [[  1., 200.,   3.],
          [  4.,   5.,   6.]]]])
tensor([[[[  1., 200., 100.],
          [  4.,   5.,   6.]]]])
tensor([[[[0, 1, 0],
          [0, 1, 1]]]])