Skip to content

Commit

Permalink
More correctness test (#10)
Browse files Browse the repository at this point in the history
* add

* rename

* add

* add

* stable

* add

* add

* add

* add
  • Loading branch information
kampersanda authored Sep 20, 2024
1 parent abae349 commit c878dc3
Show file tree
Hide file tree
Showing 9 changed files with 142 additions and 110 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,6 @@ jobs:
- name: Build elinor-evaluate
run: cargo build --release -p elinor-evaluate
- name: Run correctness test
run: python scripts/compare_with_trec_eval.py target/release/elinor-evaluate
run: |
./correctness-test/prepare_trec_eval.sh
./correctness-test/compare_with_trec_eval.py trec_eval-9.0.8 target/release
8 changes: 8 additions & 0 deletions correctness-test/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Correctness test compared with the reference implementation

## With trec_eval

```shell
./correctness-test/prepare_trec_eval.sh
./correctness-test/compare_with_trec_eval.py trec_eval-9.0.8 target/release
```
107 changes: 107 additions & 0 deletions correctness-test/compare_with_trec_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
#!/usr/bin/env python3
"""
Script to check the correctness of elinor by comparing its output with trec_eval.
"""

import argparse
import subprocess
import sys


def run_trec_eval(
trec_eval_dir: str, qrels_file: str, results_file: str
) -> dict[str, str]:
command = f"./{trec_eval_dir}/trec_eval -c -m all_trec {qrels_file} {results_file}"
print(f"Running: {command}")
result = subprocess.run(command, capture_output=True, shell=True)
parsed: dict[str, str] = {}
for line in result.stdout.decode("utf-8").split("\n"):
if not line:
continue
metric, _, value = line.split()
parsed[metric] = value
return parsed


def run_elinor_evaluate(
elinor_dir: str, qrels_file: str, results_file: str
) -> dict[str, str]:
ks = [0, 1, 5, 10, 15, 20, 30, 100, 200, 500, 1000]
ks_args = " ".join([f"-k {k}" for k in ks])
command = (
f"./{elinor_dir}/elinor-evaluate -q {qrels_file} -r {results_file} {ks_args}"
)
print(f"Running: {command}")
result = subprocess.run(command, capture_output=True, shell=True)
parsed: dict[str, str] = {}
for line in result.stdout.decode("utf-8").split("\n"):
if not line:
continue
metric, value = line.split()
parsed[metric] = value
return parsed


def compare_decimal_places(a: str, b: str, decimal_places: int) -> bool:
return round(float(a), decimal_places) == round(float(b), decimal_places)


if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument("trec_eval_dir")
p.add_argument("elinor_dir")
p.add_argument("--decimal-places", type=int, default=3)
args = p.parse_args()

trec_eval_dir: str = args.trec_eval_dir
elinor_dir: str = args.elinor_dir
decimal_places: int = args.decimal_places

failed_ids = []
test_data = [
(f"{trec_eval_dir}/test/qrels.test", f"{trec_eval_dir}/test/results.test"),
(f"{trec_eval_dir}/test/qrels.rel_level", f"{trec_eval_dir}/test/results.test"),
]

for data_id, (qrels_file, results_file) in enumerate(test_data, 1):
trec_results = run_trec_eval(trec_eval_dir, qrels_file, results_file)
elinor_results = run_elinor_evaluate(elinor_dir, qrels_file, results_file)

metric_pairs = []
metric_pairs.extend([(f"success_{k}", f"success@{k}") for k in [1, 5, 10]])
metric_pairs.extend(
[
("set_P", "precision"),
("set_recall", "recall"),
("set_F", "f1"),
("Rprec", "r_precision"),
("map", "ap"),
("recip_rank", "rr"),
("ndcg", "ndcg"),
("bpref", "bpref"),
]
)

ks = [5, 10, 15, 20, 30, 100, 200, 500, 1000]
metric_pairs.extend([(f"P_{k}", f"precision@{k}") for k in ks])
metric_pairs.extend([(f"recall_{k}", f"recall@{k}") for k in ks])
metric_pairs.extend([(f"map_cut_{k}", f"ap@{k}") for k in ks])
metric_pairs.extend([(f"ndcg_cut_{k}", f"ndcg@{k}") for k in ks])

print("case_id\ttrec_metric\telinor_metric\ttrec_score\telinor_score\tmatch")
for metric_id, (trec_metric, elinor_metric) in enumerate(metric_pairs, 1):
case_id = f"{data_id}.{metric_id}"
trec_score = trec_results[trec_metric]
elinor_score = elinor_results[elinor_metric]
match = compare_decimal_places(trec_score, elinor_score, decimal_places)
row = f"{case_id}\t{trec_metric}\t{elinor_metric}\t{trec_score}\t{elinor_score}\t{match}"
print(row)

if not match:
failed_ids.append(case_id)

if failed_ids:
print("Mismatched cases:", failed_ids, file=sys.stderr)
sys.exit(1)
else:
print(f"All metrics match 🎉 with {decimal_places=}")
15 changes: 15 additions & 0 deletions correctness-test/prepare_trec_eval.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#! /bin/bash

set -eux

# TREC EVAL
TREC_VERSION="9.0.8"
if [ -d "trec_eval-$TREC_VERSION" ]; then
echo "Directory trec_eval-$TREC_VERSION exists."
else
echo "Directory trec_eval-$TREC_VERSION does not exist."
rm -f v$TREC_VERSION.tar.gz
wget https://github.com/usnistgov/trec_eval/archive/refs/tags/v$TREC_VERSION.tar.gz
tar -xf v$TREC_VERSION.tar.gz
make -C trec_eval-$TREC_VERSION
fi
101 changes: 0 additions & 101 deletions scripts/compare_with_trec_eval.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ pub fn evaluate<K, M>(
metrics: M,
) -> Result<Evaluated<K>, errors::ElinorError>
where
K: Clone + Eq + std::hash::Hash + std::fmt::Display,
K: Clone + Eq + Ord + std::hash::Hash + std::fmt::Display,
M: IntoIterator<Item = Metric>,
{
let metrics: HashSet<Metric> = metrics.into_iter().collect();
Expand Down
2 changes: 1 addition & 1 deletion src/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ pub fn compute_metric<K>(
metric: Metric,
) -> Result<HashMap<K, f64>, ElinorError>
where
K: Clone + Eq + std::hash::Hash + std::fmt::Display,
K: Clone + Eq + Ord + std::hash::Hash + std::fmt::Display,
{
for query_id in run.query_ids() {
if qrels.get_map(query_id).is_none() {
Expand Down
6 changes: 3 additions & 3 deletions src/relevance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub struct RelevanceStore<K, T> {

impl<K, T> RelevanceStore<K, T>
where
K: Eq + Hash + Clone,
K: Eq + Ord + Hash + Clone,
T: Ord + Clone,
{
/// Creates a relevance store from a map of query ids to relevance maps.
Expand Down Expand Up @@ -162,7 +162,7 @@ impl<K, T> RelevanceStoreBuilder<K, T> {
/// Builds the relevance store.
pub fn build(self) -> RelevanceStore<K, T>
where
K: Eq + Hash + Clone + Display,
K: Eq + Ord + Hash + Clone + Display,
T: Ord + Clone,
{
let mut map = HashMap::new();
Expand All @@ -174,7 +174,7 @@ impl<K, T> RelevanceStoreBuilder<K, T> {
score: score.clone(),
})
.collect::<Vec<_>>();
sorted.sort_by(|a, b| b.score.cmp(&a.score));
sorted.sort_by(|a, b| b.score.cmp(&a.score).then(a.doc_id.cmp(&b.doc_id)));
map.insert(query_id, RelevanceData { sorted, map: rels });
}
RelevanceStore { name: None, map }
Expand Down
7 changes: 4 additions & 3 deletions src/trec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,15 @@ where
for line in lines {
let line = line.as_ref();
let rows = line.split_whitespace().collect::<Vec<_>>();
if rows.len() != 4 {
if rows.len() < 4 {
return Err(ElinorError::InvalidFormat(line.to_string()));
}
let query_id = rows[0].to_string();
let doc_id = rows[2].to_string();
let score = rows[3]
.parse::<GoldScore>()
.parse::<i32>()
.map_err(|_| ElinorError::InvalidFormat(format!("Invalid score: {}", rows[3])))?;
let score = GoldScore::try_from(score.max(0)).unwrap();
b.add_score(query_id, doc_id, score)?;
}
Ok(b.build())
Expand Down Expand Up @@ -98,7 +99,7 @@ where
for line in lines {
let line = line.as_ref();
let rows = line.split_whitespace().collect::<Vec<_>>();
if rows.len() != 6 {
if rows.len() < 6 {
return Err(ElinorError::InvalidFormat(line.to_string()));
}
let query_id = rows[0].to_string();
Expand Down

0 comments on commit c878dc3

Please sign in to comment.