Bootstrap

AF3 BaseTriangleMultiplicativeUpdate类解读

BaseTriangleMultiplicativeUpdate 类是一个抽象基类 (ABC),用于实现 AlphaFold 相关算法(具体为算法 11 和 12)。它的主要功能是通过三角形乘法更新成对表示张量(pairwise representation tensor)。

源代码:


from functools import partialmethod
from typing import Optional
from abc import ABC, abstractmethod

import torch
import torch.nn as nn
from torch.nn import LayerNorm
from src.models.components.primitives import Linear
from src.utils.chunk_utils import chunk_layer
from src.utils.tensor_utils import add, permute_final_dims


class BaseTriangleMultiplicativeUpdate(nn.Module, ABC):
    """
    Implements Algorithms 11 and 12.
    """

    @abstractmethod
    def __init__(self, c_z, c_hidden, _outgoing):
        """
        Args:
            c_z:
                Input channel dimension
            c:
                Hidden channel dimension
        """
        super(BaseTriangleMultiplicativeUpdate, self).__init__()
        self.c_z = c_z
        self.c_hidden = c_hidden
        self._outgoing = _outgoing

        self.linear_g = Linear(self.c_z, self.c_z, init="gating")
        self.linear_z = Linear(self.c_hidden, self.c_z, init="final")

        self.layer_norm_in = LayerNorm(self.c_z)
        self.layer_norm_out = LayerNorm(self.c_hidden)

        self.sigmoid = nn.Sigmoid()

    def _combine_projections(self,
                             a: torch.Tensor,
                             b: torch.Tensor,
                             _inplace_chunk_size: Optional[int] = None
                             ) -> torch.Tensor:
        if self._outgoing:
            a = permute_final_dims(a, (2, 0, 1))
            b = permute_final_dims(b, (2, 1, 0))
        else:
            a = permute_final_dims(a, (2, 1, 0))
            b = permute_final_dims(b, (2, 0, 1))

        if _inplace_chunk_size is not None:
            # To be replaced by torch vmap
            for i in range(0, a.shape[-3], _inplace_chunk_size):
                a_chunk = a[..., i: i + _inplace_chunk_size, :, :]
                b_chunk = b[..., i: i + _inplace_chunk_size, :, :]
                a[..., i: i + _inplace_chunk_size, :, :] = (
                    torch.matmul(
                        a_chunk,
                        b_chunk,
                    )
                )

            p = a
        else:
            p = torch.matmul(a, b)

        return permute_final_dims(p, (1, 2, 0))

    @abstractmethod
    def forward(self,
                z: torch.Tensor,
                mask: Optional[torch.Tensor] = None,
                inplace_safe: bool = False,
                _add_with_inplace: bool = False
                ) -> torch.Tensor:
        """
        Args:
            x:
                [*, N_res, N_res, C_z] x tensor
            mask:
                [*, N_res, N_res] x mask
        Returns:
            [*, N_res, N_res, C_z] output tensor
        """
        pass

代码解读:

1. 构造方法(__init__
def __init__(self, c_z, c_hidden, _outgoing):
    super(BaseTriangleMultiplicativeUpdate, self).__init__()
    self.c_z = c_z
    self.c_hidden = c_hidden
    self._outgoing = _outgoing

    self.linear_g = Linear(self.c_z, self.c_z, init="gating")
    self.linear_z = Linear(self.c_hidden, self.c_z, init="final")

    self.layer_norm_in = LayerNorm(self.c_z)
    self.layer_norm_out = LayerNorm(self.c_hidden)

    self.sigmoid = nn.Sigmoid()
  • c_z: 输入的通道维度。
  • c_hidden: 隐藏层的通道维度。
  • _outgoing: 一个布尔标志,用于指示操作的方向(用于三角形乘法:出发或抵达方向)。
  • 主要组件:
    • linear_g 和 linear_z:线性层,用于特征变换。
    • layer_norm_in 和 layer_norm_out:层归一化,用于提高训练稳定性。
    • sigmoid:对门控张量进行非线性变换。

2. _combine_projections 方法
def _combine_projections(self, a, b, _inplace_chunk_size=None):
    if self._outgoing:
        a = permute_final_dims(a, (2, 0, 1))
        b = permute_final_dims(b, (2, 1, 0))
    else:
        a = permute_final_dims(a, (2, 1, 0))
        b = permute_final_dims(b, (2, 0, 1))

    if _inplace_chunk_size is not None:
        for i in range(0, a.shape[-3], _inplace_chunk_size):
            a_chunk = a[..., i: i + _inplace_chunk_size, :, :]
            b_chunk = b[..., i: i + _inplace_chunk_size, :, :]
            a[..., i: i + _inplace_chunk_size, :, :] = (
                torch.matmul(a_chunk, b_chunk)
            )
        p = a
    else:
        p = torch.matmul(a, b)

    return permute_final_dims(p, (1, 2, 0))
  • permute_final_dims:交换张量的最后几个维度以适配矩阵乘法。
  • torch.matmul:计算 a 和 b 的矩阵乘法,用于更新表示。
  • _inplace_chunk_size:支持按块处理较大的张量以节省内存。
  • 返回值:重新排列维度后的结果张量 p

3. forward 方法(抽象)
@abstractmethod
def forward(self, z, mask=None, inplace_safe=False, _add_with_inplace=False):
    pass
  • 参数说明
    • z: [*, N_res, N_res, C_z],输入的成对表示张量。
    • mask: 可选的掩码张量,用于处理不需要更新的部分。
    • inplace_safe:是否安全地进行原地操作。
    • _add_with_inplace:是否在加法中使用原地操作。
  • 作用:具体的三角形乘法更新逻辑将在子类中实现。

作用与意义

  1. 三角形乘法更新

    • 用于更新成对表示张量,捕捉蛋白质序列中残基之间的几何和物理关系。
    • 根据 _outgoing,可以处理三角形结构中不同方向的信息流。
  2. 特征提取与整合

    • 通过线性变换(linear_g 和 linear_z)和矩阵乘法,整合隐藏层的高阶特征。
  3. 应用场景

    • 主要在 AlphaFold 中用于蛋白质结构预测,帮助建模残基之间的复杂相互作用。
;