Skip to content

Commit

Permalink
src/dnadiffusion rewrite (#150)
Browse files Browse the repository at this point in the history
* revamp src/dnadiffusion: removes lightning, adds diffusion class, new train loop
* add filtered dataset back to src/dnadiffusion/data
* move dnadiffusion.py to notebooks/ to fix tests/library import
  • Loading branch information
ssenan committed Jun 10, 2023
1 parent 9938d14 commit 8776054
Show file tree
Hide file tree
Showing 35 changed files with 2,107 additions and 1,820 deletions.
1,265 changes: 1,265 additions & 0 deletions notebooks/dnadiffusion.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -1485,7 +1485,7 @@ def train(self):
print("KL_SHUFFLE", self.shuffle_kl, "KL")

if epoch != 0 and epoch % 500 == 0 and self.accelerator.is_main_process:
model_path = f"./models/epoch_{str(epoch)}_{self.model_name}.pt"
model_path = f"./models/epoch_{epoch!s}_{self.model_name}.pt"
self.save(epoch, model_path)


Expand Down
File renamed without changes.
File renamed without changes.
65 changes: 65 additions & 0 deletions sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch

from dnadiffusion.data.dataloader import load_data
from dnadiffusion.models.diffusion import Diffusion
from dnadiffusion.models.unet import UNet
from dnadiffusion.utils.sample_util import create_sample


def sample(model_path: str, num_samples: int = 1000):
# Instantiating data and model

print("Loading data")
encode_data, _ = load_data(
data_path="dnadiffusion/data/K562_hESCT0_HepG2_GM12878_12k_sequences_per_group.txt",
saved_data_path="dnadiffusion/data/encode_data.pkl",
subset_list=[
"GM12878_ENCLB441ZZZ",
"hESCT0_ENCLB449ZZZ",
"K562_ENCLB843GMH",
"HepG2_ENCLB029COU",
],
limit_total_sequences=0,
num_sampling_to_compare_cells=1000,
load_saved_data=True,
batch_size=240,
)

print("Instantiating unet")
unet = UNet(
dim=200,
channels=1,
dim_mults=(1, 2, 4),
resnet_block_groups=4,
)

print("Instantiating diffusion class")
diffusion = Diffusion(
unet,
timesteps=50,
)

# Load checkpoint
print("Loading checkpoint")
checkpoint_dict = torch.load(model_path)
diffusion.load_state_dict(checkpoint_dict["model"])

# Generating cell specific samples
cell_num_list = encode_data["tag_to_numeric"].values()

for i in cell_num_list:
print(f"Generating {num_samples} samples for cell {encode_data['numeric_to_tag'][i]}")
create_sample(
diffusion,
conditional_numeric_to_tag=encode_data["numeric_to_tag"],
cell_types=encode_data["cell_types"],
num_sampling_to_compare_cells=int(num_samples / 10),
specific_group=True,
group_number=i,
cond_weight_to_metric=1,
save_timestep_dataframe=True,
)


if __name__ == "__main__":
sample()
53 changes: 0 additions & 53 deletions src/dnadiffusion/callbacks/ema.py

This file was deleted.

65 changes: 0 additions & 65 deletions src/dnadiffusion/callbacks/sampling.py

This file was deleted.

File renamed without changes.
101 changes: 0 additions & 101 deletions src/dnadiffusion/configs.py

This file was deleted.

38 changes: 0 additions & 38 deletions src/dnadiffusion/data/README.md

This file was deleted.

Loading

0 comments on commit 8776054

Please sign in to comment.