Skip to content

Commit

Permalink
Merge branch 'main' into bug/fix-scenario-agent-name
Browse files Browse the repository at this point in the history
  • Loading branch information
lwaekfjlk authored Nov 13, 2023
2 parents 11a1a20 + 430ace6 commit c07e7ab
Show file tree
Hide file tree
Showing 110 changed files with 9,449 additions and 24 deletions.
167 changes: 166 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ __pycache__
dist
.venv

# Byte-compiled / optimized / DLL files
*.py[cod]
*$py.class

# C extensions
*.so

# Log
*.log
*.log.*
Expand All @@ -13,6 +20,7 @@ llm_ft/checkpoints/*
llm_ft/*_checkpoints/*
!**/dummy_conversation.json
!llm_ft/deepspeed_config_s2.json
!llm_rl/data/*.json

# Editor
.idea
Expand All @@ -33,4 +41,161 @@ tests/state_of_the_union.txt

# Build
build
!dummy_file
!dummy_file

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

./llm_rl/preprocess/GPT4-4_Redis_Easy_No_Slide

llm_rl/*cache/
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ We split our overall framework into multiple parts
2. Together AI Finetuning --> Input the train and test data / Output model checkpoint
3. LLM Finetuning --> Input the train and test data / Output model checkpoint
4. LLM Deplyment --> Input LLM Finetuned model checkpoint / Output Deployable OpenAI type API
5. Eval --> Input model checkpoint / Output evaluation scores
5. Eval --> Input model checkpoint / Output evaluation scores
20 changes: 20 additions & 0 deletions data_process/fastchat_data_preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import json
import os

sotopia_data_dir = "/Users/pamela/Documents/capstone/sotopia-ft-data/ft-data-gpt4-gpt4-easy-2-side-partial"

ft_data_list = []
count = 0
for file in os.listdir(sotopia_data_dir):
with open(os.path.join(sotopia_data_dir, file), 'r') as f:
file_dict = json.load(f)
fastchat_dict = {"id": f"identity_{count}", "conversations": []}
fastchat_dict["conversations"].append(
{"from": "human", "value": file_dict["prompt"]})
fastchat_dict["conversations"].append(
{"from": "gpt", "value": file_dict["result"]})
ft_data_list.append(fastchat_dict)
count += 1

with open("fastchat-ft-gp4-gpt4-easy-2-side-partial.json", "w") as f:
f.write(json.dumps(ft_data_list, indent=4))
75 changes: 75 additions & 0 deletions eval/llm_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import json
from tqdm import tqdm

from langchain.llms import OpenAI
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts import PromptTemplate
from langchain.pydantic_v1 import BaseModel, Field


class QuantitativeEval(BaseModel):
agent1_name: str = Field(description="Agent 1's name")
agent1_gain: int = Field(description="Agent 1's gain/loss")
agent2_name: str = Field(description="Agent 2's name")
agent2_gain: int = Field(description="Agent 2's gain/loss")


def get_model_parser(model_name='text-davinci-003') -> (PromptTemplate, PydanticOutputParser):
model = OpenAI(model_name=model_name, temperature=0.0)
parser = PydanticOutputParser(pydantic_object=QuantitativeEval)

prompt_text = (
"Try to understand the following situation and answer the question in the end. "
"\n Situation: {situation}"
"\n Question: {question}"
"\n Please represent loss as negative values. {format_instructions}\n "
)

prompt = PromptTemplate(
template=prompt_text,
input_variables=["situation", "question"],
partial_variables={"format_instructions": parser.get_format_instructions()}
)

prompt_and_model = prompt | model

return prompt_and_model, parser


def evaluate(environment_episode_map, environment_question_map, model_name='text-davinci-003'):
results = {}
model, response_parser = get_model_parser(model_name=model_name)

for environment_id, episodes in tqdm(environment_episode_map.items()):
results_for_env = []

for episode in episodes:
situation = episode["messages_and_rewards"]
question = environment_question_map.get(environment_id)

if question:
model_response = model.invoke({"situation": situation, "question": question})
parsed_output = response_parser.parse(model_response)
episode["output"] = parsed_output.dict()

results_for_env.append(episode)

results[environment_id] = results_for_env

return results


def main():
with open("human_readable_eps_by_env.json", "r") as f:
env_eps_map = json.load(f)

with open("env_specific_eval.json", "r") as f:
env_question_map = json.load(f)

res = evaluate(env_eps_map, env_question_map)

with open("env_specific_eval_with_output.json", "w") as f:
json.dump(res, f)

if __name__ == "__main__":
main()
25 changes: 25 additions & 0 deletions eval/pull_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from sotopia.database.logs import EpisodeLog
from sotopia.database.persistent_profile import EnvironmentProfile
from sotopia.database.persistent_profile import AgentProfile
import json

TAG = "ft-llama-2-13b-chat_baseline_ruiyi_1010_7" # Baseline tag

HARD_ENVS = ["01H7VFHNV13MHN97GAH73E3KM8", "01H7VFHN5WVC5HKKVBHZBA553R", "01H7VFHN9W0WAFZCBT09PKJJNK", "01H7VFHPDZVVCDZR3AARA547CY", "01H7VFHPQQQY6H4DNC6NBQ8XTG", "01H7VFHN7WJK7VWVRZZTQ6DX9T", "01H7VFHPS5WJW2694R1MNC8JFY",
"01H7VFHNN7XTR99319DS8KZCQM", "01H7VFHQ11NAMZS4A2RDGDB01V", "01H7VFHPSWGDGEYRP63H2DJKV0", "01H7VFHNF4G18PC9JHGRC8A1R6", "01H7VFHNNYH3W0VRWVY178K2TK", "01H7VFHP8AN5643B0NR0NP00VE", "01H7VFHN7A1ZX5KSMT2YN9RXC4"]

envs = []
eps_by_env = dict()
human_readable_eps_by_env = dict()

for env_profile_id in HARD_ENVS:
eps = list(EpisodeLog.find(EpisodeLog.tag == TAG,
EpisodeLog.environment == env_profile_id))
eps_by_env[env_profile_id] = eps
human_readable_eps_by_env[env_profile_id] = []
for ep in eps:
agent_profiles, messages_and_rewards = ep.render_for_humans()
human_readable_eps_by_env[env_profile_id].append({"env_pk": env_profile_id, "ep_pk": ep.pk, "agents": ep.agents, "messages_and_rewards": "\n".join(messages_and_rewards)})

with open("human_readable_eps_by_env.json", "w") as f:
json.dump(human_readable_eps_by_env, f)
Loading

0 comments on commit c07e7ab

Please sign in to comment.