提示:代码来源慕课网PyTorch入门到进阶,仅供自学使用
前言
第三章 PyTorch入门基础精讲3-21到3-26
提示:以下是本篇文章正文内容,下面案例可供参考
一、Tensor的索引与数据筛选
torch.where(condition,x,y):按照条件从x,y中选出满足条件的元素组成新的tensor
tensor.gather(input,dim,index,out=None):在指定维度上按照索引赋值输出tensor
torch.index_select(input,dim,index,out=None):按照指定索引输出tensor
torch.masked_select(input,mask,out=None):按照mask输出tensor,输出为向量
torch.take(input,indices):将输入看成1D-tensor,按照索引得到输出tensor
torch.nonzero(input,out=None):输出非0元素的坐标
import torch
# tensor的索引与数据筛选
# torch.where(condition,x,y)
# 按照条件从x,y中选出满足条件的元素组成新的tensor
# 下例:选出a张量中大于0.5的值,不符合的位置用b张量里的值
# 应用:
# 利用一个阈值来对tensor进行二值化
# 计算loss时,只计算大于某一个值的loss,可将小于某一个值的loss定义为0
print("===================torch.where===================")
a = torch.rand(4, 4)
b = torch.rand(4, 4)
print(a)
print(b)
out = torch.where(a > 0.5, a, b)
print(out)
# torch.index_select(input,dim,index,out=None)
# 按照指定索引输出tensor
# 下例,将输入张量的第0行,第3行,第1行重新组成为一个3*4的张量
print("==================torch.index_select====================")
a = torch.rand(4, 4)
print(a)
out = torch.index_select(a, dim=0,
index=torch.tensor([0, 3, 1]))
print(out, out.shape)
# torch.gather(input,dim,index,out=None)
# 在指定维度上按照索引赋值输出tensor
# linspace从1开始,到16,16个数字,view改变张量的维度,4*4阶
# 在第0维上进行索引,按列进行索引,[第1列第0个,第2列第1个,第3列第1个,第4列第1个]
print("====================torch.gather====================")
a = torch.linspace(1, 16, 16).view(4, 4)
print(a)
out = torch.gather(a, dim=0,
index=torch.tensor([[0, 1, 1, 1],
[0, 1, 2, 2],
[0, 1, 3</