diff --git a/baal/active/stopping_criteria.py b/baal/active/stopping_criteria.py new file mode 100644 index 0000000..ac9e669 --- /dev/null +++ b/baal/active/stopping_criteria.py @@ -0,0 +1,63 @@ +from typing import Iterable, Dict + +import numpy as np + +from baal import ActiveLearningDataset + + +class StoppingCriterion: + def __init__(self, active_dataset: ActiveLearningDataset): + self._active_ds = active_dataset + + def should_stop(self, metrics: Dict[str, float], uncertainty: Iterable[float]) -> bool: + raise NotImplementedError + + +class LabellingBudgetStoppingCriterion(StoppingCriterion): + """Stops when the labelling budget is exhausted.""" + + def __init__(self, active_dataset: ActiveLearningDataset, labelling_budget: int): + super().__init__(active_dataset) + self._start_length = len(active_dataset) + self.labelling_budget = labelling_budget + + def should_stop(self, uncertainty: Iterable[float]) -> bool: + return (len(self._active_ds) - self._start_length) >= self.labelling_budget + + +class LowAverageUncertaintyStoppingCriterion(StoppingCriterion): + """Stops when the average uncertainty is on average below a threshold.""" + + def __init__(self, active_dataset: ActiveLearningDataset, avg_uncertainty_thresh: float): + super().__init__(active_dataset) + self.avg_uncertainty_thresh = avg_uncertainty_thresh + + def should_stop(self, metrics: Dict[str, float], uncertainty: Iterable[float]) -> bool: + return np.mean(uncertainty) < self.avg_uncertainty_thresh + + +class EarlyStoppingCriterion(StoppingCriterion): + """Early stopping on a particular metrics. + + Notes: + We don't have any mandatory dependency with an early stopping implementation. + So we have our own. + """ + + def __init__( + self, + active_dataset: ActiveLearningDataset, + metric_name: str, + patience: int = 10, + epsilon: float = 1e-4, + ): + super().__init__(active_dataset) + self.metric_name = metric_name + self.patience = patience + self.epsilon = epsilon + self._acc = [] + + def should_stop(self, metrics: Dict[str, float], uncertainty: Iterable[float]) -> bool: + self._acc.append(metrics[self.metric_name]) + near_threshold = np.isclose(np.array(self._acc), self._acc[-1], atol=self.epsilon) + return len(near_threshold) >= self.patience and near_threshold[-(self.patience + 1) :].all() diff --git a/experiments/mlp_mcdropout.py b/experiments/mlp_mcdropout.py index 34f3d4a..b769052 100644 --- a/experiments/mlp_mcdropout.py +++ b/experiments/mlp_mcdropout.py @@ -9,6 +9,7 @@ from baal import ActiveLearningDataset, ModelWrapper from baal.active import ActiveLearningLoop from baal.active.heuristics import BALD +from baal.active.stopping_criteria import LabellingBudgetStoppingCriterion from baal.bayesian.dropout import patch_module use_cuda = torch.cuda.is_available() @@ -54,8 +55,11 @@ # Following Gal 2016, we reset the weights at the beginning of each step. initial_weights = deepcopy(model.state_dict()) +stopping_criterion = LabellingBudgetStoppingCriterion( + active_dataset=al_dataset, labelling_budget=10 +) -for step in range(100): +while True: model.load_state_dict(initial_weights) train_loss = wrapper.train_on_dataset( al_dataset, optimizer=optimizer, batch_size=32, epoch=10, use_cuda=use_cuda @@ -64,6 +68,5 @@ pprint(wrapper.get_metrics()) flag = al_loop.step() - if not flag: - # We are done labelling! stopping + if stopping_criterion.should_stop() or flag: break diff --git a/tests/active/criterion_test.py b/tests/active/criterion_test.py new file mode 100644 index 0000000..a586d0b --- /dev/null +++ b/tests/active/criterion_test.py @@ -0,0 +1,53 @@ +from baal.active.stopping_criteria import ( + LabellingBudgetStoppingCriterion, + EarlyStoppingCriterion, + LowAverageUncertaintyStoppingCriterion, +) +from baal.active.dataset import ActiveNumpyArray +import numpy as np + + +def test_labelling_budget(): + ds = ActiveNumpyArray((np.random.randn(100, 3), np.random.randint(0, 3, 100))) + ds.label_randomly(10) + criterion = LabellingBudgetStoppingCriterion(ds, labelling_budget=50) + assert not criterion.should_stop([]) + + ds.label_randomly(10) + assert not criterion.should_stop([]) + + ds.label_randomly(40) + assert criterion.should_stop([]) + + +def test_early_stopping(): + ds = ActiveNumpyArray((np.random.randn(100, 3), np.random.randint(0, 3, 100))) + criterion = EarlyStoppingCriterion(ds, "test_loss", patience=5) + + for i in range(10): + assert not criterion.should_stop( + metrics={"test_loss": 1 / (i + 1)}, uncertainty=[] + ) + + for _ in range(4): + assert not criterion.should_stop(metrics={"test_loss": 0.1}, uncertainty=[]) + assert criterion.should_stop(metrics={"test_loss": 0.1}, uncertainty=[]) + + # test less than patience stability + criterion = EarlyStoppingCriterion(ds, "test_loss", patience=5) + for _ in range(4): + assert not criterion.should_stop(metrics={"test_loss": 0.1}, uncertainty=[]) + assert criterion.should_stop(metrics={"test_loss": 0.1}, uncertainty=[]) + + +def test_low_average(): + ds = ActiveNumpyArray((np.random.randn(100, 3), np.random.randint(0, 3, 100))) + criterion = LowAverageUncertaintyStoppingCriterion( + active_dataset=ds, avg_uncertainty_thresh=0.1 + ) + assert not criterion.should_stop( + metrics={}, uncertainty=np.random.normal(0.5, scale=0.8, size=(100,)) + ) + assert criterion.should_stop( + metrics={}, uncertainty=np.random.normal(0.05, scale=0.01, size=(100,)) + )