diff --git a/allenact/algorithms/onpolicy_sync/engine.py b/allenact/algorithms/onpolicy_sync/engine.py index d0dfa949..b481afc8 100644 --- a/allenact/algorithms/onpolicy_sync/engine.py +++ b/allenact/algorithms/onpolicy_sync/engine.py @@ -1551,6 +1551,9 @@ def run_pipeline(self, valid_on_initial_weights: bool = False): self.training_pipeline.current_stage.training_settings ) + # Change engine attributes that depend on the current stage + self.training_pipeline.current_stage.change_engine_attributes(self) + rollout_storage = self.training_pipeline.rollout_storage uuid_to_storage = self.training_pipeline.current_stage_storage self.initialize_storage_and_viz( @@ -1644,6 +1647,9 @@ def run_pipeline(self, valid_on_initial_weights: bool = False): ) uuid_to_storage = new_uuid_to_storage + # Change engine attributes that depend on the current stage + self.training_pipeline.current_stage.change_engine_attributes(self) + already_saved_checkpoint = False if self.is_distributed: diff --git a/allenact/utils/experiment_utils.py b/allenact/utils/experiment_utils.py index 609b8010..87f6bbba 100644 --- a/allenact/utils/experiment_utils.py +++ b/allenact/utils/experiment_utils.py @@ -644,8 +644,11 @@ def __init__( stage_components: Optional[Sequence[StageComponent]] = None, early_stopping_criterion: Optional[EarlyStoppingCriterion] = None, training_settings: Optional[TrainingSettings] = None, + callback_to_change_engine_attributes: Optional[Dict[str, Any]] = None, **training_settings_kwargs, ): + self.callback_to_change_engine_attributes = callback_to_change_engine_attributes + # Populate TrainingSettings members # THIS MUST COME FIRST IN `__init__` as otherwise `__getattr__` will loop infinitely. assert training_settings is None or len(training_settings_kwargs) == 0 @@ -707,6 +710,17 @@ def reset(self): for memory in self.stage_component_uuid_to_stream_memory.values(): memory.clear() + # TODO: Replace Any with the correct type + def change_engine_attributes(self, engine: Any): + if self.callback_to_change_engine_attributes is not None: + for key, value in self.callback_to_change_engine_attributes.items(): + # check if the engine has the attribute + assert hasattr(engine, key) + + func = value["func"] + args = value["args"] + setattr(engine, key, func(engine, **args)) + @property def stage_components(self) -> Tuple[StageComponent]: return tuple(self._stage_components) @@ -747,7 +761,7 @@ def add_stage_component(self, stage_component: StageComponent): self.stage_component_uuid_to_stream_memory[stage_component.uuid] = Memory() def __setattr__(self, key: str, value: Any): - if key != "training_settings" and self.training_settings.has_key(key): + if key not in ["training_settings", "callback_to_change_engine_attributes"] and self.training_settings.has_key(key): raise NotImplementedError( f"Cannot set {key} in {self.__name__}, update the" f" `training_settings` attribute of {self.__name__} instead."