# Understand PyTorch torch.max(): Return the Maximum Value of a Tensor – PyTorch Tutorial

By | May 26, 2022

In this tutorial, we will use some examples to show you how to use pytorch torch.max() function, which can make us get the maximum value of a tensor.

## torch.max()

It is defined as:

torch.max(input)

It will return the maximum value of all elements in the input tensor.

However, there are some tips you should notice.

## torch.max() without dimension

It will get the maximum value from all elements in a tensor.

For example:

import torch

input = torch.tensor([
[1, 2, 100],
[4, 3, 3],
[1, 200, 3],
[4, 5, 6]
], dtype=torch.float32)
m = torch.max(input)
print(m)

Here we use torch.max() without dimension, it will return the maximum value 200 in input.

tensor(200.)

## torch.max() with dimension

It will return a tuple (max, max_indices)

For example:

import torch

input = torch.tensor([
[1, 2, 100],
[4, 3, 3],
[1, 200, 3],
[4, 5, 6]
], dtype=torch.float32)
max_value, max_indices = torch.max(input, dim = 0)
print(input)
print(max_value)
print(max_indices)

Here we have used torch.max() with dim = 0, we will see:

tensor([[  1.,   2., 100.],
[  4.,   3.,   3.],
[  1., 200.,   3.],
[  4.,   5.,   6.]])
tensor([  4., 200., 100.])
tensor([1, 2, 0])

If we set dim = 1, we will see:

max_value, max_indices = torch.max(input, dim = 1)
print(input)
print(max_value)
print(max_indices)

The result is:

tensor([[  1.,   2., 100.],
[  4.,   3.,   3.],
[  1., 200.,   3.],
[  4.,   5.,   6.]])
tensor([100.,   4., 200.,   6.])
tensor([2, 0, 1, 2])

We also can use argument keepdim= True, for example:

mport torch

input = torch.tensor([[[
[1, 2, 100],
[4, 3, 3]],
[[1, 200, 3],
[4, 5, 6]
]]], dtype=torch.float32)
print(input.shape)
max_value, max_indices = torch.max(input, dim = 1,keepdim= True)
print(input)
print(max_value)
print(max_indices)

Run this code, we will see:

torch.Size([1, 2, 2, 3])
tensor([[[[  1.,   2., 100.],
[  4.,   3.,   3.]],

[[  1., 200.,   3.],
[  4.,   5.,   6.]]]])
tensor([[[[  1., 200., 100.],
[  4.,   5.,   6.]]]])
tensor([[[[0, 1, 0],
[0, 1, 1]]]])