diff --git a/jax/_src/random.py b/jax/_src/random.py index d8238c4e555c..1f2a46a9dc5a 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -2069,7 +2069,7 @@ def wald(key: KeyArray, Returns: A random array with the specified dtype and with shape given by ``shape`` if - ``shape`` is not None, or else by ``mean.shape`` and ``scale.shape``. + ``shape`` is not None, or else by ``mean.shape``. """ key, _ = _check_prng_key(key) if not dtypes.issubdtype(dtype, np.floating):