Skip to content

Commit

Permalink
Update UNETR to enable resize to longest side (#192)
Browse files Browse the repository at this point in the history
Add resize functionality to the UNETR and use it by default
  • Loading branch information
constantinpape committed Jan 1, 2024
1 parent 7111f72 commit fad7fb6
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 17 deletions.
18 changes: 14 additions & 4 deletions test/model/test_unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,24 @@ def test_unetr(self):
from torch_em.model import UNETR

model = UNETR()
self._test_net(model, (1, 3, 256, 256))
self._test_net(model, (1, 3, 512, 512))

def test_unetr_no_resize(self):
from torch_em.model import UNETR

model = UNETR(resize_input=False)
self._test_net(model, (1, 3, 512, 512))

@unittest.skipIf(micro_sam is None, "Needs micro_sam")
def test_unetr_from_sam(self):
from torch_em.model import build_unetr_with_sam_intialization
from torch_em.model import UNETR
from micro_sam.util import models

model_registry = models()
checkpoint = model_registry.fetch("vit_b")

model = build_unetr_with_sam_intialization()
self._test_net(model, (1, 3, 256, 256))
model = UNETR(encoder_checkpoint=checkpoint)
self._test_net(model, (1, 3, 512, 512))


if __name__ == "__main__":
Expand Down
52 changes: 39 additions & 13 deletions torch_em/model/unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):
)
encoder_state = model.image_encoder.state_dict()
except Exception:
# If we have a MAE encoder, then we directly load the encoder state
# from the checkpoint.
# Try loading the encoder state directly from a checkpoint.
encoder_state = torch.load(checkpoint)

elif backbone == "mae":
Expand Down Expand Up @@ -68,16 +67,18 @@ def __init__(
out_channels: int = 1,
use_sam_stats: bool = False,
use_mae_stats: bool = False,
resize_input: bool = True,
encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
final_activation: Optional[Union[str, nn.Module]] = None,
use_skip_connection: bool = True,
embed_dim: Optional[int] = None
embed_dim: Optional[int] = None,
) -> None:
super().__init__()

self.use_sam_stats = use_sam_stats
self.use_mae_stats = use_mae_stats
self.use_skip_connection = use_skip_connection
self.resize_input = resize_input

if isinstance(encoder, str): # "vit_b" / "vit_l" / "vit_h"
print(f"Using {encoder} from {backbone.upper()}")
Expand Down Expand Up @@ -152,25 +153,49 @@ def _get_activation(self, activation):
raise ValueError(f"Invalid activation: {activation}")
return return_activation()

@staticmethod
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
"""Compute the output size given input size and target long side length.
"""
scale = long_side_length * 1.0 / max(oldh, oldw)
newh, neww = oldh * scale, oldw * scale
neww = int(neww + 0.5)
newh = int(newh + 0.5)
return (newh, neww)

def resize_longest_side(self, image: torch.Tensor) -> torch.Tensor:
"""Resizes the image so that the longest side has the correct length.
Expects batched images with shape BxCxHxW and float format.
"""
target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.encoder.img_size)
return F.interpolate(
image, target_size, mode="bilinear", align_corners=False, antialias=True
)

def preprocess(self, x: torch.Tensor) -> torch.Tensor:
device = "cuda" if torch.cuda.is_available() else "cpu"
device = x.device

if self.use_sam_stats:
pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device)
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device)
pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(1, -1, 1, 1).to(device)
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(1, -1, 1, 1).to(device)
elif self.use_mae_stats:
# TODO: add mean std from mae experiments (or open up arguments for this)
raise NotImplementedError
else:
pixel_mean = torch.Tensor([0.0, 0.0, 0.0]).view(-1, 1, 1).to(device)
pixel_std = torch.Tensor([1.0, 1.0, 1.0]).view(-1, 1, 1).to(device)
pixel_mean = torch.Tensor([0.0, 0.0, 0.0]).view(1, -1, 1, 1).to(device)
pixel_std = torch.Tensor([1.0, 1.0, 1.0]).view(1, -1, 1, 1).to(device)

if self.resize_input:
x = self.resize_longest_side(x)
input_shape = x.shape[-2:]

x = (x - pixel_mean) / pixel_std
h, w = x.shape[-2:]
padh = self.encoder.img_size - h
padw = self.encoder.img_size - w
x = F.pad(x, (0, padw, 0, padh))
return x
return x, input_shape

def postprocess_masks(
self,
Expand All @@ -189,10 +214,11 @@ def postprocess_masks(
return masks

def forward(self, x):
org_shape = x.shape[-2:]
original_shape = x.shape[-2:]

# backbone used for reshaping inputs to the desired "encoder" shape
x = torch.stack([self.preprocess(e) for e in x], dim=0)
# Reshape the inputs to the shape expected by the encoder
# and normalize the inputs if normalization is part of the model.
x, input_shape = self.preprocess(x)

use_skip_connection = getattr(self, "use_skip_connection", True)

Expand Down Expand Up @@ -236,7 +262,7 @@ def forward(self, x):
if self.final_activation is not None:
x = self.final_activation(x)

x = self.postprocess_masks(x, org_shape, org_shape)
x = self.postprocess_masks(x, input_shape, original_shape)
return x


Expand Down

0 comments on commit fad7fb6

Please sign in to comment.