diff --git a/evals/elsuite/choice_match.py b/evals/elsuite/choice_match.py new file mode 100644 index 0000000000..5700bf4cbc --- /dev/null +++ b/evals/elsuite/choice_match.py @@ -0,0 +1,148 @@ +import os +from pathlib import Path +from typing import Any + +import oss2 +from oss2.credentials import EnvironmentVariableCredentialsProvider + +import evals +import evals.metrics +from evals.api import CompletionFn +from evals.prompt.base import is_chat_prompt + + +def init_oss(): + """ + Initialize OSS client. + """ + # Please set OSS_ACCESS_KEY_ID & OSS_ACCESS_KEY_SECRET in your environment variables. + auth = oss2.ProviderAuth(EnvironmentVariableCredentialsProvider()) + + # 设置 Endpoint + endpoint = 'https://oss-cn-beijing.aliyuncs.com' + + # 设置 Bucket + bucket_name = 'dp-filetrans-bj' + bucket = oss2.Bucket(auth, endpoint, bucket_name) + + return bucket + + +def get_rag_dataset(samples_jsonl: str) -> list[dict]: + bucket = init_oss() + raw_samples = evals.get_jsonl(samples_jsonl) + + for raw_sample in raw_samples: + for ftype in ["", "answer"]: + if f"{ftype}file_name" not in raw_sample and f"{ftype}file_link" not in raw_sample: + continue + if f"{ftype}file_name" in raw_sample: + oss_file = "changjunhan/" + os.path.basename(raw_sample[f"{ftype}file_name"]) + raw_sample[f"{ftype}file_link"] = "https://dp-filetrans-bj.oss-cn-beijing.aliyuncs.com/" + oss_file + + exists = bucket.object_exists(oss_file) + if exists: + print(f"文件 {oss_file} 已存在于 OSS 中。") + else: + # 上传文件 + bucket.put_object_from_file(oss_file, raw_sample[f"{ftype}file_name"]) + print(f"文件 {oss_file} 已上传到 OSS。") + if f"{ftype}file_link" in raw_sample: + local_file = raw_sample[f"{ftype}file_name"] if f"{ftype}file_name" in raw_sample else \ + os.path.basename(raw_sample[f"{ftype}file_link"]) + oss_file = "changjunhan/" + os.path.basename(raw_sample[f"{ftype}file_link"]) + if not os.path.exists(local_file): + if bucket.object_exists(oss_file): + # 从 OSS 下载文件 + Path(local_file).parent.mkdir(parents=True, exist_ok=True) + bucket.get_object_to_file(oss_file, local_file) + print(f"文件 {oss_file} 已下载到本地。") + return raw_samples + + +class RAGMatch(evals.Eval): + def __init__( + self, + completion_fns: list[CompletionFn], + samples_jsonl: str, + *args, + max_tokens: int = 500, + num_few_shot: int = 0, + few_shot_jsonl: str = None, + **kwargs, + ): + super().__init__(completion_fns, *args, **kwargs) + assert len(completion_fns) == 1, "Match only supports one completion fn" + self.max_tokens = max_tokens + self.samples_jsonl = samples_jsonl + self.num_few_shot = num_few_shot + if self.num_few_shot > 0: + assert few_shot_jsonl is not None, "few shot requires few shot sample dataset" + self.few_shot_jsonl = few_shot_jsonl + self.few_shot = evals.get_jsonl(self._prefix_registry_path(self.few_shot_jsonl)) + + def eval_sample(self, sample: Any, *_): + assert isinstance(sample, dict), "sample must be a dict" + assert "input" in sample, "sample must have an 'input' key" + assert "ideal" in sample, "sample must have an 'ideal' key" + assert isinstance(sample["ideal"], str) or isinstance( + sample["ideal"], list + ), "sample['ideal'] must be a string or list of strings" + + prompt = sample["input"] + if self.num_few_shot > 0: + assert is_chat_prompt(sample["input"]), "few shot requires chat prompt" + prompt = sample["input"][:-1] + for s in self.few_shot[: self.num_few_shot]: + prompt += s["sample"] + prompt += sample["input"][-1:] + + result = self.completion_fn( + prompt=prompt, + temperature=0.0, + **{k: v for k, v in sample.items() if k not in ["input", "ideal"]} + ) + sampled = result.get_completions()[0] + + extras = {} + if hasattr(result, "extras"): + if "extracted_answer" in result.extras: + sampled = result.extras["extracted_answer"].rstrip(".") + extras = result.extras + print(sampled) + sampled = sampled.split("\n") + for i in range(len(sampled)-1, -1, -1): + if i == 0: + sampled = sampled[0] + elif sampled[i] == "": + continue + else: + sampled = sampled[i] + break + for i in ["a)", "b)", "c)", "d)"]: + if i in sample["ideal"] and i in sampled: + continue + elif i not in sample["ideal"] and i not in sampled: + continue + else: + sampled = "" + break + if sampled != "": + sampled = sample["ideal"] + print("compare", sampled, sample["ideal"]) + return evals.record_and_check_match( + prompt=prompt, + sampled=sampled, + expected=sample["ideal"], + file_name=sample["file_name"], + **extras + ) + + def run(self, recorder): + samples = get_rag_dataset(self._prefix_registry_path(self.samples_jsonl).as_posix()) + self.eval_all_samples(recorder, samples) + events = recorder.get_events("match") + return { + "accuracy": evals.metrics.get_accuracy(events), + "boostrap_std": evals.metrics.get_bootstrap_accuracy_std(events), + } diff --git a/evals/elsuite/rag_table_extract_comp.py b/evals/elsuite/rag_table_extract_comp.py index 4413ce849d..ad1c116f75 100644 --- a/evals/elsuite/rag_table_extract_comp.py +++ b/evals/elsuite/rag_table_extract_comp.py @@ -310,7 +310,6 @@ def eval_sample(self, sample, rng): result = self.completion_fn( prompt=prompt, temperature=0.0, - max_tokens=5, file_name=sample.file_name, file_link=sample.file_link ) diff --git a/evals/elsuite/utils.py b/evals/elsuite/utils.py index 705885b3eb..ec037f7870 100644 --- a/evals/elsuite/utils.py +++ b/evals/elsuite/utils.py @@ -367,7 +367,7 @@ def tableMatching(df_ref, df_prompt, index='Compound', compare_fields=[], record return {"recall_field": 0.0, "recall_index": 0.0, "recall_value": 0.0, "recall_value_strict": 0.0, "accuracy_value": 0.0, "accuracy_value_strict": 0.0, "recall_SMILES": 0.0} metrics = {} - index_names = ["Compound", "Name", "SMILES", "Nickname", "Substrate"] + index_names = ["Compound", "Name", "SMILES", "Nickname", "Substrate","AlloyName"] if index not in [None, ""]: df_ref[index] = df_ref[index].astype(str) diff --git a/evals/registry/data/01_alloychart/samples.jsonl b/evals/registry/data/01_alloychart/samples.jsonl new file mode 100644 index 0000000000..c0a2fe60b0 --- /dev/null +++ b/evals/registry/data/01_alloychart/samples.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:805f5c92e67bbb8a143fd30368bd0a480486d8733a1e1c1bed8985b2ad007bde +size 7783 diff --git a/evals/registry/data/01_alloycomposition/composition.jsonl b/evals/registry/data/01_alloycomposition/composition.jsonl index f05b75dac7..c0d22563a0 100644 --- a/evals/registry/data/01_alloycomposition/composition.jsonl +++ b/evals/registry/data/01_alloycomposition/composition.jsonl @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:9b2190d4d26fc45c5ba0e6615f3c10964fd7a547be434c58e548f91eb2e412c5 -size 33632 +oid sha256:5a4facecc1fe28bd1d84fae1273eecdaa03c2eac65e4a66a3f348f8ddff671b6 +size 33206 diff --git a/evals/registry/data/01_alloynum/alloy_number.jsonl b/evals/registry/data/01_alloynum/alloy_number.jsonl index 55d0c5e1d9..5414adf7c2 100644 --- a/evals/registry/data/01_alloynum/alloy_number.jsonl +++ b/evals/registry/data/01_alloynum/alloy_number.jsonl @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:70ae2613eb223f2a9b7e8cbf10f81b5e626a9ae6827d94c7e3e25df1b8e4e4bf -size 36088 +oid sha256:2ca8af0565b74d457d188e00a5f25ff574b5d40175bc2c6321523d47753907f9 +size 35391 diff --git a/evals/registry/data/01_alloysort/sort.jsonl b/evals/registry/data/01_alloysort/sort.jsonl index 08873c35c6..eb53a2e052 100644 --- a/evals/registry/data/01_alloysort/sort.jsonl +++ b/evals/registry/data/01_alloysort/sort.jsonl @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:45b0bf49f634caf51353105a1d62a5cf8285654b0ba8a3e744442b879400bfd0 -size 15245 +oid sha256:d70ec2f034dc30585433fd1eff176bf5d72035d6a421062d0dc834a8d434a458 +size 15876 diff --git a/evals/registry/data/05_biochart/samples.jsonl b/evals/registry/data/05_biochart/samples.jsonl new file mode 100644 index 0000000000..96e54bf643 --- /dev/null +++ b/evals/registry/data/05_biochart/samples.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f17bc79d6a2f3de2d03c1682462fb1453a654cbc017f42d1edbe96689aca9895 +size 7500 diff --git a/evals/registry/data/05_biochart/samples_single.jsonl b/evals/registry/data/05_biochart/samples_single.jsonl new file mode 100644 index 0000000000..8d05e68ac6 --- /dev/null +++ b/evals/registry/data/05_biochart/samples_single.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fe96f94f5279ca67003295191c7f2c2002976a49bbe4856e58d0799bb7f58f2c +size 7664 diff --git a/evals/registry/data/05_biochart/samples_single_gemini.jsonl b/evals/registry/data/05_biochart/samples_single_gemini.jsonl new file mode 100644 index 0000000000..9f72c78011 --- /dev/null +++ b/evals/registry/data/05_biochart/samples_single_gemini.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1093a3a82f10c6ac2ce9f9896cc9159bbea4380c3d4bf9e4977224ab02f4df52 +size 7636 diff --git a/evals/registry/data/drug_ChartQA/samples copy.jsonl b/evals/registry/data/drug_ChartQA/samples copy.jsonl new file mode 100644 index 0000000000..06e4ff7a51 --- /dev/null +++ b/evals/registry/data/drug_ChartQA/samples copy.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3f5cf6d65edf256002bdd6d4bf8fdae23698a6150b23bf778e246fa9c51b8074 +size 8787 diff --git a/evals/registry/data/drug_ChartQA/samples.jsonl b/evals/registry/data/drug_ChartQA/samples.jsonl new file mode 100644 index 0000000000..06e4ff7a51 --- /dev/null +++ b/evals/registry/data/drug_ChartQA/samples.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3f5cf6d65edf256002bdd6d4bf8fdae23698a6150b23bf778e246fa9c51b8074 +size 8787 diff --git a/evals/registry/evals/01_scipaper_alloychart.yaml b/evals/registry/evals/01_scipaper_alloychart.yaml new file mode 100644 index 0000000000..0f5753b33f --- /dev/null +++ b/evals/registry/evals/01_scipaper_alloychart.yaml @@ -0,0 +1,9 @@ + +alloychart: + id: alloychart.dev.v0 + metrics: [accuracy] + +alloychart.dev.v0: + class: evals.elsuite.rag_match_fuzzy:RAGMatch + args: + samples_jsonl: 01_alloychart/samples.jsonl diff --git a/evals/registry/evals/01_scipaper_drugchart.yaml b/evals/registry/evals/01_scipaper_drugchart.yaml new file mode 100644 index 0000000000..bc4079eac8 --- /dev/null +++ b/evals/registry/evals/01_scipaper_drugchart.yaml @@ -0,0 +1,9 @@ + +drugchart: + id: drug_ChartQA.dev.v0 + metrics: [accuracy] + +drug_ChartQA.dev.v0: + class: evals.elsuite.choice_match:RAGMatch + args: + samples_jsonl: drug_ChartQA/samples.jsonl