Skip to content

Commit

Permalink
CI fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
VyrodovMikhail committed Jun 3, 2024
1 parent 78689cb commit 477ff3b
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ jobs:
- name: "isort"
run: isort . --check --diff
- name: "mypy"
run: mypy
run: mypy --ignore-missing-imports
- name: "pytests"
run: pytest
15 changes: 8 additions & 7 deletions goat/backend/add_results.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# type: ignore
import json

from datasets import get_dataset_config_names, load_dataset

from goat.utils.database_helper import DatabaseHelper, EvalResult


def get_datasets_len(tasks):
def get_datasets_len(tasks: list[str]) -> dict[str, int]:
datasets_len = dict()
datasets_len["single_choice"] = 0
datasets_len["multiple_choice"] = 0
Expand All @@ -24,16 +23,18 @@ def get_datasets_len(tasks):
return datasets_len


def get_metrics_values(tasks, evaluation, datasets_len):
def get_metrics_values(
tasks: list[str], evaluation: dict[str, dict], datasets_len: dict[str, int]
) -> tuple[float, float, float]:
metrics = [
"multi_choice_em_unordered,get-answer",
"word_in_set,none",
"acc,none",
]

single_choice_score = 0
multiple_choice_score = 0
word_gen_score = 0
single_choice_score = 0.0
multiple_choice_score = 0.0
word_gen_score = 0.0

for task in tasks:
for metric in metrics:
Expand All @@ -53,7 +54,7 @@ def get_metrics_values(tasks, evaluation, datasets_len):
return single_choice_score, multiple_choice_score, word_gen_score


def add_results(input_path):
def add_results(input_path: str) -> None:
with open(input_path, "r") as j:
contents = json.loads(j.read())
evaluation = contents["results"]
Expand Down
3 changes: 1 addition & 2 deletions goat/backend/eval.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# type: ignore
import json
import os
from pathlib import Path
Expand All @@ -12,7 +11,7 @@
from goat.utils.database_helper import DatabaseHelper


def eval(model_name: str, precision: str, generate_fastchat: bool):
def eval(model_name: str, precision: str, generate_fastchat: bool) -> None:
lm = HFLM(pretrained=model_name, dtype=precision)
taskname = "goat"
results = evaluator.simple_evaluate(model=lm, tasks=[taskname])
Expand Down
1 change: 0 additions & 1 deletion goat/frontend/app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# type: ignore
import gradio as gr

from goat.frontend.precision import Precision
Expand Down
18 changes: 9 additions & 9 deletions goat/frontend/precision.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# type: ignore
from enum import Enum


Expand All @@ -8,11 +7,12 @@ class Precision(Enum):
float32 = "float32"
Unknown = "?"

def from_str(precision):
if precision in ["torch.float16", "float16"]:
return Precision.float16
if precision in ["torch.float32", "float32"]:
return Precision.float32
if precision in ["torch.bfloat16", "bfloat16"]:
return Precision.bfloat16
return Precision.Unknown

def from_str(precision: str) -> Precision:
if precision in ["torch.float16", "float16"]:
return Precision.float16
if precision in ["torch.float32", "float32"]:
return Precision.float32
if precision in ["torch.bfloat16", "bfloat16"]:
return Precision.bfloat16
return Precision.Unknown
22 changes: 12 additions & 10 deletions goat/utils/database_helper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# type: ignore
import os
import select
from dataclasses import dataclass
from typing import Callable

import pandas as pd
import psycopg2
Expand Down Expand Up @@ -31,15 +31,15 @@ class EvalRequest:
validate_big_tasks: bool


def postgres_str_to_bool(val):
if val == 'True':
def postgres_str_to_bool(val: str) -> bool:
if val == "True":
return True
else:
return False


class DatabaseHelper:
def __init__(self):
def __init__(self) -> None:
self.engine = create_engine(
f"postgresql+psycopg2://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_IP}:{POSTGRES_PORT}/{POSTGRES_DB}",
echo=True,
Expand All @@ -55,14 +55,14 @@ def __init__(self):
self.leaderboard = Table("leaderboard", metadata, autoload_with=self.engine)
self.eval_requests = Table("eval_requests", metadata, autoload_with=self.engine)

def add_eval_request(self, model_name, precision, validate_big_tasks):
def add_eval_request(self, model_name: str, precision: str, validate_big_tasks: bool) -> None:
request = insert(self.eval_requests).values(
model_name=model_name, precision=precision, validate_big_tasks=validate_big_tasks
)
self.session.execute(request)
self.session.commit()

def add_eval_result(self, eval_result):
def add_eval_result(self, eval_result: EvalResult) -> None:
stmt = insert(self.leaderboard).values(
model=eval_result.model,
single_choice=eval_result.single_choice,
Expand All @@ -72,7 +72,7 @@ def add_eval_result(self, eval_result):
self.session.execute(stmt)
self.session.commit()

def listen_to_new_requests(self, action):
def listen_to_new_requests(self, action: Callable[[str, str, bool], None]) -> None:
cur = self.connection.cursor()
cur.execute("LISTEN id;")
while True:
Expand All @@ -85,14 +85,16 @@ def listen_to_new_requests(self, action):
model, precision, validate_big_tasks = (
df.loc[df["id"] == int(notify.payload)]["model_name"].to_string(index=False),
df.loc[df["id"] == int(notify.payload)]["precision"].to_string(index=False),
postgres_str_to_bool(df.loc[df["id"] == int(notify.payload)]["validate_big_tasks"].to_string(index=False)),
postgres_str_to_bool(
df.loc[df["id"] == int(notify.payload)]["validate_big_tasks"].to_string(index=False)
),
)
action(model, precision, validate_big_tasks)

def get_leaderboard_df(self):
def get_leaderboard_df(self) -> pd.DataFrame:
query = "SELECT * FROM leaderboard"
df = pd.DataFrame(self.engine.connect().execute(text(query)))
return df

def end_connection(self):
def end_connection(self) -> None:
self.connection.close()

0 comments on commit 477ff3b

Please sign in to comment.