Complete Tutorial for torch.max() in PyTorch with Examples

Introduction

In this article, we will go through the tutorial for the torch.max() function of PyTorch which is used to get the maximum element of a tensor. First of all, we shall understand how it works along with syntax and we shall see some examples for a better understanding of the beginners.

What is torch.max() function in PyTorch

The torch max() function is used to retrieve the elements with maximum values in a tensor along with its indices. The maximum value can be of the entire tensor among all dimensions or along a specific dimension. Apparently, it looks quite straightforward but you should know the nuances of how this function works otherwise you may be in for some unexpected surprises.

Let us first understand the syntax of this PyTorch max() function –

Syntax & Parameters

torch.max(input, dim, keepdim=False) –> Returns maximum values and indices in a tuple (values, indices)

  • input : The input tensor in which the max element has to be retrieved
  • dim : The dimension or the list of dimensions along which max elements have to retrieve. If not specified it will fetch the max among all dimensions.
  • keepdim : When it is True the dimension of the input tensor is retained for the output tensor of max values. When it is False then the dimension of the output tensor is 1 less than the input tensor. By default is False.

Examples of torch.max() with 1-D Tensor

Before starting the examples of the torch max function first of all let us import the torch library.

In [0]:

import torch;

Creating 1-D Tensor

For our first few examples, let us create a 1-D tensor as shown below –

In [1]:

tensor = torch.randint(high = 10, size=(7,))

tensor
Out [1]:
tensor([8, 9, 8, 8, 0, 8, 4])

Applying torch.max() Function with Default Values

We start by using the PyTorch max function with all its default values over 1-D tensor as input. Since we have not specified any dimension (dim parameter), the max element is fetched from all available dimensions.

In [2]:

max = torch.max(tensor)

print(tensor)
print('\nShape: ' + str(tensor.shape))
print('\nMax element is ' + str(max))
print('\nMax element shape is ' + str(max.shape))
Out[2]:
tensor([8, 9, 8, 8, 0, 8, 4])

Shape: torch.Size([7])

Max element is tensor(9)

Max element shape is torch.Size([])

Applying torch.max() with Dim=0

Next, we apply the max function across dim=0 which returns the max element and index.

In [3]:

max, index = torch.max(tensor, dim=0)

print(tensor)
print('\nShape: ' + str(tensor.shape))
print('\nMax element is ' + str(max))
print('\nMax element shape is ' + str(max.shape))
print('\nMax element indices are ' + str(index))
Out[3]:
tensor([8, 9, 8, 8, 0, 8, 4])

Shape: torch.Size([7])

Max element is tensor(9)

Max element shape is torch.Size([])

Max element indices are tensor(1)

Applying torch.max() with Dim=0 & keepdim=True

In the previous two examples, we observed that the resulting max element had one dimension less than the original 1D tensor, i.e. scalar -> torch.Size([]). This time we pass the parameter keepdim=True which ensures the original dimension is retained in the output -> torch.Size([1])

In [4]:

max, index = torch.max(tensor, dim=0, keepdim=True)

print(tensor)
print('\nShape: ' + str(tensor.shape))
print('\nMax element is ' + str(max))
print('\nMax element shape is ' + str(max.shape))
print('\nMax element indices are ' + str(index))
Out [4]:
tensor([8, 9, 8, 8, 0, 8, 4])

Shape: torch.Size([7])

Max element is tensor([9])

Max element shape is torch.Size([1])

Max element indices are tensor([1])

Examples of torch.max() with 2-D Tensor

Creating 2-D Tensor

Let us create a 2-D tensor that will be used for the torch max function below.

In [5]:

tensor = torch.randint(high = 20, size=(3,3))

tensor
Out[5]:
tensor([[13,  5, 16],
        [ 0, 16,  6],
        [18, 11,  1]])

Applying torch.max() Function with Default Values

With default values, the torch max function returns the scalar maximum value among all available dimensions.

In [6]:

max = torch.max(tensor)

print(tensor)
print('\nShape: ' + str(tensor.shape))
print('\nMax element is ' + str(max))
print('\nMax element shape is ' + str(max.shape))
Out[6]:
tensor([[13,  5, 16],
        [ 0, 16,  6],
        [18, 11,  1]])

Shape: torch.Size([3, 3])

Max element is tensor(18)

Max element shape is torch.Size([])

Applying torch max() with Dim = 0

With dim=0, the torch max() function returns each of the max elements across dim =0. Since by default keepdim=False, the max values are returned in a tensor whose dimension is 1, i.e. one less than 2 dimensions of the input tensor. It also returns the corresponding indices of max values in another tensor.

In [7]:

max, index = torch.max(tensor, dim=0)

print(tensor)
print('\nShape: ' + str(tensor.shape))
print('\nMax element is ' + str(max))
print('\nMax element shape is ' + str(max.shape))
print('\nMax element indices are ' + str(index))
Out[7]:
tensor([[13,  5, 16],
        [ 0, 16,  6],
        [18, 11,  1]])

Shape: torch.Size([3, 3])

Max element is tensor([18, 16, 16])

Max element shape is torch.Size([3])

Max element indices are tensor([2, 1, 0])

Applying torch max() with Dim = 0 & keepdim = True

In contrast with the above example, when we pass keepdim=True the dimension of the output tensor of max values & indices are kept the same as that of the input tensor, i.e. 2.
In [8]:
max, index = torch.max(tensor, dim=0, keepdim=True)

print(tensor)
print('\nShape: ' + str(tensor.shape))
print('\nMax element is ' + str(max))
print('\nMax element shape is ' + str(max.shape))
print('\nMax element indices are ' + str(index))
Out[8]:
tensor([[13,  5, 16],
        [ 0, 16,  6],
        [18, 11,  1]])

Shape: torch.Size([3, 3])

Max element is tensor([[18, 16, 16]])

Max element shape is torch.Size([1, 3])

Max element indices are tensor([[2, 1, 0]])

Applying torch max() with Dim = 1

Similar to earlier examples, this example finds the max values & their indices along dim=1. Also since by default, keepdim=False, the output tensors have one dimension less than the input 2D tensor, i.e.1.

In [9]:

max, index = torch.max(tensor, dim=1)

print(tensor)
print('\nShape: ' + str(tensor.shape))
print('\nMax element is ' + str(max))
print('\nMax element shape is ' + str(max.shape))
print('\nMax element indices are ' + str(index))
Out[9]:
tensor([[13,  5, 16],
        [ 0, 16,  6],
        [18, 11,  1]])

Shape: torch.Size([3, 3])

Max element is tensor([16, 16, 18])

Max element shape is torch.Size([3])

Max element indices are tensor([2, 1, 0])

Applying torch max() with Dim = 1 & keepdim = True

Again, when we pass keepdim=True along with dim=1 then the resulting tensors have the same dimension as that of the input tensor, i.e. 2.

In [10]:

max, index = torch.max(tensor, dim=1, keepdim=True)

print(tensor)
print('\nShape: ' + str(tensor.shape))
print('\nMax element is ' + str(max))
print('\nMax element shape is ' + str(max.shape))
print('\nMax element indices are ' + str(index))
Out[10]:
tensor([[13,  5, 16],
        [ 0, 16,  6],
        [18, 11,  1]])

Shape: torch.Size([3, 3])

Max element is tensor([[16],
        [16],
        [18]])

Max element shape is torch.Size([3, 1])

Max element indices are tensor([[2],
        [1],
        [0]])
Reference – PyTorch Documentation
  • MLK

    MLK is a knowledge sharing community platform for machine learning enthusiasts, beginners and experts. Let us create a powerful hub together to Make AI Simple for everyone.

Follow Us

Leave a Reply

Your email address will not be published. Required fields are marked *