Full Module RNG Scope #3375
Replies: 1 comment 2 replies
-
Hey @joeryjoery,
Maybe this could be implemented by iteratively running rngs = module.get_rng_keys(*args, **kwargs) You'd sill have to pass the same inputs to
Its tricky to decide how this unique RNG would behave under lifted transforms, maybe we could give it a name like |
Beta Was this translation helpful? Give feedback.
-
Hi, if I implement a Module that makes use of
make_rng
, e.g., Dropout, I now have to manually specify therng
when I use the Module's call function. Though I completely agree that the RNG streams should be split into distinct paths like it is now, it is quite annoying that I have to bookkeep this at every call toModule.apply
.Is there already something like this implemented? Or could this be done in the future?
I'd also argue that flax should allow me to just pass in one global RNG key to either
init
orapply
.This functionality would be great to have with the option to pass in the RNGS explicitly. This has no consequences for reproducibility since the rngs need to be split somewhere anyway.
At the moment the RNG design forces me to write work-arounds in my code which requires me to pass around a namespace that tracks all registered scopes. But this is quite error-prone and ideally this would be handled internally.
Beta Was this translation helpful? Give feedback.
All reactions