Skip to content

Commit

Permalink
Merge pull request #62 from hotchpotch/report_to_df
Browse files Browse the repository at this point in the history
Add `report.to_dataframe()` method
  • Loading branch information
AmenRa authored Jul 1, 2024
2 parents 33ff11e + 35bf9f6 commit 94fe982
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
18 changes: 18 additions & 0 deletions ranx/data_structures/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
from typing import Dict, List, Tuple

import pandas as pd
from tabulate import tabulate

from .frozenset_dict import FrozensetDict
Expand Down Expand Up @@ -319,6 +320,23 @@ def to_dict(self) -> Dict:

return d

def to_dataframe(self) -> pd.DataFrame:
"""Returns the Report data as a Pandas DataFrame.
Returns:
pd.DataFrame: Report data as a Pandas DataFrame
"""
report_dict = self.to_dict()
report_scores = {
name: report_dict[name]["scores"] for name in report_dict["model_names"]
}
df = pd.DataFrame.from_dict(report_scores, orient="index")
df = df.reset_index().rename(
columns={"index": "model_names"}
) # index to model_names column

return df

def save(self, path: str):
"""Save the Report data as JSON file.
See [**Report.to_dict**][ranx.report.to_dict] for more details.
Expand Down
16 changes: 16 additions & 0 deletions tests/unit/ranx/data_structures/report_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,19 @@ def test_stat_test(qrels, runs, metrics):
assert report.stat_test == "tukey"
assert report.get_stat_test_label(report.stat_test) == "Tukey's HSD test"
assert report.get_stat_test_label(report.stat_test) in report.to_latex()


def test_to_dataframe(qrels, runs, metrics):
report = compare(qrels, runs, metrics)
report_df = report.to_dataframe()

assert report_df["model_names"].tolist() == report.model_names
assert report_df.columns.tolist() == ["model_names"] + metrics

assert all(
all(
report_df[report_df["model_names"] == model][metric].notnull().all()
for metric in metrics
)
for model in report_df["model_names"]
)

0 comments on commit 94fe982

Please sign in to comment.