Skip to content

Commit

Permalink
add callback to PipelineStage
Browse files Browse the repository at this point in the history
  • Loading branch information
KuoHaoZeng committed Jul 10, 2024
1 parent 0dfed43 commit 7e68e4a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
6 changes: 6 additions & 0 deletions allenact/algorithms/onpolicy_sync/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1551,6 +1551,9 @@ def run_pipeline(self, valid_on_initial_weights: bool = False):
self.training_pipeline.current_stage.training_settings
)

# Change engine attributes that depend on the current stage
self.training_pipeline.current_stage.change_engine_attributes(self)

rollout_storage = self.training_pipeline.rollout_storage
uuid_to_storage = self.training_pipeline.current_stage_storage
self.initialize_storage_and_viz(
Expand Down Expand Up @@ -1644,6 +1647,9 @@ def run_pipeline(self, valid_on_initial_weights: bool = False):
)
uuid_to_storage = new_uuid_to_storage

# Change engine attributes that depend on the current stage
self.training_pipeline.current_stage.change_engine_attributes(self)

already_saved_checkpoint = False

if self.is_distributed:
Expand Down
16 changes: 15 additions & 1 deletion allenact/utils/experiment_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,8 +644,11 @@ def __init__(
stage_components: Optional[Sequence[StageComponent]] = None,
early_stopping_criterion: Optional[EarlyStoppingCriterion] = None,
training_settings: Optional[TrainingSettings] = None,
callback_to_change_engine_attributes: Optional[Dict[str, Any]] = None,
**training_settings_kwargs,
):
self.callback_to_change_engine_attributes = callback_to_change_engine_attributes

# Populate TrainingSettings members
# THIS MUST COME FIRST IN `__init__` as otherwise `__getattr__` will loop infinitely.
assert training_settings is None or len(training_settings_kwargs) == 0
Expand Down Expand Up @@ -707,6 +710,17 @@ def reset(self):
for memory in self.stage_component_uuid_to_stream_memory.values():
memory.clear()

# TODO: Replace Any with the correct type
def change_engine_attributes(self, engine: Any):
if self.callback_to_change_engine_attributes is not None:
for key, value in self.callback_to_change_engine_attributes.items():
# check if the engine has the attribute
assert hasattr(engine, key)

func = value["func"]
args = value["args"]
setattr(engine, key, func(engine, **args))

@property
def stage_components(self) -> Tuple[StageComponent]:
return tuple(self._stage_components)
Expand Down Expand Up @@ -747,7 +761,7 @@ def add_stage_component(self, stage_component: StageComponent):
self.stage_component_uuid_to_stream_memory[stage_component.uuid] = Memory()

def __setattr__(self, key: str, value: Any):
if key != "training_settings" and self.training_settings.has_key(key):
if key not in ["training_settings", "callback_to_change_engine_attributes"] and self.training_settings.has_key(key):
raise NotImplementedError(
f"Cannot set {key} in {self.__name__}, update the"
f" `training_settings` attribute of {self.__name__} instead."
Expand Down

0 comments on commit 7e68e4a

Please sign in to comment.