Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Apr 11, 2024
1 parent 3a2950b commit 43a7c88
Show file tree
Hide file tree
Showing 30 changed files with 257 additions and 197 deletions.
14 changes: 9 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pip install aisploit
```python
from typing import Any
import textwrap
from aisploit.core import BaseCallbackHandler, BasePromptValue, Score
from aisploit.core import BaseCallbackHandler, BasePromptValue, Score, Response
from aisploit.model import ChatOpenAI
from aisploit.redteam import RedTeamJob, RedTeamClassifierTask
from aisploit.target import target
Expand All @@ -36,14 +36,18 @@ def play_game(level: GandalfLevel, max_attempt=5) -> None:
gandalf_scorer = GandalfScorer(level=level, chat_model=chat_model)

class GandalfHandler(BaseCallbackHandler):
def on_redteam_attempt_start(self, attempt: int, prompt: BasePromptValue, **kwargs: Any):
def on_redteam_attempt_start(
self, attempt: int, prompt: BasePromptValue, **kwargs: Any
):
print(f"Attempt #{attempt}")
print("Sending the following to Gandalf:")
print(f"{prompt.to_string()}\n")

def on_redteam_attempt_end(self, attempt: int, response: str, score: Score, **kwargs: Any):
def on_redteam_attempt_end(
self, attempt: int, response: Response, score: Score, **kwargs: Any
):
print("Response from Gandalf:")
print(f"{response}\n")
print(f"{response.content}\n")

task = RedTeamClassifierTask(
objective=textwrap.dedent(
Expand All @@ -58,7 +62,7 @@ def play_game(level: GandalfLevel, max_attempt=5) -> None:
),
classifier=gandalf_scorer,
)

@target
def send_prompt(prompt: str):
return gandalf_bot.invoke(prompt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
pipeline,
)

from ...core import BaseClassifier, Score
from ...core import BaseTextClassifier, Score


class PipelinePromptInjectionIdentifier(BaseClassifier[float]):
class PipelinePromptInjectionIdentifier(BaseTextClassifier[float]):
def __init__(
self,
*,
Expand All @@ -29,8 +29,8 @@ def __init__(
self._injection_label = injection_label
self._threshold = threshold

def score_text(self, text: str) -> Score[float]:
result = self._model(text)
def score(self, input: str) -> Score[float]:
result = self._model(input)

score = (
result[0]["score"]
Expand Down
12 changes: 6 additions & 6 deletions aisploit/classifier/openai/moderation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from openai import OpenAI
from openai.types.moderation import Moderation

from ...core import BaseClassifier, Score
from ...core import BaseTextClassifier, Score


class ModerationClassifier(BaseClassifier[Moderation]):
class ModerationClassifier(BaseTextClassifier[Moderation]):
"""A classifier that uses the OpenAI Moderations API for scoring."""

def __init__(
Expand All @@ -19,14 +19,14 @@ def __init__(

self._client = OpenAI(api_key=api_key)

def score_text(self, text: str) -> Score[Moderation]:
"""Score the text using the OpenAI Moderations API."""
response = self._client.moderations.create(input=text)
def score(self, input: str) -> Score[Moderation]:
"""Score the input using the OpenAI Moderations API."""
response = self._client.moderations.create(input=input)
output = response.results[0]

return Score[Moderation](
flagged=output.flagged,
value=output,
description="Moderation score for the given text",
description="Moderation score for the given input",
explanation="Details about the moderation score",
)
41 changes: 33 additions & 8 deletions aisploit/classifier/text.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,56 @@
import re
from ..core import BaseClassifier, Score
from ..core import BaseTextClassifier, Score


class RegexClassifier(BaseClassifier[bool]):
class RegexClassifier(BaseTextClassifier[bool]):
"""A text classifier based on regular expressions."""

def __init__(self, *, pattern: re.Pattern, flag_matches=True) -> None:
"""Initialize the RegexClassifier.
Args:
pattern (re.Pattern): The regular expression pattern to match.
flag_matches (bool, optional): Flag indicating whether matches should be flagged. Defaults to True.
"""
self._pattern = pattern
self._flag_matches = flag_matches

def score_text(self, text: str) -> Score[bool]:
if re.search(self._pattern, text):
def score(self, input: str) -> Score[bool]:
"""Score the input based on the regular expression pattern.
Args:
input (str): The input text to be scored.
Returns:
Score[bool]: A Score object representing the result of scoring.
"""
if re.search(self._pattern, input):
return Score[bool](
flagged=True if self._flag_matches else False,
value=True,
description=f"Return True if the pattern {self._pattern.pattern} is found in the text",
explanation=f"Found pattern {self._pattern.pattern} in text",
description=f"Return True if the pattern {self._pattern.pattern} is found in the input",
explanation=f"Found pattern {self._pattern.pattern} in input",
)

return Score[bool](
flagged=False if self._flag_matches else True,
value=False,
description=f"Return True if the pattern {self._pattern.pattern} is found in the text",
explanation=f"Did not find pattern {self._pattern.pattern} in text",
description=f"Return True if the pattern {self._pattern.pattern} is found in the input",
explanation=f"Did not find pattern {self._pattern.pattern} in input",
)


class SubstringClassifier(RegexClassifier):
"""A text classifier based on substring matching."""

def __init__(self, *, substring: str, ignore_case=True, flag_matches=True) -> None:
"""Initialize the SubstringClassifier.
Args:
substring (str): The substring to match.
ignore_case (bool, optional): Flag indicating whether to ignore case when matching substrings. Defaults to True.
flag_matches (bool, optional): Flag indicating whether matches should be flagged. Defaults to True.
"""
compiled_pattern = (
re.compile(substring, re.IGNORECASE)
if ignore_case
Expand Down
6 changes: 4 additions & 2 deletions aisploit/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
from .callbacks import BaseCallbackHandler, Callbacks, CallbackManager
from .classifier import BaseClassifier, Score
from .classifier import BaseClassifier, BaseTextClassifier, Score
from .converter import BaseConverter, BaseChatModelConverter
from .dataset import BaseDataset, YamlDeserializable
from .generator import BaseGenerator
from .job import BaseJob
from .model import BaseLLM, BaseChatModel, BaseModel, BaseEmbeddings
from .prompt import BasePromptValue
from .report import BaseReport
from .target import BaseTarget
from .target import BaseTarget, Response
from .vectorstore import BaseVectorStore

__all__ = [
"BaseCallbackHandler",
"Callbacks",
"CallbackManager",
"BaseClassifier",
"BaseTextClassifier",
"Score",
"BaseConverter",
"BaseChatModelConverter",
Expand All @@ -29,5 +30,6 @@
"BasePromptValue",
"BaseReport",
"BaseTarget",
"Response",
"BaseVectorStore",
]
23 changes: 19 additions & 4 deletions aisploit/core/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from .prompt import BasePromptValue
from .classifier import Score
from .target import Response


class BaseCallbackHandler:
Expand All @@ -20,13 +21,13 @@ def on_redteam_attempt_start(
pass

def on_redteam_attempt_end(
self, attempt: int, response: str, score: Score, *, run_id: str
self, attempt: int, response: Response, score: Score, *, run_id: str
):
"""Called when a red team attempt ends.
Args:
attempt (int): The attempt number.
response (str): The response from the attempt.
response (Response): The response from the attempt.
score (Score): The score of the attempt.
run_id (str): The ID of the current run.
"""
Expand All @@ -50,6 +51,12 @@ def on_scanner_plugin_end(self, name: str, *, run_id: str):
"""
pass

def on_sender_send_prompt_start(self):
pass

def on_sender_send_prompt_end(self):
pass


Callbacks = Sequence[BaseCallbackHandler]

Expand Down Expand Up @@ -84,12 +91,12 @@ def on_redteam_attempt_start(self, attempt: int, prompt: BasePromptValue):
attempt=attempt, prompt=prompt, run_id=self.run_id
)

def on_redteam_attempt_end(self, attempt: int, response: str, score: Score):
def on_redteam_attempt_end(self, attempt: int, response: Response, score: Score):
"""Notify callback handlers when a red team attempt ends.
Args:
attempt (int): The attempt number.
response (str): The response from the attempt.
response (Response): The response from the attempt.
score (Score): The score of the attempt.
"""
for cb in self._callbacks:
Expand All @@ -114,3 +121,11 @@ def on_scanner_plugin_end(self, name: str):
"""
for cb in self._callbacks:
cb.on_scanner_plugin_end(name=name, run_id=self.run_id)

def on_sender_send_prompt_start(self):
for cb in self._callbacks:
cb.on_sender_send_prompt_start()

def on_sender_send_prompt_end(self):
for cb in self._callbacks:
cb.on_sender_send_prompt_end()
27 changes: 20 additions & 7 deletions aisploit/core/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,44 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass


T = TypeVar("T")
Input = TypeVar("Input")


@dataclass(frozen=True)
class Score(Generic[T]):
"""A class representing a score."""
"""A class representing a score.
Attributes:
flagged (bool): Whether the score is flagged.
value (T): The value of the score.
description (str): Optional description of the score.
explanation (str): Optional explanation of the score.
"""

flagged: bool
value: T
description: str = ""
explanation: str = ""


class BaseClassifier(ABC, Generic[T]):
class BaseClassifier(ABC, Generic[T, Input]):
"""An abstract base class for classifiers."""

@abstractmethod
def score_text(self, text: str) -> Score[T]:
"""Score the text and return a Score object.
def score(self, input: Input) -> Score[T]:
"""Score the input and return a Score object.
Args:
text (str): The text to be scored.
input (Input): The input to be scored.
Returns:
Score[T]: A Score object representing the score of the text.
Score[T]: A Score object representing the score of the input.
"""
pass


class BaseTextClassifier(BaseClassifier[T, str], Generic[T]):
"""An abstract base class for text classifiers."""

pass
28 changes: 27 additions & 1 deletion aisploit/core/target.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,35 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass

from .prompt import BasePromptValue


@dataclass(frozen=True)
class Response:
"""A class representing a response from a target.
Attributes:
content (str): The content of the response.
"""

content: str

def __repr__(self) -> str:
"""Return a string representation of the Response."""
return f"content={repr(self.content)}"


class BaseTarget(ABC):
"""An abstract base class for targets."""

@abstractmethod
def send_prompt(self, prompt: BasePromptValue) -> str:
def send_prompt(self, prompt: BasePromptValue) -> Response:
"""Send a prompt to the target and return the response.
Args:
prompt (BasePromptValue): The prompt to send.
Returns:
Response: The response from the target.
"""
pass
Loading

0 comments on commit 43a7c88

Please sign in to comment.