Introduction
In this tutorial, we will do an indepth understanding of how to use torch.sum() function to sum the elements in PyTorch tensor. We will first understand its syntax and then cover its functionalities with various examples and illustrations to make it easy for beginners.
What is torch.sum() function in PyTorch
The torch sum() function is used to sum up the elements inside the tensor in PyTorch along a given dimension or axis. On the surface, this may look like a very easy function but it does not work in an intuitive manner, thus giving headaches to beginners. Donâ€™t worry we will explain you with proper illustration, but before that let us understand its syntax.
Syntax & Parameters
torch.sum(input, dim, keepdim=False, dtype=None)
 input : The input tensor for applying sum to its elements
 dim : The dimension or the list of dimensions along which sum has to be applied. If not specified it will apply the sum along all dimensions.
 keepdim : If it is true the dimension of input tensor is retained in the output tensor, if it is false then the output tensor is reduced by len(dim) where dim is above parameter. The default is False.
 dtype : Denotes the data type for the output tensor. The default is None.
Examples of torch.sum() with 1D Tensor
Before going through examples of torch sum function let us first import the torch library.
In [0]:
import torch;
Creating 1D Tensor
We start by creating a Tensor of one dimension of size 10 as shown below.
In [1]:
tensor1 = torch.arange(10) print(tensor1) print('\nShape:Â 'Â +Â str(tensor1.shape))
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) Shape: torch.Size([10])
Â
Applying torch.sum() Function
keepdim=False
output = torch.sum(tensor1) print(output) print('\nShape:Â 'Â +Â str(output.shape))
tensor(45) Shape: torch.Size([])
keepdim=True
When keepdim=True is passed, it does not reduce the dimension as shown in the below example where the output is 1D Tensor with just one element 45.
In [3]:
output = torch.sum(tensor1, dim =0, keepdim=True) print(output) print('\nShape:Â 'Â +Â str(output.shape))
tensor([45]) Shape: torch.Size([1])
Examples of torch.sum() with 2D Tensor
Creating 2D Tensor
Again we start by creating a 2Dimensional tensor of the size 4Ã—3 that will be used in subsequent examples of torch sum function.
In [4]:
tensor2 = torch.arange(12).reshape(4,3) print(tensor2) print('\nShape:Â 'Â +Â str(tensor2.shape))
tensor([[ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8], [ 9, 10, 11]]) Shape: torch.Size([4, 3])
Applying torch.sum() Function With dim= (0,1)
keepdim=False
In this example, torch.sum() function sums and collapse the tensor along both dim = 0&1. Since keepdim=False, so it reduces the dimension of the 2D tensor by len(dim) i.e. 2, thus resulting in a scaler value 66 as output.
In [5]:
output = torch.sum(tensor2, dim = (0,1), keepdim=False) print(output) print('\nShape:Â 'Â +Â str(output.shape))
tensor(66) Shape: torch.Size([])
keepdim=True
output = torch.sum(tensor2, dim = (0,1), keepdim=True) print(output) print('\nShape:Â 'Â +Â str(output.shape))
tensor([[66]]) Shape: torch.Size([1, 1])
Applying torch.sum() Function With dim= 0
keepdim=False
In this example, torch.sum() function sums and collapses the tensor towards dim = 0 as explained in the above illustration. Since keepdim=False so it also reduces dimension by len(dim) i.e. 1. Hence the resulting tensor is 1Dimensional.
output = torch.sum(tensor2, dim = 0, keepdim=False) print(output) print('\nShape:Â 'Â +Â str(output.shape))
tensor([18, 22, 26]) Shape: torch.Size([3])
keepdim=True
When keepdim=True the dimension of the original input tensor is retained and the final tensor is 2Dimensional of size 1Ã—3.
In[6]:
output = torch.sum(tensor2, dim = 0, keepdim=True) print(output) print('\nShape: ' + str(output.shape))
tensor([[18, 22, 26]]) Shape: torch.Size([1, 3])
Applying torch.sum() Function With dim= 1
keepdim=False
In this example, torch.sum() function sums and collapses the tensor towards dim = 1 as explained in the above illustration. Since keepdim=False so it reduces dimension by len(dim) i.e. 1. Hence the resulting tensor is 1Dimensional.
output = torch.sum(tensor2, dim = 1, keepdim=False) print(output) print('\nShape:Â 'Â +Â str(output.shape))
tensor([ 3, 12, 21, 30]) Shape: torch.Size([4])
keepdim=True
When keepdim=True the dimension of the original input tensor is retained and the final tensor is 2Dimensional of size 4Ã—1.
In [8]:
output = torch.sum(tensor2, dim = 1, keepdim=True) print(output) print('\nShape:Â 'Â +Â str(output.shape))
tensor([[ 3], [12], [21], [30]]) Shape: torch.Size([4, 1])
Examples of torch.sum() with 3D Tensor
Creating 3D Tensor
Again we start by creating a 2Dimensional tensor of the size 2x2x3 that will be used in subsequent examples of torch sum function.
In [9]:
tensor3 = torch.arange(12).reshape(2,2,3) print(tensor3) print('\nShape:Â 'Â +Â str(tensor3.shape))
tensor([[[ 0, 1, 2], [ 3, 4, 5]], [[ 6, 7, 8], [ 9, 10, 11]]]) Shape: torch.Size([2, 2, 3])
Applying torch.sum() Function With dim= (0,1,2)
keepdim=False
In this example, torch.sum() function sums and collapses the tensor towards dim = 0,1 ,2 as explained in the above illustration. Since keepdim=False so it reduces dimension by len(dim) i.e. 3. Hence the resulting tensor is a scaler.
In [10]:
output = torch.sum(tensor3, dim = (0,1,2), keepdim=False) print(output) print('\nShape:Â 'Â +Â str(output.shape))
Out[10]:
tensor(66) Shape: torch.Size([])
keepdim=True
output = torch.sum(tensor3, dim = (0,1,2), keepdim=True) print(output) print('\nShape:Â 'Â +Â str(output.shape))
tensor([[[66]]]) Shape: torch.Size([1, 1, 1])
Applying torch.sum() Function With dim=0
keepdim=False
In this example, torch.sum() function sums and collapses the tensor towards dim = 0 as explained in the above illustration. Since keepdim=False so it also reduces dimension by len(dim) i.e. 1. Hence the resulting tensor is 2Dimensional.
In [12]:
output = torch.sum(tensor3, dim = 0, keepdim=False) print(output) print('\nShape: ' + str(output.shape))
tensor([[ 6, 8, 10], [12, 14, 16]]) Shape: torch.Size([2, 3])
keepdim=True
When keepdim=True the dimension of the original input tensor is retained and the final tensor is 3Dimensional of size 1x2x3.
output = torch.sum(tensor3, dim = 0, keepdim=True) print(output) print('\nShape:Â 'Â +Â str(output.shape))
tensor([[[ 6, 8, 10], [12, 14, 16]]]) Shape: torch.Size([1, 2, 3])
Applying torch.sum() Function With dim=1
keepdim=False
In this example, torch.sum() function sums and collapses the tensor towards dim = 1 as explained in the above illustration. Since keepdim=False so it reduces dimension by len(dim) i.e. 1. Hence the resulting tensor is 2Dimensional of size 2Ã—3.
In [14]:
output = torch.sum(tensor3, dim = 1, keepdim=False) print(output) print('\nShape:Â 'Â +Â str(output.shape))
Out[14]:
tensor([[ 3, 5, 7], [15, 17, 19]]) Shape: torch.Size([2, 3])
keepdim=True
When keepdim=True the original dimension of the input tensor is retained and the final tensor is 3Dimensional of size 2x1x3.
In[15]:
output = torch.sum(tensor3, dim = 1, keepdim=True) print(output) print('\nShape:Â 'Â +Â str(output.shape))
tensor([[[ 3, 5, 7]], [[15, 17, 19]]]) Shape: torch.Size([2, 1, 3])
Applying torch.sum() Function With dim=2
keepdim=False
In this example, torch.sum() function sums and collapses the tensor towards dim = 2 as explained in the above illustration. Since keepdim=False so it reduces dimension by len(dim) i.e. 1. Hence the resulting tensor is 2Dimensional of size 2Ã—2.
In [16]:
output = torch.sum(tensor3, dim = 2, keepdim=False) print(output) print('\nShape:Â 'Â +Â str(output.shape))
Out[16]:
tensor([[ 3, 12], [21, 30]]) Shape: torch.Size([2, 2])
keepdim=True
When keepdim=True the original dimension of the input tensor is retained and the final tensor is 3Dimensional of size 2x2x1.
outputÂ =Â torch.sum(tensor3,Â dimÂ =Â 2,Â keepdim=True) print(output) print('\nShape:Â 'Â +Â str(output.shape))
Out[17]:
tensor([[[ 3], [12]], [[21], [30]]]) Shape: torch.Size([2, 2, 1])
 Reference: PyTorch Documentation
Â

MLK is a knowledge sharing community platform for machine learning enthusiasts, beginners and experts. Let us create a powerful hub together to Make AI Simple for everyone.
View all posts