Introduction
In this PyTorch tutorial, we are learning about some of the inbuilt functions that can help to alter the shapes of the tensors. We will go through the following PyTorch functions Reshape, Squeeze, Unsqueeze, Flatten, and View along with their syntax and examples. These functions will be very useful while manipulating tensor shapes in your PyTorch deep learning projects.
 Also Read â€“ PyTorch Tensor â€“ Explained for Beginners
But let us first import the PyTorch library.
import torch
1. PyTorch Reshape : torch.reshape()
The reshape function in PyTorch gives the output tensor with the same values and number of elements as the input tensor, it only alters the shape of the output tensor as required by the user.
But we have to make sure that the reshaped dimension should be able to hold all the elements of the original tensor otherwise it will give an error.
The syntax of PyTorch reshape() is shown below.
Syntax
torch.reshape(input, shape)
The parameter used in the function is mentioned below.
Parameters
 input(tensor) â€“ The tensor whose shape has to be changed.
 shape(tuple of python) â€“ The new shape.
Output
The output is a tensor having the same value as the input but with a different shape.
Example 1: Simple Reshape Example in PyTorch
An example of PyTorch Reshape is shown below.
Here we build a tensor using the arange function and then we use reshape() function to reshape it into a 3Ã—3 tensor.
a = torch.arange(9.) print('Input Tensor:') print(a) print('Input Tensor Shape:') print(a.shape)
Input Tensor: tensor([0., 1., 2., 3., 4., 5., 6., 7., 8.]) Input Tensor Shape: torch.Size([9])
In [2]:
#Reshape Function
rÂ =Â torch.reshape(a,Â (3,Â 3))
print('Reshaped Tensor:')
print(r)
print('Reshaped Tensor Shape:')
print(r.shape)
Reshaped Tensor: tensor([[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]]) Reshaped Tensor Shape: torch.Size([3, 3])
Example 2: Flatten Tensor in PyTorch with Reshape()
We can flatten a PyTorch tensor using reshape() function by passing the shape parameter a value of 1.
In this example, we can see that a 2Ã—2 tensor has been flattened by passing it to reshape() with the shape parameter as 1.
b = torch.tensor([[45, 56], [27, 34]])
torch.reshape(b, (1,))
tensor([45, 56, 27, 34])
2. PyTorch Squeeze : torch.squeeze()
The squeeze function in PyTorch is used for manipulating a tensor by dropping all its dimensions of inputs having size 1.
For instance, consider an input tensor with shape as (Ax1xBxCx1xD), the output tensor will have the following shape (AxBxCxD).
The syntax of the PyTorch squeeze() function is given below.
Syntax
torch.squeeze(input, dim=None, *, out=None)
There are two parameters for the PyTorch squeeze function.
Parameters
 input (Tensor) â€“ The tensor provided as input.
 dim (int, optional) â€“ If provided by the user, squeeze operation will be restricted to this dimension only.
Keyword Arguments
out (Tensor, optional) â€“ The tensor obtained as an output.
Example of PyTorch Squeeze
Now letâ€™s look at the example where we first create a tensor using PyTorchâ€™s zeros function and then check its size. We can see that there are inputs having dimensions as 1.
x = torch.zeros(1,2,1,2,1)
print(x)
print('Dimension of input tensor is', x.size())
tensor([[[[[0.], [0.]]], [[[0.], [0.]]]]])
Dimension of input tensor is torch.Size([1, 2, 1, 2, 1])
Now in the below code snippet, we are using the squeeze function of PyTorch. As it can be seen, the tensor whose inputs are having the dimension of size 1 is dropped.
y = torch.squeeze(x)
print(y)
print('Dimension of output tensor is', y.size())
tensor([[0., 0.], [0., 0.]])
Dimension of output tensor is torch.Size([2, 2])
3. PyTorch Unsqueeze : torch.unsqueeze()
PyTorch unsqueeze function is used to generates a new tensor as output by adding a new dimension of size one at the desired position.
Again in this case as well, the data and all the elements remain the same in the tensor obtained as output.
Let us see the syntax for PyTorch unsqueeze() function below.
Syntax
torch.unsqueeze(input, dim)
Again for PyTorch unsqueeze function, we have got two parameters and for output, we get a tensor.
Parameters
 input(tensor) â€“ The input tensor
 dim(int) â€“ This parameter specifies the index at which singleton dimension is inserted.
Output
As an output, we get a tensor.
Example â€“ 1: PyTorch Unsqueeze along Dimension = 0
The following code snippets show us how the PyTorch unsqueeze function is used to add a new singleton dimension of size 1 along dimension = 0 (i.e. axis = 0) in the original tensor. The resulting output tensor gets the new shape of 1Ã—5
x = torch.tensor([50, 25, 75, 100, 150])
print("Shape of Input Tensor :", x.shape)
u = torch.unsqueeze(x, 0)
print("Output Tensor :" ,u)
print("Shape of Output Tensor :", u.shape)
Shape of Input Tensor : torch.Size([5]) Output Tensor : tensor([[ 50, 25, 75, 100, 150]]) Shape of Output Tensor : torch.Size([1, 5])
Example â€“ 2: PyTorch Unsqueeze along Dimension = 1
In this second case, we use the same input tensor but this time unsqueeze it along dimension = 1 (i.e. axis =1) and we can see that the output tensor has size 5Ã—1.
x = torch.tensor([50, 25, 75, 100, 150])
print("Shape of Input Tensor :", x.shape)
u = torch.unsqueeze(x, 1)
print("Output Tensor :" ,u)
print("Shape of Output Tensor :", u.shape)
Shape of Input Tensor : torch.Size([5]) Output Tensor : tensor([[ 50], [ 25], [ 75], [100], [150]]) Shape of Output Tensor : torch.Size([5, 1])
4. PyTorch Flatten : torch.flatten()
Pytorch Flatten function is used for flattening a tensor that has a certain shape.
Below is the syntax of flatten() function of PyTorch.
Syntax
torch.flatten(input, start_dim=0, end_dim=1)
Parameters
 input (Tensor) â€“ The input tensor is entered by the user.
 start_dim (int) â€“ The first dimension where flatten operation is applied.
 end_dim (int) â€“ The last dimension where flatten operation is applied.
Example â€“ 1:Â Simple use of PyTorch FlattenÂ
In this example, we take a tensor of 2x2x2 and use PyTorch flatten function to get a tensor of a single dimension having size 8.
t = torch.tensor([[[19, 22],
[37, 43]],
[[52, 69],
[71, 85]]])
print('Shape of input tensor :', t.shape)
f = torch.flatten(t)
print('Output tensor :',f)
print('Shape of input tensor :',f.shape)
Shape of input tensor : torch.Size([2, 2, 2]) Output tensor : tensor([19, 22, 37, 43, 52, 69, 71, 85]) Shape of input tensor : torch.Size([8])
Example â€“ 2:Â PyTorch Flatten with start_dimÂ = 1
In this example, we will use the same tensor but will not completely flatten it. Instead, we will flatten it starting with dimension = 1. This will give the output tensor whose shape isÂ 2Ã—4.
t = torch.tensor([[[19, 22],
[37, 43]],
[[52, 69],
[71, 85]]])
print('Shape of input tensor :', t.shape)
f = torch.flatten(t,start_dim=1)
print('Output tensor :',f)
print('Shape of input tensor :',f.shape)
Shape of input tensor : torch.Size([2, 2, 2]) Output tensor : tensor([[19, 22, 37, 43], [52, 69, 71, 85]]) Shape of input tensor : torch.Size([2, 4])
5. PyTorch ViewÂ
In PyTorch, you can create a view on top of the existing tensor. View does not explicitly copy the data but shares the same underlying data of the base tensor. Not keeping a separate copy allows for faster reshaping, slicing, and elementwise operations in the memory.
Example â€“ 1 : Simple PyTorch View
In this example of the PyTorch view, we create a 1D tensor that has 16 elements. We can create a view on top of this tensor of shape 2Ã—8.
a = torch.range(1, 16)
print('Base tensor :',a)
print('Shape of base tensor :',a.shape)
v = a.view(2, 8)
print('View on tensor :',v)
print('Shape of view :',v.shape)
Base tensor : tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16.]) Shape of base tensor : torch.Size([16]) View on tensor : tensor([[ 1., 2., 3., 4., 5., 6., 7., 8.], [ 9., 10., 11., 12., 13., 14., 15., 16.]]) Shape of view : torch.Size([2, 8])
Example â€“ 2: PyTorch View share same Memory
In this example, we show that the PyTorch view shares the same memory of the base tensor and any change in view effects the base tensor as well.
t = torch.rand(5, 4)
c = t.view(4, 5)
t.storage().data_ptr() == c.storage().data_ptr()
True # i.e. `t` and `b` share the same underlying data.
Since views share underlying data with their base tensor, if you edit the data in the view, it will be reflected in the base tensor as well.
# Modifying view tensor changes base tensor as well.
c[0][0] = 9.97
t[0][0]
tensor(9.9700) # Here base tensor value also changes
Â
 Also Read â€“ PyTorch Tensor â€“ Explained for Beginners
Conclusion
In this tutorial we saw the following functions for manipulating PyTorch tensors â€“ Reshape, Squeeze, Unsqueeze, Flatten, and View. We have seen the syntaxes for all these functions and also the parameters that are important while using the functions. Apart from this, we also saw multiple examples for each of them.
Reference â€“Â PyTorch Documentation

I am Palash Sharma, an undergraduate student who loves to explore and garner indepth knowledge in the fields like Artificial Intelligence and Machine Learning. I am captivated by the wonders these fields have produced with their novel implementations. With this, I have a desire to share my knowledge with others in all my capacity.