Skip to content

Commit

Permalink
Add docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
sadra-barikbin committed Jul 16, 2023
1 parent abc6a8c commit 99a45c6
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 4 deletions.
8 changes: 5 additions & 3 deletions docs/source/engine.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ Resuming the training
It is possible to resume the training from a checkpoint and approximately reproduce original run's behaviour.
Using Ignite, this can be easily done using :class:`~ignite.handlers.checkpoint.Checkpoint` handler. Engine provides two methods
to serialize and deserialize its internal state :meth:`~ignite.engine.engine.Engine.state_dict` and
:meth:`~ignite.engine.engine.Engine.load_state_dict`. In addition to serializing model, optimizer, lr scheduler etc user can
:meth:`~ignite.engine.engine.Engine.load_state_dict`. In addition to serializing model, optimizer, lr scheduler, metrics, etc., user can
store the trainer and then resume the training. For example:

.. code-block:: python
Expand All @@ -82,8 +82,9 @@ store the trainer and then resume the training. For example:
optimizer = ...
lr_scheduler = ...
data_loader = ...
metric = ...
to_save = {'trainer': trainer, 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler}
to_save = {'trainer': trainer, 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'metric': metric}
handler = Checkpoint(to_save, DiskSaver('/tmp/training', create_dir=True))
trainer.add_event_handler(Events.EPOCH_COMPLETED, handler)
trainer.run(data_loader, max_epochs=100)
Expand All @@ -104,8 +105,9 @@ We can then restore the training from the last checkpoint.
optimizer = ...
lr_scheduler = ...
data_loader = ...
metric = ...
to_load = {'trainer': trainer, 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler}
to_load = {'trainer': trainer, 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'metric': metric}
checkpoint = torch.load(checkpoint_file)
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
Expand Down
2 changes: 1 addition & 1 deletion ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ def state_dict(self) -> "OrderedDict[str, List[Tuple[int, str]]]":
return OrderedDict([("saved", [(p, f) for p, f in self._saved])])

def load_state_dict(self, state_dict: Mapping) -> None:
"""Method replace internal state of the class with provided state dict data.
"""Method replaces internal state of the class with provided state dict data.
Args:
state_dict: a dict with "saved" key and list of ``(priority, filename)`` pairs as values.
Expand Down
15 changes: 15 additions & 0 deletions ignite/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,12 @@ def is_attached(self, engine: Engine, usage: Union[str, MetricUsage] = EpochWise
return engine.has_event_handler(self.completed, usage.COMPLETED)

def state_dict(self) -> OrderedDict:
"""Method returns state dict with attributes of the metric specified in its
`_state_dict_all_req_keys` attribute. Can be used to save internal state of the class.
If there's an active distributed configuration, some collective operations is done and
the list of values across ranks is saved under each attribute's name in the dict.
"""
state = OrderedDict()
for attr_name in self._state_dict_all_req_keys:
attr = getattr(self, attr_name)
Expand All @@ -565,6 +571,15 @@ def state_dict(self) -> OrderedDict:
return state

def load_state_dict(self, state_dict: Mapping) -> None:
"""Method replaces internal state of the class with provided state dict data.
If there's an active distributed configuration, the process uses its rank to pick the proper value from
the list of values saved under each attribute's name in the dict.
Args:
state_dict: a dict containing attributes of the metric specified in its `_state_dict_all_req_keys`
attribute.
"""
super().load_state_dict(state_dict)
rank = idist.get_local_rank()
for attr in self._state_dict_all_req_keys:
Expand Down

0 comments on commit 99a45c6

Please sign in to comment.