Bootstrap

AF3 _attention 函数解读

AlphaFold3  _attention 函数位于 src.models.components.primitives模块,是一个标准的注意力机制的实现,主要用于计算输入的查询 (query)、键 (key) 和值 (value) 张量之间的注意力权重,并将其应用于值张量。_attention 函数被Attention类调用,实现定制化的多头注意力机制。

源代码:

def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, biases: List[torch.Tensor]) -> torch.Tensor:
    """A stock PyTorch implementation of the attention mechanism.
    Args:
        query:
            [*, H, Q, C_hidden] query tensor
        key:
            [*, H, K/V, C_hidden] key tensor
        value:
            [*, H, K/V, C_value] value tensor
        biases:
            a list of biases that broadcast to [*, H, Q, K]
    Returns:
        the resultant tensor [*, H, Q, C_value]
    """

    # [*, H, C_hidden, K]
    key = permute_final_dims(key, (1, 0))

    # [*, H, Q, K]
    a = torch.matmul(query, 
;