diff --git a/flax/training/checkpoints.py b/flax/training/checkpoints.py index 958053e405..9623cb7776 100644 --- a/flax/training/checkpoints.py +++ b/flax/training/checkpoints.py @@ -962,8 +962,12 @@ def restore_checkpoint( restore_kwargs['restore_args'] = orbax_utils.restore_args_from_target( target ) - if orbax_transforms is not None: - restore_kwargs['transforms'] = orbax_transforms + if isinstance(orbax_checkpointer._handler, ocp.PyTreeCheckpointHandler): # pylint: disable=protected-access + restore_kwargs['transforms'] = ( + orbax_utils.maybe_construct_transformations( + target, orbax_transforms + ) + ) restored = orbax_checkpointer.restore( ckpt_path, item=target, **restore_kwargs) restored = serialization.to_state_dict(restored) diff --git a/flax/training/orbax_utils.py b/flax/training/orbax_utils.py index 113c154fcb..44481a3827 100644 --- a/flax/training/orbax_utils.py +++ b/flax/training/orbax_utils.py @@ -40,6 +40,19 @@ def save_args_from_target(target: Any) -> Any: ) +def maybe_construct_transformations( + target: Any, transforms: Optional[Any] +) -> Any: + if transforms is not None: + return transforms + flat_transforms = {} + flat_target = ocp.utils.to_flat_dict(target, sep='/', keep_empty_nodes=True) + for k, v in flat_target.items(): + if v is None: + flat_transforms[k] = ocp.Transform(use_fallback=True) + return flat_transforms + + def restore_args_from_target(target: Any, mesh: Optional[Mesh] = None) -> Any: """Creates Orbax `restore_args` given a target Pytree.