generated from kyegomez/Python-Package-Template
-
-
Notifications
You must be signed in to change notification settings - Fork 103
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Kye
committed
May 8, 2024
1 parent
848e6b8
commit 8d1d8cb
Showing
1 changed file
with
94 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import torch | ||
from torch import nn, Tensor | ||
import torch.nn.functional as F | ||
|
||
|
||
class GeneticDiffusionModule(nn.Module): | ||
""" | ||
Diffusion Module from AlphaFold 3. | ||
This module directly predicts raw atom coordinates via a generative diffusion process. | ||
It leverages a diffusion model trained to denoise 'noised' atomic coordinates back to their | ||
true state. The diffusion process captures both local and global structural information | ||
through a series of noise scales. | ||
Attributes: | ||
channels (int): The number of channels in the input feature map, corresponding to atomic features. | ||
num_diffusion_steps (int): The number of diffusion steps or noise levels to use. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
channels: int, | ||
num_diffusion_steps: int = 1000, | ||
training: bool = False, | ||
): | ||
""" | ||
Initializes the DiffusionModule with the specified number of channels and diffusion steps. | ||
Args: | ||
channels (int): Number of feature channels for the input. | ||
num_diffusion_steps (int): Number of diffusion steps (time steps in the diffusion process). | ||
""" | ||
super(GeneticDiffusionModule, self).__init__() | ||
self.channels = channels | ||
self.num_diffusion_steps = num_diffusion_steps | ||
self.training = training | ||
self.noise_scale = nn.Parametr( | ||
torch.linspace(1.0, 0.01, num_diffusion_steps) | ||
) | ||
self.prediction_network = nn.Sequential( | ||
nn.Linear(channels, channels * 2), | ||
nn.ReLU(), | ||
nn.Linear(channels * 2, channels), | ||
) | ||
|
||
def forward(self, x: Tensor = None, ground_truth: Tensor = None): | ||
""" | ||
Forward pass of the DiffusionModule. Applies a sequence of noise and denoising operations to | ||
the input coordinates to simulate the diffusion process. | ||
Args: | ||
x (torch.Tensor): Input tensor of shape (batch_size, num_atoms, channels) | ||
representing the atomic features including coordinates. | ||
Returns: | ||
torch.Tensor: Output tensor of shape (batch_size, num_atoms, channels) with | ||
denoised atom coordinates. | ||
""" | ||
batch_size, num_atoms, channels = x.shape | ||
noisy_x = x.clone() | ||
|
||
for step in range(self.num_diffusion_steps): | ||
# Generate noise scaled by the noise level for the current step | ||
noise_level = self.noise_scale[step] | ||
noise = ( | ||
torch.randn( | ||
batch_size, num_atoms, channels, device=x.device | ||
) | ||
* noise_level | ||
) | ||
|
||
# Add noise to the input | ||
noisy_x = x + noise | ||
|
||
# Predict and denoise the noisy input | ||
noisy_x = self.prediction_network(noisy_x) | ||
|
||
if self.training and ground_truth is not None: | ||
loss = F.mse_loss(noisy_x, ground_truth) | ||
return noisy_x, loss | ||
|
||
return noisy_x | ||
|
||
|
||
# Example usage | ||
if __name__ == "__main__": | ||
model = GeneticDiffusionModule( | ||
channels=3 | ||
) # Assuming 3D coordinates | ||
input_coords = torch.randn( | ||
10, 100, 3 | ||
) # Example with batch size 10 and 100 atoms | ||
output_coords = model(input_coords) | ||
print(output_coords.shape) # Should be (10, 100, 3) |