From c4ed6070a8de09390552c6cc597416076cc731ca Mon Sep 17 00:00:00 2001 From: Kye Date: Thu, 9 May 2024 16:32:50 -0400 Subject: [PATCH] [FEAT][Main class] --- README.md | 6 +- alphafold3/__init__.py | 4 +- alphafold3/diffusion.py | 67 ++++++++++++++- alphafold3/pair_transition.py | 99 ----------------------- alphafold3/pairformer.py | 2 +- diffusion_example.py | 6 +- alphafold3/model.py => model.py | 139 ++++++++++++++++++++++++++------ requirements.txt | 3 +- 8 files changed, 188 insertions(+), 138 deletions(-) delete mode 100644 alphafold3/pair_transition.py rename alphafold3/model.py => model.py (79%) diff --git a/README.md b/README.md index d8421ab..51a0a0f 100644 --- a/README.md +++ b/README.md @@ -35,10 +35,10 @@ Need review but basically it operates on atomic coordinates. ```python import torch -from alphafold3.diffusion import GeneticDiffusionModule +from alphafold3.diffusion import GeneticDiffusionModuleBlock -# Create an instance of the GeneticDiffusionModule -model = GeneticDiffusionModule(channels=3, training=True) +# Create an instance of the GeneticDiffusionModuleBlock +model = GeneticDiffusionModuleBlock(channels=3, training=True) # Generate random input coordinates input_coords = torch.randn(10, 100, 100, 3) diff --git a/alphafold3/__init__.py b/alphafold3/__init__.py index b6080b7..0a6b229 100644 --- a/alphafold3/__init__.py +++ b/alphafold3/__init__.py @@ -1,5 +1,5 @@ -from alphafold3.diffusion import GeneticDiffusionModule +from alphafold3.diffusion import GeneticDiffusionModuleBlock __all__ = [ - "GeneticDiffusionModule", + "GeneticDiffusionModuleBlock", ] diff --git a/alphafold3/diffusion.py b/alphafold3/diffusion.py index 98263ce..c3b1520 100644 --- a/alphafold3/diffusion.py +++ b/alphafold3/diffusion.py @@ -3,7 +3,7 @@ import torch.nn.functional as F -class GeneticDiffusionModule(nn.Module): +class GeneticDiffusionModuleBlock(nn.Module): """ Diffusion Module from AlphaFold 3. @@ -22,6 +22,7 @@ def __init__( channels: int, num_diffusion_steps: int = 1000, training: bool = False, + depth: int = 30, ): """ Initializes the DiffusionModule with the specified number of channels and diffusion steps. @@ -30,10 +31,11 @@ def __init__( channels (int): Number of feature channels for the input. num_diffusion_steps (int): Number of diffusion steps (time steps in the diffusion process). """ - super(GeneticDiffusionModule, self).__init__() + super(GeneticDiffusionModuleBlock, self).__init__() self.channels = channels self.num_diffusion_steps = num_diffusion_steps self.training = training + self.depth = depth self.noise_scale = nn.Parameter( torch.linspace(1.0, 0.01, num_diffusion_steps) ) @@ -77,9 +79,68 @@ def forward(self, x: Tensor = None, ground_truth: Tensor = None): return noisy_x +class GeneticDiffusion(nn.Module): + """ + GeneticDiffusion module for performing genetic diffusion. + + Args: + channels (int): Number of input channels. + num_diffusion_steps (int): Number of diffusion steps to perform. + training (bool): Whether the module is in training mode or not. + depth (int): Number of diffusion module blocks to stack. + + Attributes: + channels (int): Number of input channels. + num_diffusion_steps (int): Number of diffusion steps to perform. + training (bool): Whether the module is in training mode or not. + depth (int): Number of diffusion module blocks to stack. + layers (nn.ModuleList): List of GeneticDiffusionModuleBlock instances. + + """ + + def __init__( + self, + channels: int, + num_diffusion_steps: int = 1000, + training: bool = False, + depth: int = 30, + ): + super(GeneticDiffusion, self).__init__() + self.channels = channels + self.num_diffusion_steps = num_diffusion_steps + self.training = training + self.depth = depth + + # Layers + self.layers = nn.ModuleList( + [ + GeneticDiffusionModuleBlock( + channels, num_diffusion_steps, training, depth + ) + for _ in range(depth) + ] + ) + + def forward(self, x: Tensor = None, ground_truth: Tensor = None): + """ + Forward pass of the GeneticDiffusion module. + + Args: + x (Tensor): Input tensor. + ground_truth (Tensor): Ground truth tensor. + + Returns: + Tuple[Tensor, Tensor]: Output tensor and loss tensor. + + """ + for layer in self.layers: + x, loss = layer(x, ground_truth) + return x, loss + + # # Example usage # if __name__ == "__main__": -# model = GeneticDiffusionModule( +# model = GeneticDiffusionModuleBlock( # channels=3, training=True # ) # Assuming 3D coordinates # input_coords = torch.randn( diff --git a/alphafold3/pair_transition.py b/alphafold3/pair_transition.py deleted file mode 100644 index 1a34782..0000000 --- a/alphafold3/pair_transition.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional - -import torch -import torch.nn as nn - -from torch.nn import Linear, LayerNorm -from openfold.utils.chunk_utils import chunk_layer - - -class PairTransition(nn.Module): - """ - Implements Algorithm 15. - """ - - def __init__(self, c_z, n): - """ - Args: - c_z: - Pair transition channel dimension - n: - Factor by which c_z is multiplied to obtain hidden channel - dimension - """ - super(PairTransition, self).__init__() - - self.c_z = c_z - self.n = n - - self.layer_norm = LayerNorm(self.c_z) - self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu") - self.relu = nn.ReLU() - self.linear_2 = Linear(self.n * self.c_z, c_z, init="final") - - def _transition(self, z, mask): - # [*, N_res, N_res, C_z] - z = self.layer_norm(z) - - # [*, N_res, N_res, C_hidden] - z = self.linear_1(z) - z = self.relu(z) - - # [*, N_res, N_res, C_z] - z = self.linear_2(z) - z = z * mask - - return z - - @torch.jit.ignore - def _chunk(self, - z: torch.Tensor, - mask: torch.Tensor, - chunk_size: int, - ) -> torch.Tensor: - return chunk_layer( - self._transition, - {"z": z, "mask": mask}, - chunk_size=chunk_size, - no_batch_dims=len(z.shape[:-2]), - ) - - def forward(self, - z: torch.Tensor, - mask: Optional[torch.Tensor] = None, - chunk_size: Optional[int] = None, - ) -> torch.Tensor: - """ - Args: - z: - [*, N_res, N_res, C_z] pair embedding - Returns: - [*, N_res, N_res, C_z] pair embedding update - """ - # DISCREPANCY: DeepMind forgets to apply the mask in this module. - if mask is None: - mask = z.new_ones(z.shape[:-1]) - - # [*, N_res, N_res, 1] - mask = mask.unsqueeze(-1) - - if chunk_size is not None: - z = self._chunk(z, mask, chunk_size) - else: - z = self._transition(z=z, mask=mask) - - return z \ No newline at end of file diff --git a/alphafold3/pairformer.py b/alphafold3/pairformer.py index efe4121..f6856c8 100644 --- a/alphafold3/pairformer.py +++ b/alphafold3/pairformer.py @@ -4,7 +4,7 @@ from torch.utils.checkpoint import checkpoint_sequential from typing import Tuple, Optional import torch -from alphafold3.model import ( +from model import ( FeedForward, AxialAttention, TriangleMultiplicativeModule, diff --git a/diffusion_example.py b/diffusion_example.py index bca9656..178cec4 100644 --- a/diffusion_example.py +++ b/diffusion_example.py @@ -1,8 +1,8 @@ import torch -from alphafold3.diffusion import GeneticDiffusionModule +from alphafold3.diffusion import GeneticDiffusionModuleBlock -# Create an instance of the GeneticDiffusionModule -model = GeneticDiffusionModule(channels=3, training=True) +# Create an instance of the GeneticDiffusionModuleBlock +model = GeneticDiffusionModuleBlock(channels=3, training=True) # Generate random input coordinates input_coords = torch.randn(10, 100, 100, 3) diff --git a/alphafold3/model.py b/model.py similarity index 79% rename from alphafold3/model.py rename to model.py index 275233d..3bee67d 100644 --- a/alphafold3/model.py +++ b/model.py @@ -1,13 +1,14 @@ -import torch -from torch import nn, einsum -from inspect import isfunction from dataclasses import dataclass -import torch.nn.functional as F +from inspect import isfunction +import torch +import torch.nn.functional as F from einops import rearrange, repeat from einops.layers.torch import Rearrange +from torch import Tensor, einsum, nn -from alphafold2_pytorch.utils import * +from alphafold3.pairformer import PairFormer +from alphafold3.diffusion import GeneticDiffusion # structure module @@ -509,10 +510,21 @@ def forward(self, x, mask=None, pairwise_repr=None): # main evoformer class -class ABlock(nn.Module): - """_summary_ - - Triangular update -> self attention -> transition -> self attention -> triangular update +class AlphaFold3(nn.Module): + """ + AlphaFold3 model implementation. + + Args: + dim (int): Dimension of the model. + seq_len (int): Length of the sequence. + heads (int): Number of attention heads. + dim_head (int): Dimension of each attention head. + attn_dropout (float): Dropout rate for attention layers. + ff_dropout (float): Dropout rate for feed-forward layers. + global_column_attn (bool, optional): Whether to use global column attention. Defaults to False. + pair_former_depth (int, optional): Depth of the PairFormer blocks. Defaults to 48. + num_diffusion_steps (int, optional): Number of diffusion steps. Defaults to 1000. + diffusion_depth (int, optional): Depth of the diffusion module. Defaults to 30. """ def __init__( @@ -521,42 +533,117 @@ def __init__( seq_len: int, heads: int, dim_head: int, - dropout: float, - global_column_attn: bool, + attn_dropout: float, + ff_dropout: float, + global_column_attn: bool = False, + pair_former_depth: int = 48, + num_diffusion_steps: int = 1000, + diffusion_depth: int = 30, ): super().__init__() self.dim = dim self.seq_len = seq_len self.heads = heads self.dim_head = dim_head - self.dropout = dropout + self.attn_dropout = attn_dropout + self.ff_dropout = ff_dropout self.global_column_attn = global_column_attn - self.msa = MsaAttentionBlock( + self.confidence_projection = nn.Linear(dim, 1) + + # Pairformer blocks + self.pairformer = PairFormer( dim=dim, seq_len=seq_len, heads=heads, dim_head=dim_head, - dropout=dropout, + attn_dropout=attn_dropout, + ff_dropout=ff_dropout, + global_column_attn=global_column_attn, + depth=pair_former_depth, ) + # Diffusion module + self.diffuser = GeneticDiffusion( + channels=dim, + num_diffusion_steps=1000, + training=False, + depth=diffusion_depth, + ) - - -class AlphaFold3(nn.module): - def __init__(self, dim: int): - super().__init__() - self.dim = dim - - self.confidence_projection = nn.Linear(dim, 1) - def forward( self, pair_representation: Tensor, single_representation: Tensor, return_loss: bool = False, ground_truth: Tensor = None, - return_confidence: bool = False + return_confidence: bool = False, + return_embeddings: bool = True, ) -> Tensor: - pass - \ No newline at end of file + """ + Forward pass of the AlphaFold3 model. + + Args: + pair_representation (Tensor): Pair representation tensor. + single_representation (Tensor): Single representation tensor. + return_loss (bool, optional): Whether to return the loss. Defaults to False. + ground_truth (Tensor, optional): Ground truth tensor. Defaults to None. + return_confidence (bool, optional): Whether to return the confidence. Defaults to False. + return_embeddings (bool, optional): Whether to return the embeddings. Defaults to False. + + Returns: + Tensor: Output tensor based on the specified return type. + """ + # Recycle bins + # recyle_bins = [] + + # TODO: Input + # TODO: Template + # TODO: MSA + + b, n, n_two, dim = pair_representation.shape + b_two, n_two, dim_two = single_representation.shape + + # Concat + x = torch.cat( + [pair_representation, single_representation], dim=1 + ) + + # Apply the 48 blocks of PairFormer + x = self.pairformer(x) + print(x.shape) + + # Add the embeddings to the recycle bins + # recyle_bins.append(x) + + + # Diffusion + x = self.diffuser(x, ground_truth) + + # If return_confidence is True, return the confidence + if return_confidence is True: + x = self.confidence_projection(x) + return x + + # If return_loss is True, return the loss + if return_embeddings is True: + return x + + +x = torch.randn(1, 5, 5, 64) +y = torch.randn(1, 5, 64) + +model = AlphaFold3( + dim=64, + seq_len=5, + heads=8, + dim_head=64, + attn_dropout=0.0, + ff_dropout=0.0, + global_column_attn=False, + pair_former_depth=48, + num_diffusion_steps=1000, + diffusion_depth=30, +) +output = model(x, y) +print(output.shape) diff --git a/requirements.txt b/requirements.txt index 236a195..d994736 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ torch zetascale -swarms +einops +alphafold2-pytorch \ No newline at end of file