Skip to content

Commit

Permalink
[FEAT][GeneticDiffusionModule]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed May 8, 2024
1 parent 848e6b8 commit 8d1d8cb
Showing 1 changed file with 94 additions and 0 deletions.
94 changes: 94 additions & 0 deletions alphafold3/diffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import torch
from torch import nn, Tensor
import torch.nn.functional as F


class GeneticDiffusionModule(nn.Module):
"""
Diffusion Module from AlphaFold 3.
This module directly predicts raw atom coordinates via a generative diffusion process.
It leverages a diffusion model trained to denoise 'noised' atomic coordinates back to their
true state. The diffusion process captures both local and global structural information
through a series of noise scales.
Attributes:
channels (int): The number of channels in the input feature map, corresponding to atomic features.
num_diffusion_steps (int): The number of diffusion steps or noise levels to use.
"""

def __init__(
self,
channels: int,
num_diffusion_steps: int = 1000,
training: bool = False,
):
"""
Initializes the DiffusionModule with the specified number of channels and diffusion steps.
Args:
channels (int): Number of feature channels for the input.
num_diffusion_steps (int): Number of diffusion steps (time steps in the diffusion process).
"""
super(GeneticDiffusionModule, self).__init__()
self.channels = channels
self.num_diffusion_steps = num_diffusion_steps
self.training = training
self.noise_scale = nn.Parametr(
torch.linspace(1.0, 0.01, num_diffusion_steps)
)
self.prediction_network = nn.Sequential(
nn.Linear(channels, channels * 2),
nn.ReLU(),
nn.Linear(channels * 2, channels),
)

def forward(self, x: Tensor = None, ground_truth: Tensor = None):
"""
Forward pass of the DiffusionModule. Applies a sequence of noise and denoising operations to
the input coordinates to simulate the diffusion process.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, num_atoms, channels)
representing the atomic features including coordinates.
Returns:
torch.Tensor: Output tensor of shape (batch_size, num_atoms, channels) with
denoised atom coordinates.
"""
batch_size, num_atoms, channels = x.shape
noisy_x = x.clone()

for step in range(self.num_diffusion_steps):
# Generate noise scaled by the noise level for the current step
noise_level = self.noise_scale[step]
noise = (
torch.randn(
batch_size, num_atoms, channels, device=x.device
)
* noise_level
)

# Add noise to the input
noisy_x = x + noise

# Predict and denoise the noisy input
noisy_x = self.prediction_network(noisy_x)

if self.training and ground_truth is not None:
loss = F.mse_loss(noisy_x, ground_truth)
return noisy_x, loss

return noisy_x


# Example usage
if __name__ == "__main__":
model = GeneticDiffusionModule(
channels=3
) # Assuming 3D coordinates
input_coords = torch.randn(
10, 100, 3
) # Example with batch size 10 and 100 atoms
output_coords = model(input_coords)
print(output_coords.shape) # Should be (10, 100, 3)

0 comments on commit 8d1d8cb

Please sign in to comment.