diff --git a/.readthedocs.yml b/.readthedocs.yml deleted file mode 100644 index fbdc74e..0000000 --- a/.readthedocs.yml +++ /dev/null @@ -1,13 +0,0 @@ -version: 2 - -build: - os: ubuntu-22.04 - tools: - python: "3.11" - -mkdocs: - configuration: mkdocs.yml - -python: - install: - - requirements: requirements.txt \ No newline at end of file diff --git a/alphafold3/model.py b/alphafold3/model.py index 5d26b37..d4abf48 100644 --- a/alphafold3/model.py +++ b/alphafold3/model.py @@ -230,14 +230,26 @@ def forward( class AxialAttention(nn.Module): def __init__( self, - dim, - heads, - row_attn=True, - col_attn=True, - accept_edges=False, - global_query_attn=False, + dim: int, + heads: int, + row_attn: bool = True, + col_attn: bool = True, + accept_edges: bool = False, + global_query_attn: bool = False, **kwargs, ): + """ + Axial Attention module. + + Args: + dim (int): The input dimension. + heads (int): The number of attention heads. + row_attn (bool, optional): Whether to perform row attention. Defaults to True. + col_attn (bool, optional): Whether to perform column attention. Defaults to True. + accept_edges (bool, optional): Whether to accept edges for attention bias. Defaults to False. + global_query_attn (bool, optional): Whether to perform global query attention. Defaults to False. + **kwargs: Additional keyword arguments for the Attention module. + """ super().__init__() assert not ( not row_attn and not col_attn @@ -260,7 +272,23 @@ def __init__( else None ) - def forward(self, x, edges=None, mask=None): + def forward( + self, + x: torch.Tensor, + edges: torch.Tensor = None, + mask: torch.Tensor = None, + ) -> torch.Tensor: + """ + Forward pass of the Axial Attention module. + + Args: + x (torch.Tensor): The input tensor of shape (batch_size, height, width, dim). + edges (torch.Tensor, optional): The edges tensor for attention bias. Defaults to None. + mask (torch.Tensor, optional): The mask tensor for masking attention. Defaults to None. + + Returns: + torch.Tensor: The output tensor of shape (batch_size, height, width, dim). + """ assert ( self.row_attn ^ self.col_attn ), "has to be either row or column attention, but not both" @@ -485,4 +513,29 @@ class ABlock(nn.Module): """_summary_ Triangular update -> self attention -> transition -> self attention -> triangular update - """ \ No newline at end of file + """ + + def __init__( + self, + dim: int, + seq_len: int, + heads: int, + dim_head: int, + dropout: float, + global_column_attn: bool, + ): + super().__init__() + self.dim = dim + self.seq_len = seq_len + self.heads = heads + self.dim_head = dim_head + self.dropout = dropout + self.global_column_attn = global_column_attn + + self.msa = MsaAttentionBlock( + dim=dim, + seq_len=seq_len, + heads=heads, + dim_head=dim_head, + dropout=dropout, + ) diff --git a/alphafold3/pairformer.py b/alphafold3/pairformer.py new file mode 100644 index 0000000..efe4121 --- /dev/null +++ b/alphafold3/pairformer.py @@ -0,0 +1,230 @@ +from alphafold2_pytorch.utils import * +from einops import rearrange +from torch import nn +from torch.utils.checkpoint import checkpoint_sequential +from typing import Tuple, Optional +import torch +from alphafold3.model import ( + FeedForward, + AxialAttention, + TriangleMultiplicativeModule, +) + +# structure module + + +def default(val, d): + return val if val is not None else d + + +def exists(val): + return val is not None + + +# PairFormer blocks + + +class OuterMean(nn.Module): + def __init__(self, dim, hidden_dim=None, eps=1e-5): + super().__init__() + self.eps = eps + self.norm = nn.LayerNorm(dim) + hidden_dim = default(hidden_dim, dim) + + self.left_proj = nn.Linear(dim, hidden_dim) + self.right_proj = nn.Linear(dim, hidden_dim) + self.proj_out = nn.Linear(hidden_dim, dim) + + def forward(self, x, mask=None): + x = self.norm(x) + left = self.left_proj(x) + right = self.right_proj(x) + outer = rearrange(left, "b m i d -> b m i () d") * rearrange( + right, "b m j d -> b m () j d" + ) + + if exists(mask): + # masked mean, if there are padding in the rows of the MSA + mask = rearrange( + mask, "b m i -> b m i () ()" + ) * rearrange(mask, "b m j -> b m () j ()") + outer = outer.masked_fill(~mask, 0.0) + outer = outer.mean(dim=1) / (mask.sum(dim=1) + self.eps) + else: + outer = outer.mean(dim=1) + + return self.proj_out(outer) + + +class PairwiseAttentionBlock(nn.Module): + def __init__( + self, + dim, + seq_len, + heads, + dim_head, + dropout=0.0, + global_column_attn=False, + ): + super().__init__() + self.outer_mean = OuterMean(dim) + + self.triangle_attention_outgoing = AxialAttention( + dim=dim, + heads=heads, + dim_head=dim_head, + row_attn=True, + col_attn=False, + accept_edges=True, + ) + self.triangle_attention_ingoing = AxialAttention( + dim=dim, + heads=heads, + dim_head=dim_head, + row_attn=False, + col_attn=True, + accept_edges=True, + global_query_attn=global_column_attn, + ) + self.triangle_multiply_outgoing = ( + TriangleMultiplicativeModule(dim=dim, mix="outgoing") + ) + self.triangle_multiply_ingoing = TriangleMultiplicativeModule( + dim=dim, mix="ingoing" + ) + + def forward(self, x, mask=None, msa_repr=None, msa_mask=None): + if exists(msa_repr): + x = x + self.outer_mean(msa_repr, mask=msa_mask) + + x = self.triangle_multiply_outgoing(x, mask=mask) + x + x = self.triangle_multiply_ingoing(x, mask=mask) + x + x = ( + self.triangle_attention_outgoing(x, edges=x, mask=mask) + + x + ) + x = self.triangle_attention_ingoing(x, edges=x, mask=mask) + x + return x + + +class MsaAttentionBlock(nn.Module): + def __init__(self, dim, seq_len, heads, dim_head, dropout=0.0): + super().__init__() + self.row_attn = AxialAttention( + dim=dim, + heads=heads, + dim_head=dim_head, + row_attn=True, + col_attn=False, + accept_edges=True, + ) + self.col_attn = AxialAttention( + dim=dim, + heads=heads, + dim_head=dim_head, + row_attn=False, + col_attn=True, + ) + + def forward(self, x, mask=None, pairwise_repr=None): + x = self.row_attn(x, mask=mask, edges=pairwise_repr) + x + x = self.col_attn(x, mask=mask) + x + return x + + +# main PairFormer class +class PairFormerBlock(nn.Module): + def __init__( + self, + *, + dim: int, + seq_len: int, + heads: int, + dim_head: int, + attn_dropout: float, + ff_dropout: float, + global_column_attn: bool = False, + ): + """ + PairFormer Block module. + + Args: + dim: The input dimension. + seq_len: The length of the sequence. + heads: The number of attention heads. + dim_head: The dimension of each attention head. + attn_dropout: The dropout rate for attention layers. + ff_dropout: The dropout rate for feed-forward layers. + global_column_attn: Whether to use global column attention in pairwise attention block. + """ + super().__init__() + self.layer = nn.ModuleList( + [ + PairwiseAttentionBlock( + dim=dim, + seq_len=seq_len, + heads=heads, + dim_head=dim_head, + dropout=attn_dropout, + global_column_attn=global_column_attn, + ), + FeedForward(dim=dim, dropout=ff_dropout), + MsaAttentionBlock( + dim=dim, + seq_len=seq_len, + heads=heads, + dim_head=dim_head, + dropout=attn_dropout, + ), + FeedForward(dim=dim, dropout=ff_dropout), + ] + ) + + def forward( + self, + inputs: Tuple[ + torch.Tensor, + torch.Tensor, + Optional[torch.Tensor], + Optional[torch.Tensor], + ], + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + Optional[torch.Tensor], + Optional[torch.Tensor], + ]: + """ + Forward pass of the PairFormer Block. + + Args: + inputs: A tuple containing the input tensors (x, m, mask, msa_mask). + + Returns: + A tuple containing the output tensors (x, m, mask, msa_mask). + """ + x, m, mask, msa_mask = inputs + attn, ff, msa_attn, msa_ff = self.layer + + # msa attention and transition + m = msa_attn(m, mask=msa_mask, pairwise_repr=x) + m = msa_ff(m) + m + + # pairwise attention and transition + x = attn(x, mask=mask, msa_repr=m, msa_mask=msa_mask) + x = ff(x) + x + + return x, m, mask, msa_mask + + +class PairFormer(nn.Module): + def __init__(self, *, depth, **kwargs): + super().__init__() + self.layers = nn.ModuleList( + [PairFormerBlock(**kwargs) for _ in range(depth)] + ) + + def forward(self, x, m, mask=None, msa_mask=None): + inp = (x, m, mask, msa_mask) + x, m, *_ = checkpoint_sequential(self.layers, 1, inp) + return x, m diff --git a/input_type.py b/input_type.py new file mode 100644 index 0000000..4fadb58 --- /dev/null +++ b/input_type.py @@ -0,0 +1,7 @@ +import torch + +batch_size = 1 +num_nodes = 5 +num_features = 64 + +x = torch.randn(batch_size, num_nodes, num_nodes, num_features)