Skip to content

Commit

Permalink
one more refactor of test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
roman-romanov-o committed Dec 30, 2024
1 parent dc98b6c commit fbb7343
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 66 deletions.
87 changes: 37 additions & 50 deletions tests/integration_tests/chat_completion_suites/text.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
from openai import BadRequestError, UnprocessableEntityError

from aidial_adapter_openai.constant import ChatCompletionDeploymentType
from tests.integration_tests.base import (
TestSuite,
exclude_deployments,
include_deployments,
)
from tests.integration_tests.base import TestSuite
from tests.utils.openai import ExpectedException, ai, sys, user


def _test_text_common(s: TestSuite) -> None:
def build_text_common(s: TestSuite) -> None:
# Basic dialog tests
s.test_case(
name="dialog recall",
Expand Down Expand Up @@ -84,56 +80,47 @@ def _test_text_common(s: TestSuite) -> None:
)


@include_deployments([ChatCompletionDeploymentType.MISTRAL])
def _test_mistral_stop_sequence(s: TestSuite) -> None:
s.test_case(
name="stop sequence",
stop=["John", "john"],
messages=[user('Reply with "Hello John Doe"')],
expected=(
# Mistral just ignores stop sequence
lambda s: "john"
in s.content.lower()
),
)
def build_stop_sequence(s: TestSuite) -> None:
if s.deployment_type == ChatCompletionDeploymentType.MISTRAL:
# Mistral just ignores stop sequence

def expected(s) -> bool:
return "john" in s.content.lower()

else:

def expected(s) -> bool:
return "john" not in s.content.lower()

@exclude_deployments([ChatCompletionDeploymentType.MISTRAL])
def _test_stop_sequence(s: TestSuite) -> None:
s.test_case(
name="stop sequence",
stop=["John", "john"],
messages=[user('Reply with "Hello John Doe"')],
expected=lambda s: "john" not in s.content.lower(),
expected=expected,
)


@include_deployments([ChatCompletionDeploymentType.DATABRICKS])
def _test_databricks_multi_system(s: TestSuite) -> None:
s.test_case(
name="many system",
messages=[
sys("act as a helpful assistant"),
sys("act as a calculator"),
user("2+5=?"),
],
# Databricks does not allow multiple system messages
expected=ExpectedException(
type=BadRequestError,
message=("Chat message input roles must alternate"),
status_code=400,
),
)


@exclude_deployments([ChatCompletionDeploymentType.DATABRICKS])
def _test_multi_system(s: TestSuite) -> None:
s.test_case(
name="many system",
messages=[
sys("act as a helpful assistant"),
sys("act as a calculator"),
user("2+5=?"),
],
expected=lambda s: "7" in s.content.lower(),
)
def build_multi_system(s: TestSuite) -> None:
messages = [
sys("act as a helpful assistant"),
sys("act as a calculator"),
user("2+5=?"),
]

if s.deployment_type == ChatCompletionDeploymentType.DATABRICKS:
s.test_case(
name="many system",
messages=messages,
# Databricks does not allow multiple system messages
expected=ExpectedException(
type=BadRequestError,
message=("Chat message input roles must alternate"),
status_code=400,
),
)
else:
s.test_case(
name="many system",
messages=messages,
expected=lambda s: "7" in s.content.lower(),
)
2 changes: 1 addition & 1 deletion tests/integration_tests/chat_completion_suites/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def supports_functions(deployment_type: ChatCompletionDeploymentType):
ChatCompletionDeploymentType.MISTRAL,
]
)
def _test_tools_common(s: TestSuite) -> None:
def build_tools_common(s: TestSuite) -> None:
if supports_parallel_tool_calls(s.deployment_type):
city_config = [[("Glasgow", 15)], [("Glasgow", 15), ("London", 20)]]
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/chat_completion_suites/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
ChatCompletionDeploymentType.GPT4_VISION,
]
)
def _test_vision_common(s: TestSuite) -> None:
def build_vision_common(s: TestSuite) -> None:
s.test_case(
name="image_in_content_parts",
messages=[
Expand Down
24 changes: 10 additions & 14 deletions tests/integration_tests/test_chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,15 @@

from tests.integration_tests.base import TestCase, TestSuite, TestSuiteBuilder
from tests.integration_tests.chat_completion_suites.text import (
_test_databricks_multi_system,
_test_mistral_stop_sequence,
_test_multi_system,
_test_stop_sequence,
_test_text_common,
build_multi_system,
build_stop_sequence,
build_text_common,
)
from tests.integration_tests.chat_completion_suites.tools import (
_test_tools_common,
build_tools_common,
)
from tests.integration_tests.chat_completion_suites.vision import (
_test_vision_common,
build_vision_common,
)
from tests.integration_tests.constants import TEST_DEPLOYMENTS_CONFIG
from tests.utils.openai import (
Expand All @@ -45,13 +43,11 @@ def create_test_cases(
"test_case",
create_test_cases(
[
_test_text_common,
_test_stop_sequence,
_test_mistral_stop_sequence,
_test_multi_system,
_test_databricks_multi_system,
_test_tools_common,
_test_vision_common,
build_text_common,
build_stop_sequence,
build_multi_system,
build_tools_common,
build_vision_common,
]
),
ids=lambda tc: tc.get_id(),
Expand Down

0 comments on commit fbb7343

Please sign in to comment.