In this article, we will understand how to squeeze and unsqueeze a PyTorch Tensor.
To squeeze a tensor we can apply the torch.squeeze() method and to unsqueeze a tensor we use the torch.unsqueeze() method. Let’s understand these methods in detail.
Squeeze a Tensor:
When we squeeze a tensor, the dimensions of size 1 are removed. The elements of the original tensor are arranged with the remaining dimensions. For example, if the input tensor is of shape: (m×1×n×1) then the output tensor after squeeze will be of shape: (m×n). The following is the syntax of the torch.squeeze() method.
Syntax: torch.squeeze(input, dim=None, *, out=None)
Parameters:
- input: the input tensor.
- dim: an optional integer value, if given the input is squeezed in this dimension.
- out: the output tensor, an optional key argument.
Return: It returns a tensor with all the dimensions of input tensor of size 1 removed.
Please note that we can squeeze the input tensor in a particular dimension dim. In this case, other dimensions of size 1 will remain unchanged. We have discussed Example 2 in more detail.
Example 1:
In the example below we squeeze a 5D tensor using torch.squeeze() method. The input tensor has two dimensions of size 1.
Python3
# Python program to squeeze the tensor # importing torch import torch # creating the input tensor input = torch.randn( 3 , 1 , 2 , 1 , 4 ) # print the input tensor print ( "Input tensor Size:\n" , input .size()) # squeeze the tensor output = torch.squeeze( input ) # print the squeezed tensor print ( "Size after squeeze:\n" ,output.size()) |
Output:
Input tensor Size: torch.Size([3, 1, 2, 1, 4]) Size after squeeze: torch.Size([3, 2, 4])
Notice that both dimensions of size 1 are removed in the squeezed tensor.
Example 2:
In this example, We squeeze the tensor into different dimensions.
Python3
# Python program to squeeze the tensor in # different dimensions # importing torch import torch # creating the input tensor input = torch.randn( 3 , 1 , 2 , 1 , 4 ) print ( "Dimension of input tensor:" , input .dim()) print ( "Input tensor Size:\n" , input .size()) # squeeze the tensor in dimension 0 output = torch.squeeze( input ,dim = 0 ) print ( "Size after squeeze with dim=0:\n" , output.size()) # squeeze the tensor in dimension 0 output = torch.squeeze( input ,dim = 1 ) print ( "Size after squeeze with dim=1:\n" , output.size()) # squeeze the tensor in dimension 0 output = torch.squeeze( input ,dim = 2 ) print ( "Size after squeeze with dim=2:\n" , output.size()) # squeeze the tensor in dimension 0 output = torch.squeeze( input ,dim = 3 ) print ( "Size after squeeze with dim=3:\n" , output.size()) # squeeze the tensor in dimension 0 output = torch.squeeze( input ,dim = 4 ) print ( "Size after squeeze with dim=4:\n" , output.size()) # output = torch.squeeze(input,dim=5) # Error |
Output:
Dimension of input tensor: 5 Input tensor Size: torch.Size([3, 1, 2, 1, 4]) Size after squeeze with dim=0: torch.Size([3, 1, 2, 1, 4]) Size after squeeze with dim=1: torch.Size([3, 2, 1, 4]) Size after squeeze with dim=2: torch.Size([3, 1, 2, 1, 4]) Size after squeeze with dim=3: torch.Size([3, 1, 2, 4]) Size after squeeze with dim=4: torch.Size([3, 1, 2, 1, 4])
Notice that when we squeeze the tensor in dimension 0, there is no change in the shape of the output tensor. When we squeeze in dimension 1 or in dimension 3 (both are of size 1), only this dimension is removed in the output tensor. When we squeeze in dimension 2 or in dimension 4, there is no change in the shape of the output tensor.
Unsqueeze a Tensor:
When we unsqueeze a tensor, a new dimension of size 1 is inserted at the specified position. Always an unsqueeze operation increases the dimension of the output tensor. For example, if the input tensor is of shape: (m×n) and we want to insert a new dimension at position 1 then the output tensor after unsqueeze will be of shape: (m×1×n). The following is the syntax of the torch.unsqueeze() method-
Syntax: torch.unsqueeze(input, dim)
Parameters:
- input: the input tensor.
- dim: an integer value, the index at which the singleton dimension is inserted.
Return: It returns a new tensor with a dimension of size one inserted at the specified position dim.
Please note that we can choose the dim value from the range [-input.dim() – 1, input.dim() + 1). The negative dim will correspond to dim = dim + input.dim() + 1.
Example 3:
In the example below we unsqueeze a 1-D tensor to a 2D tensor.
Python3
# Python program to unsqueeze the input tensor # importing torch import torch # define the input tensor input = torch.arange( 8 , dtype = torch. float ) print ( "Input tensor:\n" , input ) print ( "Size of input Tensor before unsqueeze:\n" , input .size()) output = torch.unsqueeze( input , dim = 0 ) print ( "Tensor after unsqueeze with dim=0:\n" , output) print ( "Size after unsqueeze with dim=0:\n" , output.size()) output = torch.unsqueeze( input , dim = 1 ) print ( "Tensor after unsqueeze with dim=1:\n" , output) print ( "Size after unsqueeze with dim=1:\n" , output.size()) |
Output:
Input tensor: tensor([0., 1., 2., 3., 4., 5., 6., 7.]) Size of input Tensor before unsqueeze: torch.Size([8]) Tensor after unsqueeze with dim=0: tensor([[0., 1., 2., 3., 4., 5., 6., 7.]]) Size after unsqueeze with dim=0: torch.Size([1, 8]) Tensor after unsqueeze with dim=1: tensor([[0.], [1.], [2.], [3.], [4.], [5.], [6.], [7.]]) Size after unsqueeze with dim=1: torch.Size([8, 1])