From a7f88cd8deead497c8fb6f88186b55cbec45a48d Mon Sep 17 00:00:00 2001 From: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com> Date: Thu, 3 Oct 2024 13:49:56 +0200 Subject: [PATCH] Add stop training method (#249) ### Description If CAREamics is used in a multithreaded application (e.g. napari plugin), then we need a way to stop the training programmatically. This PR add a `stop_training` method to the `CAREamist`. Additionally, this PR removes the unused parameter `experiment_name` from the CAREamist. - **What**: Add stop training method. - **Why**: Enable stopping training in multithreaded applications (e.g. napari) - **How**: Add `stop_training` method that raises the `should_stop` flag of Lightning. ### Changes Made - **Modified**: `careamist.py`, by adding the `stop_training` function and cleaning up unused keywords. --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [x] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features) --- src/careamics/careamist.py | 31 ++++++++++++++++--------------- tests/test_careamist.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 15 deletions(-) diff --git a/src/careamics/careamist.py b/src/careamics/careamist.py index 6601bd46..d913c521 100644 --- a/src/careamics/careamist.py +++ b/src/careamics/careamist.py @@ -48,8 +48,6 @@ class CAREamist: work_dir : str, optional Path to working directory in which to save checkpoints and logs, by default None. - experiment_name : str, by default "CAREamics" - Experiment name used for checkpoints. callbacks : list of Callback, optional List of callbacks to use during training and prediction, by default None. @@ -76,7 +74,6 @@ def __init__( # numpydoc ignore=GL08 self, source: Union[Path, str], work_dir: Optional[Union[Path, str]] = None, - experiment_name: str = "CAREamics", callbacks: Optional[list[Callback]] = None, ) -> None: ... @@ -85,7 +82,6 @@ def __init__( # numpydoc ignore=GL08 self, source: Configuration, work_dir: Optional[Union[Path, str]] = None, - experiment_name: str = "CAREamics", callbacks: Optional[list[Callback]] = None, ) -> None: ... @@ -93,7 +89,6 @@ def __init__( self, source: Union[Path, str, Configuration], work_dir: Optional[Union[Path, str]] = None, - experiment_name: str = "CAREamics", callbacks: Optional[list[Callback]] = None, ) -> None: """ @@ -106,18 +101,13 @@ def __init__( If no working directory is provided, the current working directory is used. - If `source` is a checkpoint, then `experiment_name` is used to name the - checkpoint, and is recorded in the configuration. - Parameters ---------- source : pathlib.Path or str or CAREamics Configuration Path to a configuration file or a trained model. - work_dir : str, optional + work_dir : str or pathlib.Path, optional Path to working directory in which to save checkpoints and logs, by default None. - experiment_name : str, optional - Experiment name used for checkpoints, by default "CAREamics". callbacks : list of Callback, optional List of callbacks to use during training and prediction, by default None. @@ -257,6 +247,12 @@ def _define_callbacks(self, callbacks: Optional[list[Callback]] = None) -> None: EarlyStopping(self.cfg.training_config.early_stopping_callback) ) + def stop_training(self) -> None: + """Stop the training loop.""" + # raise stop training flag + self.trainer.should_stop = True + self.trainer.limit_val_batches = 0 # skip validation + # TODO: is there are more elegant way than calling train again after _train_on_paths def train( self, @@ -403,9 +399,14 @@ def _train_on_datamodule(self, datamodule: TrainDataModule) -> None: datamodule : TrainDataModule Datamodule to train on. """ - # record datamodule + # register datamodule self.train_datamodule = datamodule + # set defaults (in case `stop_training` was called before) + self.trainer.should_stop = False + self.trainer.limit_val_batches = 1.0 # 100% + + # train self.trainer.fit(self.model, datamodule=datamodule) def _train_on_array( @@ -521,7 +522,7 @@ def predict( # numpydoc ignore=GL08 tile_overlap: tuple[int, ...] = (48, 48), axes: Optional[str] = None, data_type: Optional[Literal["tiff", "custom"]] = None, - tta_transforms: bool = True, + tta_transforms: bool = False, dataloader_params: Optional[dict] = None, read_source_func: Optional[Callable] = None, extension_filter: str = "", @@ -537,7 +538,7 @@ def predict( # numpydoc ignore=GL08 tile_overlap: tuple[int, ...] = (48, 48), axes: Optional[str] = None, data_type: Optional[Literal["array"]] = None, - tta_transforms: bool = True, + tta_transforms: bool = False, dataloader_params: Optional[dict] = None, ) -> Union[list[NDArray], NDArray]: ... @@ -550,7 +551,7 @@ def predict( tile_overlap: Optional[tuple[int, ...]] = (48, 48), axes: Optional[str] = None, data_type: Optional[Literal["array", "tiff", "custom"]] = None, - tta_transforms: bool = True, + tta_transforms: bool = False, dataloader_params: Optional[dict] = None, read_source_func: Optional[Callable] = None, extension_filter: str = "", diff --git a/tests/test_careamist.py b/tests/test_careamist.py index 85e5345e..8b8f0632 100644 --- a/tests/test_careamist.py +++ b/tests/test_careamist.py @@ -1,4 +1,5 @@ from pathlib import Path +from threading import Thread from typing import Tuple import numpy as np @@ -899,3 +900,33 @@ def test_error_passing_careamics_callback(tmp_path, minimum_configuration): with pytest.raises(ValueError): CAREamist(source=config, work_dir=tmp_path, callbacks=[hyper_params]) + + +def test_stop_training(tmp_path: Path, minimum_configuration: dict): + """Test that CAREamics can stop the training""" + # training data + train_array = random_array((32, 32)) + + # create configuration + config = Configuration(**minimum_configuration) + config.training_config.num_epochs = 1_000 + config.data_config.axes = "YX" + config.data_config.batch_size = 2 + config.data_config.data_type = SupportedData.ARRAY.value + config.data_config.patch_size = (8, 8) + + # instantiate CAREamist + careamist = CAREamist(source=config, work_dir=tmp_path) + + def _train(): + careamist.train(train_source=train_array) + + # create thread + thread = Thread(target=_train) + thread.start() + + # stop training + careamist.stop_training() + thread.join() + + assert careamist.trainer.should_stop