Skip to content

Commit

Permalink
Solve bug where a dataset wouldnt be able to label (#178)
Browse files Browse the repository at this point in the history
Co-authored-by: fr.branchaud-charron <fr.branchaud-charron@servicenow.com>
  • Loading branch information
Frédéric Branchaud-Charron and fr.branchaud-charron authored Dec 17, 2021
1 parent 3b49a92 commit 3db1207
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 21 deletions.
48 changes: 28 additions & 20 deletions baal/active/dataset/pytorch_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from copy import deepcopy
from itertools import zip_longest
from typing import Union, Optional, Callable, Any, Dict
from typing import Union, Optional, Callable, Any, Dict, List

import numpy as np
import torch.utils.data as torchdata
Expand Down Expand Up @@ -32,13 +32,13 @@ class ActiveLearningDataset(SplittedDataset):
"""

def __init__(
self,
dataset: torchdata.Dataset,
labelled: Optional[np.ndarray] = None,
make_unlabelled: Callable = _identity,
random_state=None,
pool_specifics: Optional[dict] = None,
last_active_steps: int = -1,
self,
dataset: torchdata.Dataset,
labelled: Optional[np.ndarray] = None,
make_unlabelled: Callable = _identity,
random_state=None,
pool_specifics: Optional[dict] = None,
last_active_steps: int = -1,
) -> None:
self._dataset = dataset

Expand Down Expand Up @@ -143,20 +143,27 @@ def label(self, index: Union[list, int], value: Optional[Any] = None) -> None:
ValueError if the indices do not match the values.
"""
if isinstance(index, int):
index = [index]
if not isinstance(value, (list, tuple)):
value = [value]
if value[0] is not None and len(index) != len(value):
# We were provided only the index, we make a list.
index_lst = [index]
value_lst: List[Any] = [value]
else:
index_lst = index
if value is None:
value_lst = [value]
else:
value_lst = value

if value_lst[0] is not None and len(index_lst) != len(value_lst):
raise ValueError(
"Expected `index` and `value` to be of same length when `value` is provided."
f"Got index={len(index)} and value={len(value)}"
f"Got index={len(index_lst)} and value={len(value_lst)}"
)
indexes = self._pool_to_oracle_index(index)
indexes = self._pool_to_oracle_index(index_lst)
active_step = self.current_al_step + 1
for index, val in zip_longest(indexes, value, fillvalue=None):
for idx, val in zip_longest(indexes, value_lst, fillvalue=None):
if self.can_label and val is not None:
self._dataset.label(index, val)
self.labelled[index] = active_step
self._dataset.label(idx, val)
self.labelled_map[idx] = active_step
elif self.can_label and val is None:
warnings.warn(
"""The dataset is able to label data, but no label was provided.
Expand All @@ -166,12 +173,13 @@ def label(self, index: Union[list, int], value: Optional[Any] = None) -> None:
""",
UserWarning,
)
elif not self.can_label:
self.labelled_map[index] = active_step
else:
# Regular research usecase.
self.labelled_map[idx] = active_step
if val is not None:
warnings.warn(
"We will consider the original label of this datasample : {}, {}.".format(
self._dataset[index][0], self._dataset[index][1]
self._dataset[idx][0], self._dataset[idx][1]
),
UserWarning,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/active/dataset/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_pool(self):
assert self.dataset.n_labelled == labels_next_1

with pytest.raises(ValueError, match="same length"):
self.dataset.label(np.arange(0, 9), value=np.arange(1, 10))
self.dataset.label(np.arange(0, 9), value=np.arange(0, 10))

# cleanup
del self.dataset._dataset.label
Expand Down
13 changes: 13 additions & 0 deletions tests/active/file_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,19 @@ def test_labelling(self):
self.dataset.label(actually_not_labelled[0], 1)
assert sum([1 for i, j in enumerate(self.dataset.lbls) if j >= 0]) == 11

def test_active_labelling(self):
assert self.active.can_label
actually_not_labelled = [i for i, j in enumerate(self.lbls) if j < 0]
actually_labelled = [i for i, j in enumerate(self.lbls) if j >= 0]

init_length = len(self.active)
self.active.label(actually_not_labelled[0], 1)
assert len(self.active) == init_length + 1

with pytest.warns(UserWarning):
self.dataset.label(actually_labelled[0], None)
assert len(self.active) == init_length + 1

def test_filedataset_segmentation(self):
target_trans = Compose([default_image_load_fn,
Resize(60), RandomRotation(90), ToTensor()])
Expand Down

0 comments on commit 3db1207

Please sign in to comment.