diff --git a/README.md b/README.md index 01f8447..d8421ab 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,56 @@ Implementation of Alpha Fold 3 from the paper: "Accurate structure prediction of ## install `$pip install alphafold3` +## Input Tensor Size Example + +```python +import torch + +# Define the batch size, number of nodes, and number of features +batch_size = 1 +num_nodes = 5 +num_features = 64 + +# Generate random pair representations using torch.randn +# Shape: (batch_size, num_nodes, num_nodes, num_features) +pair_representations = torch.randn( + batch_size, num_nodes, num_nodes, num_features +) + +# Generate random single representations using torch.randn +# Shape: (batch_size, num_nodes, num_features) +single_representations = torch.randn( + batch_size, num_nodes, num_features +) +``` + +## Genetic Diffusion +Need review but basically it operates on atomic coordinates. + +```python +import torch +from alphafold3.diffusion import GeneticDiffusionModule + +# Create an instance of the GeneticDiffusionModule +model = GeneticDiffusionModule(channels=3, training=True) + +# Generate random input coordinates +input_coords = torch.randn(10, 100, 100, 3) + +# Generate random ground truth coordinates +ground_truth = torch.randn(10, 100, 100, 3) + +# Pass the input coordinates and ground truth coordinates through the model +output_coords, loss = model(input_coords, ground_truth) + +# Print the output coordinates +print(output_coords) + +# Print the loss value +print(loss) +``` + + # Citation ```bibtex @article{Abramson2024-fj, @@ -35,6 +85,10 @@ Implementation of Alpha Fold 3 from the paper: "Accurate structure prediction of ``` + +sequences, ligands, ,covalent bonds -> input embedder [3] -> + + # Todo - [ ] Implement Figure A, implement triangle update, transition, diff --git a/alphafold3/model.py b/alphafold3/model.py index d4abf48..275233d 100644 --- a/alphafold3/model.py +++ b/alphafold3/model.py @@ -539,3 +539,24 @@ def __init__( dim_head=dim_head, dropout=dropout, ) + + + + +class AlphaFold3(nn.module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + self.confidence_projection = nn.Linear(dim, 1) + + def forward( + self, + pair_representation: Tensor, + single_representation: Tensor, + return_loss: bool = False, + ground_truth: Tensor = None, + return_confidence: bool = False + ) -> Tensor: + pass + \ No newline at end of file diff --git a/alphafold3/pair_transition.py b/alphafold3/pair_transition.py new file mode 100644 index 0000000..1a34782 --- /dev/null +++ b/alphafold3/pair_transition.py @@ -0,0 +1,99 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +import torch.nn as nn + +from torch.nn import Linear, LayerNorm +from openfold.utils.chunk_utils import chunk_layer + + +class PairTransition(nn.Module): + """ + Implements Algorithm 15. + """ + + def __init__(self, c_z, n): + """ + Args: + c_z: + Pair transition channel dimension + n: + Factor by which c_z is multiplied to obtain hidden channel + dimension + """ + super(PairTransition, self).__init__() + + self.c_z = c_z + self.n = n + + self.layer_norm = LayerNorm(self.c_z) + self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu") + self.relu = nn.ReLU() + self.linear_2 = Linear(self.n * self.c_z, c_z, init="final") + + def _transition(self, z, mask): + # [*, N_res, N_res, C_z] + z = self.layer_norm(z) + + # [*, N_res, N_res, C_hidden] + z = self.linear_1(z) + z = self.relu(z) + + # [*, N_res, N_res, C_z] + z = self.linear_2(z) + z = z * mask + + return z + + @torch.jit.ignore + def _chunk(self, + z: torch.Tensor, + mask: torch.Tensor, + chunk_size: int, + ) -> torch.Tensor: + return chunk_layer( + self._transition, + {"z": z, "mask": mask}, + chunk_size=chunk_size, + no_batch_dims=len(z.shape[:-2]), + ) + + def forward(self, + z: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + """ + Args: + z: + [*, N_res, N_res, C_z] pair embedding + Returns: + [*, N_res, N_res, C_z] pair embedding update + """ + # DISCREPANCY: DeepMind forgets to apply the mask in this module. + if mask is None: + mask = z.new_ones(z.shape[:-1]) + + # [*, N_res, N_res, 1] + mask = mask.unsqueeze(-1) + + if chunk_size is not None: + z = self._chunk(z, mask, chunk_size) + else: + z = self._transition(z=z, mask=mask) + + return z \ No newline at end of file diff --git a/diffusion_example.py b/diffusion_example.py new file mode 100644 index 0000000..bca9656 --- /dev/null +++ b/diffusion_example.py @@ -0,0 +1,20 @@ +import torch +from alphafold3.diffusion import GeneticDiffusionModule + +# Create an instance of the GeneticDiffusionModule +model = GeneticDiffusionModule(channels=3, training=True) + +# Generate random input coordinates +input_coords = torch.randn(10, 100, 100, 3) + +# Generate random ground truth coordinates +ground_truth = torch.randn(10, 100, 100, 3) + +# Pass the input coordinates and ground truth coordinates through the model +output_coords, loss = model(input_coords, ground_truth) + +# Print the output coordinates +print(output_coords) + +# Print the loss value +print(loss) diff --git a/input_type.py b/input_type.py index 7519669..fb9e9af 100644 --- a/input_type.py +++ b/input_type.py @@ -1,12 +1,18 @@ import torch +# Define the batch size, number of nodes, and number of features batch_size = 1 num_nodes = 5 num_features = 64 +# Generate random pair representations using torch.randn +# Shape: (batch_size, num_nodes, num_nodes, num_features) pair_representations = torch.randn( batch_size, num_nodes, num_nodes, num_features ) + +# Generate random single representations using torch.randn +# Shape: (batch_size, num_nodes, num_features) single_representations = torch.randn( batch_size, num_nodes, num_features ) diff --git a/pyproject.toml b/pyproject.toml index 3c27547..5190309 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ python = "^3.10" zetascale = "*" torch = "*" einops = "*" +openfold = "*" [tool.poetry.group.lint.dependencies]