diff --git a/src/careamics/careamist.py b/src/careamics/careamist.py index 31d98a96..29bf8940 100644 --- a/src/careamics/careamist.py +++ b/src/careamics/careamist.py @@ -16,19 +16,18 @@ from careamics.callbacks import ProgressBarCallback from careamics.config import ( Configuration, - create_inference_configuration, load_configuration, ) from careamics.config.support import SupportedAlgorithm, SupportedData, SupportedLogger from careamics.dataset.dataset_utils import reshape_array from careamics.lightning_datamodule import CAREamicsTrainData from careamics.lightning_module import CAREamicsModule -from careamics.lightning_prediction_datamodule import CAREamicsPredictData -from careamics.lightning_prediction_loop import CAREamicsPredictionLoop from careamics.model_io import export_to_bmz, load_pretrained +from careamics.prediction_utils import convert_outputs, create_pred_datamodule from careamics.utils import check_path_exists, get_logger from .callbacks import HyperParametersCallback +from .lightning_prediction_datamodule import CAREamicsPredictData logger = get_logger(__name__) @@ -193,9 +192,6 @@ def __init__( logger=self.experiment_logger, ) - # change the prediction loop, necessary for tiled prediction - self.trainer.predict_loop = CAREamicsPredictionLoop(self.trainer) - # place holder for the datamodules self.train_datamodule: Optional[CAREamicsTrainData] = None self.pred_datamodule: Optional[CAREamicsPredictData] = None @@ -572,28 +568,28 @@ def predict( Parameters ---------- - source : CAREamicsPredData or pathlib.Path or str or NDArray + source : CAREamicsPredData, pathlib.Path, str or numpy.ndarray Data to predict on. - batch_size : int, optional - Batch size for prediction, by default 1. + batch_size : int, default=1 + Batch size for prediction. tile_size : tuple of int, optional - Size of the tiles to use for prediction, by default None. - tile_overlap : tuple of int, optional - Overlap between tiles, by default (48, 48). + Size of the tiles to use for prediction. + tile_overlap : tuple of int, default=(48, 48) + Overlap between tiles. axes : str, optional Axes of the input data, by default None. data_type : {"array", "tiff", "custom"}, optional - Type of the input data, by default None. - tta_transforms : bool, optional - Whether to apply test-time augmentation, by default True. + Type of the input data. + tta_transforms : bool, default=True + Whether to apply test-time augmentation. dataloader_params : dict, optional - Parameters to pass to the dataloader, by default None. + Parameters to pass to the dataloader. read_source_func : Callable, optional - Function to read the source data, by default None. - extension_filter : str, optional - Filter for the file extension, by default "". + Function to read the source data. + extension_filter : str, default="" + Filter for the file extension. checkpoint : {"best", "last"}, optional - Checkpoint to use for prediction, by default None. + Checkpoint to use for prediction. **kwargs : Any Unused. @@ -601,91 +597,33 @@ def predict( ------- list of NDArray or NDArray Predictions made by the model. - - Raises - ------ - ValueError - If the input is not a CAREamicsPredData instance, a path or a numpy array. """ - if isinstance(source, CAREamicsPredictData): - # record datamodule - self.pred_datamodule = source - - return self.trainer.predict( - model=self.model, datamodule=source, ckpt_path=checkpoint + # Reuse batch size if not provided explicitly + if batch_size is None: + batch_size = ( + self.train_datamodule.batch_size + if self.train_datamodule + else self.cfg.data_config.batch_size ) - else: - if self.cfg is None: - raise ValueError( - "No configuration found. Train a model or load from a " - "checkpoint before predicting." - ) - - # Reuse batch size if not provided explicitly - if batch_size is None: - batch_size = ( - self.train_datamodule.batch_size - if self.train_datamodule - else self.cfg.data_config.batch_size - ) - - # create predict config, reuse training config if parameters missing - prediction_config = create_inference_configuration( - configuration=self.cfg, - tile_size=tile_size, - tile_overlap=tile_overlap, - data_type=data_type, - axes=axes, - tta_transforms=tta_transforms, - batch_size=batch_size, - ) - - # remove batch from dataloader parameters (priority given to config) - if dataloader_params is None: - dataloader_params = {} - if "batch_size" in dataloader_params: - del dataloader_params["batch_size"] - - if isinstance(source, Path) or isinstance(source, str): - # Check the source - source_path = check_path_exists(source) - - # create datamodule - datamodule = CAREamicsPredictData( - pred_config=prediction_config, - pred_data=source_path, - read_source_func=read_source_func, - extension_filter=extension_filter, - dataloader_params=dataloader_params, - ) - # record datamodule - self.pred_datamodule = datamodule - - return self.trainer.predict( - model=self.model, datamodule=datamodule, ckpt_path=checkpoint - ) - - elif isinstance(source, np.ndarray): - # create datamodule - datamodule = CAREamicsPredictData( - pred_config=prediction_config, - pred_data=source, - dataloader_params=dataloader_params, - ) - - # record datamodule - self.pred_datamodule = datamodule - - return self.trainer.predict( - model=self.model, datamodule=datamodule, ckpt_path=checkpoint - ) + self.pred_datamodule = create_pred_datamodule( + source=source, + config=self.cfg, + batch_size=batch_size, + tile_size=tile_size, + tile_overlap=tile_overlap, + axes=axes, + data_type=data_type, + tta_transforms=tta_transforms, + dataloader_params=dataloader_params, + read_source_func=read_source_func, + extension_filter=extension_filter, + ) - else: - raise ValueError( - f"Invalid input. Expected a CAREamicsPredData instance, paths or " - f"NDArray (got {type(source)})." - ) + predictions = self.trainer.predict( + model=self.model, datamodule=self.pred_datamodule, ckpt_path=checkpoint + ) + return convert_outputs(predictions, self.pred_datamodule.tiled) def export_to_bmz( self, diff --git a/src/careamics/config/tile_information.py b/src/careamics/config/tile_information.py index 5b7feacc..e24482be 100644 --- a/src/careamics/config/tile_information.py +++ b/src/careamics/config/tile_information.py @@ -22,6 +22,7 @@ class TileInformation(BaseModel): last_tile: bool = False overlap_crop_coords: tuple[tuple[int, ...], ...] stitch_coords: tuple[tuple[int, ...], ...] + sample_id: int @field_validator("array_shape") @classmethod @@ -69,4 +70,5 @@ def __eq__(self, other_tile: object): and self.last_tile == other_tile.last_tile and self.overlap_crop_coords == other_tile.overlap_crop_coords and self.stitch_coords == other_tile.stitch_coords + and self.sample_id == other_tile.sample_id ) diff --git a/src/careamics/dataset/tiling/__init__.py b/src/careamics/dataset/tiling/__init__.py index f7b9643a..67f876f5 100644 --- a/src/careamics/dataset/tiling/__init__.py +++ b/src/careamics/dataset/tiling/__init__.py @@ -7,5 +7,4 @@ ] from .collate_tiles import collate_tiles -from .stitch_prediction import stitch_prediction from .tiled_patching import extract_tiles diff --git a/src/careamics/dataset/tiling/stitch_prediction.py b/src/careamics/dataset/tiling/stitch_prediction.py deleted file mode 100644 index 54f94604..00000000 --- a/src/careamics/dataset/tiling/stitch_prediction.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Prediction utility functions.""" - -from typing import List - -import numpy as np - -from careamics.config.tile_information import TileInformation - - -def stitch_prediction( - tiles: List[np.ndarray], - tile_infos: List[TileInformation], -) -> np.ndarray: - """Stitch tiles back together to form a full image. - - Tiles are of dimensions SC(Z)YX, where C is the number of channels and can be a - singleton dimension. - - Parameters - ---------- - tiles : List[np.ndarray] - Cropped tiles and their respective stitching coordinates. - tile_infos : List[TileInformation] - List of information and coordinates obtained from - `dataset.tiled_patching.extract_tiles`. - - Returns - ------- - np.ndarray - Full image. - """ - # retrieve whole array size - input_shape = tile_infos[0].array_shape - predicted_image = np.zeros(input_shape, dtype=np.float32) - - for tile, tile_info in zip(tiles, tile_infos): - n_channels = tile.shape[0] - - # Compute coordinates for cropping predicted tile - slices = (slice(0, n_channels),) + tuple( - [slice(c[0], c[1]) for c in tile_info.overlap_crop_coords] - ) - - # Crop predited tile according to overlap coordinates - cropped_tile = tile[slices] - - # Insert cropped tile into predicted image using stitch coordinates - predicted_image[ - ( - ..., - *[slice(c[0], c[1]) for c in tile_info.stitch_coords], - ) - ] = cropped_tile.astype(np.float32) - - return predicted_image diff --git a/src/careamics/dataset/tiling/tiled_patching.py b/src/careamics/dataset/tiling/tiled_patching.py index 10fe695c..0ef0eb13 100644 --- a/src/careamics/dataset/tiling/tiled_patching.py +++ b/src/careamics/dataset/tiling/tiled_patching.py @@ -158,6 +158,7 @@ def extract_tiles( last_tile=last_tile, overlap_crop_coords=overlap_crop_coords, stitch_coords=stitch_coords, + sample_id=sample_idx, ) yield tile, tile_info diff --git a/src/careamics/lightning_module.py b/src/careamics/lightning_module.py index 4ece4e91..7c7b21ee 100644 --- a/src/careamics/lightning_module.py +++ b/src/careamics/lightning_module.py @@ -175,7 +175,7 @@ def predict_step(self, batch: Tensor, batch_idx: Any) -> Any: denormalized_output = denorm(patch=output.cpu().numpy()) if len(aux) > 0: # aux can be tiling information - return denormalized_output, aux + return denormalized_output, *aux else: return denormalized_output diff --git a/src/careamics/lightning_prediction_loop.py b/src/careamics/lightning_prediction_loop.py deleted file mode 100644 index 2ef6b6a5..00000000 --- a/src/careamics/lightning_prediction_loop.py +++ /dev/null @@ -1,126 +0,0 @@ -"""Lithning prediction loop allowing tiling.""" - -from typing import List, Optional - -import numpy as np -import pytorch_lightning as L -from pytorch_lightning.loops.fetchers import _DataLoaderIterDataFetcher -from pytorch_lightning.loops.utilities import _no_grad_context -from pytorch_lightning.trainer import call -from pytorch_lightning.utilities.types import _PREDICT_OUTPUT - -from careamics.config.tile_information import TileInformation -from careamics.dataset.tiling import stitch_prediction - - -class CAREamicsPredictionLoop(L.loops._PredictionLoop): - """ - CAREamics prediction loop. - - This class extends the PyTorch Lightning `_PredictionLoop` class to include - the stitching of the tiles into a single prediction result. - """ - - def _on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]: - """Call `on_predict_epoch_end` hook. - - Adapted from the parent method. - - Returns - ------- - Optional[_PREDICT_OUTPUT] - Prediction output. - """ - trainer = self.trainer - call._call_callback_hooks(trainer, "on_predict_epoch_end") - call._call_lightning_module_hook(trainer, "on_predict_epoch_end") - - if self.return_predictions: - ######################################################## - ################ CAREamics specific code ############### - if len(self.predicted_array) == 1: - # single array, already a numpy array - return self.predicted_array[0] # todo why not return the list here? - else: - return self.predicted_array - ######################################################## - return None - - @_no_grad_context - def run(self) -> Optional[_PREDICT_OUTPUT]: - """Run the prediction loop. - - Adapted from the parent method in order to stitch the predictions. - - Returns - ------- - Optional[_PREDICT_OUTPUT] - Prediction output. - """ - self.setup_data() - if self.skip: - return None - self.reset() - self.on_run_start() - data_fetcher = self._data_fetcher - assert data_fetcher is not None - - self.predicted_array = [] - self.tiles: List[np.ndarray] = [] - self.tile_information: List[TileInformation] = [] - - while True: - try: - if isinstance(data_fetcher, _DataLoaderIterDataFetcher): - dataloader_iter = next(data_fetcher) - # hook's batch_idx and dataloader_idx arguments correctness cannot - # be guaranteed in this setting - batch = data_fetcher._batch - batch_idx = data_fetcher._batch_idx - dataloader_idx = data_fetcher._dataloader_idx - else: - dataloader_iter = None - batch, batch_idx, dataloader_idx = next(data_fetcher) - self.batch_progress.is_last_batch = data_fetcher.done - - # run step hooks - self._predict_step(batch, batch_idx, dataloader_idx, dataloader_iter) - - ######################################################## - ################ CAREamics specific code ############### - is_tiled = self.trainer.datamodule.tiled - if is_tiled: - # a numpy array of shape BC(Z)YX - tile_batch = self.predictions[batch_idx][0] - - # split the tiles into C(Z)YX (skip singleton S) and - # add them to the tiles list - self.tiles.extend( - np.split(tile_batch, tile_batch.shape[0], axis=0)[0] - ) - - # tile information is passed as a list of list of TileInformation - # TODO why list of list? - tile_info = self.predictions[batch_idx][1][0] - self.tile_information.extend(tile_info) - - # if last tile, stitch the tiles and add array to the prediction - last_tiles = [t.last_tile for t in self.tile_information] - if any(last_tiles): - predicted_batches = stitch_prediction( - self.tiles, self.tile_information - ) - self.predicted_array.append(predicted_batches) - self.tiles.clear() - self.tile_information.clear() - else: - # simply add the prediction to the list - self.predicted_array.append(self.predictions[batch_idx]) - ######################################################## - except StopIteration: - break - finally: - self._restarting = False - return self.on_run_end() - - # TODO predictions aren't stacked, list returned diff --git a/src/careamics/prediction/stitch_prediction.py b/src/careamics/prediction/stitch_prediction.py deleted file mode 100644 index eefbb6e5..00000000 --- a/src/careamics/prediction/stitch_prediction.py +++ /dev/null @@ -1,68 +0,0 @@ -"""Prediction utility functions.""" - -from typing import List - -import numpy as np - - -def stitch_prediction( - tiles: List[np.ndarray], - stitching_data: List[List[np.ndarray]], -) -> np.ndarray: - """ - Stitch tiles back together to form a full image. - - Parameters - ---------- - tiles : List[np.ndarray] - Cropped tiles and their respective stitching coordinates. - stitching_data : List[List[np.ndarray]] - List of lists containing the overlap crop coordinates and stitch coordinates. - - Returns - ------- - np.ndarray - Full image. - """ - # retrieve whole array size, there is two cases to consider: - # 1. the tiles are stored in a list - # 2. the tiles are stored in a list with batches along the first dim - if tiles[0].shape[0] > 1: - input_shape = np.array( - [el.numpy() for el in stitching_data[0][0][0]], dtype=int - ).squeeze() - else: - input_shape = np.array( - [el.numpy() for el in stitching_data[0][0]], dtype=int - ).squeeze() - - # TODO should use torch.zeros instead of np.zeros - predicted_image = np.zeros(input_shape, dtype=np.float32) - - for tile_batch, (_, overlap_crop_coords_batch, stitch_coords_batch) in zip( - tiles, stitching_data - ): - for batch_idx in range(tile_batch.shape[0]): - # Compute coordinates for cropping predicted tile - slices = tuple( - [ - slice(c[0][batch_idx], c[1][batch_idx]) - for c in overlap_crop_coords_batch - ] - ) - - # Crop predited tile according to overlap coordinates - cropped_tile = tile_batch[batch_idx].squeeze()[slices] - - # Insert cropped tile into predicted image using stitch coordinates - predicted_image[ - ( - ..., - *[ - slice(c[0][batch_idx], c[1][batch_idx]) - for c in stitch_coords_batch - ], - ) - ] = cropped_tile.astype(np.float32) - - return predicted_image diff --git a/src/careamics/prediction_utils/__init__.py b/src/careamics/prediction_utils/__init__.py new file mode 100644 index 00000000..fccce02e --- /dev/null +++ b/src/careamics/prediction_utils/__init__.py @@ -0,0 +1,12 @@ +"""Package to house various prediction utilies.""" + +__all__ = [ + "create_pred_datamodule", + "stitch_prediction", + "stitch_prediction_single", + "convert_outputs", +] + +from .create_pred_datamodule import create_pred_datamodule +from .prediction_outputs import convert_outputs +from .stitch_prediction import stitch_prediction, stitch_prediction_single diff --git a/src/careamics/prediction_utils/create_pred_datamodule.py b/src/careamics/prediction_utils/create_pred_datamodule.py new file mode 100644 index 00000000..5baa3005 --- /dev/null +++ b/src/careamics/prediction_utils/create_pred_datamodule.py @@ -0,0 +1,185 @@ +"""Module containing functions to create `CAREamicsPredictData`.""" + +from pathlib import Path +from typing import Callable, Dict, Literal, Optional, Tuple, Union + +import numpy as np +from numpy.typing import NDArray + +from careamics.config import Configuration, create_inference_configuration +from careamics.utils import check_path_exists + +from ..lightning_prediction_datamodule import CAREamicsPredictData + + +def create_pred_datamodule( + source: Union[CAREamicsPredictData, Path, str, NDArray], + config: Configuration, + batch_size: Optional[int] = None, + tile_size: Optional[Tuple[int, ...]] = None, + tile_overlap: Tuple[int, ...] = (48, 48), + axes: Optional[str] = None, + data_type: Optional[Literal["array", "tiff", "custom"]] = None, + tta_transforms: bool = True, + dataloader_params: Optional[Dict] = None, + read_source_func: Optional[Callable] = None, + extension_filter: str = "", +) -> CAREamicsPredictData: + """ + Create a `CAREamicsPredictData` module. + + Parameters + ---------- + source : CAREamicsPredData, pathlib.Path, str or numpy.ndarray + Data to predict on. + config : Configuration + Global configuration. + batch_size : int, default=1 + Batch size for prediction. + tile_size : tuple of int, optional + Size of the tiles to use for prediction. + tile_overlap : tuple of int, default=(48, 48) + Overlap between tiles. + axes : str, optional + Axes of the input data, by default None. + data_type : {"array", "tiff", "custom"}, optional + Type of the input data. + tta_transforms : bool, default=True + Whether to apply test-time augmentation. + dataloader_params : dict, optional + Parameters to pass to the dataloader. + read_source_func : Callable, optional + Function to read the source data. + extension_filter : str, default="" + Filter for the file extension. + + Returns + ------- + prediction datamodule: CAREamicsPredictData + Subclass of `pytorch_lightning.LightningDataModule` for creating predictions. + + Raises + ------ + ValueError + If the input is not a CAREamicsPredData instance, a path or a numpy array. + """ + # Reuse batch size if not provided explicitly + if batch_size is None: + batch_size = config.data_config.batch_size + + # create predict config, reuse training config if parameters missing + prediction_config = create_inference_configuration( + configuration=config, + tile_size=tile_size, + tile_overlap=tile_overlap, + data_type=data_type, + axes=axes, + tta_transforms=tta_transforms, + batch_size=batch_size, + ) + + # remove batch from dataloader parameters (priority given to config) + if dataloader_params is None: + dataloader_params = {} + if "batch_size" in dataloader_params: + del dataloader_params["batch_size"] + + if isinstance(source, CAREamicsPredictData): + pred_datamodule = source + elif isinstance(source, Path) or isinstance(source, str): + pred_datamodule = _create_from_path( + source=source, + pred_config=prediction_config, + read_source_func=read_source_func, + extension_filter=extension_filter, + dataloader_params=dataloader_params, + ) + elif isinstance(source, np.ndarray): + pred_datamodule = _create_from_array( + source=source, + pred_config=prediction_config, + dataloader_params=dataloader_params, + ) + else: + raise ValueError( + f"Invalid input. Expected a CAREamicsPredData instance, paths or " + f"NDArray (got {type(source)})." + ) + + return pred_datamodule + + +def _create_from_path( + source: Union[Path, str], + pred_config: Configuration, + read_source_func: Optional[Callable] = None, + extension_filter: str = "", + dataloader_params: Optional[Dict] = None, + **kwargs, +) -> CAREamicsPredictData: + """ + Create `CAREamicsPredictData` from path. + + Parameters + ---------- + source : Path or str + _Data to predict on. + pred_config : Configuration + Prediction configuration. + read_source_func : Callable, optional + Function to read the source data. + extension_filter : str, default="" + Function to read the source data. + dataloader_params : Optional[Dict], optional + Parameters to pass to the dataloader. + **kwargs + Unused. + + Returns + ------- + prediction datamodule: CAREamicsPredictData + Subclass of `pytorch_lightning.LightningDataModule` for creating predictions. + """ + source_path = check_path_exists(source) + + datamodule = CAREamicsPredictData( + pred_config=pred_config, + pred_data=source_path, + read_source_func=read_source_func, + extension_filter=extension_filter, + dataloader_params=dataloader_params, + ) + return datamodule + + +def _create_from_array( + source: NDArray, + pred_config: Configuration, + dataloader_params: Optional[Dict] = None, + **kwargs, +) -> CAREamicsPredictData: + """ + Create `CAREamicsPredictData` from array. + + Parameters + ---------- + source : Path or str + _Data to predict on. + pred_config : Configuration + Prediction configuration. + dataloader_params : Optional[Dict], optional + Parameters to pass to the dataloader. + **kwargs + Unused. Added for compatible function signature with `_create_from_path`. + + Returns + ------- + prediction datamodule: CAREamicsPredictData + Subclass of `pytorch_lightning.LightningDataModule` for creating predictions. + """ + datamodule = CAREamicsPredictData( + pred_config=pred_config, + pred_data=source, + dataloader_params=dataloader_params, + ) + return datamodule diff --git a/src/careamics/prediction_utils/prediction_outputs.py b/src/careamics/prediction_utils/prediction_outputs.py new file mode 100644 index 00000000..423164e3 --- /dev/null +++ b/src/careamics/prediction_utils/prediction_outputs.py @@ -0,0 +1,165 @@ +"""Module containing functions to convert prediction outputs to desired form.""" + +from typing import Any, List, Literal, Tuple, Union, overload + +import numpy as np +from numpy.typing import NDArray + +from ..config.tile_information import TileInformation +from .stitch_prediction import stitch_prediction + + +def convert_outputs( + predictions: List[Any], tiled: bool +) -> Union[List[NDArray], NDArray]: + """ + Convert the outputs to the desired form. + + Parameters + ---------- + predictions : list + Predictions that are output from `Trainer.predict`. + tiled : bool + Whether the predictions are tiled. + + Returns + ------- + list of numpy.ndarray or numpy.ndarray + List of arrays with the axes SC(Z)YX. If there is only 1 output it will not + be in a list. + """ + if len(predictions) == 0: + return predictions + + # this layout is to stop mypy complaining + if tiled: + predictions_comb = combine_batches(predictions, tiled) + # remove sample dimension (always 1) `stitch_predict` func expects no S dim + tiles = [pred[0] for pred in predictions_comb[0]] + tile_infos = predictions_comb[1] + predictions_output = stitch_prediction(tiles, tile_infos) + else: + predictions_output = combine_batches(predictions, tiled) + + # TODO: add this in? Returns output with same axes as input + # Won't work with tiling rn because stitch_prediction func removes S axis + # predictions = reshape(predictions, axes) + # At least make sure stitched prediction and non-tiled prediction have matching axes + + # TODO: might want to remove this + if len(predictions_output) == 1: + return predictions_output[0] + return predictions_output + + +# for mypy +@overload +def combine_batches( # numpydoc ignore=GL08 + predictions: List[Any], tiled: Literal[True] +) -> Tuple[List[NDArray], List[TileInformation]]: ... + + +# for mypy +@overload +def combine_batches( # numpydoc ignore=GL08 + predictions: List[Any], tiled: Literal[False] +) -> List[NDArray]: ... + + +# for mypy +@overload +def combine_batches( # numpydoc ignore=GL08 + predictions: List[Any], tiled: Union[bool, Literal[True], Literal[False]] +) -> Union[List[NDArray], Tuple[List[NDArray], List[TileInformation]]]: ... + + +def combine_batches( + predictions: List[Any], tiled: bool +) -> Union[List[NDArray], Tuple[List[NDArray], List[TileInformation]]]: + """ + If predictions are in batches, they will be combined. + + Parameters + ---------- + predictions : list + Predictions that are output from `Trainer.predict`. + tiled : bool + Whether the predictions are tiled. + + Returns + ------- + (list of numpy.ndarray) or tuple of (list of numpy.ndarray, list of TileInformation) + Combined batches. + """ + if tiled: + return _combine_tiled_batches(predictions) + else: + return _combine_untiled_batches(predictions) + + +def _combine_tiled_batches( + predictions: List[Tuple[NDArray, List[TileInformation]]] +) -> Tuple[List[NDArray], List[TileInformation]]: + """ + Combine batches from tiled output. + + Parameters + ---------- + predictions : list + Predictions that are output from `Trainer.predict`. + + Returns + ------- + tuple of (list of numpy.ndarray, list of TileInformation) + Combined batches. + """ + # turn list of lists into single list + tile_infos = [ + tile_info for _, tile_info_list in predictions for tile_info in tile_info_list + ] + prediction_tiles: List[NDArray] = _combine_untiled_batches( + [preds for preds, _ in predictions] + ) + return prediction_tiles, tile_infos + + +def _combine_untiled_batches(predictions: List[NDArray]) -> List[NDArray]: + """ + Combine batches from un-tiled output. + + Parameters + ---------- + predictions : list + Predictions that are output from `Trainer.predict`. + + Returns + ------- + list of nunpy.ndarray + Combined batches. + """ + prediction_concat: NDArray = np.concatenate(predictions, axis=0) + prediction_split = np.split(prediction_concat, prediction_concat.shape[0], axis=0) + return prediction_split + + +def reshape(predictions: List[NDArray], axes: str) -> List[NDArray]: + """ + Reshape predictions to have dimensions of input. + + Parameters + ---------- + predictions : list + Predictions that are output from `Trainer.predict`. + axes : str + Axes SC(Z)YX. + + Returns + ------- + List[NDArray] + Reshaped predicitions. + """ + if "C" not in axes: + predictions = [pred[:, 0] for pred in predictions] + if "S" not in axes: + predictions = [pred[0] for pred in predictions] + return predictions diff --git a/src/careamics/prediction_utils/stitch_prediction.py b/src/careamics/prediction_utils/stitch_prediction.py new file mode 100644 index 00000000..eace8381 --- /dev/null +++ b/src/careamics/prediction_utils/stitch_prediction.py @@ -0,0 +1,98 @@ +"""Prediction utility functions.""" + +from typing import List + +import numpy as np + +from careamics.config.tile_information import TileInformation + + +# TODO: why not allow input and output of torch.tensor ? +def stitch_prediction( + tiles: List[np.ndarray], + tile_infos: List[TileInformation], +) -> List[np.ndarray]: + """ + Stitch tiles back together to form a full image(s). + + Tiles are of dimensions SC(Z)YX, where C is the number of channels and can be a + singleton dimension. + + Parameters + ---------- + tiles : list of numpy.ndarray + Cropped tiles and their respective stitching coordinates. Can contain tiles + from multiple images. + tile_infos : list of TileInformation + List of information and coordinates obtained from + `dataset.tiled_patching.extract_tiles`. + + Returns + ------- + list of numpy.ndarray + Full image(s). + """ + # Find where to split the lists so that only info from one image is contained. + # Do this by locating the last tiles of each image. + last_tiles = [tile_info.last_tile for tile_info in tile_infos] + last_tile_position = np.where(last_tiles)[0] + image_slices = [ + slice(None if i == 0 else last_tile_position[i - 1], last_tile_position[i] + 1) + for i in range(len(last_tile_position)) + ] + image_predictions = [] + # slice the lists and apply stitch_prediction_single to each in turn. + for image_slice in image_slices: + image_predictions.append( + stitch_prediction_single(tiles[image_slice], tile_infos[image_slice]) + ) + return image_predictions + + +def stitch_prediction_single( + tiles: List[np.ndarray], + tile_infos: List[TileInformation], +) -> np.ndarray: + """ + Stitch tiles back together to form a full image. + + Tiles are of dimensions SC(Z)YX, where C is the number of channels and can be a + singleton dimension. + + Parameters + ---------- + tiles : list of numpy.ndarray + Cropped tiles and their respective stitching coordinates. + tile_infos : list of TileInformation + List of information and coordinates obtained from + `dataset.tiled_patching.extract_tiles`. + + Returns + ------- + numpy.ndarray + Full image. + """ + # retrieve whole array size + input_shape = tile_infos[0].array_shape + predicted_image = np.zeros(input_shape, dtype=np.float32) + + for tile, tile_info in zip(tiles, tile_infos): + n_channels = tile.shape[0] + + # Compute coordinates for cropping predicted tile + slices = (slice(0, n_channels),) + tuple( + [slice(c[0], c[1]) for c in tile_info.overlap_crop_coords] + ) + + # Crop predited tile according to overlap coordinates + cropped_tile = tile[slices] + + # Insert cropped tile into predicted image using stitch coordinates + predicted_image[ + ( + ..., + *[slice(c[0], c[1]) for c in tile_info.stitch_coords], + ) + ] = cropped_tile.astype(np.float32) + + return predicted_image diff --git a/tests/config/test_tile_information.py b/tests/config/test_tile_information.py index 1f10b24b..34ccbca1 100644 --- a/tests/config/test_tile_information.py +++ b/tests/config/test_tile_information.py @@ -10,6 +10,7 @@ def test_defaults(): array_shape=np.zeros((6, 6)).shape, overlap_crop_coords=((1, 2),), stitch_coords=((3, 4),), + sample_id=0, ) assert tile_info.array_shape == (6, 6) @@ -39,12 +40,14 @@ def test_tile_equality(): last_tile=True, overlap_crop_coords=((1, 2),), stitch_coords=((3, 4),), + sample_id=0, ) t2 = TileInformation( array_shape=(6, 6), last_tile=True, overlap_crop_coords=((1, 2),), stitch_coords=((3, 4),), + sample_id=0, ) assert t1 == t2 diff --git a/tests/dataset/prediction/test_stitch_prediction.py b/tests/dataset/prediction/test_stitch_prediction.py deleted file mode 100644 index 062f48ec..00000000 --- a/tests/dataset/prediction/test_stitch_prediction.py +++ /dev/null @@ -1,51 +0,0 @@ -import numpy as np -import pytest - -from careamics.dataset.tiling import extract_tiles, stitch_prediction - - -@pytest.mark.parametrize( - "input_shape, tile_size, overlaps", - [ - ((1, 1, 8, 8), (4, 4), (2, 2)), - ((1, 2, 8, 8), (4, 4), (2, 2)), - ((2, 1, 8, 8), (4, 4), (2, 2)), - ((2, 2, 8, 8), (4, 4), (2, 2)), - ((1, 1, 7, 9), (4, 4), (2, 2)), - ((1, 3, 7, 9), (4, 4), (2, 2)), - ((1, 1, 9, 7, 8), (4, 4, 4), (2, 2, 2)), - ((1, 1, 321, 481), (256, 256), (48, 48)), - ((2, 1, 321, 481), (256, 256), (48, 48)), - ((1, 4, 321, 481), (256, 256), (48, 48)), - ((4, 3, 321, 481), (256, 256), (48, 48)), - ], -) -def test_stitch_tiles(ordered_array, input_shape, tile_size, overlaps): - """Test stitching tiles back together.""" - arr = ordered_array(input_shape, dtype=int) - n_samples = input_shape[0] - - # extract tiles - all_tiles = list(extract_tiles(arr, tile_size, overlaps)) - - tiles = [] - tile_infos = [] - sample_id = 0 - for tile, tile_info in all_tiles: - # create lists mimicking the output of the prediction loop - tiles.append(tile) - tile_infos.append(tile_info) - - # if we reached the last tile - if tile_info.last_tile: - result = stitch_prediction(tiles, tile_infos) - - # check equality with the correct sample - assert np.array_equal(result, arr[sample_id].squeeze()) - sample_id += 1 - - # clear the lists - tiles.clear() - tile_infos.clear() - - assert sample_id == n_samples diff --git a/tests/dataset/prediction/test_collate_tiles.py b/tests/dataset/tiling/test_collate_tiles.py similarity index 100% rename from tests/dataset/prediction/test_collate_tiles.py rename to tests/dataset/tiling/test_collate_tiles.py diff --git a/tests/dataset/prediction/test_tiled_patching.py b/tests/dataset/tiling/test_tiled_patching.py similarity index 100% rename from tests/dataset/prediction/test_tiled_patching.py rename to tests/dataset/tiling/test_tiled_patching.py diff --git a/tests/prediction_utils/test_prediction_outputs.py b/tests/prediction_utils/test_prediction_outputs.py new file mode 100644 index 00000000..4b9a7711 --- /dev/null +++ b/tests/prediction_utils/test_prediction_outputs.py @@ -0,0 +1,45 @@ +import numpy as np +import pytest + +from careamics.dataset.tiling import extract_tiles +from careamics.prediction_utils import convert_outputs + + +@pytest.mark.parametrize("n_samples", [1, 2]) +@pytest.mark.parametrize("batch_size", [1, 2]) +def test_convert_outputs_tiled(ordered_array, batch_size, n_samples): + """Test conversion of output for when prediction is tiled""" + # --- simulate outputs from trainer.predict + tile_size = (4, 4) + overlaps = (2, 2) + arr = ordered_array((n_samples, 1, 16, 16)) + all_tiles = list(extract_tiles(arr, tile_size, overlaps)) + # combine into batches to match output of trainer.predict + prediction_batches = [] + for i in range(0, len(all_tiles), batch_size): + tiles = np.concatenate( + [tile[np.newaxis] for tile, _ in all_tiles[i : i + batch_size]], axis=0 + ) + tile_infos = [tile_info for _, tile_info in all_tiles[i : i + batch_size]] + prediction_batches.append((tiles, tile_infos)) + + predictions = convert_outputs(prediction_batches, tiled=True) + # TODO: fix convert_outputs so output shape is the same as input shape + # (Or always SC(Z)YX) + assert np.array_equal(np.array(predictions), arr.squeeze()) + + +@pytest.mark.parametrize("batch_size, n_samples", [(1, 1), (1, 2), (2, 2)]) +def test_convert_outputs_not_tiled(ordered_array, batch_size, n_samples): + """Test conversion of output for when prediction is not tiled""" + # --- simulate outputs from trainer.predict + # TODO: could test for case with different size batch at the end + prediction_batches = [ + ordered_array((batch_size, 1, 16, 16)) for _ in range(n_samples // batch_size) + ] + predictions = convert_outputs(prediction_batches, tiled=False) + if not isinstance(predictions, list): # single predictions not returned as list + predictions = [predictions] + assert np.array_equal( + np.concatenate(predictions, axis=0), np.concatenate(prediction_batches, axis=0) + ) diff --git a/tests/prediction_utils/test_stitch_prediction.py b/tests/prediction_utils/test_stitch_prediction.py new file mode 100644 index 00000000..d0cd62ce --- /dev/null +++ b/tests/prediction_utils/test_stitch_prediction.py @@ -0,0 +1,90 @@ +import numpy as np +import pytest + +from careamics.dataset.tiling import extract_tiles +from careamics.prediction_utils import stitch_prediction, stitch_prediction_single + + +@pytest.mark.parametrize( + "input_shape, tile_size, overlaps", + [ + ((1, 1, 8, 8), (4, 4), (2, 2)), + ((1, 2, 8, 8), (4, 4), (2, 2)), + ((2, 1, 8, 8), (4, 4), (2, 2)), + ((2, 2, 8, 8), (4, 4), (2, 2)), + ((1, 1, 7, 9), (4, 4), (2, 2)), + ((1, 3, 7, 9), (4, 4), (2, 2)), + ((1, 1, 9, 7, 8), (4, 4, 4), (2, 2, 2)), + ((1, 1, 321, 481), (256, 256), (48, 48)), + ((2, 1, 321, 481), (256, 256), (48, 48)), + ((1, 4, 321, 481), (256, 256), (48, 48)), + ((4, 3, 321, 481), (256, 256), (48, 48)), + ], +) +def test_stitch_tiles_single(ordered_array, input_shape, tile_size, overlaps): + """Test stitching tiles back together.""" + arr = ordered_array(input_shape, dtype=int) + n_samples = input_shape[0] + + # extract tiles + all_tiles = list(extract_tiles(arr, tile_size, overlaps)) + + tiles = [] + tile_infos = [] + sample_id = 0 + for tile, tile_info in all_tiles: + # create lists mimicking the output of the prediction loop + tiles.append(tile) + tile_infos.append(tile_info) + + # if we reached the last tile + if tile_info.last_tile: + result = stitch_prediction_single(tiles, tile_infos) + + # check equality with the correct sample + assert np.array_equal(result, arr[sample_id].squeeze()) + sample_id += 1 + + # clear the lists + tiles.clear() + tile_infos.clear() + + assert sample_id == n_samples + + +@pytest.mark.parametrize( + "input_shape, tile_size, overlaps", + [ + ((1, 1, 8, 8), (4, 4), (2, 2)), + ((1, 2, 8, 8), (4, 4), (2, 2)), + ((2, 1, 8, 8), (4, 4), (2, 2)), + ((2, 2, 8, 8), (4, 4), (2, 2)), + ((1, 1, 7, 9), (4, 4), (2, 2)), + ((1, 3, 7, 9), (4, 4), (2, 2)), + ((1, 1, 9, 7, 8), (4, 4, 4), (2, 2, 2)), + ((1, 1, 321, 481), (256, 256), (48, 48)), + ((2, 1, 321, 481), (256, 256), (48, 48)), + ((1, 4, 321, 481), (256, 256), (48, 48)), + ((4, 3, 321, 481), (256, 256), (48, 48)), + ], +) +def test_stitch_tiles_multi(ordered_array, input_shape, tile_size, overlaps): + """Test stitching tiles back together.""" + arr = ordered_array(input_shape, dtype=int) + n_samples = input_shape[0] + + # extract tiles + all_tiles = list(extract_tiles(arr, tile_size, overlaps)) + + tiles = [] + tile_infos = [] + for tile, tile_info in all_tiles: + # create lists mimicking the output of the prediction loop + tiles.append(tile) + tile_infos.append(tile_info) + + stitched = stitch_prediction(tiles, tile_infos) + for sample_id, result in enumerate(stitched): + assert np.array_equal(result, arr[sample_id].squeeze()) + + assert len(stitched) == n_samples diff --git a/tests/test_careamist.py b/tests/test_careamist.py index 49cc29a7..9178093c 100644 --- a/tests/test_careamist.py +++ b/tests/test_careamist.py @@ -504,18 +504,19 @@ def test_train_tiff_files_supervised(tmp_path: Path, supervised_configuration: d assert (tmp_path / "model.zip").exists() +@pytest.mark.parametrize("samples", [1, 2, 4]) @pytest.mark.parametrize("batch_size", [1, 2]) def test_predict_on_array_tiled( - tmp_path: Path, minimum_configuration: dict, batch_size + tmp_path: Path, minimum_configuration: dict, batch_size, samples ): """Test that CAREamics can predict on arrays.""" # training data - train_array = random_array((32, 32)) + train_array = random_array((samples, 32, 32)) # create configuration config = Configuration(**minimum_configuration) config.training_config.num_epochs = 1 - config.data_config.axes = "YX" + config.data_config.axes = "SYX" config.data_config.batch_size = 2 config.data_config.data_type = SupportedData.ARRAY.value config.data_config.patch_size = (8, 8) @@ -530,8 +531,9 @@ def test_predict_on_array_tiled( predicted = careamist.predict( train_array, batch_size=batch_size, tile_size=(16, 16), tile_overlap=(4, 4) ) + predicted_squeeze = [p.squeeze() for p in predicted] - assert predicted.squeeze().shape == train_array.shape + assert np.array(predicted_squeeze).shape == train_array.squeeze().shape # export to BMZ careamist.export_to_bmz( @@ -544,10 +546,14 @@ def test_predict_on_array_tiled( assert (tmp_path / "model.zip").exists() -def test_predict_arrays_no_tiling(tmp_path: Path, minimum_configuration: dict): +@pytest.mark.parametrize("samples", [1, 2, 4]) +@pytest.mark.parametrize("batch_size", [1, 2]) +def test_predict_arrays_no_tiling( + tmp_path: Path, minimum_configuration: dict, batch_size, samples +): """Test that CAREamics can predict on arrays without tiling.""" # training data - train_array = random_array((4, 32, 32)) + train_array = random_array((samples, 32, 32)) # create configuration config = Configuration(**minimum_configuration) @@ -564,9 +570,9 @@ def test_predict_arrays_no_tiling(tmp_path: Path, minimum_configuration: dict): careamist.train(train_source=train_array) # predict CAREamist - predicted = careamist.predict(train_array) + predicted = careamist.predict(train_array, batch_size=batch_size) - assert np.concatenate(predicted).squeeze().shape == train_array.shape + assert np.concatenate(predicted).squeeze().shape == train_array.squeeze().shape # export to BMZ careamist.export_to_bmz( @@ -579,6 +585,44 @@ def test_predict_arrays_no_tiling(tmp_path: Path, minimum_configuration: dict): assert (tmp_path / "model.zip").exists() +@pytest.mark.skip( + reason=( + "This might be a problem at the PyTorch level during `forward`. Values up to " + "0.001 different." + ) +) +def test_batched_prediction(tmp_path: Path, minimum_configuration: dict): + "Compare outputs when a batch size of 1 or 2 is used" + + tile_size = (16, 16) + tile_overlap = (4, 4) + shape = (32, 32) + + train_array = random_array(shape) + # create configuration + config = Configuration(**minimum_configuration) + config.training_config.num_epochs = 1 + config.data_config.axes = "YX" + config.data_config.batch_size = 2 + config.data_config.data_type = SupportedData.ARRAY.value + + # instantiate CAREamist + careamist = CAREamist(source=config, work_dir=tmp_path) + + # train CAREamist + careamist.train(train_source=train_array) + + # predict with batch size 1 and batch size 2 + pred_bs_1 = careamist.predict( + train_array, batch_size=1, tile_size=tile_size, tile_overlap=tile_overlap + ) + pred_bs_2 = careamist.predict( + train_array, batch_size=2, tile_size=tile_size, tile_overlap=tile_overlap + ) + + assert np.array_equal(pred_bs_1, pred_bs_2) + + @pytest.mark.parametrize("independent_channels", [False, True]) @pytest.mark.parametrize("batch_size", [1, 2]) def test_predict_tiled_channel(