Skip to content

Commit

Permalink
Minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Oct 20, 2024
1 parent 8d7cd07 commit d7ca4a5
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions direct/nn/mri_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,16 @@ def _do_iteration(
else data["reference_image"].tile((1, registered_image.shape[1], *([1] * len(shape[1:]))))
),
)
if "displacement_field" in data:
target_displacement_field = data["displacement_field"]
else:
target_displacement_field = None
loss_dict = self.compute_loss_on_data(
loss_dict,
loss_fns,
data,
output_displacement_field=displacement_field,
target_displacement_field=data["displacement_field"],
target_displacement_field=target_displacement_field,
)
loss_dict = self.compute_loss_on_data(loss_dict, loss_fns, data, output_image, output_kspace)
regularizer_dict = self.compute_loss_on_data(
Expand Down Expand Up @@ -1347,9 +1351,7 @@ def compute_loss_on_data(
elif "displacement_field" in key:
if output_displacement_field is not None:
output = output_displacement_field
target = (
data["displacement_field"] if target_displacement_field is None else target_displacement_field
)
target = target_displacement_field
reconstruction_size = data.get("reconstruction_size", None)
else:
continue
Expand Down

0 comments on commit d7ca4a5

Please sign in to comment.