Skip to content

Saving an optimizer? #180

Closed Answered by aaarrti
HHalva asked this question in Q&A
Aug 31, 2021 · 3 comments · 8 replies
Discussion options

You must be logged in to vote

I would typically save my complete flax.training.TrainState using orbax-checkpoint. Then, before training

from orbax.checkpoint import CheckpointManager, checkpoint_utils
from jax import tree_util
import optax

TX = TypeVar("TX", bound=optax.OptState)


def restore_optimizer_state(opt_state: TX, restored: Mapping[str, ...]) -> TX:
    """Restore optimizer state from loaded checkpoint (or .msgpack file)."""
    return tree_util.tree_unflatten(
        tree_util.tree_structure(opt_state), tree_util.tree_leaves(restored)
    )

# Make "empty" training state
training_state = ...

mngr = CheckpointManager(...)
# Write checkpoint during previous training, which was interrupted

all_steps = mngr.a…

Replies: 3 comments 8 replies

Comment options

You must be logged in to vote
8 replies
@HHalva
Comment options

@cgarciae
Comment options

@cgarciae
Comment options

@HHalva
Comment options

@rosshemsley
Comment options

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Answer selected by fabianp
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
5 participants