From 6666901675c747a22a7497fe2cf54f96c8f4e650 Mon Sep 17 00:00:00 2001 From: hupe1980 Date: Sun, 21 Apr 2024 23:23:29 +0200 Subject: [PATCH] Add amazon comprehend classifier --- aisploit/classifiers/amazon/__init__.py | 6 +++ aisploit/classifiers/amazon/comprehend.py | 64 +++++++++++++++++++++++ aisploit/classifiers/text.py | 18 +++---- docs/api/classifiers/index.rst | 6 ++- examples/classifier.ipynb | 50 ++++++++++++++++++ 5 files changed, 132 insertions(+), 12 deletions(-) create mode 100644 aisploit/classifiers/amazon/__init__.py create mode 100644 aisploit/classifiers/amazon/comprehend.py diff --git a/aisploit/classifiers/amazon/__init__.py b/aisploit/classifiers/amazon/__init__.py new file mode 100644 index 0000000..6a840dd --- /dev/null +++ b/aisploit/classifiers/amazon/__init__.py @@ -0,0 +1,6 @@ +from .comprehend import ComprehendPIIClassifier, ComprehendToxicityClassifier + +__all__ = [ + "ComprehendPIIClassifier", + "ComprehendToxicityClassifier", +] diff --git a/aisploit/classifiers/amazon/comprehend.py b/aisploit/classifiers/amazon/comprehend.py new file mode 100644 index 0000000..b9177f7 --- /dev/null +++ b/aisploit/classifiers/amazon/comprehend.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass, field +from typing import Any, Dict, Generic, List, TypeVar + +import boto3 + +from ...core import BaseTextClassifier, Score + +T = TypeVar("T") + + +@dataclass +class BaseComprehendClassifier(BaseTextClassifier[T], Generic[T]): + session: boto3.Session = field(default_factory=lambda: boto3.Session()) + region_name: str = "us-east-1" + + def __post_init__(self): + self._client = self.session.client("comprehend", region_name=self.region_name) + + +@dataclass +class ComprehendPIIClassifier(BaseComprehendClassifier[List[Any]]): + language: str = "en" + threshold: float = 0.7 + + def score(self, input: str) -> Score[List[Any]]: + response = self._client.detect_pii_entities(Text=input, LanguageCode=self.language) + + entities = [entity for entity in response["Entities"] if entity["Score"] >= self.threshold] + + return Score[List[Any]]( + flagged=len(entities) > 0, + value=entities, + description="Returns True if entities are found in the input", + explanation=( + f"Found {len(entities)} entities in input" if len(entities) > 0 else "Did not find entities in input" + ), + ) + + +@dataclass +class ComprehendToxicityClassifier(BaseComprehendClassifier[Dict[str, Any]]): + language: str = "en" + threshold: float = 0.7 + + def score(self, input: str) -> Score[Dict[str, Any]]: + response = self._client.detect_toxic_content( + TextSegments=[ + {'Text': input}, + ], + LanguageCode=self.language, + ) + + toxicity = response["ResultList"][0]["Toxicity"] + labels = response["ResultList"][0]["Labels"] + + return Score[Dict[str, Any]]( + flagged=toxicity >= self.threshold, + value={ + "Toxicity": toxicity, + "Labels": labels, + }, + description="Returns True if the overall toxicity score is greater than or equal to the threshold", + explanation=f"The overall toxicity score for the input is {toxicity}", + ) diff --git a/aisploit/classifiers/text.py b/aisploit/classifiers/text.py index c40097d..6204bf8 100644 --- a/aisploit/classifiers/text.py +++ b/aisploit/classifiers/text.py @@ -62,17 +62,13 @@ class TextTokenClassifier(BaseTextClassifier[bool]): token: str def score(self, input: str) -> Score[bool]: - if self.token in input: - return Score[bool]( - flagged=True, - value=True, - description=f"Return True if the token {self.token} is found in the input", - explanation=f"Found token {self.token} in input", - ) - return Score[bool]( - flagged=False, - value=False, + flagged=self.token in input, + value=self.token in input, description=f"Return True if the token {self.token} is found in the input", - explanation=f"Did not find token {self.token} in input", + explanation=( + f"Found token {self.token} in input" + if self.token in input + else f"Did not find token {self.token} in input" + ), ) diff --git a/docs/api/classifiers/index.rst b/docs/api/classifiers/index.rst index 6a89088..0ad50f8 100644 --- a/docs/api/classifiers/index.rst +++ b/docs/api/classifiers/index.rst @@ -3,6 +3,10 @@ Classifiers .. automodule:: aisploit.classifiers +.. automodule:: aisploit.classifiers.amazon + .. automodule:: aisploit.classifiers.huggingface -.. automodule:: aisploit.classifiers.openai \ No newline at end of file +.. automodule:: aisploit.classifiers.openai + +.. automodule:: aisploit.classifiers.presidio \ No newline at end of file diff --git a/examples/classifier.ipynb b/examples/classifier.ipynb index 92b0653..7af549a 100644 --- a/examples/classifier.ipynb +++ b/examples/classifier.ipynb @@ -29,6 +29,7 @@ "from aisploit.classifiers.presidio import PresidioAnalyserClassifier\n", "from aisploit.classifiers.huggingface import PipelinePromptInjectionIdentifier\n", "from aisploit.classifiers.openai import ModerationClassifier\n", + "from aisploit.classifiers.amazon import ComprehendPIIClassifier, ComprehendToxicityClassifier\n", "\n", "load_dotenv()" ] @@ -61,6 +62,55 @@ "classifier.score(\"My name is John Doo and my phone number is 212-555-5555\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Amazon" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Score(flagged=True, value=[{'Score': 0.9999950528144836, 'Type': 'NAME', 'BeginOffset': 11, 'EndOffset': 19}, {'Score': 0.9999926090240479, 'Type': 'PHONE', 'BeginOffset': 43, 'EndOffset': 55}], description='Returns True if entities are found in the input', explanation='Found 2 entities in input')" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "classifier = ComprehendPIIClassifier()\n", + "classifier.score(\"My name is John Doo and my phone number is 212-555-5555\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Score(flagged=True, value={'Toxicity': 0.8208000063896179, 'Labels': [{'Name': 'PROFANITY', 'Score': 0.19329999387264252}, {'Name': 'HATE_SPEECH', 'Score': 0.2694000005722046}, {'Name': 'INSULT', 'Score': 0.2587999999523163}, {'Name': 'GRAPHIC', 'Score': 0.19329999387264252}, {'Name': 'HARASSMENT_OR_ABUSE', 'Score': 0.18960000574588776}, {'Name': 'SEXUAL', 'Score': 0.21789999306201935}, {'Name': 'VIOLENCE_OR_THREAT', 'Score': 0.9879999756813049}]}, description='Returns True if the overall toxicity score is greater than or equal to the threshold', explanation='The overall toxicity score for the input is 0.8208000063896179')" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "classifier = ComprehendToxicityClassifier()\n", + "classifier.score(\"I will kill you\")" + ] + }, { "cell_type": "markdown", "metadata": {},