From adb18fd8bf29433c9bdd96e69af3209463717c73 Mon Sep 17 00:00:00 2001 From: George Yiasemis Date: Thu, 24 Oct 2024 16:43:11 +0200 Subject: [PATCH] Allow for non end to end training --- direct/nn/mri_models.py | 39 ++++++++++++++++--- direct/nn/registration/config.py | 1 + direct/nn/registration/registration.py | 1 + direct/nn/vsharp/vsharp_engine.py | 54 ++++++++++++++++++-------- 4 files changed, 73 insertions(+), 22 deletions(-) diff --git a/direct/nn/mri_models.py b/direct/nn/mri_models.py index e0541209..26ec0286 100644 --- a/direct/nn/mri_models.py +++ b/direct/nn/mri_models.py @@ -127,7 +127,16 @@ def _do_iteration( output_kspace: TensorOrNone with autocast(enabled=self.mixed_precision): + + if self.ndim == 3 and "registration_model" in self.models: + # Freeze registration model weights + if self.cfg.additional_models.registration_model.train_end_to_end: + if len(list(self.models["registration_model"].parameters())) > 0: + for param in self.models["registration_model"].parameters(): + param.requires_grad = False + data["sensitivity_map"] = self.compute_sensitivity_map(data["sensitivity_map"]) + data = self.perform_sampling(data) output_image, output_kspace = self.forward_function(data) output_image = T.modulus_if_complex(output_image, complex_axis=self._complex_dim) @@ -137,9 +146,33 @@ def _do_iteration( k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys() } + loss_dict = self.compute_loss_on_data(loss_dict, loss_fns, data, output_image, output_kspace) + regularizer_dict = self.compute_loss_on_data( + regularizer_dict, regularizer_fns, data, output_image, output_kspace + ) + if self.ndim == 3 and "registration_model" in self.models: + + if self.cfg.additional_models.registration_model.train_end_to_end: + if len(list(self.models["registration_model"].parameters())) > 0: + for param in self.models["registration_model"].parameters(): + param.requires_grad = True + for param in self.model.parameters(): + param.requires_grad = False + for model in self.models: + if model != "registration_model": + for param in self.models[model].parameters(): + param.requires_grad = False + # Perform registration and compute loss on registered image and displacement field - registered_image, displacement_field = self.do_registration(data, output_image) + registered_image, displacement_field = self.do_registration( + data, + ( + output_image.detach() + if self.cfg.additional_models.registration_model.train_end_to_end + else output_image + ), + ) # If DL-based model calculate loss if len(list(self.models["registration_model"].parameters())) > 0: @@ -166,10 +199,6 @@ def _do_iteration( output_displacement_field=displacement_field, target_displacement_field=target_displacement_field, ) - loss_dict = self.compute_loss_on_data(loss_dict, loss_fns, data, output_image, output_kspace) - regularizer_dict = self.compute_loss_on_data( - regularizer_dict, regularizer_fns, data, output_image, output_kspace - ) loss = sum(loss_dict.values()) + sum(regularizer_dict.values()) # type: ignore diff --git a/direct/nn/registration/config.py b/direct/nn/registration/config.py index 28a46eb6..b42c2756 100644 --- a/direct/nn/registration/config.py +++ b/direct/nn/registration/config.py @@ -49,3 +49,4 @@ class UnetRegistration2dModelConfig(RegistrationModelConfig): unet_num_pool_layers: int = 4 unet_dropout_probability: float = 0.0 unet_normalized: bool = False + train_end_to_end: bool = True diff --git a/direct/nn/registration/registration.py b/direct/nn/registration/registration.py index 9a50c778..960de235 100644 --- a/direct/nn/registration/registration.py +++ b/direct/nn/registration/registration.py @@ -274,6 +274,7 @@ def __init__( unet_dropout_probability: float = 0.0, unet_normalized: bool = False, warp_num_integration_steps: int = 1, + **kwargs, ) -> None: """Inits :class:`UnetRegistration2dModel`. diff --git a/direct/nn/vsharp/vsharp_engine.py b/direct/nn/vsharp/vsharp_engine.py index 526ab0b1..9efb16e4 100644 --- a/direct/nn/vsharp/vsharp_engine.py +++ b/direct/nn/vsharp/vsharp_engine.py @@ -108,14 +108,49 @@ def _do_iteration( output_kspace: TensorOrNone with autocast(enabled=self.mixed_precision): + + if "registration_model" in self.models: + # Freeze registration model weights + if self.cfg.additional_models.registration_model.train_end_to_end: + if len(list(self.models["registration_model"].parameters())) > 0: + for param in self.models["registration_model"].parameters(): + param.requires_grad = False + output_images, output_kspace = self.forward_function(data) output_images = [T.modulus_if_complex(_, complex_axis=self._complex_dim) for _ in output_images] loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()} + auxiliary_loss_weights = torch.logspace(-1, 0, steps=len(output_images)).to(output_images[0]) + for i, output_image in enumerate(output_images): + loss_dict = self.compute_loss_on_data( + loss_dict, + loss_fns, + data, + output_image=output_image, + output_kspace=None, + weight=auxiliary_loss_weights[i], + ) + # Compute loss on k-space + loss_dict = self.compute_loss_on_data( + loss_dict, loss_fns, data, output_image=None, output_kspace=output_kspace + ) + if "registration_model" in self.models: - # Perform registration and compute loss on registered image and displacement field - registered_image, displacement_field = self.do_registration(data, output_images[-1]) + if self.cfg.additional_models.registration_model.train_end_to_end: + if len(list(self.models["registration_model"].parameters())) > 0: + for param in self.models["registration_model"].parameters(): + param.requires_grad = True + for param in self.model.parameters(): + param.requires_grad = False + for model in self.models: + if model != "registration_model": + for param in self.models[model].parameters(): + param.requires_grad = False + # Perform registration and compute loss on registered image and displacement field + registered_image, displacement_field = self.do_registration(data, output_images[-1].detach()) + else: + registered_image, displacement_field = self.do_registration(data, output_images[-1]) # If DL-based model calculate loss if len(list(self.models["registration_model"].parameters())) > 0: @@ -139,21 +174,6 @@ def _do_iteration( target_displacement_field=data["displacement_field"], ) - auxiliary_loss_weights = torch.logspace(-1, 0, steps=len(output_images)).to(output_images[0]) - for i, output_image in enumerate(output_images): - loss_dict = self.compute_loss_on_data( - loss_dict, - loss_fns, - data, - output_image=output_image, - output_kspace=None, - weight=auxiliary_loss_weights[i], - ) - # Compute loss on k-space - loss_dict = self.compute_loss_on_data( - loss_dict, loss_fns, data, output_image=None, output_kspace=output_kspace - ) - loss = sum(loss_dict.values()) # type: ignore if self.model.training: