From 0677c57038a5297af9a10d676c5c4b6e39f88789 Mon Sep 17 00:00:00 2001 From: Parmida Atighehchian Date: Mon, 11 Nov 2019 11:01:43 -0500 Subject: [PATCH] support for multi-output models in heuristics (#24) * support for multi-output models in heuristics * fix flake8 issues * PR comments * fix documentation for shuffle_prop --- src/baal/active/heuristics/heuristics.py | 137 +++++++++++++++++++++-- tests/active/heuristic_test.py | 45 ++++++++ 2 files changed, 171 insertions(+), 11 deletions(-) diff --git a/src/baal/active/heuristics/heuristics.py b/src/baal/active/heuristics/heuristics.py index 87c1eed1..7b2ad084 100644 --- a/src/baal/active/heuristics/heuristics.py +++ b/src/baal/active/heuristics/heuristics.py @@ -1,6 +1,7 @@ import types import warnings from functools import wraps as _wraps +from typing import List import numpy as np import scipy.stats @@ -172,10 +173,11 @@ class BALD(AbstractHeuristic): Sort by the highest acquisition function value. Args: - shuffle_prop (float): Amount of noise to put in the ranking. Helps with selection bias. + shuffle_prop (float): Amount of noise to put in the ranking. Helps with selection bias + (default: 0.0). threshold (Optional[Float]): Will ignore sample if the maximum prob is below this. - reduction (Union[str, callable]): function that aggregates the results. - + reduction (Union[str, callable]): function that aggregates the results + (default: 'none`). References: https://arxiv.org/abs/1703.02910 @@ -215,9 +217,11 @@ class BatchBALD(BALD): Args: num_samples (int): Number of samples to select. (min 2*the amount of samples you want) - shuffle_prop (float): Amount of noise to put in the ranking. Helps with selection bias. + shuffle_prop (float): Amount of noise to put in the ranking. Helps with selection bias + (default: 0.0). threshold (Optional[Float]): Will ignore sample if the maximum prob is below this. - reduction (Union[str, callable]): function that aggregates the results. + reduction (Union[str, callable]): function that aggregates the results + (default: 'none'). References: https://arxiv.org/abs/1906.08158 @@ -320,9 +324,10 @@ class Variance(AbstractHeuristic): Sort by the highest variance. Args: - shuffle_prop (float): Amount of noise to put in the ranking. Helps with selection bias. + shuffle_prop (float): Amount of noise to put in the ranking. Helps with selection bias + (default: 0.0). threshold (Optional[Float]): Will ignore sample if the maximum prob is below this. - reduction (Union[str, callable]): function that aggregates the results. + reduction (Union[str, callable]): function that aggregates the results (default: `mean`). """ def __init__(self, shuffle_prop=0.0, threshold=None, reduction='mean'): @@ -342,9 +347,10 @@ class Entropy(AbstractHeuristic): Sort by the highest entropy. Args: - shuffle_prop (float): Amount of noise to put in the ranking. Helps with selection bias. + shuffle_prop (float): Amount of noise to put in the ranking. Helps with selection bias + (default: 0.0). threshold (Optional[Float]): Will ignore sample if the maximum prob is below this. - reduction (Union[str, callable]): function that aggregates the results. + reduction (Union[str, callable]): function that aggregates the results (default: `none`). """ def __init__(self, shuffle_prop=0.0, threshold=None, reduction='none'): @@ -364,9 +370,11 @@ class Margin(AbstractHeuristic): the second most confident class. Args: - shuffle_prop (float): Amount of noise to put in the ranking. Helps with selection bias. + shuffle_prop (float): Amount of noise to put in the ranking. Helps with selection bias + (default: 0.0). threshold (Optional[Float]): Will ignore sample if the maximum prob is below this. - reduction (Union[str, callable]): function that aggregates the results. + reduction (Union[str, callable]): function that aggregates the results + (default: `none`). """ def __init__(self, shuffle_prop=0.0, threshold=None, reduction='none'): @@ -435,3 +443,110 @@ def __init__(self, shuffle_prop=0.0, threshold=None, reverse=False): def compute_score(self, predictions): return predictions + + +class CombineHeuristics(AbstractHeuristic): + """Combine heuristics for multi-output models. + heuristics would be applied on output predictions in the assigned order. + For each heuristic the necessary `reduction`, `reversed` and `threshold` + parameters should be defined. + + NOTE: heuristics could be combined together only if they use the same + value for `reversed` parameter. + + NOTE: `shuffle_prop` should only be defined as direct input of + `CombineHeuristics`, otherwise there will be no effect. + + NOTE: `reduction` is defined for each of the input heuristics and as a direct + input to `CombineHeuristics`. For each heuristic, `reduction` should be defined + if the relevant model output to that heuristic has more than 3-dimenstions. + In `CombineHeuristics`, the `reduction` is used to aggregate the final result of + heuristics. + + Args: + heuristics (list[AbstractHeuristic]): list of heuristic instances + weights (list[float]): the assigned weights to the result of each heuristic + before calculation of ranks + reduction (Union[str, callable]): function that aggregates the results of the heuristics + (default: weighted average which could be used as (reduction='mean`) + shuffle_prop (float): shuffle proportion. + + """ + def __init__(self, heuristics: List, weights: List, reduction='mean', shuffle_prop=0.0): + super(CombineHeuristics, self).__init__(reduction=reduction, shuffle_prop=shuffle_prop) + self.composed_heuristic = heuristics + self.weights = weights + + reversed = [bool(heuristic.reversed) for heuristic in self.composed_heuristic] + + if all(item is False for item in reversed): + self.reversed = False + elif all(item is True for item in reversed): + self.reversed = True + else: + raise Exception("heuristics should have the same value for `revesed` parameter") + + self.threshold = [bool(heuristic.threshold) for heuristic in self.composed_heuristic] + + def get_uncertainties(self, predictions): + """ + Computes the score for each part of predictions according to the assigned heuristic. + + NOTE: predictions is a list of each model outputs. For example for a object detection model, + the predictions should be as: + [confidence_predictions: nd.array(), boundingbox_predictions: nd.array()] + + Args: + predictions (list[ndarray]): list of predictions arrays + + Returns: + Array of uncertainties + + """ + + results = [] + for ind, prediction in enumerate(predictions): + if isinstance(predictions[0], types.GeneratorType): + results.append(self.composed_heuristic[ind].get_uncertainties_generator(prediction)) + else: + results.append(self.composed_heuristic[ind].get_uncertainties(prediction)) + return results + + def get_ranks(self, predictions): + """ + Rank the predictions according to the weighted vote of each heuristic. + + Args: + predictions (list[ndarray]): + list[[batch_size, C, ..., Iterations], [batch_size, C, ..., Iterations], ...] + + Returns: + Ranked index according to the uncertainty (highest to lowest). + + """ + + scores_list = self.get_uncertainties(predictions) + + # normalizing weights + w = np.array(self.weights).sum() + self.weights = [weight / w for weight in self.weights] + + # num_heuristics X batch_size + scores_array = np.vstack([weight * scores + for weight, scores in zip(self.weights, scores_list)]) + + # batch_size X num_heuristic + final_scores = self.reduction(np.swapaxes(scores_array, 0, -1)) + + assert final_scores.ndim == 1 + ranks = np.argsort(final_scores) + + for indx, threshold in enumerate(self.threshold): + if threshold: + ranks = np.asarray([idx for idx in ranks + if np.amax(predictions[indx][idx]) > threshold]) + + if self.reversed: + ranks = ranks[::-1] + ranks = _shuffle_subset(ranks, self.shuffle_prop) + return ranks diff --git a/tests/active/heuristic_test.py b/tests/active/heuristic_test.py index c5bf6ed7..1b5b6b96 100644 --- a/tests/active/heuristic_test.py +++ b/tests/active/heuristic_test.py @@ -16,6 +16,7 @@ AbstractHeuristic, requireprobs, Precomputed, + CombineHeuristics, ) N_ITERATIONS = 50 @@ -262,5 +263,49 @@ def test_that_precomputed_passes_back_predictions(): assert (precomputed(ranks) == ranks).all() +@pytest.mark.parametrize( + 'heuristic1, heuristic2, weights', + [(BALD(), Variance(), [0.7, 0.3]), + (BALD(), Entropy(reduction='mean'), [0.9, 0.8]), + (Entropy(), Variance(), [4, 8]), + (Certainty(), Variance(), [9, 2]), + (Certainty(), Certainty(reduction='mean'), [1, 3])] +) +def test_combine_heuristics(heuristic1, heuristic2, weights): + np.random.seed(1337) + predictions = [distributions_3d, distributions_5d] + + if isinstance(heuristic1, Certainty) and not isinstance(heuristic2, Certainty): + with pytest.raises(Exception) as e_info: + heuristics = CombineHeuristics([heuristic1, heuristic2], weights=weights, + reduction='mean') + assert 'heuristics should have the same value for `revesed` parameter' in str(e_info.value) + else: + heuristics = CombineHeuristics([heuristic1, heuristic2], weights=weights, + reduction='mean') + if isinstance(heuristic1, Certainty) and isinstance(heuristic2, Certainty): + assert not heuristics.reversed + else: + assert heuristics.reversed + ranks = heuristics(predictions) + assert np.all(ranks==[1, 2, 0]), "Combine Heuristics is not right {}".format(ranks) + +def test_combine_heuristics_uncertainty_generator(): + np.random.seed(1337) + prediction_chunks = [chunks(distributions_3d, 2), chunks(distributions_5d, 2)] + predictions = [distributions_3d, distributions_5d] + + heuristics = CombineHeuristics([BALD(), Variance()], weights=[0.5, 0.5], + reduction='mean') + + assert np.allclose( + heuristics.get_uncertainties(predictions), + heuristics.get_uncertainties(prediction_chunks), + ) + + prediction_chunks = [chunks(distributions_3d, 2), chunks(distributions_5d, 2)] + ranks = heuristics(prediction_chunks) + assert np.all(ranks == [1, 2, 0]), "Combine Heuristics is not right {}".format(ranks) + if __name__ == '__main__': pytest.main()