Bootstrap

相对位置偏置代码解析

1. 初始化相对位置偏置嵌入

self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)

     假设window_size=7、slef.heads=4,则 2 * window_size - 1 = 13;嵌入层的大小为13*13=169,创建一个大小为169*4的嵌入矩阵。

2. 创建位置索引

pos = torch.arange(window_size)   # tensor([0, 1, 2, 3, 4, 5, 6])

   pos 是一个从 0window_size-1 的一维张量

3. 生成二维网格

grid = torch.meshgrid(pos, pos, indexing='ij')

        torch.meshgrid(pos, pos, indexing='ij') 创建一个二维网格,表示每个位置的(x,y)坐标。pos 是形状为 (N,) 的张量,那么输出将是两个形状为 (N, N) 的张量。第一个张量沿着行变化,第二个张量沿着列变化。

# torch.meshgrid(pos, pos, indexing='ij')返回两个张量,并作为一个元组返回
(tensor([[0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3, 3, 3],
        [4, 4, 4, 4, 4, 4, 4],
        [5, 5, 5, 5, 5, 5, 5],
        [6, 6, 6, 6, 6, 6, 6]]), tensor([[0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6]]))

# 如果想要输出的是两个单独的张量,则可以使用下面的代码
# grid_x, grid_y = torch.meshgrid(pos, pos, indexing='ij')

# 打印单独的张量
# print(grid_x)
# print(grid_y)

3.1 torch.stack 将上述两个二维张量沿新的维度堆叠 

grid = torch.stack(torch.meshgrid(pos, pos, indexing='ij'))
tensor([[[0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1],
         [2, 2, 2, 2, 2, 2, 2],
         [3, 3, 3, 3, 3, 3, 3],
         [4, 4, 4, 4, 4, 4, 4],
         [5, 5, 5, 5, 5, 5, 5],
         [6, 6, 6, 6, 6, 6, 6]],

        [[0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6]]])

3.2 Rearrange函数 

# 其中 c=2 代表两个网格(x 和 y 坐标),i=7 和 j=7 代表网格的维度
grid = rearrange(grid, 'c i j -> (i j) c')

# (i j) c 表示将原始张量的维度重新排列成 (i j) 和 c。
# 即将网格的每一个点((i, j))展平为一维,并将每个点的 c 个值(这里是两个值)放在新的维度 c 中

        张量 grid 的形状从 (2, window_size, window_size) 变为 (window_size * window_size, 2)指的意思就是将第二个维度i和第三个维度j合并为一个维度。形状为 (49, 2)。

此时输出的维度为2 

tensor([[0, 0],
        [0, 1],
        [0, 2],
        [0, 3],
        [0, 4],
        [0, 5],
        [0, 6],
        [1, 0],
        [1, 1],
        [1, 2],
        [1, 3],
        [1, 4],
        [1, 5],
        [1, 6],
        [2, 0],
        [2, 1],
        [2, 2],
        [2, 3],
        [2, 4],
        [2, 5],
        [2, 6],
        [3, 0],
        [3, 1],
        [3, 2],
        [3, 3],
        [3, 4],
        [3, 5],
        [3, 6],
        [4, 0],
        [4, 1],
        [4, 2],
        [4, 3],
        [4, 4],
        [4, 5],
        [4, 6],
        [5, 0],
        [5, 1],
        [5, 2],
        [5, 3],
        [5, 4],
        [5, 5],
        [5, 6],
        [6, 0],
        [6, 1],
        [6, 2],
        [6, 3],
        [6, 4],
        [6, 5],
        [6, 6]])

        在重排后的张量中,每个元素代表窗口内一个位置的 (x, y) 坐标对。通过这种方式,可以方便地处理窗口内位置的相对关系。 

3.2.1 将其展平为一维

# 使用reshape方法转换为一维张量
one_dim = grid.reshape(-1)

print(one_dim)
tensor([0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 1, 0, 1, 1, 1, 2, 1, 3, 1, 4,
        1, 5, 1, 6, 2, 0, 2, 1, 2, 2, 2, 3, 2, 4, 2, 5, 2, 6, 3, 0, 3, 1, 3, 2,
        3, 3, 3, 4, 3, 5, 3, 6, 4, 0, 4, 1, 4, 2, 4, 3, 4, 4, 4, 5, 4, 6, 5, 0,
        5, 1, 5, 2, 5, 3, 5, 4, 5, 5, 5, 6, 6, 0, 6, 1, 6, 2, 6, 3, 6, 4, 6, 5,
        6, 6])

4. 计算相对位置

rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
rearrange(grid, 'i ... -> i 1 ...') 

        将这个张量的第一个维度保持不变,同时在第二个维度的位置插入一个新的维度,新的维度大小为1。

rearrange(grid, 'i ... -> i 1 ...'):
tensor([
  [[0, 0]], [[0, 1]], [[0, 2]], [[0, 3]], [[0, 4]], [[0, 5]], [[0, 6]],
  [[1, 0]], [[1, 1]], [[1, 2]], [[1, 3]], [[1, 4]], [[1, 5]], [[1, 6]],
  [[2, 0]], [[2, 1]], [[2, 2]], [[2, 3]], [[2, 4]], [[2, 5]], [[2, 6]],
  [[3, 0]], [[3, 1]], [[3, 2]], [[3, 3]], [[3, 4]], [[3, 5]], [[3, 6]],
  [[4, 0]], [[4, 1]], [[4, 2]], [[4, 3]], [[4, 4]], [[4, 5]], [[4, 6]],
  [[5, 0]], [[5, 1]], [[5, 2]], [[5, 3]], [[5, 4]], [[5, 5]], [[5, 6]],
  [[6, 0]], [[6, 1]], [[6, 2]], [[6, 3]], [[6, 4]], [[6, 5]], [[6, 6]]
])

rearrange(grid, 'j ... -> 1 j ...'):
tensor([
  [[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6],
   [1, 0], [1, 1], [1, 2], [1, 3], [1, 4], [1, 5], [1, 6],
   [2, 0], [2, 1], [2, 2], [2, 3], [2, 4], [2, 5], [2, 6],
   [3, 0], [3, 1], [3, 2], [3, 3], [3, 4], [3, 5], [3, 6],
   [4, 0], [4, 1], [4, 2], [4, 3], [4, 4], [4, 5], [4, 6],
   [5, 0], [5, 1], [5, 2], [5, 3], [5, 4], [5, 5], [5, 6],
   [6, 0], [6, 1], [6, 2], [6, 3], [6, 4], [6, 5], [6, 6]]
])

        grid变形为(49,1,2)和(1,49,2),则相减时, i 维度的 49 会与第二个张量的 49 对齐,第二个维度 11 对齐。... 维度大小为 2,这两个维度自然对齐。

减法操作通过广播机制进行:

        对每个 (i, j) 对应的元素,计算 A[i, 0, :] - B[0, j, :]最终形状: (49, 49, 2)。因为 49(第一个维度)和 49(第二个维度)都被保留下来,并且 2 是原始的最后一个维度。

rel_pos += window_size - 1
# 相减得到相对位置,并加上 window_size - 1 调整索引为正:
tensor([
  [[6, 6], [6, 5], [6, 4], [6, 3], [6, 2], [6, 1], [6, 0],
   [5, 6], [5, 5], [5, 4], [5, 3], [5, 2], [5, 1], [5, 0],
   [4, 6], [4, 5], [4, 4], [4, 3], [4, 2], [4, 1], [4, 0],
   [3, 6], [3, 5], [3, 4], [3, 3], [3, 2], [3, 1], [3, 0],
   [2, 6], [2, 5], [2, 4], [2, 3], [2, 2], [2, 1], [2, 0],
   [1, 6], [1, 5], [1, 4], [1, 3], [1, 2], [1, 1], [1, 0],
   [0, 6], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], [0, 0]],

  [[6, 6], [6, 5], [6, 4], [6, 3], [6, 2], [6, 1], [6, 0],
   [5, 6], [5, 5], [5, 4], [5, 3], [5, 2], [5, 1], [5, 0],
   [4, 6], [4, 5], [4, 4], [4, 3], [4, 2], [4, 1], [4, 0],
   [3, 6], [3, 5], [3, 4], [3, 3], [3, 2], [3, 1], [3, 0],
   [2, 6], [2, 5], [2, 4], [2, 3], [2, 2], [2, 1], [2, 0],
   [1, 6], [1, 5], [1, 4], [1, 3], [1, 2], [1, 1], [1, 0],
   [0, 6], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], [0, 0]],

  [[6, 6], [6, 5], [6, 4], [6, 3], [6, 2], [6, 1], [6, 0],
   [5, 6], [5, 5], [5, 4], [5, 3], [5, 2], [5, 1], [5, 0],
   [4, 6], [4, 5], [4, 4], [4, 3], [4, 2], [4, 1], [4, 0],
   [3, 6], [3, 5], [3, 4], [3, 3], [3, 2], [3, 1], [3, 0],
   [2, 6], [2, 5], [2, 4], [2, 3], [2, 2], [2, 1], [2, 0],
   [1, 6], [1, 5], [1, 4], [1, 3], [1, 2], [1, 1], [1, 0],
   [0, 6], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], [0, 0]],

  [[6, 6], [6, 5], [6, 4], [6, 3], [6, 2], [6, 1], [6, 0],
   [5, 6], [5, 5], [5, 4], [5, 3], [5, 2], [5, 1], [5, 0],
   [4, 6], [4, 5], [4, 4], [4, 3], [4, 2], [4, 1], [4, 0],
   [3, 6], [3, 5], [3, 4], [3, 3], [3, 2], [3, 1], [3, 0],
   [2, 6], [2, 5], [2, 4], [2, 3], [2, 2], [2, 1], [2, 0],
   [1, 6], [1, 5], [1, 4], [1, 3], [1, 2], [1, 1], [1, 0],
   [0, 6], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], [0, 0]],

  [[6, 6], [6, 5], [6, 4], [6, 3], [6, 2], [6, 1], [6, 0],
   [5, 6], [5, 5], [5, 4], [5, 3], [5, 2], [5, 1], [5, 0],
   [4, 6], [4, 5], [4, 4], [4, 3], [4, 2], [4, 1], [4, 0],
   [3, 6], [3, 5], [3, 4], [3, 3], [3, 2], [3, 1], [3, 0],
   [2, 6], [2, 5], [2, 4], [2, 3], [2, 2], [2, 1], [2, 0],
   [1, 6], [1, 5], [1, 4], [1, 3], [1, 2], [1, 1], [1, 0],
   [0, 6], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], [0, 0]],

  [[6, 6], [6, 5], [6, 4], [6, 3], [6, 2], [6, 1], [6, 0],
   [5, 6], [5, 5], [5, 4], [5, 3], [5, 2], [5, 1], [5, 0],
   [4, 6], [4, 5], [4, 4], [4, 3], [4, 2], [4, 1], [4, 0],
   [3, 6], [3, 5], [3, 4], [3, 3], [3, 2], [3, 1], [3, 0],
   [2, 6], [2, 5], [2, 4], [2, 3], [2, 2], [2, 1], [2, 0],
   [1, 6], [1, 5], [1, 4], [1, 3], [1, 2], [1, 1], [1, 0],
   [0, 6], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], [0, 0]]
])

5. 计算相对位置索引

rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim=-1)

# rel_pose的形状为torch.Size([49, 49, 2])
# torch.tensor([13, 1]):这是一个形状为 (2,) 的一维张量。它包含两个元素:13 和 1

# 当其相乘时,torch.tensor([13, 1])权重张量会被扩展为形状为(1,1,2)的张量,结果仍然是(49,49,2)的三维张量
# 三维张量中每个 (x, y) 对都被替换为其与权重张量的乘积,这一步将对每个 (13*x, 1*y) 对进行求和

 相对位置乘以 [13, 1] 并求和得到唯一索引:

# 给定的张量表示了一个二维网格中每个位置的线性索引。每个数值在张量中指示了在一维数组中的线性位置
tensor([
  [84, 83, 82, 81, 80, 79, 78, 71, 70, 69, 68, 67, 66, 65],
  [72, 71, 70, 69, 68, 67, 66, 59, 58, 57, 56, 55, 54, 53],
  [60, 59, 58, 57, 56, 55, 54, 47, 46, 45, 44, 43, 42, 41],
  [48, 47, 46, 45, 44, 43, 42, 35, 34, 33, 32, 31, 30, 29],
  [36, 35, 34, 33, 32, 31, 30, 23, 22, 21, 20, 19, 18, 17],
  [24, 23, 22, 21, 20, 19, 18, 11, 10, 9,  8,  7,  6,  5],
  [12, 11, 10,  9,  8,  7,  6, -1, -2, -3, -4, -5, -6, -7]
])

   rel_pos 是一个张量,通常表示元素之间的相对位置关系。

6. 注册缓冲区

self.register_buffer('rel_pos_indices', rel_pos_indices, persistent=False)

        注册缓冲区(或称缓冲区,Buffer)的主要作用是在内存中预留指定大小的存储空间,用于对输入/输出(I/O)的数据进行临时存储。

        rel_pos_indices: 这是一个张量(tensor),它将被注册为缓冲区。这个张量可以是任何需要在前向传播中使用的、但不希望被优化器更新的数据。

    persistent=False: 这是一个可选参数。默认情况下,persistent 是 True。当 persistent=True 时,这个缓冲区会被保存在模型的 state_dict 中,这样当模型被保存和加载时,这个缓冲区也会被保存和恢复。当 persistent=False 时,这个缓冲区不会被保存在 state_dict 中。这意味着如果你保存模型然后加载它,这个缓冲区的数据将不会被恢复。

        为什么有时我们需要一个 persistent=False 的缓冲区呢?因为这个缓冲区的数据可以在模型初始化时重新计算或获取,或者这个缓冲区中的数据不是模型状态的重要部分,因此不需要在模型保存和加载时一起保存。 

;