-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconf.py
36 lines (30 loc) · 996 Bytes
/
conf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import os
from spirl.models.skill_prior_mdl import SkillPriorMdl
from spirl.components.logger import Logger
from spirl.utils.general_utils import AttrDict
from spirl.configs.default_data_configs.kitchen import data_spec
from spirl.components.evaluator import TopOfNSequenceEvaluator
current_dir = os.path.dirname(os.path.realpath(__file__))
configuration = {
'model': SkillPriorMdl,
'logger': Logger,
'data_dir': '.',
'epoch_cycles_train': 10,
'evaluator': TopOfNSequenceEvaluator,
'top_of_n_eval': 100,
'top_comp_metric': 'mse',
}
configuration = AttrDict(configuration)
model_config = AttrDict(
state_dim=data_spec.state_dim,
action_dim=data_spec.n_actions,
n_rollout_steps=10,
kl_div_weight=5e-4,
nz_enc=128,
nz_mid=128,
n_processing_layers=5,
)
# Dataset
data_config = AttrDict()
data_config.dataset_spec = data_spec
data_config.dataset_spec.subseq_len = model_config.n_rollout_steps + 1 # flat last action from seq gets cropped