In this article, we will see Tensor.detach() method in PyTorch using Python.
Pytorch is a Python and C++ interface for an open-source deep learning platform. It is found within the torch module. In PyTorch, the input data has to be processed in the form of a tensor. It also includes a module that calculates gradients automatically for backpropagation. Tensor.detach() method in PyTorch is used to separate a tensor from the computational graph by returning a new tensor that doesn’t require a gradient. If we want to move a tensor from the Graphical Processing Unit (GPU) to the Central Processing Unit (CPU), then we can use detach() method. It will not take any parameter and return the detached tensor.
Syntax: tensor.detach()
Return: the detached tensor
Example 1:
In this example, we will create a one-dimensional tensor with a gradient parameter and detach it using a tensor.detach() method. requires_grad takes boolean value – True
Python3
# import the torch module import torch # create one dimensional tensor with 5 elements with requires_grad # parameter that sets to True tensor1 = torch.tensor([ 7.8 , 3.2 , 4.4 , 4.3 , 3.3 ], requires_grad = True ) print (tensor1) # detach the tensor print (tensor1.detach()) |
Output:
tensor([7.8000, 3.2000, 4.4000, 4.3000, 3.3000], requires_grad=True) tensor([7.8000, 3.2000, 4.4000, 4.3000, 3.3000])
Example 2:
In this example, we will create a two-dimensional tensor with a gradient parameter= False , you will notice that in the output, the tensor doesn’t effect if we set requires_grad = False, and detach it using a tensor.detach() method.
Python3
# import the torch module import torch # create two dimensional tensor with 5 elements with # requires_grad parameter that sets to True tensor1 = torch.tensor([[ 7.8 , 3.2 , 4.4 , 4.3 , 3.3 ], [ 3. , 6. , 7. , 3. , 2. ]], requires_grad = False ) print (tensor1) # detach the tensor print (tensor1.detach()) |
Output:
tensor([[7.8000, 3.2000, 4.4000, 4.3000, 3.3000], [3.0000, 6.0000, 7.0000, 3.0000, 2.0000]]) tensor([[7.8000, 3.2000, 4.4000, 4.3000, 3.3000], [3.0000, 6.0000, 7.0000, 3.0000, 2.0000]])