Skip to content

Commit

Permalink
Fix memory handling
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Apr 5, 2024
1 parent 43823f0 commit 928d353
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 21 deletions.
28 changes: 9 additions & 19 deletions aisploit/redteam/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(

self._chain = RunnableWithMessageHistory(
runnable, # type: ignore[arg-type]
get_session_history=lambda session_id: self._history,
get_session_history=self.get_session_history,
input_messages_key="input",
output_messages_key="output",
history_messages_key="chat_history",
Expand All @@ -87,8 +87,8 @@ def invoke(self, message: str) -> str:
- str: The response from the bot.
"""
return self._chain.invoke(
{"conversation_objective": self._conversation_objective, "input": message},
config={"configurable": {"session_id": self._conversation_id}},
{"conversation_objective": self.conversation_objective, "input": message},
config={"configurable": {"session_id": self.conversation_id}},
)

def is_conversation_complete(self) -> bool:
Expand All @@ -98,7 +98,7 @@ def is_conversation_complete(self) -> bool:
Returns:
- bool: True if the conversation is complete, False otherwise.
"""
current_messages = self._history.messages
current_messages = self.get_session_history(self.conversation_id).messages

# If there are no messages, then the conversation is not complete
if not current_messages or len(current_messages) == 0:
Expand All @@ -110,21 +110,11 @@ def is_conversation_complete(self) -> bool:

return False

def clear_history(self) -> None:
"""
Clear the conversation history.
"""
self._history.clear()
def get_session_history(self, conversation_id: str) -> BaseChatMessageHistory:
return self._history

@property
def history(self) -> List[BaseMessage]:
"""
Get the conversation history.
Returns:
- List[BaseMessage]: The list of messages in the conversation history.
"""
return self._history.messages
def clear_history(self, conversation_id: str) -> None:
self.get_session_history(conversation_id).clear()

@property
def conversation_id(self) -> str:
Expand Down Expand Up @@ -153,4 +143,4 @@ def __str__(self):
Returns:
- str: The string representation.
"""
return f"Red Teaming Bot ID {self._conversation_id}"
return f"Red Teaming Bot ID {self.conversation_id}"
4 changes: 2 additions & 2 deletions aisploit/redteam/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
self._classifier = classifier
self._initial_prompt = initial_prompt
self._callback_manager = CallbackManager(
id=self.conversation_id,
id=bot.conversation_id,
callbacks=callbacks,
)

Expand All @@ -38,7 +38,7 @@ def conversation_id(self):

def execute(self, max_attempt=5, clear_history=True):
if clear_history:
self._bot.clear_history()
self._bot.clear_history(self.conversation_id)

current_prompt = self._initial_prompt

Expand Down

0 comments on commit 928d353

Please sign in to comment.