Skip to content

random normal numbers, seem to be returned sorted and not completely random #15794

Answered by jakevdp
AlejandroBaron asked this question in Q&A
Discussion options

You must be logged in to vote

The issue is that you're re-using the same random key three times, so your random draws are not independent. If you do this instead, you should get the results you expect:

key1, key2, key3 = random.split(key, 3)

X = random.uniform(key1, (n, p - 1))
X = jnp.concatenate([jnp.ones((n, 1)), X], axis=1)
B = random.randint(key2, (p, 1), 0, 10)

epsilon = random.normal(key3, (n, 1))*0.3

JAX's random number generator is different than numpy and other tools in that it is not stateful, so you must explicitly generate additional keys when creating multiple independent streams. You can read more about this at JAX Sharp Bits: Random Numbers.

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by AlejandroBaron
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants