Friday, November 15, 2024
Google search engine
HomeLanguagesTensor.detach() Method in Python PyTorch

Tensor.detach() Method in Python PyTorch

In this article, we will see Tensor.detach() method in PyTorch using Python.

Pytorch is a Python and C++ interface for an open-source deep learning platform. It is found within the torch module. In PyTorch, the input data has to be processed in the form of a tensor. It also includes a module that calculates gradients automatically for backpropagation. Tensor.detach() method in PyTorch is used to separate a tensor from the computational graph by returning a new tensor that doesn’t require a gradient. If we want to move a tensor from the Graphical Processing Unit (GPU) to the Central Processing Unit (CPU), then we can use detach() method. It will not take any parameter and return the detached tensor.

Syntax: tensor.detach()

Return: the detached tensor

Example 1:

In this example, we will create a one-dimensional tensor with a gradient parameter and detach it using a tensor.detach() method. requires_grad takes boolean value – True

Python3




# import the  torch module
import torch
  
# create one dimensional tensor with 5 elements with requires_grad
# parameter that sets to True
tensor1 = torch.tensor([7.8, 3.2, 4.4, 4.3, 3.3], requires_grad=True)
print(tensor1)
  
# detach the tensor
print(tensor1.detach())


Output:

tensor([7.8000, 3.2000, 4.4000, 4.3000, 3.3000], requires_grad=True)

tensor([7.8000, 3.2000, 4.4000, 4.3000, 3.3000])

Example 2:

In this example, we will create a two-dimensional tensor with a gradient parameter= False , you will notice that in the output, the tensor doesn’t effect if we set requires_grad = False, and detach it using a tensor.detach() method.

Python3




# import the  torch module
import torch
  
# create two dimensional tensor with 5 elements with
# requires_grad parameter that sets to True
tensor1 = torch.tensor([[7.8, 3.2, 4.4, 4.3, 3.3],
                        [3., 6., 7., 3., 2.]], requires_grad=False)
print(tensor1)
  
# detach the tensor
print(tensor1.detach())


Output:

tensor([[7.8000, 3.2000, 4.4000, 4.3000, 3.3000],
        [3.0000, 6.0000, 7.0000, 3.0000, 2.0000]])
        
tensor([[7.8000, 3.2000, 4.4000, 4.3000, 3.3000],
        [3.0000, 6.0000, 7.0000, 3.0000, 2.0000]])

RELATED ARTICLES

Most Popular

Recent Comments