Skip to content

Commit

Permalink
Add Stopping Criteria for loop (#286)
Browse files Browse the repository at this point in the history
* Add Stopping Criteria for loop

* Changes according to review
  • Loading branch information
Dref360 authored May 13, 2024
1 parent a82665d commit 4171b7a
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 3 deletions.
63 changes: 63 additions & 0 deletions baal/active/stopping_criteria.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import Iterable, Dict

import numpy as np

from baal import ActiveLearningDataset


class StoppingCriterion:
def __init__(self, active_dataset: ActiveLearningDataset):
self._active_ds = active_dataset

def should_stop(self, metrics: Dict[str, float], uncertainty: Iterable[float]) -> bool:
raise NotImplementedError


class LabellingBudgetStoppingCriterion(StoppingCriterion):
"""Stops when the labelling budget is exhausted."""

def __init__(self, active_dataset: ActiveLearningDataset, labelling_budget: int):
super().__init__(active_dataset)
self._start_length = len(active_dataset)
self.labelling_budget = labelling_budget

def should_stop(self, uncertainty: Iterable[float]) -> bool:
return (len(self._active_ds) - self._start_length) >= self.labelling_budget


class LowAverageUncertaintyStoppingCriterion(StoppingCriterion):
"""Stops when the average uncertainty is on average below a threshold."""

def __init__(self, active_dataset: ActiveLearningDataset, avg_uncertainty_thresh: float):
super().__init__(active_dataset)
self.avg_uncertainty_thresh = avg_uncertainty_thresh

def should_stop(self, metrics: Dict[str, float], uncertainty: Iterable[float]) -> bool:
return np.mean(uncertainty) < self.avg_uncertainty_thresh


class EarlyStoppingCriterion(StoppingCriterion):
"""Early stopping on a particular metrics.
Notes:
We don't have any mandatory dependency with an early stopping implementation.
So we have our own.
"""

def __init__(
self,
active_dataset: ActiveLearningDataset,
metric_name: str,
patience: int = 10,
epsilon: float = 1e-4,
):
super().__init__(active_dataset)
self.metric_name = metric_name
self.patience = patience
self.epsilon = epsilon
self._acc = []

def should_stop(self, metrics: Dict[str, float], uncertainty: Iterable[float]) -> bool:
self._acc.append(metrics[self.metric_name])
near_threshold = np.isclose(np.array(self._acc), self._acc[-1], atol=self.epsilon)
return len(near_threshold) >= self.patience and near_threshold[-(self.patience + 1) :].all()
9 changes: 6 additions & 3 deletions experiments/mlp_mcdropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from baal import ActiveLearningDataset, ModelWrapper
from baal.active import ActiveLearningLoop
from baal.active.heuristics import BALD
from baal.active.stopping_criteria import LabellingBudgetStoppingCriterion
from baal.bayesian.dropout import patch_module

use_cuda = torch.cuda.is_available()
Expand Down Expand Up @@ -54,8 +55,11 @@

# Following Gal 2016, we reset the weights at the beginning of each step.
initial_weights = deepcopy(model.state_dict())
stopping_criterion = LabellingBudgetStoppingCriterion(
active_dataset=al_dataset, labelling_budget=10
)

for step in range(100):
while True:
model.load_state_dict(initial_weights)
train_loss = wrapper.train_on_dataset(
al_dataset, optimizer=optimizer, batch_size=32, epoch=10, use_cuda=use_cuda
Expand All @@ -64,6 +68,5 @@

pprint(wrapper.get_metrics())
flag = al_loop.step()
if not flag:
# We are done labelling! stopping
if stopping_criterion.should_stop() or flag:
break
53 changes: 53 additions & 0 deletions tests/active/criterion_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from baal.active.stopping_criteria import (
LabellingBudgetStoppingCriterion,
EarlyStoppingCriterion,
LowAverageUncertaintyStoppingCriterion,
)
from baal.active.dataset import ActiveNumpyArray
import numpy as np


def test_labelling_budget():
ds = ActiveNumpyArray((np.random.randn(100, 3), np.random.randint(0, 3, 100)))
ds.label_randomly(10)
criterion = LabellingBudgetStoppingCriterion(ds, labelling_budget=50)
assert not criterion.should_stop([])

ds.label_randomly(10)
assert not criterion.should_stop([])

ds.label_randomly(40)
assert criterion.should_stop([])


def test_early_stopping():
ds = ActiveNumpyArray((np.random.randn(100, 3), np.random.randint(0, 3, 100)))
criterion = EarlyStoppingCriterion(ds, "test_loss", patience=5)

for i in range(10):
assert not criterion.should_stop(
metrics={"test_loss": 1 / (i + 1)}, uncertainty=[]
)

for _ in range(4):
assert not criterion.should_stop(metrics={"test_loss": 0.1}, uncertainty=[])
assert criterion.should_stop(metrics={"test_loss": 0.1}, uncertainty=[])

# test less than patience stability
criterion = EarlyStoppingCriterion(ds, "test_loss", patience=5)
for _ in range(4):
assert not criterion.should_stop(metrics={"test_loss": 0.1}, uncertainty=[])
assert criterion.should_stop(metrics={"test_loss": 0.1}, uncertainty=[])


def test_low_average():
ds = ActiveNumpyArray((np.random.randn(100, 3), np.random.randint(0, 3, 100)))
criterion = LowAverageUncertaintyStoppingCriterion(
active_dataset=ds, avg_uncertainty_thresh=0.1
)
assert not criterion.should_stop(
metrics={}, uncertainty=np.random.normal(0.5, scale=0.8, size=(100,))
)
assert criterion.should_stop(
metrics={}, uncertainty=np.random.normal(0.05, scale=0.01, size=(100,))
)

0 comments on commit 4171b7a

Please sign in to comment.