PyTorch Tutorial for Reshape, Squeeze, Unsqueeze, Flatten and View

Introduction

In this PyTorch tutorial, we are learning about some of the in-built functions that can help to alter the shapes of the tensors. We will go through the following PyTorch functions Reshape, Squeeze, Unsqueeze, Flatten, and View along with their syntax and examples. These functions will be very useful while manipulating tensor shapes in your PyTorch deep learning projects.

But let us first import the PyTorch library.

In [1]:
import torch

1. PyTorch Reshape : torch.reshape()

The reshape function in PyTorch gives the output tensor with the same values and number of elements as the input tensor, it only alters the shape of the output tensor as required by the user.

But we have to make sure that the reshaped dimension should be able to hold all the elements of the original tensor otherwise it will give an error.

The syntax of PyTorch reshape() is shown below.

Syntax

torch.reshape(input, shape)

The parameter used in the function is mentioned below.

Parameters

  • input(tensor) – The tensor whose shape has to be changed.
  • shape(tuple of python) – The new shape.

Output

The output is a tensor having the same value as the input but with a different shape.

Example 1: Simple Reshape Example in PyTorch

An example of PyTorch Reshape is shown below.

Here we build a tensor using the arange function and then we use reshape() function to reshape it into a 3×3 tensor.

In [2]:
a = torch.arange(9.)

print('Input Tensor:')
print(a)

print('Input Tensor Shape:')
print(a.shape)
Output:
Input Tensor:
tensor([0., 1., 2., 3., 4., 5., 6., 7., 8.])

Input Tensor Shape:
torch.Size([9])

In [2]:

#Reshape Function
r = torch.reshape(a, (3, 3))

print('Reshaped Tensor:')
print(r)

print('Reshaped Tensor Shape:')
print(r.shape)
Output:
Reshaped Tensor:
tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]])
		
Reshaped Tensor Shape:
torch.Size([3, 3])

Example 2: Flatten Tensor in PyTorch with Reshape()

We can flatten a PyTorch tensor using reshape() function by passing the shape parameter a value of -1.

In this example, we can see that a 2×2 tensor has been flattened by passing it to reshape() with the shape parameter as -1.

In [3]:
b = torch.tensor([[45, 56], [27, 34]])
torch.reshape(b, (-1,))
Output:
tensor([45, 56, 27, 34])

2. PyTorch Squeeze : torch.squeeze()

The squeeze function in PyTorch is used for manipulating a tensor by dropping all its dimensions of inputs having size 1.

For instance, consider an input tensor with shape as (Ax1xBxCx1xD), the output tensor will have the following shape (AxBxCxD).

The syntax of the PyTorch squeeze() function is given below.

Syntax

torch.squeeze(input, dim=None, *, out=None)

There are two parameters for the PyTorch squeeze function.

Parameters

  • input (Tensor) – The tensor provided as input.
  • dim (int, optional) – If provided by the user, squeeze operation will be restricted to this dimension only.

Keyword Arguments

out (Tensor, optional) – The tensor obtained as an output.

Example of PyTorch Squeeze

Now let’s look at the example where we first create a tensor using PyTorch’s zeros function and then check its size. We can see that there are inputs having dimensions as 1.

In [4]:
x = torch.zeros(1,2,1,2,1)
print(x)
print('Dimension of input tensor is', x.size())
Output:
tensor([[[[[0.],
[0.]]],


[[[0.],
[0.]]]]])
Dimension of input tensor is torch.Size([1, 2, 1, 2, 1])

Now in the below code snippet, we are using the squeeze function of PyTorch. As it can be seen, the tensor whose inputs are having the dimension of size 1 is dropped.

In [9]:
y = torch.squeeze(x)
print(y)
print('Dimension of output tensor is', y.size())
Output:
tensor([[0., 0.],
[0., 0.]])
Dimension of output tensor is torch.Size([2, 2])

3. PyTorch Unsqueeze : torch.unsqueeze()

PyTorch unsqueeze function is used to generates a new tensor as output by adding a new dimension of size one at the desired position.

Again in this case as well, the data and all the elements remain the same in the tensor obtained as output.

Let us see the syntax for PyTorch unsqueeze() function below.

Syntax

torch.unsqueeze(input, dim)

Again for PyTorch unsqueeze function, we have got two parameters and for output, we get a tensor.

Parameters

  • input(tensor) – The input tensor
  • dim(int) – This parameter specifies the index at which singleton dimension is inserted.

Output

As an output, we get a tensor.

Example – 1: PyTorch Unsqueeze along Dimension = 0

The following code snippets show us how the PyTorch unsqueeze function is used to add a new singleton dimension of size 1 along dimension = 0 (i.e. axis = 0) in the original tensor. The resulting output tensor gets the new shape of 1×5

In [10]:
x = torch.tensor([50, 25, 75, 100, 150])

print("Shape of Input Tensor :", x.shape)

u = torch.unsqueeze(x, 0) 

print("Output Tensor :" ,u)

print("Shape of Output Tensor :", u.shape)
Output:
Shape of Input Tensor : torch.Size([5])

Output Tensor : tensor([[ 50,  25,  75, 100, 150]])

Shape of Output Tensor : torch.Size([1, 5])

Example – 2: PyTorch Unsqueeze along Dimension = 1

In this second case, we use the same input tensor but this time unsqueeze it along dimension = 1 (i.e. axis =1) and we can see that the output tensor has size 5×1.

In [10]:
x = torch.tensor([50, 25, 75, 100, 150])

print("Shape of Input Tensor :", x.shape)

u = torch.unsqueeze(x, 1) 

print("Output Tensor :" ,u)

print("Shape of Output Tensor :", u.shape)
Output:
Shape of Input Tensor : torch.Size([5])

Output Tensor : tensor([[ 50],
        [ 25],
        [ 75],
        [100],
        [150]])

Shape of Output Tensor : torch.Size([5, 1])

4. PyTorch Flatten : torch.flatten()

Pytorch Flatten function is used for flattening a tensor that has a certain shape.

Below is the syntax of flatten() function of PyTorch.

Syntax

torch.flatten(input, start_dim=0, end_dim=-1)

Parameters

  • input (Tensor) – The input tensor is entered by the user.
  • start_dim (int) – The first dimension where flatten operation is applied.
  • end_dim (int) – The last dimension where flatten operation is applied.

Example – 1:  Simple use of PyTorch Flatten 

In this example, we take a tensor of 2x2x2 and use PyTorch flatten function to get a tensor of a single dimension having size 8.

In [12]:
t = torch.tensor([[[19, 22],
[37, 43]],
[[52, 69],
[71, 85]]])

print('Shape of input tensor :', t.shape)

f = torch.flatten(t)

print('Output tensor :',f)
print('Shape of input tensor :',f.shape)
Output:
Shape of input tensor : torch.Size([2, 2, 2])

Output tensor : tensor([19, 22, 37, 43, 52, 69, 71, 85])

Shape of input tensor : torch.Size([8])

Example – 2:  PyTorch Flatten with start_dim  = 1

In this example, we will use the same tensor but will not completely flatten it. Instead, we will flatten it starting with dimension = 1. This will give the output tensor whose shape is  2×4.

In [14]:
t = torch.tensor([[[19, 22],
[37, 43]],
[[52, 69],
[71, 85]]])

print('Shape of input tensor :', t.shape)

f = torch.flatten(t,start_dim=1)

print('Output tensor :',f)
print('Shape of input tensor :',f.shape)
Output:
Shape of input tensor : torch.Size([2, 2, 2])

Output tensor : tensor([[19, 22, 37, 43],
        [52, 69, 71, 85]])

Shape of input tensor : torch.Size([2, 4])

5. PyTorch View 

In PyTorch, you can create a view on top of the existing tensor. View does not explicitly copy the data but shares the same underlying data of the base tensor. Not keeping a separate copy allows for faster reshaping, slicing, and element-wise operations in the memory.

Example – 1 : Simple PyTorch View

In this example of the PyTorch view, we create a 1-D tensor that has 16 elements. We can create a view on top of this tensor of shape 2×8.

In [16]:
a = torch.range(1, 16)

print('Base tensor :',a)
print('Shape of base tensor :',a.shape)
v = a.view(2, 8) 
print('View on tensor :',v)
print('Shape of view :',v.shape)
Output:
Base tensor : tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13., 14.,
        15., 16.])

Shape of base tensor : torch.Size([16])

View on tensor : tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12., 13., 14., 15., 16.]])

Shape of view : torch.Size([2, 8])

Example – 2: PyTorch View share same Memory

In this example, we show that the PyTorch view shares the same memory of the base tensor and any change in view effects the base tensor as well.

In [17]:
t = torch.rand(5, 4)

c = t.view(4, 5)

t.storage().data_ptr() == c.storage().data_ptr() 
Output:
True  # i.e. `t` and `b` share the same underlying data.

Since views share underlying data with their base tensor, if you edit the data in the view, it will be reflected in the base tensor as well.

In [19]:
# Modifying view tensor changes base tensor as well.
c[0][0] = 9.97
t[0][0]
Output:
tensor(9.9700) # Here base tensor value also changes

 

Conclusion

In this tutorial we saw the following functions for manipulating PyTorch tensors – Reshape, Squeeze, Unsqueeze, Flatten, and View. We have seen the syntaxes for all these functions and also the parameters that are important while using the functions. Apart from this, we also saw multiple examples for each of them.

Reference – PyTorch Documentation

  • Palash Sharma

    I am Palash Sharma, an undergraduate student who loves to explore and garner in-depth knowledge in the fields like Artificial Intelligence and Machine Learning. I am captivated by the wonders these fields have produced with their novel implementations. With this, I have a desire to share my knowledge with others in all my capacity.

    View all posts

Follow Us

Leave a Reply

Your email address will not be published. Required fields are marked *