From 5de7231c99d2c718fdfde88b73a310d6ae1ad74e Mon Sep 17 00:00:00 2001 From: Mikhail Vyrodov Date: Sun, 2 Jun 2024 21:31:57 +0300 Subject: [PATCH] Add fastchat validation support --- goat/backend/eval.py | 34 +++++++++++++++++++++++++++----- goat/database/bd_init_script.sql | 3 ++- goat/frontend/app.py | 8 ++++++-- goat/utils/database_helper.py | 12 +++++++---- pyproject.toml | 1 + 5 files changed, 46 insertions(+), 12 deletions(-) diff --git a/goat/backend/eval.py b/goat/backend/eval.py index 49d6eea..a0ccf0f 100644 --- a/goat/backend/eval.py +++ b/goat/backend/eval.py @@ -1,7 +1,10 @@ # type: ignore import json +import os from pathlib import Path +from fastchat.llm_judge.gen_model_answer import run_eval +from fastchat.utils import str_to_torch_dtype from lm_eval import evaluator from lm_eval.models.huggingface import HFLM @@ -9,17 +12,38 @@ from goat.utils.database_helper import DatabaseHelper -def eval(model_name: str, precision: str): +def eval(model_name: str, precision: str, generate_fastchat: bool): lm = HFLM(pretrained=model_name, dtype=precision) taskname = "goat" results = evaluator.simple_evaluate(model=lm, tasks=[taskname]) - filename = model_name.replace("/", "__") - Path("results").mkdir(exist_ok=True) - with open(f"results/{filename}.json", "w", encoding="utf-8") as f: + model_id = model_name.replace("/", "__") + Path(f"goat/backend/results/{model_id}").mkdir(exist_ok=True) + lm_eval_output_file = f"goat/backend/results/{model_id + '_lm_eval'}.json" + with open(lm_eval_output_file, "w", encoding="utf-8") as f: json.dump(results, f, ensure_ascii=False) - add_results(input_path=f"results/{filename}.json") + if generate_fastchat: + fastchat_filename = os.path.join("goat/backend/results", model_id + "_fastchat.jsonl") + question_file = "goat/backend/data/question.jsonl" + + run_eval( + model_path=model_name, + model_id=model_id, + answer_file=fastchat_filename, + question_file=question_file, + question_begin=None, + question_end=None, + max_new_token=1024, + num_choices=1, + num_gpus_per_model=1, + num_gpus_total=1, + max_gpu_memory=None, + dtype=str_to_torch_dtype(precision), + revision="main", + ) + + add_results(input_path=lm_eval_output_file) if __name__ == "__main__": diff --git a/goat/database/bd_init_script.sql b/goat/database/bd_init_script.sql index d9285a3..577241d 100644 --- a/goat/database/bd_init_script.sql +++ b/goat/database/bd_init_script.sql @@ -16,7 +16,8 @@ create table if not exists public.eval_requests constraint eval_requests_pk primary key, model_name varchar not null, - precision varchar not null + precision varchar not null, + validate_big_tasks boolean not null ); alter table public.eval_requests diff --git a/goat/frontend/app.py b/goat/frontend/app.py index 0f0a0e3..9a98dd7 100644 --- a/goat/frontend/app.py +++ b/goat/frontend/app.py @@ -3,7 +3,7 @@ from goat.frontend.precision import Precision -from ..utils.database_helper import DatabaseHelper, EvalRequest +from ..utils.database_helper import DatabaseHelper TITLE = "Goat leaderboard" INTRODUCTION_TEXT = "This is really nice introduction text!!!" @@ -45,11 +45,15 @@ value="float16", interactive=True, ) + validate_big_tasks = gr.Checkbox( + label="Validate on big text tasks", + info="Do you need to validate your model on tasks that require large text answer?", + ) submit_button = gr.Button("Submit Eval") submission_result = gr.Markdown() submit_button.click( db_helper.add_eval_request, - [model_name, model_precision], + [model_name, model_precision, validate_big_tasks], submission_result, ) diff --git a/goat/utils/database_helper.py b/goat/utils/database_helper.py index efc1240..3f84403 100644 --- a/goat/utils/database_helper.py +++ b/goat/utils/database_helper.py @@ -28,6 +28,7 @@ class EvalResult: class EvalRequest: model_name: str precision: str + validate_big_tasks: bool class DatabaseHelper: @@ -47,8 +48,10 @@ def __init__(self): self.leaderboard = Table("leaderboard", metadata, autoload_with=self.engine) self.eval_requests = Table("eval_requests", metadata, autoload_with=self.engine) - def add_eval_request(self, model_name, precision): - request = insert(self.eval_requests).values(model_name=model_name, precision=precision) + def add_eval_request(self, model_name, precision, validate_big_tasks): + request = insert(self.eval_requests).values( + model_name=model_name, precision=precision, validate_big_tasks=validate_big_tasks + ) self.session.execute(request) self.session.commit() @@ -72,11 +75,12 @@ def listen_to_new_requests(self, action): notify = self.connection.notifies.pop() query = "SELECT * FROM eval_requests" df = pd.DataFrame(self.engine.connect().execute(text(query))) - model, precision = ( + model, precision, validate_big_tasks = ( df.loc[df["id"] == int(notify.payload)]["model_name"].to_string(index=False), df.loc[df["id"] == int(notify.payload)]["precision"].to_string(index=False), + df.loc[df["id"] == int(notify.payload)]["validate_big_tasks"].to_string(index=False), ) - action(model, precision) + action(model, precision, validate_big_tasks) def get_leaderboard_df(self): query = "SELECT * FROM leaderboard" diff --git a/pyproject.toml b/pyproject.toml index d8353ae..4528ea3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "torchvision==0.17.0", "transformer_engine==0.0.0", "lm_eval@git+https://github.com/deepvk/lm-evaluation-harness@goat", + "fschat[model_worker,llm_judge]@git+https://github.com/deepvk/FastChat/@goat", ] [tool.isort]