diff --git a/tomotwin/modules/training/torchtrainer.py b/tomotwin/modules/training/torchtrainer.py index 4ce6522..0e9207e 100644 --- a/tomotwin/modules/training/torchtrainer.py +++ b/tomotwin/modules/training/torchtrainer.py @@ -217,10 +217,10 @@ def classification_f1_score(self, test_loader: DataLoader) -> float: with torch.no_grad(): for _, batch in enumerate(t): - anchor_vol = batch["anchor"].to(self.device, non_blocking=True) - positive_vol = batch["positive"].to(self.device, non_blocking=True) - negative_vol = batch["negative"].to(self.device, non_blocking=True) - full_input = torch.cat((anchor_vol,positive_vol,negative_vol), dim=0) + anchor_vol = batch["anchor"] + positive_vol = batch["positive"] + negative_vol = batch["negative"] + full_input = torch.cat((anchor_vol,positive_vol,negative_vol), dim=0).to(self.device, non_blocking=True) filenames = batch["filenames"] with autocast(): out = self.model.forward(full_input) @@ -257,10 +257,10 @@ def run_batch(self, batch: Dict): :param batch: Dictionary with batch data :return: Loss of the batch """ - anchor_vol = batch["anchor"].to(self.device, non_blocking=True) - positive_vol = batch["positive"].to(self.device, non_blocking=True) - negative_vol = batch["negative"].to(self.device, non_blocking=True) - full_input = torch.cat((anchor_vol,positive_vol,negative_vol), dim=0) + anchor_vol = batch["anchor"] + positive_vol = batch["positive"] + negative_vol = batch["negative"] + full_input = torch.cat((anchor_vol,positive_vol,negative_vol), dim=0).to(self.device, non_blocking=True) with autocast(): out = self.model.forward(full_input) out = torch.split(out, anchor_vol.shape[0], dim=0)