Skip to content

Commit

Permalink
add load function to train_util to resume training (#152)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssenan committed Jun 23, 2023
1 parent 8776054 commit d93906b
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/dnadiffusion/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def load_data(
limit_total_sequences: int = 0,
num_sampling_to_compare_cells: int = 1000,
load_saved_data: bool = False,
batch_size: int = 240,
batch_size: int = 960,
):
# Preprocessing data
if load_saved_data:
Expand Down Expand Up @@ -66,7 +66,7 @@ def load_data(
# Wrapping data into dataloader
tf = T.Compose([T.ToTensor()])
seq_dataset = SequenceDataset(seqs=X_train, c=x_train_cell_type, transform=tf)
train_dl = DataLoader(seq_dataset, batch_size, shuffle=True, num_workers=96, pin_memory=True)
train_dl = DataLoader(seq_dataset, batch_size, shuffle=True, num_workers=8, pin_memory=True)

# Collecting variables into a dict
encode_data_dict = {
Expand Down
23 changes: 18 additions & 5 deletions src/dnadiffusion/utils/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,34 +45,36 @@ def __init__(
self.train_kl, self.test_kl, self.shuffle_kl = 1, 1, 1
self.seq_similarity = 1

self.start_epoch = 0

def train_loop(self):
# Prepare for training
self.model, self.optimizer, self.train_dl = self.accelerator.prepare(self.model, self.optimizer, self.train_dl)

def train_loop(self):
# Initialize wandb
if self.accelerator.is_main_process:
self.accelerator.init_trackers(
"dnadiffusion",
init_kwargs={"wandb": {"notes": "testing wandb accelerate script"}},
)

for epoch in tqdm(range(self.epochs)):
for epoch in tqdm(range(self.start_epoch, self.epochs)):
self.model.train()

# Getting loss of current batch
for _, batch in enumerate(self.train_dl):
loss = self.train_step(batch)

# Logging loss
if epoch % self.loss_show_epoch == 0 and self.accelerator.is_main_process:
if (epoch + 1) % self.loss_show_epoch == 0 and self.accelerator.is_main_process:
self.log_step(loss, epoch)

# Sampling
if epoch % self.sample_epoch == 0 and self.accelerator.is_main_process:
if (epoch + 1) % self.sample_epoch == 0 and self.accelerator.is_main_process:
self.sample()

# Saving model
if epoch % self.save_epoch == 0:
if (epoch + 1) % self.save_epoch == 0 and self.accelerator.is_main_process:
self.save_model(epoch)

def train_step(self, batch):
Expand Down Expand Up @@ -138,3 +140,14 @@ def save_model(self, epoch):
checkpoint_dict,
f"dnadiffusion/checkpoints/epoch_{epoch}_{self.model_name}.pt",
)

def load(self, path):
checkpoint_dict = torch.load(path)
self.model.load_state_dict(checkpoint_dict["model"])
self.optimizer.load_state_dict(checkpoint_dict["optimizer"])
self.start_epoch = checkpoint_dict["epoch"]

if self.accelerator.is_main_process:
self.ema_model.load_state_dict(checkpoint_dict["ema_model"])

self.train_loop()
2 changes: 1 addition & 1 deletion train_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def train():
limit_total_sequences=0,
num_sampling_to_compare_cells=1000,
load_saved_data=True,
batch_size=240,
batch_size=960,
)

unet = UNet(
Expand Down

0 comments on commit d93906b

Please sign in to comment.