-
Notifications
You must be signed in to change notification settings - Fork 2
/
learn.py
129 lines (124 loc) · 4.24 KB
/
learn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import hydra
import wandb
import logging
import os
import openai
from omegaconf import OmegaConf
from pathlib import Path
from evolution.utils.misc import *
from evolution.utils.extract_task_code import *
from zero_hero.behavior import BehaviorCaptioner
from zero_hero.core import TaskNode, TaskDatabase, SkillDatabase
from zero_hero.utils import FakeWandb
import json
@hydra.main(config_path="cfg", config_name="config", version_base="1.1")
def main(cfg):
workspace_dir = Path.cwd()
logging.info(f"Workspace: {workspace_dir}")
openai.api_key = os.getenv("OPENAI_API_KEY")
logging.info(cfg)
env_name = cfg.env.env_name.lower()
task = cfg.task
specified_task = task is not None and len(task) > 0
seed = 99 if specified_task else cfg.seed
env_idx = f"E{seed:02d}"
tdb = TaskDatabase(
env_name=env_name,
env_idx=env_idx,
)
sdb = SkillDatabase(env_name, env_idx)
if specified_task:
tdb.add_task(task)
tdb.render()
task = tdb.pop()
if task is None or task == "":
logging.info(f"Nothing to do with task database {tdb.store_path}!")
return
cfg.task = task
cfg.seed = seed
precedents = cfg.precedents
if isinstance(cfg.precedents, str):
cfg.precedents = ast.literal_eval(cfg.precedents)
precedents = cfg.precedents
if precedents is None or len(precedents) == 0:
cfg.finetune = False
my_cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
if cfg.use_wandb:
wandbrun = wandb.init(
project=cfg.wandb_project,
config=my_cfg,
)
else:
wandbrun = FakeWandb(my_cfg)
task_node: TaskNode = TaskNode(
code=cfg.task,
n_samples=cfg.n_success_samples,
temperature=cfg.temperature,
model=cfg.model,
precedents=precedents,
skill_database = sdb,
).init()
bc = BehaviorCaptioner(
init_sys_prompt=f"{task_node.prompt_dir}/task/behavior_context.txt",
)
logging.info(f"Learning skill: {task}.")
for task_ite in range(cfg.task_iterations):
task_node.temperature += 0.2
if task_node.num_variants >= cfg.num_variants:
break
task_node.propose(
n_samples=cfg.n_reward_samples,
iterations=2,
temperature=cfg.temperature + task_ite * 0.2,
model=cfg.model,
) # params for child init
for reward_ite in range(cfg.reward_iterations):
success_nodes = task_node.children
for success_node in success_nodes:
reward_nodes = success_node.propose(
num_envs=cfg.num_envs,
headless=cfg.headless,
video=cfg.video,
memory_requirement=cfg.memory_requirement,
min_gpu=cfg.min_gpu,
max_iterations=cfg.max_iterations,
task_ite=task_ite,
reward_ite=reward_ite,
behavior_captioner=bc,
finetune=cfg.finetune,
)
for node in reward_nodes:
node.run()
for success_node in success_nodes:
_, succ_stat = success_node.collect()
wandbrun.log(
{
**succ_stat,
"reward_ite": reward_ite,
"task_ite": task_ite,
}
)
task_stat = task_node.collect()
wandbrun.log(
{
**task_stat,
"task_ite": task_ite,
"reward_ite": reward_ite,
}
)
if task_node.num_variants > 0:
task_status = "completed"
variants = [v.best_reward.idx for v in task_node.variants]
logging.info(
f"Collected new skill {task} with {task_node.num_variants} variants: {variants}."
)
else:
task_status = "failed"
logging.info(f"Mission impossible on {task}.")
variants = [""]
tdb.load()
tdb.update_task({"command": task, "status": task_status, "variants": variants[0]})
tdb.render()
logging.info(f"Done! for task: {task}.")
if __name__ == "__main__":
main()