From dd97234c716b812c083999f55f3bc293e8c57204 Mon Sep 17 00:00:00 2001 From: Kye Date: Wed, 8 May 2024 14:39:18 -0400 Subject: [PATCH] [cleanup] --- alphafold3/__init__.py | 5 +++++ alphafold3/diffusion.py | 30 +++++++++++++++--------------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/alphafold3/__init__.py b/alphafold3/__init__.py index e69de29..b6080b7 100644 --- a/alphafold3/__init__.py +++ b/alphafold3/__init__.py @@ -0,0 +1,5 @@ +from alphafold3.diffusion import GeneticDiffusionModule + +__all__ = [ + "GeneticDiffusionModule", +] diff --git a/alphafold3/diffusion.py b/alphafold3/diffusion.py index f871f47..98263ce 100644 --- a/alphafold3/diffusion.py +++ b/alphafold3/diffusion.py @@ -63,7 +63,7 @@ def forward(self, x: Tensor = None, ground_truth: Tensor = None): # Generate noise scaled by the noise level for the current step noise_level = self.noise_scale[step] noise = torch.randn_like(x) * noise_level - + # Add noise to the input noisy_x = x + noise @@ -77,17 +77,17 @@ def forward(self, x: Tensor = None, ground_truth: Tensor = None): return noisy_x -# Example usage -if __name__ == "__main__": - model = GeneticDiffusionModule( - channels=3, training=True - ) # Assuming 3D coordinates - input_coords = torch.randn( - 10, 100, 100, 3 - ) # Example with batch size 10 and 100 atoms - ground_truth = torch.randn( - 10, 100, 100, 3 - ) # Example with batch size 10 and 100 atoms - output_coords, loss = model(input_coords, ground_truth) - print(output_coords) # Should be (10, 100, 3) - print(loss) # Should be a scalar (MSE loss value +# # Example usage +# if __name__ == "__main__": +# model = GeneticDiffusionModule( +# channels=3, training=True +# ) # Assuming 3D coordinates +# input_coords = torch.randn( +# 10, 100, 100, 3 +# ) # Example with batch size 10 and 100 atoms +# ground_truth = torch.randn( +# 10, 100, 100, 3 +# ) # Example with batch size 10 and 100 atoms +# output_coords, loss = model(input_coords, ground_truth) +# print(output_coords) # Should be (10, 100, 3) +# print(loss) # Should be a scalar (MSE loss value