diff --git a/pyproject.toml b/pyproject.toml index b38b1c85..39687c3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ dev = [ "pre-commit", "pytest", "pytest-cov", + "onnx", "sybil", # doctesting ] diff --git a/src/careamics/models/layers.py b/src/careamics/models/layers.py index 3cc621fe..40e9211e 100644 --- a/src/careamics/models/layers.py +++ b/src/careamics/models/layers.py @@ -459,7 +459,8 @@ def __init__( self.stride = stride self.max_pool_size = max_pool_size self.ceil_mode = ceil_mode - self.kernel = _get_pascal_kernel_nd(kernel_size, norm=True, dim=self.dim) + kernel = _get_pascal_kernel_nd(kernel_size, norm=True, dim=self.dim) + self.register_buffer("kernel", kernel, persistent=False) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass of the function. @@ -474,11 +475,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: torch.Tensor Output tensor. """ - self.kernel = torch.as_tensor(self.kernel, device=x.device, dtype=x.dtype) + kernel = self.kernel.to(dtype=x.dtype) + num_channels = int(x.size(1)) if self.dim == 2: return _max_blur_pool_by_kernel2d( x, - self.kernel.repeat((x.size(1), 1, 1, 1)), + kernel.repeat((num_channels, 1, 1, 1)), self.stride, self.max_pool_size, self.ceil_mode, @@ -486,7 +488,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: else: return _max_blur_pool_by_kernel3d( x, - self.kernel.repeat((x.size(1), 1, 1, 1, 1)), + kernel.repeat((num_channels, 1, 1, 1, 1)), self.stride, self.max_pool_size, self.ceil_mode, diff --git a/tests/lightning/test_lightning_module_onnx_exportability.py b/tests/lightning/test_lightning_module_onnx_exportability.py new file mode 100644 index 00000000..b68679d9 --- /dev/null +++ b/tests/lightning/test_lightning_module_onnx_exportability.py @@ -0,0 +1,59 @@ +import pytest +import torch +from onnx import checker + +from careamics.config import FCNAlgorithmConfig +from careamics.lightning.lightning_module import FCNModule + + +@pytest.mark.parametrize( + "algorithm, architecture, conv_dim, n2v2, loss, shape", + [ + ("n2n", "UNet", 2, False, "mae", (16, 16)), # n2n 2D model + ("n2n", "UNet", 3, False, "mae", (8, 16, 16)), # n2n 3D model + ("n2v", "UNet", 2, False, "n2v", (16, 16)), # n2v 2D model + ("n2v", "UNet", 3, False, "n2v", (8, 16, 16)), # n2v 3D model + ("n2v", "UNet", 2, True, "n2v", (16, 16)), # n2v2 2D model + ("n2v", "UNet", 3, True, "n2v", (8, 16, 16)), # n2v2 3D model + ], +) +def test_onnx_export(tmp_path, algorithm, architecture, conv_dim, n2v2, loss, shape): + """Test model exportability to ONNX.""" + + algo_config = { + "algorithm": algorithm, + "model": { + "architecture": architecture, + "conv_dims": conv_dim, + "in_channels": 1, + "num_classes": 1, + "depth": 3, + "n2v2": n2v2, + }, + "loss": loss, + } + algo_config = FCNAlgorithmConfig(**algo_config) + + # instantiate CAREamicsKiln + model = FCNModule(algo_config) + # set model to evaluation mode to avoid batch dimension error + model.model.eval() + # create a sample input of BC(Z)XY + x = torch.rand((1, 1, *shape)) + + # create dynamic axes from the shape of the x + dynamic_axes = {"input": {}, "output": {}} + for i in range(len(x.shape)): + dynamic_axes["input"][i] = f"dim_{i}" + dynamic_axes["output"][i] = f"dim_{i}" + + torch.onnx.export( + model, + x, + f"{tmp_path}/test_model.onnx", + input_names=["input"], # the model's input names + output_names=["output"], # the model's output names + dynamic_axes=dynamic_axes, # variable length axes, + ) + + checker.check_model(f"{tmp_path}/test_model.onnx")