Bootstrap

【python】30、矩阵加法 tensor.sum

文章目录

一、tensor.sum

为了更好地理解 `torch.sum` 函数中 `dim` 参数的作用,我们可以将三维张量的求和过程分解,并通过具体的例子来说明不同 `dim` 参数的效果。

### 三维张量的结构

假设我们有一个 3x2x2 的张量,如下所示:
import torch

# 其中 [1, 2] 中的 1 和 2 是 x 方向
# 其中 [[1,2], [3,4]] 中的 [1, 2] 和 [3, 4] 是 y 方向
# 其中 [[[1,2][3,4]],[[5,6][7,8]],[[9,10][11,12]]] 中的 [[1,2][3,4]] 和 [[5,6][7,8]] 和 [[9,10][11,12]] 是 z 方向
tensor = torch.tensor([[[1, 2],
                        [3, 4]],

                       [[5, 6],
                        [7, 8]],

                       [[9, 10],
                        [11, 12]]])

print(tensor)

这个张量可以看作是包含三个 2x2 矩阵的集合:
[
 [[ 1,  2],
  [ 3,  4]],
 [[ 5,  6],
  [ 7,  8]],
 [[ 9, 10],
  [11, 12]]
]

### 沿指定维度求和的效果

#### 不指定 `dim` 参数
默认情况下,`torch.sum` 会对所有元素求和:
total_sum = torch.sum(tensor)
print(total_sum)  # 输出: tensor(78)
解释:1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 = 78

#### 指定 `dim=0`
`dim=0` 表示沿最外层维度求和,即对每个 2x2 矩阵的对应位置元素求和:
即制定dim=x, 表示沿着z方向求和(即消灭z方向)

sum_dim0 = torch.sum(tensor, dim=0)
print(sum_dim0)
输出:
tensor([[15, 18],
        [21, 24]])
解释:
第一个位置:[1 + 5 + 9, 2 + 6 + 10] = [15, 18]
第二个位置:[3 + 7 + 11, 4 + 8 + 12] = [21, 24]

#### 指定 `dim=1`
`dim=1` 表示沿每个 2x2 矩阵的行方向求和:
即制定dim=y, 表示沿着y方向求和(即消灭y方向)
sum_dim1 = torch.sum(tensor, dim=1)
print(sum_dim1)
输出:
tensor([[ 4,  6],
        [12, 14],
        [20, 22]])
解释:
对第一个二维矩阵:行和 [1 + 3, 2 + 4] = [4, 6]
对第二个二维矩阵:行和 [5 + 7, 6 + 8] = [12, 14]
对第三个二维矩阵:行和 [9 + 11, 10 + 12] = [20, 22]

#### 指定 `dim=2`
`dim=2` 表示沿每个 2x2 矩阵的列方向求和:
即制定dim=z, 表示沿着x方向求和(即消灭x方向)

sum_dim2 = torch.sum(tensor, dim=2)
print(sum_dim2)
输出:
tensor([[ 3,  7],
        [11, 15],
        [19, 23]])
解释:
对第一个二维矩阵:列和 [1 + 2, 3 + 4] = [3, 7]
对第二个二维矩阵:列和 [5 + 6, 7 + 8] = [11, 15]
对第三个二维矩阵:列和 [9 + 10, 11 + 12] = [19, 23]
### 总结
dim=0:沿最外层维度求和,结果是一个 2x2 矩阵,每个元素是对应位置上所有二维矩阵元素的和。
dim=1:沿每个二维矩阵的行方向求和,结果是一个 3x2 矩阵,每个元素是对应位置上行的和。
dim=2:沿每个二维矩阵的列方向求和,结果是一个 3x2 矩阵,每个元素是对应位置上列的和。
;