AlphaFold3 _attention
函数位于 src.models.components.primitives模块,是一个标准的注意力机制的实现,主要用于计算输入的查询 (query
)、键 (key
) 和值 (value
) 张量之间的注意力权重,并将其应用于值张量。_attention
函数被Attention类调用,实现定制化的多头注意力机制。
源代码:
def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, biases: List[torch.Tensor]) -> torch.Tensor:
"""A stock PyTorch implementation of the attention mechanism.
Args:
query:
[*, H, Q, C_hidden] query tensor
key:
[*, H, K/V, C_hidden] key tensor
value:
[*, H, K/V, C_value] value tensor
biases:
a list of biases that broadcast to [*, H, Q, K]
Returns:
the resultant tensor [*, H, Q, C_value]
"""
# [*, H, C_hidden, K]
key = permute_final_dims(key, (1, 0))
# [*, H, Q, K]
a = torch.matmul(query,