Introduction
In this article, we will go through the tutorial for the torch.max() function of PyTorch which is used to get the maximum element of a tensor. First of all, we shall understand how it works along with syntax and we shall see some examples for a better understanding of the beginners.
What is torch.max() function in PyTorch
The torch max() function is used to retrieve the elements with maximum values in a tensor along with its indices. The maximum value can be of the entire tensor among all dimensions or along a specific dimension. Apparently, it looks quite straightforward but you should know the nuances of how this function works otherwise you may be in for some unexpected surprises.
Let us first understand the syntax of this PyTorch max() function –
Syntax & Parameters
torch.max(input, dim, keepdim=False) –> Returns maximum values and indices in a tuple (values, indices)
- input : The input tensor in which the max element has to be retrieved
- dim : The dimension or the list of dimensions along which max elements have to retrieve. If not specified it will fetch the max among all dimensions.
- keepdim : When it is True the dimension of the input tensor is retained for the output tensor of max values. When it is False then the dimension of the output tensor is 1 less than the input tensor. By default is False.
Examples of torch.max() with 1-D Tensor
Before starting the examples of the torch max function first of all let us import the torch library.
In [0]:
import torch;
Creating 1-D Tensor
For our first few examples, let us create a 1-D tensor as shown below –
In [1]:
tensor = torch.randint(high = 10, size=(7,)) tensor
tensor([8, 9, 8, 8, 0, 8, 4])
Applying torch.max() Function with Default Values
We start by using the PyTorch max function with all its default values over 1-D tensor as input. Since we have not specified any dimension (dim parameter), the max element is fetched from all available dimensions.
In [2]:
max = torch.max(tensor) print(tensor) print('\nShape: ' + str(tensor.shape)) print('\nMax element is ' + str(max)) print('\nMax element shape is ' + str(max.shape))
tensor([8, 9, 8, 8, 0, 8, 4]) Shape: torch.Size([7]) Max element is tensor(9) Max element shape is torch.Size([])
Applying torch.max() with Dim=0
Next, we apply the max function across dim=0 which returns the max element and index.
In [3]:
max, index = torch.max(tensor, dim=0) print(tensor) print('\nShape: ' + str(tensor.shape)) print('\nMax element is ' + str(max)) print('\nMax element shape is ' + str(max.shape)) print('\nMax element indices are ' + str(index))
tensor([8, 9, 8, 8, 0, 8, 4]) Shape: torch.Size([7]) Max element is tensor(9) Max element shape is torch.Size([]) Max element indices are tensor(1)
Applying torch.max() with Dim=0 & keepdim=True
In the previous two examples, we observed that the resulting max element had one dimension less than the original 1D tensor, i.e. scalar -> torch.Size([]). This time we pass the parameter keepdim=True which ensures the original dimension is retained in the output -> torch.Size([1])
In [4]:
max, index = torch.max(tensor, dim=0, keepdim=True) print(tensor) print('\nShape: ' + str(tensor.shape)) print('\nMax element is ' + str(max)) print('\nMax element shape is ' + str(max.shape)) print('\nMax element indices are ' + str(index))
tensor([8, 9, 8, 8, 0, 8, 4]) Shape: torch.Size([7]) Max element is tensor([9]) Max element shape is torch.Size([1]) Max element indices are tensor([1])
Examples of torch.max() with 2-D Tensor
Creating 2-D Tensor
Let us create a 2-D tensor that will be used for the torch max function below.
In [5]:
tensor = torch.randint(high = 20, size=(3,3)) tensor
tensor([[13, 5, 16], [ 0, 16, 6], [18, 11, 1]])
Applying torch.max() Function with Default Values
With default values, the torch max function returns the scalar maximum value among all available dimensions.
In [6]:
max = torch.max(tensor) print(tensor) print('\nShape: ' + str(tensor.shape)) print('\nMax element is ' + str(max)) print('\nMax element shape is ' + str(max.shape))
tensor([[13, 5, 16], [ 0, 16, 6], [18, 11, 1]]) Shape: torch.Size([3, 3]) Max element is tensor(18) Max element shape is torch.Size([])
Applying torch max() with Dim = 0
With dim=0, the torch max() function returns each of the max elements across dim =0. Since by default keepdim=False, the max values are returned in a tensor whose dimension is 1, i.e. one less than 2 dimensions of the input tensor. It also returns the corresponding indices of max values in another tensor.
In [7]:
max, index = torch.max(tensor, dim=0) print(tensor) print('\nShape: ' + str(tensor.shape)) print('\nMax element is ' + str(max)) print('\nMax element shape is ' + str(max.shape)) print('\nMax element indices are ' + str(index))
tensor([[13, 5, 16], [ 0, 16, 6], [18, 11, 1]]) Shape: torch.Size([3, 3]) Max element is tensor([18, 16, 16]) Max element shape is torch.Size([3]) Max element indices are tensor([2, 1, 0])
Applying torch max() with Dim = 0 & keepdim = True
max, index = torch.max(tensor, dim=0, keepdim=True) print(tensor) print('\nShape: ' + str(tensor.shape)) print('\nMax element is ' + str(max)) print('\nMax element shape is ' + str(max.shape)) print('\nMax element indices are ' + str(index))
tensor([[13, 5, 16], [ 0, 16, 6], [18, 11, 1]]) Shape: torch.Size([3, 3]) Max element is tensor([[18, 16, 16]]) Max element shape is torch.Size([1, 3]) Max element indices are tensor([[2, 1, 0]])
Applying torch max() with Dim = 1
Similar to earlier examples, this example finds the max values & their indices along dim=1. Also since by default, keepdim=False, the output tensors have one dimension less than the input 2D tensor, i.e.1.
In [9]:
max, index = torch.max(tensor, dim=1) print(tensor) print('\nShape: ' + str(tensor.shape)) print('\nMax element is ' + str(max)) print('\nMax element shape is ' + str(max.shape)) print('\nMax element indices are ' + str(index))
tensor([[13, 5, 16], [ 0, 16, 6], [18, 11, 1]]) Shape: torch.Size([3, 3]) Max element is tensor([16, 16, 18]) Max element shape is torch.Size([3]) Max element indices are tensor([2, 1, 0])
Applying torch max() with Dim = 1 & keepdim = True
Again, when we pass keepdim=True along with dim=1 then the resulting tensors have the same dimension as that of the input tensor, i.e. 2.
In [10]:
max, index = torch.max(tensor, dim=1, keepdim=True) print(tensor) print('\nShape: ' + str(tensor.shape)) print('\nMax element is ' + str(max)) print('\nMax element shape is ' + str(max.shape)) print('\nMax element indices are ' + str(index))
tensor([[13, 5, 16], [ 0, 16, 6], [18, 11, 1]]) Shape: torch.Size([3, 3]) Max element is tensor([[16], [16], [18]]) Max element shape is torch.Size([3, 1]) Max element indices are tensor([[2], [1], [0]])