diff --git a/plurals/deliberation.py b/plurals/deliberation.py index 3e0b73a..eedc43b 100644 --- a/plurals/deliberation.py +++ b/plurals/deliberation.py @@ -498,7 +498,7 @@ def process(self): previous_responses_slice = previous_responses[-self.last_n:] previous_responses_str = format_previous_responses( previous_responses_slice) - agent.combination_instructions = self.combination_instructions + agent.combination_instructions = agent.combination_instructions if agent.combination_instructions else self.combination_instructions response = agent.process( previous_responses=previous_responses_str) previous_responses.append(response) @@ -590,7 +590,8 @@ def __init__( if len(agents) != 2: raise ValueError("Debate requires exactly two agents.") super().__init__( - agents, task, shuffle, cycles, last_n, combination_instructions, moderator) + agents=agents, task=task, shuffle=shuffle, cycles=cycles, last_n=last_n, moderator=moderator) + self.combination_instructions = combination_instructions @staticmethod def _format_previous_responses(responses: List[str]) -> str: @@ -648,7 +649,7 @@ def process(self): previous_responses_str = self._format_previous_responses( previous_responses_agent2[-self.last_n:]) - agent.combination_instructions = self.combination_instructions + agent.combination_instructions = agent.combination_instructions if agent.combination_instructions else self.combination_instructions response = agent.process( previous_responses=previous_responses_str) response = self._strip_placeholders(response) @@ -765,7 +766,6 @@ def __init__(self, self.edges = edges super().__init__(agents=self.agents, task=task, last_n=last_n, - combination_instructions=combination_instructions, moderator=moderator) self._build_graph() @@ -838,7 +838,7 @@ def process(self): # Process agents according to topological order response_dict = {} for agent in topological_order: - agent.combination_instructions = self.combination_instructions + agent.combination_instructions = agent.combination_instructions if agent.combination_instructions else self.combination_instructions # Gather responses from all predecessors to form the input for the current agent previous_responses = [response_dict[pred] for pred in self.agents if agent in self.graph[pred]] diff --git a/plurals/unit_test.py b/plurals/unit_test.py index 00efb4a..da8d08f 100644 --- a/plurals/unit_test.py +++ b/plurals/unit_test.py @@ -454,6 +454,26 @@ def test_chain_combination_instructions(self): self.assertIsNotNone(mixed.final_response) self.assertEqual(DEFAULTS['combination_instructions']['chain'], mixed.combination_instructions) + def test_chain_with_different_combination_instructions(self): + task = "Say hello" + agent1 = Agent(task=task, combination_instructions="Hello to previous ${previous_responses}", + model="gpt-3.5-turbo") + agent2 = Agent(task=task, combination_instructions="Prev2 to previous ${previous_responses}", model="gpt-4") + agent3 = Agent(task=task, combination_instructions="Greetings to previous ${previous_responses}", + model="gpt-3.5-turbo") + + chain = Chain([agent1, agent2, agent3], task=task) + chain.process() + + self.assertEqual(agent1.combination_instructions, "Hello to previous ${previous_responses}") + self.assertEqual(agent2.combination_instructions, "Prev2 to previous ${previous_responses}") + self.assertEqual(agent3.combination_instructions, "Greetings to previous ${previous_responses}") + + self.assertEqual(len(chain.responses), 3) + + self.assertIn("Prev2 to previous", agent2.prompts[0]['user']) + self.assertIn("Greetings to previous", agent3.prompts[0]['user']) + def test_chain_debate_instructions(self): a2 = Agent(ideology='moderate', model=self.model) a3 = Agent(ideology='liberal', model=self.model) @@ -670,6 +690,26 @@ def test_debate_with_different_agent_tasks(self): self.assertIn("effective in other countries", debate.responses[2]) self.assertIn("self-defense", debate.responses[3]) + def test_debate_with_different_combination_instructions(self): + task = "Debate the pros and cons of artificial intelligence in 10 words." + agent1 = Agent(task=task, combination_instructions="Respond to previous argument: ${previous_responses}", + model="gpt-3.5-turbo") + agent2 = Agent(task=task, combination_instructions="Counter the previous point: ${previous_responses}", + model="gpt-4") + + debate = Debate([agent1, agent2], task=task) + debate.process() + + print(agent1.combination_instructions) + self.assertEqual(agent1.combination_instructions, "Respond to previous argument: ${previous_responses}") + self.assertEqual(agent2.combination_instructions, "Counter the previous point: ${previous_responses}") + + self.assertEqual(len(debate.responses), 2) + + self.assertIn("Counter the previous point:", agent2.prompts[0]['user']) + + + class TestAgentStructures(unittest.TestCase): @@ -764,6 +804,35 @@ def test_graph_with_different_agent_tasks(self): self.assertIn("inequality", graph.responses[1]) self.assertIn("Both economic growth and social equality", final_response) + def test_graph_with_different_combination_instructions(self): + task = "Say hello" + agent1 = Agent(task=task, combination_instructions="Hello to previous ${previous_responses}", + model="gpt-3.5-turbo") + agent2 = Agent(task=task, combination_instructions="Prev2 to previous ${previous_responses}", model="gpt-4") + agent3 = Agent(task=task, combination_instructions="Greetings to previous ${previous_responses}", + model="gpt-3.5-turbo") + + agents = { + 'agent1': agent1, + 'agent2': agent2, + 'agent3': agent3 + } + edges = [('agent1', 'agent2'), ('agent1', 'agent3'), ('agent2', 'agent3')] + + graph = Graph(agents=agents, edges=edges, task=task) + graph.process() + + # Check if the combination instructions are preserved for each agent + self.assertEqual(agent1.combination_instructions, "Hello to previous ${previous_responses}") + self.assertEqual(agent2.combination_instructions, "Prev2 to previous ${previous_responses}") + self.assertEqual(agent3.combination_instructions, "Greetings to previous ${previous_responses}") + + self.assertEqual(len(graph.responses), 3) + + # Check if the combination instructions are used in the prompts + self.assertIn("Prev2 to previous", agent2.prompts[0]['user']) + self.assertIn("Greetings to previous", agent3.prompts[0]['user']) + def test_dict_to_list_conversion(self): """Test that the dictionary method (Method 2) correctly converts to the list method (Method 1) internally.""" agents_dict = {'liberal': Agent(system_instructions="you are a liberal", model=self.model, kwargs=self.kwargs),