Skip to content

Commit

Permalink
Add fastchat validation support
Browse files Browse the repository at this point in the history
  • Loading branch information
VyrodovMikhail committed Jun 2, 2024
1 parent 08e7182 commit 5de7231
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 12 deletions.
34 changes: 29 additions & 5 deletions goat/backend/eval.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,49 @@
# type: ignore
import json
import os
from pathlib import Path

from fastchat.llm_judge.gen_model_answer import run_eval
from fastchat.utils import str_to_torch_dtype
from lm_eval import evaluator
from lm_eval.models.huggingface import HFLM

from goat.backend.add_results import add_results
from goat.utils.database_helper import DatabaseHelper


def eval(model_name: str, precision: str):
def eval(model_name: str, precision: str, generate_fastchat: bool):
lm = HFLM(pretrained=model_name, dtype=precision)
taskname = "goat"
results = evaluator.simple_evaluate(model=lm, tasks=[taskname])

filename = model_name.replace("/", "__")
Path("results").mkdir(exist_ok=True)
with open(f"results/{filename}.json", "w", encoding="utf-8") as f:
model_id = model_name.replace("/", "__")
Path(f"goat/backend/results/{model_id}").mkdir(exist_ok=True)
lm_eval_output_file = f"goat/backend/results/{model_id + '_lm_eval'}.json"
with open(lm_eval_output_file, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False)

add_results(input_path=f"results/{filename}.json")
if generate_fastchat:
fastchat_filename = os.path.join("goat/backend/results", model_id + "_fastchat.jsonl")
question_file = "goat/backend/data/question.jsonl"

run_eval(
model_path=model_name,
model_id=model_id,
answer_file=fastchat_filename,
question_file=question_file,
question_begin=None,
question_end=None,
max_new_token=1024,
num_choices=1,
num_gpus_per_model=1,
num_gpus_total=1,
max_gpu_memory=None,
dtype=str_to_torch_dtype(precision),
revision="main",
)

add_results(input_path=lm_eval_output_file)


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion goat/database/bd_init_script.sql
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ create table if not exists public.eval_requests
constraint eval_requests_pk
primary key,
model_name varchar not null,
precision varchar not null
precision varchar not null,
validate_big_tasks boolean not null
);

alter table public.eval_requests
Expand Down
8 changes: 6 additions & 2 deletions goat/frontend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from goat.frontend.precision import Precision

from ..utils.database_helper import DatabaseHelper, EvalRequest
from ..utils.database_helper import DatabaseHelper

TITLE = "Goat leaderboard"
INTRODUCTION_TEXT = "This is really nice introduction text!!!"
Expand Down Expand Up @@ -45,11 +45,15 @@
value="float16",
interactive=True,
)
validate_big_tasks = gr.Checkbox(
label="Validate on big text tasks",
info="Do you need to validate your model on tasks that require large text answer?",
)
submit_button = gr.Button("Submit Eval")
submission_result = gr.Markdown()
submit_button.click(
db_helper.add_eval_request,
[model_name, model_precision],
[model_name, model_precision, validate_big_tasks],
submission_result,
)

Expand Down
12 changes: 8 additions & 4 deletions goat/utils/database_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class EvalResult:
class EvalRequest:
model_name: str
precision: str
validate_big_tasks: bool


class DatabaseHelper:
Expand All @@ -47,8 +48,10 @@ def __init__(self):
self.leaderboard = Table("leaderboard", metadata, autoload_with=self.engine)
self.eval_requests = Table("eval_requests", metadata, autoload_with=self.engine)

def add_eval_request(self, model_name, precision):
request = insert(self.eval_requests).values(model_name=model_name, precision=precision)
def add_eval_request(self, model_name, precision, validate_big_tasks):
request = insert(self.eval_requests).values(
model_name=model_name, precision=precision, validate_big_tasks=validate_big_tasks
)
self.session.execute(request)
self.session.commit()

Expand All @@ -72,11 +75,12 @@ def listen_to_new_requests(self, action):
notify = self.connection.notifies.pop()
query = "SELECT * FROM eval_requests"
df = pd.DataFrame(self.engine.connect().execute(text(query)))
model, precision = (
model, precision, validate_big_tasks = (
df.loc[df["id"] == int(notify.payload)]["model_name"].to_string(index=False),
df.loc[df["id"] == int(notify.payload)]["precision"].to_string(index=False),
df.loc[df["id"] == int(notify.payload)]["validate_big_tasks"].to_string(index=False),
)
action(model, precision)
action(model, precision, validate_big_tasks)

def get_leaderboard_df(self):
query = "SELECT * FROM leaderboard"
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"torchvision==0.17.0",
"transformer_engine==0.0.0",
"lm_eval@git+https://github.com/deepvk/lm-evaluation-harness@goat",
"fschat[model_worker,llm_judge]@git+https://github.com/deepvk/FastChat/@goat",
]

[tool.isort]
Expand Down

0 comments on commit 5de7231

Please sign in to comment.