Skip to content

Commit

Permalink
Merge pull request #105 from MPI-Dortmund/train-speed-2
Browse files Browse the repository at this point in the history
Minor optimization of training
  • Loading branch information
thorstenwagner authored Oct 2, 2024
2 parents ada76c7 + 9ff6d46 commit 12d6323
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions tomotwin/modules/training/torchtrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,10 @@ def classification_f1_score(self, test_loader: DataLoader) -> float:

with torch.no_grad():
for _, batch in enumerate(t):
anchor_vol = batch["anchor"].to(self.device, non_blocking=True)
positive_vol = batch["positive"].to(self.device, non_blocking=True)
negative_vol = batch["negative"].to(self.device, non_blocking=True)
full_input = torch.cat((anchor_vol,positive_vol,negative_vol), dim=0)
anchor_vol = batch["anchor"]
positive_vol = batch["positive"]
negative_vol = batch["negative"]
full_input = torch.cat((anchor_vol,positive_vol,negative_vol), dim=0).to(self.device, non_blocking=True)
filenames = batch["filenames"]
with autocast():
out = self.model.forward(full_input)
Expand Down Expand Up @@ -257,10 +257,10 @@ def run_batch(self, batch: Dict):
:param batch: Dictionary with batch data
:return: Loss of the batch
"""
anchor_vol = batch["anchor"].to(self.device, non_blocking=True)
positive_vol = batch["positive"].to(self.device, non_blocking=True)
negative_vol = batch["negative"].to(self.device, non_blocking=True)
full_input = torch.cat((anchor_vol,positive_vol,negative_vol), dim=0)
anchor_vol = batch["anchor"]
positive_vol = batch["positive"]
negative_vol = batch["negative"]
full_input = torch.cat((anchor_vol,positive_vol,negative_vol), dim=0).to(self.device, non_blocking=True)
with autocast():
out = self.model.forward(full_input)
out = torch.split(out, anchor_vol.shape[0], dim=0)
Expand Down

0 comments on commit 12d6323

Please sign in to comment.