PyTorch torch.permute() rearranges the original tensor according to the desired ordering and returns a new multidimensional rotated tensor. The size of the returned tensor remains the same as that of the original.
Syntax: torch.permute(*dims)
Parameters:
- dims: sequence of indices in desired ordering of dimensions Of the tensor (indexing starts from zero).
Return: tensor with desired ordering of dimensions.
Let’s see this concept with the help of few examples:
Example 1: Create a two-dimensional tensor of size 2 × 4 and then permuted.
Python3
# import pytorch library import torch # create a tensor of size 2 x 4 input_var = torch.randn( 2 , 4 ) # print size print (input_var.size()) print (input_var) # dimensions permuted input_var = input_var.permute( 1 , 0 ) # print size print (input_var.size()) print (input_var) |
Output:
torch.Size([2, 4]) tensor([[ 0.9801, 0.5296, 0.5449, -1.1481], [-0.6762, -0.1161, 0.6360, -0.5371]]) torch.Size([4, 2]) tensor([[ 0.9801, -0.6762], [ 0.5296, -0.1161], [ 0.5449, 0.6360], [-1.1481, -0.5371]])
Example 2: Create a Three-dimensional tensor of size 3 × 5 × 2 and then permuted.
Python3
# import pytorch library import torch # creating a tensor with random # values of dimension 3 X 5 X 2 input_var = torch.randn( 3 , 5 , 2 ) # print size print (input_var.size()) print (input_var) # dimensions permuted input_var = input_var.permute( 2 , 0 , 1 ) # print size print (input_var.size()) print (input_var) |
Output:
torch.Size([3, 5, 2]) tensor([[[ 0.2059, -0.7165], [-1.1305, 0.5886], [-0.1247, -0.4969], [-0.5788, 0.0159], [ 1.4304, 0.6014]], [[ 2.4882, -0.3910], [-0.5558, 0.6903], [-0.4219, -0.5498], [-0.5346, -0.0703], [ 1.1497, -0.3252]], [[-0.5075, 0.5752], [ 1.3738, -0.3321], [-0.3317, -0.9209], [-1.6677, -1.1471], [-0.9269, -0.6493]]]) torch.Size([2, 3, 5]) tensor([[[ 0.2059, -1.1305, -0.1247, -0.5788, 1.4304], [ 2.4882, -0.5558, -0.4219, -0.5346, 1.1497], [-0.5075, 1.3738, -0.3317, -1.6677, -0.9269]], [[-0.7165, 0.5886, -0.4969, 0.0159, 0.6014], [-0.3910, 0.6903, -0.5498, -0.0703, -0.3252], [ 0.5752, -0.3321, -0.9209, -1.1471, -0.6493]]])