Bootstrap

解析Torch中多头注意力`MultiheadAttention`

前沿: 这部分内容是《Attention Is All You Need》出来之后就深入研究了这篇文章的模型结构,也是之后工作那一年进行实际落地的一小部分内容。最近再次使用它,顺带读了torch官方的实现,大家风范的实现,注意很多细节,值得我们学习,也顺带放在这,之后就不再了解这块内容了,过去式了。下面是内容:

MultiheadAttention

class MultiheadAttention(Module):
    r"""Allows the model to jointly attend to information
    from different representation subspaces.
    See `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_

    .. math::
        \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O

    where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.

    Args:
        embed_dim: total dimension of the model.
        num_heads: parallel attention heads.
        dropout: a Dropout layer on attn_output_weights. Default: 0.0.
        bias: add bias as module parameter. Default: True.
        add_bias_kv: add bias to the key and value sequences at dim=0.
        add_zero_attn: add a new batch of zeros to the key and
                       value sequences at dim=1.
        kdim: total number of features in key. Default: None.
        vdim: total number of features in value. Default: None.
        batch_first: If ``True``, then the input and output tensors are provided
            as (batch, seq, feature). Default: ``False`` (seq, batch, feature).

    Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set
    to :attr:`embed_dim` such that query, key, and value have the same
    number of features.

    Examples::

        >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
        >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
    """
    __constants__ = ['batch_first']
    bias_k: Optional[torch.Tensor]
    bias_v: Optional[torch.Tensor]

    def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False,
                 kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None:
        """

        :param embed_dim: 模型的嵌入维度,即查询、键和值的维度
        :param num_heads: 并行的注意力头的数量
        :param dropout: 一个可选参数,用于定义在注意力权重上的 Dropout 层的丢弃率,默认为0
        :param bias: 一个布尔值,用于确定是否在投影权重中添加偏置,默认为 True
        :param add_bias_kv: 一个布尔值,用于确定是否在键和值的序列中添加额外的偏置项,默认为 False
        :param add_zero_attn: 一个布尔值,用于确定是否在键和值序列中添加一个新的零向量,默认为 False
        :param kdim: 键的维度,如果没有提供则默认等于 embed_dim
        :param vdim: 值的维度,如果没有提供则默认等于 embed_dim
        :param batch_first: 一个布尔值,用于确定输入和输出张量的维度顺序是否为 (batch, seq, feature),默认为 False,即 (seq, batch, feature)
        :param device: 用于指定参数的设备(CPU/GPU)
        :param dtype: 用于指定参数的数据类型
        """
        # 创建一个字典 factory_kwargs 来存储设备和数据类型的信息,这些信息将在后面创建参数时使用。
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(MultiheadAttention, self).__init__()

        #  设置 embed_dim、kdim 和 vdim 的值,如果 kdim 和 vdim 没有指定,则它们等于 embed_dim
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim

        # 判断键、值和嵌入维度是否相同,这影响到投影权重的创建方式
        self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim

        # 设置 num_heads、dropout 和 batch_first。
        # 计算每个头的维度 head_dim。
        # 确保 embed_dim 能够被 num_heads 整除,这是多头注意力机制的要求
        self.num_heads = num_heads
        self.dropout = dropout
        self.batch_first = batch_first
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

        # 根据 _qkv_same_embed_dim 的真假创建投影权重。如果维度相同,则创建一个大型的共享权重矩阵;如果不相同,则为查询、键和值创建单独的权重矩阵
        if self._qkv_same_embed_dim is False:
            self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
            self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs))
            self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs))
            self.register_parameter('in_proj_weight', None)
        else:
            self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
            self.register_parameter('q_proj_weight', None)
            self.register_parameter('k_proj_weight', None)
            self.register_parameter('v_proj_weight', None)

        # 如果 bias 为 True,则创建一个偏置参数;否则,不创建偏置参数
        if bias:
            self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
        else:
            self.register_parameter('in_proj_bias', None)

        # 创建输出投影层,这是一个线性层,用于在多头注意力计算之后将输出映射回 embed_dim 维度
        self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)

        # 如果 add_bias_kv 为 True,则创建键和值的偏置参数;否则,不创建这些偏置参数
        if add_bias_kv:
            self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
            self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
        else:
            self.bias_k = self.bias_v = None

        # 设置 add_zero_attn,用于控制是否在键和值序列中添加零向量
        self.add_zero_attn = add_zero_attn

        # 用 _reset_parameters 方法来初始化所有的参数,例如通过 Xavier 初始化等方法来确保参数的合理范围。这一步是确保模型在训练开始时有一个良好的初始状态
        self._reset_parameters()

    def _reset_parameters(self):
        if self._qkv_same_embed_dim:
            xavier_uniform_(self.in_proj_weight)
        else:
            xavier_uniform_(self.q_proj_weight)
            xavier_uniform_(self.k_proj_weight)
            xavier_uniform_(self.v_proj_weight)

        if self.in_proj_bias is not None:
            constant_(self.in_proj_bias, 0.)
            constant_(self.out_proj.bias, 0.)
        if self.bias_k is not None:
            xavier_normal_(self.bias_k)
        if self.bias_v is not None:
            xavier_normal_(self.bias_v)

    def __setstate__(self, state):
        # Support loading old MultiheadAttention checkpoints generated by v1.1.0
        if '_qkv_same_embed_dim' not in state:
            state['_qkv_same_embed_dim'] = True

        super(MultiheadAttention, self).__setstate__(state)

    def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,
                need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
        r"""
        参数说明:
            query(查询)、key(键)、value(值):将一个查询和一组键值对映射到一个输出。更多细节请参考论文《Attention Is All You Need》。
            key_padding_mask:如果提供,那么在键中的指定填充元素将被注意力机制忽略。当给定一个二进制掩码并且某个值为True时,注意力层上对应的值将被忽略。当给定一个字节掩码且某个值非零时,注意力层上对应的值也将被忽略。
            need_weights:输出注意力权重attn_output_weights。
            attn_mask:2D或3D的掩码,用于阻止对某些位置的注意力。一个2D掩码将为所有批次广播,而一个3D掩码允许为每个批次的条目指定不同的掩码。

        输入张量的形状:
            query:数学表示为(L, N, E),其中L是目标序列长度,N是批次大小,E是嵌入维度。如果batch_first设置为True,则表示为(N, L, E)。
            key:数学表示为(S, N, E),其中S是源序列长度,N是批次大小,E是嵌入维度。如果batch_first设置为True,则表示为(N, S, E)。
            value:数学表示为(S, N, E),其中S是源序列长度,N是批次大小,E是嵌入维度。如果batch_first设置为True,则表示为(N, S, E)。
            key_padding_mask:数学表示为(N, S),其中N是批次大小,S是源序列长度。如果提供了ByteTensor,非零位置将被忽略,而零位置保持不变。如果提供了BoolTensor,值为True的位置将被忽略,而值为False的位置保持不变。
            attn_mask:如果是一个2D掩码,数学表示为(L, S),其中L是目标序列长度,S是源序列长度。

              如果是一个3D掩码,数学表示为`(N\cdot\text{num_heads}, L, S)`,其中N是批次大小,L是目标序列长度,S是源序列长度。
              `attn_mask`确保位置i可以关注未被掩码的位置。如果提供了ByteTensor,非零位置不允许参与注意力计算,而零位置保持不变。
              如果提供了BoolTensor,值为`True`的位置不允许参与注意力计算,而`False`值保持不变。如果提供了FloatTensor,它将被加到注意力权重上。

        输出张量的形状:
            attn_output:数学表示为(L, N, E),其中L是目标序列长度,N是批次大小,E是嵌入维度。如果batch_first设置为True,则表示为(N, L, E)。
            attn_output_weights:数学表示为(N, L, S),其中N是批次大小,L是目标序列长度,S是源序列长度。
        """
        # 维度转换:如果 batch_first 设定为 True,则会将输入张量的维度顺序从 (N, L, E) 转换为 (L, N, E),以便与内部计算兼容
        if self.batch_first:
            query, key, value = [x.transpose(1, 0) for x in (query, key, value)]

        if not self._qkv_same_embed_dim:
            attn_output, attn_output_weights = F.multi_head_attention_forward(
                query, key, value, self.embed_dim, self.num_heads,
                self.in_proj_weight, self.in_proj_bias,
                self.bias_k, self.bias_v, self.add_zero_attn,
                self.dropout, self.out_proj.weight, self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights,
                attn_mask=attn_mask, use_separate_proj_weight=True,
                q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
                v_proj_weight=self.v_proj_weight)
        else:
            attn_output, attn_output_weights = F.multi_head_attention_forward(
                query, key, value, self.embed_dim, self.num_heads,
                self.in_proj_weight, self.in_proj_bias,
                self.bias_k, self.bias_v, self.add_zero_attn,
                self.dropout, self.out_proj.weight, self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights,
                attn_mask=attn_mask)
        if self.batch_first:
            return attn_output.transpose(1, 0), attn_output_weights
        else:
            return attn_output, attn_output_weights

multi_head_attention_forward

def multi_head_attention_forward(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    embed_dim_to_check: int,
    num_heads: int,
    in_proj_weight: Tensor,
    in_proj_bias: Optional[Tensor],
    bias_k: Optional[Tensor],
    bias_v: Optional[Tensor],
    add_zero_attn: bool,
    dropout_p: float,
    out_proj_weight: Tensor,
    out_proj_bias: Optional[Tensor],
    training: bool = True,
    key_padding_mask: Optional[Tensor] = None,
    need_weights: bool = True,
    attn_mask: Optional[Tensor] = None,
    use_separate_proj_weight: bool = False,
    q_proj_weight: Optional[Tensor] = None,
    k_proj_weight: Optional[Tensor] = None,
    v_proj_weight: Optional[Tensor] = None,
    static_k: Optional[Tensor] = None,
    static_v: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
    r"""
    参数说明:
        query(查询),key(键),value(值):将一个查询和一组键值对映射到一个输出。更多信息请参见论文《Attention Is All You Need》。
        embed_dim_to_check:模型的总维度。
        num_heads:并行注意力头的数量。
        in_proj_weight,in_proj_bias:输入投影的权重和偏置。
        bias_k,bias_v:要添加到键和值序列上的偏置,在第0维添加。
        add_zero_attn:在键和值序列的第1维添加一个新的零批次。
        dropout_p:元素被置零的概率。
        out_proj_weight,out_proj_bias:输出投影的权重和偏置。
        training:如果为True,则应用dropout。
        key_padding_mask:如果提供,指定在键中的填充元素将被注意力机制忽略。这是一个二进制掩码。当值为True时,注意力层上对应的位置将被填充为-inf。
        need_weights:输出注意力权重attn_output_weights。
        attn_mask:2D或3D的掩码,用于阻止对某些位置的注意力。2D掩码将为所有批次广播,而3D掩码允许为每个批次的不同条目指定不同的掩码。
        use_separate_proj_weight:函数接受查询、键和值的投影权重的不同形式。如果为假,将使用in_proj_weight,它是q_proj_weight、k_proj_weight和v_proj_weight的组合。
        q_proj_weight,k_proj_weight,v_proj_weight,in_proj_bias:输入投影的权重和偏置。
        static_k,static_v:用于注意力运算符的静态键和值。
    形状说明:
        输入:
            query:数学表示为(L, N, E),其中L是目标序列长度,N是批次大小,E是嵌入维度。
            key:数学表示为(S, N, E),其中S是源序列长度,N是批次大小,E是嵌入维度。
            value:数学表示为(S, N, E),其中S是源序列长度,N是批次大小,E是嵌入维度。
            key_padding_mask:数学表示为(N, S),其中N是批次大小,S是源序列长度。如果提供了ByteTensor,非零位置将被忽略,而零位置保持不变。如果提供了BoolTensor,值为True的位置将被忽略,而值为False的位置保持不变。
            attn_mask:2D掩码数学表示为(L, S),其中L是目标序列长度,S是源序列长度。3D掩码数学表示为(N*num_heads, L, S),其中N是批次大小,L是目标序列长度,S是源序列长度。attn_mask确保位置i被允许关注未被掩码的位置。
            如果提供了ByteTensor,非零位置不允许参与注意力计算,而零位置保持不变。如果提供了BoolTensor,值为True的位置不允许参与注意力计算,而False值保持不变。如果提供了FloatTensor,它将被加到注意力权重上。
            static_k:数学表示为(N*num_heads, S, E/num_heads),其中S是源序列长度,N是批次大小,E是嵌入维度。E/num_heads是头的维度。
            static_v:数学表示为(N*num_heads, S, E/num_heads),其中S是源序列长度,N是批次大小,E是嵌入维度。E/num_heads是头的维度。
        输出:
            attn_output:数学表示为(L, N, E),其中L是目标序列长度,N是批次大小,E是嵌入维度。
            attn_output_weights:数学表示为(N, L, S),其中N是批次大小,L是目标序列长度,S是源序列长度。
    """
    tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
    if has_torch_function(tens_ops):
        return handle_torch_function(
            multi_head_attention_forward,
            tens_ops,
            query,
            key,
            value,
            embed_dim_to_check,
            num_heads,
            in_proj_weight,
            in_proj_bias,
            bias_k,
            bias_v,
            add_zero_attn,
            dropout_p,
            out_proj_weight,
            out_proj_bias,
            training=training,
            key_padding_mask=key_padding_mask,
            need_weights=need_weights,
            attn_mask=attn_mask,
            use_separate_proj_weight=use_separate_proj_weight,
            q_proj_weight=q_proj_weight,
            k_proj_weight=k_proj_weight,
            v_proj_weight=v_proj_weight,
            static_k=static_k,
            static_v=static_v,
        )

    # set up shape vars
    # 初始化形状变量
    tgt_len, bsz, embed_dim = query.shape  # (目标序列长度, 批次大小, 嵌入维度)
    src_len, _, _ = key.shape              # (源序列长度, 批次大小, 嵌入维度)
    # 检查 embed_dim 是否等于 embed_dim_to_check
    assert embed_dim == embed_dim_to_check, \
        f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"

    # 计算每个注意力头的维度 head_dim。如果 embed_dim 是一个 torch.Tensor 类型(这在使用 PyTorch 的 Just-In-Time (JIT) 编译器时可能发生),
    # 则使用 .div() 方法除以 num_heads 并使用 'trunc' 模式舍去小数点后的部分。否则,直接使用 Python 的整数除法 // 计算 head_dim
    if isinstance(embed_dim, torch.Tensor):
        # embed_dim can be a tensor when JIT tracing
        head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
    else:
        head_dim = embed_dim // num_heads

    # 确保嵌入维度可被头数整除
    assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"

    # 检查键和值张量的形状
    # 当使用单独的投影权重时 (use_separate_proj_weight 为 True),仅检查键和值张量的前两个维度(即序列长度和批次大小)是否一致。这是因为当使用独立的权重时,键和值可以有不同的嵌入维度。
    # 如果不使用单独的投影权重,则键和值张量的形状必须完全相同
    if use_separate_proj_weight:
        # allow MHA to have different embedding dimensions when separate projection weights are used
        assert key.shape[:2] == value.shape[:2], \
            f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
    else:
        assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"

    #
    # compute in-projection
    #
    if not use_separate_proj_weight:
        # 使用一个共享的权重矩阵 in_proj_weight 和一个共享的偏置向量 in_proj_bias 来同时投影查询、键和值。这种情况下,调用 _in_projection_packed 函数完成投影
        q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
    else:
        # 需要确保查询、键和值的投影权重 q_proj_weight、k_proj_weight 和 v_proj_weight 分别存在。这三组权重将分别用于投影查询、键和值张量
        assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
        assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
        assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"

        # 如果 in_proj_bias 存在,它会被分割成三个部分,分别作为查询、键和值的偏置向量 b_q、b_k 和 b_v。如果 in_proj_bias 不存在,则将三个偏置向量设为 None
        if in_proj_bias is None:
            b_q = b_k = b_v = None
        else:
            b_q, b_k, b_v = in_proj_bias.chunk(3)
        # 使用独立的投影权重和偏置向量完成查询、键和值的投影
        q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)

    # prep attention mask
    if attn_mask is not None:
        # 检查掩码类型
        # 如果 attn_mask 是 uint8 类型,发出警告建议使用布尔类型(bool)替代。
        # 确认 attn_mask 的类型为浮点数(floating point)或布尔类型(bool)。在多头注意力中,掩码通常用于标识应被屏蔽的位置
        if attn_mask.dtype == torch.uint8:
            warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
            attn_mask = attn_mask.to(torch.bool)
        else:
            assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \
                f"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}"
        # ensure attn_mask's dim is 3
        # 验证掩码维度
        # 如果 attn_mask 的维度为 2,确保其形状为 (tgt_len, src_len),然后将其扩展到 3 维,形状变为 (1, tgt_len, src_len),以适应多头注意力的计算
        if attn_mask.dim() == 2:
            correct_2d_size = (tgt_len, src_len)
            if attn_mask.shape != correct_2d_size:
                raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
            attn_mask = attn_mask.unsqueeze(0)
        # 如果 attn_mask 的维度为 3,确保其形状为 (bsz * num_heads, tgt_len, src_len),这与多头注意力机制中批量和头数相关联
        elif attn_mask.dim() == 3:
            correct_3d_size = (bsz * num_heads, tgt_len, src_len)
            if attn_mask.shape != correct_3d_size:
                raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
        # 如果 attn_mask 的维度既不是 2 也不是 3,抛出运行时错误
        else:
            raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")

    # prep key padding mask
    # 指示在键(key)序列中哪些位置应当被忽略(即被视为“填充”或“pad”)的掩码
    # 如果key_padding_mask存在,并且其数据类型为uint8,则发出一个警告,推荐使用布尔类型(bool)的张量代替uint8类型
    if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
        warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
        key_padding_mask = key_padding_mask.to(torch.bool)

    # add bias along batch dimension (currently second)
    # 处理了在多头注意力机制中可能存在的偏差项(bias)的添加。偏差项 bias_k 和 bias_v 分别用于键(key)和值(value)的增强,它们可以增加模型的表达能力
    # 首先确认 bias_k 和 bias_v 是否存在。如果两者都存在,则进行后续的处理
    # 如果 bias_k 和 bias_v 存在,那么需要确认此时不存在静态键(static_k)和静态值(static_v)。这是因为静态键和值通常是在训练过程中预先计算好的,不希望在此处添加额外的偏差
    # 如果偏差项存在,将 bias_k 和 bias_v 添加到键 k 和值 v 张量的末尾
    # 如果存在注意力掩码 attn_mask 或键填充掩码 key_padding_mask,在添加了偏差项后,这两个掩码也需要相应地进行扩展,以保持与新的键和值张量形状的一致性。这里使用 pad 函数在最后一个维度上添加了一个额外的位置
    #
    if bias_k is not None and bias_v is not None:
        assert static_k is None, "bias cannot be added to static key."
        assert static_v is None, "bias cannot be added to static value."
        k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
        v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
        if attn_mask is not None:
            attn_mask = pad(attn_mask, (0, 1))
        if key_padding_mask is not None:
            key_padding_mask = pad(key_padding_mask, (0, 1))
    else:
        assert bias_k is None
        assert bias_v is None

    #
    # reshape q, k, v for multihead attention and make em batch first
    # 负责将查询(query)、键(key)和值(value)张量重塑以适应多头注意力机制,并将它们转换为以批次为首要维度的格式
    # 查询张量 q 被连续化(通过 contiguous()),然后重塑为 (tgt_len, bsz * num_heads, head_dim) 的形状,最后转置为 (bsz * num_heads, tgt_len, head_dim),即以批次和头数的乘积为首要维度
    q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    # 如果没有提供静态键 static_k,键张量 k 也将被连续化和重塑,最终转置为 (bsz * num_heads, src_len, head_dim) 的形状
    # 如果提供了静态键 static_k,则跳过内部投影步骤,直接使用静态键。这里还包含了对静态键尺寸的检查,确保其第一个维度等于批次大小与头数的乘积,第二个维度等于序列长度,第三个维度等于头的维度
    if static_k is None:
        k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
    else:
        # TODO finish disentangling control flow so we don't do in-projections when statics are passed
        assert static_k.size(0) == bsz * num_heads, \
            f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
        assert static_k.size(2) == head_dim, \
            f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
        k = static_k
    # 类似于键张量的处理,如果没有提供静态值 static_v,值张量 v 将被连续化和重塑,最终转置为 (bsz * num_heads, src_len, head_dim) 的形状。
    # 如果提供了静态值 static_v,同样跳过内部投影步骤,直接使用静态值。这里也包含对静态值尺寸的检查,确保其尺寸满足同样的要求
    if static_v is None:
        v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
    else:
        # TODO finish disentangling control flow so we don't do in-projections when statics are passed
        assert static_v.size(0) == bsz * num_heads, \
            f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
        assert static_v.size(2) == head_dim, \
            f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
        v = static_v

    # add zero attention along batch dimension (now first)
    # 在多头注意力机制中添加零注意力的情况,具体来说,就是在键(key)和值(value)张量的批量维度上追加一个全零的张量,以及相应的对注意力掩码和键填充掩码的调整
    # 通过 add_zero_attn 标志来判断是否需要在键和值张量上添加一个全零的张量。这个标志通常在某些特定的应用场景下使用,比如为了增加模型的泛化能力或者在训练过程中引入某种形式的正则化
    if add_zero_attn:
        # 使用 torch.cat 函数将全零的张量与原有的键和值张量沿着批量维度(现在的第一个维度)进行拼接。
        # 这样做的目的是为了在多头注意力计算中引入一个额外的注意力头,该头的键和值均为全零,不会对实际的注意力权重产生贡献,但可以影响模型的学习过程
        zero_attn_shape = (bsz * num_heads, 1, head_dim)
        k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
        v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
        # 如果存在注意力掩码 attn_mask 或键填充掩码 key_padding_mask,在添加了零注意力张量后,这两个掩码也需要相应地进行扩展,以保持与新的键和值张量形状的一致性。
        # 这里同样使用 pad 函数在最后一个维度上添加了一个额外的位置,通常是为了保持掩码与键和值张量的维度对齐,确保注意力计算的正确性
        if attn_mask is not None:
            attn_mask = pad(attn_mask, (0, 1))
        if key_padding_mask is not None:
            key_padding_mask = pad(key_padding_mask, (0, 1))

    #
    # update source sequence length after adjustments
    # 由于可能进行了诸如添加零注意力等操作,源序列长度(src_len)可能发生了变化。这里重新计算 src_len 为键张量 k 的第二个维度的大小,确保它反映了最新的序列长度
    src_len = k.size(1)

    # merge key padding and attention masks
    # 负责更新和合并注意力掩码(attention mask)和键填充掩码(key padding mask),以确保在接下来的多头注意力计算中,模型能够正确地忽略掉那些应该被屏蔽的部分
    # 如果 key_padding_mask 存在,首先检查其形状是否符合预期,即 (bsz, src_len)。然后,将 key_padding_mask 的形状调整为 (bsz * num_heads, 1, src_len),使其能够与注意力掩码 attn_mask 在多头注意力计算中正确地配合使用
    if key_padding_mask is not None:
        assert key_padding_mask.shape == (bsz, src_len), \
            f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
        key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len).   \
            expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
        # 如果 key_padding_mask 存在,且 attn_mask 为空,则直接将 key_padding_mask 作为 attn_mask 使用。
        if attn_mask is None:
            attn_mask = key_padding_mask
        # 如果两者都存在,根据 attn_mask 的数据类型(布尔类型或浮点类型)采取不同的合并策略:
        # 如果 attn_mask 是布尔类型,使用逻辑或操作(logical_or)将两个掩码合并,这相当于对两个掩码进行按位或操作,确保任一掩码中标记为 True 的位置在合并后的掩码中仍标记为 True。
        # 如果 attn_mask 是浮点类型,使用 masked_fill 函数将 key_padding_mask 中标记为 True 的位置填充为负无穷大(float("-inf"))。在注意力计算中,负无穷大将导致这些位置的注意力权重被计算为零,从而达到屏蔽的效果
        elif attn_mask.dtype == torch.bool:
            attn_mask = attn_mask.logical_or(key_padding_mask)
        else:
            attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf"))

    # convert mask to float
    # 将布尔类型的注意力掩码(attention mask)转换为浮点类型,以便在后续的多头注意力计算中正确应用
    if attn_mask is not None and attn_mask.dtype == torch.bool:
        new_attn_mask = torch.zeros_like(attn_mask, dtype=torch.float)
        new_attn_mask.masked_fill_(attn_mask, float("-inf"))
        attn_mask = new_attn_mask

    # adjust dropout probability
    if not training:
        dropout_p = 0.0

    #
    # (deep breath) calculate attention and out projection
    # 执行多头注意力机制的核心计算——缩放点积注意力(Scaled Dot-Product Attention)计算,并进行输出投影
    # 调用 _scaled_dot_product_attention 函数,输入是查询(q)、键(k)、值(v)以及注意力掩码(attn_mask)和丢弃率(dropout_p
    # 重塑注意力输出:
    # 将注意力输出 attn_output 的维度进行转换,首先是通过 transpose 将批量和头数的乘积维度和序列长度维度互换,然后通过 contiguous 确保张量是连续存储的,最后通过 view 将其重塑为 (tgt_len, bsz, embed_dim) 的形状,恢复到最初的形状,但经过了多头注意力的处理。
    # 执行输出投影:
    # 使用 linear 函数对注意力输出进行线性变换,输入是经过调整的注意力输出,权重是 out_proj_weight,偏置是 out_proj_bias。这一步骤将注意力输出从多头注意力的维度空间转换回原始的嵌入维度空间,为后续可能的层或操作做准备
    attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, dropout_p)
    attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
    attn_output = linear(attn_output, out_proj_weight, out_proj_bias)

    # 是注意力机制的一部分,它处理了注意力权重的后处理步骤,具体取决于是否需要返回这些权重(need_weights 参数)。下面是对这段代码的详细解释:
    #
    # 条件判断:
    # 检查 need_weights 布尔变量,如果为 True,则代码会继续计算并返回平均注意力权重;如果为 False,则只返回注意力输出,不返回任何权重信息。
    # 重新塑造注意力权重:
    # 当 need_weights 为 True 时,首先将注意力权重张量 attn_output_weights 的形状从 (bsz * num_heads, tgt_len, src_len) 改变为 (bsz, num_heads, tgt_len, src_len)。这里,bsz 是批量大小,num_heads 是注意力头的数量,tgt_len 和 src_len 分别是目标序列和源序列的长度。
    # 计算平均注意力权重:
    # 接下来,对每个样本(bsz)和每个目标位置(tgt_len),计算所有注意力头的注意力权重的平均值。这是通过在维度 1 上求和(sum(dim=1))然后除以 num_heads 来实现的。这样就得到了一个形状为 (bsz, tgt_len, src_len) 的平均注意力权重张量,其中每一行代表一个样本,每一列代表目标序列的一个位置,而每个元素表示该位置对源序列中每个位置的平均注意力权重。
    # 返回结果:
    # 如果 need_weights 为 True,函数返回经过多头注意力处理的输出 attn_output 和平均注意力权重。否则,仅返回 attn_output,并且权重设为 None
    if need_weights:
        # average attention weights over heads
        attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
        return attn_output, attn_output_weights.sum(dim=1) / num_heads
    else:
        return attn_output, None

_scaled_dot_product_attention

def _scaled_dot_product_attention(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    attn_mask: Optional[Tensor] = None,
    dropout_p: float = 0.0,
) -> Tuple[Tensor, Tensor]:
    r"""
    计算基于查询、键和值张量的缩放点积注意力,可选地使用传递的注意力掩码,并在指定了大于0.0的概率时应用丢弃(dropout)。
    返回一个包含注意力计算后的值和注意力权重的张量对。

    参数:
        q, k, v: 查询、键和值张量。请参阅Shape部分了解形状细节。
        attn_mask: 可选张量,包含要添加到计算注意力中的掩码值。可能是2D或3D;请参阅Shape部分了解详情。
        dropout_p: 丢弃概率。如果大于0.0,则应用丢弃。

    形状:
        - q: :math:`(B, Nt, E)` 其中B是批量大小,Nt是目标序列长度,E是嵌入维度。
        - key: :math:`(B, Ns, E)` 其中B是批量大小,Ns是源序列长度,E是嵌入维度。
        - value: :math:`(B, Ns, E)` 其中B是批量大小,Ns是源序列长度,E是嵌入维度。
        - attn_mask: 可能是3D张量,形状为 :math:`(B, Nt, Ns)` 或2D张量,形状为 :math:`(Nt, Ns)`。

        - 输出: 注意力值具有形状 :math:`(B, Nt, E)`;注意力权重具有形状 :math:`(B, Nt, Ns)`
    计算步骤:
        缩放查询向量:
        首先,将查询张量 q 沿着嵌入维度(E)进行缩放,即除以 sqrt(E)。缩放的目的是防止点积结果过大,从而导致softmax函数的数值不稳定。
        计算注意力得分:
        使用 torch.bmm 函数计算查询张量 q 和键张量 k 转置后的点积,得到一个形状为 (B, Nt, Ns) 的注意力得分矩阵。这里的 bmm 是批量矩阵乘法,适用于批量的矩阵乘法计算。
        应用注意力掩码:
        如果 attn_mask 存在,则将其加到注意力得分矩阵上。这一步骤允许屏蔽掉某些位置的注意力计算,例如,通过设置掩码矩阵中相应位置为负无穷大,可以使得softmax之后的注意力权重为0。
        计算注意力权重:
        使用 softmax 函数对注意力得分矩阵进行归一化,得到注意力权重矩阵。softmax 函数沿最后一个维度(Ns)应用,确保每个目标位置的注意力权重之和为1。
        应用丢弃:
        如果 dropout_p 大于0,则对注意力权重矩阵应用 dropout 函数。丢弃是一种正则化技术,通过随机关闭一部分神经元,可以减少过拟合,提高模型的泛化能力。
        计算注意力输出:
        最后,使用 torch.bmm 函数计算注意力权重矩阵和值张量 v 的点积,得到注意力输出。输出的形状为 (B, Nt, E),与输入的查询张量 q 相同。
        返回结果:
        函数返回注意力输出和注意力权重矩阵,以便后续可能的使用,如可视化注意力模式或进一步的计算。
    """
    B, Nt, E = q.shape
    q = q / math.sqrt(E)
    # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
    attn = torch.bmm(q, k.transpose(-2, -1))
    if attn_mask is not None:
        attn += attn_mask
    attn = softmax(attn, dim=-1)
    if dropout_p > 0.0:
        attn = dropout(attn, p=dropout_p)
    # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
    output = torch.bmm(attn, v)
    return output, attn
;