Skip to content

Commit

Permalink
Expand HP search space for PC
Browse files Browse the repository at this point in the history
  • Loading branch information
ernestum committed Jan 6, 2024
1 parent f630e9e commit f688325
Showing 1 changed file with 26 additions and 12 deletions.
38 changes: 26 additions & 12 deletions src/imitation/scripts/config/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,28 +188,42 @@ def pc():
parallel_run_config = dict(
sacred_ex_name="train_preference_comparisons",
run_name="pc_tuning",
base_named_configs=["logging.wandb_logging"],
base_named_configs=[],
base_config_updates={
"environment": {"num_vec": 1},
"total_timesteps": 2e7,
"total_comparisons": 5000,
"query_schedule": "hyperbolic",
"gatherer_kwargs": {"sample": True},
"total_comparisons": 1000,
"active_selection": True,
},
search_space={
"named_configs": [
"reward.normalize_output_disable",
],
"named_configs": ["reward.reward_ensemble"],
"config_updates": {
"num_iterations": tune.choice([25, 50]),
"initial_comparison_frac": tune.choice([0.1, 0.25]),
"active_selection_oversampling": tune.randint(1, 11),
"comparison_queue_size": tune.randint(1, 1001), # upper bound determined by total_comparisons=1000
"exploration_frac": tune.uniform(0.0, 0.5),
"fragment_length": 100, # TODO: tune this too!
"gatherer_kwargs": {
"temperature": tune.uniform(0.0, 2.0),
"discount_factor": tune.uniform(0.95, 1.0),
"sample": tune.choice([True, False]),
},
"initial_comparison_frac": tune.uniform(0.01, 1.0),
"num_iterations": tune.randint(1, 51),
"preference_model_kwargs": {
"noise_prob": tune.uniform(0.0, 0.1),
"discount_factor": tune.uniform(0.95, 1.0),
},
"query_schedule": tune.choice(["hyperbolic", "constant", "inverse_quadratic"]),
"trajectory_generator_kwargs": {
"switch_prob": tune.uniform(0.1, 1),
"random_prob": tune.uniform(0.1, 0.9),
},
"transition_oversampling": tune.uniform(0.9, 2.0),
"reward_trainer_kwargs": {
"epochs": tune.choice([1, 3, 6]),
"epochs": tune.randint(1, 11),
},
"rl": {
"batch_size": tune.choice([512, 2048, 8192]),
"rl_kwargs": {
"learning_rate": tune.loguniform(1e-5, 1e-2),
"ent_coef": tune.loguniform(1e-7, 1e-3),
},
},
Expand Down

0 comments on commit f688325

Please sign in to comment.