torch.vmap
是 PyTorch 提供的一个高效矢量化映射函数,用于对批量数据上的操作进行自动矢量化。它可以显著提高代码的性能和可读性,避免显式使用循环来操作批量数据。
torch.vmap
的核心功能
- 对函数进行批量化操作。
- 自动扩展函数,使其可以作用于批量输入(即
N
个样本)。 - 提供对批量维度的灵活控制,包括指定输入输出的批量维度。
函数签名
torch.vmap(func, in_dims=0, out_dims=0)
参数
-
func
:- 要矢量化的函数(可以是用户定义函数,也可以是 PyTorch 函数)。
- 必须接收张量作为输入,并返回张量或元组。
-
in_dims
:- 指定输入张量的批量维度,默认为
0
。 - 如果输入是多个张量,可以传递一个元组,表示每个输入的批量维度。
- 若
in_dims=None
,表示输入不需要矢量化。
- 指定输入张量的批量维度,默认为
-
out_dims
:- 指定函数输出的批量维度,默认为
0
。
- 指定函数输出的批量维度,默认为