Skip to content

Commit

Permalink
improve agent.info docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
josh-ashkinaze committed Jul 7, 2024
1 parent 91078d5 commit c308a77
Showing 1 changed file with 23 additions and 38 deletions.
61 changes: 23 additions & 38 deletions plurals/agent.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from datetime import datetime
import warnings
from typing import Optional, Dict

import pandas as pd
from typing import Optional, Dict, Any
from plurals.helpers import *
from litellm import completion
import warnings

from plurals.helpers import *

DEFAULTS = load_yaml("instructions.yaml")


class Agent:
"""
A class to represent an agent that processes tasks based on specific characteristics.
Expand Down Expand Up @@ -35,20 +37,13 @@ class Agent:
original_task_description (str): The original task description without modifications.
current_task_description (str): The current task description that appends `previous_responses'.
history (list): A list of dicts like {'prompt':prompt, 'response':response, 'model':model}
info (dict): A dictionary of different attributes of the agent.
info (dict): A dictionary of different attributes of the agent. Contains keys for: 'task', 'system_instructions', 'history', 'persona', 'ideology', 'query_str', 'model', 'persona_template', and 'kwargs'
"""

def __init__(self,
task: Optional[str] = None,
data: Optional[pd.DataFrame] = None,
persona_mapping: Optional[Dict[str, Any]] = None,
ideology: Optional[str] = None,
query_str: Optional[str] = None,
model: str = "gpt-4o",
system_instructions: Optional[str] = None,
persona_template: Optional[str] = "default",
persona: Optional[str] = None,
**kwargs):
def __init__(self, task: Optional[str] = None, data: Optional[pd.DataFrame] = None,
persona_mapping: Optional[Dict[str, Any]] = None, ideology: Optional[str] = None,
query_str: Optional[str] = None, model: str = "gpt-4o", system_instructions: Optional[str] = None,
persona_template: Optional[str] = "default", persona: Optional[str] = None, **kwargs):
self.model = model
self.system_instructions = system_instructions
self._history = []
Expand Down Expand Up @@ -103,7 +98,8 @@ def _set_system_instructions(self):
self.persona = self._generate_persona()

# Use the persona_template to create system_instructions
self.persona_template = self.defaults['persona_template'].get(self.persona_template, self.persona_template).strip()
self.persona_template = self.defaults['persona_template'].get(self.persona_template,
self.persona_template).strip()
self.system_instructions = SmartString(self.persona_template).format(persona=self.persona,
task=self.task_description).strip()

Expand Down Expand Up @@ -165,7 +161,6 @@ def process(self, previous_responses: str = "") -> Optional[str]:
self.current_task_description = self.original_task_description
return self._get_response(self.current_task_description)


def _get_random_persona(self, data: pd.DataFrame) -> str:
"""
Generates a random persona description based on the dataset.
Expand All @@ -190,14 +185,9 @@ def _get_response(self, task: str) -> Optional[str]:
Optional[str]: The response from the LLM.
"""
if self.system_instructions:
messages = [
{"role": "system", "content": self.system_instructions},
{"role": "user", "content": task}
]
messages = [{"role": "system", "content": self.system_instructions}, {"role": "user", "content": task}]
else:
messages = [
{"role": "user", "content": task}
]
messages = [{"role": "user", "content": task}]
try:
response = completion(model=self.model, messages=messages, **self.kwargs)
content = response.choices[0].message.content
Expand Down Expand Up @@ -275,7 +265,8 @@ def _validate_templates(self):
if self.persona_template:
default_templates = list(self.defaults['persona_template'].keys())

assert '${persona}' in self.persona_template or self.persona_template in default_templates, "If you pass in a persona_template, it must contain a ${persona} placeholder or be one of the default templates:" + str(default_templates)
assert '${persona}' in self.persona_template or self.persona_template in default_templates, "If you pass in a persona_template, it must contain a ${persona} placeholder or be one of the default templates:" + str(
default_templates)

def _convert_ideology_to_query_str(self, ideology: str) -> str:
"""
Expand All @@ -284,7 +275,7 @@ def _convert_ideology_to_query_str(self, ideology: str) -> str:
if ideology.lower() == 'liberal':
return "ideo5=='Liberal'|ideo5=='Very liberal'"
elif ideology.lower() == 'conservative':
return"ideo5=='Conservative'|ideo5=='Very conservative'"
return "ideo5=='Conservative'|ideo5=='Very conservative'"
elif ideology.lower() == 'moderate':
return "ideo5 == 'Moderate'"
elif ideology.lower() == "very liberal":
Expand All @@ -301,17 +292,11 @@ def history(self):
else:
return self._history

@ property
@property
def info(self):
return {"task": self.task_description,
"system_instructions": self.system_instructions,
"history": self.history,
"persona": self.persona,
"ideology": self.ideology,
"query_str": self.query_str,
"model": self.model,
"persona_template": self.persona_template,
"kwargs": self.kwargs}
return {"task": self.task_description, "system_instructions": self.system_instructions, "history": self.history,
"persona": self.persona, "ideology": self.ideology, "query_str": self.query_str, "model": self.model,
"persona_template": self.persona_template, "kwargs": self.kwargs}

def __repr__(self):
return str(self.info)

0 comments on commit c308a77

Please sign in to comment.