-
Is there a recommended way for storing and loading an optimizer? |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 8 replies
-
|
Beta Was this translation helpful? Give feedback.
-
load params you need:
|
Beta Was this translation helpful? Give feedback.
-
I would typically save my complete from orbax.checkpoint import CheckpointManager, checkpoint_utils
from jax import tree_util
import optax
TX = TypeVar("TX", bound=optax.OptState)
def restore_optimizer_state(opt_state: TX, restored: Mapping[str, ...]) -> TX:
"""Restore optimizer state from loaded checkpoint (or .msgpack file)."""
return tree_util.tree_unflatten(
tree_util.tree_structure(opt_state), tree_util.tree_leaves(restored)
)
# Make "empty" training state
training_state = ...
mngr = CheckpointManager(...)
# Write checkpoint during previous training, which was interrupted
all_steps = mngr.all_steps(True)
latest_step = max(all_steps)
restore_args = checkpoint_utils.construct_restore_args(training_state)
restored_dict = mngr.restore(latest_step, restore_kwargs={"restore_args": restore_args})
restored_optimizer = restore_optimizer_state(training_state.opt_state, restored_dict["opt_state"])
training_state.replace(params=restored_dict["params"], step=restored_dict["step"], opt_state=restored_optimizer, ...) from flax.serialization import from_bytes, msgpack_restore, to_bytes If you would rather not use from flax.serialization import msgpack_restore, to_bytes
import zlib
def save_as_msgpack(params, save_path: str, compression = None) -> None:
msgpack_bytes: bytes = to_bytes(params)
if compression == "GZIP":
msgpack_bytes = zlib.compress(msgpack_bytes)
with open(save_path, "wb+") as file:
file.write(msgpack_bytes)
def load_from_msgpack(params, save_path: str, compression = None) -> Dict[str, Any]:
bytes_data = file.read()
if compression == "GZIP":
bytes_data = zlib.decompress(bytes_data)
params = msgpack_restore(bytes_data)
return params
# You can dump/reload any pytree using this functions.
# You'd still need to create an "empy" optimizer, to restore actuall pytree,
# since reloaded object will be a plain python dict. |
Beta Was this translation helpful? Give feedback.
I would typically save my complete
flax.training.TrainState
using orbax-checkpoint. Then, before training