Skip to content

Commit

Permalink
Refactor: Tiling applied post prediction (in contrast to during loop) (
Browse files Browse the repository at this point in the history
…#141)

### Description

Using the regular Lightning prediction loop, and tiled datasets, the
tile information is also returned when calling `trainer.predict`. This
means there is no reason tiling cannot be applied after prediction
instead of during the prediction loop. This means we don't have to write
a custom loop, simplifying the code and reducing coupling with
Lightning. This will also make life easier when we add saving
predictions (tiff or zarr).

**What**: Removes `CAREamicsPredictionLoop` (where tiling was previously
implemented) and applies tiling in `CAREamist.predict` after calling
`trainer.predict` with the regular Lightning prediction loop.

**Why**: Altering the Lightning prediction loop was overcomplicated,
hard to maintain and made adding an option to save predictions more
difficult.

**How**: predictions are returned with tiling information so predicted
tiles are stitched together at the end of prediction.

### Changes Made

- **Added**:
- multi-image stitch prediction function (takes name of old stitching
function, `stitched_prediction`)
   - New `prediction_utils` package.
   - module `prediction_outputs` in new `prediction_utils` package.
- **Modified**: 
   - `Trainer` in `CAREamist` no longer has prediction_loop replaced
- At the end of `CAREamist.predict` predictions are stitched and/or
converted to match old `CAREamistPredictionLoop` outputs.
- Creation of `CAREamicsPredictData` has been moved to new
`prediction_utils` package
- `stitch_prediction` function has been moved to `prediction_utils`
package
- **Removed**: `CAREamistPredictionLoop` 
- **Tests** 
   - Added: test for new `stitched_prediction` function.
   - Added: test for prediction output conversion.
- Modified: moved `stitched_prediction` tests to match new file
structure of src.
- Modified: `test_predict_on_array_tiled` and
`test_predict_arrays_no_tiling` – parametrised with `samples` and
`batch_size`; squeeze `train_array` when asserting size equality.

### Related Issues

- Resolves #140: Now a `BasePredictionWriter` Callback can be written
more easily.
- Resolves #143: New logic stops tiles from being skipped at the end of
batches.

### Breaking changes

Any code that instantiated a Lightning `Trainer` and added the
`CAREamistPredictionLoop` .
There might be some unforeseen changes to dimensions of prediction
outputs that the tests do not catch, i.e. adding S & C dims.

### TODO: (for future)
- Change prediction so that list is always output (currently, if there
is only 1 prediction it will not output a list).
- Dimensions of outputs always have all SC(Z)YX, or match input
dimensions. Currently, there is some inconsistency between tiled and not
tiled prediction output dimensions.

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [x] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
melisande-c and pre-commit-ci[bot] authored Jun 20, 2024
1 parent 37e9ad4 commit 0a29ea2
Show file tree
Hide file tree
Showing 19 changed files with 693 additions and 411 deletions.
140 changes: 39 additions & 101 deletions src/careamics/careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,18 @@
from careamics.callbacks import ProgressBarCallback
from careamics.config import (
Configuration,
create_inference_configuration,
load_configuration,
)
from careamics.config.support import SupportedAlgorithm, SupportedData, SupportedLogger
from careamics.dataset.dataset_utils import reshape_array
from careamics.lightning_datamodule import CAREamicsTrainData
from careamics.lightning_module import CAREamicsModule
from careamics.lightning_prediction_datamodule import CAREamicsPredictData
from careamics.lightning_prediction_loop import CAREamicsPredictionLoop
from careamics.model_io import export_to_bmz, load_pretrained
from careamics.prediction_utils import convert_outputs, create_pred_datamodule
from careamics.utils import check_path_exists, get_logger

from .callbacks import HyperParametersCallback
from .lightning_prediction_datamodule import CAREamicsPredictData

logger = get_logger(__name__)

Expand Down Expand Up @@ -193,9 +192,6 @@ def __init__(
logger=self.experiment_logger,
)

# change the prediction loop, necessary for tiled prediction
self.trainer.predict_loop = CAREamicsPredictionLoop(self.trainer)

# place holder for the datamodules
self.train_datamodule: Optional[CAREamicsTrainData] = None
self.pred_datamodule: Optional[CAREamicsPredictData] = None
Expand Down Expand Up @@ -572,120 +568,62 @@ def predict(
Parameters
----------
source : CAREamicsPredData or pathlib.Path or str or NDArray
source : CAREamicsPredData, pathlib.Path, str or numpy.ndarray
Data to predict on.
batch_size : int, optional
Batch size for prediction, by default 1.
batch_size : int, default=1
Batch size for prediction.
tile_size : tuple of int, optional
Size of the tiles to use for prediction, by default None.
tile_overlap : tuple of int, optional
Overlap between tiles, by default (48, 48).
Size of the tiles to use for prediction.
tile_overlap : tuple of int, default=(48, 48)
Overlap between tiles.
axes : str, optional
Axes of the input data, by default None.
data_type : {"array", "tiff", "custom"}, optional
Type of the input data, by default None.
tta_transforms : bool, optional
Whether to apply test-time augmentation, by default True.
Type of the input data.
tta_transforms : bool, default=True
Whether to apply test-time augmentation.
dataloader_params : dict, optional
Parameters to pass to the dataloader, by default None.
Parameters to pass to the dataloader.
read_source_func : Callable, optional
Function to read the source data, by default None.
extension_filter : str, optional
Filter for the file extension, by default "".
Function to read the source data.
extension_filter : str, default=""
Filter for the file extension.
checkpoint : {"best", "last"}, optional
Checkpoint to use for prediction, by default None.
Checkpoint to use for prediction.
**kwargs : Any
Unused.
Returns
-------
list of NDArray or NDArray
Predictions made by the model.
Raises
------
ValueError
If the input is not a CAREamicsPredData instance, a path or a numpy array.
"""
if isinstance(source, CAREamicsPredictData):
# record datamodule
self.pred_datamodule = source

return self.trainer.predict(
model=self.model, datamodule=source, ckpt_path=checkpoint
# Reuse batch size if not provided explicitly
if batch_size is None:
batch_size = (
self.train_datamodule.batch_size
if self.train_datamodule
else self.cfg.data_config.batch_size
)
else:
if self.cfg is None:
raise ValueError(
"No configuration found. Train a model or load from a "
"checkpoint before predicting."
)

# Reuse batch size if not provided explicitly
if batch_size is None:
batch_size = (
self.train_datamodule.batch_size
if self.train_datamodule
else self.cfg.data_config.batch_size
)

# create predict config, reuse training config if parameters missing
prediction_config = create_inference_configuration(
configuration=self.cfg,
tile_size=tile_size,
tile_overlap=tile_overlap,
data_type=data_type,
axes=axes,
tta_transforms=tta_transforms,
batch_size=batch_size,
)

# remove batch from dataloader parameters (priority given to config)
if dataloader_params is None:
dataloader_params = {}
if "batch_size" in dataloader_params:
del dataloader_params["batch_size"]

if isinstance(source, Path) or isinstance(source, str):
# Check the source
source_path = check_path_exists(source)

# create datamodule
datamodule = CAREamicsPredictData(
pred_config=prediction_config,
pred_data=source_path,
read_source_func=read_source_func,
extension_filter=extension_filter,
dataloader_params=dataloader_params,
)

# record datamodule
self.pred_datamodule = datamodule

return self.trainer.predict(
model=self.model, datamodule=datamodule, ckpt_path=checkpoint
)

elif isinstance(source, np.ndarray):
# create datamodule
datamodule = CAREamicsPredictData(
pred_config=prediction_config,
pred_data=source,
dataloader_params=dataloader_params,
)

# record datamodule
self.pred_datamodule = datamodule

return self.trainer.predict(
model=self.model, datamodule=datamodule, ckpt_path=checkpoint
)
self.pred_datamodule = create_pred_datamodule(
source=source,
config=self.cfg,
batch_size=batch_size,
tile_size=tile_size,
tile_overlap=tile_overlap,
axes=axes,
data_type=data_type,
tta_transforms=tta_transforms,
dataloader_params=dataloader_params,
read_source_func=read_source_func,
extension_filter=extension_filter,
)

else:
raise ValueError(
f"Invalid input. Expected a CAREamicsPredData instance, paths or "
f"NDArray (got {type(source)})."
)
predictions = self.trainer.predict(
model=self.model, datamodule=self.pred_datamodule, ckpt_path=checkpoint
)
return convert_outputs(predictions, self.pred_datamodule.tiled)

def export_to_bmz(
self,
Expand Down
2 changes: 2 additions & 0 deletions src/careamics/config/tile_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class TileInformation(BaseModel):
last_tile: bool = False
overlap_crop_coords: tuple[tuple[int, ...], ...]
stitch_coords: tuple[tuple[int, ...], ...]
sample_id: int

@field_validator("array_shape")
@classmethod
Expand Down Expand Up @@ -69,4 +70,5 @@ def __eq__(self, other_tile: object):
and self.last_tile == other_tile.last_tile
and self.overlap_crop_coords == other_tile.overlap_crop_coords
and self.stitch_coords == other_tile.stitch_coords
and self.sample_id == other_tile.sample_id
)
1 change: 0 additions & 1 deletion src/careamics/dataset/tiling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,4 @@
]

from .collate_tiles import collate_tiles
from .stitch_prediction import stitch_prediction
from .tiled_patching import extract_tiles
55 changes: 0 additions & 55 deletions src/careamics/dataset/tiling/stitch_prediction.py

This file was deleted.

1 change: 1 addition & 0 deletions src/careamics/dataset/tiling/tiled_patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def extract_tiles(
last_tile=last_tile,
overlap_crop_coords=overlap_crop_coords,
stitch_coords=stitch_coords,
sample_id=sample_idx,
)

yield tile, tile_info
2 changes: 1 addition & 1 deletion src/careamics/lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def predict_step(self, batch: Tensor, batch_idx: Any) -> Any:
denormalized_output = denorm(patch=output.cpu().numpy())

if len(aux) > 0: # aux can be tiling information
return denormalized_output, aux
return denormalized_output, *aux
else:
return denormalized_output

Expand Down
Loading

0 comments on commit 0a29ea2

Please sign in to comment.