Skip to content

RNG for multiple dropout layers #3262

Answered by cgarciae
minkooseo asked this question in Q&A
Discussion options

You must be logged in to vote

We have long been pending a "Randomness Guide" explaining how make_rng works and its interaction with lifted transforms. For now here is basic idea (BTW this is pseudo code, internal names are different):

You have a path: tuple[str, ...] which is built by the Module system, and you have a count: int that keeps count of how many times make_rng has been called. The trick is to create a hash for the tuple (*self.path, self.count) using hashlib and create a uint32 from it, in the example below this is done in the _stable_hash method. That integer will be the fold_data you pass to jax.random.fold_in to produce a unique derived key from a root key.

  def make_rng(self) -> jax.Array:
    fold_data 

Replies: 3 comments 4 replies

Comment options

You must be logged in to vote
2 replies
@minkooseo
Comment options

@peterdavidfagan
Comment options

Comment options

You must be logged in to vote
1 reply
@peterdavidfagan
Comment options

Comment options

You must be logged in to vote
1 reply
@peterdavidfagan
Comment options

Answer selected by minkooseo
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants