Skip to content

Commit

Permalink
Add stop training method (#249)
Browse files Browse the repository at this point in the history
### Description

If CAREamics is used in a multithreaded application (e.g. napari
plugin), then we need a way to stop the training programmatically. This
PR add a `stop_training` method to the `CAREamist`.

Additionally, this PR removes the unused parameter `experiment_name`
from the CAREamist.

- **What**: Add stop training method.
- **Why**: Enable stopping training in multithreaded applications (e.g.
napari)
- **How**: Add `stop_training` method that raises the `should_stop` flag
of Lightning.

### Changes Made

- **Modified**: `careamist.py`, by adding the `stop_training` function
and cleaning up unused keywords.

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [x] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)
  • Loading branch information
jdeschamps authored Oct 3, 2024
1 parent 7435ce2 commit a7f88cd
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 15 deletions.
31 changes: 16 additions & 15 deletions src/careamics/careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ class CAREamist:
work_dir : str, optional
Path to working directory in which to save checkpoints and logs,
by default None.
experiment_name : str, by default "CAREamics"
Experiment name used for checkpoints.
callbacks : list of Callback, optional
List of callbacks to use during training and prediction, by default None.
Expand All @@ -76,7 +74,6 @@ def __init__( # numpydoc ignore=GL08
self,
source: Union[Path, str],
work_dir: Optional[Union[Path, str]] = None,
experiment_name: str = "CAREamics",
callbacks: Optional[list[Callback]] = None,
) -> None: ...

Expand All @@ -85,15 +82,13 @@ def __init__( # numpydoc ignore=GL08
self,
source: Configuration,
work_dir: Optional[Union[Path, str]] = None,
experiment_name: str = "CAREamics",
callbacks: Optional[list[Callback]] = None,
) -> None: ...

def __init__(
self,
source: Union[Path, str, Configuration],
work_dir: Optional[Union[Path, str]] = None,
experiment_name: str = "CAREamics",
callbacks: Optional[list[Callback]] = None,
) -> None:
"""
Expand All @@ -106,18 +101,13 @@ def __init__(
If no working directory is provided, the current working directory is used.
If `source` is a checkpoint, then `experiment_name` is used to name the
checkpoint, and is recorded in the configuration.
Parameters
----------
source : pathlib.Path or str or CAREamics Configuration
Path to a configuration file or a trained model.
work_dir : str, optional
work_dir : str or pathlib.Path, optional
Path to working directory in which to save checkpoints and logs,
by default None.
experiment_name : str, optional
Experiment name used for checkpoints, by default "CAREamics".
callbacks : list of Callback, optional
List of callbacks to use during training and prediction, by default None.
Expand Down Expand Up @@ -257,6 +247,12 @@ def _define_callbacks(self, callbacks: Optional[list[Callback]] = None) -> None:
EarlyStopping(self.cfg.training_config.early_stopping_callback)
)

def stop_training(self) -> None:
"""Stop the training loop."""
# raise stop training flag
self.trainer.should_stop = True
self.trainer.limit_val_batches = 0 # skip validation

# TODO: is there are more elegant way than calling train again after _train_on_paths
def train(
self,
Expand Down Expand Up @@ -403,9 +399,14 @@ def _train_on_datamodule(self, datamodule: TrainDataModule) -> None:
datamodule : TrainDataModule
Datamodule to train on.
"""
# record datamodule
# register datamodule
self.train_datamodule = datamodule

# set defaults (in case `stop_training` was called before)
self.trainer.should_stop = False
self.trainer.limit_val_batches = 1.0 # 100%

# train
self.trainer.fit(self.model, datamodule=datamodule)

def _train_on_array(
Expand Down Expand Up @@ -521,7 +522,7 @@ def predict( # numpydoc ignore=GL08
tile_overlap: tuple[int, ...] = (48, 48),
axes: Optional[str] = None,
data_type: Optional[Literal["tiff", "custom"]] = None,
tta_transforms: bool = True,
tta_transforms: bool = False,
dataloader_params: Optional[dict] = None,
read_source_func: Optional[Callable] = None,
extension_filter: str = "",
Expand All @@ -537,7 +538,7 @@ def predict( # numpydoc ignore=GL08
tile_overlap: tuple[int, ...] = (48, 48),
axes: Optional[str] = None,
data_type: Optional[Literal["array"]] = None,
tta_transforms: bool = True,
tta_transforms: bool = False,
dataloader_params: Optional[dict] = None,
) -> Union[list[NDArray], NDArray]: ...

Expand All @@ -550,7 +551,7 @@ def predict(
tile_overlap: Optional[tuple[int, ...]] = (48, 48),
axes: Optional[str] = None,
data_type: Optional[Literal["array", "tiff", "custom"]] = None,
tta_transforms: bool = True,
tta_transforms: bool = False,
dataloader_params: Optional[dict] = None,
read_source_func: Optional[Callable] = None,
extension_filter: str = "",
Expand Down
31 changes: 31 additions & 0 deletions tests/test_careamist.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pathlib import Path
from threading import Thread
from typing import Tuple

import numpy as np
Expand Down Expand Up @@ -899,3 +900,33 @@ def test_error_passing_careamics_callback(tmp_path, minimum_configuration):

with pytest.raises(ValueError):
CAREamist(source=config, work_dir=tmp_path, callbacks=[hyper_params])


def test_stop_training(tmp_path: Path, minimum_configuration: dict):
"""Test that CAREamics can stop the training"""
# training data
train_array = random_array((32, 32))

# create configuration
config = Configuration(**minimum_configuration)
config.training_config.num_epochs = 1_000
config.data_config.axes = "YX"
config.data_config.batch_size = 2
config.data_config.data_type = SupportedData.ARRAY.value
config.data_config.patch_size = (8, 8)

# instantiate CAREamist
careamist = CAREamist(source=config, work_dir=tmp_path)

def _train():
careamist.train(train_source=train_array)

# create thread
thread = Thread(target=_train)
thread.start()

# stop training
careamist.stop_training()
thread.join()

assert careamist.trainer.should_stop

0 comments on commit a7f88cd

Please sign in to comment.