pytorch小记(三):pytorch中的最大值操作:x.max()
在 PyTorch 中,x.max(dim=n)
表示沿指定维度 dim
求张量的最大值,并返回 最大值 和 最大值的索引。我们逐步分析 dim=0
, dim=1
, 和 dim=2
的行为。
初始化张量:
x = torch.arange(8).reshape(2, 2, 2)
print(x)
输出:
tensor([[[0, 1],
[2, 3]],
[[4, 5],
[6, 7]]])
x
是一个 3D 张量,形状为(2, 2, 2)
。- 第一维度(
dim=0
)有 2 个块。 - 第二维度(
dim=1
)有 2 行。 - 第三维度(
dim=2
)有 2 列。
- 第一维度(
1. x.max(dim=0)
含义:
- 沿着第 0 维(块的方向)比较,保留 其他维度。
- 比较时,将每个位置的两个块中元素逐个比较,选出最大值。
计算过程:
[[0, 1], [2, 3]] # 块 0
| | | |
[[4, 5], [6, 7]] # 块 1
我们沿第 0 维比较,即逐个元素比较块 0 和块 1 对应位置的值,得到最大值:
对比每个位置的值:
- 对位置
[0, 0]
比较:块 0: 0, 块 1: 4 => 最大值是 4
- 对位置
[0, 1]
比较:块 0: 1, 块 1: 5 => 最大值是 5
- 对位置
[1, 0]
比较:块 0: 2, 块 1: 6 => 最大值是 6
- 对位置
[1, 1]
比较:块 0: 3, 块 1: 7 => 最大值是 7
合并最大值:
将每个位置的最大值整合成新的张量:
[[4, 5],
[6, 7]]
输出:
Max: torch.return_types.max(
values=tensor([[4, 5],
[6, 7]]),
indices=tensor([[1, 1],
[1, 1]]))
- 最大值为
tensor([[4, 5], [6, 7]])
,形状为(2, 2)
。 - 索引为
tensor([[1, 1], [1, 1]])
,表示在第 0 维中,每个位置最大值来自第 1 块(索引1
)。
2. x.max(dim=1)
含义:
- 沿着第 1 维(行的方向)比较,保留 其他维度。
- 比较时,将每个块中的两行逐个比较,选出最大值。
计算过程:
块 0:
[[0, 1], => [2, 3]
[2, 3]]
块 1:
[[4, 5], => [6, 7]
[6, 7]]
输出:
Max: torch.return_types.max(
values=tensor([[2, 3],
[6, 7]]),
indices=tensor([[1, 1],
[1, 1]]))
- 最大值为
tensor([[2, 3], [6, 7]])
,形状为(2, 2)
。 - 索引为
tensor([[1, 1], [1, 1]])
,表示在第 1 维中,每个位置最大值来自第 1 行(索引1
)。
3. x.max(dim=2)
含义:
- 沿着第 2 维(列的方向)比较,保留 其他维度。
- 比较时,将每行中的两列逐个比较,选出最大值。
计算过程:
块 0:
[0, 1], => [1]
[2, 3], => [3]
块 1:
[4, 5], => [5]
[6, 7] => [7]
输出:
Max: torch.return_types.max(
values=tensor([[1, 3],
[5, 7]]),
indices=tensor([[1, 1],
[1, 1]]))
- 最大值为
tensor([[1, 3], [5, 7]])
,形状为(2, 2)
。 - 索引为
tensor([[1, 1], [1, 1]])
,表示在第 2 维中,每个位置最大值来自第 1 列(索引1
)。
总结
dim=0
:在 块的方向 比较,结果保留每个位置的 行和列。dim=1
:在 行的方向 比较,结果保留每个位置的 块和列。dim=2
:在 列的方向 比较,结果保留每个位置的 块和行。
图示化总结:
dim=0: [[0, 1], [[4, 5], => [[4, 5],
[2, 3]] [6, 7]] [6, 7]]
dim=1: [[0, 1], [[2, 3], => [[2, 3],
[2, 3]] [6, 7]] [6, 7]]
dim=2: [[0, 1], [[1], => [[1, 3],
[2, 3]] [3]] [5, 7]]