Skip to content

Commit

Permalink
Refactor TypeHandler to operate over batches of values, rather than i…
Browse files Browse the repository at this point in the history
…ndividual ones. This allows more flexibility for implementations that may operate more efficiently on batches.

PiperOrigin-RevId: 547892330
  • Loading branch information
cpgaffney1 authored and Flax Authors committed Jul 19, 2023
1 parent 15d6857 commit 83f36c9
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
9 changes: 7 additions & 2 deletions flax/training/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,8 +962,13 @@ def restore_checkpoint(
restore_kwargs['restore_args'] = orbax_utils.restore_args_from_target(
target
)
if orbax_transforms is not None:
restore_kwargs['transforms'] = orbax_transforms
if isinstance(orbax_checkpointer._handler, ocp.PyTreeCheckpointHandler): # pylint: disable=protected-access
restore_kwargs['transforms'] = (
orbax_utils.maybe_construct_transformations(
target, orbax_transforms
)
)
logging.info(restore_kwargs)
restored = orbax_checkpointer.restore(
ckpt_path, item=target, **restore_kwargs)
restored = serialization.to_state_dict(restored)
Expand Down
13 changes: 13 additions & 0 deletions flax/training/orbax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,19 @@ def save_args_from_target(target: Any) -> Any:
)


def maybe_construct_transformations(
target: Any, transforms: Optional[Any]
) -> Any:
if transforms is not None:
return transforms
flat_transforms = {}
flat_target = ocp.utils.to_flat_dict(target, sep='/', keep_empty_nodes=True)
for k, v in flat_target.items():
if v is None:
flat_transforms[k] = ocp.Transform(use_fallback=True)
return flat_transforms


def restore_args_from_target(target: Any, mesh: Optional[Mesh] = None) -> Any:
"""Creates Orbax `restore_args` given a target Pytree.
Expand Down

0 comments on commit 83f36c9

Please sign in to comment.