diff --git a/SuperResolution/models.py b/SuperResolution/models.py index e9c0998..3b41c40 100644 --- a/SuperResolution/models.py +++ b/SuperResolution/models.py @@ -225,9 +225,12 @@ def check_performance( # Disable gradient computation for evaluation with torch.no_grad(): # Iterate over batches in the data loader - for (X_val_batch,) in tqdm(data_loader): + for (y_val_batch,) in tqdm(data_loader): # Move the batch to the specified device - X_val_batch = X_val_batch.to(device) + y_val_batch = y_val_batch.to(device) + X_val_batch = y_val_batch[ + :, :, ::downsample, ::downsample, ::downsample + ].clone() # Forward pass: get predictions from the network y_val_pred, y_val_interp = net(X_val_batch)