From c8bb930f4648fa1fc82a6672ff6b10af08b4bd8c Mon Sep 17 00:00:00 2001 From: Colin Gaffney Date: Thu, 20 Jul 2023 12:49:02 -0700 Subject: [PATCH] Refactor TypeHandler to operate over batches of values, rather than individual ones. This allows more flexibility for implementations that may operate more efficiently on batches. PiperOrigin-RevId: 549712245 --- flax/training/checkpoints.py | 8 ++++++-- flax/training/orbax_utils.py | 13 +++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/flax/training/checkpoints.py b/flax/training/checkpoints.py index 958053e405..9623cb7776 100644 --- a/flax/training/checkpoints.py +++ b/flax/training/checkpoints.py @@ -962,8 +962,12 @@ def restore_checkpoint( restore_kwargs['restore_args'] = orbax_utils.restore_args_from_target( target ) - if orbax_transforms is not None: - restore_kwargs['transforms'] = orbax_transforms + if isinstance(orbax_checkpointer._handler, ocp.PyTreeCheckpointHandler): # pylint: disable=protected-access + restore_kwargs['transforms'] = ( + orbax_utils.maybe_construct_transformations( + target, orbax_transforms + ) + ) restored = orbax_checkpointer.restore( ckpt_path, item=target, **restore_kwargs) restored = serialization.to_state_dict(restored) diff --git a/flax/training/orbax_utils.py b/flax/training/orbax_utils.py index 113c154fcb..44481a3827 100644 --- a/flax/training/orbax_utils.py +++ b/flax/training/orbax_utils.py @@ -40,6 +40,19 @@ def save_args_from_target(target: Any) -> Any: ) +def maybe_construct_transformations( + target: Any, transforms: Optional[Any] +) -> Any: + if transforms is not None: + return transforms + flat_transforms = {} + flat_target = ocp.utils.to_flat_dict(target, sep='/', keep_empty_nodes=True) + for k, v in flat_target.items(): + if v is None: + flat_transforms[k] = ocp.Transform(use_fallback=True) + return flat_transforms + + def restore_args_from_target(target: Any, mesh: Optional[Mesh] = None) -> Any: """Creates Orbax `restore_args` given a target Pytree.