Skip to content

Commit

Permalink
Refactor TypeHandler to operate over batches of values, rather than i…
Browse files Browse the repository at this point in the history
…ndividual ones. This allows more flexibility for implementations that may operate more efficiently on batches.

PiperOrigin-RevId: 547892330
  • Loading branch information
cpgaffney1 authored and Flax Authors committed Jul 20, 2023
1 parent b4591c1 commit 5ccf8db
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
8 changes: 6 additions & 2 deletions flax/training/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,8 +962,12 @@ def restore_checkpoint(
restore_kwargs['restore_args'] = orbax_utils.restore_args_from_target(
target
)
if orbax_transforms is not None:
restore_kwargs['transforms'] = orbax_transforms
if isinstance(orbax_checkpointer._handler, ocp.PyTreeCheckpointHandler): # pylint: disable=protected-access
restore_kwargs['transforms'] = (
orbax_utils.maybe_construct_transformations(
target, orbax_transforms
)
)
restored = orbax_checkpointer.restore(
ckpt_path, item=target, **restore_kwargs)
restored = serialization.to_state_dict(restored)
Expand Down
13 changes: 13 additions & 0 deletions flax/training/orbax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,19 @@ def save_args_from_target(target: Any) -> Any:
)


def maybe_construct_transformations(
target: Any, transforms: Optional[Any]
) -> Any:
if transforms is not None:
return transforms
flat_transforms = {}
flat_target = ocp.utils.to_flat_dict(target, sep='/', keep_empty_nodes=True)
for k, v in flat_target.items():
if v is None:
flat_transforms[k] = ocp.Transform(use_fallback=True)
return flat_transforms


def restore_args_from_target(target: Any, mesh: Optional[Mesh] = None) -> Any:
"""Creates Orbax `restore_args` given a target Pytree.
Expand Down

0 comments on commit 5ccf8db

Please sign in to comment.