Skip to content

Commit

Permalink
Add amazon comprehend classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Apr 21, 2024
1 parent 9d4ee34 commit 6666901
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 12 deletions.
6 changes: 6 additions & 0 deletions aisploit/classifiers/amazon/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .comprehend import ComprehendPIIClassifier, ComprehendToxicityClassifier

__all__ = [
"ComprehendPIIClassifier",
"ComprehendToxicityClassifier",
]
64 changes: 64 additions & 0 deletions aisploit/classifiers/amazon/comprehend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from dataclasses import dataclass, field
from typing import Any, Dict, Generic, List, TypeVar

import boto3

from ...core import BaseTextClassifier, Score

T = TypeVar("T")


@dataclass
class BaseComprehendClassifier(BaseTextClassifier[T], Generic[T]):
session: boto3.Session = field(default_factory=lambda: boto3.Session())
region_name: str = "us-east-1"

def __post_init__(self):
self._client = self.session.client("comprehend", region_name=self.region_name)


@dataclass
class ComprehendPIIClassifier(BaseComprehendClassifier[List[Any]]):
language: str = "en"
threshold: float = 0.7

def score(self, input: str) -> Score[List[Any]]:
response = self._client.detect_pii_entities(Text=input, LanguageCode=self.language)

entities = [entity for entity in response["Entities"] if entity["Score"] >= self.threshold]

return Score[List[Any]](
flagged=len(entities) > 0,
value=entities,
description="Returns True if entities are found in the input",
explanation=(
f"Found {len(entities)} entities in input" if len(entities) > 0 else "Did not find entities in input"
),
)


@dataclass
class ComprehendToxicityClassifier(BaseComprehendClassifier[Dict[str, Any]]):
language: str = "en"
threshold: float = 0.7

def score(self, input: str) -> Score[Dict[str, Any]]:
response = self._client.detect_toxic_content(
TextSegments=[
{'Text': input},
],
LanguageCode=self.language,
)

toxicity = response["ResultList"][0]["Toxicity"]
labels = response["ResultList"][0]["Labels"]

return Score[Dict[str, Any]](
flagged=toxicity >= self.threshold,
value={
"Toxicity": toxicity,
"Labels": labels,
},
description="Returns True if the overall toxicity score is greater than or equal to the threshold",
explanation=f"The overall toxicity score for the input is {toxicity}",
)
18 changes: 7 additions & 11 deletions aisploit/classifiers/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,13 @@ class TextTokenClassifier(BaseTextClassifier[bool]):
token: str

def score(self, input: str) -> Score[bool]:
if self.token in input:
return Score[bool](
flagged=True,
value=True,
description=f"Return True if the token {self.token} is found in the input",
explanation=f"Found token {self.token} in input",
)

return Score[bool](
flagged=False,
value=False,
flagged=self.token in input,
value=self.token in input,
description=f"Return True if the token {self.token} is found in the input",
explanation=f"Did not find token {self.token} in input",
explanation=(
f"Found token {self.token} in input"
if self.token in input
else f"Did not find token {self.token} in input"
),
)
6 changes: 5 additions & 1 deletion docs/api/classifiers/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ Classifiers

.. automodule:: aisploit.classifiers

.. automodule:: aisploit.classifiers.amazon

.. automodule:: aisploit.classifiers.huggingface

.. automodule:: aisploit.classifiers.openai
.. automodule:: aisploit.classifiers.openai

.. automodule:: aisploit.classifiers.presidio
50 changes: 50 additions & 0 deletions examples/classifier.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"from aisploit.classifiers.presidio import PresidioAnalyserClassifier\n",
"from aisploit.classifiers.huggingface import PipelinePromptInjectionIdentifier\n",
"from aisploit.classifiers.openai import ModerationClassifier\n",
"from aisploit.classifiers.amazon import ComprehendPIIClassifier, ComprehendToxicityClassifier\n",
"\n",
"load_dotenv()"
]
Expand Down Expand Up @@ -61,6 +62,55 @@
"classifier.score(\"My name is John Doo and my phone number is 212-555-5555\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Amazon"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Score(flagged=True, value=[{'Score': 0.9999950528144836, 'Type': 'NAME', 'BeginOffset': 11, 'EndOffset': 19}, {'Score': 0.9999926090240479, 'Type': 'PHONE', 'BeginOffset': 43, 'EndOffset': 55}], description='Returns True if entities are found in the input', explanation='Found 2 entities in input')"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"classifier = ComprehendPIIClassifier()\n",
"classifier.score(\"My name is John Doo and my phone number is 212-555-5555\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Score(flagged=True, value={'Toxicity': 0.8208000063896179, 'Labels': [{'Name': 'PROFANITY', 'Score': 0.19329999387264252}, {'Name': 'HATE_SPEECH', 'Score': 0.2694000005722046}, {'Name': 'INSULT', 'Score': 0.2587999999523163}, {'Name': 'GRAPHIC', 'Score': 0.19329999387264252}, {'Name': 'HARASSMENT_OR_ABUSE', 'Score': 0.18960000574588776}, {'Name': 'SEXUAL', 'Score': 0.21789999306201935}, {'Name': 'VIOLENCE_OR_THREAT', 'Score': 0.9879999756813049}]}, description='Returns True if the overall toxicity score is greater than or equal to the threshold', explanation='The overall toxicity score for the input is 0.8208000063896179')"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"classifier = ComprehendToxicityClassifier()\n",
"classifier.score(\"I will kill you\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down

0 comments on commit 6666901

Please sign in to comment.