Bootstrap

pytorch torch.vmap函数介绍

torch.vmap 是 PyTorch 提供的一个高效矢量化映射函数,用于对批量数据上的操作进行自动矢量化。它可以显著提高代码的性能和可读性,避免显式使用循环来操作批量数据。


torch.vmap 的核心功能

  • 对函数进行批量化操作。
  • 自动扩展函数,使其可以作用于批量输入(即 N 个样本)。
  • 提供对批量维度的灵活控制,包括指定输入输出的批量维度。

函数签名

torch.vmap(func, in_dims=0, out_dims=0)
参数
  1. func:

    • 要矢量化的函数(可以是用户定义函数,也可以是 PyTorch 函数)。
    • 必须接收张量作为输入,并返回张量或元组。
  2. in_dims:

    • 指定输入张量的批量维度,默认为 0
    • 如果输入是多个张量,可以传递一个元组,表示每个输入的批量维度。
    • 若 in_dims=None,表示输入不需要矢量化。
  3. out_dims:

    • 指定函数输出的批量维度,默认为 0
;