Introduction
In this tutorial, we are going to look at PyTorch Stack and Cat functions that are used for joining tensors. We will go through their syntax along with examples and finally understand the difference between stack vs cat functions which confuses a lot of people.
PyTorch Cat()
Cat() in PyTorch is used for concatenating a sequence of tensors in the same dimension. We must ensure that the tensors used for concatenating should have the same shape or they can be empty on non-concatenating dimensions.
Let’s look at the syntax of the PyTorch cat() function.
Syntax
torch.cat(tensors, dim=0, *, out=None)
Parameters Info:
- tensors (sequence of Tensors) – Here we provide the python sequence that will be used for concatenating.
- dim (int, optional) – This parameter takes the dimension on which the concatenation will be done.
Example 1: Concatenating Multiple Sequence of Tensors
This example shows how to concatenate four different tensors to make one tensor with cat().
We first import the PyTorch library and then with the tensor function, we create desired tensor sequences.
t1 = torch.tensor([1,2,3,4])
t2 = torch.tensor([5,6,7,8])
t3 = torch.tensor([9,10,11,12])
t4 = torch.tensor([13,14,15,16])
We now concatenate them and as we can look at the result, the four different tensors have been concatenated to make one single tensor along 0 dimension which is the only available axis.
torch.cat(
(t1,t2,t3,t4)
,dim=0
)
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16])
PyTorch Stack()
The second function for this tutorial is the stack() function. Stack operation also joins a sequence of tensors but over a new dimension. Moreover, here also tensors should be of the same size.
Let’s look at the syntax of the stack() function in PyTorch.
Syntax
torch.stack(tensors, dim=0, *, out=None)
Parameters Info:
- tensors (sequence of Tensors) – Here we provide the tensors that are to be concatenated.
- dim (int) – This parameter takes the dimension on which the stacking operation will be performed.
Example 1: Stacking Tensors Using Dimension as 0
In this example, we use 3 tensors and join them across a new dimension 0. Please observe that these individual tensors have only a single axis but the resultant combined tensor has the new axes along dimension 0.
For this, we pass the three tensors and dim parameter as ‘0’ to the stack() function and get a resulting tensor of 4X3 shape.
t1 = torch.tensor([1,2,3,4])
t2 = torch.tensor([5,6,7,8])
t3 = torch.tensor([9,10,11,12])
torch.stack(
(t1,t2,t3)
,dim=0
)
tensor([[ 1, 2, 3, 4], [ 5, 6, 7, 8], [ 9, 10, 11, 12]])
Example 2: Stacking Tensors Using Dimension as 1
In this example, we have again use the three tensors and concatenate them with a new dimension 1. Just like the previous example, these tensors individually have a single axis but using stack function we concatenate them resulting in a new axis along dimension = 1.
For this, we pass the three tensors to the stack() function with dim=1.
t1 = torch.tensor([1,2,3,4])
t2 = torch.tensor([5,6,7,8])
t3 = torch.tensor([9,10,11,12])
torch.stack(
(t1,t2,t3)
,dim=1
)
tensor([[1, 5, 9], [2, 6, 10], [3, 7, 11], [4, 8, 12])
PyTorch Stack vs Cat
The two functions that we discussed often confuse people because of their similar functionality of concatenating the PyTorch tensors. Let us understand what is the difference between stack vs cat functions in PyTorch.
In concat() function the tensors are concatenated along the existing axis whereas in stack() function the tensors are concatenated along a new axis that does not exist for the individual tensors.
Let us understand this with the below example.
t1 = torch.tensor([1,2,3,4])
t2 = torch.tensor([5,6,7,8])
t3 = torch.tensor([9,10,11,12])
t4 = torch.tensor([13,14,15,16])
Concatenation with cat()
As you can see that with cat() function the tensors are concatenated along the existing axis.
torch.cat(
(t1,t2,t3,t4)
,dim=0
)
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16])
Concatenation with stack()
Here using the stack() function, the concatenation happens along a new axis, in this example along axis=0.
In [10]:
torch.stack( (t1,t2,t3,t4) ,dim=0 )
tensor([[ 1, 2, 3, 4], [ 5, 6, 7, 8], [ 9, 10, 11, 12], [13, 14, 15, 16]])
In fact, we can derive the same results of stack() using cat() function as well but with an extra step. We can add a new axis to the tensor and then concatenate them with cat().
To do this we apply unsqueeze operation on each tensor to create a new axis along 0 and then pass them to the cat function for concatenation. The resulting output is the same as the output of the stack() function we saw above.
torch.cat(
(
t1.unsqueeze(0)
,t2.unsqueeze(0)
,t3.unsqueeze(0)
,t4.unsqueeze(0)
)
,dim=0
)
tensor([[ 1, 2, 3, 4], [ 5, 6, 7, 8], [ 9, 10, 11, 12], [13, 14, 15, 16]])
- Also Read – PyTorch Tensor – Explained for Beginners
Conclusion
We have reached the end of this tutorial, in this, we learned about PyTorch functions stack and cat used for joining tensors. We also looked at syntax and examples for both functions along with a detailed difference between stack vs cat function.
Reference – PyTorch Documentation
-
I am Palash Sharma, an undergraduate student who loves to explore and garner in-depth knowledge in the fields like Artificial Intelligence and Machine Learning. I am captivated by the wonders these fields have produced with their novel implementations. With this, I have a desire to share my knowledge with others in all my capacity.
View all posts