diff --git a/aisploit/redteam/job.py b/aisploit/redteam/job.py index 7377cd6..ea95810 100644 --- a/aisploit/redteam/job.py +++ b/aisploit/redteam/job.py @@ -11,6 +11,7 @@ from ..core import ( BaseChatModel, BaseJob, + BaseConverter, BaseTarget, Callbacks, CallbackManager, @@ -36,6 +37,7 @@ def __init__( task: RedTeamTask, target: BaseTarget, get_session_history: GetSessionHistoryCallable = get_session_history, + converter: Optional[BaseConverter] = None, callbacks: Callbacks = [], verbose=False, ) -> None: @@ -45,6 +47,7 @@ def __init__( self._task = task self._target = target self._get_session_history = get_session_history + self._converter = converter self._callbacks = callbacks def execute( @@ -80,7 +83,11 @@ def execute( config={"configurable": {"session_id": run_id}}, ) - current_prompt = StringPromptValue(text=current_prompt_text) + current_prompt = ( + self._converter.convert(current_prompt_text) + if self._converter + else StringPromptValue(text=current_prompt_text) + ) callback_manager.on_redteam_attempt_start(attempt, current_prompt) @@ -101,8 +108,7 @@ def execute( ) ) - # task.is_completed - if score.flagged: + if score.flagged: # task is completed break current_prompt_text = response.content