From d9dc245ed96b14cfcc56c113c5fc40890336557b Mon Sep 17 00:00:00 2001 From: thelfer1 Date: Wed, 7 Aug 2024 06:50:49 -0400 Subject: [PATCH] minor bug fix --- SuperResolution/models.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/SuperResolution/models.py b/SuperResolution/models.py index e9c0998..3b41c40 100644 --- a/SuperResolution/models.py +++ b/SuperResolution/models.py @@ -225,9 +225,12 @@ def check_performance( # Disable gradient computation for evaluation with torch.no_grad(): # Iterate over batches in the data loader - for (X_val_batch,) in tqdm(data_loader): + for (y_val_batch,) in tqdm(data_loader): # Move the batch to the specified device - X_val_batch = X_val_batch.to(device) + y_val_batch = y_val_batch.to(device) + X_val_batch = y_val_batch[ + :, :, ::downsample, ::downsample, ::downsample + ].clone() # Forward pass: get predictions from the network y_val_pred, y_val_interp = net(X_val_batch)