Bootstrap

pytorch中的.clone() 和 .detach()

在PyTorch中,.clone().detach() 是两个用于处理张量(Tensor)的方法,它们各自有不同的用途:

  1. .clone()

    • .clone() 方法用于创建一个张量的副本(深拷贝)。这意味着原始张量和新张量将有不同的内存地址,并且对新张量的任何修改都不会影响原始张量。
    • 这个操作会复制张量的所有数据,包括梯度信息(如果张量需要梯度的话)。
    • 示例代码:
       

      python

      import torch
      tensor = torch.tensor([1, 2, 3], requires_grad=True)
      cloned_tensor = tensor.clone()
      cloned_tensor[0] = 10  # 修改克隆的张量不会影响原始张量
      print(tensor)  # 输出: tensor([1, 2, 3])
  2. .detach()

    • .detach() 方法用于从当前计算图中分离出一个张量,返回一个新的张量,这个新的张量不会在反向传播中计算梯度。
    • 这个操作通常用于评估模型时,当你不希望某些张量参与梯度计算时使用。
    • .detach() 返回的张量与原始张量共享数据,但是不会跟踪梯度。这意味着对返回的张量的修改可能会影响原始张量的数据,但是不会影响梯度计算。
    • 示例代码:
       

      python

      import torch
      tensor = torch.tensor([1, 2, 3], requires_grad=True)
      detached_tensor = tensor.detach()
      detached_tensor[0] = 10  # 修改分离的张量会影响原始张量的数据
      print(tensor)  # 输出: tensor([10, 2, 3], requires_grad=True)

总结来说,.clone() 是用来创建张量的深拷贝,而 .detach() 是用来从计算图中分离张量,返回一个不会计算梯度的张量。在使用时,需要根据具体的需求选择合适的方法。

;