Introduction
This tutorial will provide a detailed exploration of the torch.mean() function in PyTorch. We will begin by understanding the syntax of this function, and then move on to examine its capabilities through a variety of examples and illustrations. This will help beginners gain a thorough understanding of how to use this function to calculate the mean of the elements in a PyTorch tensor.
What is torch.mean() function in PyTorch
The torch sum() function is used to find the mean of elements inside the tensor in PyTorch along a given dimension or axis.
Syntax & Parameters
torch.mean(input, dim, keepdim=False, dtype=None)
- input : The input tensor for which the mean has to be determined
- dim : The dimension or the list of dimensions along which mean has to be calculated. If not specified it calculate the mean along all dimensions.
- keepdim : If it is true the dimension of input tensor is retained in the output tensor, if it is false then the output tensor is reduced by len(dim) where dim is above parameter. The default is False.
- dtype : This denotes the data type for the output tensor. The default is None.
Note – How parameter dim (i.e. dimension) works in PyTorch may be counterintuitive and confusing. We addressed this with illustrations in the following article and advise you to go through it because a similar dim parameter is used in PyTorch mean function also.
Complete Tutorial for torch.sum() to Sum Tensor Elements in PyTorch
Examples of torch.mean() with 1-D Tensor
Before starting the examples of torch mean function first of all let us import the torch library.
In [0]:
import torch;
Creating 1-D Tensor
Below we are creating 1-D tensor of size 4 that will be used in subsequent examples.
In [1]:
tensor1 = torch.randn(4) print(tensor1) print('\nShape: ' + str(tensor1.shape))
tensor([-1.4441, 0.1200, 0.3438, -0.0846]) Shape: torch.Size([4])
Applying torch.mean() Function with Default Values
In the first example, we are applying torch.mean() function with default values on the 1-D tensor. In fact, the below example represents torch.mean(tensor1, dim =0, keepdim=False).
Since we have not explicitly used any parameters, the mean is calculated with keepdim=False. Hence the dimension of the input tensor is not retained in the output tensor and is reduced by len(dim) which is 1, thus producing a scalar output.
In [2]:
output = torch.mean(tensor1) print(output)
tensor(-0.2662)
Applying torch.mean() with keepdim=True
In this example, we have used keepdim=True. As a result, the dimension of the input tensor is retained in the output tensor which is 1-D as evident below.
In [3]:
output = torch.mean(tensor1, dim =0, keepdim=True) output
tensor([-0.2662])
Examples of torch.mean() with 2-D Tensor
Creating 2-D Tensor
Let us create a 2-D tensor that will be used for the torch max function below.
In [4]:
tensor2 = torch.randn(4,3) print(tensor2) print('\nShape: ' + str(tensor2.shape))
tensor([[ 0.1219, 0.3153, -0.0304], [-0.1506, 0.3405, -0.7328], [-0.6010, -0.3473, 0.6824], [-0.9229, 0.1360, 0.7910]]) Shape: torch.Size([4, 3])
Applying torch.mean() Function with Default Values
In this example, we apply torch mean function with default values by not passing any explicit parameters. This is equivalent to torch.mean(tensor2, dim = (0,1), keepdim=False).
It calculates the mean across both 0 and 1 dimensions i.e. on all elements of the 2D tensor. Also since keepdim=False, the dimension of the input tensor is not retained and is reduced by len(dim) which is 2 in this case. Hence the dimension of the output tensor is 2 less than the input tensor and is a scalar.
In [5]:
output = torch.mean(tensor2) print(output)
Out[5]:
tensor(-0.0331)
Applying torch.mean() with dim=(0,1) and keepdim=True
This time we are passing keepdim=True in the below example. Hence the dimension of the input tensor is retained in the output which is 2 dimensions.
In[6]:
output = torch.mean(tensor2, dim = (0,1), keepdim=True) print(output)
tensor([[-0.0331]])
Applying torch.mean() with dim=0 and keepdim=False
In this example, the mean is calculated along the 0 dimension of the 2D tensor. And since keepdim=False, the dimension of output is reduced by len(dim) i.e. by 1. Hence the final output tensor is of 1 dimension.
In[7]:
output = torch.mean(tensor2, dim = 0, keepdim=False) print(output)
tensor([-0.3881, 0.1111, 0.1776])
Applying torch.mean() with dim=0 and keepdim=True
In this case, we are passing keepdim=True, hence the dimension of the input tensor is retained in the output tensor as 2 dimensions.
In [8]:
output = torch.mean(tensor2, dim = 0, keepdim=True) print(output)
tensor([[-0.3881, 0.1111, 0.1776]])
Applying torch.mean() with dim=1 and keepdim=False
In this example, the mean is calculated along the 1 dimension of the 2D tensor. And since keepdim=False, the dimension of output is reduced by len(dim) i.e. by 1. As a result the output tensor is of 1 dimension.
In [9]:
output = torch.mean(tensor2, dim = 1, keepdim=False) print(output)
Out[9]:
tensor([ 0.1356, -0.1809, -0.0886, 0.0014])
Applying torch.mean() with dim=1 and keepdim=True
In this example, we are passing keepdim=True, hence the dimension of the input tensor is retained in the output tensor as 2 dimensions.
In [10]:
output = torch.mean(tensor2, dim = 1, keepdim=True) print(output)
Out[10]:
tensor([[ 0.1356], [-0.1809], [-0.0886], [ 0.0014]]) Shape: torch.Size([4, 1])
Reference- PyTorch Documentation