Skip to content

Commit

Permalink
support for multi-output models in heuristics (#24)
Browse files Browse the repository at this point in the history
* support for multi-output models in heuristics

* fix flake8 issues

* PR comments

* fix documentation for shuffle_prop
  • Loading branch information
parmidaatg authored and Frédéric Branchaud-Charron committed Nov 11, 2019
1 parent 9886273 commit 0677c57
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 11 deletions.
137 changes: 126 additions & 11 deletions src/baal/active/heuristics/heuristics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import types
import warnings
from functools import wraps as _wraps
from typing import List

import numpy as np
import scipy.stats
Expand Down Expand Up @@ -172,10 +173,11 @@ class BALD(AbstractHeuristic):
Sort by the highest acquisition function value.
Args:
shuffle_prop (float): Amount of noise to put in the ranking. Helps with selection bias.
shuffle_prop (float): Amount of noise to put in the ranking. Helps with selection bias
(default: 0.0).
threshold (Optional[Float]): Will ignore sample if the maximum prob is below this.
reduction (Union[str, callable]): function that aggregates the results.
reduction (Union[str, callable]): function that aggregates the results
(default: 'none`).
References:
https://arxiv.org/abs/1703.02910
Expand Down Expand Up @@ -215,9 +217,11 @@ class BatchBALD(BALD):
Args:
num_samples (int): Number of samples to select. (min 2*the amount of samples you want)
shuffle_prop (float): Amount of noise to put in the ranking. Helps with selection bias.
shuffle_prop (float): Amount of noise to put in the ranking. Helps with selection bias
(default: 0.0).
threshold (Optional[Float]): Will ignore sample if the maximum prob is below this.
reduction (Union[str, callable]): function that aggregates the results.
reduction (Union[str, callable]): function that aggregates the results
(default: 'none').
References:
https://arxiv.org/abs/1906.08158
Expand Down Expand Up @@ -320,9 +324,10 @@ class Variance(AbstractHeuristic):
Sort by the highest variance.
Args:
shuffle_prop (float): Amount of noise to put in the ranking. Helps with selection bias.
shuffle_prop (float): Amount of noise to put in the ranking. Helps with selection bias
(default: 0.0).
threshold (Optional[Float]): Will ignore sample if the maximum prob is below this.
reduction (Union[str, callable]): function that aggregates the results.
reduction (Union[str, callable]): function that aggregates the results (default: `mean`).
"""

def __init__(self, shuffle_prop=0.0, threshold=None, reduction='mean'):
Expand All @@ -342,9 +347,10 @@ class Entropy(AbstractHeuristic):
Sort by the highest entropy.
Args:
shuffle_prop (float): Amount of noise to put in the ranking. Helps with selection bias.
shuffle_prop (float): Amount of noise to put in the ranking. Helps with selection bias
(default: 0.0).
threshold (Optional[Float]): Will ignore sample if the maximum prob is below this.
reduction (Union[str, callable]): function that aggregates the results.
reduction (Union[str, callable]): function that aggregates the results (default: `none`).
"""

def __init__(self, shuffle_prop=0.0, threshold=None, reduction='none'):
Expand All @@ -364,9 +370,11 @@ class Margin(AbstractHeuristic):
the second most confident class.
Args:
shuffle_prop (float): Amount of noise to put in the ranking. Helps with selection bias.
shuffle_prop (float): Amount of noise to put in the ranking. Helps with selection bias
(default: 0.0).
threshold (Optional[Float]): Will ignore sample if the maximum prob is below this.
reduction (Union[str, callable]): function that aggregates the results.
reduction (Union[str, callable]): function that aggregates the results
(default: `none`).
"""

def __init__(self, shuffle_prop=0.0, threshold=None, reduction='none'):
Expand Down Expand Up @@ -435,3 +443,110 @@ def __init__(self, shuffle_prop=0.0, threshold=None, reverse=False):

def compute_score(self, predictions):
return predictions


class CombineHeuristics(AbstractHeuristic):
"""Combine heuristics for multi-output models.
heuristics would be applied on output predictions in the assigned order.
For each heuristic the necessary `reduction`, `reversed` and `threshold`
parameters should be defined.
NOTE: heuristics could be combined together only if they use the same
value for `reversed` parameter.
NOTE: `shuffle_prop` should only be defined as direct input of
`CombineHeuristics`, otherwise there will be no effect.
NOTE: `reduction` is defined for each of the input heuristics and as a direct
input to `CombineHeuristics`. For each heuristic, `reduction` should be defined
if the relevant model output to that heuristic has more than 3-dimenstions.
In `CombineHeuristics`, the `reduction` is used to aggregate the final result of
heuristics.
Args:
heuristics (list[AbstractHeuristic]): list of heuristic instances
weights (list[float]): the assigned weights to the result of each heuristic
before calculation of ranks
reduction (Union[str, callable]): function that aggregates the results of the heuristics
(default: weighted average which could be used as (reduction='mean`)
shuffle_prop (float): shuffle proportion.
"""
def __init__(self, heuristics: List, weights: List, reduction='mean', shuffle_prop=0.0):
super(CombineHeuristics, self).__init__(reduction=reduction, shuffle_prop=shuffle_prop)
self.composed_heuristic = heuristics
self.weights = weights

reversed = [bool(heuristic.reversed) for heuristic in self.composed_heuristic]

if all(item is False for item in reversed):
self.reversed = False
elif all(item is True for item in reversed):
self.reversed = True
else:
raise Exception("heuristics should have the same value for `revesed` parameter")

self.threshold = [bool(heuristic.threshold) for heuristic in self.composed_heuristic]

def get_uncertainties(self, predictions):
"""
Computes the score for each part of predictions according to the assigned heuristic.
NOTE: predictions is a list of each model outputs. For example for a object detection model,
the predictions should be as:
[confidence_predictions: nd.array(), boundingbox_predictions: nd.array()]
Args:
predictions (list[ndarray]): list of predictions arrays
Returns:
Array of uncertainties
"""

results = []
for ind, prediction in enumerate(predictions):
if isinstance(predictions[0], types.GeneratorType):
results.append(self.composed_heuristic[ind].get_uncertainties_generator(prediction))
else:
results.append(self.composed_heuristic[ind].get_uncertainties(prediction))
return results

def get_ranks(self, predictions):
"""
Rank the predictions according to the weighted vote of each heuristic.
Args:
predictions (list[ndarray]):
list[[batch_size, C, ..., Iterations], [batch_size, C, ..., Iterations], ...]
Returns:
Ranked index according to the uncertainty (highest to lowest).
"""

scores_list = self.get_uncertainties(predictions)

# normalizing weights
w = np.array(self.weights).sum()
self.weights = [weight / w for weight in self.weights]

# num_heuristics X batch_size
scores_array = np.vstack([weight * scores
for weight, scores in zip(self.weights, scores_list)])

# batch_size X num_heuristic
final_scores = self.reduction(np.swapaxes(scores_array, 0, -1))

assert final_scores.ndim == 1
ranks = np.argsort(final_scores)

for indx, threshold in enumerate(self.threshold):
if threshold:
ranks = np.asarray([idx for idx in ranks
if np.amax(predictions[indx][idx]) > threshold])

if self.reversed:
ranks = ranks[::-1]
ranks = _shuffle_subset(ranks, self.shuffle_prop)
return ranks
45 changes: 45 additions & 0 deletions tests/active/heuristic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
AbstractHeuristic,
requireprobs,
Precomputed,
CombineHeuristics,
)

N_ITERATIONS = 50
Expand Down Expand Up @@ -262,5 +263,49 @@ def test_that_precomputed_passes_back_predictions():
assert (precomputed(ranks) == ranks).all()


@pytest.mark.parametrize(
'heuristic1, heuristic2, weights',
[(BALD(), Variance(), [0.7, 0.3]),
(BALD(), Entropy(reduction='mean'), [0.9, 0.8]),
(Entropy(), Variance(), [4, 8]),
(Certainty(), Variance(), [9, 2]),
(Certainty(), Certainty(reduction='mean'), [1, 3])]
)
def test_combine_heuristics(heuristic1, heuristic2, weights):
np.random.seed(1337)
predictions = [distributions_3d, distributions_5d]

if isinstance(heuristic1, Certainty) and not isinstance(heuristic2, Certainty):
with pytest.raises(Exception) as e_info:
heuristics = CombineHeuristics([heuristic1, heuristic2], weights=weights,
reduction='mean')
assert 'heuristics should have the same value for `revesed` parameter' in str(e_info.value)
else:
heuristics = CombineHeuristics([heuristic1, heuristic2], weights=weights,
reduction='mean')
if isinstance(heuristic1, Certainty) and isinstance(heuristic2, Certainty):
assert not heuristics.reversed
else:
assert heuristics.reversed
ranks = heuristics(predictions)
assert np.all(ranks==[1, 2, 0]), "Combine Heuristics is not right {}".format(ranks)

def test_combine_heuristics_uncertainty_generator():
np.random.seed(1337)
prediction_chunks = [chunks(distributions_3d, 2), chunks(distributions_5d, 2)]
predictions = [distributions_3d, distributions_5d]

heuristics = CombineHeuristics([BALD(), Variance()], weights=[0.5, 0.5],
reduction='mean')

assert np.allclose(
heuristics.get_uncertainties(predictions),
heuristics.get_uncertainties(prediction_chunks),
)

prediction_chunks = [chunks(distributions_3d, 2), chunks(distributions_5d, 2)]
ranks = heuristics(prediction_chunks)
assert np.all(ranks == [1, 2, 0]), "Combine Heuristics is not right {}".format(ranks)

if __name__ == '__main__':
pytest.main()

0 comments on commit 0677c57

Please sign in to comment.