Skip to content

Commit

Permalink
Add last_active_step iteration to iterate over the last N active step…
Browse files Browse the repository at this point in the history
…s. (#174)

Co-authored-by: fr.branchaud-charron <fr.branchaud-charron@servicenow.com>
  • Loading branch information
Frédéric Branchaud-Charron and fr.branchaud-charron authored Dec 10, 2021
1 parent ed33f59 commit 51ab4a8
Show file tree
Hide file tree
Showing 21 changed files with 345 additions and 259 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ mypy:

.PHONY: check-mypy-error-count
check-mypy-error-count: MYPY_INFO = $(shell expr `poetry run mypy baal | grep ": error" | wc -l`)
check-mypy-error-count: MYPY_ERROR_COUNT = 16
check-mypy-error-count: MYPY_ERROR_COUNT = 9

check-mypy-error-count:
@if [ ${MYPY_INFO} -gt ${MYPY_ERROR_COUNT} ]; then \
Expand Down
4 changes: 4 additions & 0 deletions baal/active/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .numpy import ActiveNumpyArray
from .pytorch_dataset import ActiveLearningDataset

# Do not include HF Dataset here as it requires more deps.
126 changes: 126 additions & 0 deletions baal/active/dataset/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import warnings
from typing import Union, List, Optional, Any, Iterable, Sized

import numpy as np
from sklearn.utils import check_random_state
from torch.utils import data as torchdata


class SplittedDataset(torchdata.Dataset):
"""Abstract class for Dataset that can be splitted.
Args:
labelled: An array that acts as a mask which is greater than 1 for every
data point that is labelled, and 0 for every data point that is not
labelled.
random_state: Set the random seed for label_randomly().
last_active_steps: If specified, will iterate over the last active steps
instead of the full dataset. Useful when doing partial finetuning.
"""

def __init__(
self,
labelled,
random_state=None,
last_active_steps: int = -1,
) -> None:
self.labelled_map = labelled
self.random_state = check_random_state(random_state)
if last_active_steps == 0 or last_active_steps < -1:
raise ValueError("last_active_steps must be > 0 or -1 when disabled.")
self.last_active_steps = last_active_steps

def get_indices_for_active_step(self):
"""Returns the indices required for the active step.
Returns the indices of the labelled items. Also takes into account self.last_active_step.
Returns:
Array of the selected indices for training.
"""
if self.last_active_steps == -1:
min_labelled_step = 0
else:
min_labelled_step = max(0, self.current_al_step - self.last_active_steps)
indices = np.arange(len(self.labelled_map))
bool_mask = self.labelled_map > min_labelled_step
return indices[bool_mask]

def is_labelled(self, idx: int) -> bool:
"""Check if a datapoint is labelled."""
return bool(self.labelled[idx].item() == 1)

def __len__(self) -> int:
"""Return how many actual data / label pairs we have."""
return len(self.get_indices_for_active_step())

def __getitem__(self, index):
raise NotImplementedError

@property
def n_unlabelled(self):
"""The number of unlabelled data points."""
return (~self.labelled).sum()

@property
def n_labelled(self):
"""The number of labelled data points."""
return self.labelled.sum()

def label(self, index: Union[list, int], value: Optional[Any] = None) -> None:
"""
Label data points.
The index should be relative to the pool, not the overall data.
Args:
index: one or many indices to label.
value: The label value. If not provided, no modification
to the underlying dataset is done.
"""
raise NotImplementedError

def label_randomly(self, n: int = 1) -> None:
"""
Label `n` data-points randomly.
Args:
n (int): Number of samples to label.
"""
self.label(self.random_state.choice(self.n_unlabelled, n, replace=False).tolist())

@property
def _labelled(self):
warnings.warn(
"_labelled as been renamed labelled. Please update your script.", DeprecationWarning
)
return self.labelled

@property
def current_al_step(self) -> int:
"""Get the current active learning step."""
return int(self.labelled_map.max())

@property
def labelled(self):
"""An array that acts as a boolean mask which is True for every
data point that is labelled, and False for every data point that is not
labelled."""
return self.labelled_map.astype(bool)

def _labelled_to_oracle_index(self, index: int) -> int:
return int(self.labelled.nonzero()[0][index].squeeze().item())

def _pool_to_oracle_index(self, index: Union[int, List[int]]) -> List[int]:
if isinstance(index, np.int64) or isinstance(index, int):
index = [index]

lbl_nz = (~self.labelled).nonzero()[0]
return [int(lbl_nz[idx].squeeze().item()) for idx in index]

def _oracle_to_pool_index(self, index: Union[int, List[int]]) -> List[int]:
if isinstance(index, int):
index = [index]

# Pool indices are the unlabelled, starts at 0
lbl_cs = np.cumsum(~self.labelled) - 1
return [int(lbl_cs[idx].squeeze().item()) for idx in index]
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

import numpy as np
import torch
from datasets import Dataset as HFDataset
Expand Down Expand Up @@ -31,7 +33,7 @@ def __init__(
):
self.dataset = dataset
self.targets, self.texts = self.dataset[target_key], self.dataset[input_key]
self.targets_list = np.unique(self.targets).tolist()
self.targets_list: List = np.unique(self.targets).tolist()
self.input_ids, self.attention_masks = (
self._tokenize(tokenizer, max_seq_len) if tokenizer else ([], [])
)
Expand Down
72 changes: 72 additions & 0 deletions baal/active/dataset/numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from itertools import zip_longest
from typing import Tuple, Optional, Any, Union

import numpy as np

from baal.active.dataset.base import SplittedDataset


class ActiveNumpyArray(SplittedDataset):
"""
Active dataset for numpy arrays. Useful when using sklearn.
Args:
dataset (Tuple[ndarray, ndarray]): [Train x, train y], The dataset.
labelled (Union[np.ndarray, torch.Tensor]):
An array/tensor that acts as a boolean mask which is True for every
data point that is labelled, and False for every data point that is not
labelled.
random_state: Random seed for the ActiveLearningDataset
"""

def __init__(
self,
dataset: Tuple[np.ndarray, np.ndarray],
labelled: Optional[np.ndarray] = None,
random_state: Any = None,
) -> None:
self._dataset = dataset
# The labelled_map keeps track of the step at which an item as been labelled.
if labelled is not None:
labelled_map: np.ndarray = labelled.astype(int)
else:
labelled_map = np.zeros(len(self._dataset[0]), dtype=int)
super().__init__(labelled=labelled_map, random_state=random_state, last_active_steps=-1)

@property
def pool(self):
"""Return the unlabelled portion of the dataset."""
return self._dataset[0][~self.labelled], self._dataset[1][~self.labelled]

@property
def dataset(self):
"""Return the labelled portion of the dataset."""
return self._dataset[0][self.labelled], self._dataset[1][self.labelled]

def get_raw(self, idx: int) -> Any:
return self._dataset[0][idx], self._dataset[1][idx]

def __iter__(self):
return zip(*self._dataset)

def __getitem__(self, index):
index = self.get_indices_for_active_step()[index]
return self._dataset[0][index], self._dataset[1][index]

def label(self, index: Union[list, int], value: Optional[Any] = None) -> None:
"""
Label data points.
The index should be relative to the pool, not the overall data.
Args:
index (Union[list,int]): one or many indices to label.
value (Optional[Any]): The label value. If not provided, no modification
to the underlying dataset is done.
"""
if isinstance(index, int):
index = [index]
if not isinstance(value, (list, tuple)):
value = [value]
indexes = self._pool_to_oracle_index(index)
for index, val in zip_longest(indexes, value, fillvalue=None):
self.labelled_map[index] = 1
Loading

0 comments on commit 51ab4a8

Please sign in to comment.