diff --git a/docs/source/engine.rst b/docs/source/engine.rst index 48fa9cc576c..b7bd846df4d 100644 --- a/docs/source/engine.rst +++ b/docs/source/engine.rst @@ -69,7 +69,7 @@ Resuming the training It is possible to resume the training from a checkpoint and approximately reproduce original run's behaviour. Using Ignite, this can be easily done using :class:`~ignite.handlers.checkpoint.Checkpoint` handler. Engine provides two methods to serialize and deserialize its internal state :meth:`~ignite.engine.engine.Engine.state_dict` and -:meth:`~ignite.engine.engine.Engine.load_state_dict`. In addition to serializing model, optimizer, lr scheduler etc user can +:meth:`~ignite.engine.engine.Engine.load_state_dict`. In addition to serializing model, optimizer, lr scheduler, metrics, etc., user can store the trainer and then resume the training. For example: .. code-block:: python @@ -82,8 +82,9 @@ store the trainer and then resume the training. For example: optimizer = ... lr_scheduler = ... data_loader = ... + metric = ... - to_save = {'trainer': trainer, 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler} + to_save = {'trainer': trainer, 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'metric': metric} handler = Checkpoint(to_save, DiskSaver('/tmp/training', create_dir=True)) trainer.add_event_handler(Events.EPOCH_COMPLETED, handler) trainer.run(data_loader, max_epochs=100) @@ -104,8 +105,9 @@ We can then restore the training from the last checkpoint. optimizer = ... lr_scheduler = ... data_loader = ... + metric = ... - to_load = {'trainer': trainer, 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler} + to_load = {'trainer': trainer, 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'metric': metric} checkpoint = torch.load(checkpoint_file) Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint) diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index 39bd359c289..f4bd9435f70 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -710,7 +710,7 @@ def state_dict(self) -> "OrderedDict[str, List[Tuple[int, str]]]": return OrderedDict([("saved", [(p, f) for p, f in self._saved])]) def load_state_dict(self, state_dict: Mapping) -> None: - """Method replace internal state of the class with provided state dict data. + """Method replaces internal state of the class with provided state dict data. Args: state_dict: a dict with "saved" key and list of ``(priority, filename)`` pairs as values. diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index 9ebc362459a..d9c5f01fbbd 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -551,6 +551,12 @@ def is_attached(self, engine: Engine, usage: Union[str, MetricUsage] = EpochWise return engine.has_event_handler(self.completed, usage.COMPLETED) def state_dict(self) -> OrderedDict: + """Method returns state dict with attributes of the metric specified in its + `_state_dict_all_req_keys` attribute. Can be used to save internal state of the class. + + If there's an active distributed configuration, some collective operations is done and + the list of values across ranks is saved under each attribute's name in the dict. + """ state = OrderedDict() for attr_name in self._state_dict_all_req_keys: attr = getattr(self, attr_name) @@ -565,6 +571,15 @@ def state_dict(self) -> OrderedDict: return state def load_state_dict(self, state_dict: Mapping) -> None: + """Method replaces internal state of the class with provided state dict data. + + If there's an active distributed configuration, the process uses its rank to pick the proper value from + the list of values saved under each attribute's name in the dict. + + Args: + state_dict: a dict containing attributes of the metric specified in its `_state_dict_all_req_keys` + attribute. + """ super().load_state_dict(state_dict) rank = idist.get_local_rank() for attr in self._state_dict_all_req_keys: