You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I just want to open a discussion on the proposal for a new training and evaluation API. It allows creating training loops in a concise way, with each component being replaceable.
A few built-in components aim to cover the most common use cases, while the lightweight core API intended to supply an interface for building custom extensions in a unified way.
The current draft implementation includes
logging to stdout
tensorboard integration
writing regular and best metric checkpoints
Details may be found in the FLIP and a notebook with complete working example may be found here.
With proposed API, one may describe the training as follows:
train_loop=TrainLoop(
init=model().init, # A function to initialize parameters of the model.task=TrainTask(
apply=model().apply, # A function to propagate through the model.optimizer=optax.sgd(learning_rate=0.1), # An optimizer.loss=categorical_cross_entropy, # A loss function.data=train_datastream, # A generator of training examples.
),
n_steps_per_checkpoint=25, # A number of training steps per checkpoint.n_steps=100, # The total number of training steps.
)
eval_loop=EvalLoop(
task=EvalTask(
apply=model().apply,
metrics=dict(
lnpp=categorical_cross_entropy, # Sets log perplexity as an evaluation metric.
),
data=eval_datastream, # A generator of evaluation examples.
),
n_steps=10, # A number of evaluation steps per checkpoint.
)
summary_logger=SummaryLogger()
summary_writer=SummaryWriter(output_dir="/tmp/tensorboard")
checkpoint_writer=CheckpointFileWriter(output_dir="/tmp/checkpoints")
forcheckpointintrain_loop: # Yields a checkpoint each `n_steps_per_checkpoint` step.summary=eval_loop(checkpoint) # Estimates evaluation metrics for the checkpoint._=summary_logger(summary) # Prints summaries to stdout._=summary_writer(summary) # Writes summaries to Tensorboard dir._=checkpoint_writer(summary) # Saves the checkpoint to the local file system.
Total model initialization time is 8.74 seconds.
Total number of trainable weights: 328960 ~ 1.2 MB.
Step 1: Ran 1 train steps in 6.15 seconds
Step 1: train seconds_per_step | 6.14884496
Step 1: train gradients_l2norm | 0.00296191
Step 1: train weights_l2norm | 14.01843071
Step 1: train loss | 5.54912329
Step 1: eval lnpp | 5.54803228
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
I just want to open a discussion on the proposal for a new training and evaluation API. It allows creating training loops in a concise way, with each component being replaceable.
A few built-in components aim to cover the most common use cases, while the lightweight core API intended to supply an interface for building custom extensions in a unified way.
The current draft implementation includes
Details may be found in the FLIP and a notebook with complete working example may be found here.
With proposed API, one may describe the training as follows:
Beta Was this translation helpful? Give feedback.
All reactions