From 5ca113bd68541b2b7f6314a6c716cf75401e2175 Mon Sep 17 00:00:00 2001 From: George Yiasemis Date: Mon, 4 Nov 2024 20:57:27 +0100 Subject: [PATCH] isort, add vit reg --- direct/nn/registration/config.py | 35 +++-- direct/nn/registration/registration.py | 177 ++++++++++++++++++++++++- direct/nn/registration/voxelmorph.py | 5 +- 3 files changed, 206 insertions(+), 11 deletions(-) diff --git a/direct/nn/registration/config.py b/direct/nn/registration/config.py index ba471312..ae9f334c 100644 --- a/direct/nn/registration/config.py +++ b/direct/nn/registration/config.py @@ -12,12 +12,13 @@ @dataclass class RegistrationModelConfig(ModelConfig): warp_num_integration_steps: int = 1 + train_end_to_end: bool = False @dataclass class OpticalFlowILKRegistration2dModelConfig(RegistrationModelConfig): - radius: int = 7 - num_warp: int = 10 + radius: int = 5 + num_warp: int = 3 gaussian: bool = False prefilter: bool = True @@ -26,10 +27,10 @@ class OpticalFlowILKRegistration2dModelConfig(RegistrationModelConfig): class OpticalFlowTVL1Registration2dModelConfig(RegistrationModelConfig): attachment: float = 15 tightness: float = 0.3 - num_warp: int = 5 - num_iter: int = 10 - tol: float = 1e-3 - prefilter: bool = True + num_warp: int = 3 + num_iter: int = 5 + tol: float = 1e-2 + prefilter: bool = False @dataclass @@ -49,13 +50,31 @@ class UnetRegistration2dModelConfig(RegistrationModelConfig): unet_num_pool_layers: int = 4 unet_dropout_probability: float = 0.0 unet_normalized: bool = False - train_end_to_end: bool = True @dataclass -class VxmDenseModelConfig(RegistrationModelConfig): +class VxmDenseConfig(RegistrationModelConfig): inshape: tuple = (512, 246) nb_unet_features: int = 16 nb_unet_levels: int = 4 nb_unet_conv_per_level: int = 1 int_downsize: int = 2 + + +@dataclass +class ViTRegistration2dModelConfig(RegistrationModelConfig): + max_seq_len: int = 12 + average_size: tuple[int, int] = (320, 320) + patch_size: tuple[int, int] = (16, 16) + embedding_dim: int = 64 + depth: int = 8 + num_heads: int = 9 + mlp_ratio: float = 4.0 + qkv_bias: bool = False + qk_scale: Optional[float] = None + drop_rate: float = 0.0 + attn_drop_rate: float = 0.0 + dropout_path_rate: float = 0.0 + gpsa_interval: tuple[int, int] = (-1, -1) + locality_strength: float = 1.0 + use_pos_embedding: bool = True diff --git a/direct/nn/registration/registration.py b/direct/nn/registration/registration.py index 79e2fb4b..136c699b 100644 --- a/direct/nn/registration/registration.py +++ b/direct/nn/registration/registration.py @@ -9,13 +9,13 @@ import torch.nn as nn from direct.nn.registration.voxelmorph import VxmDense +from direct.nn.transformers.vit import VisionTransformer2D from direct.nn.unet.unet_2d import NormUnetModel2d, UnetModel2d from direct.registration.demons import DemonsFilterType, multiscale_demons_displacement from direct.registration.optical_flow import OpticalFlowEstimatorType, optical_flow_displacement from direct.registration.registration import DISCPLACEMENT_FIELD_2D_DIMENSIONS from direct.registration.warp import warp - __all__ = [ "OpticalFlowILKRegistration2dModel", "OpticalFlowTVL1Registration2dModel", @@ -145,6 +145,7 @@ def __init__( gaussian: bool = False, prefilter: bool = True, warp_num_integration_steps: int = 1, + **kwargs, ) -> None: super().__init__( estimator_type=OpticalFlowEstimatorType.ILK, @@ -167,6 +168,7 @@ def __init__( tol: float = 1e-3, prefilter: bool = True, warp_num_integration_steps: int = 1, + **kwargs, ) -> None: super().__init__( estimator_type=OpticalFlowEstimatorType.TV_L1, @@ -191,6 +193,7 @@ def __init__( demons_intensity_difference_threshold: float | None = None, demons_maximum_rms_error: float | None = None, warp_num_integration_steps: int = 1, + **kwargs, ) -> None: """Inits :class:`DemonsRegistration2dModel`. @@ -355,3 +358,175 @@ def forward(self, moving_image: torch.Tensor, reference_image: torch.Tensor) -> warped_image.reshape(batch_size, seq_len, height, width), displacement_field.reshape(batch_size, seq_len, DISCPLACEMENT_FIELD_2D_DIMENSIONS, height, width), ) + + +class ViTRegistration2dModel(nn.Module): + """Vision Transformer registration model for 2D images. + + Parameters + ---------- + max_seq_len : int + Maximum sequence length expected in the moving image. + average_size : int or tuple[int, int] + The average size of the input image. If an int is provided, this will be determined by the + `dimensionality`, i.e., (average_size, average_size) for 2D and + (average_size, average_size, average_size) for 3D. Default: 320. + patch_size : int or tuple[int, int] + The size of the patch. If an int is provided, this will be determined by the `dimensionality`, i.e., + (patch_size, patch_size) for 2D and (patch_size, patch_size, patch_size) for 3D. Default: 16. + embedding_dim : int + Dimension of the output embedding. + depth : int + Number of transformer blocks. + num_heads : int + Number of attention heads. + mlp_ratio : float + The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0. + qkv_bias : bool + Whether to add bias to the query, key, and value projections. Default: False. + qk_scale : float + The scale factor for the query-key dot product. Default: None. + drop_rate : float + The dropout probability for all dropout layers except dropout_path. Default: 0.0. + attn_drop_rate : float + The dropout probability for the attention layer. Default: 0.0. + dropout_path_rate : float + The dropout probability for the dropout path. Default: 0.0. + gpsa_interval : tuple[int, int] + The interval of the blocks where the GPSA layer is used. Default: (-1, -1). + locality_strength : float + The strength of the locality assumption in initialization. Default: 1.0. + use_pos_embedding : bool + Whether to use positional embeddings. Default: True. + warp_num_integration_steps : int + Number of integration steps to perform when warping the moving image. Default: 1. + """ + + def __init__( + self, + max_seq_len: int, + average_size: int | tuple[int, int] = 320, + patch_size: int | tuple[int, int] = 16, + embedding_dim: int = 64, + depth: int = 8, + num_heads: int = 9, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_scale: float = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + dropout_path_rate: float = 0.0, + gpsa_interval: tuple[int, int] = (-1, -1), + locality_strength: float = 1.0, + use_pos_embedding: bool = True, + warp_num_integration_steps: int = 1, + **kwargs, + ) -> None: + """Inits :class:`ViTRegistration2dModel`. + + Parameters + ---------- + max_seq_len : int + Maximum sequence length expected in the moving image. + average_size : int or tuple[int, int] + The average size of the input image. If an int is provided, this will be determined by the + `dimensionality`, i.e., (average_size, average_size) for 2D and + (average_size, average_size, average_size) for 3D. Default: 320. + patch_size : int or tuple[int, int] + The size of the patch. If an int is provided, this will be determined by the `dimensionality`, i.e., + (patch_size, patch_size) for 2D and (patch_size, patch_size, patch_size) for 3D. Default: 16. + embedding_dim : int + Dimension of the output embedding. + depth : int + Number of transformer blocks. + num_heads : int + Number of attention heads. + mlp_ratio : float + The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0. + qkv_bias : bool + Whether to add bias to the query, key, and value projections. Default: False. + qk_scale : float + The scale factor for the query-key dot product. Default: None. + drop_rate : float + The dropout probability for all dropout layers except dropout_path. Default: 0.0. + attn_drop_rate : float + The dropout probability for the attention layer. Default: 0.0. + dropout_path_rate : float + The dropout probability for the dropout path. Default: 0.0. + gpsa_interval : tuple[int, int] + The interval of the blocks where the GPSA layer is used. Default: (-1, -1). + locality_strength : float + The strength of the locality assumption in initialization. Default: 1.0. + use_pos_embedding : bool + Whether to use positional embeddings. Default: True. + warp_num_integration_steps : int + Number of integration steps to perform when warping the moving image. Default: 1. + """ + super().__init__() + self.transformer = VisionTransformer2D( + average_img_size=average_size, + patch_size=patch_size, + in_channels=max_seq_len + 1, + out_channels=max_seq_len * DISCPLACEMENT_FIELD_2D_DIMENSIONS, + embedding_dim=embedding_dim, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + dropout_path_rate=dropout_path_rate, + gpsa_interval=gpsa_interval, + locality_strength=locality_strength, + use_pos_embedding=use_pos_embedding, + ) + self.max_seq_len = max_seq_len + self.warp_num_integration_steps = warp_num_integration_steps + + def forward(self, moving_image: torch.Tensor, reference_image: torch.Tensor) -> torch.Tensor: + """Forward pass of :class:`UnetRegistration2dModel`. + + Parameters + ---------- + moving_image : torch.Tensor + Moving image tensor of shape (batch_size, seq_len, height, width). + reference_image : torch.Tensor + Reference image tensor of shape (batch_size, height, width). + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + Tuple containing the warped image tensor of shape (batch_size, seq_len, height, width) + and the displacement field tensor of shape (batch_size, seq_len, 2, height, width). + """ + batch_size, seq_len, height, width = moving_image.shape + + # Pad the moving image to the maximum sequence length + x = nn.functional.pad(moving_image, (0, 0, 0, 0, 0, self.max_seq_len - moving_image.shape[1])) + # Add the reference image as the first channel + x = torch.cat((reference_image.unsqueeze(1), x), dim=1) + + # Forward pass through the model + displacement_field = self.transformer(x) + + # Model outputs the displacement field for each time step with 2 channels (x and y displacements) + displacement_field = displacement_field.reshape( + batch_size, self.max_seq_len, DISCPLACEMENT_FIELD_2D_DIMENSIONS, height, width + ) # (batch_size, max_seq_len, 2, height, width) + + # Crop the displacement field to the actual sequence length + displacement_field = displacement_field[:, :seq_len] # (batch_size, seq_len, 2, height, width) + + # Reshape the displacement field and moving image to be compatible with the warp module + displacement_field = displacement_field.reshape( + batch_size * seq_len, DISCPLACEMENT_FIELD_2D_DIMENSIONS, height, width + ) + moving_image = moving_image.reshape(batch_size * seq_len, 1, height, width) + + # Warp the moving image + warped_image = warp(moving_image, displacement_field, num_integration_steps=self.warp_num_integration_steps) + return ( + warped_image.reshape(batch_size, seq_len, height, width), + displacement_field.reshape(batch_size, seq_len, DISCPLACEMENT_FIELD_2D_DIMENSIONS, height, width), + ) diff --git a/direct/nn/registration/voxelmorph.py b/direct/nn/registration/voxelmorph.py index f06905ba..6847a945 100644 --- a/direct/nn/registration/voxelmorph.py +++ b/direct/nn/registration/voxelmorph.py @@ -168,10 +168,10 @@ def __init__( # cache some parameters self.half_res = half_res - enc_nf = [nb_features * (2 ** i) for i in range(nb_levels)] + enc_nf = [nb_features * (2**i) for i in range(nb_levels)] dec_nf = enc_nf[::-1] + [nb_features] - enc_nf = [nb_features * (2 ** i) for i in range(nb_levels)] + enc_nf = [nb_features * (2**i) for i in range(nb_levels)] dec_nf = enc_nf[::-1] + [nb_features] nb_dec_convs = len(enc_nf) @@ -278,6 +278,7 @@ def __init__( int_downsize=2, src_feats=1, trg_feats=1, + **kwargs, ) -> None: super().__init__()