From c878dc36a6c3270fbefd703d195902bee57b165a Mon Sep 17 00:00:00 2001 From: Shunsuke Kanda Date: Fri, 20 Sep 2024 22:20:55 +0900 Subject: [PATCH] More correctness test (#10) * add * rename * add * add * stable * add * add * add * add --- .github/workflows/ci.yml | 4 +- correctness-test/README.md | 8 ++ correctness-test/compare_with_trec_eval.py | 107 +++++++++++++++++++++ correctness-test/prepare_trec_eval.sh | 15 +++ scripts/compare_with_trec_eval.py | 101 ------------------- src/lib.rs | 2 +- src/metrics.rs | 2 +- src/relevance.rs | 6 +- src/trec.rs | 7 +- 9 files changed, 142 insertions(+), 110 deletions(-) create mode 100644 correctness-test/README.md create mode 100755 correctness-test/compare_with_trec_eval.py create mode 100755 correctness-test/prepare_trec_eval.sh delete mode 100755 scripts/compare_with_trec_eval.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 80e5adb..34d647a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -67,4 +67,6 @@ jobs: - name: Build elinor-evaluate run: cargo build --release -p elinor-evaluate - name: Run correctness test - run: python scripts/compare_with_trec_eval.py target/release/elinor-evaluate + run: | + ./correctness-test/prepare_trec_eval.sh + ./correctness-test/compare_with_trec_eval.py trec_eval-9.0.8 target/release diff --git a/correctness-test/README.md b/correctness-test/README.md new file mode 100644 index 0000000..f90fca6 --- /dev/null +++ b/correctness-test/README.md @@ -0,0 +1,8 @@ +# Correctness test compared with the reference implementation + +## With trec_eval + +```shell +./correctness-test/prepare_trec_eval.sh +./correctness-test/compare_with_trec_eval.py trec_eval-9.0.8 target/release +``` diff --git a/correctness-test/compare_with_trec_eval.py b/correctness-test/compare_with_trec_eval.py new file mode 100755 index 0000000..5920492 --- /dev/null +++ b/correctness-test/compare_with_trec_eval.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 +""" +Script to check the correctness of elinor by comparing its output with trec_eval. +""" + +import argparse +import subprocess +import sys + + +def run_trec_eval( + trec_eval_dir: str, qrels_file: str, results_file: str +) -> dict[str, str]: + command = f"./{trec_eval_dir}/trec_eval -c -m all_trec {qrels_file} {results_file}" + print(f"Running: {command}") + result = subprocess.run(command, capture_output=True, shell=True) + parsed: dict[str, str] = {} + for line in result.stdout.decode("utf-8").split("\n"): + if not line: + continue + metric, _, value = line.split() + parsed[metric] = value + return parsed + + +def run_elinor_evaluate( + elinor_dir: str, qrels_file: str, results_file: str +) -> dict[str, str]: + ks = [0, 1, 5, 10, 15, 20, 30, 100, 200, 500, 1000] + ks_args = " ".join([f"-k {k}" for k in ks]) + command = ( + f"./{elinor_dir}/elinor-evaluate -q {qrels_file} -r {results_file} {ks_args}" + ) + print(f"Running: {command}") + result = subprocess.run(command, capture_output=True, shell=True) + parsed: dict[str, str] = {} + for line in result.stdout.decode("utf-8").split("\n"): + if not line: + continue + metric, value = line.split() + parsed[metric] = value + return parsed + + +def compare_decimal_places(a: str, b: str, decimal_places: int) -> bool: + return round(float(a), decimal_places) == round(float(b), decimal_places) + + +if __name__ == "__main__": + p = argparse.ArgumentParser() + p.add_argument("trec_eval_dir") + p.add_argument("elinor_dir") + p.add_argument("--decimal-places", type=int, default=3) + args = p.parse_args() + + trec_eval_dir: str = args.trec_eval_dir + elinor_dir: str = args.elinor_dir + decimal_places: int = args.decimal_places + + failed_ids = [] + test_data = [ + (f"{trec_eval_dir}/test/qrels.test", f"{trec_eval_dir}/test/results.test"), + (f"{trec_eval_dir}/test/qrels.rel_level", f"{trec_eval_dir}/test/results.test"), + ] + + for data_id, (qrels_file, results_file) in enumerate(test_data, 1): + trec_results = run_trec_eval(trec_eval_dir, qrels_file, results_file) + elinor_results = run_elinor_evaluate(elinor_dir, qrels_file, results_file) + + metric_pairs = [] + metric_pairs.extend([(f"success_{k}", f"success@{k}") for k in [1, 5, 10]]) + metric_pairs.extend( + [ + ("set_P", "precision"), + ("set_recall", "recall"), + ("set_F", "f1"), + ("Rprec", "r_precision"), + ("map", "ap"), + ("recip_rank", "rr"), + ("ndcg", "ndcg"), + ("bpref", "bpref"), + ] + ) + + ks = [5, 10, 15, 20, 30, 100, 200, 500, 1000] + metric_pairs.extend([(f"P_{k}", f"precision@{k}") for k in ks]) + metric_pairs.extend([(f"recall_{k}", f"recall@{k}") for k in ks]) + metric_pairs.extend([(f"map_cut_{k}", f"ap@{k}") for k in ks]) + metric_pairs.extend([(f"ndcg_cut_{k}", f"ndcg@{k}") for k in ks]) + + print("case_id\ttrec_metric\telinor_metric\ttrec_score\telinor_score\tmatch") + for metric_id, (trec_metric, elinor_metric) in enumerate(metric_pairs, 1): + case_id = f"{data_id}.{metric_id}" + trec_score = trec_results[trec_metric] + elinor_score = elinor_results[elinor_metric] + match = compare_decimal_places(trec_score, elinor_score, decimal_places) + row = f"{case_id}\t{trec_metric}\t{elinor_metric}\t{trec_score}\t{elinor_score}\t{match}" + print(row) + + if not match: + failed_ids.append(case_id) + + if failed_ids: + print("Mismatched cases:", failed_ids, file=sys.stderr) + sys.exit(1) + else: + print(f"All metrics match 🎉 with {decimal_places=}") diff --git a/correctness-test/prepare_trec_eval.sh b/correctness-test/prepare_trec_eval.sh new file mode 100755 index 0000000..9095cac --- /dev/null +++ b/correctness-test/prepare_trec_eval.sh @@ -0,0 +1,15 @@ +#! /bin/bash + +set -eux + +# TREC EVAL +TREC_VERSION="9.0.8" +if [ -d "trec_eval-$TREC_VERSION" ]; then + echo "Directory trec_eval-$TREC_VERSION exists." +else + echo "Directory trec_eval-$TREC_VERSION does not exist." + rm -f v$TREC_VERSION.tar.gz + wget https://github.com/usnistgov/trec_eval/archive/refs/tags/v$TREC_VERSION.tar.gz + tar -xf v$TREC_VERSION.tar.gz + make -C trec_eval-$TREC_VERSION +fi diff --git a/scripts/compare_with_trec_eval.py b/scripts/compare_with_trec_eval.py deleted file mode 100755 index 89f187c..0000000 --- a/scripts/compare_with_trec_eval.py +++ /dev/null @@ -1,101 +0,0 @@ -""" -Script to check the correctness of elinor by comparing its output with trec_eval. - -Usage: - $ python3 ./scripts/compare_with_trec_eval.py ./target/release/elinor-evaluate -""" - -import argparse -import os -import subprocess -import sys - - -def download_trec_eval(): - if os.path.exists("trec_eval-9.0.8"): - print("trec_eval-9.0.8 already exists", file=sys.stderr) - return - subprocess.run("rm -f v9.0.8.tar.gz", shell=True) - subprocess.run( - "wget https://github.com/usnistgov/trec_eval/archive/refs/tags/v9.0.8.tar.gz", - shell=True, - ) - subprocess.run("tar -xf v9.0.8.tar.gz", shell=True) - subprocess.run("make -C trec_eval-9.0.8", shell=True) - - -def run_trec_eval() -> dict[str, str]: - command = "./trec_eval-9.0.8/trec_eval -c -m all_trec trec_eval-9.0.8/test/qrels.test trec_eval-9.0.8/test/results.test" - result = subprocess.run(command, capture_output=True, shell=True) - parsed: dict[str, str] = {} - for line in result.stdout.decode("utf-8").split("\n"): - if not line: - continue - metric, _, value = line.split() - parsed[metric] = value - return parsed - - -def run_elinor(elinor_exe: str) -> dict[str, str]: - ks = [0, 1, 5, 10, 15, 20, 30, 100, 200, 500, 1000] - command = ( - f"{elinor_exe} -q trec_eval-9.0.8/test/qrels.test -r trec_eval-9.0.8/test/results.test" - + "".join([f" -k {k}" for k in ks]) - ) - result = subprocess.run(command, capture_output=True, shell=True) - parsed: dict[str, str] = {} - for line in result.stdout.decode("utf-8").split("\n"): - if not line: - continue - metric, value = line.split() - parsed[metric] = value - return parsed - - -if __name__ == "__main__": - p = argparse.ArgumentParser() - p.add_argument("elinor_exe") - args = p.parse_args() - - download_trec_eval() - trec_results = run_trec_eval() - elinor_results = run_elinor(args.elinor_exe) - - ks = [5, 10, 15, 20, 30, 100, 200, 500, 1000] - - metric_pairs = [] - metric_pairs.extend([(f"success_{k}", f"success@{k}") for k in [1, 5, 10]]) - metric_pairs.extend( - [ - ("set_P", "precision"), - ("set_recall", "recall"), - ("set_F", "f1"), - ("Rprec", "r_precision"), - ("map", "ap"), - ("recip_rank", "rr"), - ("ndcg", "ndcg"), - ("bpref", "bpref"), - ] - ) - metric_pairs.extend([(f"P_{k}", f"precision@{k}") for k in ks]) - metric_pairs.extend([(f"recall_{k}", f"recall@{k}") for k in ks]) - metric_pairs.extend([(f"map_cut_{k}", f"ap@{k}") for k in ks]) - metric_pairs.extend([(f"ndcg_cut_{k}", f"ndcg@{k}") for k in ks]) - - failed_rows = [] - - print("trec_metric\telinor_metric\ttrec_score\telinor_score\tmatch") - for trec_metric, elinor_metric in metric_pairs: - trec_score = trec_results[trec_metric] - elinor_score = elinor_results[elinor_metric] - match = trec_score == elinor_score - row = f"{trec_metric}\t{elinor_metric}\t{trec_score}\t{elinor_score}\t{match}" - if not match: - failed_rows.append(row) - print(row) - - if failed_rows: - print("\nFailed rows:", file=sys.stderr) - for row in failed_rows: - print(row, file=sys.stderr) - sys.exit(1) diff --git a/src/lib.rs b/src/lib.rs index f3af5b1..bcf4cf6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -120,7 +120,7 @@ pub fn evaluate( metrics: M, ) -> Result, errors::ElinorError> where - K: Clone + Eq + std::hash::Hash + std::fmt::Display, + K: Clone + Eq + Ord + std::hash::Hash + std::fmt::Display, M: IntoIterator, { let metrics: HashSet = metrics.into_iter().collect(); diff --git a/src/metrics.rs b/src/metrics.rs index 4d93d9c..cf5dd2e 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -322,7 +322,7 @@ pub fn compute_metric( metric: Metric, ) -> Result, ElinorError> where - K: Clone + Eq + std::hash::Hash + std::fmt::Display, + K: Clone + Eq + Ord + std::hash::Hash + std::fmt::Display, { for query_id in run.query_ids() { if qrels.get_map(query_id).is_none() { diff --git a/src/relevance.rs b/src/relevance.rs index fb0113e..f6bf043 100644 --- a/src/relevance.rs +++ b/src/relevance.rs @@ -39,7 +39,7 @@ pub struct RelevanceStore { impl RelevanceStore where - K: Eq + Hash + Clone, + K: Eq + Ord + Hash + Clone, T: Ord + Clone, { /// Creates a relevance store from a map of query ids to relevance maps. @@ -162,7 +162,7 @@ impl RelevanceStoreBuilder { /// Builds the relevance store. pub fn build(self) -> RelevanceStore where - K: Eq + Hash + Clone + Display, + K: Eq + Ord + Hash + Clone + Display, T: Ord + Clone, { let mut map = HashMap::new(); @@ -174,7 +174,7 @@ impl RelevanceStoreBuilder { score: score.clone(), }) .collect::>(); - sorted.sort_by(|a, b| b.score.cmp(&a.score)); + sorted.sort_by(|a, b| b.score.cmp(&a.score).then(a.doc_id.cmp(&b.doc_id))); map.insert(query_id, RelevanceData { sorted, map: rels }); } RelevanceStore { name: None, map } diff --git a/src/trec.rs b/src/trec.rs index 4472e65..078bada 100644 --- a/src/trec.rs +++ b/src/trec.rs @@ -45,14 +45,15 @@ where for line in lines { let line = line.as_ref(); let rows = line.split_whitespace().collect::>(); - if rows.len() != 4 { + if rows.len() < 4 { return Err(ElinorError::InvalidFormat(line.to_string())); } let query_id = rows[0].to_string(); let doc_id = rows[2].to_string(); let score = rows[3] - .parse::() + .parse::() .map_err(|_| ElinorError::InvalidFormat(format!("Invalid score: {}", rows[3])))?; + let score = GoldScore::try_from(score.max(0)).unwrap(); b.add_score(query_id, doc_id, score)?; } Ok(b.build()) @@ -98,7 +99,7 @@ where for line in lines { let line = line.as_ref(); let rows = line.split_whitespace().collect::>(); - if rows.len() != 6 { + if rows.len() < 6 { return Err(ElinorError::InvalidFormat(line.to_string())); } let query_id = rows[0].to_string();