Skip to content

Commit

Permalink
[BugFix] Task loading schema validation (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini authored Nov 16, 2023
1 parent 3bcdf6a commit 0830a2b
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 4 deletions.
1 change: 0 additions & 1 deletion benchmarl/environments/pettingzoo/simple_push.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,3 @@
class TaskConfig:
task: str = MISSING
max_cycles: int = MISSING
local_ratio: float = MISSING
2 changes: 1 addition & 1 deletion benchmarl/environments/pettingzoo/waterworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class TaskConfig:
poison_speed: float = MISSING
poison_reward: float = MISSING
food_reward: float = MISSING
encounter_rewar: float = MISSING
encounter_reward: float = MISSING
thrust_penalty: float = MISSING
local_ratio: float = MISSING
speed_features: bool = MISSING
6 changes: 4 additions & 2 deletions benchmarl/hydra_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def load_experiment_from_hydra(cfg: DictConfig, task_name: str) -> Experiment:

def load_task_config_from_hydra(cfg: DictConfig, task_name: str) -> Task:
return task_config_registry[task_name].update_config(
OmegaConf.to_container(cfg, resolve=True)
OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
)


Expand All @@ -56,5 +56,7 @@ def load_model_config_from_hydra(cfg: DictConfig) -> ModelConfig:
else:
model_class = model_config_registry[cfg.name]
return model_class(
**parse_model_config(OmegaConf.to_container(cfg, resolve=True))
**parse_model_config(
OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
)
)

0 comments on commit 0830a2b

Please sign in to comment.