Skip to content

Reload opt_state and modify learning rate #262

Closed Answered by rosshemsley
borisdayma asked this question in Q&A
Discussion options

You must be logged in to vote

Hey @borisdayma

The default way to handle the learning rate is by

  • Pass a constant value (which is used for every update)
  • Pass a schedule function (in which case, optax keeps track of the number of steps elapsed, and uses the learning rate computed from the schedule function given the step count).

However, if you'd like more control over the learning rate (or any other hyperparmeter) you can put the hyperparmeters of your optimizer into the optimizer's state and then mutate the state however you would like. This is required because optax optimizers are pure functions - so the only way to dynamically change the behavior is to change the data passed in.

import numpy as np
import optax

# S…

Replies: 1 comment 7 replies

Comment options

You must be logged in to vote
7 replies
@PabloAMC
Comment options

@ddrous
Comment options

@Ridhamz-nd
Comment options

@hericks
Comment options

@vroulet
Comment options

Answer selected by borisdayma
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
8 participants