Skip to content

Commit

Permalink
[FEAT][PairFormer]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed May 8, 2024
1 parent 3f998d2 commit 848e6b8
Show file tree
Hide file tree
Showing 4 changed files with 298 additions and 21 deletions.
13 changes: 0 additions & 13 deletions .readthedocs.yml

This file was deleted.

69 changes: 61 additions & 8 deletions alphafold3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,14 +230,26 @@ def forward(
class AxialAttention(nn.Module):
def __init__(
self,
dim,
heads,
row_attn=True,
col_attn=True,
accept_edges=False,
global_query_attn=False,
dim: int,
heads: int,
row_attn: bool = True,
col_attn: bool = True,
accept_edges: bool = False,
global_query_attn: bool = False,
**kwargs,
):
"""
Axial Attention module.
Args:
dim (int): The input dimension.
heads (int): The number of attention heads.
row_attn (bool, optional): Whether to perform row attention. Defaults to True.
col_attn (bool, optional): Whether to perform column attention. Defaults to True.
accept_edges (bool, optional): Whether to accept edges for attention bias. Defaults to False.
global_query_attn (bool, optional): Whether to perform global query attention. Defaults to False.
**kwargs: Additional keyword arguments for the Attention module.
"""
super().__init__()
assert not (
not row_attn and not col_attn
Expand All @@ -260,7 +272,23 @@ def __init__(
else None
)

def forward(self, x, edges=None, mask=None):
def forward(
self,
x: torch.Tensor,
edges: torch.Tensor = None,
mask: torch.Tensor = None,
) -> torch.Tensor:
"""
Forward pass of the Axial Attention module.
Args:
x (torch.Tensor): The input tensor of shape (batch_size, height, width, dim).
edges (torch.Tensor, optional): The edges tensor for attention bias. Defaults to None.
mask (torch.Tensor, optional): The mask tensor for masking attention. Defaults to None.
Returns:
torch.Tensor: The output tensor of shape (batch_size, height, width, dim).
"""
assert (
self.row_attn ^ self.col_attn
), "has to be either row or column attention, but not both"
Expand Down Expand Up @@ -485,4 +513,29 @@ class ABlock(nn.Module):
"""_summary_
Triangular update -> self attention -> transition -> self attention -> triangular update
"""
"""

def __init__(
self,
dim: int,
seq_len: int,
heads: int,
dim_head: int,
dropout: float,
global_column_attn: bool,
):
super().__init__()
self.dim = dim
self.seq_len = seq_len
self.heads = heads
self.dim_head = dim_head
self.dropout = dropout
self.global_column_attn = global_column_attn

self.msa = MsaAttentionBlock(
dim=dim,
seq_len=seq_len,
heads=heads,
dim_head=dim_head,
dropout=dropout,
)
230 changes: 230 additions & 0 deletions alphafold3/pairformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
from alphafold2_pytorch.utils import *
from einops import rearrange
from torch import nn
from torch.utils.checkpoint import checkpoint_sequential
from typing import Tuple, Optional
import torch
from alphafold3.model import (
FeedForward,
AxialAttention,
TriangleMultiplicativeModule,
)

# structure module


def default(val, d):
return val if val is not None else d


def exists(val):
return val is not None


# PairFormer blocks


class OuterMean(nn.Module):
def __init__(self, dim, hidden_dim=None, eps=1e-5):
super().__init__()
self.eps = eps
self.norm = nn.LayerNorm(dim)
hidden_dim = default(hidden_dim, dim)

self.left_proj = nn.Linear(dim, hidden_dim)
self.right_proj = nn.Linear(dim, hidden_dim)
self.proj_out = nn.Linear(hidden_dim, dim)

def forward(self, x, mask=None):
x = self.norm(x)
left = self.left_proj(x)
right = self.right_proj(x)
outer = rearrange(left, "b m i d -> b m i () d") * rearrange(
right, "b m j d -> b m () j d"
)

if exists(mask):
# masked mean, if there are padding in the rows of the MSA
mask = rearrange(
mask, "b m i -> b m i () ()"
) * rearrange(mask, "b m j -> b m () j ()")
outer = outer.masked_fill(~mask, 0.0)
outer = outer.mean(dim=1) / (mask.sum(dim=1) + self.eps)
else:
outer = outer.mean(dim=1)

return self.proj_out(outer)


class PairwiseAttentionBlock(nn.Module):
def __init__(
self,
dim,
seq_len,
heads,
dim_head,
dropout=0.0,
global_column_attn=False,
):
super().__init__()
self.outer_mean = OuterMean(dim)

self.triangle_attention_outgoing = AxialAttention(
dim=dim,
heads=heads,
dim_head=dim_head,
row_attn=True,
col_attn=False,
accept_edges=True,
)
self.triangle_attention_ingoing = AxialAttention(
dim=dim,
heads=heads,
dim_head=dim_head,
row_attn=False,
col_attn=True,
accept_edges=True,
global_query_attn=global_column_attn,
)
self.triangle_multiply_outgoing = (
TriangleMultiplicativeModule(dim=dim, mix="outgoing")
)
self.triangle_multiply_ingoing = TriangleMultiplicativeModule(
dim=dim, mix="ingoing"
)

def forward(self, x, mask=None, msa_repr=None, msa_mask=None):
if exists(msa_repr):
x = x + self.outer_mean(msa_repr, mask=msa_mask)

x = self.triangle_multiply_outgoing(x, mask=mask) + x
x = self.triangle_multiply_ingoing(x, mask=mask) + x
x = (
self.triangle_attention_outgoing(x, edges=x, mask=mask)
+ x
)
x = self.triangle_attention_ingoing(x, edges=x, mask=mask) + x
return x


class MsaAttentionBlock(nn.Module):
def __init__(self, dim, seq_len, heads, dim_head, dropout=0.0):
super().__init__()
self.row_attn = AxialAttention(
dim=dim,
heads=heads,
dim_head=dim_head,
row_attn=True,
col_attn=False,
accept_edges=True,
)
self.col_attn = AxialAttention(
dim=dim,
heads=heads,
dim_head=dim_head,
row_attn=False,
col_attn=True,
)

def forward(self, x, mask=None, pairwise_repr=None):
x = self.row_attn(x, mask=mask, edges=pairwise_repr) + x
x = self.col_attn(x, mask=mask) + x
return x


# main PairFormer class
class PairFormerBlock(nn.Module):
def __init__(
self,
*,
dim: int,
seq_len: int,
heads: int,
dim_head: int,
attn_dropout: float,
ff_dropout: float,
global_column_attn: bool = False,
):
"""
PairFormer Block module.
Args:
dim: The input dimension.
seq_len: The length of the sequence.
heads: The number of attention heads.
dim_head: The dimension of each attention head.
attn_dropout: The dropout rate for attention layers.
ff_dropout: The dropout rate for feed-forward layers.
global_column_attn: Whether to use global column attention in pairwise attention block.
"""
super().__init__()
self.layer = nn.ModuleList(
[
PairwiseAttentionBlock(
dim=dim,
seq_len=seq_len,
heads=heads,
dim_head=dim_head,
dropout=attn_dropout,
global_column_attn=global_column_attn,
),
FeedForward(dim=dim, dropout=ff_dropout),
MsaAttentionBlock(
dim=dim,
seq_len=seq_len,
heads=heads,
dim_head=dim_head,
dropout=attn_dropout,
),
FeedForward(dim=dim, dropout=ff_dropout),
]
)

def forward(
self,
inputs: Tuple[
torch.Tensor,
torch.Tensor,
Optional[torch.Tensor],
Optional[torch.Tensor],
],
) -> Tuple[
torch.Tensor,
torch.Tensor,
Optional[torch.Tensor],
Optional[torch.Tensor],
]:
"""
Forward pass of the PairFormer Block.
Args:
inputs: A tuple containing the input tensors (x, m, mask, msa_mask).
Returns:
A tuple containing the output tensors (x, m, mask, msa_mask).
"""
x, m, mask, msa_mask = inputs
attn, ff, msa_attn, msa_ff = self.layer

# msa attention and transition
m = msa_attn(m, mask=msa_mask, pairwise_repr=x)
m = msa_ff(m) + m

# pairwise attention and transition
x = attn(x, mask=mask, msa_repr=m, msa_mask=msa_mask)
x = ff(x) + x

return x, m, mask, msa_mask


class PairFormer(nn.Module):
def __init__(self, *, depth, **kwargs):
super().__init__()
self.layers = nn.ModuleList(
[PairFormerBlock(**kwargs) for _ in range(depth)]
)

def forward(self, x, m, mask=None, msa_mask=None):
inp = (x, m, mask, msa_mask)
x, m, *_ = checkpoint_sequential(self.layers, 1, inp)
return x, m
7 changes: 7 additions & 0 deletions input_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import torch

batch_size = 1
num_nodes = 5
num_features = 64

x = torch.randn(batch_size, num_nodes, num_nodes, num_features)

0 comments on commit 848e6b8

Please sign in to comment.