Skip to content

Commit

Permalink
Inputs to BMZ should not be normalized (#132)
Browse files Browse the repository at this point in the history
### Description

The inputs to the BMZ model should not be normalized, as the BMZ core
applies the normalization itself. Since all our tests were using tensors
with values [-1, 1], this went unnoticed.

This PR uses a quick fix by applying denormalization to the patches
extracted from the datamodules. This part of the code was refactored
into a tested member function of `CAREamist`. I also scaled all tensors
in `CAREamit` tests to [0, 255] to be more realistic.

- **What**: Add denormalization to BMZ export input data.
- **Why**: The current code export normalized tensors, which leads to
failing validation of the BMZ archive
- **How**: `Denormalize` is now called in a the (refactored)
`_create_data_for_bmz` method in `CAREamist`.

### Changes Made

- **Added**: `_create_data_for_bmz` method in `CAREamist`, specific
tests.
- **Modified**: Existing `CAREamist` tests.

### Related Issues

Issues reported by users.

### Additional Notes and Examples

The choice here is **quick fix** one. The current way example inputs are
generated prior to exporting to the BMZ format is by pulling patches
from the existing dataloaders (either prediction, or training one):

https://github.com/CAREamics/careamics/blob/f613d724fd7e9f8f6ea102ee108701bb44dfae44/src/careamics/careamist.py#L694

This fix here is to simply call `Denormalize` on the array:
```python
# unpack a batch, ignore masks or targets
input_patch, *_ = next(iter(self.pred_datamodule.predict_dataloader()))

# convert torch.Tensor to numpy
input_patch = input_patch.numpy()

# denormalize
denormalize = Denormalize(
    mean=self.cfg.data_config.mean, std=self.cfg.data_config.std
)
input_patch, _ = denormalize(input_patch)
```

An alternative would be to find a way to create a new Dataloader with
the data and remove normalization, but this seems possible only after
training because the `CAREamist` remembers the source data (unless it is
a DataLoader). However, we want to be able to export after loading a
checkpoint for instance (choosing best model, loading it, exporting to
BMZ).

So I think this is the best pragmatic choice.

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [x] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [x] No change to the documentation needed
  • Loading branch information
jdeschamps authored Jun 6, 2024
1 parent a19a770 commit 763d965
Show file tree
Hide file tree
Showing 4 changed files with 259 additions and 72 deletions.
115 changes: 80 additions & 35 deletions src/careamics/careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
load_configuration,
)
from careamics.config.support import SupportedAlgorithm, SupportedData, SupportedLogger
from careamics.dataset.dataset_utils import reshape_array
from careamics.lightning_datamodule import CAREamicsTrainData
from careamics.lightning_module import CAREamicsModule
from careamics.lightning_prediction_datamodule import CAREamicsPredictData
from careamics.lightning_prediction_loop import CAREamicsPredictionLoop
from careamics.model_io import export_to_bmz, load_pretrained
from careamics.transforms import Denormalize
from careamics.utils import check_path_exists, get_logger

from .callbacks import HyperParametersCallback
Expand Down Expand Up @@ -651,58 +653,65 @@ def predict(
f"np.ndarray (got {type(source)})."
)

def export_to_bmz(
def _create_data_for_bmz(
self,
path: Union[Path, str],
name: str,
authors: List[dict],
input_array: Optional[np.ndarray] = None,
general_description: str = "",
channel_names: Optional[List[str]] = None,
data_description: Optional[str] = None,
) -> None:
"""Export the model to the BioImage Model Zoo format.
) -> np.ndarray:
"""Create data for BMZ export.
Input array must be of shape SC(Z)YX, with S and C singleton dimensions.
If no `input_array` is provided, this method checks if there is a prediction
datamodule, or a training data module, to extract a patch. If none exists,
then a random aray is created.
If there is a non-singleton batch dimension, this method returns only the first
element.
Parameters
----------
path : Union[Path, str]
Path to save the model.
name : str
Name of the model.
authors : List[dict]
List of authors of the model.
input_array : Optional[np.ndarray], optional
Input array for the model, must be of shape SC(Z)YX, by default None.
general_description : str
General description of the model, used in the metadata of the BMZ archive.
channel_names : Optional[List[str]], optional
Channel names, by default None.
data_description : Optional[str], optional
Description of the data, by default None.
Input array, by default None.
Returns
-------
np.ndarray
Input data for BMZ export.
Raises
------
ValueError
If mean and std are not provided in the configuration.
"""
if input_array is None:
if self.cfg.data_config.mean is None or self.cfg.data_config.std is None:
raise ValueError(
"Mean and std cannot be None in the configuration in order to"
"export to the BMZ format. Was the model trained?"
)

# generate images, priority is given to the prediction data module
if self.pred_datamodule is not None:
# unpack a batch, ignore masks or targets
input_patch, *_ = next(iter(self.pred_datamodule.predict_dataloader()))

# convert torch.Tensor to numpy
input_patch = input_patch.numpy()

# denormalize
denormalize = Denormalize(
mean=self.cfg.data_config.mean, std=self.cfg.data_config.std
)
input_patch, _ = denormalize(input_patch)

elif self.train_datamodule is not None:
input_patch, *_ = next(iter(self.train_datamodule.train_dataloader()))
input_patch = input_patch.numpy()
else:
if (
self.cfg.data_config.mean is None
or self.cfg.data_config.std is None
):
raise ValueError(
"Mean and std cannot be None in the configuration in order to"
"export to the BMZ format. Was the model trained?"
)

# denormalize
denormalize = Denormalize(
mean=self.cfg.data_config.mean, std=self.cfg.data_config.std
)
input_patch, _ = denormalize(input_patch)
else:
# create a random input array
input_patch = np.random.normal(
loc=self.cfg.data_config.mean,
Expand All @@ -712,11 +721,47 @@ def export_to_bmz(
np.newaxis, np.newaxis, ...
] # add S & C dimensions
else:
input_patch = input_array
# potentially correct shape
input_patch = reshape_array(input_array, self.cfg.data_config.axes)

# if there is a batch dimension
# if this a batch
if input_patch.shape[0] > 1:
input_patch = input_patch[0:1, ...] # keep singleton dim
input_patch = input_patch[[0], ...] # keep singleton dim

return input_patch

def export_to_bmz(
self,
path: Union[Path, str],
name: str,
authors: List[dict],
input_array: Optional[np.ndarray] = None,
general_description: str = "",
channel_names: Optional[List[str]] = None,
data_description: Optional[str] = None,
) -> None:
"""Export the model to the BioImage Model Zoo format.
Input array must be of shape SC(Z)YX, with S and C singleton dimensions.
Parameters
----------
path : Union[Path, str]
Path to save the model.
name : str
Name of the model.
authors : List[dict]
List of authors of the model.
input_array : Optional[np.ndarray], optional
Input array for the model, must be of shape SC(Z)YX, by default None.
general_description : str
General description of the model, used in the metadata of the BMZ archive.
channel_names : Optional[List[str]], optional
Channel names, by default None.
data_description : Optional[str], optional
Description of the data, by default None.
"""
input_patch = self._create_data_for_bmz(input_array)

# axes need to be reformated for the export because reshaping was done in the
# datamodule
Expand Down
6 changes: 3 additions & 3 deletions src/careamics/model_io/bmz_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ def export_to_bmz(
authors : List[dict]
Authors of the model.
input_array : np.ndarray
Input array.
Input array, should not have been normalized.
output_array : np.ndarray
Output array.
Output array, should have been denormalized.
channel_names : Optional[List[str]], optional
Channel names, by default None.
data_description : Optional[str], optional
Expand Down Expand Up @@ -178,7 +178,7 @@ def export_to_bmz(
)

# test model description
summary: ValidationSummary = test_model(model_description, decimal=0)
summary: ValidationSummary = test_model(model_description, decimal=2)
if summary.status == "failed":
raise ValueError(f"Model description test failed: {summary}")

Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def overlaps() -> Tuple[int, int]:
def pre_trained(tmp_path, minimum_configuration):
"""Fixture to create a pre-trained CAREamics model."""
# training data
train_array = np.arange(32 * 32).reshape((32, 32))
train_array = np.arange(32 * 32).reshape((32, 32)).astype(float)

# create configuration
config = Configuration(**minimum_configuration)
Expand Down
Loading

0 comments on commit 763d965

Please sign in to comment.