Skip to content

Commit

Permalink
Add translation converter
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Apr 20, 2024
1 parent 8e94f24 commit ec4c8e4
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 13 deletions.
2 changes: 2 additions & 0 deletions aisploit/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .remove_punctuation import RemovePunctuationConverter
from .sequence import SequenceConverter
from .stemming import StemmingConverter
from .translation import TranslationConverter
from .unicode_confusable import UnicodeConfusableConverter

__all__ = [
Expand All @@ -27,5 +28,6 @@
"RemovePunctuationConverter",
"SequenceConverter",
"StemmingConverter",
"TranslationConverter",
"UnicodeConfusableConverter",
]
14 changes: 10 additions & 4 deletions aisploit/converters/gender.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import textwrap
from dataclasses import dataclass, field
from typing import Dict

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableSerializable

from ..core import BaseChatModel, BaseChatModelConverter
from ..core import BaseChatModelConverter

_template = ChatPromptTemplate.from_template(
textwrap.dedent(
Expand All @@ -24,10 +27,13 @@
)


@dataclass
class GenderConverter(BaseChatModelConverter):
def __init__(self, *, chat_model: BaseChatModel, prompt=_template) -> None:
super().__init__(chat_model)
self._chain = prompt | chat_model | StrOutputParser()
prompt: ChatPromptTemplate = field(default_factory=lambda: _template)
_chain: RunnableSerializable[Dict, str] = field(init=False)

def __post_init__(self) -> None:
self._chain = self.prompt | self.chat_model | StrOutputParser()

def _convert(self, prompt: str) -> str:
return self._chain.invoke({"input": prompt})
36 changes: 36 additions & 0 deletions aisploit/converters/translation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import textwrap
from dataclasses import dataclass, field
from typing import Dict

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableSerializable

from ..core import BaseChatModelConverter

_template = ChatPromptTemplate.from_template(
textwrap.dedent(
"""
Translate the following text to {language}.
Please maintain the original meaning and context as closely as possible.
Input text:
{input}
Translated text:
"""
)
)


@dataclass
class TranslationConverter(BaseChatModelConverter):
language: str
prompt: ChatPromptTemplate = field(default_factory=lambda: _template)
_chain: RunnableSerializable[Dict, str] = field(init=False)

def __post_init__(self) -> None:
self._chain = self.prompt | self.chat_model | StrOutputParser()

def _convert(self, prompt: str) -> str:
return self._chain.invoke({"input": prompt, "language": self.language})
7 changes: 4 additions & 3 deletions aisploit/core/converter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Union

from langchain_core.prompt_values import StringPromptValue
Expand Down Expand Up @@ -52,9 +53,9 @@ def __repr__(self) -> str:
return f"<{prefix}.{self.__class__.__name__}>"


@dataclass
class BaseChatModelConverter(BaseConverter, ABC):
def __init__(self, chat_model: BaseChatModel) -> None:
self._chat_model = chat_model
chat_model: BaseChatModel

def __repr__(self) -> str:
"""Return a string representation of the converter.
Expand All @@ -66,4 +67,4 @@ def __repr__(self) -> str:
if not self.__module__.startswith(prefix):
prefix = "custom"

return f"<{prefix}.{self.__class__.__name__}(chat_model={self._chat_model.get_name()})>"
return f"<{prefix}.{self.__class__.__name__}(chat_model={self.chat_model.get_name()})>"
45 changes: 39 additions & 6 deletions examples/converter.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
" SequenceConverter,\n",
" StemmingConverter,\n",
" UnicodeConfusableConverter,\n",
" TranslationConverter,\n",
")\n",
"from aisploit.models import ChatOpenAI\n",
"\n",
Expand Down Expand Up @@ -194,13 +195,45 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## RemovePunctuationConverter"
"## TranslationConverter"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"H3ll0, w0rld! H0w 4r3 y0u?"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"converter = TranslationConverter(chat_model=chat_model, language=\"leetspeak\")\n",
"converted_prompt = converter.convert(\"Hello, world! How are you?\")\n",
"\n",
"display(Markdown(converted_prompt.to_string()))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## RemovePunctuationConverter"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
Expand Down Expand Up @@ -231,7 +264,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -263,7 +296,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -295,7 +328,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -327,7 +360,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -366,7 +399,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"metadata": {},
"outputs": [
{
Expand Down

0 comments on commit ec4c8e4

Please sign in to comment.