Skip to content

Commit

Permalink
fix overwrite combo inst and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
josh-ashkinaze committed Oct 15, 2024
1 parent 664d393 commit e46b777
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 5 deletions.
10 changes: 5 additions & 5 deletions plurals/deliberation.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def process(self):
previous_responses_slice = previous_responses[-self.last_n:]
previous_responses_str = format_previous_responses(
previous_responses_slice)
agent.combination_instructions = self.combination_instructions
agent.combination_instructions = agent.combination_instructions if agent.combination_instructions else self.combination_instructions
response = agent.process(
previous_responses=previous_responses_str)
previous_responses.append(response)
Expand Down Expand Up @@ -590,7 +590,8 @@ def __init__(
if len(agents) != 2:
raise ValueError("Debate requires exactly two agents.")
super().__init__(
agents, task, shuffle, cycles, last_n, combination_instructions, moderator)
agents=agents, task=task, shuffle=shuffle, cycles=cycles, last_n=last_n, moderator=moderator)
self.combination_instructions = combination_instructions

@staticmethod
def _format_previous_responses(responses: List[str]) -> str:
Expand Down Expand Up @@ -648,7 +649,7 @@ def process(self):
previous_responses_str = self._format_previous_responses(
previous_responses_agent2[-self.last_n:])

agent.combination_instructions = self.combination_instructions
agent.combination_instructions = agent.combination_instructions if agent.combination_instructions else self.combination_instructions
response = agent.process(
previous_responses=previous_responses_str)
response = self._strip_placeholders(response)
Expand Down Expand Up @@ -765,7 +766,6 @@ def __init__(self,
self.edges = edges

super().__init__(agents=self.agents, task=task, last_n=last_n,
combination_instructions=combination_instructions,
moderator=moderator)
self._build_graph()

Expand Down Expand Up @@ -838,7 +838,7 @@ def process(self):
# Process agents according to topological order
response_dict = {}
for agent in topological_order:
agent.combination_instructions = self.combination_instructions
agent.combination_instructions = agent.combination_instructions if agent.combination_instructions else self.combination_instructions
# Gather responses from all predecessors to form the input for the current agent
previous_responses = [response_dict[pred]
for pred in self.agents if agent in self.graph[pred]]
Expand Down
69 changes: 69 additions & 0 deletions plurals/unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,26 @@ def test_chain_combination_instructions(self):
self.assertIsNotNone(mixed.final_response)
self.assertEqual(DEFAULTS['combination_instructions']['chain'], mixed.combination_instructions)

def test_chain_with_different_combination_instructions(self):
task = "Say hello"
agent1 = Agent(task=task, combination_instructions="Hello to previous ${previous_responses}",
model="gpt-3.5-turbo")
agent2 = Agent(task=task, combination_instructions="Prev2 to previous ${previous_responses}", model="gpt-4")
agent3 = Agent(task=task, combination_instructions="Greetings to previous ${previous_responses}",
model="gpt-3.5-turbo")

chain = Chain([agent1, agent2, agent3], task=task)
chain.process()

self.assertEqual(agent1.combination_instructions, "Hello to previous ${previous_responses}")
self.assertEqual(agent2.combination_instructions, "Prev2 to previous ${previous_responses}")
self.assertEqual(agent3.combination_instructions, "Greetings to previous ${previous_responses}")

self.assertEqual(len(chain.responses), 3)

self.assertIn("Prev2 to previous", agent2.prompts[0]['user'])
self.assertIn("Greetings to previous", agent3.prompts[0]['user'])

def test_chain_debate_instructions(self):
a2 = Agent(ideology='moderate', model=self.model)
a3 = Agent(ideology='liberal', model=self.model)
Expand Down Expand Up @@ -670,6 +690,26 @@ def test_debate_with_different_agent_tasks(self):
self.assertIn("effective in other countries", debate.responses[2])
self.assertIn("self-defense", debate.responses[3])

def test_debate_with_different_combination_instructions(self):
task = "Debate the pros and cons of artificial intelligence in 10 words."
agent1 = Agent(task=task, combination_instructions="Respond to previous argument: ${previous_responses}",
model="gpt-3.5-turbo")
agent2 = Agent(task=task, combination_instructions="Counter the previous point: ${previous_responses}",
model="gpt-4")

debate = Debate([agent1, agent2], task=task)
debate.process()

print(agent1.combination_instructions)
self.assertEqual(agent1.combination_instructions, "Respond to previous argument: ${previous_responses}")
self.assertEqual(agent2.combination_instructions, "Counter the previous point: ${previous_responses}")

self.assertEqual(len(debate.responses), 2)

self.assertIn("Counter the previous point:", agent2.prompts[0]['user'])




class TestAgentStructures(unittest.TestCase):

Expand Down Expand Up @@ -764,6 +804,35 @@ def test_graph_with_different_agent_tasks(self):
self.assertIn("inequality", graph.responses[1])
self.assertIn("Both economic growth and social equality", final_response)

def test_graph_with_different_combination_instructions(self):
task = "Say hello"
agent1 = Agent(task=task, combination_instructions="Hello to previous ${previous_responses}",
model="gpt-3.5-turbo")
agent2 = Agent(task=task, combination_instructions="Prev2 to previous ${previous_responses}", model="gpt-4")
agent3 = Agent(task=task, combination_instructions="Greetings to previous ${previous_responses}",
model="gpt-3.5-turbo")

agents = {
'agent1': agent1,
'agent2': agent2,
'agent3': agent3
}
edges = [('agent1', 'agent2'), ('agent1', 'agent3'), ('agent2', 'agent3')]

graph = Graph(agents=agents, edges=edges, task=task)
graph.process()

# Check if the combination instructions are preserved for each agent
self.assertEqual(agent1.combination_instructions, "Hello to previous ${previous_responses}")
self.assertEqual(agent2.combination_instructions, "Prev2 to previous ${previous_responses}")
self.assertEqual(agent3.combination_instructions, "Greetings to previous ${previous_responses}")

self.assertEqual(len(graph.responses), 3)

# Check if the combination instructions are used in the prompts
self.assertIn("Prev2 to previous", agent2.prompts[0]['user'])
self.assertIn("Greetings to previous", agent3.prompts[0]['user'])

def test_dict_to_list_conversion(self):
"""Test that the dictionary method (Method 2) correctly converts to the list method (Method 1) internally."""
agents_dict = {'liberal': Agent(system_instructions="you are a liberal", model=self.model, kwargs=self.kwargs),
Expand Down

0 comments on commit e46b777

Please sign in to comment.