Skip to content

Commit

Permalink
fix: no pool preds for random heuristic (#277)
Browse files Browse the repository at this point in the history
* fix: no pool preds for random heuristic

* use isinstance() instead of __name__ attribute

* test: add test for get_probabilities with Random

---------

Co-authored-by: Frédéric Branchaud-Charron <frederic.branchaud.charron@gmail.com>
  • Loading branch information
arthur-thuy and Dref360 authored Sep 13, 2023
1 parent f98fe7b commit ce55a85
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
5 changes: 4 additions & 1 deletion baal/active/active_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ def step(self, pool=None) -> bool:
indices = None

if len(pool) > 0:
probs = self.get_probabilities(pool, **self.kwargs)
if isinstance(self.heuristic, heuristics.Random):
probs = np.random.uniform(low=0, high=1, size=(len(pool), 1))
else:
probs = self.get_probabilities(pool, **self.kwargs)
if probs is not None and (isinstance(probs, types.GeneratorType) or len(probs) > 0):
to_label, uncertainty = self.heuristic.get_ranks(probs)
if indices is not None:
Expand Down
20 changes: 20 additions & 0 deletions tests/active/active_loop_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import pickle
import warnings
from unittest.mock import patch

import numpy as np
import pytest
Expand Down Expand Up @@ -140,5 +141,24 @@ def test_deprecation():
assert issubclass(w[-1].category, DeprecationWarning)
assert "ndata_to_label" in str(w[-1].message)


@pytest.mark.parametrize('heur,num_get_probs', [(heuristics.Random(), 0),
(heuristics.BALD(), 1),
(heuristics.Entropy(), 1),
(heuristics.Variance(reduction='sum'), 1)
])
def test_get_probs(heur, num_get_probs):
dataset = ActiveLearningDataset(MyDataset(), make_unlabelled=lambda x: -1)
active_loop = ActiveLearningLoop(dataset,
get_probs_iter,
heur,
query_size=5,
dummy_param=1)
dataset.label_randomly(10)
with patch.object(active_loop, "get_probabilities") as mock_probs:
active_loop.step()
assert mock_probs.call_count == num_get_probs


if __name__ == '__main__':
pytest.main()

0 comments on commit ce55a85

Please sign in to comment.