Skip to content

Commit

Permalink
Add basic stats for cli (#48)
Browse files Browse the repository at this point in the history
* add

* add

* add

* other metric test
  • Loading branch information
kampersanda authored Oct 26, 2024
1 parent d3f1725 commit a1b8270
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 8 deletions.
25 changes: 24 additions & 1 deletion elinor-cli/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,14 @@ cargo run --release -p elinor-cli --bin elinor-evaluate -- \

The available metrics are shown in [Metric](https://docs.rs/elinor/latest/elinor/metrics/enum.Metric.html).

The output will show the macro-averaged scores for each metric:
The output will show the basic statistics and the macro-averaged scores for each metric:

```
n_queries_in_true 8
n_queries_in_pred 8
n_docs_in_true 20
n_docs_in_pred 24
n_true_relevant_docs 14
precision@3 0.5833
ap 0.8229
rr 0.8125
Expand Down Expand Up @@ -135,6 +140,15 @@ cargo run --release -p elinor-cli --bin elinor-compare -- \
The output will be:

```
# Basic statistics
+-----------+-------+
| Key | Value |
+-----------+-------+
| n_systems | 2 |
| n_topics | 8 |
| n_metrics | 4 |
+-----------+-------+
# Alias
+----------+-----------------------------+
| Alias | Path |
Expand Down Expand Up @@ -198,6 +212,15 @@ cargo run --release -p elinor-cli --bin elinor-compare -- \
The output will be:

```
# Basic statistics
+-----------+-------+
| Key | Value |
+-----------+-------+
| n_systems | 3 |
| n_topics | 8 |
| n_metrics | 4 |
+-----------+-------+
# Alias
+----------+-----------------------------+
| Alias | Path |
Expand Down
53 changes: 46 additions & 7 deletions elinor-cli/src/bin/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ struct Args {
/// Print mode for the output (pretty or raw).
#[arg(short, long, default_value = "pretty")]
print_mode: PrintMode,

/// Number of resamples for the bootstrap test.
#[arg(long, default_value = "10000")]
n_resamples: usize,

/// Number of iterations for the randomized test.
#[arg(long, default_value = "10000")]
n_iters: usize,
}

fn main() -> Result<()> {
Expand Down Expand Up @@ -85,9 +93,33 @@ fn main() -> Result<()> {
}
let topic_header = topic_headers[0].as_str();

println!("# Basic statistics");
{
let columns = vec![
Series::new(
"Key".into(),
vec![
"n_systems".to_string(),
"n_topics".to_string(),
"n_metrics".to_string(),
],
),
Series::new(
"Value".into(),
vec![
dfs.len() as u64,
dfs[0].get_columns()[0].len() as u64,
dfs[0].get_columns().len() as u64 - 1,
],
),
];
let df = DataFrame::new(columns)?;
print_dataframe(&df, args.print_mode);
}

// If there is only one input CSV file, just print the means.
if args.input_csvs.len() == 1 {
println!("# Means");
println!("\n# Means");
{
let metrics = extract_metrics(&dfs[0]);
let values = get_means(&dfs[0], &metrics, topic_header);
Expand All @@ -101,7 +133,7 @@ fn main() -> Result<()> {
return Ok(());
}

println!("# Alias");
println!("\n# Alias");
{
let columns = vec![
Series::new(
Expand All @@ -123,10 +155,17 @@ fn main() -> Result<()> {
}

if dfs.len() == 2 {
compare_two_systems(&dfs[0], &dfs[1], topic_header, args.print_mode)?;
compare_two_systems(
&dfs[0],
&dfs[1],
topic_header,
args.print_mode,
args.n_resamples,
args.n_iters,
)?;
}
if dfs.len() > 2 {
compare_multiple_systems(&dfs, topic_header, args.print_mode)?;
compare_multiple_systems(&dfs, topic_header, args.print_mode, args.n_iters)?;
}

Ok(())
Expand Down Expand Up @@ -180,6 +219,8 @@ fn compare_two_systems(
df_2: &DataFrame,
topic_header: &str,
print_mode: PrintMode,
n_resamples: usize,
n_iters: usize,
) -> Result<()> {
let metrics = extract_common_metrics([df_1, df_2]);
if metrics.is_empty() {
Expand Down Expand Up @@ -278,7 +319,6 @@ fn compare_two_systems(
print_dataframe(&df, print_mode);
}

let n_resamples = 10000;
println!("\n# Two-sided paired Bootstrap test (n_resamples = {n_resamples})");
{
let mut stats = vec![];
Expand Down Expand Up @@ -306,7 +346,6 @@ fn compare_two_systems(
print_dataframe(&df, print_mode);
}

let n_iters = 10000;
println!("\n# Fisher's randomized test (n_iters = {n_iters})");
{
let mut stats = vec![];
Expand Down Expand Up @@ -344,6 +383,7 @@ fn compare_multiple_systems(
dfs: &[DataFrame],
topic_header: &str,
print_mode: PrintMode,
n_iters: usize,
) -> Result<()> {
let metrics = extract_common_metrics(dfs);
if metrics.is_empty() {
Expand Down Expand Up @@ -382,7 +422,6 @@ fn compare_multiple_systems(
df_metrics.push(joined);
}

let n_iters = 10000;
let rthsd_tester = RandomizedTukeyHsdTester::new(dfs.len()).with_n_iters(n_iters);

for (metric, df_metric) in metrics.iter().zip(df_metrics.iter()) {
Expand Down
11 changes: 11 additions & 0 deletions elinor-cli/src/bin/evaluate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ fn main() -> Result<()> {
args.metrics
};

println!("n_queries_in_true\t{}", true_rels.n_queries());
println!("n_queries_in_pred\t{}", pred_rels.n_queries());
println!("n_docs_in_true\t{}", true_rels.n_docs());
println!("n_docs_in_pred\t{}", pred_rels.n_docs());
println!("n_true_relevant_docs\t{}", n_relevant_docs(&true_rels));

let mut columns = vec![];
for metric in metrics {
let result = elinor::evaluate(&true_rels, &pred_rels, metric)?;
Expand All @@ -79,6 +85,11 @@ fn main() -> Result<()> {
Ok(())
}

fn n_relevant_docs(true_rels: &TrueRelStore<String>) -> usize {
let records = true_rels.records();
records.into_iter().filter(|r| r.score > 0).count()
}

fn default_metrics() -> Vec<Metric> {
let mut metrics = Vec::new();
for k in [1, 5, 10] {
Expand Down
10 changes: 10 additions & 0 deletions scripts/compare_with_trec_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@ def compare_decimal_places(a: str, b: str, decimal_places: int) -> bool:
[metric for _, metric in metric_pairs],
)

# Add some additional basic metrics
metric_pairs.extend(
[
("num_q", "n_queries_in_true"),
("num_q", "n_queries_in_pred"),
("num_ret", "n_docs_in_pred"),
("num_rel", "n_true_relevant_docs"),
]
)

failed_rows: list[str] = []
for trec_metric, elinor_metric in metric_pairs:
trec_score = trec_results["trec_eval_output"][trec_metric]
Expand Down
14 changes: 14 additions & 0 deletions src/relevance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,20 @@ where
.collect()
}

/// Returns the relevance store as records.
pub fn records(&self) -> Vec<Record<K, T>> {
self.map
.iter()
.flat_map(|(query_id, data)| {
data.sorted.iter().map(move |rel| Record {
query_id: query_id.clone(),
doc_id: rel.doc_id.clone(),
score: rel.score.clone(),
})
})
.collect()
}

/// Returns the score for a given query-document pair.
pub fn get_score<Q>(&self, query_id: &Q, doc_id: &Q) -> Option<&T>
where
Expand Down

0 comments on commit a1b8270

Please sign in to comment.