-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
180 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from .text import RegexClassifier, SubstringClassifier | ||
|
||
__all__ = [ | ||
"RegexClassifier", | ||
"SubstringClassifier", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import re | ||
from ..core import BaseClassifier, Score | ||
|
||
|
||
class RegexClassifier(BaseClassifier[bool]): | ||
def __init__(self, *, pattern: re.Pattern, flag_matches=True) -> None: | ||
self._pattern = pattern | ||
self._flag_matches = flag_matches | ||
|
||
def score_text(self, text: str) -> Score[bool]: | ||
if re.search(self._pattern, text): | ||
return Score[bool]( | ||
flagged=True if self._flag_matches else False, | ||
value=True, | ||
description=f"Return True if the pattern {self._pattern.pattern} is found in the text", | ||
explanation=f"Found pattern {self._pattern.pattern} in text", | ||
) | ||
|
||
return Score[bool]( | ||
flagged=False if self._flag_matches else True, | ||
value=False, | ||
description=f"Return True if the pattern {self._pattern.pattern} is found in the text", | ||
explanation=f"Did not find pattern {self._pattern.pattern} in text", | ||
) | ||
|
||
|
||
class SubstringClassifier(RegexClassifier): | ||
def __init__(self, *, substring: str, ignore_case=True, flag_matches=True) -> None: | ||
compiled_pattern = ( | ||
re.compile(substring, re.IGNORECASE) | ||
if ignore_case | ||
else re.compile(substring) | ||
) | ||
super().__init__(pattern=compiled_pattern, flag_matches=flag_matches) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,21 @@ | ||
from typing import TypeVar, Generic | ||
from abc import ABC, abstractmethod | ||
from typing import Literal | ||
from dataclasses import dataclass | ||
|
||
|
||
T = TypeVar("T", int, float, str, bool) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class Score: | ||
class Score(Generic[T]): | ||
flagged: bool | ||
score_type: Literal["int", "float", "str", "bool"] | ||
score_value: int | float | str | bool | ||
score_description: str = "" | ||
score_explanation: str = "" | ||
value: T | ||
description: str = "" | ||
explanation: str = "" | ||
|
||
|
||
class BaseClassifier(ABC): | ||
class BaseClassifier(ABC, Generic[T]): | ||
@abstractmethod | ||
def score_text(self, text: str) -> Score: | ||
def score_text(self, text: str) -> Score[T]: | ||
"""Score the text and return a Score object.""" | ||
raise NotImplementedError("score_text method not implemented") | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -96,7 +96,7 @@ def execute( | |
) | ||
) | ||
|
||
if score.score_value: | ||
if score.flagged: | ||
break | ||
|
||
current_prompt_text = response | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.