Skip to content

Commit

Permalink
isort, add vit reg
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Nov 4, 2024
1 parent b16eff6 commit 5ca113b
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 11 deletions.
35 changes: 27 additions & 8 deletions direct/nn/registration/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
@dataclass
class RegistrationModelConfig(ModelConfig):
warp_num_integration_steps: int = 1
train_end_to_end: bool = False


@dataclass
class OpticalFlowILKRegistration2dModelConfig(RegistrationModelConfig):
radius: int = 7
num_warp: int = 10
radius: int = 5
num_warp: int = 3
gaussian: bool = False
prefilter: bool = True

Expand All @@ -26,10 +27,10 @@ class OpticalFlowILKRegistration2dModelConfig(RegistrationModelConfig):
class OpticalFlowTVL1Registration2dModelConfig(RegistrationModelConfig):
attachment: float = 15
tightness: float = 0.3
num_warp: int = 5
num_iter: int = 10
tol: float = 1e-3
prefilter: bool = True
num_warp: int = 3
num_iter: int = 5
tol: float = 1e-2
prefilter: bool = False


@dataclass
Expand All @@ -49,13 +50,31 @@ class UnetRegistration2dModelConfig(RegistrationModelConfig):
unet_num_pool_layers: int = 4
unet_dropout_probability: float = 0.0
unet_normalized: bool = False
train_end_to_end: bool = True


@dataclass
class VxmDenseModelConfig(RegistrationModelConfig):
class VxmDenseConfig(RegistrationModelConfig):
inshape: tuple = (512, 246)
nb_unet_features: int = 16
nb_unet_levels: int = 4
nb_unet_conv_per_level: int = 1
int_downsize: int = 2


@dataclass
class ViTRegistration2dModelConfig(RegistrationModelConfig):
max_seq_len: int = 12
average_size: tuple[int, int] = (320, 320)
patch_size: tuple[int, int] = (16, 16)
embedding_dim: int = 64
depth: int = 8
num_heads: int = 9
mlp_ratio: float = 4.0
qkv_bias: bool = False
qk_scale: Optional[float] = None
drop_rate: float = 0.0
attn_drop_rate: float = 0.0
dropout_path_rate: float = 0.0
gpsa_interval: tuple[int, int] = (-1, -1)
locality_strength: float = 1.0
use_pos_embedding: bool = True
177 changes: 176 additions & 1 deletion direct/nn/registration/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
import torch.nn as nn

from direct.nn.registration.voxelmorph import VxmDense
from direct.nn.transformers.vit import VisionTransformer2D
from direct.nn.unet.unet_2d import NormUnetModel2d, UnetModel2d
from direct.registration.demons import DemonsFilterType, multiscale_demons_displacement
from direct.registration.optical_flow import OpticalFlowEstimatorType, optical_flow_displacement
from direct.registration.registration import DISCPLACEMENT_FIELD_2D_DIMENSIONS
from direct.registration.warp import warp


__all__ = [
"OpticalFlowILKRegistration2dModel",
"OpticalFlowTVL1Registration2dModel",
Expand Down Expand Up @@ -145,6 +145,7 @@ def __init__(
gaussian: bool = False,
prefilter: bool = True,
warp_num_integration_steps: int = 1,
**kwargs,
) -> None:
super().__init__(
estimator_type=OpticalFlowEstimatorType.ILK,
Expand All @@ -167,6 +168,7 @@ def __init__(
tol: float = 1e-3,
prefilter: bool = True,
warp_num_integration_steps: int = 1,
**kwargs,
) -> None:
super().__init__(
estimator_type=OpticalFlowEstimatorType.TV_L1,
Expand All @@ -191,6 +193,7 @@ def __init__(
demons_intensity_difference_threshold: float | None = None,
demons_maximum_rms_error: float | None = None,
warp_num_integration_steps: int = 1,
**kwargs,
) -> None:
"""Inits :class:`DemonsRegistration2dModel`.
Expand Down Expand Up @@ -355,3 +358,175 @@ def forward(self, moving_image: torch.Tensor, reference_image: torch.Tensor) ->
warped_image.reshape(batch_size, seq_len, height, width),
displacement_field.reshape(batch_size, seq_len, DISCPLACEMENT_FIELD_2D_DIMENSIONS, height, width),
)


class ViTRegistration2dModel(nn.Module):
"""Vision Transformer registration model for 2D images.
Parameters
----------
max_seq_len : int
Maximum sequence length expected in the moving image.
average_size : int or tuple[int, int]
The average size of the input image. If an int is provided, this will be determined by the
`dimensionality`, i.e., (average_size, average_size) for 2D and
(average_size, average_size, average_size) for 3D. Default: 320.
patch_size : int or tuple[int, int]
The size of the patch. If an int is provided, this will be determined by the `dimensionality`, i.e.,
(patch_size, patch_size) for 2D and (patch_size, patch_size, patch_size) for 3D. Default: 16.
embedding_dim : int
Dimension of the output embedding.
depth : int
Number of transformer blocks.
num_heads : int
Number of attention heads.
mlp_ratio : float
The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0.
qkv_bias : bool
Whether to add bias to the query, key, and value projections. Default: False.
qk_scale : float
The scale factor for the query-key dot product. Default: None.
drop_rate : float
The dropout probability for all dropout layers except dropout_path. Default: 0.0.
attn_drop_rate : float
The dropout probability for the attention layer. Default: 0.0.
dropout_path_rate : float
The dropout probability for the dropout path. Default: 0.0.
gpsa_interval : tuple[int, int]
The interval of the blocks where the GPSA layer is used. Default: (-1, -1).
locality_strength : float
The strength of the locality assumption in initialization. Default: 1.0.
use_pos_embedding : bool
Whether to use positional embeddings. Default: True.
warp_num_integration_steps : int
Number of integration steps to perform when warping the moving image. Default: 1.
"""

def __init__(
self,
max_seq_len: int,
average_size: int | tuple[int, int] = 320,
patch_size: int | tuple[int, int] = 16,
embedding_dim: int = 64,
depth: int = 8,
num_heads: int = 9,
mlp_ratio: float = 4.0,
qkv_bias: bool = False,
qk_scale: float = None,
drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
dropout_path_rate: float = 0.0,
gpsa_interval: tuple[int, int] = (-1, -1),
locality_strength: float = 1.0,
use_pos_embedding: bool = True,
warp_num_integration_steps: int = 1,
**kwargs,
) -> None:
"""Inits :class:`ViTRegistration2dModel`.
Parameters
----------
max_seq_len : int
Maximum sequence length expected in the moving image.
average_size : int or tuple[int, int]
The average size of the input image. If an int is provided, this will be determined by the
`dimensionality`, i.e., (average_size, average_size) for 2D and
(average_size, average_size, average_size) for 3D. Default: 320.
patch_size : int or tuple[int, int]
The size of the patch. If an int is provided, this will be determined by the `dimensionality`, i.e.,
(patch_size, patch_size) for 2D and (patch_size, patch_size, patch_size) for 3D. Default: 16.
embedding_dim : int
Dimension of the output embedding.
depth : int
Number of transformer blocks.
num_heads : int
Number of attention heads.
mlp_ratio : float
The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0.
qkv_bias : bool
Whether to add bias to the query, key, and value projections. Default: False.
qk_scale : float
The scale factor for the query-key dot product. Default: None.
drop_rate : float
The dropout probability for all dropout layers except dropout_path. Default: 0.0.
attn_drop_rate : float
The dropout probability for the attention layer. Default: 0.0.
dropout_path_rate : float
The dropout probability for the dropout path. Default: 0.0.
gpsa_interval : tuple[int, int]
The interval of the blocks where the GPSA layer is used. Default: (-1, -1).
locality_strength : float
The strength of the locality assumption in initialization. Default: 1.0.
use_pos_embedding : bool
Whether to use positional embeddings. Default: True.
warp_num_integration_steps : int
Number of integration steps to perform when warping the moving image. Default: 1.
"""
super().__init__()
self.transformer = VisionTransformer2D(
average_img_size=average_size,
patch_size=patch_size,
in_channels=max_seq_len + 1,
out_channels=max_seq_len * DISCPLACEMENT_FIELD_2D_DIMENSIONS,
embedding_dim=embedding_dim,
depth=depth,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
dropout_path_rate=dropout_path_rate,
gpsa_interval=gpsa_interval,
locality_strength=locality_strength,
use_pos_embedding=use_pos_embedding,
)
self.max_seq_len = max_seq_len
self.warp_num_integration_steps = warp_num_integration_steps

def forward(self, moving_image: torch.Tensor, reference_image: torch.Tensor) -> torch.Tensor:
"""Forward pass of :class:`UnetRegistration2dModel`.
Parameters
----------
moving_image : torch.Tensor
Moving image tensor of shape (batch_size, seq_len, height, width).
reference_image : torch.Tensor
Reference image tensor of shape (batch_size, height, width).
Returns
-------
tuple[torch.Tensor, torch.Tensor]
Tuple containing the warped image tensor of shape (batch_size, seq_len, height, width)
and the displacement field tensor of shape (batch_size, seq_len, 2, height, width).
"""
batch_size, seq_len, height, width = moving_image.shape

# Pad the moving image to the maximum sequence length
x = nn.functional.pad(moving_image, (0, 0, 0, 0, 0, self.max_seq_len - moving_image.shape[1]))
# Add the reference image as the first channel
x = torch.cat((reference_image.unsqueeze(1), x), dim=1)

# Forward pass through the model
displacement_field = self.transformer(x)

# Model outputs the displacement field for each time step with 2 channels (x and y displacements)
displacement_field = displacement_field.reshape(
batch_size, self.max_seq_len, DISCPLACEMENT_FIELD_2D_DIMENSIONS, height, width
) # (batch_size, max_seq_len, 2, height, width)

# Crop the displacement field to the actual sequence length
displacement_field = displacement_field[:, :seq_len] # (batch_size, seq_len, 2, height, width)

# Reshape the displacement field and moving image to be compatible with the warp module
displacement_field = displacement_field.reshape(
batch_size * seq_len, DISCPLACEMENT_FIELD_2D_DIMENSIONS, height, width
)
moving_image = moving_image.reshape(batch_size * seq_len, 1, height, width)

# Warp the moving image
warped_image = warp(moving_image, displacement_field, num_integration_steps=self.warp_num_integration_steps)
return (
warped_image.reshape(batch_size, seq_len, height, width),
displacement_field.reshape(batch_size, seq_len, DISCPLACEMENT_FIELD_2D_DIMENSIONS, height, width),
)
5 changes: 3 additions & 2 deletions direct/nn/registration/voxelmorph.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,10 @@ def __init__(
# cache some parameters
self.half_res = half_res

enc_nf = [nb_features * (2 ** i) for i in range(nb_levels)]
enc_nf = [nb_features * (2**i) for i in range(nb_levels)]
dec_nf = enc_nf[::-1] + [nb_features]

enc_nf = [nb_features * (2 ** i) for i in range(nb_levels)]
enc_nf = [nb_features * (2**i) for i in range(nb_levels)]
dec_nf = enc_nf[::-1] + [nb_features]

nb_dec_convs = len(enc_nf)
Expand Down Expand Up @@ -278,6 +278,7 @@ def __init__(
int_downsize=2,
src_feats=1,
trg_feats=1,
**kwargs,
) -> None:
super().__init__()

Expand Down

0 comments on commit 5ca113b

Please sign in to comment.