diff --git a/.gitignore b/.gitignore index cb2b65cbe6..d1cd9abd75 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,6 @@ venv/ .idea/ build + +openai-key.txt +*.code-workspace diff --git a/docs/eval-templates.md b/docs/eval-templates.md index ab949375e4..c6c90696d0 100644 --- a/docs/eval-templates.md +++ b/docs/eval-templates.md @@ -11,6 +11,10 @@ For a model completion `a` and a reference list of correct answers `B`, the foll - [`basic/includes.py:Includes`](../evals/elsuite/basic/includes.py): `any([(b in a) for b in B])` - [`basic/fuzzy_match.py:FuzzyMatch`](../evals/elsuite/basic/fuzzy_match.py): `any([(a in b or b in a) for b in B])` +To compare a model completion `a` in *JSON format* to a reference list of correct answers `B` also formatted in JSON, use the following eval: +- [`basic/json_match.py:JsonMatch`](../evals/elsuite/basic/json_match.py) yields a match if `a` is identical to at least one answer from `B`. Two JSON objects are +identical if they have the same set of keys and the values for each key are identical. Key order is not significant, and whitespace outside values is ignored. Invalid JSON never matches. + Which eval template you use will depend on your use case. It is always recommended that you inspect the completions from your model, as this will help you determine how and whether to tweak your prompt (or your reference answers) and pick your eval template. Academic benchmarks oftentimes fit the mold of these basic evals, and we have implemented several end-to-end examples of academic evals as Jupyter notebooks in the `examples` folder. Sometimes, [custom eval logic](custom-eval.md) will better suit your needs. One example of this is the [machine translation](../evals/elsuite/translate.py) eval [example](../examples/lafand-mt.ipynb), in which there is a unique and clearly defined metric that we wish to use in our eval. You should use your best judgment when deciding between custom eval logic, using a basic eval template, or using model-graded evals as described next. diff --git a/evals/elsuite/basic/json_match.py b/evals/elsuite/basic/json_match.py new file mode 100644 index 0000000000..dfaa00a51a --- /dev/null +++ b/evals/elsuite/basic/json_match.py @@ -0,0 +1,106 @@ +import json +import random +from typing import Any, Dict, List, Mapping, Union, cast + +import numpy as np + +import evals +from evals.api import CompletionFn +from evals.record import RecorderBase + + +def json_match(sampled_json: Any, correct_json: Any) -> bool: + """Return True if the sampled completion in JSON format + matches a correct answer, component by component""" + if sampled_json is None or correct_json is None: + # Missing values are never correct + return False + if isinstance(sampled_json, dict): + if isinstance(correct_json, dict): + sample = cast(Mapping[str, Any], sampled_json) + correct = cast(Mapping[str, Any], correct_json) + all_keys = set(sample.keys()) | set(correct.keys()) + return all(json_match(sample.get(key), correct.get(key)) for key in all_keys) + else: + return False + elif isinstance(sampled_json, list): + if isinstance(correct_json, list): + slist = cast(List[Any], sampled_json) + clist = cast(List[Any], correct_json) + if len(slist) != len(clist): + # Lists must have the same length + return False + return all(json_match(s, c) for s, c in zip(slist, clist)) + else: + return False + # Not a structured item: do a direct comparison + return sampled_json == correct_json + + +class JsonMatch(evals.Eval): + + """Compares a JSON completion with one or more ideal answers, + also coded in JSON. The decoded JSON objects are compared + elementwise and must match exactly.""" + + def __init__( + self, + completion_fns: list[CompletionFn], + samples_jsonl: str, + *args: Any, + max_tokens: int = 512, # Increase this for longer JSON completions + **kwargs: Any, + ): + super().__init__(completion_fns, *args, **kwargs) + assert len(completion_fns) == 1, "JsonMatch only supports one completion fn" + self.max_tokens = max_tokens + self.samples_jsonl = samples_jsonl + + def eval_sample(self, sample: Any, rng: random.Random): + del rng + + assert isinstance(sample, dict), "sample must be a dict" + assert "input" in sample, "sample must have an 'input' key" + assert "ideal" in sample, "sample must have an 'ideal' key" + + prompt = cast(str, sample["input"]) + correct_answers = cast(Union[str, List[str]], sample["ideal"]) + if not isinstance(correct_answers, list): + correct_answers = [correct_answers] + + result = self.completion_fn( + prompt=prompt, + temperature=0.0, # Q: why are these hardcoded? + max_tokens=self.max_tokens, + ) + sampled = result.get_completions()[0] + + sampled_json: Any + try: + sampled_json = json.loads(sampled) + except ValueError: + # If the sampled string is not valid JSON, it will never match + sampled_json = None + + # Allow the following to raise ValueError; the correct answers + # should always be valid JSON + correct_json = [json.loads(correct_answer) for correct_answer in correct_answers] + + matches = [json_match(sampled_json, cj) for cj in correct_json] + + evals.record.record_match( + True in matches, + expected=correct_answers, + picked=[sampled for i in range(len(correct_answers)) if matches[i]], + ) + evals.record.record_metrics( + accuracy=float(True in matches), + ) + + def run(self, recorder: RecorderBase) -> Dict[str, float]: + samples = self.get_samples() + self.eval_all_samples(recorder, samples) + + return { + "accuracy": np.mean(recorder.get_scores("accuracy")), + } diff --git a/evals/elsuite/basic/json_match_test.py b/evals/elsuite/basic/json_match_test.py new file mode 100644 index 0000000000..84d3cdd732 --- /dev/null +++ b/evals/elsuite/basic/json_match_test.py @@ -0,0 +1,98 @@ +from pathlib import Path +from typing import Any, Type + +from mock import patch +from pytest import mark, raises + +from evals.api import DummyCompletionFn +from evals.elsuite.basic.json_match import JsonMatch +from evals.record import DummyRecorder +from evals.utils.test import TestCompletionFn + + +@mark.parametrize( + "completion, ideal, expected_metrics", + [ + # Basic match + ('{ "key": "value" }', '{ "key": "value" }', dict(accuracy=1.0)), + # Whitespace is not significant + ('{\n "key":"value"\n }\n', '{ "key": "value" }', dict(accuracy=1.0)), + # Key order is not significant + ( + '{ "key2": "foo", "key1": "bar" }', + '{ "key1": "bar", "key2": "foo" }', + dict(accuracy=1.0), + ), + # No match if values are different + ('{ "key": "value" }', '{ "key": "notvalue" }', dict(accuracy=0)), + # Values can be numbers as well as strings + ('{ "key": 100 }', '{ "key": 100 }', dict(accuracy=1.0)), + # Numerical values are not accepted if they differ + ('{ "key": 100 }', '{ "key": 100.1 }', dict(accuracy=0)), + # Completion is accepted if it is found in an array of valid answers + ('{ "key": 100 }', ['{ "key": 100.1 }', '{ "key": 100 }'], dict(accuracy=1.0)), + # Completion is not accepted if it is not found in an array of valid answers + ('{ "key": 100 }', ['{ "key": 100.1 }', '{ "key": 99.9 }'], dict(accuracy=0)), + # Different keys do not match + ('{ "key": "value" }', '{ "anotherkey": "value" }', dict(accuracy=0)), + # Missing keys do not match + ( + '{ "key": "value" }', + '{ "key": "value", "anotherkey": "value" }', + dict(accuracy=0), + ), + # Extra keys do not match + ( + '{ "key": "value", "anotherkey": "value" }', + '{ "key": "value" }', + dict(accuracy=0), + ), + # Lists are supported, and matched by element equality + ('{ "key": [1.0,2.0,3.0] }', '{ "key": [1, 2, 3] }', dict(accuracy=1.0)), + # Lists of different lengths do not match + ('{ "key": [1, 2, 3] }', '{ "key": [1, 2, 3, 3] }', dict(accuracy=0)), + # Lists that are not equal index-by-index do not match + ('{ "key": [1, 2, 3] }', '{ "key": [1, 3, 2] }', dict(accuracy=0)), + # An empty list does not match a nonempty list + ('{ "key": [] }', '{ "key": [1] }', dict(accuracy=0)), + # Completion with invalid JSON is not accepted + ('{ "key": "value }', '{ "key": "value" }', dict(accuracy=0)), + ], +) +def test_eval_sample( + completion: str, + ideal: list[str], + expected_metrics: dict[str, float], +) -> None: + eval = JsonMatch( + completion_fns=[TestCompletionFn(completion)], + samples_jsonl="", + eval_registry_path=Path("."), + ) + + recorder = DummyRecorder(None) + with recorder.as_default_recorder("x"), patch.object( + recorder, "record_metrics", wraps=recorder.record_metrics + ) as record_metrics: + eval.eval_sample(dict(input=completion, ideal=ideal), None) + record_metrics.assert_called_once_with(**expected_metrics) + + +@mark.parametrize( + "sample, expected_error", + [ + (None, AssertionError), + ("", AssertionError), + (dict(ideal="world"), AssertionError), # Missing input + (dict(input="world"), AssertionError), # Missing ideal answer + ], +) +def test_eval_sample_raises(sample: Any, expected_error: Type[Exception]) -> None: + eval = JsonMatch( + completion_fns=[DummyCompletionFn()], + samples_jsonl="", + eval_registry_path=Path("."), + ) + + with raises(expected_error): + eval.eval_sample(sample, None) diff --git a/evals/registry/data/icelandic-inflection-easy/samples.jsonl b/evals/registry/data/icelandic-inflection-easy/samples.jsonl new file mode 100644 index 0000000000..53bd0fc8f6 --- /dev/null +++ b/evals/registry/data/icelandic-inflection-easy/samples.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d85df00cf22b3c4638efc9f61c42d7adca7cdf19ccae107ef515fb5b5616e706 +size 72354 diff --git a/evals/registry/data/icelandic-inflection-hard/samples.jsonl b/evals/registry/data/icelandic-inflection-hard/samples.jsonl new file mode 100644 index 0000000000..785f92b0c6 --- /dev/null +++ b/evals/registry/data/icelandic-inflection-hard/samples.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:02d06a3b274f136c038a5f6fd12e03cc63b29db11e6f481e7eaded8b941bd849 +size 74148 diff --git a/evals/registry/data/icelandic-inflection-medium/samples.jsonl b/evals/registry/data/icelandic-inflection-medium/samples.jsonl new file mode 100644 index 0000000000..047b236209 --- /dev/null +++ b/evals/registry/data/icelandic-inflection-medium/samples.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2c71f284b2caee78a244cdffa3db3830435e4906a6f0855f70d8efcaf104df9a +size 75509 diff --git a/evals/registry/evals/icelandic-inflection-easy.yaml b/evals/registry/evals/icelandic-inflection-easy.yaml new file mode 100644 index 0000000000..eaaa0bb83b --- /dev/null +++ b/evals/registry/evals/icelandic-inflection-easy.yaml @@ -0,0 +1,9 @@ +icelandic-inflection-easy: + id: icelandic-inflection-easy.dev.v0 + description: Test the model's ability to correctly inflect Icelandic noun phrases (easiest category) + metrics: [accuracy] + +icelandic-inflection-easy.dev.v0: + class: evals.elsuite.basic.json_match:JsonMatch + args: + samples_jsonl: icelandic-inflection-easy/samples.jsonl diff --git a/evals/registry/evals/icelandic-inflection-hard.yaml b/evals/registry/evals/icelandic-inflection-hard.yaml new file mode 100644 index 0000000000..aa2c7253be --- /dev/null +++ b/evals/registry/evals/icelandic-inflection-hard.yaml @@ -0,0 +1,9 @@ +icelandic-inflection-hard: + id: icelandic-inflection-hard.dev.v0 + description: Test the model's ability to correctly inflect Icelandic noun phrases (hard category) + metrics: [accuracy] + +icelandic-inflection-hard.dev.v0: + class: evals.elsuite.basic.json_match:JsonMatch + args: + samples_jsonl: icelandic-inflection-hard/samples.jsonl diff --git a/evals/registry/evals/icelandic-inflection-medium.yaml b/evals/registry/evals/icelandic-inflection-medium.yaml new file mode 100644 index 0000000000..15cae7c67c --- /dev/null +++ b/evals/registry/evals/icelandic-inflection-medium.yaml @@ -0,0 +1,9 @@ +icelandic-inflection-medium: + id: icelandic-inflection-medium.dev.v0 + description: Test the model's ability to correctly inflect Icelandic noun phrases (medium category) + metrics: [accuracy] + +icelandic-inflection-medium.dev.v0: + class: evals.elsuite.basic.json_match:JsonMatch + args: + samples_jsonl: icelandic-inflection-medium/samples.jsonl diff --git a/evals/utils/test.py b/evals/utils/test.py index 0ad762f2c5..f42bdcb4e3 100644 --- a/evals/utils/test.py +++ b/evals/utils/test.py @@ -5,6 +5,9 @@ class TestCompletionResult(CompletionResult): + + __test__ = False # Prevent pytest from trying to run this class as a test + def __init__(self, completion: str): self.completion = completion @@ -13,6 +16,9 @@ def get_completions(self) -> list[str]: class TestCompletionFn(CompletionFn): + + __test__ = False # Prevent pytest from trying to run this class as a test + def __init__(self, completion: str): self.completion = completion