Skip to content

Commit

Permalink
Add support for std dev
Browse files Browse the repository at this point in the history
  • Loading branch information
AmenRa committed Nov 28, 2023
1 parent 42e8032 commit f543163
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 2 deletions.
5 changes: 5 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.3.19] - 2022-11-28
### Added
- `Run` now has an additional property to store metrics standard deviation.
- `evaluate` now has `return_std` flag to compute metrics standard deviation.

## [0.3.18] - 2022-09-29
### Changed
- `Qrels.from_df` now checks that scores are `numpy.int64` to avoid errors on Windows.
Expand Down
1 change: 1 addition & 0 deletions ranx/data_structures/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(self, run: Dict[str, Dict[str, float]] = None, name: str = None):
self.metadata = {}
self.scores = defaultdict(dict)
self.mean_scores = {}
self.std_scores = {}

def keys(self):
"""Returns query ids. Used internally."""
Expand Down
15 changes: 14 additions & 1 deletion ranx/meta/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def evaluate(
],
metrics: Union[List[str], str],
return_mean: bool = True,
return_std: bool = False,
threads: int = 0,
save_results_in_run: bool = True,
make_comparable: bool = False,
Expand Down Expand Up @@ -124,6 +125,9 @@ def evaluate(
elif threads != 0:
set_num_threads(threads)

if not return_mean:
return_std = False

if make_comparable and type(qrels) == Qrels and type(run) == Run:
run = run.make_comparable(qrels)

Expand All @@ -145,12 +149,21 @@ def evaluate(
if type(run) == Run and save_results_in_run:
for m, scores in metric_scores_dict.items():
run.mean_scores[m] = np.mean(scores)
if return_std:
run.std_scores[m] = np.std(scores)
for i, q_id in enumerate(run.get_query_ids()):
run.scores[m][q_id] = scores[i]

# Prepare output -----------------------------------------------------------
if return_mean:
for m, scores in metric_scores_dict.items():
metric_scores_dict[m] = np.mean(scores)
if return_std:
metric_scores_dict[m] = {
"mean": np.mean(scores),
"std": np.std(scores),
}

else:
metric_scores_dict[m] = np.mean(scores)

return metric_scores_dict[m] if len(metrics) == 1 else metric_scores_dict
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="ranx",
version="0.3.18",
version="0.3.19",
author="Elias Bassani",
author_email="elias.bssn@gmail.com",
description="ranx: A Blazing-Fast Python Library for Ranking Evaluation, Comparison, and Fusion",
Expand Down
42 changes: 42 additions & 0 deletions tests/unit/ranx/meta/evaluate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,3 +538,45 @@ def test_keys_control():

with pytest.raises(Exception):
evaluate(qrels, run, "ndcg@5")


def test_std():
qrels_dict = {"q_1": {"d_12": 5, "d_25": 3}, "q_2": {"d_11": 6, "d_22": 1}}

run_dict = {
"q_1": {
"d_12": 0.9,
"d_23": 0.8,
"d_25": 0.7,
"d_36": 0.6,
"d_32": 0.5,
"d_35": 0.4,
},
"q_2": {
"d_12": 0.9,
"d_11": 0.8,
"d_25": 0.7,
"d_36": 0.6,
"d_22": 0.5,
"d_35": 0.4,
},
}

qrels = Qrels(qrels_dict)
run = Run(run_dict)

res = evaluate(qrels, run, ["map@5", "mrr"], return_std=True)

assert len(res.keys()) == 2
assert "map@5" in res
assert "mrr" in res
assert "mean" in res["map@5"]
assert "std" in res["map@5"]
assert "mean" in res["mrr"]
assert "std" in res["mrr"]
assert res["map@5"]["std"] == np.std(list(run.scores["map@5"].values()))
assert res["mrr"]["std"] == np.std(list(run.scores["mrr"].values()))
assert run.std_scores == {
"map@5": np.std(list(run.scores["map@5"].values())),
"mrr": np.std(list(run.scores["mrr"].values())),
}

0 comments on commit f543163

Please sign in to comment.