[Diagram] How to use torch.gather() Function in PyTorch with Examples

Introduction

In this article, we will see how to use torch.gather() function in PyTorch. We shall first understand for what purpose the gather() function is used in PyTorch, along with its syntax. Finally, we will cover a few examples so that the concept can be understood easily by beginners.

What is torch.gather() function in PyTorch

As the name suggests, torch.gather() function is used to create a new tensor by gathering elements from an input tensor along a specific dimension and from specific indices.

Syntax & Parameters

torch.gather(input,dim,index)

  • input : The input tensor from which the values have to be gathered
  • dim : The dimension of the input tensor along which the elements have to be gathered.
  • index: The specific indices along the dimension from where the elements have to be gathered.

Important Note About Dimension (Dim) in PyTorch

It is important to note that the way Dimension (Dim) works in PyTorch is often counterintuitive. As a result, all those functions that use Dim as a parameter produces results that do not match people’s expected output. The same holds true even for this torch gather function.

For e.g. in a 2-D tensor, dim=0 implies the row, and you would think it will gather elements from each row. However, it ends up gathering elements seemingly column-wise. Similarly, for dim=1 you may think it would gather elements column-wise but in actual output, it seemingly does it row-wise.

Well, when it is dim=0, PyTorch implies “collapsing” along the row (think from top to down), and when dim=1, it implies “collapsing” along the column (think from left to right).

We have already covered this concept in great detail along with diagrams in our tutorial for torch.sum() and we advise you to go through it for getting more conceptual & visual clarity.

Examples of torch.gather() with 1-D Tensor

Before we start examples of torch gather function let us first import the torch library.

In [0]:

import torch;

Creating 1-D Tensor

We start by creating a Tensor of one dimension of size 10 as shown below.

In [1]:

tensor1 = torch.arange(10)*2

print(tensor1)
print('\nShape: ' + str(tensor1.shape))
Out[1]:
tensor([ 0,  2,  4,  6,  8, 10, 12, 14, 16, 18])

Shape: torch.Size([10])

Applying torch.gather() Function

This example of torch.gather() is very straightforward, where we are creating an output tensor by gathering elements from the 8th, 4th, and 2nd indices of the input tensor that we created above.

In [2]:

output = torch.gather(input=tensor1,dim=0, index=torch.tensor([8, 4, 2]))

output
Out [2]:
tensor([16,  8,  4])

Examples of torch.gather() with 2-D Tensor

Creating 2-D Tensor

Again we are creating a 2-Dimensional tensor of the size 3×3 for using it in subsequent examples of torch gather function.

In [3]:
tensor2 = torch.arange(9).reshape(3,3)

print(tensor2)
print('\nShape: ' + str(tensor2.shape))
Out[3]:
tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])

Shape: torch.Size([3, 3])

Applying torch.gather() with Dim=0

 

torch.gather() with Dim=0 Example 1

For dim=0 in 2D tensor, let us visualize the working of the torch gather function as shown in the above illustration. Here the column of the index tensor corresponds to the column of the input tensor (this is highlighted with the gray shade). Now for each index value in the index tensor pick up the corresponding value from that column and the index of the input tensor.

Let us understand this example step by step –

  • In the 0th column, the value of the element in the 0th and 2nd index of the input tensor is 0 and 6 respectively.
  • In 1st column, the value of the element in the 1st index of the input tensor is 4.
  • Similarly, in the 2nd column, the value of the element in the 2nd and 0th index of the input tensor is 8 and 2 respectively.

This is implemented in the code as below –

In [4]:

torch.gather(input=tensor2, dim=0, index = torch.tensor([[0, 1, 2], 
                                                         [2, 1, 0]]))
Out[4]:
tensor([[0, 4, 8],
        [6, 4, 2]])

Applying torch.gather() with Dim=1

torch.gather() with Dim=1 Example 1

The above illustration shows the working of the torch gather() function on 2D tensor with dim=1.

Here the row of the index tensor corresponds to the row of the input tensor (this is highlighted with the gray shade). Now for each index value in the index tensor, pick up the corresponding value from that row and the index of the input tensor.

Let us understand this example step by step –
  • In the 0th row, the value of the element in the 2nd, 1st, 0th index of the input tensor is 2, 1, and 0 respectively.
  • In the 1st row, the value of the element in the 0th, 2nd, and 1st index of the input tensor is 3, 5, and 4 respectively.
The code implementation is shown below –
In [5]:
torch.gather(input=tensor2, dim=1, index = torch.tensor([[2, 1, 0], 
                                                         [0, 2, 1]]))
Out[5]:
tensor([[2, 1, 0],
        [3, 5, 4]])

 

Reference – PyTorch Documentation
  • MLK

    MLK is a knowledge sharing community platform for machine learning enthusiasts, beginners and experts. Let us create a powerful hub together to Make AI Simple for everyone.

    View all posts

Follow Us

Leave a Reply

Your email address will not be published. Required fields are marked *