Skip to content

Commit

Permalink
move: old repo edits (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmilanta authored Dec 10, 2024
1 parent bc9343b commit 95a19f8
Show file tree
Hide file tree
Showing 23 changed files with 773 additions and 177 deletions.
36 changes: 24 additions & 12 deletions testing/invariant/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from invariant.config import Config
from invariant.constants import (
INVARIANT_AGENT_PARAMS_ENV_VAR,
INVARIANT_AP_KEY_ENV_VAR,
INVARIANT_RUNNER_TEST_RESULTS_DIR,
INVARIANT_TEST_RUNNER_CONFIG_ENV_VAR,
Expand Down Expand Up @@ -54,6 +55,9 @@ def parse_args(args: list[str]) -> tuple[argparse.Namespace, list[str]]:
{BOLD}https://explorer.invariantlabs.ai/docs/{END} to see steps to generate
an API key.""",
)
parser.add_argument(
"--agent-params", help="JSON containing the parameters of the agent", type=str, default=None
)
return parser.parse_known_args(args)


Expand All @@ -68,6 +72,11 @@ def create_config(args: argparse.Namespace) -> Config:
"""
api_key = os.getenv(INVARIANT_AP_KEY_ENV_VAR)

try:
agent_params = None if args.agent_params is None else json.loads(args.agent_params)
except json.JSONDecodeError as e:
raise ValueError("--agent-params should be a valid JSON") from e

prefix = args.dataset_name
dataset_name = f"{prefix}-{int(time.time())}"

Expand All @@ -78,6 +87,7 @@ def create_config(args: argparse.Namespace) -> Config:
push=args.push,
api_key=api_key,
result_output_dir=INVARIANT_RUNNER_TEST_RESULTS_DIR,
agent_params=agent_params,
)


Expand Down Expand Up @@ -122,15 +132,19 @@ def finalize_tests_and_print_summary(conf: Config, open_browser: bool) -> None:

# update dataset metadata if --push
if conf.push:
metadata = {
"invariant_test_results": {
"num_tests": tests,
"num_passed": passed_count,
}
}
if conf.agent_params:
metadata["agent_params"] = conf.agent_params

client = InvariantClient()
client.create_request_and_update_dataset_metadata(
dataset_name=conf.dataset_name,
metadata={
"invariant_test_results": {
"num_tests": tests,
"num_passed": passed_count,
}
},
metadata=metadata,
request_kwargs={"verify": utils.ssl_verification_enabled()},
)

Expand All @@ -148,16 +162,14 @@ def test(args: list[str]) -> None:
config = create_config(invariant_runner_args)
os.environ[INVARIANT_TEST_RUNNER_CONFIG_ENV_VAR] = config.model_dump_json()
# pass along actual terminal width to the test runner (for better formatting)
os.environ[INVARIANT_TEST_RUNNER_TERMINAL_WIDTH_ENV_VAR] = str(
utils.terminal_width()
)
os.environ[INVARIANT_TEST_RUNNER_TERMINAL_WIDTH_ENV_VAR] = str(utils.terminal_width())
if invariant_runner_args.agent_params:
os.environ[INVARIANT_AGENT_PARAMS_ENV_VAR] = invariant_runner_args.agent_params
except ValueError as e:
logger.error("Configuration error: %s", e)
sys.exit(1)

test_results_directory_path = utils.get_test_results_directory_path(
config.dataset_name
)
test_results_directory_path = utils.get_test_results_directory_path(config.dataset_name)
if os.path.exists(test_results_directory_path):
shutil.rmtree(test_results_directory_path)

Expand Down
5 changes: 2 additions & 3 deletions testing/invariant/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@ class Config(BaseModel):
push: bool = False
api_key: Optional[str]
result_output_dir: str
agent_params: Optional[dict]

@field_validator("api_key")
@classmethod
def validate_api_key(cls, api_key_value, info):
"""Ensure that `api_key` is provided if `push` is set to true."""
push_value = info.data.get("push")
if push_value and not api_key_value:
raise ValueError(
"`INVARIANT_API_KEY` is required if `push` is set to true."
)
raise ValueError("`INVARIANT_API_KEY` is required if `push` is set to true.")
return api_key_value
3 changes: 3 additions & 0 deletions testing/invariant/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@
# used to pass the actual terminal width to the test runner
# (if not available, we'll use a fallback, but nice to have)
INVARIANT_TEST_RUNNER_TERMINAL_WIDTH_ENV_VAR = "INVARIANT_TERMINAL_WIDTH"

# used to pass the agent params to the test runner
INVARIANT_AGENT_PARAMS_ENV_VAR = "INVARIANT_AGENT_PARAMS"
1 change: 0 additions & 1 deletion testing/invariant/custom_types/invariant_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from invariant.custom_types.invariant_bool import InvariantBool
from invariant.custom_types.invariant_value import InvariantValue


class InvariantDict:
"""Invariant implementation of a dict type"""

Expand Down
39 changes: 32 additions & 7 deletions testing/invariant/custom_types/invariant_image.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
""" A custom type for an invariant image. """
"""A custom type for an invariant image."""

import base64
import io
from typing import Optional

from PIL import Image

from invariant.custom_types.invariant_bool import InvariantBool
from invariant.custom_types.invariant_string import InvariantString
from invariant.scorers.llm.classifier import Classifier
from invariant.scorers.utils.ocr import OCRDetector
from PIL import Image


class InvariantImage(InvariantString):
Expand Down Expand Up @@ -43,20 +44,44 @@ def llm_vision(
client (invariant.scorers.llm.clients.client.SupportedClients): The
client to use for the LLM.
"""
llm_clf = Classifier(
model=model, prompt=prompt, options=options, vision=True, client=client
)
llm_clf = Classifier(model=model, prompt=prompt, options=options, vision=True, client=client)
res = llm_clf.classify_vision(
self.value, image_type=self.image_type, use_cached_result=use_cached_result
)
return InvariantString(res, self.addresses)

def ocr_contains(
self,
text: str,
text: str | InvariantString,
case_sensitive: bool = False,
bbox: Optional[dict] = None,
) -> InvariantBool:
"""Check if the value contains the given text using OCR."""
addresses = self.addresses
if type(text) == InvariantString:
addresses.extend(text.addresses)
text = text.value
res = OCRDetector().contains(self.image, text, case_sensitive, bbox)
return InvariantBool(res, self.addresses)
return InvariantBool(res, addresses)

def ocr_contains_any(
self,
texts: list[str | InvariantString],
case_sensitive: bool = False,
bbox: Optional[dict] = None,
) -> InvariantBool:
for text in texts:
if res := self.ocr_contains(text, case_sensitive, bbox):
return res
return InvariantBool(False, self.addresses)

def ocr_contains_all(
self,
texts: list[str | InvariantString],
case_sensitive: bool = False,
bbox: Optional[dict] = None,
) -> InvariantBool:
for text in texts:
if not (res := self.ocr_contains(text, case_sensitive, bbox)):
return res
return InvariantBool(True, self.addresses)
89 changes: 51 additions & 38 deletions testing/invariant/custom_types/invariant_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import json
import re
from operator import ge, gt, le, lt, ne
from typing import Any, Union
from typing import Any, Literal, Union

from _pytest.python_api import ApproxBase

from invariant.custom_types.invariant_bool import InvariantBool
from invariant.custom_types.invariant_number import InvariantNumber
from invariant.custom_types.invariant_value import InvariantValue
Expand Down Expand Up @@ -68,9 +69,7 @@ def __le__(self, other: Union[str, "InvariantString"]) -> InvariantBool:
def __add__(self, other: Union[str, "InvariantString"]) -> "InvariantString":
"""Concatenate the string with another string."""
if isinstance(other, InvariantString):
return InvariantString(
self.value + other.value, self.addresses + other.addresses
)
return InvariantString(self.value + other.value, self.addresses + other.addresses)
return InvariantString(self.value + other, self.addresses)

def __radd__(self, other: str) -> "InvariantString":
Expand All @@ -85,7 +84,7 @@ def __repr__(self) -> str:

def __len__(self):
raise NotImplementedError(
"InvariantString does not support len(). Please use .len() instead."
"InvariantString does not support len(). Please use functionals.len() instead."
)

def __getitem__(self, key: Any, default: Any = None) -> "InvariantString":
Expand All @@ -110,25 +109,20 @@ def __getitem__(self, key: Any, default: Any = None) -> "InvariantString":
def count(self, pattern: str) -> InvariantNumber:
"""Counts the number of occurences of the given regex pattern."""
new_addresses = []
for match in re.finditer(pattern, self.value):
for match in re.finditer(pattern, self.value, re.DOTALL):
start, end = match.span()
new_addresses.append(f"{start}-{end}")
return InvariantNumber(
len(new_addresses),
(
self.addresses
if len(new_addresses) == 0
else self._concat_addresses(new_addresses)
),
(self.addresses if len(new_addresses) == 0 else self._concat_addresses(new_addresses)),
)

def len(self):
"""Return the length of the string."""
return InvariantNumber(len(self.value), self.addresses)

def __getattr__(self, attr):
"""
Delegate attribute access to the underlying string.
"""Delegate attribute access to the underlying string.
Args:
attr (str): The attribute being accessed.
Expand All @@ -152,9 +146,7 @@ def wrapper(*args, **kwargs):
return method
raise AttributeError(f"'InvariantString' object has no attribute '{attr}'")

def _concat_addresses(
self, other_addresses: list[str] | None, separator: str = ":"
) -> list[str]:
def _concat_addresses(self, other_addresses: list[str] | None, separator: str = ":") -> list[str]:
"""Concatenate the addresses of two invariant values."""
if other_addresses is None:
return self.addresses
Expand All @@ -180,39 +172,56 @@ def _concat_addresses(

def moderation(self) -> InvariantBool:
"""Check if the value is moderated."""

analyzer = ModerationAnalyzer()
res = analyzer.detect_all(self.value)
new_addresses = [str(range) for _, range in res]
return InvariantBool(len(res) > 0, self._concat_addresses(new_addresses))

def contains(
self, pattern: str | InvariantString, flags=re.IGNORECASE
self,
*patterns: Union[str, InvariantString],
criterion: Literal["all", "any"] = "all",
flags=re.IGNORECASE,
) -> InvariantBool:
"""Check if the value contains the given pattern. This ignores case by default.
"""Check if the value contains all of the given patterns.
Args:
pattern (str | InvariantString): The pattern to check for.
flags (int): The flags to use for the regex search. To pass in multiple flags,
use the bitwise OR operator (|). By default, this is re.IGNORECASE.
*patterns: Variable number of patterns to check for. Each pattern can be a string
or InvariantString.
criterion: The criterion to use for the contains check - can be "all" or "any".
flags: The flags to use for the regex search. To pass in multiple flags, use the bitwise OR operator (|). By default, this is re.IGNORECASE.
Returns:
InvariantBool: True if all patterns are found, False otherwise. The addresses will
contain the locations of all pattern matches if found.
"""
if isinstance(pattern, InvariantString):
pattern = pattern.value
if criterion not in ["all", "any"]:
raise ValueError("Criterion must be either 'all' or 'any'")
new_addresses = []
for pattern in patterns:
if isinstance(pattern, InvariantString):
pattern = pattern.value

for match in re.finditer(pattern, self.value, flags=flags):
start, end = match.span()
new_addresses.append(f"{start}-{end}")
pattern_matches = []
for match in re.finditer(pattern, self.value, flags=flags):
start, end = match.span()
pattern_matches.append(f"{start}-{end}")

return InvariantBool(
len(new_addresses) > 0,
(
self.addresses
if len(new_addresses) == 0
else self._concat_addresses(new_addresses)
),
)
if criterion == "all" and not pattern_matches:
return InvariantBool(False, self.addresses)
if criterion == "any" and pattern_matches:
return InvariantBool(True, self._concat_addresses(pattern_matches))
new_addresses.extend(pattern_matches)

return InvariantBool(criterion == "all", self._concat_addresses(new_addresses))

def contains_all(self, *patterns: Union[str, InvariantString]) -> InvariantBool:
"""Check if the value contains all of the given patterns."""
return self.contains(*patterns, criterion="all")

def contains_any(self, *patterns: Union[str, InvariantString]) -> InvariantBool:
"""Check if the value contains any of the given patterns."""
return self.contains(*patterns, criterion="any")

def __contains__(self, pattern: str | InvariantString) -> InvariantBool:
"""Check if the value contains the given pattern."""
Expand All @@ -224,9 +233,13 @@ def match(self, pattern: str, group_id: int | str = 0) -> InvariantString:
if match is None:
return None
start, end = match.span(group_id)
return InvariantString(
match.group(group_id), self._concat_addresses([f"{start}-{end}"])
)
return InvariantString(match.group(group_id), self._concat_addresses([f"{start}-{end}"]))

def match_all(self, pattern: str, group_id: int | str = 0):
"""Match the value against the given regex pattern and return all matches."""
for match in re.finditer(pattern, self.value):
start, end = match.span(group_id)
yield InvariantString(match.group(group_id), self._concat_addresses([f"{start}-{end}"]))

def is_similar(self, other: str, threshold: float = 0.5) -> InvariantBool:
"""Check if the value is similar to the given string using cosine similarity."""
Expand Down
Loading

0 comments on commit 95a19f8

Please sign in to comment.