-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Sync private repo with this public repo (#12)
- Loading branch information
Showing
34 changed files
with
6,010 additions
and
194 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Come here if you encounter any issues | ||
|
||
## Missing episodes | ||
|
||
Large batch size may cause some episodes to be skipped. This is due to the fact that the server may not be able to handle the load. Try reducing the batch size. But you can also use the script in `examples/fix_missing_episodes.py` to fix the missing episodes. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Example Scripts For Using The Library | ||
|
||
## Example 1: Evaluating existing episodes | ||
|
||
```python | ||
python examples/evaluate_existing_episodes.py --tag=<tag to upload to the database> --model=<the model used to re-evaluate the existing episodes> --batch_size=<batch size used for evaluation> --push-to-db | ||
``` | ||
|
||
Run ```python examples/evaluate_existing_episodes.py --help``` for more information. | ||
|
||
## Example 2: Generate script-like episodes | ||
See `docs/simulation_modes.md` for more information. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Hyperparameters that are used in the simulation | ||
|
||
## Tags | ||
|
||
- `TAG`: The tag of the simulation. This tag is used to identify the simulation in the database. | ||
- `TAG_TO_CHECK_EXISTING_EPISODES`: Scripts like `examples/experiment_eval.py` checks if there are existing episodes with the same tag in the database. If there are, the simulation **will not** be run. This is to avoid running the same simulation twice. If you want to run the simulation again, you can change the tag or set `TAG_TO_CHECK_EXISTING_EPISODES` to `None`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Different Modes of Simulation | ||
|
||
## Simulation Modes | ||
|
||
The simulation can be run in different modes. The mode is specified in the configuration file. The following modes are available: | ||
|
||
### Sotopia-lite | ||
|
||
- `lite`: The simulation runs without characters' detailed background information but just names. To use this mode, set `lite` to `True` in the gin configuration command. | ||
e.g., | ||
```bash | ||
python examples/experiment_eval.py \ | ||
--gin_file sotopia_conf/generation_utils_conf/generate.gin \ | ||
--gin_file sotopia_conf/server_conf/server.gin \ | ||
--gin_file sotopia_conf/run_async_server_in_batch.gin \ | ||
'--gin.ENV_IDS=[]' \ | ||
'--gin.AGENT1_MODEL="gpt-3.5-turbo"' \ | ||
'--gin.AGENT2_MODEL="gpt-3.5-turbo"' \ | ||
'--gin.BATCH_SIZE=5' \ | ||
'--gin.TAG="lite_gpt3.5_gpt3.5"' \ | ||
'--gin.TAG_TO_CHECK_EXISTING_EPISODES="lite_gpt3.5_gpt3.5"' \ | ||
'--gin.PUSH_TO_DB=False' \ | ||
'--gin.OMNISCIENT=False' \ | ||
'--gin.VERBOSE=False' \ | ||
'--gin.LITE=True' \ | ||
``` | ||
|
||
### Sotopia-script | ||
|
||
- `script`: The simulation runs with enabling LLMs generating the interaction in one shot with a script writing setting. To use this mode, set `script` to `True` in the gin configuration command. | ||
|
||
e.g., | ||
|
||
```bash | ||
python examples/generate_script.py \ | ||
--gin_file sotopia_conf/generation_utils_conf/generate.gin \ | ||
--gin_file sotopia_conf/run_async_server_in_batch_script.gin \ | ||
'--gin.ENV_IDS=[]' \ | ||
'--gin.SCRIPT_MODEL="gpt-3.5-turbo"' \ | ||
'--gin.BATCH_SIZE=5' \ | ||
'--gin.TAG="lite_script_gpt3.5_gpt3.5"' \ | ||
'--gin.TAG_TO_CHECK_EXISTING_EPISODES="lite_script_gpt3.5_gpt3.5"' \ | ||
'--gin.PUSH_TO_DB=True' \ | ||
'--gin.VERBOSE=False' \ | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import asyncio | ||
import logging | ||
import subprocess | ||
import typing | ||
from datetime import datetime | ||
from logging import FileHandler | ||
|
||
import gin | ||
import typer | ||
from experiment_eval import _iterate_env_agent_combo_not_in_db | ||
from rich import print | ||
from rich.logging import RichHandler | ||
from tqdm import tqdm | ||
from tqdm.asyncio import tqdm_asyncio | ||
|
||
from sotopia.agents.llm_agent import Agents | ||
from sotopia.database.logs import AnnotationForEpisode, EpisodeLog | ||
from sotopia.database.persistent_profile import EnvironmentProfile | ||
from sotopia.generation_utils.generate import LLM_Name, agenerate_script | ||
from sotopia.messages.message_classes import ( | ||
AgentAction, | ||
Observation, | ||
ScriptBackground, | ||
) | ||
from sotopia.samplers import ( | ||
BaseSampler, | ||
ConstraintBasedSampler, | ||
EnvAgentCombo, | ||
) | ||
from sotopia.server import aevaluate_one_episode, arun_one_script | ||
|
||
# date and message only | ||
FORMAT = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" | ||
|
||
process = subprocess.Popen( | ||
["git", "rev-parse", "HEAD"], shell=False, stdout=subprocess.PIPE | ||
) | ||
git_head_hash = process.communicate()[0].strip() | ||
|
||
logging.basicConfig( | ||
level=15, | ||
format=FORMAT, | ||
datefmt="[%X]", | ||
handlers=[ | ||
RichHandler(), | ||
FileHandler( | ||
datetime.now().strftime( | ||
f"./logs/%H_%M_%d_%m_%Y_{str(git_head_hash.decode('utf-8'))}.log" | ||
) | ||
), | ||
], | ||
) | ||
app = typer.Typer() | ||
|
||
|
||
def run_async_server_in_batch_aevaluate( | ||
batch_size: int = 10, | ||
model: LLM_Name = "gpt-4", | ||
reeval_list: list[str] = [], | ||
tag: str | None = None, | ||
push_to_db: bool = False, | ||
verbose: bool = False, | ||
) -> None: | ||
|
||
if not verbose: | ||
logger = logging.getLogger() | ||
logger.setLevel(logging.CRITICAL) | ||
rich_handler = logger.handlers[0] | ||
logger.removeHandler(rich_handler) | ||
|
||
episode_batch: list[EpisodeLog] = [] | ||
|
||
while True: | ||
for env_pk in tqdm(reeval_list): | ||
episode = EpisodeLog.get(env_pk) | ||
episode_batch.append(episode) | ||
if len(episode_batch) == batch_size: | ||
logging.info( | ||
f"Running batch of {batch_size} episodes: {episode_batch}" | ||
) | ||
episode_futures = [ | ||
aevaluate_one_episode( | ||
episode=episode, | ||
model=model, | ||
tag=tag, | ||
push_to_db=push_to_db, | ||
) | ||
for episode in episode_batch | ||
] | ||
asyncio.run( | ||
tqdm_asyncio.gather( | ||
*episode_futures, desc="Running one batch" | ||
) | ||
) | ||
|
||
episode_batch = [] | ||
else: | ||
if episode_batch: | ||
logging.info( | ||
f"Running batch of {batch_size} episodes: {episode_batch}" | ||
) | ||
episode_futures = [ | ||
aevaluate_one_episode( | ||
episode=episode, | ||
model=model, | ||
tag=tag, | ||
push_to_db=push_to_db, | ||
) | ||
for episode in episode_batch | ||
] | ||
asyncio.run( | ||
tqdm_asyncio.gather( | ||
*episode_futures, desc="Running one batch" | ||
) | ||
) | ||
return | ||
|
||
|
||
@app.command() | ||
def run_server( | ||
tag: str = "reeval_llama2", | ||
model: str = "togethercomputer/llama-2-70b-chat", # Why typer does not accept LLM_Name? | ||
batch_size: int = 10, | ||
push_to_db: bool = True, | ||
verbose: bool = False, | ||
) -> None: | ||
annotated_episodes_pks = [ | ||
AnnotationForEpisode.get(anno).episode | ||
for anno in AnnotationForEpisode.all_pks() | ||
] | ||
annotated_episodes_pks = list(set(annotated_episodes_pks)) | ||
model = typing.cast(LLM_Name, model) | ||
# Call the function with the specified parameters | ||
run_async_server_in_batch_aevaluate( | ||
tag=tag, | ||
model=model, | ||
batch_size=batch_size, | ||
push_to_db=push_to_db, | ||
verbose=verbose, | ||
reeval_list=annotated_episodes_pks, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
app() |
Oops, something went wrong.