BaseTriangleMultiplicativeUpdate
类是一个抽象基类 (ABC),用于实现 AlphaFold 相关算法(具体为算法 11 和 12)。它的主要功能是通过三角形乘法更新成对表示张量(pairwise representation tensor)。
源代码:
from functools import partialmethod
from typing import Optional
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
from torch.nn import LayerNorm
from src.models.components.primitives import Linear
from src.utils.chunk_utils import chunk_layer
from src.utils.tensor_utils import add, permute_final_dims
class BaseTriangleMultiplicativeUpdate(nn.Module, ABC):
"""
Implements Algorithms 11 and 12.
"""
@abstractmethod
def __init__(self, c_z, c_hidden, _outgoing):
"""
Args:
c_z:
Input channel dimension
c:
Hidden channel dimension
"""
super(BaseTriangleMultiplicativeUpdate, self).__init__()
self.c_z = c_z
self.c_hidden = c_hidden
self._outgoing = _outgoing
self.linear_g = Linear(self.c_z, self.c_z, init="gating")
self.linear_z = Linear(self.c_hidden, self.c_z, init="final")
self.layer_norm_in = LayerNorm(self.c_z)
self.layer_norm_out = LayerNorm(self.c_hidden)
self.sigmoid = nn.Sigmoid()
def _combine_projections(self,
a: torch.Tensor,
b: torch.Tensor,
_inplace_chunk_size: Optional[int] = None
) -> torch.Tensor:
if self._outgoing:
a = permute_final_dims(a, (2, 0, 1))
b = permute_final_dims(b, (2, 1, 0))
else:
a = permute_final_dims(a, (2, 1, 0))
b = permute_final_dims(b, (2, 0, 1))
if _inplace_chunk_size is not None:
# To be replaced by torch vmap
for i in range(0, a.shape[-3], _inplace_chunk_size):
a_chunk = a[..., i: i + _inplace_chunk_size, :, :]
b_chunk = b[..., i: i + _inplace_chunk_size, :, :]
a[..., i: i + _inplace_chunk_size, :, :] = (
torch.matmul(
a_chunk,
b_chunk,
)
)
p = a
else:
p = torch.matmul(a, b)
return permute_final_dims(p, (1, 2, 0))
@abstractmethod
def forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
inplace_safe: bool = False,
_add_with_inplace: bool = False
) -> torch.Tensor:
"""
Args:
x:
[*, N_res, N_res, C_z] x tensor
mask:
[*, N_res, N_res] x mask
Returns:
[*, N_res, N_res, C_z] output tensor
"""
pass
代码解读:
1. 构造方法(__init__
)
def __init__(self, c_z, c_hidden, _outgoing):
super(BaseTriangleMultiplicativeUpdate, self).__init__()
self.c_z = c_z
self.c_hidden = c_hidden
self._outgoing = _outgoing
self.linear_g = Linear(self.c_z, self.c_z, init="gating")
self.linear_z = Linear(self.c_hidden, self.c_z, init="final")
self.layer_norm_in = LayerNorm(self.c_z)
self.layer_norm_out = LayerNorm(self.c_hidden)
self.sigmoid = nn.Sigmoid()
c_z
: 输入的通道维度。c_hidden
: 隐藏层的通道维度。_outgoing
: 一个布尔标志,用于指示操作的方向(用于三角形乘法:出发或抵达方向)。- 主要组件:
linear_g
和linear_z
:线性层,用于特征变换。layer_norm_in
和layer_norm_out
:层归一化,用于提高训练稳定性。sigmoid
:对门控张量进行非线性变换。
2. _combine_projections
方法
def _combine_projections(self, a, b, _inplace_chunk_size=None):
if self._outgoing:
a = permute_final_dims(a, (2, 0, 1))
b = permute_final_dims(b, (2, 1, 0))
else:
a = permute_final_dims(a, (2, 1, 0))
b = permute_final_dims(b, (2, 0, 1))
if _inplace_chunk_size is not None:
for i in range(0, a.shape[-3], _inplace_chunk_size):
a_chunk = a[..., i: i + _inplace_chunk_size, :, :]
b_chunk = b[..., i: i + _inplace_chunk_size, :, :]
a[..., i: i + _inplace_chunk_size, :, :] = (
torch.matmul(a_chunk, b_chunk)
)
p = a
else:
p = torch.matmul(a, b)
return permute_final_dims(p, (1, 2, 0))
permute_final_dims
:交换张量的最后几个维度以适配矩阵乘法。torch.matmul
:计算a
和b
的矩阵乘法,用于更新表示。_inplace_chunk_size
:支持按块处理较大的张量以节省内存。- 返回值:重新排列维度后的结果张量
p
。
3. forward
方法(抽象)
@abstractmethod
def forward(self, z, mask=None, inplace_safe=False, _add_with_inplace=False):
pass
- 参数说明:
z
: [*, N_res, N_res, C_z],输入的成对表示张量。mask
: 可选的掩码张量,用于处理不需要更新的部分。inplace_safe
:是否安全地进行原地操作。_add_with_inplace
:是否在加法中使用原地操作。
- 作用:具体的三角形乘法更新逻辑将在子类中实现。
作用与意义
-
三角形乘法更新:
- 用于更新成对表示张量,捕捉蛋白质序列中残基之间的几何和物理关系。
- 根据
_outgoing
,可以处理三角形结构中不同方向的信息流。
-
特征提取与整合:
- 通过线性变换(
linear_g
和linear_z
)和矩阵乘法,整合隐藏层的高阶特征。
- 通过线性变换(
-
应用场景:
- 主要在 AlphaFold 中用于蛋白质结构预测,帮助建模残基之间的复杂相互作用。