-
Hi I'm trying to understand how RNGs should be passed when there are multiple dropouts. Below is one example. from typing import Optional
from flax.training import train_state
from jax import random
from jax import lax
from flax.linen.module import merge_param
from typing import Sequence
class TrainState(train_state.TrainState):
key: jax.random.KeyArray
class MyModelMultiple(nn.Module):
num_neurons: int
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(self.num_neurons)(x)
x = nn.Dropout(rate=0.5, deterministic=not training)(x)
x = nn.Dropout(rate=0.5, deterministic=not training)(x)
x = nn.Dropout(rate=0.5, deterministic=not training)(x)
return x
@jax.jit
def train_step(state: TrainState, xs, ys, dropout_key):
dropout_train_key = jax.random.fold_in(
key=dropout_key, data=state.step)
def loss_fn(params):
yhats = state.apply_fn(
{'params': params}, xs, training=True,
rngs={'dropout': dropout_train_key})
loss = jnp.mean((ys - yhats) ** 2)
return loss
grad_fn = jax.value_and_grad(loss_fn)
loss, grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
return state, loss
model = MyModelMultiple(num_neurons=3)
rng1, rng2, rng3 = jax.random.split(dropout_key, 3)
print("* Init")
variables = model.init(params_key, xs, training=False)
params = variables['params']
state = TrainState.create(
apply_fn=model.apply,
params=params,
key=dropout_key,
tx=optax.adam(1e-3),
)
print("* Training")
for i in range(1001):
state, loss = train_step(state, xs, ys, dropout_key)
if i % 100 == 0:
print(f'Iteration {i}: {loss}') The more and more looking into this, it looks to me that those three Dropout will be using same rng which will cause those three act the same. So I changed the code to: from typing import Optional
from flax.linen.stochastic import KeyArray
from flax.training import train_state
from jax import random
from jax import lax
from flax.linen.module import merge_param
from typing import Sequence
class TrainState(train_state.TrainState):
key: jax.random.KeyArray
class MyModelMultiple(nn.Module):
num_neurons: int
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(self.num_neurons)(x)
x = nn.Dropout(rate=0.5, deterministic=not training, rng_collection='drop1')(x)
x = nn.Dropout(rate=0.5, deterministic=not training, rng_collection='drop2')(x)
x = nn.Dropout(rate=0.5, deterministic=not training, rng_collection='drop3')(x)
return x
@jax.jit
def train_step(state: TrainState, xs, ys, dropout_key):
dropout_train_key = jax.random.fold_in(
key=dropout_key, data=state.step)
dkey1, dkey2, dkey3 = jax.random.split(dropout_train_key, 3)
def loss_fn(params):
yhats = state.apply_fn(
{'params': params}, xs, training=True,
rngs={'drop1': dkey1, 'drop2': dkey2, 'drop3': dkey3})
loss = jnp.mean((ys - yhats) ** 2)
return loss
grad_fn = jax.value_and_grad(loss_fn)
loss, grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
return state, loss
model = MyModelMultiple(num_neurons=3)
rng1, rng2, rng3 = jax.random.split(dropout_key, 3)
print("* Init")
variables = model.init(params_key, xs, training=False)
params = variables['params']
state = TrainState.create(
apply_fn=model.apply,
params=params,
key=dropout_key,
tx=optax.adam(1e-3),
)
print("* Training")
for i in range(1001):
state, loss = train_step(state, xs, ys, dropout_key)
if i % 100 == 0:
print(f'Iteration {i}: {loss}') Note that I'm using dkey1, dkey2, and key3 and dropout_train_key = jax.random.fold_in(
key=dropout_key, data=state.step)
dkey1, dkey2, dkey3 = jax.random.split(dropout_train_key, 3)
...
yhats = state.apply_fn(
{'params': params}, xs, training=True,
rngs={'drop1': dkey1, 'drop2': dkey2, 'drop3': dkey3}) How does this sound? Is this standard approach to use multiple dropout? I couldn't find relevant code example demonstrating multiple dropout easily. |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 4 replies
-
Hi @minkooseo,
I think the default behaviour will result in unique rng for each submodule. If you pass in the rng as part of the rng collection this should result in the As a result, I don't believe you need to be concerned with passing a single rng for "dropout" when you are calling the apply method as it should handle the generation of unique rngs for dropout submodules automatically. I may have missed something so I would take my answer with a grain of salt. |
Beta Was this translation helpful? Give feedback.
-
Thank you for discussion. I think the 'scope' is playing the role of preventing 'Dropout' using the same 'dropout' rng key resulting in the same random numbers. Lines 637 to 652 in fe54d39 Test code: import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.linen.module import merge_param
from typing import Sequence, Optional
from flax.core.scope import Scope
# Copy paste of Dropout, but added debug message of rng key
class MyDropout(nn.Module):
rate: float
broadcast_dims: Sequence[int] = ()
deterministic: Optional[bool] = None
rng_collection: str = 'dropout'
@nn.compact
def __call__(self, inputs, deterministic: Optional[bool] = None):
deterministic = merge_param(
'deterministic', self.deterministic, deterministic)
if (self.rate == 0.) or deterministic:
return inputs
# Prevent gradient NaNs in 1.0 edge-case.
if self.rate == 1.0:
return jnp.zeros_like(inputs)
keep_prob = 1. - self.rate
rng = self.make_rng(self.rng_collection)
jax.debug.print(f'{self.parent=}, {self.name=}, {self.scope.name=}, {self.scope.rngs.keys()}, {self.scope.rngs.values()}, {self.scope.rng_counters.keys()=}, {self.scope.rng_counters.values()=}, {self.rng_collection=}, {str(rng)=}')
broadcast_shape = list(inputs.shape)
for dim in self.broadcast_dims:
broadcast_shape[dim] = 1
mask = jax.random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
mask = jnp.broadcast_to(mask, inputs.shape)
return jax.lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs))
class MyModel(nn.Module):
dropout_rate: float
@nn.compact
def __call__(self, x, is_training: bool):
x = nn.Dense(10)(x)
x = MyDropout(rate=0.5, deterministic=not is_training)(x)
x = MyDropout(rate=0.5, deterministic=not is_training)(x)
return x
root_key = jax.random.PRNGKey(0)
params_key, dropout_key, data_key = jax.random.split(root_key, 3)
model = MyModel(0.5)
x = jax.random.uniform(data_key, (10, 10))
params = model.init(params_key, x, is_training=False)['params']
y = model.apply({'params': params}, x, is_training=True, rngs={'dropout': dropout_key}) Output:
So, each 'MyDropout' has suffix of 'MyDropout_0' and 'MyDropout_1'. They're given the same seed of Thus, it's suffice to give 'dropout' rng only in the code. |
Beta Was this translation helpful? Give feedback.
-
We have long been pending a "Randomness Guide" explaining how You have a def make_rng(self) -> jax.Array:
fold_data = self._stable_hash((*self.path, self.count))
self.count += 1
return random.fold_in(self.root, fold_data)
@staticmethod
def _stable_hash(data: tuple[int | str, ...]) -> int:
hash_str = " ".join(str(x) for x in data)
_hash = hashlib.blake2s(hash_str.encode())
hash_bytes = _hash.digest()
# uint32 is represented as 4 bytes in big endian
return int.from_bytes(hash_bytes[:4], byteorder="big") |
Beta Was this translation helpful? Give feedback.
We have long been pending a "Randomness Guide" explaining how
make_rng
works and its interaction with lifted transforms. For now here is basic idea (BTW this is pseudo code, internal names are different):You have a
path: tuple[str, ...]
which is built by the Module system, and you have acount: int
that keeps count of how many timesmake_rng
has been called. The trick is to create a hash for the tuple(*self.path, self.count)
using hashlib and create auint32
from it, in the example below this is done in the_stable_hash
method. That integer will be thefold_data
you pass tojax.random.fold_in
to produce a unique derived key from aroot
key.