Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/TablewareBox/evals
Browse files Browse the repository at this point in the history
  • Loading branch information
TablewareBox committed Mar 12, 2024
2 parents b145b96 + 733291c commit ae76148
Show file tree
Hide file tree
Showing 14 changed files with 191 additions and 8 deletions.
148 changes: 148 additions & 0 deletions evals/elsuite/choice_match.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import os
from pathlib import Path
from typing import Any

import oss2
from oss2.credentials import EnvironmentVariableCredentialsProvider

import evals
import evals.metrics
from evals.api import CompletionFn
from evals.prompt.base import is_chat_prompt


def init_oss():
"""
Initialize OSS client.
"""
# Please set OSS_ACCESS_KEY_ID & OSS_ACCESS_KEY_SECRET in your environment variables.
auth = oss2.ProviderAuth(EnvironmentVariableCredentialsProvider())

# 设置 Endpoint
endpoint = 'https://oss-cn-beijing.aliyuncs.com'

# 设置 Bucket
bucket_name = 'dp-filetrans-bj'
bucket = oss2.Bucket(auth, endpoint, bucket_name)

return bucket


def get_rag_dataset(samples_jsonl: str) -> list[dict]:
bucket = init_oss()
raw_samples = evals.get_jsonl(samples_jsonl)

for raw_sample in raw_samples:
for ftype in ["", "answer"]:
if f"{ftype}file_name" not in raw_sample and f"{ftype}file_link" not in raw_sample:
continue
if f"{ftype}file_name" in raw_sample:
oss_file = "changjunhan/" + os.path.basename(raw_sample[f"{ftype}file_name"])
raw_sample[f"{ftype}file_link"] = "https://dp-filetrans-bj.oss-cn-beijing.aliyuncs.com/" + oss_file

exists = bucket.object_exists(oss_file)
if exists:
print(f"文件 {oss_file} 已存在于 OSS 中。")
else:
# 上传文件
bucket.put_object_from_file(oss_file, raw_sample[f"{ftype}file_name"])
print(f"文件 {oss_file} 已上传到 OSS。")
if f"{ftype}file_link" in raw_sample:
local_file = raw_sample[f"{ftype}file_name"] if f"{ftype}file_name" in raw_sample else \
os.path.basename(raw_sample[f"{ftype}file_link"])
oss_file = "changjunhan/" + os.path.basename(raw_sample[f"{ftype}file_link"])
if not os.path.exists(local_file):
if bucket.object_exists(oss_file):
# 从 OSS 下载文件
Path(local_file).parent.mkdir(parents=True, exist_ok=True)
bucket.get_object_to_file(oss_file, local_file)
print(f"文件 {oss_file} 已下载到本地。")
return raw_samples


class RAGMatch(evals.Eval):
def __init__(
self,
completion_fns: list[CompletionFn],
samples_jsonl: str,
*args,
max_tokens: int = 500,
num_few_shot: int = 0,
few_shot_jsonl: str = None,
**kwargs,
):
super().__init__(completion_fns, *args, **kwargs)
assert len(completion_fns) == 1, "Match only supports one completion fn"
self.max_tokens = max_tokens
self.samples_jsonl = samples_jsonl
self.num_few_shot = num_few_shot
if self.num_few_shot > 0:
assert few_shot_jsonl is not None, "few shot requires few shot sample dataset"
self.few_shot_jsonl = few_shot_jsonl
self.few_shot = evals.get_jsonl(self._prefix_registry_path(self.few_shot_jsonl))

def eval_sample(self, sample: Any, *_):
assert isinstance(sample, dict), "sample must be a dict"
assert "input" in sample, "sample must have an 'input' key"
assert "ideal" in sample, "sample must have an 'ideal' key"
assert isinstance(sample["ideal"], str) or isinstance(
sample["ideal"], list
), "sample['ideal'] must be a string or list of strings"

prompt = sample["input"]
if self.num_few_shot > 0:
assert is_chat_prompt(sample["input"]), "few shot requires chat prompt"
prompt = sample["input"][:-1]
for s in self.few_shot[: self.num_few_shot]:
prompt += s["sample"]
prompt += sample["input"][-1:]

result = self.completion_fn(
prompt=prompt,
temperature=0.0,
**{k: v for k, v in sample.items() if k not in ["input", "ideal"]}
)
sampled = result.get_completions()[0]

extras = {}
if hasattr(result, "extras"):
if "extracted_answer" in result.extras:
sampled = result.extras["extracted_answer"].rstrip(".")
extras = result.extras
print(sampled)
sampled = sampled.split("\n")
for i in range(len(sampled)-1, -1, -1):
if i == 0:
sampled = sampled[0]
elif sampled[i] == "":
continue
else:
sampled = sampled[i]
break
for i in ["a)", "b)", "c)", "d)"]:
if i in sample["ideal"] and i in sampled:
continue
elif i not in sample["ideal"] and i not in sampled:
continue
else:
sampled = ""
break
if sampled != "":
sampled = sample["ideal"]
print("compare", sampled, sample["ideal"])
return evals.record_and_check_match(
prompt=prompt,
sampled=sampled,
expected=sample["ideal"],
file_name=sample["file_name"],
**extras
)

def run(self, recorder):
samples = get_rag_dataset(self._prefix_registry_path(self.samples_jsonl).as_posix())
self.eval_all_samples(recorder, samples)
events = recorder.get_events("match")
return {
"accuracy": evals.metrics.get_accuracy(events),
"boostrap_std": evals.metrics.get_bootstrap_accuracy_std(events),
}
1 change: 0 additions & 1 deletion evals/elsuite/rag_table_extract_comp.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,6 @@ def eval_sample(self, sample, rng):
result = self.completion_fn(
prompt=prompt,
temperature=0.0,
max_tokens=5,
file_name=sample.file_name,
file_link=sample.file_link
)
Expand Down
2 changes: 1 addition & 1 deletion evals/elsuite/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def tableMatching(df_ref, df_prompt, index='Compound', compare_fields=[], record
return {"recall_field": 0.0, "recall_index": 0.0, "recall_value": 0.0, "recall_value_strict": 0.0,
"accuracy_value": 0.0, "accuracy_value_strict": 0.0, "recall_SMILES": 0.0}
metrics = {}
index_names = ["Compound", "Name", "SMILES", "Nickname", "Substrate"]
index_names = ["Compound", "Name", "SMILES", "Nickname", "Substrate","AlloyName"]

if index not in [None, ""]:
df_ref[index] = df_ref[index].astype(str)
Expand Down
3 changes: 3 additions & 0 deletions evals/registry/data/01_alloychart/samples.jsonl
Git LFS file not shown
4 changes: 2 additions & 2 deletions evals/registry/data/01_alloycomposition/composition.jsonl
Git LFS file not shown
4 changes: 2 additions & 2 deletions evals/registry/data/01_alloynum/alloy_number.jsonl
Git LFS file not shown
4 changes: 2 additions & 2 deletions evals/registry/data/01_alloysort/sort.jsonl
Git LFS file not shown
3 changes: 3 additions & 0 deletions evals/registry/data/05_biochart/samples.jsonl
Git LFS file not shown
3 changes: 3 additions & 0 deletions evals/registry/data/05_biochart/samples_single.jsonl
Git LFS file not shown
3 changes: 3 additions & 0 deletions evals/registry/data/05_biochart/samples_single_gemini.jsonl
Git LFS file not shown
3 changes: 3 additions & 0 deletions evals/registry/data/drug_ChartQA/samples copy.jsonl
Git LFS file not shown
3 changes: 3 additions & 0 deletions evals/registry/data/drug_ChartQA/samples.jsonl
Git LFS file not shown
9 changes: 9 additions & 0 deletions evals/registry/evals/01_scipaper_alloychart.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@

alloychart:
id: alloychart.dev.v0
metrics: [accuracy]

alloychart.dev.v0:
class: evals.elsuite.rag_match_fuzzy:RAGMatch
args:
samples_jsonl: 01_alloychart/samples.jsonl
9 changes: 9 additions & 0 deletions evals/registry/evals/01_scipaper_drugchart.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@

drugchart:
id: drug_ChartQA.dev.v0
metrics: [accuracy]

drug_ChartQA.dev.v0:
class: evals.elsuite.choice_match:RAGMatch
args:
samples_jsonl: drug_ChartQA/samples.jsonl

0 comments on commit ae76148

Please sign in to comment.