Skip to content

Commit

Permalink
Add Args to ModelWrapper to simplify common API (#294)
Browse files Browse the repository at this point in the history
* Add Args to ModelWrapper to simplify common API

* Update experiment scripts
  • Loading branch information
Dref360 authored May 26, 2024
1 parent f9b9ebf commit a9939fa
Show file tree
Hide file tree
Showing 33 changed files with 444 additions and 612 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,15 @@ In conclusion, your script should be similar to this:
dataset = ActiveLearningDataset(your_dataset)
dataset.label_randomly(INITIAL_POOL) # label some data
model = MCDropoutModule(your_model)
model = ModelWrapper(model, your_criterion)
model = ModelWrapper(model, args=TrainingArgs(...))
active_loop = ActiveLearningLoop(dataset,
get_probabilities=model.predict_on_dataset,
heuristic=heuristics.BALD(),
iterations=20, # Number of MC sampling.
query_size=QUERY_SIZE) # Number of item to label.
for al_step in range(N_ALSTEP):
model.train_on_dataset(dataset, optimizer, BATCH_SIZE, use_cuda=use_cuda)
metrics = model.test_on_dataset(test_dataset, BATCH_SIZE)
model.train_on_dataset(dataset)
metrics = model.test_on_dataset(test_dataset)
# Label the next most uncertain items.
if not active_loop.step():
# We're done!
Expand Down
6 changes: 4 additions & 2 deletions baal/active/dataset/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import warnings
from typing import Union, List, Optional, Any, TYPE_CHECKING, Protocol
from typing import Union, List, Optional, Any, TYPE_CHECKING, Protocol, Tuple

import numpy as np
from sklearn.utils import check_random_state
from torch.utils import data as torchdata

from baal.utils.equality import assert_not_none


class SizeableDataset(torchdata.Dataset):
def __len__(self):
Expand Down Expand Up @@ -40,7 +42,7 @@ def __init__(
if last_active_steps == 0 or last_active_steps < -1:
raise ValueError("last_active_steps must be > 0 or -1 when disabled.")
self.last_active_steps = last_active_steps
self._indices_cache = (-1, None)
self._indices_cache: Tuple[int, List[int]] = (-1, [])

def get_indices_for_active_step(self) -> List[int]:
"""Returns the indices required for the active step.
Expand Down
28 changes: 4 additions & 24 deletions baal/active/heuristics/heuristics_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,12 @@ class AbstractGPUHeuristic(ModelWrapper):
def __init__(
self,
model: ModelWrapper,
criterion,
shuffle_prop=0.0,
threshold=None,
reverse=False,
reduction="none",
):
super().__init__(model, criterion)
super().__init__(model, model.args)
self.shuffle_prop = shuffle_prop
self.threshold = threshold
self.reversed = reverse
Expand Down Expand Up @@ -102,32 +101,15 @@ def get_uncertainties(self, predictions):
def predict_on_dataset(
self,
dataset: Dataset,
batch_size: int,
iterations: int,
use_cuda: bool,
workers: int = 4,
collate_fn: Optional[Callable] = None,
half=False,
verbose=True,
):
return (
super()
.predict_on_dataset(
dataset,
batch_size,
iterations,
use_cuda,
workers,
collate_fn,
half,
verbose,
)
.reshape([-1])
)
return super().predict_on_dataset(dataset, iterations, half, verbose).reshape([-1])

def predict_on_batch(self, data, iterations=1, use_cuda=False):
def predict_on_batch(self, data, iterations=1):
"""Rank the predictions according to their uncertainties."""
return self.get_uncertainties(self.model.predict_on_batch(data, iterations, cuda=use_cuda))
return self.get_uncertainties(self.model.predict_on_batch(data, iterations))


class BALDGPUWrapper(AbstractGPUHeuristic):
Expand All @@ -139,14 +121,12 @@ class BALDGPUWrapper(AbstractGPUHeuristic):
def __init__(
self,
model: ModelWrapper,
criterion,
shuffle_prop=0.0,
threshold=None,
reduction="none",
):
super().__init__(
model,
criterion=criterion,
shuffle_prop=shuffle_prop,
threshold=threshold,
reverse=True,
Expand Down
13 changes: 8 additions & 5 deletions baal/active/stopping_criteria.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, Dict
from typing import Iterable, Dict, List

import numpy as np

Expand All @@ -21,7 +21,7 @@ def __init__(self, active_dataset: ActiveLearningDataset, labelling_budget: int)
self._start_length = len(active_dataset)
self.labelling_budget = labelling_budget

def should_stop(self, uncertainty: Iterable[float]) -> bool:
def should_stop(self, metrics: Dict[str, float], uncertainty: Iterable[float]) -> bool:
return (len(self._active_ds) - self._start_length) >= self.labelling_budget


Expand All @@ -33,7 +33,8 @@ def __init__(self, active_dataset: ActiveLearningDataset, avg_uncertainty_thresh
self.avg_uncertainty_thresh = avg_uncertainty_thresh

def should_stop(self, metrics: Dict[str, float], uncertainty: Iterable[float]) -> bool:
return np.mean(uncertainty) < self.avg_uncertainty_thresh
arr = np.array(uncertainty)
return bool(np.mean(arr) < self.avg_uncertainty_thresh)


class EarlyStoppingCriterion(StoppingCriterion):
Expand All @@ -55,9 +56,11 @@ def __init__(
self.metric_name = metric_name
self.patience = patience
self.epsilon = epsilon
self._acc = []
self._acc: List[float] = []

def should_stop(self, metrics: Dict[str, float], uncertainty: Iterable[float]) -> bool:
self._acc.append(metrics[self.metric_name])
near_threshold = np.isclose(np.array(self._acc), self._acc[-1], atol=self.epsilon)
return len(near_threshold) >= self.patience and near_threshold[-(self.patience + 1) :].all()
return len(near_threshold) >= self.patience and bool(
near_threshold[-(self.patience + 1) :].all()
)
48 changes: 19 additions & 29 deletions baal/calibration/calibration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from copy import deepcopy
from typing import Optional

import structlog
import torch
Expand All @@ -7,6 +8,7 @@
from torch.optim import Adam

from baal import ModelWrapper
from baal.modelwrapper import TrainingArgs
from baal.utils.metrics import ECE, ECE_PerCLs

log = structlog.get_logger("Calibrating...")
Expand Down Expand Up @@ -37,6 +39,7 @@ class DirichletCalibrator(object):
reg_factor (float): Regularization factor for the linear layer weights.
mu (float): Regularization factor for the linear layer biases.
If not given, will be initialized by "l".
training_duration (int): How long to train calibration layer.
"""

Expand All @@ -46,7 +49,8 @@ def __init__(
num_classes: int,
lr: float,
reg_factor: float,
mu: float = None,
mu: Optional[float] = None,
training_duration: int = 5,
):
self.num_classes = num_classes
self.criterion = nn.CrossEntropyLoss()
Expand All @@ -55,7 +59,17 @@ def __init__(
self.mu = mu or reg_factor
self.dirichlet_linear = nn.Linear(self.num_classes, self.num_classes)
self.model = nn.Sequential(wrapper.model, self.dirichlet_linear)
self.wrapper = ModelWrapper(self.model, self.criterion)
self.optimizer = Adam(self.dirichlet_linear.parameters(), lr=self.lr)
self.wrapper = ModelWrapper(
self.model,
TrainingArgs(
criterion=self.criterion,
optimizer=self.optimizer,
regularizer=self.l2_reg,
epoch=training_duration,
use_cuda=wrapper.args.use_cuda,
),
)

self.wrapper.add_metric("ece", lambda: ECE())
self.wrapper.add_metric("ece", lambda: ECE_PerCLs(num_classes))
Expand All @@ -75,8 +89,6 @@ def calibrate(
self,
train_set: Dataset,
test_set: Dataset,
batch_size: int,
epoch: int,
use_cuda: bool,
double_fit: bool = False,
**kwargs
Expand All @@ -88,8 +100,6 @@ def calibrate(
Args:
train_set (Dataset): The training set.
test_set (Dataset): The validation set.
batch_size (int): Batch size used.
epoch (int): Number of epochs to train the linear layer for.
use_cuda (bool): If "True", will use GPU.
double_fit (bool): If "True" would fit twice on the train set.
kwargs (dict): Rest of parameters for baal.ModelWrapper.train_and_test_on_dataset().
Expand All @@ -106,36 +116,16 @@ def calibrate(
if use_cuda:
self.dirichlet_linear.cuda()

optimizer = Adam(self.dirichlet_linear.parameters(), lr=self.lr)

loss_history, weights = self.wrapper.train_and_test_on_datasets(
train_set,
test_set,
optimizer,
batch_size,
epoch,
use_cuda,
regularizer=self.l2_reg,
return_best_weights=True,
patience=None,
**kwargs
train_set, test_set, return_best_weights=True, patience=None, **kwargs
)
self.model.load_state_dict(weights)

if double_fit:
lr = self.lr / 10
optimizer = Adam(self.dirichlet_linear.parameters(), lr=lr)
self.wrapper.args.optimizer = Adam(self.dirichlet_linear.parameters(), lr=lr)
loss_history, weights = self.wrapper.train_and_test_on_datasets(
train_set,
test_set,
optimizer,
batch_size,
epoch,
use_cuda,
regularizer=self.l2_reg,
return_best_weights=True,
patience=None,
**kwargs
train_set, test_set, return_best_weights=True, patience=None, **kwargs
)
self.model.load_state_dict(weights)

Expand Down
10 changes: 5 additions & 5 deletions baal/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import nn, Tensor

from baal import ModelWrapper
from baal.modelwrapper import _stack_preds
from baal.modelwrapper import _stack_preds, TrainingArgs
from baal.utils.cuda_utils import to_cuda


Expand All @@ -15,16 +15,16 @@ class EnsembleModelWrapper(ModelWrapper):
Args:
model (nn.Module): A Model.
criterion (Callable): Loss function
args (TrainingArgs): Argument for model
Notes:
If you're looking to use ensembles for non-deep models, see our sklearn tutorial:
baal.readthedocs.io/en/latest/notebooks/sklearn_tutorial.html
"""

def __init__(self, model, criterion):
super().__init__(model, criterion)
self._weights = []
def __init__(self, model, args: TrainingArgs):
super().__init__(model, args)
self._weights: List[Dict] = []

def add_checkpoint(self):
"""
Expand Down
Loading

0 comments on commit a9939fa

Please sign in to comment.