文章目录
前言
1、torch.cat 函数
torch.cat 函数将两个张量拼接起来,具体地是在第三个维度(dim=2)上进行拼接。注:dim取值范围是0~2
node_xy_demand = torch.cat((node_xy, node_demand[:, :, None]), dim=2)
其中所用参数为:
node_xy = reset_state.node_xy
# shape: (batch, problem, 2)
node_demand = reset_state.node_demand
# shape: (batch, problem)
若要拼接node_xy 与node_demand 需要将node_demand 进行维度拓展即 node_demand[:, :, None])
node_xy = torch.tensor([[[1, 2], [3, 4]],
[[5, 6], [7, 8]]])
node_demand = torch.tensor([[[10], [20]],
[[30], [40]]])
node_xy_demand = torch.tensor([[[ 1, 2, 10], [ 3, 4, 20]],
[[ 5, 6, 30], [ 7, 8, 40]]])
2、索引、维度扩展和张量的广播
_ = self.decoder.regret_embedding[None, None, :].expand(encoded_nodes.size(0), 1, self.decoder.regret_embedding.size(-1))
self.decoder.regret_embedding
是一个张量。- self.decoder.regret_embedding[None, None, :]增加regret_embedding的维度。维度扩展成 (1, 1, D)
.expand(encoded_nodes.size(0), 1, self.decoder.regret_embedding.size(-1))
- expand 用来沿特定维度复制张量,以实现广播。
- encoded_nodes.size(0) 返回的是 encoded_nodes 张量的第一个维度大小。
- 1 表示第二个维度的大小。
- self.decoder.regret_embedding.size(-1) 返回的是 self.decoder.regret_embedding 的最后一个维度的大小,也就是嵌入的维度 D
总结: 将张量建立为所需维度在此为三维,使用expand沿着新建维度进行拓展到所需形状。
3、切片操作
3.1、 encoded_first_node
encoded_first_node = self.encoded_nodes[:, [0], :]
这行代码中的切片操作是从 self.encoded_nodes
中提取特定的数据部分:
:
表示选择所有批次的样本,保留第一个维度(batch
)。[0]
表示选择每个样本中的第一个节点,因此提取的是第一个节点的嵌入向量。:
表示选择该节点的所有嵌入维度,即保留第三个维度(embedding
)的所有值。
最终,经过这些操作,encoded_first_node
的形状为 (batch, 1, embedding)
,即每个样本只包含第一个节点的嵌入向量,保留了嵌入维度。
3.2、probs
probs[:, :, :-1]
- 这是对 probs 张量的切片操作,作用是从 probs 的第三个维度(即最后一个维度)中移除最后一列。
selected = probs.argmax(dim=2)
-
argmax(dim=2)
表示在probs
张量的第3维度(类别维度)上,找到每个样本中概率最大的类别索引。 -
argmax
返回的是最大值的索引,而不是最大值本身。
4、长难代码分析
4.1、selected
selected = probs.reshape(batch_size * pomo_size, -1).multinomial(1).squeeze(dim=1).reshape(batch_size, pomo_size)
prob的shape: (batch, pomo, problem+1)
-
probs.reshape(batch_size * pomo_size, -1)
:- 这一步将 probs 的形状从 (batch, pomo, problem + 1) 转变为 (batch * pomo, problem + 1)。
-1
:表示自动推算出第二维的大小(即 problem + 1)- 新的形状 (batch * pomo, problem + 1)。
-
multinomial(1)
:multinomial(1)
用于从给定的概率分布中选择一个类别。它会返回一个形状为(batch_size * pomo_size, 1)
的张量,每一行选择一个元素的索引,代表从probs
中选择的元素。
-
.squeeze(dim=1)
:squeeze(dim=1)
是去除第二个维度(索引维度),将形状变为(batch_size * pomo_size)
。
-
.reshape(batch_size, pomo_size)
:- 最后,通过
reshape(batch_size, pomo_size)
将张量恢复到原来的形状(batch_size, pomo_size)
,即每个批次对应一个选择的元素索引。
- 最后,通过
4.1.1、multinomial(1)工作原理:
-
输入:
multinomial(1)
需要一个形状为(N, C)
的张量,其中N
是样本的数量,C
是类别的数量。这个张量表示每个样本在各个类别下的概率分布。 -
输出:
multinomial(1)
返回一个形状为(N, 1)
的张量,每个元素是该样本选择的类别的索引。
具体来说,multinomial(1)
会根据每个类别的概率,从概率分布中选取一个类别。这个选择是随机的,但是会遵循给定的概率分布,即概率较大的类别被选中的几率较高,概率较小的类别被选中的几率较低。