diff --git a/llm_generate/generate.py b/llm_generate/generate.py deleted file mode 100644 index c75edba7..00000000 --- a/llm_generate/generate.py +++ /dev/null @@ -1,735 +0,0 @@ -import logging -import re -from typing import TypeVar, cast - -import gin -from beartype import beartype -from beartype.typing import Type -from langchain.callbacks import StdOutCallbackHandler -from langchain.chains import LLMChain -from langchain.chat_models import ChatOpenAI -from langchain.llms import OpenAI -from langchain.output_parsers import PydanticOutputParser -from langchain.prompts import ( - ChatPromptTemplate, - HumanMessagePromptTemplate, - PromptTemplate, -) -from langchain.schema import ( - BaseOutputParser, - HumanMessage, - OutputParserException, -) -from pydantic import BaseModel, Field, validator -from rich import print -from rich.logging import RichHandler -from typing_extensions import Literal - -from sotopia.database import EnvironmentProfile, RelationshipProfile -from sotopia.messages import ( - ActionType, - AgentAction, - ScriptBackground, - ScriptEnvironmentResponse, -) -from sotopia.utils import format_docstring - -from langchain_callback_handler import LoggingCallbackHandler -from llama2 import Llama2 - -log = logging.getLogger("generate") -logging_handler = LoggingCallbackHandler("langchain") - -LLM_Name = Literal[ - "togethercomputer/llama-2-7b-chat", - "togethercomputer/llama-2-70b-chat", - "togethercomputer/mpt-30b-chat", - "gpt-3.5-turbo", - "text-davinci-003", - "gpt-4", - "human", - "redis", -] - -OutputType = TypeVar("OutputType", bound=object) - - -class EnvResponse(BaseModel): - reasoning: str = Field( - description="first reiterate agents' social goals and then reason about what agents say/do and whether that aligns with their goals." - ) - p1_rate: int = Field( - description="rating of participant 1, on the scale of 0 to 9" - ) - p2_rate: int = Field( - description="rating of participant 2, on the scale of 0 to 9" - ) - - -class EnvResponsePydanticOutputParser(PydanticOutputParser[EnvResponse]): - def __init__(self, pydantic_object: Type[BaseModel] = EnvResponse) -> None: - super(EnvResponsePydanticOutputParser, self).__init__( - pydantic_object=pydantic_object - ) - - def parse(self, text: str) -> EnvResponse: - # remove trailing commas before ) or ] from text - text = re.sub(r",\s*(\)|\])", r"\1", text) - return super().parse(text) - - def get_format_instructions(self) -> str: - format_instruction = super().get_format_instructions() - return format_instruction - - -class ListOfIntOutputParser(BaseOutputParser[list[int]]): - number_of_int: int | None - range_of_int: tuple[int, int] | None - - def __init__( - self, - number_of_int: int | None = None, - range_of_int: tuple[int, int] | None = None, - ): - """ - Parse the output to a list of integers - - Args: - number_of_int (int | None): The number of integers in the output. If None, the number of integers is not fixed. - """ - super().__init__() - self.number_of_int = number_of_int - self.range_of_int = range_of_int - - def _get_description_text(self) -> str: - return f"a list of{' ' + str(self.number_of_int) if self.number_of_int else ''} intergers{' within the range of' + str(self.range_of_int) if self.range_of_int else ''} separated by space" - - def get_format_instructions(self) -> str: - return "Please output " + self._get_description_text() - - def parse(self, output: str) -> list[int]: - try: - if ":" in output: - output = output.split(":")[1] - result = [int(x) for x in output.split(" ") if x] - if self.number_of_int and len(result) != self.number_of_int: - msg = ( - f"Expect {self.number_of_int} integers, got {len(result)}" - ) - raise OutputParserException(msg) - if self.range_of_int: - for x in result: - if x < self.range_of_int[0] or x > self.range_of_int[1]: - msg = f"Expect integers within the range of {self.range_of_int}, got {result}" - raise OutputParserException(msg) - return result - except KeyboardInterrupt: - raise KeyboardInterrupt - except Exception as e: - msg = f"Exception {e}: the output format is not correct. Expect {self._get_description_text()}, got {output}" - raise OutputParserException(msg) - - @property - def _type(self) -> str: - """Return the type key.""" - return "list[int]" - - -class ListOfStrOutputParser(BaseOutputParser[list[str]]): - number_of_str: int | None - - def __init__( - self, - number_of_str: int | None = None, - ): - """ - Parse the output to a list of strings - - Args: - number_of_str (int | None): The number of strings in the output. If None, the number of strings is not fixed. - """ - super().__init__() - self.number_of_str = number_of_str - - def _get_description_text(self) -> str: - return f"a list of{' ' + str(self.number_of_str) if self.number_of_str else ''} strings separated by space" - - def get_format_instructions(self) -> str: - return "Please output " + self._get_description_text() - - def parse(self, output: str) -> list[str]: - try: - result = output.split(" ") - if self.number_of_str and len(result) != self.number_of_str: - msg = f"Expect {self.number_of_str} strings, got {len(result)}" - raise OutputParserException(msg) - return result - except KeyboardInterrupt: - raise KeyboardInterrupt - except Exception as e: - msg = f"Exception {e}: the output format is not correct. Expect {self._get_description_text()}, got {output}" - raise OutputParserException(msg) - - @property - def _type(self) -> str: - """Return the type key.""" - return "list[str]" - - -class StrOutputParser(BaseOutputParser[str]): - def __init__(self) -> None: - super().__init__() - - def get_format_instructions(self) -> str: - return "Please output a string" - - def parse(self, output: str) -> str: - return output - - @property - def _type(self) -> str: - """Return the type key.""" - return "str" - - -def _return_fixed_model_version( - model_name: Literal["gpt-3.5-turbo", "gpt-4"] -) -> str: - return { - "gpt-3.5-turbo": "gpt-3.5-turbo-0613", - "gpt-4": "gpt-4-0613", - }[model_name] - - -@gin.configurable -@beartype -def obtain_chain( - model_name: LLM_Name, - template: str, - input_variables: list[str], - temperature: float = 0.7, - max_retries: int = 6, -) -> LLMChain: - """ - Using langchain to sample profiles for participants - """ - match model_name: - case "gpt-3.5-turbo" | "gpt-4": - human_message_prompt = HumanMessagePromptTemplate( - prompt=PromptTemplate( - template=template, - input_variables=input_variables, - ) - ) - chat_prompt_template = ChatPromptTemplate.from_messages( - [human_message_prompt] - ) - chat = ChatOpenAI( - model_name=_return_fixed_model_version(model_name), - temperature=temperature, - max_retries=max_retries, - ) - chain = LLMChain(llm=chat, prompt=chat_prompt_template) - return chain - case "text-davinci-003": - # Warning: no interactive mode for 003 - llm = OpenAI( - model_name=model_name, - temperature=temperature, - max_retries=max_retries, - ) - prompt = PromptTemplate( - input_variables=input_variables, - template=template, - ) - chain = LLMChain(llm=llm, prompt=prompt) - return chain - case "togethercomputer/llama-2-7b-chat" | "togethercomputer/llama-2-70b-chat": - human_message_prompt = HumanMessagePromptTemplate( - prompt=PromptTemplate( - template=template, - input_variables=input_variables, - ) - ) - chat_prompt_template = ChatPromptTemplate.from_messages( - [human_message_prompt] - ) - together_llm = Llama2( - model_name=model_name, temperature=temperature - ) - chain = LLMChain(llm=together_llm, prompt=chat_prompt_template) - return chain - case _: - raise ValueError(f"Invalid model name: {model_name}") - - -@beartype -def format_bad_output( - ill_formed_output: str, - format_instructions: str, - model_name: LLM_Name = "gpt-3.5-turbo", -) -> str: - template = """ - Given the string that can not be parsed by json parser, reformat it to a string that can be parsed by json parser. - Original string: {ill_formed_output} - - Format instructions: {format_instructions} - - Please only generate the JSON: - """ - chain = obtain_chain( - model_name=model_name, - template=template, - input_variables=re.findall(r"{(.*?)}", template), - ) - input_values = { - "ill_formed_output": ill_formed_output, - "format_instructions": format_instructions, - } - reformat = chain.predict([logging_handler], **input_values) - log.info(f"Reformated output: {reformat}") - return reformat - - -@beartype -def generate( - model_name: LLM_Name, - template: str, - input_values: dict[str, str], - output_parser: BaseOutputParser[OutputType], - temperature: float = 0.7, -) -> OutputType: - input_variables = re.findall(r"{(.*?)}", template) - assert set(input_variables) == set( - list(input_values.keys()) + ["format_instructions"] - ) or set(input_variables) == set( - list(input_values.keys()) - ), f"The variables in the template must match input_values except for format_instructions. Got {sorted(input_values.keys())}, expect {sorted(input_variables)}" - # process template - template = format_docstring(template) - chain = obtain_chain( - model_name=model_name, - template=template, - input_variables=input_variables, - temperature=temperature, - ) - if "format_instructions" not in input_values: - input_values[ - "format_instructions" - ] = output_parser.get_format_instructions() - result = chain.predict([logging_handler], **input_values) - try: - parsed_result = output_parser.parse(result) - except KeyboardInterrupt: - raise KeyboardInterrupt - except Exception as e: - log.debug( - f"[red] Failed to parse result: {result}\nEncounter Exception {e}\nstart to reparse", - extra={"markup": True}, - ) - reformat_parsed_result = format_bad_output( - result, format_instructions=output_parser.get_format_instructions() - ) - parsed_result = output_parser.parse(reformat_parsed_result) - log.info(f"Generated result: {parsed_result}") - return parsed_result - - -@gin.configurable -@beartype -async def agenerate( - model_name: LLM_Name, - template: str, - input_values: dict[str, str], - output_parser: BaseOutputParser[OutputType], - temperature: float = 0.7, -) -> tuple[OutputType, str]: - input_variables = re.findall(r"{(.*?)}", template) - assert set(input_variables) == set( - list(input_values.keys()) + ["format_instructions"] - ) or set(input_variables) == set( - list(input_values.keys()) - ), f"The variables in the template must match input_values except for format_instructions. Got {sorted(input_values.keys())}, expect {sorted(input_variables)}" - # process template - template = format_docstring(template) - chain = obtain_chain( - model_name=model_name, - template=template, - input_variables=input_variables, - temperature=temperature, - ) - if "format_instructions" not in input_values: - input_values[ - "format_instructions" - ] = output_parser.get_format_instructions() - result = await chain.apredict([logging_handler], **input_values) - prompt = logging_handler.retrive_prompt() - try: - parsed_result = output_parser.parse(result) - except Exception as e: - log.debug( - f"[red] Failed to parse result: {result}\nEncounter Exception {e}\nstart to reparse", - extra={"markup": True}, - ) - reformat_parsed_result = format_bad_output( - result, format_instructions=output_parser.get_format_instructions() - ) - parsed_result = output_parser.parse(reformat_parsed_result) - log.info(f"Generated result: {parsed_result}") - return parsed_result, prompt - - -# deprecated function -@beartype -def generate_episode( - model_name: LLM_Name, - participants: str = "Jack (a greedy person), Rose", - topic: str = "lawsuit", - extra_info: str = "", -) -> EnvResponse: - """ - Using langchain to generate an example episode - """ - return generate( - model_name=model_name, - template=""" - Please generate a episode for the interaction between {participants} regarding {topic}. - You should generate the personal backgrounds and goals in this interaction. - Use the following extra info if given: {extra_info} - Please use the following format: - {format_instructions} - """, - input_values=dict( - participants=participants, - topic=topic, - extra_info=extra_info, - ), - output_parser=EnvResponsePydanticOutputParser(), - ) - - -@gin.configurable -@beartype -async def agenerate_env_profile( - model_name: LLM_Name, - inspiration_prompt: str = "asking my boyfriend to stop being friends with his ex", - examples: str = "", -) -> tuple[EnvironmentProfile, str]: - """ - Using langchain to generate the background - """ - return await agenerate( - model_name=model_name, - template="""Please generate scenarios and goals based on the examples below as well as the inspirational prompt, when creating the goals, try to find one point that both sides may not agree upon initially and need to collaboratively resolve it. - Examples: - {examples} - Inspirational prompt: {inspiration_prompt} - Please use the following format: - {format_instructions} - """, - input_values=dict( - inspiration_prompt=inspiration_prompt, - examples=examples, - ), - output_parser=PydanticOutputParser(pydantic_object=EnvironmentProfile), - ) - - -@beartype -async def agenerate_relationship_profile( - model_name: LLM_Name, - agents_profiles: list[str], -) -> tuple[RelationshipProfile, str]: - """ - Using langchain to generate the background - """ - agent_profile = "\n".join(agents_profiles) - return await agenerate( - model_name=model_name, - template="""Please generate relationship between two agents based on the agents' profiles below. Note that you generate - {agent_profile} - Please use the following format: - {format_instructions} - """, - input_values=dict( - agent_profile=agent_profile, - ), - output_parser=PydanticOutputParser( - pydantic_object=RelationshipProfile - ), - ) - - -@beartype -async def agenerate_enviroment_profile( - model_name: LLM_Name, - inspiration_prompt: str = "asking my boyfriend to stop being friends with his ex", - examples: str = "", -) -> tuple[EnvironmentProfile, str]: - """ - Using langchain to generate the background - """ - return await agenerate( - model_name=model_name, - template="""Please generate scenarios and goals based on the examples below as well as the inspirational prompt, when creating the goals, try to find one point that both sides may not agree upon initially and need to collaboratively resolve it. - Examples: - {examples} - Inspirational prompt: {inspiration_prompt} - Please use the following format: - {format_instructions} - """, - input_values=dict( - inspiration_prompt=inspiration_prompt, - examples=examples, - ), - output_parser=PydanticOutputParser(pydantic_object=EnvironmentProfile), - ) - - -@beartype -def fill_in_background( - model_name: LLM_Name, - partial_background: ScriptBackground, -) -> ScriptBackground: - """ - Fill in the missing information of the background - """ - return generate( - model_name=model_name, - template="""Please fill in all missing information of the given background, don't leave any tag: - {partial_background} - Please use the following format: - {format_instructions} - """, - input_values=dict( - partial_background=partial_background.to_natural_language(), - ), - output_parser=PydanticOutputParser(pydantic_object=ScriptBackground), - ) - - -@beartype -def generate_action( - model_name: LLM_Name, - history: str, - turn_number: int, - action_types: list[ActionType], - agent: str, - goal: str, -) -> AgentAction: - """ - Using langchain to generate an example episode - """ - try: - return generate( - model_name=model_name, - template=""" - Imagine you are {agent}, your task is to act/speak like {agent} with {agent}'s social goal in mind. - You can find {agent}'s background and goal in the following history: - {history} - You are at Turn #{turn_number}. Your available action types are - {action_list}. - Note: You can "leave" this conversation if 1. this conversation makes you uncomfortable, 2. you find it uninteresting/you lose your patience, 3. you have achieved your social goals, 4. or for other reasons you want to leave. - - Please only generate a JSON string including the action type and the argument. - Your action should follow the given format: - {format_instructions} - """, - input_values=dict( - agent=agent, - turn_number=str(turn_number), - history=history, - action_list=" ".join(action_types), - ), - output_parser=PydanticOutputParser(pydantic_object=AgentAction), - ) - except KeyboardInterrupt: - raise KeyboardInterrupt - except: - return AgentAction(action_type="none", argument="") - - -@beartype -def generate_action_speak( - model_name: LLM_Name, - history: str, - turn_number: int, - action_types: list[ActionType], - agent: str, - goal: str, -) -> AgentAction: - """ - Using langchain to generate the action but only speak action is allowed - """ - try: - utterance = generate( - model_name=model_name, - template=""" - You are {agent}. - {history} - - You are at Turn #{turn_number}. Your available action type is speak. - Your goal is: {goal} - Follow the given format: - {agent} said: - should not include any quotation marks, "Turn #", or etc. - """, - input_values=dict( - agent=agent, - turn_number=str(turn_number), - history=history, - goal=goal, - ), - output_parser=StrOutputParser(), - ) - # delete the first line - utterance = utterance.replace(f"{agent} said:", "") - utterance = utterance.replace(f"Turn #{turn_number}:", "") - utterance = utterance.strip() - utterance = utterance.replace('"', "") - return AgentAction(action_type="speak", argument=utterance) - except KeyboardInterrupt: - raise KeyboardInterrupt - except: - return AgentAction(action_type="none", argument="") - - -@gin.configurable -@beartype -async def agenerate_action( - model_name: LLM_Name, - history: str, - turn_number: int, - action_types: list[ActionType], - agent: str, - goal: str, - temperature: float = 0.7, -) -> tuple[AgentAction, str]: - """ - Using langchain to generate an example episode - """ - try: - return await agenerate( - model_name=model_name, - template=""" - Imagine you are {agent}, your task is to act/speak as {agent} would, keeping in mind {agent}'s social goal. - You can find {agent}'s background and goal in the 'Here is the context of the interaction' field. - Note that {agent}'s secret and goal is only visible to you. - You should try your best to achieve {agent}'s goal in a way that align with their character traits. - Additionally, maintaining the conversation's naturalness and realism is essential (e.g., do not repeat what other people has already said before). - {history}. - You are at Turn #{turn_number}. Your available action types are - {action_list}. - Note: You can "leave" this conversation if 1. you have achieved your social goals, 2. this conversation makes you uncomfortable, 3. you find it uninteresting/you lose your patience, 4. or for other reasons you want to leave. - - Please only generate a JSON string including the action type and the argument. - Your action should follow the given format: - {format_instructions} - """, - input_values=dict( - agent=agent, - turn_number=str(turn_number), - history=history, - action_list=" ".join(action_types), - ), - output_parser=PydanticOutputParser(pydantic_object=AgentAction), - temperature=temperature, - ) - except: - return AgentAction(action_type="none", argument=""), "" - - -@beartype -def process_history( - script: ScriptBackground | EnvResponse | dict[str, AgentAction] -) -> str: - """ - Format the script background - """ - result = "" - if isinstance(script, ScriptBackground | EnvResponse): - script = script.dict() - result = "The initial observation\n\n" - for key, value in script.items(): - if value: - result += f"{key}: {value} \n" - return result - - -@beartype -def generate_init_profile( - model_name: LLM_Name, basic_info: dict[str, str] -) -> str: - """ - Using langchain to generate the background - """ - return generate( - model_name=model_name, - template="""Please expand a fictional background for {name}. Here is the basic information: - {name}'s age: {age} - {name}'s gender identity: {gender_identity} - {name}'s pronouns: {pronoun} - {name}'s occupation: {occupation} - {name}'s big 5 personality traits: {bigfive} - {name}'s moral Foundation: think {mft} is more important than others - {name}'s Schwartz portrait value: {schwartz} - {name}'s decision-making style: {decision_style} - {name}'s secret: {secret} - Include the previous information in the background. - Then expand the personal backgrounds with concrete details (e.g, look, family, hobbies, friends and etc.) - For the personality and values (e.g., MBTI, moral foundation, and etc.), - remember to use examples and behaviors in the person's life to demonstrate it. - """, - input_values=dict( - name=basic_info["name"], - age=basic_info["age"], - gender_identity=basic_info["gender_identity"], - pronoun=basic_info["pronoun"], - occupation=basic_info["occupation"], - bigfive=basic_info["Big_Five_Personality"], - mft=basic_info["Moral_Foundation"], - schwartz=basic_info["Schwartz_Portrait_Value"], - decision_style=basic_info["Decision_making_Style"], - secret=basic_info["secret"], - ), - output_parser=StrOutputParser(), - ) - - -@beartype -def convert_narratives(model_name: LLM_Name, narrative: str, text: str) -> str: - if narrative == "first": - return generate( - model_name=model_name, - template="""Please convert the following text into a first-person narrative. - e.g, replace name, he, she, him, her, his, and hers with I, me, my, and mine. - {text}""", - input_values=dict(text=text), - output_parser=StrOutputParser(), - ) - elif narrative == "second": - return generate( - model_name=model_name, - template="""Please convert the following text into a second-person narrative. - e.g, replace name, he, she, him, her, his, and hers with you, your, and yours. - {text}""", - input_values=dict(text=text), - output_parser=StrOutputParser(), - ) - else: - raise ValueError(f"Narrative {narrative} is not supported.") - - -@beartype -def generate_goal(model_name: LLM_Name, background: str) -> str: - """ - Using langchain to generate the background - """ - return generate( - model_name=model_name, - template="""Please generate your goal based on the background: - {background} - """, - input_values=dict(background=background), - output_parser=StrOutputParser(), - ) \ No newline at end of file diff --git a/llm_generate/generate_episode_constraint_based_sampling.py b/llm_generate/generate_episode_constraint_based_sampling.py deleted file mode 100644 index 703430ac..00000000 --- a/llm_generate/generate_episode_constraint_based_sampling.py +++ /dev/null @@ -1,118 +0,0 @@ -import asyncio -import logging -import sys -from logging import FileHandler -from typing import Literal - -from rich import print -from rich.logging import RichHandler - -from sotopia.agents import LLMAgent -from sotopia.database import ( - AgentProfile, - EnvAgentComboStorage, - EnvironmentProfile, -) -from sotopia.envs import ParallelSotopiaEnv -from sotopia.envs.evaluators import ( - ReachGoalLLMEvaluator, - RuleBasedTerminatedEvaluator, -) -from sotopia.generation_utils.generate import LLM_Name -from sotopia.messages import AgentAction, Observation -from sotopia.samplers import ( - BaseSampler, - ConstraintBasedSampler, - EnvAgentCombo, -) -from sotopia.server import run_async_server - -# date and message only -FORMAT = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" - -logging.basicConfig( - level=15, - format=FORMAT, - datefmt="[%X]", - handlers=[ - RichHandler(), - FileHandler("./round_robin_parallel_sotopia_env_2.log"), - ], -) - -model_names: dict[str, LLM_Name] = { - "env": "gpt-4", - "agent1": "gpt-3.5-turbo", - "agent2": "gpt-3.5-turbo", -} - -push_to_db = sys.argv[1] -assert push_to_db in ["True", "False"], "push_to_db should be True or False" -push_to_db_bool = push_to_db == "True" -env_ids: list[str] = [] - -for code_name in ["secret_feeling"]: - envs_with_code_name = EnvironmentProfile.find( - EnvironmentProfile.codename == code_name - ).all() - assert len(envs_with_code_name) - assert (env_id := envs_with_code_name[0].pk) - env_ids.append(env_id) - - -for env_id in env_ids: - assert env_id is not None, "env_id should not be None" - env_agent_combo_storage_list = EnvAgentComboStorage.find( - EnvAgentComboStorage.env_id == env_id - ).all() - - sampler = ( - ConstraintBasedSampler[Observation, AgentAction]( - env_candidates=[env_id], - ) - if len(env_agent_combo_storage_list) == 0 - else BaseSampler[Observation, AgentAction]() - ) - - env_agent_combo_list: list[EnvAgentCombo[Observation, AgentAction]] = [] - - import pdb; pdb.set_trace() - - for env_agent_combo_storage in env_agent_combo_storage_list: - assert isinstance(env_agent_combo_storage, EnvAgentComboStorage) - env_profile = EnvironmentProfile.get(env_agent_combo_storage.env_id) - env = ParallelSotopiaEnv( - env_profile=env_profile, - model_name=model_names["env"], - action_order="round-robin", - evaluators=[ - RuleBasedTerminatedEvaluator( - max_turn_number=20, max_stale_turn=2 - ), - ], - terminal_evaluators=[ - ReachGoalLLMEvaluator(model_names["env"]), - ], - ) - agent_profiles = [ - AgentProfile.get(id) for id in env_agent_combo_storage.agent_ids - ] - - agents = [ - LLMAgent(agent_profile=agent_profile, model_name=agent_model) - for agent_profile, agent_model in zip( - agent_profiles, [model_names["agent1"], model_names["agent2"]] - ) - ] - - env_agent_combo_list.append((env, agents)) - asyncio.run( - run_async_server( - model_dict=model_names, - action_order="round-robin", - sampler=sampler, - env_agent_combo_list=env_agent_combo_list, - push_to_db=push_to_db_bool, - using_async=False, - ) - ) \ No newline at end of file diff --git a/llm_generate/langchain_callback_handler.py b/llm_generate/langchain_callback_handler.py deleted file mode 100644 index 81f3afc0..00000000 --- a/llm_generate/langchain_callback_handler.py +++ /dev/null @@ -1,60 +0,0 @@ -import logging -from typing import Any - -from langchain.callbacks import StdOutCallbackHandler - -logging.addLevelName(15, "LangChain") - - -class LoggingCallbackHandler(StdOutCallbackHandler): - """Callback Handler that prints to std out.""" - - always_verbose = True - - def __init__(self, name: str) -> None: - """Initialize callback handler.""" - super().__init__() - self.logger = logging.getLogger(name) - self.prompt = "" - - def on_chain_start(self, *args: Any, **kwargs: Any) -> None: - pass - - def on_chain_end(self, *args: Any, **kwargs: Any) -> None: - pass - - def on_agent_action(self, *args: Any, **kwargs: Any) -> Any: - pass - - def on_tool_end( - self, - *args: Any, - **kwargs: Any, - ) -> None: - pass - - def on_tool_error( - self, error: BaseException | KeyboardInterrupt, **kwargs: Any - ) -> None: - """Do nothing.""" - pass - - def on_text( - self, - text: str, - color: str | None = None, - end: str = "", - **kwargs: Any, - ) -> None: - """Run when agent ends.""" - # leave only prompt for environment - text = text.replace("\x1b[32;1m\x1b[1;3mHuman: ", "") - logging.log(15, f"LLM Call: {text}") - self.prompt = text - - def retrive_prompt(self) -> str: - return self.prompt - - def on_agent_finish(self, *args: Any, **kwargs: Any) -> None: - """Run on agent end.""" - pass \ No newline at end of file diff --git a/llm_generate/llama2.py b/llm_generate/llama2.py deleted file mode 100644 index cd0c54f6..00000000 --- a/llm_generate/llama2.py +++ /dev/null @@ -1,213 +0,0 @@ -from __future__ import annotations - -import logging -import os -import sys -from typing import ( - Any, - Callable, - Coroutine, - Dict, - List, - Mapping, - Optional, - Tuple, - Union, - cast, -) - -import together -from langchain.callbacks.manager import ( - AsyncCallbackManagerForLLMRun, - CallbackManagerForLLMRun, -) -from langchain.chat_models.base import BaseChatModel -from langchain.schema import ( - AIMessage, - BaseMessage, - ChatGeneration, - ChatMessage, - ChatResult, - HumanMessage, - SystemMessage, -) -from langchain.utils import get_from_dict_or_env -from pydantic import Extra, Field, root_validator -from tenacity import ( - before_sleep_log, - retry, - retry_if_exception_type, - stop_after_attempt, - wait_exponential, -) - - -def _convert_message_to_dict(message: BaseMessage) -> dict[str, Any]: - if isinstance(message, ChatMessage): - message_dict = {"role": message.role, "content": message.content} - elif isinstance(message, HumanMessage): - message_dict = {"role": "user", "content": message.content} - elif isinstance(message, AIMessage): - message_dict = {"role": "assistant", "content": message.content} - elif isinstance(message, SystemMessage): - message_dict = {"role": "system", "content": message.content} - else: - raise ValueError(f"Got unknown type {message}") - if "name" in message.additional_kwargs: - message_dict["name"] = message.additional_kwargs["name"] - return message_dict - - -def _convert_dict_to_message(_dict: dict[str, str]) -> BaseMessage: - text = _dict["text"] - return AIMessage(content=text) - - -def _make_prompt_from_dict(dialog: List[dict[str, str]]) -> str: - """ - Follow chat format in https://docs.together.ai/docs/python-chat - example: together complete ": List the best restaurants in SF\n: " - python example: together.Complete.create(prompt=(above), model=model_name) - The convertion is similar to llama2 official, EXCEPT not using special tags: - https://github.com/facebookresearch/llama/blob/main/example_chat_completion.py - """ - user_tag = "" - assistant_tag = "" - tag_sep = ": " - conv_sep = "\n" - - if dialog[0]["role"] == "system": - dialog = [ - { - "role": dialog[1]["role"], - "content": dialog[0]["content"] - + conv_sep - + dialog[1]["content"], - } - ] + dialog[2:] - assert all([msg["role"] == "user" for msg in dialog[::2]]) and all( - [msg["role"] == "assistant" for msg in dialog[1::2]] - ), ( - "model only supports 'system', 'user' and 'assistant' roles, " - "starting with 'system', then 'user' and alternating (u/a/u/a/u...)" - ) - ret = "" - for d in dialog: - tt = user_tag - if d["role"] == "user": - tt = user_tag - elif d["role"] == "assistant": - tt = assistant_tag - ret += tt + tag_sep + d["content"].strip() + conv_sep - ret += assistant_tag + tag_sep - return ret - - -logger = logging.getLogger(__name__) - - -class Llama2(BaseChatModel): - client: type[together.Complete] = together.Complete #: :meta private: - model_name: str = "togethercomputer/llama-2-7b-chat" - """Model name to use.""" - # default Together params - temperature: float = 0.7 - max_tokens: int = 128 - top_p: float = 0.7 - top_k: int = 50 - repetition_penalty: float = 1.0 - start: bool = False - _llm_type: str = "llama2" - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.ignore - - @root_validator(pre=True) - def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: - """Build extra kwargs from additional params that were passed in.""" - all_required_field_names = { - field.alias for field in cls.__fields__.values() - } - - extra = values.get("model_kwargs", {}) - for field_name in list(values): - if field_name not in all_required_field_names: - if field_name in extra: - raise ValueError(f"Found {field_name} supplied twice.") - extra[field_name] = values.pop(field_name) - values["model_kwargs"] = extra - return values - - @root_validator() - def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]: - together_api_key = get_from_dict_or_env( - values, "together_api_key", "TOGETHER_API_KEY" - ) - together.api_key = together_api_key - values["client"] = together.Complete - return values - - def _generate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - if not self.start: - together.Models.start(self.model_name) - self.start = True - prompt, params = self._create_message_dicts(messages, stop) - response = self.client.create(prompt=prompt, **params) - chat_result = self._create_chat_result(response) - return chat_result - - @property - def _default_params(self) -> Dict[str, Any]: - """Get the default parameters for calling Together API.""" - return { - "model": self.model_name, - "max_tokens": self.max_tokens, - "temperature": self.temperature, - "top_p": self.top_p, - "top_k": self.top_k, - "repetition_penalty": self.repetition_penalty, - } - - def _create_message_dicts( - self, messages: List[BaseMessage], stop: Optional[List[str]] - ) -> Tuple[str, Dict[str, Any]]: - params: Dict[str, Any] = { - **{"model": self.model_name}, - **self._default_params, - } - if stop is not None: - if "stop" in params: - raise ValueError( - "`stop` found in both the input and default params." - ) - params["stop"] = stop - message_dicts = [_convert_message_to_dict(m) for m in messages] - message_prompt = _make_prompt_from_dict(message_dicts) - return message_prompt, params - - def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: - generations = [] - for res in response["output"]["choices"]: - message = _convert_dict_to_message(res) - gen = ChatGeneration(message=message) - generations.append(gen) - llm_output = {"model_name": self.model_name} - return ChatResult(generations=generations, llm_output=llm_output) - - async def _agenerate( - self, - messages: List[BaseMessage], - stop: List[str] | None = None, - run_manager: AsyncCallbackManagerForLLMRun | None = None, - **kwargs: Any, - ) -> ChatResult: - sync_run_manager = cast(CallbackManagerForLLMRun, run_manager) - return self._generate(messages, stop, sync_run_manager) \ No newline at end of file diff --git a/llm_generate/step2_push_agent_relationship_env_to_db.py b/llm_generate/step2_push_agent_relationship_env_to_db.py index ed955eb9..7a8fe8e8 100644 --- a/llm_generate/step2_push_agent_relationship_env_to_db.py +++ b/llm_generate/step2_push_agent_relationship_env_to_db.py @@ -9,13 +9,11 @@ AgentProfile, EnvironmentProfile, RelationshipProfile, - RelationshipType, ) from sotopia.database.env_agent_combo_storage import EnvAgentComboStorage from sotopia.samplers import ConstraintBasedSampler from sotopia.messages import AgentAction, Observation from sotopia.agents import LLMAgent -import redis @@ -41,16 +39,13 @@ def retrieve_agent_by_first_name(first_name: str) -> AgentProfile: def add_env_profile(**kwargs: dict[str, Any]) -> None: - import pdb; pdb.set_trace() env_profile = EnvironmentProfile(**kwargs) - env_profile.save(r) - import pdb; pdb.set_trace() + env_profile.save() def add_env_profiles(env_profiles: list[dict[str, Any]]) -> None: for env_profile in env_profiles: add_env_profile(**env_profile) - import pdb; pdb.set_trace() def add_relationship_profile(**kwargs: dict[str, Any]) -> None: @@ -90,9 +85,13 @@ def sample_env_agent_combo_and_push_to_db(env_id: str) -> None: sampler = ConstraintBasedSampler[Observation, AgentAction]( env_candidates=[env_id] ) - env_agent_combo_list = list( - sampler.sample(agent_classes=[LLMAgent] * 2, replacement=False) - ) + try: + env_agent_combo_list = list( + sampler.sample(agent_classes=[LLMAgent] * 2, replacement=False) + ) + except: + return + print(len(env_agent_combo_list)) for env, agent in env_agent_combo_list: EnvAgentComboStorage( env_id=env.profile.pk, @@ -111,7 +110,6 @@ def relationship_map(relationship: str) -> int: df = pd.read_csv(sys.argv[1]) type = sys.argv[2] if type == "agent": - import pdb; pdb.set_trace() agents = cast(list[dict[str, Any]], df.to_dict(orient="records")) for agent in agents: agent["age"] = int(agent["age"]) @@ -120,18 +118,8 @@ def relationship_map(relationship: str) -> int: "schwartz_personal_values" ].split(",") add_agents_to_database(agents) + Migrator().run() elif type == "environment": - pks = EnvironmentProfile.all_pks() - ''' - df = df[ - ( - df["Xuhui"].astype(float).fillna(0) - + df["Leena"].astype(float).fillna(0) - + df["Hao"].astype(float).fillna(0) - ) - > 1 - ] - ''' df = df[ [ "codename", @@ -146,19 +134,17 @@ def relationship_map(relationship: str) -> int: envs = cast(list[dict[str, Any]], df.to_dict(orient="records")) for env in envs: env["agent_goals"] = ast.literal_eval(env["agent_goals"]) - # check env['relationship'] is int assert isinstance(env["relationship"], int) - - add_env_profiles(envs) Migrator().run() elif type == "relationship": - import pdb; pdb.set_trace() relationships = cast( list[dict[str, Any]], df.to_dict(orient="records") ) for relationship in relationships: - relationship["relationship"] = relationship_map( - relationship["relationship"] - ) + assert isinstance(relationship["relationship"], int) add_relationship_profiles(relationships) Migrator().run() + elif type == 'agentenvcombo': + env_ids = list(EnvironmentProfile.all_pks()) + for env_id in env_ids: + sample_env_agent_combo_and_push_to_db(env_id) \ No newline at end of file diff --git a/llm_generate/test_redis.py b/llm_generate/test_redis1.py similarity index 100% rename from llm_generate/test_redis.py rename to llm_generate/test_redis1.py diff --git a/llm_generate/test_redis2.py b/llm_generate/test_redis2.py index 9e3996cf..06c691c9 100644 --- a/llm_generate/test_redis2.py +++ b/llm_generate/test_redis2.py @@ -7,12 +7,5 @@ class Person(JsonModel): # Create an instance of your model person = Person(name="John", age=30) -# Save to Redis -person.save() - -# Retrieve from Redis -retrieved_person = Person.load(person.id) - -# Print the retrieved data -print(retrieved_person.name) # Output: John -print(retrieved_person.age) # Output: 30 \ No newline at end of file +# Save to Redis with a specific key +person.save() \ No newline at end of file