# Complete Tutorial for torch.max() in PyTorch with Examples ## Introduction

In this article, we will go through the tutorial for the torch.max() function of PyTorch which is used to get the maximum element of a tensor. First of all, we shall understand how it works along with syntax and we shall see some examples for a better understanding of the beginners.

## What is torch.max() function in PyTorch

The torch max() function is used to retrieve the elements with maximum values in a tensor along with its indices. The maximum value can be of the entire tensor among all dimensions or along a specific dimension. Apparently, it looks quite straightforward but you should know the nuances of how this function works otherwise you may be in for some unexpected surprises.

Let us first understand the syntax of this PyTorch max() function –

### Syntax & Parameters

torch.max(input, dim, keepdim=False) –> Returns maximum values and indices in a tuple (values, indices)

• input : The input tensor in which the max element has to be retrieved
• dim : The dimension or the list of dimensions along which max elements have to retrieve. If not specified it will fetch the max among all dimensions.
• keepdim : When it is True the dimension of the input tensor is retained for the output tensor of max values. When it is False then the dimension of the output tensor is 1 less than the input tensor. By default is False.

## Examples of torch.max() with 1-D Tensor

Before starting the examples of the torch max function first of all let us import the torch library.

In :

`import torch;`

### Creating 1-D Tensor

For our first few examples, let us create a 1-D tensor as shown below –

In :

```tensor = torch.randint(high = 10, size=(7,))

tensor```
Out :
`tensor([8, 9, 8, 8, 0, 8, 4])`

In :

```max = torch.max(tensor)

print(tensor)
print('\nShape: ' + str(tensor.shape))
print('\nMax element is ' + str(max))
print('\nMax element shape is ' + str(max.shape))```
Out:
```tensor([8, 9, 8, 8, 0, 8, 4])

Shape: torch.Size()

Max element is tensor(9)

Max element shape is torch.Size([])
```

### Applying torch.max() with Dim=0

Next, we apply the max function across dim=0 which returns the max element and index.

In :

```max, index = torch.max(tensor, dim=0)

print(tensor)
print('\nShape: ' + str(tensor.shape))
print('\nMax element is ' + str(max))
print('\nMax element shape is ' + str(max.shape))
print('\nMax element indices are ' + str(index))```
Out:
```tensor([8, 9, 8, 8, 0, 8, 4])

Shape: torch.Size()

Max element is tensor(9)

Max element shape is torch.Size([])

Max element indices are tensor(1)```

### Applying torch.max() with Dim=0 & keepdim=True

In the previous two examples, we observed that the resulting max element had one dimension less than the original 1D tensor, i.e. scalar -> torch.Size([]). This time we pass the parameter keepdim=True which ensures the original dimension is retained in the output -> torch.Size()

In :

```max, index = torch.max(tensor, dim=0, keepdim=True)

print(tensor)
print('\nShape: ' + str(tensor.shape))
print('\nMax element is ' + str(max))
print('\nMax element shape is ' + str(max.shape))
print('\nMax element indices are ' + str(index))```
Out :
```tensor([8, 9, 8, 8, 0, 8, 4])

Shape: torch.Size()

Max element is tensor()

Max element shape is torch.Size()

Max element indices are tensor()
```

## Examples of torch.max() with 2-D Tensor

### Creating 2-D Tensor

Let us create a 2-D tensor that will be used for the torch max function below.

In :

```tensor = torch.randint(high = 20, size=(3,3))

tensor```
Out:
```tensor([[13,  5, 16],
[ 0, 16,  6],
[18, 11,  1]])```

### Applying torch.max() Function with Default Values

With default values, the torch max function returns the scalar maximum value among all available dimensions.

In :

```max = torch.max(tensor)

print(tensor)
print('\nShape: ' + str(tensor.shape))
print('\nMax element is ' + str(max))
print('\nMax element shape is ' + str(max.shape))```
Out:
```tensor([[13,  5, 16],
[ 0, 16,  6],
[18, 11,  1]])

Shape: torch.Size([3, 3])

Max element is tensor(18)

Max element shape is torch.Size([])
```

### Applying torch max() with Dim = 0

With dim=0, the torch max() function returns each of the max elements across dim =0. Since by default keepdim=False, the max values are returned in a tensor whose dimension is 1, i.e. one less than 2 dimensions of the input tensor. It also returns the corresponding indices of max values in another tensor.

In :

```max, index = torch.max(tensor, dim=0)

print(tensor)
print('\nShape: ' + str(tensor.shape))
print('\nMax element is ' + str(max))
print('\nMax element shape is ' + str(max.shape))
print('\nMax element indices are ' + str(index))```
Out:
```tensor([[13,  5, 16],
[ 0, 16,  6],
[18, 11,  1]])

Shape: torch.Size([3, 3])

Max element is tensor([18, 16, 16])

Max element shape is torch.Size()

Max element indices are tensor([2, 1, 0])
```

### Applying torch max() with Dim = 0 & keepdim = True

In contrast with the above example, when we pass keepdim=True the dimension of the output tensor of max values & indices are kept the same as that of the input tensor, i.e. 2.
In :
```max, index = torch.max(tensor, dim=0, keepdim=True)

print(tensor)
print('\nShape: ' + str(tensor.shape))
print('\nMax element is ' + str(max))
print('\nMax element shape is ' + str(max.shape))
print('\nMax element indices are ' + str(index))```
Out:
```tensor([[13,  5, 16],
[ 0, 16,  6],
[18, 11,  1]])

Shape: torch.Size([3, 3])

Max element is tensor([[18, 16, 16]])

Max element shape is torch.Size([1, 3])

Max element indices are tensor([[2, 1, 0]])```

### Applying torch max() with Dim = 1

Similar to earlier examples, this example finds the max values & their indices along dim=1. Also since by default, keepdim=False, the output tensors have one dimension less than the input 2D tensor, i.e.1.

In :

```max, index = torch.max(tensor, dim=1)

print(tensor)
print('\nShape: ' + str(tensor.shape))
print('\nMax element is ' + str(max))
print('\nMax element shape is ' + str(max.shape))
print('\nMax element indices are ' + str(index))```
Out:
```tensor([[13,  5, 16],
[ 0, 16,  6],
[18, 11,  1]])

Shape: torch.Size([3, 3])

Max element is tensor([16, 16, 18])

Max element shape is torch.Size()

Max element indices are tensor([2, 1, 0])```

### Applying torch max() with Dim = 1 & keepdim = True

Again, when we pass keepdim=True along with dim=1 then the resulting tensors have the same dimension as that of the input tensor, i.e. 2.

In :

```max, index = torch.max(tensor, dim=1, keepdim=True)

print(tensor)
print('\nShape: ' + str(tensor.shape))
print('\nMax element is ' + str(max))
print('\nMax element shape is ' + str(max.shape))
print('\nMax element indices are ' + str(index))```
Out:
```tensor([[13,  5, 16],
[ 0, 16,  6],
[18, 11,  1]])

Shape: torch.Size([3, 3])

Max element is tensor([,
,
])

Max element shape is torch.Size([3, 1])

Max element indices are tensor([,
,
])
```
Reference – PyTorch Documentation