torch.meshgrid()函数的主要功能是生成网格,用于生成坐标。
输入:两个数据类型相同的一维张量
输出:两个张量,行数为第一个输入张量的元素个数,列数为第二个输入张量的元素个数
作用:从某个维度顺序依次获得各个点的x坐标值和y坐标值。
import torch
x_range = torch.tensor([0, 1, 2, 3, 4])
y_range = torch.tensor([0, 2, 3])
y, x = torch.meshgrid(y_range, x_range)
print(x)
print(y)
>>>
tensor([[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4]])
tensor([[0, 0, 0, 0, 0],
[2, 2, 2, 2, 2],
[3, 3, 3, 3, 3]])
x轴坐标在第二个参数时,一行一行读取各个点的x和y坐标值。
如图,共有3*5个点,x存的是每个点的横坐标,图中红点位于整个“点阵”的第2行第2列(下标从0开始,对应索引为[1][1])那么,分别将X,Y中对应位置元素取出(X[1][1]、Y[1][1])就是红点对应的坐标(1,2)。
因此,在x中存储的是所有点的x坐标,y中存储的是所有点的纵坐标。
y轴坐标在第二个参数时,一列一列的获得各个点的x和y坐标值。
上述操作获得按照x轴依次获得各个点的横坐标。如图所示:
x, y = torch.meshgrid(x_range, y_range)
print(x)
print(y)
>>>
tensor([[0, 0, 0],
[1, 1, 1],
[2, 2, 2],
[3, 3, 3],
[4, 4, 4]])
tensor([[0, 2, 3],
[0, 2, 3],
[0, 2, 3],
[0, 2, 3],
[0, 2, 3]])