From 95a19f807c38158ac575e548898bcfbcdf614650 Mon Sep 17 00:00:00 2001 From: Marco Milanta <72441376+mmilanta@users.noreply.github.com> Date: Tue, 10 Dec 2024 16:53:51 +0100 Subject: [PATCH] move: old repo edits (#24) --- testing/invariant/__main__.py | 36 ++-- testing/invariant/config.py | 5 +- testing/invariant/constants.py | 3 + .../invariant/custom_types/invariant_dict.py | 1 - .../invariant/custom_types/invariant_image.py | 39 +++- .../custom_types/invariant_string.py | 89 ++++++---- testing/invariant/custom_types/trace.py | 73 +++----- testing/invariant/manager.py | 112 +++++++++--- testing/invariant/testing/__init__.py | 4 +- testing/invariant/testing/functional.py | 13 ++ testing/invariant/utils/utils.py | 10 ++ testing/sample_tests/demos/chatbot.py | 68 +++++++ .../sample_tests/demos/computer_use_agent.py | 166 ++++++++++++++++++ testing/sample_tests/demos/qa-chatbot.py | 70 ++++++++ .../weather_agent/test_weather_agent.py | 2 +- .../sample_tests/openai/test_python_agent.py | 7 +- .../test_capital_finder_agent.py | 4 +- .../custom_types/test_invariant_image.py | 8 +- .../custom_types/test_invariant_string.py | 71 +++++--- testing/tests/test_assertion_args.py | 3 +- testing/tests/test_contains.py | 45 +++++ testing/tests/test_display.py | 120 +++++++++++++ testing/tests/test_strings.py | 1 + 23 files changed, 773 insertions(+), 177 deletions(-) create mode 100644 testing/sample_tests/demos/chatbot.py create mode 100644 testing/sample_tests/demos/computer_use_agent.py create mode 100644 testing/sample_tests/demos/qa-chatbot.py create mode 100644 testing/tests/test_contains.py create mode 100644 testing/tests/test_display.py diff --git a/testing/invariant/__main__.py b/testing/invariant/__main__.py index daa92bd..4822091 100644 --- a/testing/invariant/__main__.py +++ b/testing/invariant/__main__.py @@ -14,6 +14,7 @@ from invariant.config import Config from invariant.constants import ( + INVARIANT_AGENT_PARAMS_ENV_VAR, INVARIANT_AP_KEY_ENV_VAR, INVARIANT_RUNNER_TEST_RESULTS_DIR, INVARIANT_TEST_RUNNER_CONFIG_ENV_VAR, @@ -54,6 +55,9 @@ def parse_args(args: list[str]) -> tuple[argparse.Namespace, list[str]]: {BOLD}https://explorer.invariantlabs.ai/docs/{END} to see steps to generate an API key.""", ) + parser.add_argument( + "--agent-params", help="JSON containing the parameters of the agent", type=str, default=None + ) return parser.parse_known_args(args) @@ -68,6 +72,11 @@ def create_config(args: argparse.Namespace) -> Config: """ api_key = os.getenv(INVARIANT_AP_KEY_ENV_VAR) + try: + agent_params = None if args.agent_params is None else json.loads(args.agent_params) + except json.JSONDecodeError as e: + raise ValueError("--agent-params should be a valid JSON") from e + prefix = args.dataset_name dataset_name = f"{prefix}-{int(time.time())}" @@ -78,6 +87,7 @@ def create_config(args: argparse.Namespace) -> Config: push=args.push, api_key=api_key, result_output_dir=INVARIANT_RUNNER_TEST_RESULTS_DIR, + agent_params=agent_params, ) @@ -122,15 +132,19 @@ def finalize_tests_and_print_summary(conf: Config, open_browser: bool) -> None: # update dataset metadata if --push if conf.push: + metadata = { + "invariant_test_results": { + "num_tests": tests, + "num_passed": passed_count, + } + } + if conf.agent_params: + metadata["agent_params"] = conf.agent_params + client = InvariantClient() client.create_request_and_update_dataset_metadata( dataset_name=conf.dataset_name, - metadata={ - "invariant_test_results": { - "num_tests": tests, - "num_passed": passed_count, - } - }, + metadata=metadata, request_kwargs={"verify": utils.ssl_verification_enabled()}, ) @@ -148,16 +162,14 @@ def test(args: list[str]) -> None: config = create_config(invariant_runner_args) os.environ[INVARIANT_TEST_RUNNER_CONFIG_ENV_VAR] = config.model_dump_json() # pass along actual terminal width to the test runner (for better formatting) - os.environ[INVARIANT_TEST_RUNNER_TERMINAL_WIDTH_ENV_VAR] = str( - utils.terminal_width() - ) + os.environ[INVARIANT_TEST_RUNNER_TERMINAL_WIDTH_ENV_VAR] = str(utils.terminal_width()) + if invariant_runner_args.agent_params: + os.environ[INVARIANT_AGENT_PARAMS_ENV_VAR] = invariant_runner_args.agent_params except ValueError as e: logger.error("Configuration error: %s", e) sys.exit(1) - test_results_directory_path = utils.get_test_results_directory_path( - config.dataset_name - ) + test_results_directory_path = utils.get_test_results_directory_path(config.dataset_name) if os.path.exists(test_results_directory_path): shutil.rmtree(test_results_directory_path) diff --git a/testing/invariant/config.py b/testing/invariant/config.py index 131a2f0..a6829e3 100644 --- a/testing/invariant/config.py +++ b/testing/invariant/config.py @@ -12,6 +12,7 @@ class Config(BaseModel): push: bool = False api_key: Optional[str] result_output_dir: str + agent_params: Optional[dict] @field_validator("api_key") @classmethod @@ -19,7 +20,5 @@ def validate_api_key(cls, api_key_value, info): """Ensure that `api_key` is provided if `push` is set to true.""" push_value = info.data.get("push") if push_value and not api_key_value: - raise ValueError( - "`INVARIANT_API_KEY` is required if `push` is set to true." - ) + raise ValueError("`INVARIANT_API_KEY` is required if `push` is set to true.") return api_key_value diff --git a/testing/invariant/constants.py b/testing/invariant/constants.py index 98b915c..262069e 100644 --- a/testing/invariant/constants.py +++ b/testing/invariant/constants.py @@ -7,3 +7,6 @@ # used to pass the actual terminal width to the test runner # (if not available, we'll use a fallback, but nice to have) INVARIANT_TEST_RUNNER_TERMINAL_WIDTH_ENV_VAR = "INVARIANT_TERMINAL_WIDTH" + +# used to pass the agent params to the test runner +INVARIANT_AGENT_PARAMS_ENV_VAR = "INVARIANT_AGENT_PARAMS" diff --git a/testing/invariant/custom_types/invariant_dict.py b/testing/invariant/custom_types/invariant_dict.py index 79a3bf4..d591692 100644 --- a/testing/invariant/custom_types/invariant_dict.py +++ b/testing/invariant/custom_types/invariant_dict.py @@ -5,7 +5,6 @@ from invariant.custom_types.invariant_bool import InvariantBool from invariant.custom_types.invariant_value import InvariantValue - class InvariantDict: """Invariant implementation of a dict type""" diff --git a/testing/invariant/custom_types/invariant_image.py b/testing/invariant/custom_types/invariant_image.py index b42261d..6b4b7bc 100644 --- a/testing/invariant/custom_types/invariant_image.py +++ b/testing/invariant/custom_types/invariant_image.py @@ -1,14 +1,15 @@ -""" A custom type for an invariant image. """ +"""A custom type for an invariant image.""" import base64 import io from typing import Optional +from PIL import Image + from invariant.custom_types.invariant_bool import InvariantBool from invariant.custom_types.invariant_string import InvariantString from invariant.scorers.llm.classifier import Classifier from invariant.scorers.utils.ocr import OCRDetector -from PIL import Image class InvariantImage(InvariantString): @@ -43,9 +44,7 @@ def llm_vision( client (invariant.scorers.llm.clients.client.SupportedClients): The client to use for the LLM. """ - llm_clf = Classifier( - model=model, prompt=prompt, options=options, vision=True, client=client - ) + llm_clf = Classifier(model=model, prompt=prompt, options=options, vision=True, client=client) res = llm_clf.classify_vision( self.value, image_type=self.image_type, use_cached_result=use_cached_result ) @@ -53,10 +52,36 @@ def llm_vision( def ocr_contains( self, - text: str, + text: str | InvariantString, case_sensitive: bool = False, bbox: Optional[dict] = None, ) -> InvariantBool: """Check if the value contains the given text using OCR.""" + addresses = self.addresses + if type(text) == InvariantString: + addresses.extend(text.addresses) + text = text.value res = OCRDetector().contains(self.image, text, case_sensitive, bbox) - return InvariantBool(res, self.addresses) + return InvariantBool(res, addresses) + + def ocr_contains_any( + self, + texts: list[str | InvariantString], + case_sensitive: bool = False, + bbox: Optional[dict] = None, + ) -> InvariantBool: + for text in texts: + if res := self.ocr_contains(text, case_sensitive, bbox): + return res + return InvariantBool(False, self.addresses) + + def ocr_contains_all( + self, + texts: list[str | InvariantString], + case_sensitive: bool = False, + bbox: Optional[dict] = None, + ) -> InvariantBool: + for text in texts: + if not (res := self.ocr_contains(text, case_sensitive, bbox)): + return res + return InvariantBool(True, self.addresses) diff --git a/testing/invariant/custom_types/invariant_string.py b/testing/invariant/custom_types/invariant_string.py index dd50ba8..92d1f20 100644 --- a/testing/invariant/custom_types/invariant_string.py +++ b/testing/invariant/custom_types/invariant_string.py @@ -5,9 +5,10 @@ import json import re from operator import ge, gt, le, lt, ne -from typing import Any, Union +from typing import Any, Literal, Union from _pytest.python_api import ApproxBase + from invariant.custom_types.invariant_bool import InvariantBool from invariant.custom_types.invariant_number import InvariantNumber from invariant.custom_types.invariant_value import InvariantValue @@ -68,9 +69,7 @@ def __le__(self, other: Union[str, "InvariantString"]) -> InvariantBool: def __add__(self, other: Union[str, "InvariantString"]) -> "InvariantString": """Concatenate the string with another string.""" if isinstance(other, InvariantString): - return InvariantString( - self.value + other.value, self.addresses + other.addresses - ) + return InvariantString(self.value + other.value, self.addresses + other.addresses) return InvariantString(self.value + other, self.addresses) def __radd__(self, other: str) -> "InvariantString": @@ -85,7 +84,7 @@ def __repr__(self) -> str: def __len__(self): raise NotImplementedError( - "InvariantString does not support len(). Please use .len() instead." + "InvariantString does not support len(). Please use functionals.len() instead." ) def __getitem__(self, key: Any, default: Any = None) -> "InvariantString": @@ -110,16 +109,12 @@ def __getitem__(self, key: Any, default: Any = None) -> "InvariantString": def count(self, pattern: str) -> InvariantNumber: """Counts the number of occurences of the given regex pattern.""" new_addresses = [] - for match in re.finditer(pattern, self.value): + for match in re.finditer(pattern, self.value, re.DOTALL): start, end = match.span() new_addresses.append(f"{start}-{end}") return InvariantNumber( len(new_addresses), - ( - self.addresses - if len(new_addresses) == 0 - else self._concat_addresses(new_addresses) - ), + (self.addresses if len(new_addresses) == 0 else self._concat_addresses(new_addresses)), ) def len(self): @@ -127,8 +122,7 @@ def len(self): return InvariantNumber(len(self.value), self.addresses) def __getattr__(self, attr): - """ - Delegate attribute access to the underlying string. + """Delegate attribute access to the underlying string. Args: attr (str): The attribute being accessed. @@ -152,9 +146,7 @@ def wrapper(*args, **kwargs): return method raise AttributeError(f"'InvariantString' object has no attribute '{attr}'") - def _concat_addresses( - self, other_addresses: list[str] | None, separator: str = ":" - ) -> list[str]: + def _concat_addresses(self, other_addresses: list[str] | None, separator: str = ":") -> list[str]: """Concatenate the addresses of two invariant values.""" if other_addresses is None: return self.addresses @@ -180,39 +172,56 @@ def _concat_addresses( def moderation(self) -> InvariantBool: """Check if the value is moderated.""" - analyzer = ModerationAnalyzer() res = analyzer.detect_all(self.value) new_addresses = [str(range) for _, range in res] return InvariantBool(len(res) > 0, self._concat_addresses(new_addresses)) def contains( - self, pattern: str | InvariantString, flags=re.IGNORECASE + self, + *patterns: Union[str, InvariantString], + criterion: Literal["all", "any"] = "all", + flags=re.IGNORECASE, ) -> InvariantBool: - """Check if the value contains the given pattern. This ignores case by default. + """Check if the value contains all of the given patterns. Args: - pattern (str | InvariantString): The pattern to check for. - flags (int): The flags to use for the regex search. To pass in multiple flags, - use the bitwise OR operator (|). By default, this is re.IGNORECASE. + *patterns: Variable number of patterns to check for. Each pattern can be a string + or InvariantString. + criterion: The criterion to use for the contains check - can be "all" or "any". + flags: The flags to use for the regex search. To pass in multiple flags, use the bitwise OR operator (|). By default, this is re.IGNORECASE. + Returns: + InvariantBool: True if all patterns are found, False otherwise. The addresses will + contain the locations of all pattern matches if found. """ - if isinstance(pattern, InvariantString): - pattern = pattern.value + if criterion not in ["all", "any"]: + raise ValueError("Criterion must be either 'all' or 'any'") new_addresses = [] + for pattern in patterns: + if isinstance(pattern, InvariantString): + pattern = pattern.value - for match in re.finditer(pattern, self.value, flags=flags): - start, end = match.span() - new_addresses.append(f"{start}-{end}") + pattern_matches = [] + for match in re.finditer(pattern, self.value, flags=flags): + start, end = match.span() + pattern_matches.append(f"{start}-{end}") - return InvariantBool( - len(new_addresses) > 0, - ( - self.addresses - if len(new_addresses) == 0 - else self._concat_addresses(new_addresses) - ), - ) + if criterion == "all" and not pattern_matches: + return InvariantBool(False, self.addresses) + if criterion == "any" and pattern_matches: + return InvariantBool(True, self._concat_addresses(pattern_matches)) + new_addresses.extend(pattern_matches) + + return InvariantBool(criterion == "all", self._concat_addresses(new_addresses)) + + def contains_all(self, *patterns: Union[str, InvariantString]) -> InvariantBool: + """Check if the value contains all of the given patterns.""" + return self.contains(*patterns, criterion="all") + + def contains_any(self, *patterns: Union[str, InvariantString]) -> InvariantBool: + """Check if the value contains any of the given patterns.""" + return self.contains(*patterns, criterion="any") def __contains__(self, pattern: str | InvariantString) -> InvariantBool: """Check if the value contains the given pattern.""" @@ -224,9 +233,13 @@ def match(self, pattern: str, group_id: int | str = 0) -> InvariantString: if match is None: return None start, end = match.span(group_id) - return InvariantString( - match.group(group_id), self._concat_addresses([f"{start}-{end}"]) - ) + return InvariantString(match.group(group_id), self._concat_addresses([f"{start}-{end}"])) + + def match_all(self, pattern: str, group_id: int | str = 0): + """Match the value against the given regex pattern and return all matches.""" + for match in re.finditer(pattern, self.value): + start, end = match.span(group_id) + yield InvariantString(match.group(group_id), self._concat_addresses([f"{start}-{end}"])) def is_similar(self, other: str, threshold: float = 0.5) -> InvariantBool: """Check if the value is similar to the given string using cosine similarity.""" diff --git a/testing/invariant/custom_types/trace.py b/testing/invariant/custom_types/trace.py index 5797771..994195a 100644 --- a/testing/invariant/custom_types/trace.py +++ b/testing/invariant/custom_types/trace.py @@ -5,13 +5,14 @@ import json from typing import Any, Callable, Dict, Generator, List -from invariant.custom_types.invariant_dict import InvariantDict, InvariantValue -from invariant.custom_types.matchers import ContainsImage, Matcher -from invariant.utils.utils import ssl_verification_enabled from invariant_sdk.client import Client as InvariantClient from invariant_sdk.types.push_traces import PushTracesResponse from pydantic import BaseModel +from invariant.custom_types.invariant_dict import InvariantDict, InvariantValue +from invariant.custom_types.matchers import ContainsImage, Matcher +from invariant.utils.utils import ssl_verification_enabled + def iterate_tool_calls( messages: list[dict], @@ -64,8 +65,7 @@ def iterate_tool_calls( def iterate_tool_outputs( messages: list[dict], ) -> Generator[tuple[list[str], dict], None, None]: - """ - Generator function to iterate over tool outputs in a list of messages. + """Generator function to iterate over tool outputs in a list of messages. Args: messages (list[dict]): A list of messages without address information. @@ -84,8 +84,7 @@ def iterate_tool_outputs( def iterate_messages( messages: list[dict], ) -> Generator[tuple[list[str], dict], None, None]: - """ - Generator function to iterate over messages in a list of messages. + """Generator function to iterate over messages in a list of messages. Args: messages (list[dict]): A list of messages without address information. @@ -211,8 +210,7 @@ def run_assertions(self, assertions: list[Callable[Trace, Any]]): # Functions to check data_types @property def content_checkers(self) -> Dict[str, Matcher]: - """ - Register content checkers for data_types. When implementing a new content checker, + """Register content checkers for data_types. When implementing a new content checker, add the new content checker to the dictionary below. Returns: @@ -223,11 +221,8 @@ def content_checkers(self) -> Dict[str, Matcher]: } return __content_checkers__ - def _is_data_type( - self, message: InvariantDict, data_type: str | None = None - ) -> bool: - """ - Check if a message matches a given data_type using the content_checkers. + def _is_data_type(self, message: InvariantDict, data_type: str | None = None) -> bool: + """Check if a message matches a given data_type using the content_checkers. data_type should correspond to the keys in the content_checkers dictionary. If data_type is None, the message is considered to match the data_type (i.e., no filtering is performed). @@ -262,8 +257,7 @@ def _filter_trace( data_type: str | None = None, **filterkwargs, ) -> list[InvariantDict] | InvariantDict: - """ - Filter the trace based on the provided selector, keyword arguments and data_type. Use this + """Filter the trace based on the provided selector, keyword arguments and data_type. Use this method as a helper for custom filters such as messages(), tool_calls(), and tool_outputs(). Args: @@ -278,27 +272,20 @@ def _filter_trace( Returns: list[InvariantDict] | InvariantDict: The filtered trace. """ - # If a single index is provided, return the message at that index if isinstance(selector, int): for i, (addresses, message) in enumerate(iterator_func(self.trace)): if i == selector: return_val = InvariantDict(message, [f"{i}"]) - return ( - return_val - if self._is_data_type(return_val, data_type) - else None - ) + return return_val if self._is_data_type(return_val, data_type) else None # If a dictionary is provided, filter messages based on the dictionary elif isinstance(selector, dict): return [ InvariantDict(message, addresses) for addresses, message in iterator_func(self.trace) - if all( - traverse_dot_path(message, kwname)[0] == kwvalue - for kwname, kwvalue in selector.items() - ) + + if all(traverse_dot_path(message, kwname)[0] == kwvalue for kwname, kwvalue in selector.items()) and self._is_data_type(InvariantDict(message, addresses), data_type) ] @@ -308,9 +295,7 @@ def _filter_trace( InvariantDict(message, addresses) for addresses, message in iterator_func(self.trace) if all( - match_keyword_function( - kwname, kwvalue, message.get(kwname), message - ) + match_keyword_function(kwname, kwvalue, message.get(kwname), message) for kwname, kwvalue in filterkwargs.items() ) and self._is_data_type(InvariantDict(message, addresses), data_type) @@ -329,8 +314,7 @@ def messages( data_type: str | None = None, **filterkwargs, ) -> list[InvariantDict] | InvariantDict: - """ - Get all messages from the trace that match the provided selector, data_type, and keyword filters. + """Get all messages from the trace that match the provided selector, data_type, and keyword filters. Args: selector: The selector to use to filter the trace. @@ -340,13 +324,10 @@ def messages( Returns: list[InvariantDict] | InvariantDict: The filtered messages. """ - if isinstance(selector, int): return InvariantDict(self.trace[selector], [str((selector + len(self.trace)) % len(self.trace))]) - return self._filter_trace( - iterate_messages, match_keyword_filter, selector, data_type, **filterkwargs - ) + return self._filter_trace(iterate_messages, match_keyword_filter, selector, data_type, **filterkwargs) def tool_calls( self, @@ -354,8 +335,7 @@ def tool_calls( data_type: str | None = None, **filterkwargs, ) -> list[InvariantDict] | InvariantDict: - """ - Get all tool calls from the trace that match the provided selector, data_type, and keyword filters. + """Get all tool calls from the trace that match the provided selector, data_type, and keyword filters. Args: selector: The selector to use to filter the trace. @@ -379,8 +359,7 @@ def tool_outputs( data_type: str | None = None, **filterkwargs, ) -> list[InvariantDict] | InvariantDict: - """ - Get all tool outputs from the trace that match the provided selector, data_type, and keyword filters. + """Get all tool outputs from the trace that match the provided selector, data_type, and keyword filters. Args: selector: The selector to use to filter the trace. @@ -433,32 +412,24 @@ def tool_pairs(self) -> list[tuple[InvariantDict, InvariantDict]]: ) break - return [ - (res_pair[1], res_pair[2]) for res_pair in res if res_pair[2] is not None - ] + return [(res_pair[1], res_pair[2]) for res_pair in res if res_pair[2] is not None] def to_python(self) -> str: - """ - Returns a snippet of Python code construct that can be used + """Returns a snippet of Python code construct that can be used to recreate the trace in a Python script. Returns: str: The Python string representing the trace. """ - return ( - "Trace(trace=[\n" - + ",\n".join(" " + str(msg) for msg in self.trace) - + "\n])" - ) + return "Trace(trace=[\n" + ",\n".join(" " + str(msg) for msg in self.trace) + "\n])" def push_to_explorer( self, client: InvariantClient | None = None, dataset_name: None | str = None, ) -> PushTracesResponse: - """ - Pushes the trace to the explorer. + """Pushes the trace to the explorer. Args: client: The client used to push. If None a standard invariant_sdk client is initialized. diff --git a/testing/invariant/manager.py b/testing/invariant/manager.py index f23b74b..6cf537d 100644 --- a/testing/invariant/manager.py +++ b/testing/invariant/manager.py @@ -10,6 +10,7 @@ import traceback as tb from contextvars import ContextVar from json import JSONEncoder +from typing import Literal import pytest from invariant_sdk.client import Client as InvariantClient @@ -18,6 +19,8 @@ from invariant.config import Config from invariant.constants import INVARIANT_TEST_RUNNER_CONFIG_ENV_VAR +from invariant.custom_types.invariant_dict import InvariantDict +from invariant.custom_types.invariant_string import InvariantString from invariant.custom_types.test_result import AssertionResult, TestResult from invariant.formatter import format_trace from invariant.utils import utils @@ -265,10 +268,18 @@ def handle_outcome(self): pytest.fail(error_message, pytrace=False) - def _create_annotation( + def _create_annotations( self, assertion: AssertionResult, address: str, source: str, assertion_id: int - ): - """Create an annotation for a single assertion.""" + ) -> list[dict]: + """Create annotations for a single assertion. + + This converts assertion to a standard which is easy to parse by the explorer. + In particular: + * addresses pointing at a part of a message content, are rendered by highlighting that part. + * addresses pointing at full messages are rendered by highlighting, + the whole message content if available, otherwise all the tool calls. + * addresses pointing at tool calls are rendered by highlighting the tool call name. + """ content = assertion.message # if there is no message, we extract the assertion call @@ -279,21 +290,69 @@ def _create_annotation( content = "\n".join(remainder) content = utils.ast_truncate(content.lstrip(">")) - return { - # non-localized assertions are top-level - "address": "messages." + address if address != "" else address, - # the assertion message - "content": content, - # metadata as expected by Explorer - "extra_metadata": { - "source": source, - "test": assertion.test, - "passed": assertion.passed, - "line": assertion.test_line, - # ID of the assertion (if an assertion results in multiple annotations) - "assertion_id": assertion_id, - }, - } + if address == "": + address_to_push = [""] + + elif address.isdigit(): + # Case where the address points to a message, but not a portion of the content + msg = self.trace.trace[int(address)] + if msg.get("content", False): + if isinstance(msg["content"], str | InvariantString): + address_to_push_inner = [f".content:0-{len(msg['content'])}"] + elif isinstance(msg["content"], dict | InvariantDict): + address_to_push_inner = [ + f".content.{k}:0-{len(msg['content'][k])}" + for k in msg["content"] + ] + else: + address_to_push_inner = [""] + elif msg.get("tool_calls", False): + address_to_push_inner = [ + f".tool_calls.{i}.function.name:0-{len(tool_call['function']['name'])}" + for i, tool_call in enumerate(msg["tool_calls"]) + ] + + address_to_push = [ + "messages." + address + atpi for atpi in address_to_push_inner + ] + + elif len(address.split(".")) > 1 and address.split(".")[1] == "tool_calls": + msg = self.trace.trace[int(address.split(".")[0])] + if len(address.split(".")) > 2: + if not address.split(".")[2].isdigit(): + raise ValueError( + f"Tool call index must be an integer, got {address.split('.')[2]}" + ) + tool_calls = [msg["tool_calls"][int(address.split(".")[2])]] + else: + tool_calls = msg["tool_calls"] + address_to_push = [ + "messages." + + address + + f".function.name:0-{len(tool_call['function']['name'])}" + for tool_call in tool_calls + ] + else: + address_to_push = ["messages." + address] + + return [ + { + # non-localized assertions are top-level + "address": atp, + # the assertion message + "content": content, + # metadata as expected by Explorer + "extra_metadata": { + "source": source, + "test": assertion.test, + "passed": assertion.passed, + "line": assertion.test_line, + # ID of the assertion (if an assertion results in multiple annotations) + "assertion_id": assertion_id, + }, + } + for atp in address_to_push + ] def push(self) -> PushTracesResponse: """Push the test results to Explorer.""" @@ -304,7 +363,6 @@ def push(self) -> PushTracesResponse: annotations = [] for assertion in self.assertions: assertion_id = id(assertion) - for address in assertion.addresses: source = ( "test-assertion" if assertion.type == "HARD" else "test-expectation" @@ -312,18 +370,16 @@ def push(self) -> PushTracesResponse: if assertion.passed: source += "-passed" - annotations.append( - self._create_annotation(assertion, address, source, assertion_id) + annotations += self._create_annotations( + assertion, address, source, assertion_id ) if len(assertion.addresses) == 0: - annotations.append( - self._create_annotation( - assertion, - "", - "test-assertion" + ("-passed" if assertion.passed else ""), - assertion_id, - ) + annotations += self._create_annotations( + assertion, + "", + "test-assertion" + ("-passed" if assertion.passed else ""), + assertion_id, ) test_result = self._get_test_result() metadata = { diff --git a/testing/invariant/testing/__init__.py b/testing/invariant/testing/__init__.py index 682e742..d7a7106 100644 --- a/testing/invariant/testing/__init__.py +++ b/testing/invariant/testing/__init__.py @@ -1,4 +1,4 @@ -""" Imports for invariant testing. """ +"""Imports for invariant testing.""" from invariant.custom_types.assertions import ( assert_equals, @@ -19,6 +19,7 @@ ) from invariant.custom_types.trace import Trace from invariant.custom_types.trace_factory import TraceFactory +from invariant.utils.utils import get_agent_param # re-export trace and various assertion types __all__ = [ @@ -37,4 +38,5 @@ "HasSubstring", "IsSimilar", "IsFactuallyEqual", + "get_agent_param", ] diff --git a/testing/invariant/testing/functional.py b/testing/invariant/testing/functional.py index 5257517..6654117 100644 --- a/testing/invariant/testing/functional.py +++ b/testing/invariant/testing/functional.py @@ -11,6 +11,7 @@ from invariant.custom_types.invariant_bool import InvariantBool from invariant.custom_types.invariant_number import InvariantNumber from invariant.custom_types.invariant_value import InvariantValue +from invariant.custom_types.invariant_string import InvariantString def map( # pylint: disable=redefined-builtin @@ -82,6 +83,18 @@ def map_func(a): return sum(map(map_func, iterable)) +def frequency(iterable: Iterable[InvariantNumber | InvariantString]) -> dict[int | float | str, InvariantNumber]: + """Return a dictionary with the frequency of each string in the iterable.""" + freq = {} + for item in iterable: + if item.value in freq: + new_freq = (freq[item.value][0] + 1, freq[item.value][1] + item.addresses) + freq[item.value] = new_freq + else: + freq[item.value] = (1, item.addresses) + return {k: InvariantNumber(v[0], v[1]) for k, v in freq.items()} + + def match( pattern: str, iterable: Iterable[InvariantValue], group_id: int | str = 0 ) -> list[InvariantValue]: diff --git a/testing/invariant/utils/utils.py b/testing/invariant/utils/utils.py index df26d7d..fd5eb16 100644 --- a/testing/invariant/utils/utils.py +++ b/testing/invariant/utils/utils.py @@ -1,15 +1,25 @@ """Utility functions for the invariant runner.""" import ast +import json import os import shutil from invariant.constants import ( + INVARIANT_AGENT_PARAMS_ENV_VAR, INVARIANT_RUNNER_TEST_RESULTS_DIR, INVARIANT_TEST_RUNNER_TERMINAL_WIDTH_ENV_VAR, ) +def get_agent_param(param: str) -> str | None: + """Get a parameter from the environment variable.""" + params = os.getenv(INVARIANT_AGENT_PARAMS_ENV_VAR) + if params is None: + return None + return json.loads(params)[param] + + def get_test_results_directory_path(dataset_name: str) -> str: """Get the directory path for the test results.""" return f"{INVARIANT_RUNNER_TEST_RESULTS_DIR}/results_for_{dataset_name}" diff --git a/testing/sample_tests/demos/chatbot.py b/testing/sample_tests/demos/chatbot.py new file mode 100644 index 0000000..265d05b --- /dev/null +++ b/testing/sample_tests/demos/chatbot.py @@ -0,0 +1,68 @@ +import os + +import openai +import pytest + +from invariant.custom_types.trace_factory import TraceFactory +from invariant.testing import Trace, assert_true, get_agent_param +from invariant.testing import functional as F + + +def run_agent(prompt: str) -> Trace: + agent_prompt = get_agent_param("prompt") + agent_model = get_agent_param("model") + + client = openai.OpenAI() + messages = [ + {"role": "system", "content": agent_prompt}, + {"role": "user", "content": prompt}, + ] + response = client.chat.completions.create( + model=agent_model, + messages=messages, + ) + return TraceFactory.from_openai(messages + [response.choices[0].message.model_dump()]) + + +@pytest.mark.parametrize( + "country,capital", [("France", "Paris"), ("Germany", "Berlin"), ("Italy", "Rome"), ("Spain", "Madrid")] +) +def test_capitals(country, capital): + trace = run_agent(f"What's the capital of {country}?") + with trace.as_context(): + assert_true(trace.messages(role="assistant")[0]["content"].contains(capital)) + + +@pytest.mark.parametrize("n", [5, 10]) +def test_emails(n): + trace = run_agent(f"Write {n} randomly generated e-mail addresses") + with trace.as_context(): + emails = trace.messages(role="assistant")[0]["content"].match_all(r"[a-zA-Z0-9_\.]+@[a-zA-Z0-9\._]+") + assert_true(F.len(emails) == n) + + +def test_small_big(): + trace = run_agent("What is the opposite of small? Answer with one word only.") + with trace.as_context(): + assert_true(trace.messages(role="assistant")[0]["content"].is_similar("big")) + + +def test_haiku(): + trace = run_agent("Write a haiku that mentions 7 cities in Switzerland") + with trace.as_context(): + msg = trace.messages(role="assistant")[0]["content"] + assert_true(F.len(msg.extract("city in Switzerland")) == 7) + + +def test_python_code(): + trace = run_agent( + "Write a function that takes a list of numbers and returns the sum of the squares of the numbers" + ) + with trace.as_context(): + msg = trace.messages(role="assistant")[0]["content"] + if "```python" in msg: + code = msg.match("```python(.*)```", 1) + assert_true(code.is_valid_code("python")) + else: + res = msg.is_valid_code("python") + assert_true(res) diff --git a/testing/sample_tests/demos/computer_use_agent.py b/testing/sample_tests/demos/computer_use_agent.py new file mode 100644 index 0000000..e70a956 --- /dev/null +++ b/testing/sample_tests/demos/computer_use_agent.py @@ -0,0 +1,166 @@ +import re + +import urllib3 + +import invariant.testing.functional as F +from invariant.custom_types.invariant_image import InvariantImage +from invariant.custom_types.trace_factory import TraceFactory +from invariant.testing import Trace, assert_false, assert_true, expect_true +from invariant.utils.explorer import from_explorer + +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + + +def does_not_click_on_firefox_menu(trace: Trace): + """Agent should not click on the firefox hamburger menu on the right.""" + for tool_out in trace.tool_outputs(data_type="image"): + assert_false(tool_out["content"].ocr_contains_all("New tab", "New window")) + + +def does_not_make_python_error(trace: Trace): + """Agent should not produce code that results in ModuleNotFoundError.""" + for tool_out in trace.messages(role="tool"): + assert_false(tool_out["content"].contains("ModuleNotFoundError")) + + +def does_not_make_file_edit_errors(trace: Trace): + """Agent should not make file edit errors.""" + for tool_out in trace.tool_outputs(): + assert_false(tool_out["content"].contains("Cannot overwrite files using command `create`.")) + + +global_asserts = [ + does_not_click_on_firefox_menu, + does_not_make_python_error, + does_not_make_file_edit_errors, +] + + +def run_agent(prompt: str) -> Trace: + if "annotate" in prompt: + return TraceFactory.from_explorer("mbalunovic/tests-1732714692", 1) + elif "100 traces" in prompt: + return TraceFactory.from_explorer("mbalunovic/tests-1732714692", 2) + elif "chats-about-food" in prompt: + return TraceFactory.from_explorer("mbalunovic/tests-1732714692", 3) + elif "anthropic" in prompt: + return TraceFactory.from_explorer("mbalunovic/tests-1732714692", 4) + elif "fastapi" in prompt: + return TraceFactory.from_explorer("mbalunovic/tests-1732714692", 5) + elif "fibonacci" in prompt: + return TraceFactory.from_explorer("mbalunovic/tests-1732714692", 6) + + +def test_annotation(): + trace = run_agent( + """Go to this snippet https://explorer.invariantlabs.ai/trace/9d55fa77-18f5-4a3b-9f7f-deae06833c58 + and annotate the first comment with: "nice nice" """ + ) + + with trace.as_context(): + trace.run_assertions(global_asserts) + + expect_true( + max( + F.frequency( + F.filter( + lambda x: "http" in x.value, + F.map( + lambda tc: tc.argument("text"), + trace.tool_calls({"arguments.action": "type", "name": "computer"}), + ), + ) + ).values() + ) + <= 1 + ) + # assert that the last screenshot contains the text "annotated" and text "nice nice" + last_screenshot = trace.messages(role="tool")[-1]["content"] + assert_true(last_screenshot.ocr_contains_all(["annotated", "nice nice"])) + + +def test_upload_traces(): + trace = run_agent("""upload a dataset of 100 traces using a browser""") + with trace.as_context(): + trace.run_assertions(global_asserts) + assert_false( + F.any( + F.map( + lambda x: x.argument("command").contains_all("100", "EOF", ".py"), + trace.tool_calls(name="bash"), + ) + ) + ) + + +def test_food_dataset(): + trace = run_agent("""create an empty dataset "chats-about-food", then use sdk to push 4 different traces + to it and then finally use sdk to update the metadata of the dataset to have "weather="snowy day" and "mood"="great" + after that go to the UI and verify that there are 4 traces and metadata is good""") + with trace.as_context(): + assert_true( + F.any( + F.map( + lambda x: x.argument("file_text").contains( + "create_request_and_push_trace" + ), + trace.tool_calls(name="str_replace_editor"), + ) + ) + ) + + +def test_anthropic(): + trace = run_agent( + """use https://github.com/anthropics/anthropic-sdk-python to generate some traces and upload them + to the explorer using invariant sdk. your ANTHROPIC_API_KEY is already set up with a valid key""" + ) + with trace.as_context(): + trace.run_assertions(global_asserts) + + edit_tool_calls = trace.tool_calls({"name": "str_replace_editor", "arguments.command": "create"}) + file_text = edit_tool_calls[0].argument("file_text") + assert_true(file_text.contains_any("import anthropic", "from anthropic import")) + + # Extract the dataset name from a tool output and check if it's in the last screenshot + tool_outs = trace.messages(role="tool") + dataset_name = F.match(r"Dataset: (\w+)", F.map(lambda x: x["content"], tool_outs), 1)[0] + tool_out = trace.messages(role="tool")[-1] + assert_true(tool_out["content"].ocr_contains(dataset_name)) + + +def test_code_agent_fastapi(): + trace = run_agent( + """use fastapi to create a count_words api that receives a string and counts + the number of words in it, then write a small client that tests it with a couple of different inputs""" + ) + + with trace.as_context(): + trace.run_assertions(global_asserts) + + for tool_call, tool_out in trace.tool_pairs(): + assert_false( + tool_call["function"]["name"] == "bash" + and tool_out.get("content", "").contains("Permission denied") + ) + + tool_calls = trace.tool_calls({"name": "str_replace_editor"}) + + max_freq = max( + F.frequency(F.map(lambda x: x.argument("file_text"), tool_calls)).values() + ) + assert_true(max_freq <= 2, "At least 3 edits to the same file with the same text") + +def test_fibonacci(): + trace = run_agent( + """write me a python function compute_fibonacci(n) that computes n-th fibonacci number and test it on a few inputs""" + ) + with trace.as_context(): + trace.run_assertions(global_asserts) + + tool_calls = trace.tool_calls({"name": "str_replace_editor", "arguments.command": "create"}) + for tc in tool_calls: + res = tc.argument("file_text").execute_contains( + "144", "print(compute_fibonacci(12))" + ) + assert_true(res, "Execution output does not contain 144") diff --git a/testing/sample_tests/demos/qa-chatbot.py b/testing/sample_tests/demos/qa-chatbot.py new file mode 100644 index 0000000..ed1635e --- /dev/null +++ b/testing/sample_tests/demos/qa-chatbot.py @@ -0,0 +1,70 @@ +import os + +import openai +import pytest + +from invariant.custom_types.trace_factory import TraceFactory +from invariant.testing import Trace, assert_true, get_agent_param +from invariant.testing import functional as F + + +def run_agent(prompt: str) -> Trace: + agent_prompt = get_agent_param("prompt") + agent_model = get_agent_param("model") + + client = openai.OpenAI() + messages = [ + {"role": "system", "content": agent_prompt}, + {"role": "user", "content": prompt}, + ] + response = client.chat.completions.create( + model=agent_model, + messages=messages, + ) + return TraceFactory.from_openai(messages + [response.choices[0].message.model_dump()]) + + +@pytest.mark.parametrize( + "country,capital", [("France", "Paris"), ("Germany", "Berlin"), ("Italy", "Rome"), ("Spain", "Madrid")] +) +def test_capitals(country, capital): + trace = run_agent(f"What's the capital of {country}?") + with trace.as_context(): + assert_true(trace.messages(role="assistant")[0]["content"].contains(capital)) + + +@pytest.mark.parametrize("n", [5, 10]) +def test_emails(n): + trace = run_agent(f"Write {n} randomly generated e-mail addresses") + with trace.as_context(): + emails = list( + trace.messages(role="assistant")[0]["content"].match_all(r"[a-zA-Z0-9_\.]+@[a-zA-Z0-9\._]+") + ) + assert_true(F.len(emails) == n) + + +def test_small_big(): + trace = run_agent("What is the opposite of small? Answer with one word only.") + with trace.as_context(): + assert_true(trace.messages(role="assistant")[0]["content"].is_similar("big")) + + +def test_haiku(): + trace = run_agent("Write a haiku that mentions 7 cities in Switzerland") + with trace.as_context(): + msg = trace.messages(role="assistant")[0]["content"] + assert_true(F.len(msg.extract("city in Switzerland")) == 7) + + +def test_python_code(): + trace = run_agent( + "Write a function that takes a list of numbers and returns the sum of the squares of the numbers" + ) + with trace.as_context(): + msg = trace.messages(role="assistant")[0]["content"] + if "```python" in msg: + code = msg.match("```python(.*)```", 1) + assert_true(code.is_valid_code("python")) + else: + res = msg.is_valid_code("python") + assert_true(res) diff --git a/testing/sample_tests/langgraph/weather_agent/test_weather_agent.py b/testing/sample_tests/langgraph/weather_agent/test_weather_agent.py index ea0fe82..f4742df 100644 --- a/testing/sample_tests/langgraph/weather_agent/test_weather_agent.py +++ b/testing/sample_tests/langgraph/weather_agent/test_weather_agent.py @@ -58,7 +58,7 @@ def test_weather_agent_with_sf_and_nyc(weather_agent): find_weather_tool_calls = trace.tool_calls(name="_find_weather") assert_true(len(find_weather_tool_calls) == 2) find_weather_tool_call_args = str( - F.map(lambda x: x["function"]["arguments"], find_weather_tool_calls) + F.map(lambda x: x.argument(), find_weather_tool_calls) ) assert_true( "San Francisco" in find_weather_tool_call_args diff --git a/testing/sample_tests/openai/test_python_agent.py b/testing/sample_tests/openai/test_python_agent.py index 0e2df15..e287b8a 100644 --- a/testing/sample_tests/openai/test_python_agent.py +++ b/testing/sample_tests/openai/test_python_agent.py @@ -1,9 +1,10 @@ import json from unittest.mock import MagicMock -import invariant.testing.functional as F import openai -from invariant.testing import TraceFactory, assert_true, expect_equals + +import invariant.testing.functional as F +from invariant.testing import TraceFactory, assert_true, expect_equals, Trace def run_python(code): @@ -161,7 +162,7 @@ def test_java_question(): run_python_tool_call = trace.tool_calls(name="run_python") assert_true(F.len(run_python_tool_call) == 0) expect_equals( - "I can only help with Python code.", trace.messages(-1)["content"] + expected_response, trace.messages(-1)["content"] ) assert_true(trace.messages(-1)["content"].levenshtein(expected_response) < 5) diff --git a/testing/sample_tests/swarm/capital_finder_agent/test_capital_finder_agent.py b/testing/sample_tests/swarm/capital_finder_agent/test_capital_finder_agent.py index 52c67a0..3526b40 100644 --- a/testing/sample_tests/swarm/capital_finder_agent/test_capital_finder_agent.py +++ b/testing/sample_tests/swarm/capital_finder_agent/test_capital_finder_agent.py @@ -1,11 +1,11 @@ """Tests for the capital_finder_agent""" +import invariant.testing.functional as F import pytest from invariant.wrappers.swarm_wrapper import SwarmWrapper -from swarm import Swarm -import invariant.testing.functional as F from invariant.testing import assert_equals, assert_false, assert_true +from swarm import Swarm from .capital_finder_agent import create_agent diff --git a/testing/tests/custom_types/test_invariant_image.py b/testing/tests/custom_types/test_invariant_image.py index cae9fc0..0350441 100644 --- a/testing/tests/custom_types/test_invariant_image.py +++ b/testing/tests/custom_types/test_invariant_image.py @@ -1,8 +1,9 @@ -""" Tests for the InvariantImage class. """ +"""Tests for the InvariantImage class.""" import base64 import pytest + from invariant.custom_types.invariant_image import InvariantImage from invariant.custom_types.invariant_string import InvariantString from invariant.utils.packages import is_program_installed @@ -56,3 +57,8 @@ def test_ocr_detector(): "making", bbox={"x1": 50, "y1": 10, "x2": 120, "y2": 40} ) assert not inv_img.ocr_contains("LLM") + + assert inv_img.ocr_contains_all(["agents", "making"]) + assert not inv_img.ocr_contains_all(["agents", "making", "LLM"]) + assert inv_img.ocr_contains_any(["something", "agents", "abc"]) + assert not inv_img.ocr_contains_any(["something", "def", "abc"]) diff --git a/testing/tests/custom_types/test_invariant_string.py b/testing/tests/custom_types/test_invariant_string.py index 48bd163..81f01de 100644 --- a/testing/tests/custom_types/test_invariant_string.py +++ b/testing/tests/custom_types/test_invariant_string.py @@ -3,12 +3,13 @@ from unittest.mock import patch import pytest +from pytest import approx + from invariant.custom_types.invariant_bool import InvariantBool from invariant.custom_types.invariant_number import InvariantNumber from invariant.custom_types.invariant_string import InvariantString from invariant.scorers.code import Dependencies from invariant.utils.packages import is_program_installed -from pytest import approx def test_invariant_string_initialization(): @@ -82,6 +83,38 @@ def test_invariant_string_contains(value, substring, expected): assert result.value == expected +@pytest.mark.parametrize( + "value, substrings, expected", + [ + (InvariantString("Hello World"), ["Hello", "World"], True), + (InvariantString("Hello World"), ["Hello", "world"], True), + (InvariantString("Hello World"), ["Hello", "Goodbye"], False), + (InvariantString("Hello World"), ["Hell", "o", "World"], True), + ], +) +def test_invariant_string_contains_all(value, substrings, expected): + """Test the contains_all method of InvariantString.""" + result = value.contains_all(*substrings) + assert isinstance(result, InvariantBool) + assert result.value == expected + + +@pytest.mark.parametrize( + "value, substrings, expected", + [ + (InvariantString("Hello World"), ["Hello", "Goodbye"], True), + (InvariantString("Hello World"), ["goodbye", "farewell"], False), + (InvariantString("Hello World"), ["Hell", "Bye"], True), + (InvariantString("Hello World"), ["Goodbye", "Farewell"], False), + ], +) +def test_invariant_string_contains_any(value, substrings, expected): + """Test the contains_any method of InvariantString.""" + result = value.contains_any(*substrings) + assert isinstance(result, InvariantBool) + assert result.value == expected + + @pytest.mark.parametrize( "value1, value2, expected_value, expected_addresses", [ @@ -95,9 +128,7 @@ def test_invariant_string_contains(value, substring, expected): ("World", InvariantString("Hello", ["addr1"]), "WorldHello", ["addr1:0-5"]), ], ) -def test_invariant_string_concatenation( - value1, value2, expected_value, expected_addresses -): +def test_invariant_string_concatenation(value1, value2, expected_value, expected_addresses): """Test the concatenation of InvariantString objects.""" result = value1 + value2 assert isinstance(result, InvariantString) @@ -162,17 +193,11 @@ def test_contains_ignores_case_by_default(): def test_match(): """Test the match transformer of InvariantString.""" - res = InvariantString("Dataset: demo\nAuthor: demo-agent", [""]).match( - "Dataset: (.*)", 1 - ) + res = InvariantString("Dataset: demo\nAuthor: demo-agent", [""]).match("Dataset: (.*)", 1) assert res.value == "demo" and res.addresses == [":9-13"] - res = InvariantString("Dataset: demo\nAuthor: demo-agent", [""]).match( - "Author: (?P.*)", "author" - ) + res = InvariantString("Dataset: demo\nAuthor: demo-agent", [""]).match("Author: (?P.*)", "author") assert res.value == "demo-agent" and res.addresses == [":22-32"] - res = InvariantString("My e-mail is abc@def.com, and yours?", [""]).match( - "[a-z\\.]*@[a-z\\.]*", 0 - ) + res = InvariantString("My e-mail is abc@def.com, and yours?", [""]).match("[a-z\\.]*@[a-z\\.]*", 0) assert res.value == "abc@def.com" and res.addresses == [":13-24"] @@ -190,9 +215,7 @@ def test_is_valid_code(): """Test the is_valid_code transformer of InvariantString.""" assert InvariantString("def hello():\n\treturn 1").is_valid_code("python") - res = InvariantString( - """a = 2\n2x = a\nc=a""", ["messages.0.content"] - ).is_valid_code("python") + res = InvariantString("""a = 2\n2x = a\nc=a""", ["messages.0.content"]).is_valid_code("python") assert isinstance(res, InvariantBool) assert len(res.addresses) == 1 and res.addresses[0] == "messages.0.content:6-12" assert not res @@ -205,9 +228,7 @@ def test_is_valid_code(): } """ - res = InvariantString(invalid_json_example, ["messages.0.content"]).is_valid_code( - "json" - ) + res = InvariantString(invalid_json_example, ["messages.0.content"]).is_valid_code("json") assert isinstance(res, InvariantBool) assert len(res.addresses) == 1 and res.addresses[0] == "messages.0.content:33-54" assert not res @@ -244,9 +265,7 @@ def test_moderation(): pytest.param( "claude-3-5-sonnet-20241022", "Anthropic", - marks=pytest.mark.skip( - "Skipping because we have not setup the API key in the CI" - ), + marks=pytest.mark.skip("Skipping because we have not setup the API key in the CI"), ), ], ) @@ -282,9 +301,7 @@ def test_extract(): assert res[3] == "pears" and res[3].addresses[0] == "message.0.content:104-109" -@pytest.mark.skipif( - not is_program_installed("docker"), reason="Skip for now, needs docker" -) +@pytest.mark.skipif(not is_program_installed("docker"), reason="Skip for now, needs docker") def test_execute_without_detect_packages(): """Test the code execution transformer of InvariantString without detect_packages.""" code = InvariantString("""def f(n):\treturn n**2""", ["messages.0.content"]) @@ -293,9 +310,7 @@ def test_execute_without_detect_packages(): assert len(res.addresses) == 1 and res.addresses[0] == "messages.0.content:0-21" -@pytest.mark.skipif( - not is_program_installed("docker"), reason="Skip for now, needs docker" -) +@pytest.mark.skipif(not is_program_installed("docker"), reason="Skip for now, needs docker") def test_execute_with_detect_packages(): """Test the code execution transformer of InvariantString with detect_packages.""" with patch("invariant.scorers.code._get_dependencies") as mock_get_dependencies: diff --git a/testing/tests/test_assertion_args.py b/testing/tests/test_assertion_args.py index aac57af..2c616a3 100644 --- a/testing/tests/test_assertion_args.py +++ b/testing/tests/test_assertion_args.py @@ -1,5 +1,6 @@ """Checks that flipping the order of (expected, actual) on _equals assertions and -expectations does not crash the test.""" +expectations does not crash the test. +""" from invariant.testing import Trace, assert_equals, expect_equals diff --git a/testing/tests/test_contains.py b/testing/tests/test_contains.py new file mode 100644 index 0000000..d53cec4 --- /dev/null +++ b/testing/tests/test_contains.py @@ -0,0 +1,45 @@ +import invariant.testing.functional as F +from invariant.testing import Trace, assert_true + +from .testutils import should_fail_with + + +@should_fail_with(num_assertion=1) +def test_in(): + """Test that expect_equals works fine with the right order.""" + trace = Trace( + trace=[ + {"role": "user", "content": "Hello there"}, + {"role": "assistant", "content": "there where!?"}, + {"role": "assistant", "content": "Hello to you as well"}, + ] + ) + + with trace.as_context(): + assert_true(F.len(trace.messages(content=lambda c: "Hello" in c)) == 3) + assert_true(F.len(trace.messages(content=lambda c: "there" in c)) == 2) + + +@should_fail_with(num_assertion=1) +def test_in_word_level(): + """Test that expect_equals works fine with the right order.""" + trace = Trace( + trace=[ + {"role": "user", "content": "Hello there"}, + {"role": "assistant", "content": "there where!?"}, + {"role": "assistant", "content": "Hello to you as well"}, + ] + ) + + with trace.as_context(): + trace.messages(content=lambda c: "Hello" in c) + hellos = [msg["content"].contains("Hello") for msg in trace.messages()] + theres = [msg["content"].contains("there") for msg in trace.messages()] + assert_true( + F.len([x for x in hellos if x]) == 3, + "Expected 3 messages to contain 'Hello'", + ) + assert_true( + F.len([x for x in theres if x]) == 2, + "Expected 2 messages to contain 'there'", + ) diff --git a/testing/tests/test_display.py b/testing/tests/test_display.py new file mode 100644 index 0000000..2e0e1e7 --- /dev/null +++ b/testing/tests/test_display.py @@ -0,0 +1,120 @@ +import invariant.testing.functional as F +from invariant.testing import Trace, assert_true + + +def test_assertion_points_to_substring(): + """Test to display how addresses pointing to substr of content.""" + trace = Trace( + trace=[ + {"role": "user", "content": "Hello there"}, + ] + ) + with trace.as_context(): + assert_true( + trace.messages()[0]["content"].contains("Hello"), + "Expected Hello to be in the first message", + ) + + +def test_assertion_points_to_message_content_string(): + """Test to display how addresses pointing to message with content wich is a str.""" + trace = Trace( + trace=[ + {"role": "user", "content": "Hello there"}, + ] + ) + with trace.as_context(): + assert_true( + F.len(trace.messages()) == 1, + "Expected to have exactly one message", + ) + + +def test_assertion_points_to_message_content_dict(): + """Test to display how addresses pointing to message with content wich is a dict.""" + trace = Trace( + trace=[ + {"role": "user", "content": {"this": "is", "a": "dictionary"}}, + ] + ) + with trace.as_context(): + assert_true( + F.len(trace.messages()) == 1, + "Expected to have exactly one message", + ) + + +def test_assertion_points_to_message_tool_call(): + """Test to display how addresses pointing to message composed only of tool calls are displayed.""" + trace = Trace( + trace=[ + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_uB9tU43cqiiE1CyYzrg7b07b", + "function": { + "arguments": '{"query":"lunch with Sarah","date":"2024-05-15"}', + "name": "search_calendar_events", + }, + "type": "function", + }, + { + "id": "call_uB9tU43cqiiE1CyYzrg7b07b", + "function": { + "arguments": '{"query":"lunch with Sarah","date":"2024-05-15"}', + "name": "search_calendar_events", + }, + "type": "function", + }, + ], + }, + ] + ) + with trace.as_context(): + assert_true( + F.len(trace.messages()) == 1, + "Expected to have exactly one message", + ) + + +def test_assertion_points_to_tool_call(): + """Test to display how addresses pointing to tool call are displayed.""" + trace = Trace( + trace=[ + {"role": "user", "content": "Hello there"}, + { + "role": "assistant", + "content": { + "option1": "there where!?", + "option2": "Hello to you as well", + "Hello": "Hello to you as well", + "there": {"there": "Hello to you as well"}, + }, + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_uB9tU43cqiiE1CyYzrg7b07b", + "function": { + "arguments": '{"query":"lunch with Sarah","date":"2024-05-15"}', + "name": "tool_1", + }, + "type": "function", + }, + { + "id": "call_uB9tU43cqiiE1CyYzrg7b07b", + "function": { + "arguments": '{"query":"lunch with Sarah","date":"2024-05-15"}', + "name": "tool_2", + }, + "type": "function", + }, + ], + }, + ] + ) + + with trace.as_context(): + assert_true(F.len(trace.tool_calls(name="tool_2")) == 1) diff --git a/testing/tests/test_strings.py b/testing/tests/test_strings.py index e03be92..6f188ad 100644 --- a/testing/tests/test_strings.py +++ b/testing/tests/test_strings.py @@ -1,6 +1,7 @@ import base64 import pytest + from invariant.scorers.base import approx from invariant.scorers.llm.classifier import Classifier from invariant.scorers.llm.detector import Detector