Skip to content

Commit

Permalink
[README]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed May 9, 2024
1 parent 6400894 commit 84067ad
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 0 deletions.
54 changes: 54 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,56 @@ Implementation of Alpha Fold 3 from the paper: "Accurate structure prediction of
## install
`$pip install alphafold3`

## Input Tensor Size Example

```python
import torch

# Define the batch size, number of nodes, and number of features
batch_size = 1
num_nodes = 5
num_features = 64

# Generate random pair representations using torch.randn
# Shape: (batch_size, num_nodes, num_nodes, num_features)
pair_representations = torch.randn(
batch_size, num_nodes, num_nodes, num_features
)

# Generate random single representations using torch.randn
# Shape: (batch_size, num_nodes, num_features)
single_representations = torch.randn(
batch_size, num_nodes, num_features
)
```

## Genetic Diffusion
Need review but basically it operates on atomic coordinates.

```python
import torch
from alphafold3.diffusion import GeneticDiffusionModule

# Create an instance of the GeneticDiffusionModule
model = GeneticDiffusionModule(channels=3, training=True)

# Generate random input coordinates
input_coords = torch.randn(10, 100, 100, 3)

# Generate random ground truth coordinates
ground_truth = torch.randn(10, 100, 100, 3)

# Pass the input coordinates and ground truth coordinates through the model
output_coords, loss = model(input_coords, ground_truth)

# Print the output coordinates
print(output_coords)

# Print the loss value
print(loss)
```


# Citation
```bibtex
@article{Abramson2024-fj,
Expand Down Expand Up @@ -35,6 +85,10 @@ Implementation of Alpha Fold 3 from the paper: "Accurate structure prediction of
```



sequences, ligands, ,covalent bonds -> input embedder [3] ->


# Todo

- [ ] Implement Figure A, implement triangle update, transition,
Expand Down
21 changes: 21 additions & 0 deletions alphafold3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,3 +539,24 @@ def __init__(
dim_head=dim_head,
dropout=dropout,
)




class AlphaFold3(nn.module):
def __init__(self, dim: int):
super().__init__()
self.dim = dim

self.confidence_projection = nn.Linear(dim, 1)

def forward(
self,
pair_representation: Tensor,
single_representation: Tensor,
return_loss: bool = False,
ground_truth: Tensor = None,
return_confidence: bool = False
) -> Tensor:
pass

99 changes: 99 additions & 0 deletions alphafold3/pair_transition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

import torch
import torch.nn as nn

from torch.nn import Linear, LayerNorm
from openfold.utils.chunk_utils import chunk_layer


class PairTransition(nn.Module):
"""
Implements Algorithm 15.
"""

def __init__(self, c_z, n):
"""
Args:
c_z:
Pair transition channel dimension
n:
Factor by which c_z is multiplied to obtain hidden channel
dimension
"""
super(PairTransition, self).__init__()

self.c_z = c_z
self.n = n

self.layer_norm = LayerNorm(self.c_z)
self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu")
self.relu = nn.ReLU()
self.linear_2 = Linear(self.n * self.c_z, c_z, init="final")

def _transition(self, z, mask):
# [*, N_res, N_res, C_z]
z = self.layer_norm(z)

# [*, N_res, N_res, C_hidden]
z = self.linear_1(z)
z = self.relu(z)

# [*, N_res, N_res, C_z]
z = self.linear_2(z)
z = z * mask

return z

@torch.jit.ignore
def _chunk(self,
z: torch.Tensor,
mask: torch.Tensor,
chunk_size: int,
) -> torch.Tensor:
return chunk_layer(
self._transition,
{"z": z, "mask": mask},
chunk_size=chunk_size,
no_batch_dims=len(z.shape[:-2]),
)

def forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
) -> torch.Tensor:
"""
Args:
z:
[*, N_res, N_res, C_z] pair embedding
Returns:
[*, N_res, N_res, C_z] pair embedding update
"""
# DISCREPANCY: DeepMind forgets to apply the mask in this module.
if mask is None:
mask = z.new_ones(z.shape[:-1])

# [*, N_res, N_res, 1]
mask = mask.unsqueeze(-1)

if chunk_size is not None:
z = self._chunk(z, mask, chunk_size)
else:
z = self._transition(z=z, mask=mask)

return z
20 changes: 20 additions & 0 deletions diffusion_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch
from alphafold3.diffusion import GeneticDiffusionModule

# Create an instance of the GeneticDiffusionModule
model = GeneticDiffusionModule(channels=3, training=True)

# Generate random input coordinates
input_coords = torch.randn(10, 100, 100, 3)

# Generate random ground truth coordinates
ground_truth = torch.randn(10, 100, 100, 3)

# Pass the input coordinates and ground truth coordinates through the model
output_coords, loss = model(input_coords, ground_truth)

# Print the output coordinates
print(output_coords)

# Print the loss value
print(loss)
6 changes: 6 additions & 0 deletions input_type.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import torch

# Define the batch size, number of nodes, and number of features
batch_size = 1
num_nodes = 5
num_features = 64

# Generate random pair representations using torch.randn
# Shape: (batch_size, num_nodes, num_nodes, num_features)
pair_representations = torch.randn(
batch_size, num_nodes, num_nodes, num_features
)

# Generate random single representations using torch.randn
# Shape: (batch_size, num_nodes, num_features)
single_representations = torch.randn(
batch_size, num_nodes, num_features
)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ python = "^3.10"
zetascale = "*"
torch = "*"
einops = "*"
openfold = "*"


[tool.poetry.group.lint.dependencies]
Expand Down

0 comments on commit 84067ad

Please sign in to comment.