Skip to content

Commit

Permalink
Merge pull request #12173 from froystig:random-unwrap
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 471281023
  • Loading branch information
jax authors committed Aug 31, 2022
2 parents da24b99 + 0d3630b commit c1192f3
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 1 deletion.
3 changes: 2 additions & 1 deletion jax/_src/prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,8 @@ def random_wrap_batch_rule(batched_args, batch_dims, *, impl):


def random_unwrap(keys):
assert isinstance(keys, PRNGKeyArray)
if not isinstance(keys, PRNGKeyArray):
raise TypeError(f'random_unwrap takes key array operand, got {type(keys)}')
return random_unwrap_p.bind(keys)

random_unwrap_p = core.Primitive('random_unwrap')
Expand Down
8 changes: 8 additions & 0 deletions jax/_src/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,14 @@ def split(key: KeyArray, num: int = 2) -> KeyArray:
key, wrapped = _check_prng_key(key)
return _return_prng_keys(wrapped, _split(key, num))

def _key_data(keys: KeyArray) -> jnp.ndarray:
assert isinstance(keys, prng.PRNGKeyArray)
return prng.random_unwrap(keys)

def key_data(keys: KeyArray) -> jnp.ndarray:
keys, _ = _check_prng_key(keys)
return _key_data(keys)


### random samplers

Expand Down
1 change: 1 addition & 0 deletions jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@
gamma as gamma,
generalized_normal as generalized_normal,
gumbel as gumbel,
key_data as key_data,
laplace as laplace,
logistic as logistic,
loggamma as loggamma,
Expand Down
36 changes: 36 additions & 0 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1544,6 +1544,42 @@ def test_random_wrap_vmap(self):
self.assertIsInstance(keys, random.KeyArray)
self.assertEqual(keys.shape, (3,))

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_internal" if use_internal else "",
"use_internal": use_internal}
for use_internal in [False, True]))
def test_random_unwrap(self, use_internal):
unwrap = prng_internal.random_unwrap if use_internal else random.key_data
def f(k): return unwrap(k)
k = self.make_keys(3, 4)
out = f(k)
self.assertEqual(out.dtype, np.dtype('uint32'))
self.assertEqual(out.shape[:2], (3, 4))
out = jax.jit(f)(k)
self.assertEqual(out.dtype, np.dtype('uint32'))
self.assertEqual(out.shape[:2], (3, 4))
out = jax.vmap(f)(k)
self.assertEqual(out.dtype, np.dtype('uint32'))
self.assertEqual(out.shape[:2], (3, 4))
out = jax.vmap(jax.jit(f))(k)
self.assertEqual(out.dtype, np.dtype('uint32'))
self.assertEqual(out.shape[:2], (3, 4))

# TODO(frostig): simplify when we always enable_custom_prng
if not (config.jax_enable_custom_prng and use_internal):
return

x = jnp.arange(12, dtype=np.dtype('uint32')).reshape(3, 4)
self.assertRaisesRegex(
TypeError, 'random_unwrap takes key array operand, got .*',
lambda: f(x))
self.assertRaisesRegex(
TypeError, 'random_unwrap takes key array operand, got .*',
lambda: jax.jit(f)(x))
self.assertRaisesRegex(
TypeError, 'random_unwrap takes key array operand, got .*',
lambda: jax.vmap(f)(x))

def test_eval_shape_keys_in(self):
def f(key):
return prng_internal.random_bits(key, bit_width=32, shape=(5,))
Expand Down

0 comments on commit c1192f3

Please sign in to comment.