random normal numbers, seem to be returned sorted and not completely random #15794
-
Hi, I'm trying to learn jax from 0. I've been trying to do some synthetic linear relationship data, so that My surprise has been that, when using the raw epsilons, the data seems to be ordered, so the points at the extremes get a higher error than the ones at the center. The shuffled version works fine, with errors normally distributed around the linear relationship
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
The issue is that you're re-using the same random key three times, so your random draws are not independent. If you do this instead, you should get the results you expect: key1, key2, key3 = random.split(key, 3)
X = random.uniform(key1, (n, p - 1))
X = jnp.concatenate([jnp.ones((n, 1)), X], axis=1)
B = random.randint(key2, (p, 1), 0, 10)
epsilon = random.normal(key3, (n, 1))*0.3 JAX's random number generator is different than |
Beta Was this translation helpful? Give feedback.
The issue is that you're re-using the same random key three times, so your random draws are not independent. If you do this instead, you should get the results you expect:
JAX's random number generator is different than
numpy
and other tools in that it is not stateful, so you must explicitly generate additional keys when creating multiple independent streams. You can read more about this at JAX Sharp Bits: Random Numbers.