From 938e7565353d1f8625793a1676b0dbeb26c0402c Mon Sep 17 00:00:00 2001 From: chaoming Date: Mon, 16 May 2022 23:34:01 +0800 Subject: [PATCH 1/7] fix #149: provide dozens of random samplings in NumPy --- brainpy/math/random.py | 1161 +++++++++++++++++++++-------- brainpy/math/tests/test_random.py | 146 +++- brainpy/math/utils.py | 2 +- 3 files changed, 993 insertions(+), 316 deletions(-) diff --git a/brainpy/math/random.py b/brainpy/math/random.py index 10b1c4093..381e72cd8 100644 --- a/brainpy/math/random.py +++ b/brainpy/math/random.py @@ -1,27 +1,33 @@ # -*- coding: utf-8 -*- -import jax.experimental.host_callback + +from collections import namedtuple +from functools import partial + +import jax import numpy as np -import numpy.random -from jax import numpy as jnp -from jax import random as jr +from jax import lax, jit, vmap, numpy as jnp, random as jr, core +from jax.experimental.host_callback import call +from jax.experimental import checkify from jax.tree_util import register_pytree_node from brainpy.math.jaxarray import JaxArray, Variable - from .utils import wraps -from jax.experimental.host_callback import call as hcb_call - __all__ = [ - 'RandomState', - - 'seed', - - 'rand', 'randint', 'randn', 'random', 'random_sample', 'ranf', 'sample', 'choice', 'permutation', 'shuffle', - 'beta', 'exponential', 'gamma', 'gumbel', 'laplace', 'logistic', 'normal', 'pareto', 'poisson', 'standard_cauchy', - 'standard_exponential', 'standard_gamma', 'standard_normal', 'standard_t', 'uniform', 'truncated_normal', 'bernoulli', - - 'lognormal', + 'RandomState', 'Generator', + + 'seed', 'default_rng', + + 'rand', 'randint', 'random_integers', 'randn', 'random', + 'random_sample', 'ranf', 'sample', 'choice', 'permutation', 'shuffle', 'beta', + 'exponential', 'gamma', 'gumbel', 'laplace', 'logistic', 'normal', 'pareto', + 'poisson', 'standard_cauchy', 'standard_exponential', 'standard_gamma', + 'standard_normal', 'standard_t', 'uniform', 'truncated_normal', 'bernoulli', + 'lognormal', 'binomial', 'chisquare', 'dirichlet', 'geometric', 'f', + 'hypergeometric', 'logseries', 'multinomial', 'multivariate_normal', + 'negative_binomial', 'noncentral_chisquare', 'noncentral_f', 'power', + 'rayleigh', 'triangular', 'vonmises', 'wald', 'weibull', 'weibull_min', + 'zipf', 'maxwell' ] @@ -45,6 +51,326 @@ def _check_shape(name, shape, *param_shapes): raise ValueError(msg.format(name, s, shape)) +def _remove_jax_array(a): + return a.value if isinstance(a, JaxArray) else a + + +_tr_params = namedtuple( + "tr_params", ["c", "b", "a", "alpha", "u_r", "v_r", "m", "log_p", "log1_p", "log_h"] +) + + +def _get_tr_params(n, p): + # See Table 1. Additionally, we pre-compute log(p), log1(-p) and the + # constant terms, that depend only on (n, p, m) in log(f(k)) (bottom of page 5). + mu = n * p + spq = jnp.sqrt(mu * (1 - p)) + c = mu + 0.5 + b = 1.15 + 2.53 * spq + a = -0.0873 + 0.0248 * b + 0.01 * p + alpha = (2.83 + 5.1 / b) * spq + u_r = 0.43 + v_r = 0.92 - 4.2 / b + m = jnp.floor((n + 1) * p).astype(n.dtype) + log_p = jnp.log(p) + log1_p = jnp.log1p(-p) + log_h = (m + 0.5) * (jnp.log((m + 1.0) / (n - m + 1.0)) + log1_p - log_p) + ( + _stirling_approx_tail(m) + _stirling_approx_tail(n - m) + ) + return _tr_params(c, b, a, alpha, u_r, v_r, m, log_p, log1_p, log_h) + + +def _stirling_approx_tail(k): + precomputed = jnp.array([0.08106146679532726, + 0.04134069595540929, + 0.02767792568499834, + 0.02079067210376509, + 0.01664469118982119, + 0.01387612882307075, + 0.01189670994589177, + 0.01041126526197209, + 0.009255462182712733, + 0.008330563433362871, ]) + kp1 = k + 1 + kp1sq = (k + 1) ** 2 + return jnp.where( + k < 10, precomputed[k], + (1.0 / 12 - (1.0 / 360 - (1.0 / 1260) / kp1sq) / kp1sq) / kp1, + ) + + +def _binomial_btrs(key, p, n): + """ + Based on the transformed rejection sampling algorithm (BTRS) from the + following reference: + + Hormann, "The Generation of Binonmial Random Variates" + (https://core.ac.uk/download/pdf/11007254.pdf) + """ + + def _btrs_body_fn(val): + _, key, _, _ = val + key, key_u, key_v = jr.split(key, 3) + u = jr.uniform(key_u) + v = jr.uniform(key_v) + u = u - 0.5 + k = jnp.floor( + (2 * tr_params.a / (0.5 - jnp.abs(u)) + tr_params.b) * u + tr_params.c + ).astype(n.dtype) + return k, key, u, v + + def _btrs_cond_fn(val): + def accept_fn(k, u, v): + # See acceptance condition in Step 3. (Page 3) of TRS algorithm + # v <= f(k) * g_grad(u) / alpha + + m = tr_params.m + log_p = tr_params.log_p + log1_p = tr_params.log1_p + # See: formula for log(f(k)) at bottom of Page 5. + log_f = ( + (n + 1.0) * jnp.log((n - m + 1.0) / (n - k + 1.0)) + + (k + 0.5) * (jnp.log((n - k + 1.0) / (k + 1.0)) + log_p - log1_p) + + (_stirling_approx_tail(k) - _stirling_approx_tail(n - k)) + + tr_params.log_h + ) + g = (tr_params.a / (0.5 - jnp.abs(u)) ** 2) + tr_params.b + return jnp.log((v * tr_params.alpha) / g) <= log_f + + k, key, u, v = val + early_accept = (jnp.abs(u) <= tr_params.u_r) & (v <= tr_params.v_r) + early_reject = (k < 0) | (k > n) + return lax.cond( + early_accept | early_reject, + (), + lambda _: ~early_accept, + (k, u, v), + lambda x: ~accept_fn(*x), + ) + + tr_params = _get_tr_params(n, p) + ret = lax.while_loop( + _btrs_cond_fn, _btrs_body_fn, (-1, key, 1.0, 1.0) + ) # use k=-1 initially so that cond_fn returns True + return ret[0] + + +def _binomial_inversion(key, p, n): + def _binom_inv_body_fn(val): + i, key, geom_acc = val + key, key_u = jr.split(key) + u = jr.uniform(key_u) + geom = jnp.floor(jnp.log1p(-u) / log1_p) + 1 + geom_acc = geom_acc + geom + return i + 1, key, geom_acc + + def _binom_inv_cond_fn(val): + i, _, geom_acc = val + return geom_acc <= n + + log1_p = jnp.log1p(-p) + ret = lax.while_loop(_binom_inv_cond_fn, _binom_inv_body_fn, (-1, key, 0.0)) + return ret[0] + + +def _binomial_dispatch(key, p, n): + def dispatch(key, p, n): + is_le_mid = p <= 0.5 + pq = jnp.where(is_le_mid, p, 1 - p) + mu = n * pq + k = lax.cond( + mu < 10, + (key, pq, n), + lambda x: _binomial_inversion(*x), + (key, pq, n), + lambda x: _binomial_btrs(*x), + ) + return jnp.where(is_le_mid, k, n - k) + + # Return 0 for nan `p` or negative `n`, since nan values are not allowed for integer types + cond0 = jnp.isfinite(p) & (n > 0) & (p > 0) + return lax.cond( + cond0 & (p < 1), + (key, p, n), + lambda x: dispatch(*x), + (), + lambda _: jnp.where(cond0, n, 0), + ) + + +@partial(jit, static_argnums=(3,)) +def _binomial(key, p, n, shape): + shape = shape or lax.broadcast_shapes(jnp.shape(p), jnp.shape(n)) + # reshape to map over axis 0 + p = jnp.reshape(jnp.broadcast_to(p, shape), -1) + n = jnp.reshape(jnp.broadcast_to(n, shape), -1) + key = jr.split(key, jnp.size(p)) + if jax.default_backend() == "cpu": + ret = lax.map(lambda x: _binomial_dispatch(*x), (key, p, n)) + else: + ret = vmap(lambda *x: _binomial_dispatch(*x))(key, p, n) + return jnp.reshape(ret, shape) + + +@partial(jit, static_argnums=(2,)) +def _categorical(key, p, shape): + # this implementation is fast when event shape is small, and slow otherwise + # Ref: https://stackoverflow.com/a/34190035 + shape = shape or p.shape[:-1] + s = jnp.cumsum(p, axis=-1) + r = jr.uniform(key, shape=shape + (1,)) + return jnp.sum(s < r, axis=-1) + + +def _scatter_add_one(operand, indices, updates): + return lax.scatter_add( + operand, + indices, + updates, + lax.ScatterDimensionNumbers( + update_window_dims=(), + inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(0,), + ), + ) + + +def _reshape(x, shape): + if isinstance(x, (int, float, np.ndarray, np.generic)): + return np.reshape(x, shape) + else: + return jnp.reshape(x, shape) + + +def _promote_shapes(*args, shape=()): + # adapted from lax.lax_numpy + if len(args) < 2 and not shape: + return args + else: + shapes = [jnp.shape(arg) for arg in args] + num_dims = len(lax.broadcast_shapes(shape, *shapes)) + return [ + _reshape(arg, (1,) * (num_dims - len(s)) + s) if len(s) < num_dims else arg + for arg, s in zip(args, shapes) + ] + + +@partial(jit, static_argnums=(3, 4)) +def _multinomial(key, p, n, n_max, shape=()): + if jnp.shape(n) != jnp.shape(p)[:-1]: + broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1]) + n = jnp.broadcast_to(n, broadcast_shape) + p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:]) + shape = shape or p.shape[:-1] + if n_max == 0: + return jnp.zeros(shape + p.shape[-1:], dtype=jnp.result_type(int)) + # get indices from categorical distribution then gather the result + indices = _categorical(key, p, (n_max,) + shape) + # mask out values when counts is heterogeneous + if jnp.ndim(n) > 0: + mask = _promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0] + mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype) + excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1), + jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,))], + -1) + else: + mask = 1 + excess = 0 + # NB: we transpose to move batch shape to the front + indices_2D = (jnp.reshape(indices * mask, (n_max, -1))).T + samples_2D = vmap(_scatter_add_one, (0, 0, 0))(jnp.zeros((indices_2D.shape[0], p.shape[-1]), dtype=indices.dtype), + jnp.expand_dims(indices_2D, axis=-1), + jnp.ones(indices_2D.shape, dtype=indices.dtype)) + return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess + + +@partial(jit, static_argnums=(2, 3)) +def _von_mises_centered(key, concentration, shape, dtype=jnp.float64): + """Compute centered von Mises samples using rejection sampling from [1]_ with wrapped Cauchy proposal. + + Returns + ------- + out: array_like + centered samples from von Mises + + References + ---------- + .. [1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag, 1986; + Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf + + """ + shape = shape or jnp.shape(concentration) + dtype = jnp.result_type(dtype) + concentration = lax.convert_element_type(concentration, dtype) + concentration = jnp.broadcast_to(concentration, shape) + + s_cutoff_map = { + jnp.dtype(jnp.float16): 1.8e-1, + jnp.dtype(jnp.float32): 2e-2, + jnp.dtype(jnp.float64): 1.2e-4, + } + s_cutoff = s_cutoff_map.get(dtype) + + r = 1.0 + jnp.sqrt(1.0 + 4.0 * concentration ** 2) + rho = (r - jnp.sqrt(2.0 * r)) / (2.0 * concentration) + s_exact = (1.0 + rho ** 2) / (2.0 * rho) + + s_approximate = 1.0 / concentration + + s = jnp.where(concentration > s_cutoff, s_exact, s_approximate) + + def cond_fn(*args): + """check if all are done or reached max number of iterations""" + i, _, done, _, _ = args[0] + return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done))) + + def body_fn(*args): + i, key, done, _, w = args[0] + uni_ukey, uni_vkey, key = jr.split(key, 3) + u = jr.uniform( + key=uni_ukey, + shape=shape, + dtype=concentration.dtype, + minval=-1.0, + maxval=1.0, + ) + z = jnp.cos(jnp.pi * u) + w = jnp.where(done, w, (1.0 + s * z) / (s + z)) # Update where not done + y = concentration * (s - w) + v = jr.uniform(key=uni_vkey, shape=shape, dtype=concentration.dtype) + accept = (y * (2.0 - y) >= v) | (jnp.log(y / v) + 1.0 >= y) + return i + 1, key, accept | done, u, w + + init_done = jnp.zeros(shape, dtype=bool) + init_u = jnp.zeros(shape) + init_w = jnp.zeros(shape) + + _, _, done, u, w = lax.while_loop( + cond_fun=cond_fn, + body_fun=body_fn, + init_val=(jnp.array(0), key, init_done, init_u, init_w), + ) + + return jnp.sign(u) * jnp.arccos(w) + + +def _loc_scale(loc, scale, value): + if loc is None: + if scale is None: + return JaxArray(value) + else: + return JaxArray(value * scale) + else: + if scale is None: + return JaxArray(value + loc) + else: + return JaxArray(value * scale + loc) + + +def _check_py_seq(seq): + return jnp.asarray(seq) if isinstance(seq, (tuple, list)) else seq + + class RandomState(Variable): """RandomState that track the random generator state. """ __slots__ = () @@ -57,7 +383,8 @@ def __init__(self, seed=None): seed : int, jax.DeviceArray, Optional The initial seed of the random number generator. """ - if seed is None: seed = np.random.randint(0, 100000, 2, dtype=np.uint32) + if seed is None: + seed = np.random.randint(0, 100000, 2, dtype=np.uint32) if isinstance(seed, int): key = jr.PRNGKey(seed) else: @@ -109,10 +436,37 @@ def split_keys(self, n): def rand(self, *dn): return JaxArray(jr.uniform(self.split_key(), shape=dn, minval=0., maxval=1.)) - def randint(self, low, high=None, size=None, dtype=int): - return JaxArray(jr.randint(self.split_key(), shape=_size2shape(size), + def randint(self, low, high=None, size=None, dtype=jnp.int_): + low = _remove_jax_array(low) + high = _remove_jax_array(high) + if high is None: + high = low + low = 0 + high = _check_py_seq(high) + low = _check_py_seq(low) + if size is None: + size = lax.broadcast_shapes(jnp.shape(low), + jnp.shape(high)) + return JaxArray(jr.randint(self.split_key(), + shape=_size2shape(size), minval=low, maxval=high, dtype=dtype)) + def random_integers(self, low, high=None, size=None): + low = _remove_jax_array(low) + high = _remove_jax_array(high) + low = _check_py_seq(low) + high = _check_py_seq(high) + if high is None: + high = low + low = 1 + high += 1 + if size is None: + size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high)) + return JaxArray(jr.randint(self.split_key(), + shape=_size2shape(size), + minval=low, + maxval=high)) + def randn(self, *dn): return JaxArray(jr.normal(self.split_key(), shape=dn)) @@ -122,18 +476,23 @@ def random(self, size=None): def random_sample(self, size=None): return self.random(size=size) - def randf(self, size=None): + def ranf(self, size=None): return self.random(size=size) def sample(self, size=None): return self.random(size=size) def choice(self, a, size=None, replace=True, p=None): - a = a.value if isinstance(a, JaxArray) else a - return JaxArray(jr.choice(self.split_key(), a=a, shape=_size2shape(size), replace=replace, p=p)) + a = _remove_jax_array(a) + p = _remove_jax_array(p) + a = _check_py_seq(a) + p = _check_py_seq(p) + return JaxArray(jr.choice(self.split_key(), a=a, shape=_size2shape(size), + replace=replace, p=p)) def permutation(self, x): x = x.value if isinstance(x, JaxArray) else x + x = _check_py_seq(x) return JaxArray(jr.permutation(self.split_key(), x)) def shuffle(self, x, axis=0): @@ -143,38 +502,83 @@ def shuffle(self, x, axis=0): def beta(self, a, b, size=None): a = a.value if isinstance(a, JaxArray) else a b = b.value if isinstance(b, JaxArray) else b + a = _check_py_seq(a) + b = _check_py_seq(b) + if size is None: + size = lax.broadcast_shapes(jnp.shape(a), jnp.shape(b)) return JaxArray(jr.beta(self.split_key(), a=a, b=b, shape=_size2shape(size))) - def exponential(self, scale=1.0, size=None): - assert scale == 1. - return JaxArray(jr.exponential(self.split_key(), shape=_size2shape(size))) - - def gamma(self, shape, scale=1.0, size=None): - assert scale == 1. - return JaxArray(jr.gamma(self.split_key(), a=shape, shape=_size2shape(size))) - - def gumbel(self, loc=0.0, scale=1.0, size=None): - assert loc == 0. - assert scale == 1. - return JaxArray(jr.gumbel(self.split_key(), shape=_size2shape(size))) - - def laplace(self, loc=0.0, scale=1.0, size=None): - assert loc == 0. - assert scale == 1. - return JaxArray(jr.laplace(self.split_key(), shape=_size2shape(size))) - - def logistic(self, loc=0.0, scale=1.0, size=None): - assert loc == 0. - assert scale == 1. - return JaxArray(jr.logistic(self.split_key(), shape=_size2shape(size))) - - def normal(self, loc=0.0, scale=1.0, size=None): - return JaxArray(jr.normal(self.split_key(), shape=_size2shape(size)) * scale + loc) + def exponential(self, scale=None, size=None): + scale = _remove_jax_array(scale) + scale = _check_py_seq(scale) + if size is None: + size = jnp.shape(scale) + r = jr.exponential(self.split_key(), shape=_size2shape(size)) + if scale is None: + return JaxArray(r) + else: + return JaxArray(r / scale) + + def gamma(self, shape, scale=None, size=None): + shape = _remove_jax_array(shape) + scale = _remove_jax_array(scale) + shape = _check_py_seq(shape) + scale = _check_py_seq(scale) + if size is None: + size = lax.broadcast_shapes(jnp.shape(shape), jnp.shape(scale)) + r = jr.gamma(self.split_key(), a=shape, shape=_size2shape(size)) + if scale is None: + return JaxArray(r) + else: + return JaxArray(r * scale) + + def gumbel(self, loc=None, scale=None, size=None): + loc = _remove_jax_array(loc) + scale = _remove_jax_array(scale) + loc = _check_py_seq(loc) + scale = _check_py_seq(scale) + if size is None: + size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)) + return _loc_scale(loc, scale, jr.gumbel(self.split_key(), shape=size)) + + def laplace(self, loc=None, scale=None, size=None): + loc = _remove_jax_array(loc) + scale = _remove_jax_array(scale) + loc = _check_py_seq(loc) + scale = _check_py_seq(scale) + if size is None: + size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)) + return _loc_scale(loc, scale, jr.laplace(self.split_key(), shape=size)) + + def logistic(self, loc=None, scale=None, size=None): + loc = _remove_jax_array(loc) + scale = _remove_jax_array(scale) + loc = _check_py_seq(loc) + scale = _check_py_seq(scale) + if size is None: + size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)) + return _loc_scale(loc, scale, jr.logistic(self.split_key(), shape=size)) + + def normal(self, loc=None, scale=None, size=None): + loc = _remove_jax_array(loc) + scale = _remove_jax_array(scale) + loc = _check_py_seq(loc) + scale = _check_py_seq(scale) + if size is None: + size = lax.broadcast_shapes(jnp.shape(scale), jnp.shape(loc)) + return _loc_scale(loc, scale, jr.normal(self.split_key(), shape=size)) def pareto(self, a, size=None): + a = _remove_jax_array(a) + a = _check_py_seq(a) + if size is None: + size = jnp.shape(a) return JaxArray(jr.pareto(self.split_key(), b=a, shape=_size2shape(size))) def poisson(self, lam=1.0, size=None): + lam = _check_py_seq(_remove_jax_array(lam)) + if size is None: + size = jnp.shape(lam) return JaxArray(jr.poisson(self.split_key(), lam=lam, shape=_size2shape(size))) def standard_cauchy(self, size=None): @@ -184,204 +588,493 @@ def standard_exponential(self, size=None): return JaxArray(jr.exponential(self.split_key(), shape=_size2shape(size))) def standard_gamma(self, shape, size=None): + shape = _remove_jax_array(shape) + shape = _check_py_seq(shape) + if size is None: + size = jnp.shape(shape) return JaxArray(jr.gamma(self.split_key(), a=shape, shape=_size2shape(size))) def standard_normal(self, size=None): return JaxArray(jr.normal(self.split_key(), shape=_size2shape(size))) def standard_t(self, df, size=None): + df = _remove_jax_array(df) + df = _check_py_seq(df) + if size is None: + size = jnp.shape(size) return JaxArray(jr.t(self.split_key(), df=df, shape=_size2shape(size))) def uniform(self, low=0.0, high=1.0, size=None): - return JaxArray(jr.uniform(self.split_key(), shape=_size2shape(size), minval=low, maxval=high)) - - def truncated_normal(self, lower, upper, size, scale=1.): + low = _remove_jax_array(low) + high = _remove_jax_array(high) + low = _check_py_seq(low) + high = _check_py_seq(high) + if size is None: + size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high)) + return JaxArray(jr.uniform(self.split_key(), + shape=_size2shape(size), + minval=low, + maxval=high)) + + def truncated_normal(self, lower, upper, size, scale=None): + lower = _remove_jax_array(lower) + lower = _check_py_seq(lower) + upper = _remove_jax_array(upper) + upper = _check_py_seq(upper) + scale = _remove_jax_array(scale) + scale = _check_py_seq(scale) + if size is None: + size = lax.broadcast_shapes(jnp.shape(lower), + jnp.shape(upper), + jnp.shape(scale)) rands = jr.truncated_normal(self.split_key(), lower=lower, upper=upper, shape=_size2shape(size)) - return JaxArray(rands * scale) + if scale is None: + return JaxArray(rands) + else: + return JaxArray(rands * scale) def bernoulli(self, p, size=None): + p = _remove_jax_array(p) + p = _check_py_seq(p) + checkify.check(jnp.all(jnp.logical_and(p >= 0, p <= 1)), 'Bernoulli parameter p should be within [0, 1]') + if size is None: + size = jnp.shape(p) return JaxArray(jr.bernoulli(self.split_key(), p=p, shape=_size2shape(size))) - def lognormallognormal(self, mean=0.0, sigma=1.0, size=None): + def lognormal(self, mean=None, sigma=None, size=None): + mean = _check_py_seq(_remove_jax_array(mean)) + sigma = _check_py_seq(_remove_jax_array(sigma)) + if size is None: + size = jnp.broadcast_shapes(jnp.shape(mean), + jnp.shape(sigma)) samples = jr.normal(self.split_key(), shape=_size2shape(size)) - samples = samples * sigma + mean - samples = jnp.exp(samples) + samples = _loc_scale(mean, sigma, samples) + samples = jnp.exp(samples.value) + return JaxArray(samples) + + def binomial(self, n, p, size=None): + n = n.value if isinstance(n, JaxArray) else n + p = p.value if isinstance(p, JaxArray) else p + n = _check_py_seq(n) + p = _check_py_seq(p) + checkify.check(jnp.all(jnp.logical_and(p >= 0, p <= 1)), '"p" must be in [0, 1].') + if size is None: + size = jnp.broadcast_shapes(jnp.shape(n), jnp.shape(p)) + return JaxArray(_binomial(self.split_key(), p, n, shape=_size2shape(size))) + + def chisquare(self, df, size=None): + df = _check_py_seq(_remove_jax_array(df)) + if size is None: + if jnp.ndim(df) == 0: + dist = jr.normal(self.split_key(), (df,)) ** 2 + dist = dist.sum() + else: + raise NotImplementedError('Do not support non-scale "df" when "size" is None') + else: + dist = jr.normal(self.split_key(), (df,) + _size2shape(size)) ** 2 + dist = dist.sum(axis=0) + return JaxArray(dist) + + def dirichlet(self, alpha, size=None): + alpha = _check_py_seq(_remove_jax_array(alpha)) + return JaxArray(jr.dirichlet(self.split_key(), alpha=alpha, shape=_size2shape(size))) + + def geometric(self, p, size=None): + p = _remove_jax_array(p) + p = _check_py_seq(p) + if size is None: + size = jnp.shape(p) + u = jr.uniform(self.split_key(), size) + r = jnp.floor(jnp.log1p(-u) / jnp.log1p(-p)) + return JaxArray(r) + + def multinomial(self, n, pvals, size=None): + n = _check_py_seq(_remove_jax_array(n)) + pvals = _check_py_seq(_remove_jax_array(pvals)) + checkify.check(jnp.sum(pvals[:-1]) <= 1., 'We require `sum(pvals[:-1]) <= 1`.') + if isinstance(n, jax.core.Tracer): + raise ValueError("The total count parameter `n` should not be a jax abstract array.") + size = _size2shape(size) + n_max = int(np.max(jax.device_get(n))) + batch_shape = lax.broadcast_shapes(jnp.shape(pvals)[:-1], jnp.shape(n)) + return JaxArray(_multinomial(self.split_key(), pvals, n, n_max, batch_shape + size)) + + def multivariate_normal(self, mean, cov, size=None): + mean = _check_py_seq(_remove_jax_array(mean)) + cov = _check_py_seq(_remove_jax_array(cov)) + size = _size2shape(size) + scale = jnp.linalg.cholesky(cov) + batch_shape = lax.broadcast_shapes(jnp.shape(mean)[:-2], jnp.shape(scale)[:-2]) + event_shape = jnp.shape(scale)[-1:] + eps = jr.normal(self.split_key(), shape=size + batch_shape + event_shape) + r = mean + jnp.squeeze(jnp.matmul(scale, eps[..., jnp.newaxis]), axis=-1) + return JaxArray(r) + + def rayleigh(self, scale=1.0, size=None): + scale = _check_py_seq(_remove_jax_array(scale)) + if size is None: + size = jnp.shape(scale) + x = jnp.sqrt(-2. * jnp.log(jr.uniform(self.split_key(), shape=size, minval=0, maxval=1))) + return JaxArray(x * scale) + + def triangular(self, size=None): + size = _size2shape(size) + bernoulli_samples = jr.bernoulli(self.split_key(), p=0.5, shape=size) + return JaxArray(2 * bernoulli_samples - 1) + + def vonmises(self, mu, kappa, size=None): + mu = _check_py_seq(_remove_jax_array(mu)) + kappa = _check_py_seq(_remove_jax_array(kappa)) + if size is None: + size = lax.broadcast_shapes(jnp.shape(mu), jnp.shape(kappa)) + size = _size2shape(size) + samples = _von_mises_centered(self.split_key(), kappa, size) + samples = samples + mu + samples = (samples + jnp.pi) % (2.0 * jnp.pi) - jnp.pi return JaxArray(samples) + def weibull(self, a, size=None): + a = _check_py_seq(_remove_jax_array(a)) + if size is None: + size = jnp.shape(a) + else: + if jnp.size(a) > 1: + raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}') + size = _size2shape(size) + random_uniform = jr.uniform(key=self.split_key(), shape=size, minval=0, maxval=1) + r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a) + return JaxArray(r) + def weibull_min(self, a, scale=None, size=None): + """Sample from a Weibull minimum distribution. + Parameters + ---------- + a: float, array_like + The concentration parameter of the distribution. + scale: float, array_like + The scale parameter of the distribution. + size: optional, int, tuple of int + The shape added to the parameters loc and scale broadcastable shape. + + Returns + ------- + out: array_like + The sampling results. + """ + a = _check_py_seq(_remove_jax_array(a)) + scale = _check_py_seq(_remove_jax_array(scale)) + if size is None: + size = jnp.broadcast_shapes(jnp.shape(a), jnp.shape(scale)) + else: + if jnp.size(a) > 1: + raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}') + size = _size2shape(size) + random_uniform = jr.uniform(key=self.split_key(), shape=size, minval=0, maxval=1) + r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a) + if scale is not None: + r /= scale + return JaxArray(r) + + def maxwell(self, size=None): + shape = core.canonicalize_shape(_size2shape(size)) + (3,) + norm_rvs = jr.normal(key=self.split_key(), shape=shape) + return JaxArray(jnp.linalg.norm(norm_rvs, axis=-1)) + + def negative_binomial(self, n, p, size=None): + n = _check_py_seq(_remove_jax_array(n)) + p = _check_py_seq(_remove_jax_array(p)) + if size is None: + size = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)) + size = _size2shape(size) + logits = jnp.log(p) - jnp.log1p(-p) + rate = self.gamma(shape=n, scale=jnp.exp(-logits), size=size) + return JaxArray(self.poisson(lam=rate)) + + def wald(self, mean, scale, size=None): + mean = _check_py_seq(_remove_jax_array(mean)) + scale = _check_py_seq(_remove_jax_array(scale)) + if size is None: + size = lax.broadcast_shapes(jnp.shape(mean), jnp.shape(scale)) + size = _size2shape(size) + sampled_chi2 = jnp.square(self.randn(*size).value) + sampled_uniform = self.uniform(size=size).value + # Wikipedia defines an intermediate x with the formula + # x = loc + loc ** 2 * y / (2 * conc) + # - loc / (2 * conc) * sqrt(4 * loc * conc * y + loc ** 2 * y ** 2) + # where y ~ N(0, 1)**2 (sampled_chi2 above) and conc is the concentration. + # Let us write + # w = loc * y / (2 * conc) + # Then we can extract the common factor in the last two terms to obtain + # x = loc + loc * w * (1 - sqrt(2 / w + 1)) + # Now we see that the Wikipedia formula suffers from catastrphic + # cancellation for large w (e.g., if conc << loc). + # + # Fortunately, we can fix this by multiplying both sides + # by 1 + sqrt(2 / w + 1). We get + # x * (1 + sqrt(2 / w + 1)) = + # = loc * (1 + sqrt(2 / w + 1)) + loc * w * (1 - (2 / w + 1)) + # = loc * (sqrt(2 / w + 1) - 1) + # The term sqrt(2 / w + 1) + 1 no longer presents numerical + # difficulties for large w, and sqrt(2 / w + 1) - 1 is just + # sqrt1pm1(2 / w), which we know how to compute accurately. + # This just leaves the matter of small w, where 2 / w may + # overflow. In the limit a w -> 0, x -> loc, so we just mask + # that case. + sqrt1pm1_arg = 4 * scale / (mean * sampled_chi2) # 2 / w above + safe_sqrt1pm1_arg = jnp.where(sqrt1pm1_arg < np.inf, sqrt1pm1_arg, 1.0) + denominator = 1.0 + jnp.sqrt(safe_sqrt1pm1_arg + 1.0) + ratio = jnp.expm1(0.5 * jnp.log1p(safe_sqrt1pm1_arg)) / denominator + sampled = mean * jnp.where(sqrt1pm1_arg < np.inf, ratio, 1.0) # x above + res = jnp.where(sampled_uniform <= mean / (mean + sampled), + sampled, + jnp.square(mean) / sampled) + return JaxArray(res) + + def noncentral_chisquare(self, df, nonc, size=None): + df = _check_py_seq(_remove_jax_array(df)) + nonc = _check_py_seq(_remove_jax_array(nonc)) + if size is None: + size = lax.broadcast_shapes(jnp.shape(df), jnp.shape(nonc)) + size = _size2shape(size) + i = jr.poisson(self.split_key(), 0.5 * nonc, shape=size) + n = jr.normal(self.split_key(), shape=size) + jnp.sqrt(nonc) + cond = jnp.greater(df, 1.0) + df2 = jnp.where(cond, df - 1.0, df + 2.0 * i) + chi2 = 2.0 * jr.gamma(self.split_key(), 0.5 * df2, shape=size) + return JaxArray(jnp.where(cond, chi2 + n * n, chi2)) + + def zipf(self, a, size=None): + a = _check_py_seq(_remove_jax_array(a)) + if size is None: + size = jnp.shape(a) + return JaxArray(call(lambda x: np.random.zipf(x, size), + a, + result_shape=jax.ShapeDtypeStruct(size, jnp.int_))) + + def power(self, a, size=None): + a = _check_py_seq(_remove_jax_array(a)) + if size is None: + size = jnp.shape(a) + size = _size2shape(size) + return JaxArray(call(lambda a: np.random.power(a=a, size=size), + a, result_shape=jax.ShapeDtypeStruct(size, jnp.float_))) + + def f(self, dfnum, dfden, size=None): + dfnum = _remove_jax_array(dfnum) + dfden = _remove_jax_array(dfden) + dfnum = _check_py_seq(dfnum) + dfden = _check_py_seq(dfden) + if size is None: + size = jnp.broadcast_shapes(jnp.shape(dfnum), jnp.shape(dfden)) + size = _size2shape(size) + d = {'dfnum': dfnum, 'dfden': dfden} + return JaxArray(call(lambda x: np.random.f(dfnum=x['dfnum'], + dfden=x['dfden'], + size=size), + d, + result_shape=jax.ShapeDtypeStruct(size, jnp.float_))) + + def hypergeometric(self, ngood, nbad, nsample, size=None): + ngood = _check_py_seq(_remove_jax_array(ngood)) + nbad = _check_py_seq(_remove_jax_array(nbad)) + nsample = _check_py_seq(_remove_jax_array(nsample)) + + if size is None: + size = lax.broadcast_shapes(jnp.shape(ngood), + jnp.shape(nbad), + jnp.shape(nsample)) + size = _size2shape(size) + d = {'ngood': ngood, 'nbad': nbad, 'nsample': nsample} + return JaxArray(call(lambda x: np.random.hypergeometric(ngood=x['ngood'], + nbad=x['nbad'], + nsample=x['nsample'], + size=size), + d, result_shape=jax.ShapeDtypeStruct(size, jnp.int_))) + + def logseries(self, p, size=None): + p = _check_py_seq(_remove_jax_array(p)) + if size is None: + size = jnp.shape(p) + size = _size2shape(size) + return JaxArray(call(lambda p: np.random.logseries(p=p, size=size), + p, result_shape=jax.ShapeDtypeStruct(size, jnp.int_))) + + def noncentral_f(self, dfnum, dfden, nonc, size=None): + dfnum = _check_py_seq(_remove_jax_array(dfnum)) + dfden = _check_py_seq(_remove_jax_array(dfden)) + nonc = _check_py_seq(_remove_jax_array(nonc)) + if size is None: + size = lax.broadcast_shapes(jnp.shape(dfnum), + jnp.shape(dfden), + jnp.shape(nonc)) + size = _size2shape(size) + d = {'dfnum': dfnum, 'dfden': dfden, 'nonc': nonc} + return JaxArray(call(lambda x: np.random.noncentral_f(dfnum=x['dfnum'], + dfden=x['dfden'], + nonc=x['nonc'], + size=size), + d, result_shape=jax.ShapeDtypeStruct(size, jnp.float_))) + + +# alias +Generator = RandomState + +# register pytree register_pytree_node(RandomState, lambda t: ((t.value,), None), lambda aux_data, flat_contents: RandomState(*flat_contents)) +# default random genrator DEFAULT = RandomState(np.random.randint(0, 10000, size=2, dtype=np.uint32)) +@wraps(np.random.default_rng) +def default_rng(seed=None): + return RandomState(seed) + + @wraps(np.random.seed) def seed(seed=None): - global DEFAULT DEFAULT.seed(np.random.randint(0, 100000) if seed is None else seed) @wraps(np.random.rand) def rand(*dn): - return JaxArray(jr.uniform(DEFAULT.split_key(), shape=dn, minval=0., maxval=1.)) + return DEFAULT.rand(*dn) @wraps(np.random.randint) -def randint(low, high=None, size=None, dtype=int): - if high is None: - high = low - low = 0 - high = jnp.asarray(high) - low = jnp.asarray(low) - if size is None: - size = np.broadcast(low, high).shape +def randint(low, high=None, size=None, dtype=jnp.int_): + return DEFAULT.randint(low, high=high, size=size, dtype=dtype) + - return JaxArray(jr.randint(DEFAULT.split_key(), shape=_size2shape(size), - minval=low, maxval=high, dtype=dtype)) +@wraps(np.random.random_integers) +def random_integers(low, high=None, size=None): + return DEFAULT.random_integers(low, high=high, size=size) @wraps(np.random.randn) def randn(*dn): - return JaxArray(jr.normal(DEFAULT.split_key(), shape=dn)) + return DEFAULT.randn(*dn) @wraps(np.random.random) def random(size=None): - return JaxArray(jr.uniform(DEFAULT.split_key(), shape=_size2shape(size), minval=0., maxval=1.)) + return DEFAULT.random(size) @wraps(np.random.random_sample) def random_sample(size=None): - return JaxArray(jr.uniform(DEFAULT.split_key(), shape=_size2shape(size), minval=0., maxval=1.)) + return DEFAULT.random_sample(size) -ranf = random_sample -sample = random_sample +@wraps(np.random.ranf) +def ranf(size=None): + return DEFAULT.ranf(size) + + +@wraps(np.random.sample) +def sample(size=None): + return DEFAULT.sample(size) @wraps(np.random.choice) def choice(a, size=None, replace=True, p=None): - a = a.value if isinstance(a, JaxArray) else a - if p is not None: - p = jnp.asarray(p) - return JaxArray(jr.choice(DEFAULT.split_key(), a=a, shape=_size2shape(size), replace=replace, p=p)) + a = _remove_jax_array(a) + return DEFAULT.choice(a=a, size=size, replace=replace, p=p) @wraps(np.random.permutation) def permutation(x): - x = x.value if isinstance(x, JaxArray) else x - return JaxArray(jr.permutation(DEFAULT.split_key(), x)) + return DEFAULT.permutation(x) @wraps(np.random.shuffle) def shuffle(x, axis=0): - assert isinstance(x, JaxArray), f'Must be a JaxArray, but got {type(x)}' - x.value = jr.permutation(DEFAULT.split_key(), x.value, axis=axis) + DEFAULT.shuffle(x, axis) @wraps(np.random.beta) def beta(a, b, size=None): - a = jnp.asarray(a) - b = jnp.asarray(b) - return JaxArray(jr.beta(DEFAULT.split_key(), a=a, b=b, shape=_size2shape(size))) + return DEFAULT.beta(a, b, size=size) @wraps(np.random.exponential) -def exponential(scale=1.0, size=None): - scale = jnp.asarray(scale) - return JaxArray(jr.exponential(DEFAULT.split_key(), shape=_size2shape(size)) / scale) +def exponential(scale=None, size=None): + return DEFAULT.exponential(scale, size) @wraps(np.random.gamma) -def gamma(shape, scale=1.0, size=None): - shape = jnp.asarray(shape) - scale = jnp.asarray(scale) - return JaxArray(jr.gamma(DEFAULT.split_key(), a=shape, shape=_size2shape(size)) * scale) +def gamma(shape, scale=None, size=None): + return DEFAULT.gamma(shape, scale, size=size) @wraps(np.random.gumbel) -def gumbel(loc=0.0, scale=1.0, size=None): - loc = jnp.asarray(loc) - scale = jnp.asarray(scale) - return JaxArray(jr.gumbel(DEFAULT.split_key(), shape=_size2shape(size)) * scale + loc) +def gumbel(loc=None, scale=None, size=None): + return DEFAULT.gumbel(loc, scale, size=size) @wraps(np.random.laplace) -def laplace(loc=0.0, scale=1.0, size=None): - loc = jnp.asarray(loc) - scale = jnp.asarray(scale) - return JaxArray(jr.laplace(DEFAULT.split_key(), shape=_size2shape(size)) * scale + loc) +def laplace(loc=None, scale=None, size=None): + return DEFAULT.laplace(loc, scale, size) @wraps(np.random.logistic) -def logistic(loc=0.0, scale=1.0, size=None): - loc = jnp.asarray(loc) - scale = jnp.asarray(scale) - return JaxArray(jr.logistic(DEFAULT.split_key(), shape=_size2shape(size)) * scale + loc) +def logistic(loc=None, scale=None, size=None): + return DEFAULT.logistic(loc, scale, size) @wraps(np.random.normal) -def normal(loc=0.0, scale=1.0, size=None): - loc = jnp.asarray(loc) - scale = jnp.asarray(scale) - return JaxArray(jr.normal(DEFAULT.split_key(), shape=_size2shape(size)) * scale + loc) +def normal(loc=None, scale=None, size=None): + return DEFAULT.normal(loc, scale, size) @wraps(np.random.pareto) def pareto(a, size=None): - a = jnp.asarray(a) - return JaxArray(jr.pareto(DEFAULT.split_key(), b=a, shape=_size2shape(size))) + return DEFAULT.pareto(a, size) @wraps(np.random.poisson) def poisson(lam=1.0, size=None): - lam = jnp.asarray(lam) - return JaxArray(jr.poisson(DEFAULT.split_key(), lam=lam, shape=_size2shape(size))) + return DEFAULT.poisson(lam, size) @wraps(np.random.standard_cauchy) def standard_cauchy(size=None): - return JaxArray(jr.cauchy(DEFAULT.split_key(), shape=_size2shape(size))) + return DEFAULT.standard_cauchy(size) @wraps(np.random.standard_exponential) def standard_exponential(size=None): - return JaxArray(jr.exponential(DEFAULT.split_key(), shape=_size2shape(size))) + return DEFAULT.standard_exponential(size) @wraps(np.random.standard_gamma) def standard_gamma(shape, size=None): - shape = jnp.asarray(shape) - return JaxArray(jr.gamma(DEFAULT.split_key(), a=shape, shape=_size2shape(size))) + return DEFAULT.standard_gamma(shape, size) @wraps(np.random.standard_normal) def standard_normal(size=None): - return JaxArray(jr.normal(DEFAULT.split_key(), shape=_size2shape(size))) + return DEFAULT.standard_normal(size) @wraps(np.random.standard_t) def standard_t(df, size=None): - df = jnp.asarray(df) - return JaxArray(jr.t(DEFAULT.split_key(), df=df, shape=_size2shape(size))) + return DEFAULT.standard_t(df, size) @wraps(np.random.uniform) def uniform(low=0.0, high=1.0, size=None): - low = jnp.asarray(low) - high = jnp.asarray(high) - if size is None: - size = np.broadcast(low, high).shape + return DEFAULT.uniform(low, high, size) - return JaxArray(jr.uniform(DEFAULT.split_key(), shape=_size2shape(size), minval=low, maxval=high)) - -def truncated_normal(lower, upper, size=None, scale=1.): +@wraps(jr.truncated_normal) +def truncated_normal(lower, upper, size=None, scale=None): """Sample truncated standard normal random values with given shape and dtype. Parameters @@ -408,258 +1101,138 @@ def truncated_normal(lower, upper, size=None, scale=1.): ``shape`` is not None, or else by broadcasting ``lower`` and ``upper``. Returns values in the open interval ``(lower, upper)``. """ - lower = jnp.asarray(lower) - upper = jnp.asarray(upper) - if size is None: - size = np.broadcast(lower, upper).shape - - rands = jr.truncated_normal(DEFAULT.split_key(), - lower=lower, - upper=upper, - shape=_size2shape(size)) - return JaxArray(rands * scale) + return DEFAULT.truncated_normal(lower, upper, size, scale) +@wraps(jr.bernoulli) def bernoulli(p=0.5, size=None): """Sample Bernoulli random values with given shape and mean. - Args: - p: optional, a float or array of floats for the mean of the random - variables. Must be broadcast-compatible with ``shape`` and the values - should be within [0, 1]. Default 0.5. - size: optional, a tuple of nonnegative integers representing the result - shape. Must be broadcast-compatible with ``p.shape``. The default (None) - produces a result shape equal to ``p.shape``. + Parameters + ---------- + p: float, array_like, optional + A float or array of floats for the mean of the random + variables. Must be broadcast-compatible with ``shape`` and the values + should be within [0, 1]. Default 0.5. + size: optional, tuple of int, int + A tuple of nonnegative integers representing the result + shape. Must be broadcast-compatible with ``p.shape``. The default (None) + produces a result shape equal to ``p.shape``. - Returns: + Returns + ------- + out: array_like A random array with boolean dtype and shape given by ``shape`` if ``shape`` is not None, or else ``p.shape``. """ - p = jnp.asarray(p) - if jnp.unique(jnp.logical_and(p >= 0, p <= 1)) != jnp.array([True]): - raise ValueError(r'Bernoulli parameter p should be within [0, 1], but we got {}'.format(p)) - - if size is None: - size = p.shape - - return JaxArray(jr.bernoulli(DEFAULT.split_key(), p=p, shape=_size2shape(size))) + return DEFAULT.bernoulli(p, size) @wraps(np.random.lognormal) -def lognormal(mean=0.0, sigma=1.0, size=None): - mean = jnp.asarray(mean) - sigma = jnp.asarray(sigma) - samples = jr.normal(DEFAULT.split_key(), shape=_size2shape(size)) - samples = samples * sigma + mean - samples = jnp.exp(samples) - return JaxArray(samples) +def lognormal(mean=None, sigma=None, size=None): + return DEFAULT.lognormal(mean, sigma, size) @wraps(np.random.binomial) def binomial(n, p, size=None): - if size is None: - size = np.broadcast(n, p).shape - size = _size2shape(size) - d = {'n': n, 'p': p, 'size': size} - return JaxArray(hcb_call(lambda x: np.random.binomial(n=x['n'], p=x['p'], size=x['size']), - d, result_shape=jax.ShapeDtypeStruct(size, int))) + return DEFAULT.binomial(n, p, size) @wraps(np.random.chisquare) def chisquare(df, size=None): - if size is None: - size = np.shape(df) - size = _size2shape(size) - d = {'df': df, 'size': size} - return JaxArray(hcb_call(lambda x: np.random.chisquare(df=x['df'], size=x['size']), - d, result_shape=jax.ShapeDtypeStruct(size, float))) + return DEFAULT.chisquare(df, size) @wraps(np.random.dirichlet) def dirichlet(alpha, size=None): - size = _size2shape(size) - d = {'alpha': alpha, 'size': size} - output_shape = size + np.shape(alpha) - return JaxArray(hcb_call(lambda x: np.random.dirichlet(alpha=x['alpha'], size=x['size']), - d, result_shape=jax.ShapeDtypeStruct(output_shape, float))) - - -@wraps(np.random.f) -def f(dfnum, dfden, size=None): - if size is None: - size = np.broadcast(dfnum, dfden).shape - size = _size2shape(size) - d = {'dfnum': dfnum, 'dfden': dfden, 'size': size} - return JaxArray(hcb_call(lambda x: np.random.f(dfnum=x['dfnum'], dfden=x['dfden'], size=x['size']), - d, result_shape=jax.ShapeDtypeStruct(size, float))) + return DEFAULT.dirichlet(alpha, size) @wraps(np.random.geometric) def geometric(p, size=None): - if size is None: - size = np.shape(p) - size = _size2shape(size) - d = {'p': p, 'size': size} - return JaxArray(hcb_call(lambda x: np.random.geometric(p=x['p'], size=x['size']), - d, result_shape=jax.ShapeDtypeStruct(size, int))) + return DEFAULT.geometric(p, size) + + +@wraps(np.random.f) +def f(dfnum, dfden, size=None): + return DEFAULT.f(dfnum, dfden, size) @wraps(np.random.hypergeometric) def hypergeometric(ngood, nbad, nsample, size=None): - if size is None: - size = np.broadcast(ngood, nbad, nsample).shape - size = _size2shape(size) - d = {'ngood': ngood, 'nbad': nbad, 'nsample': nsample, 'size': size} - return JaxArray(hcb_call(lambda x: np.random.hypergeometric(ngood=x['ngood'], nbad=x['nbad'], - nsample=x['nsample'], size=x['size']), - d, result_shape=jax.ShapeDtypeStruct(size, int))) + return DEFAULT.hypergeometric(ngood, nbad, nsample, size) @wraps(np.random.logseries) def logseries(p, size=None): - if size is None: - size = np.shape(p) - size = _size2shape(size) - d = {'p': p, 'size': size} - return JaxArray(hcb_call(lambda x: np.random.logseries(p=x['p'], size=x['size']), - d, result_shape=jax.ShapeDtypeStruct(size, int))) + return DEFAULT.logseries(p, size) @wraps(np.random.multinomial) def multinomial(n, pvals, size=None): - size = _size2shape(size) - d = {'n': n, 'pvals': pvals, 'size': size} - output_shape = size + np.shape(pvals) - return JaxArray(hcb_call(lambda x: np.random.multinomial(n=x['n'], pvals=x['pvals'], size=x['size']), - d, result_shape=jax.ShapeDtypeStruct(output_shape, int))) - - -def _packed_multivariate_normal(d): - candidate_str = ['warn', 'raise', 'ignore'] - selected = np.array([d['warn'], d['raise'], d['ignore']]) + return DEFAULT.multinomial(n, pvals, size) - return np.random.multivariate_normal(mean=d['mean'], cov=d['cov'], size=d['size'], - check_valid=candidate_str[np.arange(3)[selected][0]], - tol=d['tol']) @wraps(np.random.multivariate_normal) -def multivariate_normal(mean, cov, size=None, check_valid='warn', tol=1e-8): - size = _size2shape(size) - - if not (check_valid == 'warn' or check_valid == 'raise' or check_valid == 'ignore'): - raise ValueError(r'multivariate_normal argument check_valid should be "warn", "raise", ' - 'or "ignore", but we got {}'.format(check_valid)) - - d = {'mean': mean, 'cov': cov, 'size': size, - 'warn': True if check_valid == 'warn' else False, - 'raise': True if check_valid == 'raise' else False, - 'ignore': True if check_valid == 'ignore' else False, - 'tol': tol} - output_shape = size + np.shape(mean) - - return JaxArray(hcb_call(_packed_multivariate_normal, d, - result_shape=jax.ShapeDtypeStruct(output_shape, float))) +def multivariate_normal(mean, cov, size=None): + return DEFAULT.multivariate_normal(mean, cov, size) @wraps(np.random.negative_binomial) def negative_binomial(n, p, size=None): - if size is None: - size = np.broadcast(n, p).shape - size = _size2shape(size) - d = {'n': n, 'p': p, 'size': size} - return JaxArray(hcb_call(lambda x: np.random.negative_binomial(n=x['n'], p=x['p'], size=x['size']), - d, result_shape=jax.ShapeDtypeStruct(size, int))) + return DEFAULT.negative_binomial(n, p, size) @wraps(np.random.noncentral_chisquare) def noncentral_chisquare(df, nonc, size=None): - if size is None: - size = np.broadcast(df, nonc).shape - size = _size2shape(size) - d = {'df': df, 'nonc': nonc, 'size': size} - return JaxArray(hcb_call(lambda x: np.random.noncentral_chisquare(df=x['df'], nonc=x['nonc'], size=x['size']), - d, result_shape=jax.ShapeDtypeStruct(size, float))) + return DEFAULT.noncentral_chisquare(df, nonc, size) @wraps(np.random.noncentral_f) def noncentral_f(dfnum, dfden, nonc, size=None): - if size is None: - size = np.broadcast(dfnum, dfden, nonc).shape - size = _size2shape(size) - d = {'dfnum': dfnum, 'dfden': dfden, 'nonc': nonc, 'size': size} - return JaxArray(hcb_call(lambda x: np.random.noncentral_f(dfnum=x['dfnum'], dfden=x['dfden'], - nonc=x['nonc'], size=x['size']), - d, result_shape=jax.ShapeDtypeStruct(size, float))) + return DEFAULT.noncentral_f(dfnum, dfden, nonc, size) @wraps(np.random.power) def power(a, size=None): - if size is None: - size = np.shape(a) - size = _size2shape(size) - d = {'a': a, 'size': size} - return JaxArray(hcb_call(lambda x: np.random.power(a=x['a'], size=x['size']), - d, result_shape=jax.ShapeDtypeStruct(size, float))) + return DEFAULT.power(a, size) @wraps(np.random.rayleigh) def rayleigh(scale=1.0, size=None): - if size is None: - size = np.shape(scale) - size = _size2shape(size) - d = {'scale': scale, 'size': size} - return JaxArray(hcb_call(lambda x: np.random.rayleigh(scale=x['scale'], size=x['size']), - d, result_shape=jax.ShapeDtypeStruct(size, float))) + return DEFAULT.rayleigh(scale, size) @wraps(np.random.triangular) -def triangular(left, mode, right, size=None): - if size is None: - size = np.broadcast(left, mode, right).shape - size = _size2shape(size) - d = {'left': left, 'mode': mode, 'right': right, 'size': size} - return JaxArray(hcb_call(lambda x: np.random.triangular(left=x['left'], mode=x['mode'], - right=x['right'], size=x['size']), - d, result_shape=jax.ShapeDtypeStruct(size, float))) +def triangular(size=None): + return DEFAULT.triangular(size) @wraps(np.random.vonmises) def vonmises(mu, kappa, size=None): - if size is None: - size = np.broadcast(mu, kappa).shape - size = _size2shape(size) - d = {'mu': mu, 'kappa': kappa, 'size': size} - return JaxArray(hcb_call(lambda x: np.random.vonmises(mu=x['mu'], kappa=x['kappa'], size=x['size']), - d, result_shape=jax.ShapeDtypeStruct(size, float))) + return DEFAULT.vonmises(mu, kappa, size) @wraps(np.random.wald) def wald(mean, scale, size=None): - if size is None: - size = np.broadcast(mean, scale).shape - size = _size2shape(size) - d = {'mean': mean, 'scale': scale, 'size': size} - return JaxArray(hcb_call(lambda x: np.random.wald(mean=x['mean'], scale=x['scale'], size=x['size']), - d, result_shape=jax.ShapeDtypeStruct(size, float))) + return DEFAULT.wald(mean, scale, size) @wraps(np.random.weibull) def weibull(a, size=None): - if size is None: - size = np.shape(a) - size = _size2shape(size) - d = {'a': a, 'size': size} - return JaxArray(hcb_call(lambda x: np.random.weibull(a=x['a'], size=x['size']), - d, result_shape=jax.ShapeDtypeStruct(size, float))) + return DEFAULT.weibull(a, size) + + +@wraps(jr.weibull_min) +def weibull_min(a, scale=None, size=None): + return DEFAULT.weibull_min(a, scale, size) @wraps(np.random.zipf) def zipf(a, size=None): - if size is None: - size = np.shape(a) - size = _size2shape(size) - d = {'a': a, 'size': size} - return JaxArray(hcb_call(lambda x: np.random.zipf(a=x['a'], size=x['size']), - d, result_shape=jax.ShapeDtypeStruct(size, int))) + return DEFAULT.zipf(a, size) +@wraps(jr.maxwell) +def maxwell(size=None): + return DEFAULT.maxwell(size) diff --git a/brainpy/math/tests/test_random.py b/brainpy/math/tests/test_random.py index 055615250..aeb010b42 100644 --- a/brainpy/math/tests/test_random.py +++ b/brainpy/math/tests/test_random.py @@ -1,10 +1,13 @@ import unittest +import jax import jax.numpy as jnp -import jax.random as jr +import numpy as np +import numpy.random as nr + import brainpy.math as bm import brainpy.math.random as br -import numpy.random as nr + class TestRandom(unittest.TestCase): def test_seed(self): @@ -177,7 +180,7 @@ def test_standard_t(self): def test_standard_uniform1(self): a = bm.random.uniform() self.assertTupleEqual(a.shape, ()) - self.assertTrue(0 <=a< 1) + self.assertTrue(0 <= a < 1) def test_uniform2(self): a = bm.random.uniform(low=[-1., 5., 2.], high=[2., 6., 10.], size=3) @@ -237,6 +240,9 @@ def test_lognormal3(self): def test_binomial1(self): a = bm.random.binomial(5, 0.5) + b = np.random.binomial(5, 0.5) + print(a) + print(b) self.assertIsInstance(a, bm.JaxArray) self.assertTupleEqual(a.shape, ()) self.assertTrue(a.dtype, int) @@ -247,7 +253,7 @@ def test_binomial2(self): self.assertTrue((a >= 0).all() and (a <= 5).all()) def test_binomial3(self): - a = bm.random.binomial(n=[2, 3, 4], p=[[0.5, 0.5, 0.5], [0.6, 0.6, 0.6]]) + a = bm.random.binomial(n=bm.asarray([2, 3, 4]), p=bm.asarray([[0.5, 0.5, 0.5], [0.6, 0.6, 0.6]])) self.assertTupleEqual(a.shape, (2, 3)) def test_chisquare1(self): @@ -257,8 +263,16 @@ def test_chisquare1(self): self.assertTrue(a.dtype, float) def test_chisquare2(self): - a = bm.random.chisquare(df=[2, 3, 4]) - self.assertTupleEqual(a.shape, (3,)) + with self.assertRaises(NotImplementedError): + a = bm.random.chisquare(df=[2, 3, 4]) + + def test_chisquare3(self): + a = bm.random.chisquare(df=2, size=100) + self.assertTupleEqual(a.shape, (100,)) + + def test_chisquare4(self): + a = bm.random.chisquare(df=2, size=(100, 10)) + self.assertTupleEqual(a.shape, (100, 10)) def test_dirichlet1(self): a = bm.random.dirichlet((10, 5, 3)) @@ -293,58 +307,148 @@ def test_logseries(self): self.assertTupleEqual(a.shape, (4, 3)) def test_multinominal1(self): - a = bm.random.multinomial(100, (0.5, 0.2, 0.3), size=[4, 2]) - self.assertTupleEqual(a.shape, (4, 2, 3)) + a = np.random.multinomial(100, (0.5, 0.2, 0.3), size=[4, 2]) + print(a, a.shape) + b = bm.random.multinomial(100, (0.5, 0.2, 0.3), size=[4, 2]) + print(b, b.shape) + self.assertTupleEqual(a.shape, b.shape) + self.assertTupleEqual(b.shape, (4, 2, 3)) def test_multinominal2(self): a = bm.random.multinomial(100, (0.5, 0.2, 0.3)) self.assertTupleEqual(a.shape, (3,)) self.assertTrue(a.sum() == 100) + def test_multinominal3(self): + with self.assertRaises(ValueError): + a = bm.random.multinomial(100, (0.5, 0.6, 0.3)) + with self.assertRaises(ValueError): + f = jax.jit(bm.random.multinomial, static_argnums=2) + a = f(100, (0.5, 0.6, 0.3), 2) + def test_multivariate_normal1(self): - a = bm.random.multivariate_normal([1, 2], [[1, 0], [0, 1]], size=3) + a = np.random.multivariate_normal([1, 2], [[1, 0], [0, 1]], size=3) + b = bm.random.multivariate_normal([1, 2], [[1, 0], [0, 1]], size=3) + print('test_multivariate_normal1') + print(a) + print(b) + self.assertTupleEqual(a.shape, b.shape) self.assertTupleEqual(a.shape, (3, 2)) def test_multivariate_normal2(self): - a = bm.random.multivariate_normal([1, 2], [[1, 3], [3, 1]], check_valid='ignore') + a = np.random.multivariate_normal([1, 2], [[1, 3], [3, 1]]) + b = bm.random.multivariate_normal([1, 2], [[1, 3], [3, 1]]) + self.assertTupleEqual(a.shape, b.shape) self.assertTupleEqual(a.shape, (2,)) def test_negative_binomial(self): - a = bm.random.negative_binomial([3., 10.], 0.5) - self.assertTupleEqual(a.shape, (2,)) + a = np.random.negative_binomial([3., 10.], 0.5) + b = bm.random.negative_binomial([3., 10.], 0.5) + print(a) + print(b) + self.assertTupleEqual(a.shape, b.shape) + self.assertTupleEqual(b.shape, (2,)) + + def test_negative_binomial2(self): + a = np.random.negative_binomial(3., 0.5, 10) + b = bm.random.negative_binomial(3., 0.5, 10) + print(a) + print(b) + self.assertTupleEqual(a.shape, b.shape) + self.assertTupleEqual(b.shape, (10,)) def test_noncentral_chisquare(self): - a = bm.random.noncentral_chisquare(3, [3., 2.], (4, 2)) - self.assertTupleEqual(a.shape, (4, 2)) + a = np.random.noncentral_chisquare(3, [3., 2.], (4, 2)) + b = bm.random.noncentral_chisquare(3, [3., 2.], (4, 2)) + self.assertTupleEqual(a.shape, b.shape) + self.assertTupleEqual(b.shape, (4, 2)) + + def test_noncentral_chisquare2(self): + a = bm.random.noncentral_chisquare(3, [3., 2.]) + self.assertTupleEqual(a.shape, (2,)) def test_noncentral_f(self): a = bm.random.noncentral_f(3, 20, 3., 100) self.assertTupleEqual(a.shape, (100,)) def test_power(self): - a = bm.random.power(2, (4, 2)) - self.assertTupleEqual(a.shape, (4, 2)) + a = np.random.power(2, (4, 2)) + b = bm.random.power(2, (4, 2)) + self.assertTupleEqual(a.shape, b.shape) + self.assertTupleEqual(b.shape, (4, 2)) def test_rayleigh(self): a = bm.random.power(2., (4, 2)) self.assertTupleEqual(a.shape, (4, 2)) def test_triangular(self): - a = bm.random.triangular([-1., 0.], 1., [[2., 5.], [3., 3.]]) + a = bm.random.triangular((2, 2)) self.assertTupleEqual(a.shape, (2, 2)) def test_vonmises(self): - a = bm.random.vonmises(2., 2.) - self.assertTupleEqual(a.shape, ()) + a = np.random.vonmises(2., 2.) + b = bm.random.vonmises(2., 2.) + print(a, b) + self.assertTupleEqual(np.shape(a), b.shape) + self.assertTupleEqual(b.shape, ()) + + def test_vonmises2(self): + a = np.random.vonmises(2., 2., 10) + b = bm.random.vonmises(2., 2., 10) + print(a, b) + self.assertTupleEqual(a.shape, b.shape) + self.assertTupleEqual(b.shape, (10,)) def test_wald(self): - a = bm.random.wald([2., 0.5], 2.) - self.assertTupleEqual(a.shape, (2,)) + a = np.random.wald([2., 0.5], 2.) + b = bm.random.wald([2., 0.5], 2.) + self.assertTupleEqual(a.shape, b.shape) + self.assertTupleEqual(b.shape, (2,)) + + def test_wald2(self): + a = np.random.wald(2., 2., 100) + b = bm.random.wald(2., 2., 100) + self.assertTupleEqual(a.shape, b.shape) + self.assertTupleEqual(b.shape, (100,)) def test_weibull(self): a = bm.random.weibull(2., (4, 2)) self.assertTupleEqual(a.shape, (4, 2)) + def test_weibull2(self): + a = bm.random.weibull(2., ) + self.assertTupleEqual(a.shape, ()) + + def test_weibull3(self): + a = bm.random.weibull([2., 3.], ) + self.assertTupleEqual(a.shape, (2,)) + + def test_weibull_min(self): + a = bm.random.weibull_min(2., 2., (4, 2)) + self.assertTupleEqual(a.shape, (4, 2)) + + def test_weibull_min2(self): + a = bm.random.weibull_min(2., 2.) + self.assertTupleEqual(a.shape, ()) + + def test_weibull_min3(self): + a = bm.random.weibull_min([2., 3.], 2.) + self.assertTupleEqual(a.shape, (2,)) + def test_zipf(self): a = bm.random.zipf(2., (4, 2)) self.assertTupleEqual(a.shape, (4, 2)) + + def test_zipf2(self): + a = np.random.zipf([1.1, 2.]) + b = bm.random.zipf([1.1, 2.]) + self.assertTupleEqual(a.shape, b.shape) + self.assertTupleEqual(b.shape, (2,)) + + def test_maxwell(self): + a = bm.random.maxwell(10) + self.assertTupleEqual(a.shape, (10,)) + + def test_maxwell2(self): + a = bm.random.maxwell() + self.assertTupleEqual(a.shape, ()) diff --git a/brainpy/math/utils.py b/brainpy/math/utils.py index ab3e9e0fd..9432a2b22 100644 --- a/brainpy/math/utils.py +++ b/brainpy/math/utils.py @@ -20,7 +20,7 @@ def wraps(fun: Callable): def wrap(op): docstr = getattr(fun, "__doc__", None) op.__doc__ = docstr - op.__np_wrapped__ = fun + op.__wrapped__ = fun for attr in ['__name__', '__qualname__']: try: value = getattr(fun, attr) From 3d7a90d377603e7058133e98bd72ddc18caea641 Mon Sep 17 00:00:00 2001 From: chaoming Date: Mon, 16 May 2022 23:52:47 +0800 Subject: [PATCH 2/7] update docs and tests --- brainpy/math/random.py | 13 ++++----- changelog.rst | 63 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 7 deletions(-) diff --git a/brainpy/math/random.py b/brainpy/math/random.py index 381e72cd8..276bb2581 100644 --- a/brainpy/math/random.py +++ b/brainpy/math/random.py @@ -539,7 +539,7 @@ def gumbel(self, loc=None, scale=None, size=None): scale = _check_py_seq(scale) if size is None: size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)) - return _loc_scale(loc, scale, jr.gumbel(self.split_key(), shape=size)) + return _loc_scale(loc, scale, jr.gumbel(self.split_key(), shape=_size2shape(size))) def laplace(self, loc=None, scale=None, size=None): loc = _remove_jax_array(loc) @@ -548,7 +548,7 @@ def laplace(self, loc=None, scale=None, size=None): scale = _check_py_seq(scale) if size is None: size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)) - return _loc_scale(loc, scale, jr.laplace(self.split_key(), shape=size)) + return _loc_scale(loc, scale, jr.laplace(self.split_key(), shape=_size2shape(size))) def logistic(self, loc=None, scale=None, size=None): loc = _remove_jax_array(loc) @@ -557,7 +557,7 @@ def logistic(self, loc=None, scale=None, size=None): scale = _check_py_seq(scale) if size is None: size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)) - return _loc_scale(loc, scale, jr.logistic(self.split_key(), shape=size)) + return _loc_scale(loc, scale, jr.logistic(self.split_key(), shape=_size2shape(size))) def normal(self, loc=None, scale=None, size=None): loc = _remove_jax_array(loc) @@ -566,7 +566,7 @@ def normal(self, loc=None, scale=None, size=None): scale = _check_py_seq(scale) if size is None: size = lax.broadcast_shapes(jnp.shape(scale), jnp.shape(loc)) - return _loc_scale(loc, scale, jr.normal(self.split_key(), shape=size)) + return _loc_scale(loc, scale, jr.normal(self.split_key(), shape=_size2shape(size))) def pareto(self, a, size=None): a = _remove_jax_array(a) @@ -717,12 +717,11 @@ def rayleigh(self, scale=1.0, size=None): scale = _check_py_seq(_remove_jax_array(scale)) if size is None: size = jnp.shape(scale) - x = jnp.sqrt(-2. * jnp.log(jr.uniform(self.split_key(), shape=size, minval=0, maxval=1))) + x = jnp.sqrt(-2. * jnp.log(jr.uniform(self.split_key(), shape=_size2shape(size), minval=0, maxval=1))) return JaxArray(x * scale) def triangular(self, size=None): - size = _size2shape(size) - bernoulli_samples = jr.bernoulli(self.split_key(), p=0.5, shape=size) + bernoulli_samples = jr.bernoulli(self.split_key(), p=0.5, shape=_size2shape(size)) return JaxArray(2 * bernoulli_samples - 1) def vonmises(self, mu, kappa, size=None): diff --git a/changelog.rst b/changelog.rst index d229a57f1..a9add84e3 100644 --- a/changelog.rst +++ b/changelog.rst @@ -6,6 +6,69 @@ brainpy 2.x (LTS) ***************** + + + + +Version 2.1.11 (2022.05.15) +========================== + + +What's Changed +~~~~~~~~~~~~~~ + +* fix: cross-correlation bug by `@ztqakita `_ in `#201 `_ +* update apis, test and docs of numpy ops by `@chaoming0625 `_ in `#202 `_ +* docs: add sphinx_book_theme by `@ztqakita `_ in `#203 `_ +* fix: add requirements-doc.txt by `@ztqakita `_ in `#204 `_ +* update control flow, integrators, operators, and docs by `@chaoming0625 `_ in `#205 `_ +* improve oo-to-function transformation speed by `@chaoming0625 `_ in `#208 `_ + +**Full Changelog**\ : `V2.1.10...V2.1.11 `_ + + + +Version 2.1.10 (2022.05.05) +========================== + + +What's Changed +~~~~~~~~~~~~~~ + +* update control flow APIs and Docs by `@chaoming0625 `_ in `#192 `_ +* doc: update docs of dynamics simulation by `@chaoming0625 `_ in `#193 `_ +* fix `#125 `_: add channel models and two-compartment Pinsky-Rinzel model by `@chaoming0625 `_ in `#194 `_ +* JIT errors do not change Variable values by `@chaoming0625 `_ in `#195 `_ +* fix a bug in math.activations.py by `@c-xy17 `_ in `#196 `_ +* Functionalinaty improvements by `@chaoming0625 `_ in `#197 `_ +* update rate docs by `@chaoming0625 `_ in `#198 `_ +* update brainpy.dyn doc by `@chaoming0625 `_ in `#199 `_ + +**Full Changelog**\ : `V2.1.8...V2.1.10 `_ + + + +Version 2.1.8 (2022.04.26) +========================== + + +What's Changed +~~~~~~~~~~~~~~ + +* Fix `#120 `_ by `@chaoming0625 `_ in `#178 `_ +* feat: brainpy.Collector supports addition and subtraction by `@chaoming0625 `_ in `#179 `_ +* feat: delay variables support "indices" and "reset()" function by `@chaoming0625 `_ in `#180 `_ +* Support reset functions in neuron and synapse models by `@chaoming0625 `_ in `#181 `_ +* ``update()`` function on longer need ``_t`` and ``_dt`` by `@chaoming0625 `_ in `#183 `_ +* small updates by `@chaoming0625 `_ in `#188 `_ +* feat: easier control flows with ``brainpy.math.ifelse`` by `@chaoming0625 `_ in `#189 `_ +* feat: update delay couplings of ``DiffusiveCoupling`` and ``AdditiveCouping`` by `@chaoming0625 `_ in `#190 `_ +* update version and changelog by `@chaoming0625 `_ in `#191 `_ + +**Full Changelog**\ : `V2.1.7...V2.1.8 `_ + + + Version 2.1.7 (2022.04.22) ========================== From 860b5306fd7788e803a337cdd541f13512ba1f89 Mon Sep 17 00:00:00 2001 From: chaoming Date: Tue, 17 May 2022 10:28:04 +0800 Subject: [PATCH 3/7] update random apis --- .gitignore | 2 +- brainpy/math/random.py | 142 +++++++++++++++++++++++++++--- brainpy/math/tests/test_random.py | 12 ++- 3 files changed, 141 insertions(+), 15 deletions(-) diff --git a/.gitignore b/.gitignore index 1d40ddd9e..3f87f9728 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,7 @@ publishment.md #experimental/ .vscode - +io_test_tmp* brainpy/base/tests/io_test_tmp* diff --git a/brainpy/math/random.py b/brainpy/math/random.py index 276bb2581..5a676162c 100644 --- a/brainpy/math/random.py +++ b/brainpy/math/random.py @@ -3,9 +3,11 @@ from collections import namedtuple from functools import partial +from operator import index import jax import numpy as np from jax import lax, jit, vmap, numpy as jnp, random as jr, core +from jax._src import dtypes from jax.experimental.host_callback import call from jax.experimental import checkify from jax.tree_util import register_pytree_node @@ -27,7 +29,7 @@ 'hypergeometric', 'logseries', 'multinomial', 'multivariate_normal', 'negative_binomial', 'noncentral_chisquare', 'noncentral_f', 'power', 'rayleigh', 'triangular', 'vonmises', 'wald', 'weibull', 'weibull_min', - 'zipf', 'maxwell' + 'zipf', 'maxwell', 't', 'orthogonal', 'loggamma', 'categorical', ] @@ -43,18 +45,28 @@ def _size2shape(size): def _check_shape(name, shape, *param_shapes): - for s in param_shapes: - if s != shape: + shape = core.as_named_shape(shape) + if param_shapes: + shape_ = lax.broadcast_shapes(shape.positional, *param_shapes) + if shape.positional != shape_: msg = ("{} parameter shapes must be broadcast-compatible with shape " "argument, and the result of broadcasting the shapes must equal " "the shape argument, but got result {} for shape argument {}.") - raise ValueError(msg.format(name, s, shape)) + raise ValueError(msg.format(name, shape_, shape)) def _remove_jax_array(a): return a.value if isinstance(a, JaxArray) else a +def _const(example, val): + dtype = dtypes.dtype(example, canonicalize=True) + if dtypes.is_python_scalar(example): + val = dtypes.scalar_type_of(example)(val) + return val if dtype == dtypes.dtype(val, canonicalize=True) else np.array(val, dtype) + return np.array(val, dtype) + + _tr_params = namedtuple( "tr_params", ["c", "b", "a", "alpha", "u_r", "v_r", "m", "log_p", "log1_p", "log_h"] ) @@ -702,15 +714,36 @@ def multinomial(self, n, pvals, size=None): batch_shape = lax.broadcast_shapes(jnp.shape(pvals)[:-1], jnp.shape(n)) return JaxArray(_multinomial(self.split_key(), pvals, n, n_max, batch_shape + size)) - def multivariate_normal(self, mean, cov, size=None): + def multivariate_normal(self, mean, cov, size=None, method: str = 'cholesky'): + if method not in {'svd', 'eigh', 'cholesky'}: + raise ValueError("method must be one of {'svd', 'eigh', 'cholesky'}") mean = _check_py_seq(_remove_jax_array(mean)) cov = _check_py_seq(_remove_jax_array(cov)) - size = _size2shape(size) - scale = jnp.linalg.cholesky(cov) - batch_shape = lax.broadcast_shapes(jnp.shape(mean)[:-2], jnp.shape(scale)[:-2]) - event_shape = jnp.shape(scale)[-1:] - eps = jr.normal(self.split_key(), shape=size + batch_shape + event_shape) - r = mean + jnp.squeeze(jnp.matmul(scale, eps[..., jnp.newaxis]), axis=-1) + + if not jnp.ndim(mean) >= 1: + raise ValueError(f"multivariate_normal requires mean.ndim >= 1, got mean.ndim == {jnp.ndim(mean)}") + if not jnp.ndim(cov) >= 2: + raise ValueError(f"multivariate_normal requires cov.ndim >= 2, got cov.ndim == {jnp.ndim(cov)}") + n = mean.shape[-1] + if jnp.shape(cov)[-2:] != (n, n): + raise ValueError(f"multivariate_normal requires cov.shape == (..., n, n) for n={n}, " + f"but got cov.shape == {jnp.shape(cov)}.") + if size is None: + size = lax.broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + else: + size = _size2shape(size) + _check_shape("normal", size, mean.shape[:-1], cov.shape[:-2]) + + if method == 'svd': + (u, s, _) = jnp.linalg.svd(cov) + factor = u * jnp.sqrt(s[..., None, :]) + elif method == 'eigh': + (w, v) = jnp.linalg.eigh(cov) + factor = v * jnp.sqrt(w[..., None, :]) + else: # 'cholesky' + factor = jnp.linalg.cholesky(cov) + normal_samples = jr.normal(self.split_key(), size + mean.shape[-1:]) + r = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples) return JaxArray(r) def rayleigh(self, scale=1.0, size=None): @@ -833,6 +866,29 @@ def wald(self, mean, scale, size=None): jnp.square(mean) / sampled) return JaxArray(res) + def t(self, df, size=None): + df = _check_py_seq(_remove_jax_array(df)) + if size is None: + size = np.shape(df) + else: + size = _size2shape(size) + _check_shape("t", size, np.shape(df)) + keys = self.split_keys(2) + n = jr.normal(keys[0], size) + two = _const(n, 2) + half_df = lax.div(df, two) + g = jr.gamma(keys[1], half_df, size) + return JaxArray(n * jnp.sqrt(half_df / g)) + + def orthogonal(self, n: int, size=None): + size = _size2shape(size) + _check_shape("orthogonal", size) + n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()") + z = jr.normal(self.split_key(), size + (n, n)) + q, r = jnp.linalg.qr(z) + d = jnp.diagonal(r, 0, -2, -1) + return JaxArray(q * jnp.expand_dims(d / abs(d), -2)) + def noncentral_chisquare(self, df, nonc, size=None): df = _check_py_seq(_remove_jax_array(df)) nonc = _check_py_seq(_remove_jax_array(nonc)) @@ -918,6 +974,19 @@ def noncentral_f(self, dfnum, dfden, nonc, size=None): size=size), d, result_shape=jax.ShapeDtypeStruct(size, jnp.float_))) + def loggamma(self, a, size=None): + a = _check_py_seq(_remove_jax_array(a)) + if size is None: + size = jnp.shape(a) + return JaxArray(jr.loggamma(self.split_key(), a, shape=size)) + + def categorical(self, logits, axis:int= -1, size=None): + logits = _check_py_seq(_remove_jax_array(logits)) + if size is None: + size = list(jnp.shape(logits)) + size.pop(axis) + return JaxArray(jr.categorical(self.split_key(), logits, axis=axis, shape=size)) + # alias Generator = RandomState @@ -1173,8 +1242,8 @@ def multinomial(n, pvals, size=None): @wraps(np.random.multivariate_normal) -def multivariate_normal(mean, cov, size=None): - return DEFAULT.multivariate_normal(mean, cov, size) +def multivariate_normal(mean, cov, size=None, method: str = 'cholesky'): + return DEFAULT.multivariate_normal(mean, cov, size, method) @wraps(np.random.negative_binomial) @@ -1235,3 +1304,50 @@ def zipf(a, size=None): @wraps(jr.maxwell) def maxwell(size=None): return DEFAULT.maxwell(size) + + +def t(df, size=None): + """Sample Student’s t random values. + + Parameters + ---------- + df: float, array_like + A float or array of floats broadcast-compatible with shape representing the parameter of the distribution. + size: optional, int, tuple of int + A tuple of non-negative integers specifying the result shape. + Must be broadcast-compatible with `df`. The default (None) produces a result shape equal to `df.shape`. + + Returns + ------- + out: array_like + The sampled value. + """ + return DEFAULT.t(df, size) + + +def orthogonal(n: int, size=None): + """Sample uniformly from the orthogonal group `O(n)`. + + Parameters + ---------- + n: int + An integer indicating the resulting dimension. + size: optional, int, tuple of int + The batch dimensions of the result. + + Returns + ------- + out: JaxArray + The sampled results. + """ + return DEFAULT.orthogonal(n, size) + + +@wraps(jr.loggamma) +def loggamma(a, size=None): + return DEFAULT.loggamma(a, size) + + +@wraps(jr.categorical) +def categorical(logits, axis:int= -1, size=None): + return DEFAULT.categorical(logits, axis, size) diff --git a/brainpy/math/tests/test_random.py b/brainpy/math/tests/test_random.py index aeb010b42..3e5592924 100644 --- a/brainpy/math/tests/test_random.py +++ b/brainpy/math/tests/test_random.py @@ -337,7 +337,9 @@ def test_multivariate_normal1(self): def test_multivariate_normal2(self): a = np.random.multivariate_normal([1, 2], [[1, 3], [3, 1]]) - b = bm.random.multivariate_normal([1, 2], [[1, 3], [3, 1]]) + b = bm.random.multivariate_normal([1, 2], [[1, 3], [3, 1]], method='svd') + print(a) + print(b) self.assertTupleEqual(a.shape, b.shape) self.assertTupleEqual(a.shape, (2,)) @@ -452,3 +454,11 @@ def test_maxwell(self): def test_maxwell2(self): a = bm.random.maxwell() self.assertTupleEqual(a.shape, ()) + + def test_t(self): + a = bm.random.t(1., size=10) + self.assertTupleEqual(a.shape, (10,)) + + def test_t2(self): + a = bm.random.t([1., 2.], size=None) + self.assertTupleEqual(a.shape, (2,)) From b704134a7eb42874c5d20e3e6e98edc7fb1ff687 Mon Sep 17 00:00:00 2001 From: chaoming Date: Tue, 17 May 2022 10:41:58 +0800 Subject: [PATCH 4/7] fix random tests --- .github/workflows/Windows_CI.yml | 2 +- brainpy/math/tests/test_random.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/Windows_CI.yml b/.github/workflows/Windows_CI.yml index 9ed248a9f..c2162c1be 100644 --- a/.github/workflows/Windows_CI.yml +++ b/.github/workflows/Windows_CI.yml @@ -29,7 +29,7 @@ jobs: python -m pip install --upgrade pip python -m pip install flake8 pytest python -m pip install numpy==1.21.0 - python -m pip install "jax[cpu]==0.3.5" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver + python -m pip install "jax[cpu]==0.3.7" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver python -m pip install -r requirements-win.txt python -m pip install tqdm brainpylib python setup.py install diff --git a/brainpy/math/tests/test_random.py b/brainpy/math/tests/test_random.py index 3e5592924..e764c57a6 100644 --- a/brainpy/math/tests/test_random.py +++ b/brainpy/math/tests/test_random.py @@ -335,13 +335,13 @@ def test_multivariate_normal1(self): self.assertTupleEqual(a.shape, b.shape) self.assertTupleEqual(a.shape, (3, 2)) - def test_multivariate_normal2(self): - a = np.random.multivariate_normal([1, 2], [[1, 3], [3, 1]]) - b = bm.random.multivariate_normal([1, 2], [[1, 3], [3, 1]], method='svd') - print(a) - print(b) - self.assertTupleEqual(a.shape, b.shape) - self.assertTupleEqual(a.shape, (2,)) + # def test_multivariate_normal2(self): + # a = np.random.multivariate_normal([1, 2], [[1, 3], [3, 1]]) + # b = bm.random.multivariate_normal([1, 2], [[1, 3], [3, 1]], method='svd') + # print(a) + # print(b) + # self.assertTupleEqual(a.shape, b.shape) + # self.assertTupleEqual(a.shape, (2,)) def test_negative_binomial(self): a = np.random.negative_binomial([3., 10.], 0.5) From 51e6df5a1154ee1d4dab98ba3cf11b6733dfffc2 Mon Sep 17 00:00:00 2001 From: chaoming Date: Tue, 17 May 2022 11:08:49 +0800 Subject: [PATCH 5/7] lower windows jaxlib version to 0.3.2 because lateasts have bugs --- .github/workflows/Windows_CI.yml | 2 +- brainpy/math/tests/test_random.py | 15 ++++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/.github/workflows/Windows_CI.yml b/.github/workflows/Windows_CI.yml index c2162c1be..1e4f427f6 100644 --- a/.github/workflows/Windows_CI.yml +++ b/.github/workflows/Windows_CI.yml @@ -29,7 +29,7 @@ jobs: python -m pip install --upgrade pip python -m pip install flake8 pytest python -m pip install numpy==1.21.0 - python -m pip install "jax[cpu]==0.3.7" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver + python -m pip install "jax[cpu]==0.3.2" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver python -m pip install -r requirements-win.txt python -m pip install tqdm brainpylib python setup.py install diff --git a/brainpy/math/tests/test_random.py b/brainpy/math/tests/test_random.py index e764c57a6..485449e2a 100644 --- a/brainpy/math/tests/test_random.py +++ b/brainpy/math/tests/test_random.py @@ -327,6 +327,7 @@ def test_multinominal3(self): a = f(100, (0.5, 0.6, 0.3), 2) def test_multivariate_normal1(self): + # self.skipTest('Windows jaxlib error') a = np.random.multivariate_normal([1, 2], [[1, 0], [0, 1]], size=3) b = bm.random.multivariate_normal([1, 2], [[1, 0], [0, 1]], size=3) print('test_multivariate_normal1') @@ -335,13 +336,13 @@ def test_multivariate_normal1(self): self.assertTupleEqual(a.shape, b.shape) self.assertTupleEqual(a.shape, (3, 2)) - # def test_multivariate_normal2(self): - # a = np.random.multivariate_normal([1, 2], [[1, 3], [3, 1]]) - # b = bm.random.multivariate_normal([1, 2], [[1, 3], [3, 1]], method='svd') - # print(a) - # print(b) - # self.assertTupleEqual(a.shape, b.shape) - # self.assertTupleEqual(a.shape, (2,)) + def test_multivariate_normal2(self): + a = np.random.multivariate_normal([1, 2], [[1, 3], [3, 1]]) + b = bm.random.multivariate_normal([1, 2], [[1, 3], [3, 1]], method='svd') + print(a) + print(b) + self.assertTupleEqual(a.shape, b.shape) + self.assertTupleEqual(a.shape, (2,)) def test_negative_binomial(self): a = np.random.negative_binomial([3., 10.], 0.5) From c61ef17d61ff36727bafbf170c81b377c99778ed Mon Sep 17 00:00:00 2001 From: chaoming Date: Tue, 17 May 2022 14:25:06 +0800 Subject: [PATCH 6/7] fix random.loggamma compatibility --- brainpy/math/random.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/brainpy/math/random.py b/brainpy/math/random.py index 5a676162c..b5ac71338 100644 --- a/brainpy/math/random.py +++ b/brainpy/math/random.py @@ -1343,8 +1343,22 @@ def orthogonal(n: int, size=None): return DEFAULT.orthogonal(n, size) -@wraps(jr.loggamma) def loggamma(a, size=None): + """Sample log-gamma random values. + + Parameters + ---------- + a: float, array_like + A float or array of floats broadcast-compatible with shape representing the parameter of the distribution. + size: optional, int, tuple of int + A tuple of nonnegative integers specifying the result shape. + Must be broadcast-compatible with `a`. The default (None) produces a result shape equal to `a.shape`. + + Returns + ------- + out: array_like + The sampled results. + """ return DEFAULT.loggamma(a, size) From 184d720949dcf1080c972c4bac55fcd6789ddfdf Mon Sep 17 00:00:00 2001 From: chaoming Date: Tue, 17 May 2022 16:38:05 +0800 Subject: [PATCH 7/7] fix: fix JaxArray op errors --- brainpy/math/jaxarray.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/brainpy/math/jaxarray.py b/brainpy/math/jaxarray.py index 69b659330..99fd98422 100644 --- a/brainpy/math/jaxarray.py +++ b/brainpy/math/jaxarray.py @@ -236,7 +236,7 @@ def __sub__(self, oc): return JaxArray(self._value - (oc._value if isinstance(oc, JaxArray) else oc)) def __rsub__(self, oc): - return JaxArray(self._value - (oc._value if isinstance(oc, JaxArray) else oc)) + return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) - self._value) def __isub__(self, oc): # a -= b @@ -249,7 +249,7 @@ def __mul__(self, oc): return JaxArray(self._value * (oc._value if isinstance(oc, JaxArray) else oc)) def __rmul__(self, oc): - return JaxArray(self._value * (oc._value if isinstance(oc, JaxArray) else oc)) + return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) * self._value) def __imul__(self, oc): # a *= b @@ -258,17 +258,17 @@ def __imul__(self, oc): self._value = self._value * (oc._value if isinstance(oc, JaxArray) else oc) return self - def __div__(self, oc): - return JaxArray(self._value / (oc._value if isinstance(oc, JaxArray) else oc)) + # def __div__(self, oc): + # return JaxArray(self._value / (oc._value if isinstance(oc, JaxArray) else oc)) def __rdiv__(self, oc): - return JaxArray(self._value / (oc._value if isinstance(oc, JaxArray) else oc)) + return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) / self._value) def __truediv__(self, oc): return JaxArray(self._value / (oc._value if isinstance(oc, JaxArray) else oc)) def __rtruediv__(self, oc): - return JaxArray(self._value / (oc._value if isinstance(oc, JaxArray) else oc)) + return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) / self._value) def __itruediv__(self, oc): # a /= b @@ -281,7 +281,7 @@ def __floordiv__(self, oc): return JaxArray(self._value // (oc._value if isinstance(oc, JaxArray) else oc)) def __rfloordiv__(self, oc): - return JaxArray(self._value // (oc._value if isinstance(oc, JaxArray) else oc)) + return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) // self._value) def __ifloordiv__(self, oc): # a //= b @@ -291,16 +291,16 @@ def __ifloordiv__(self, oc): return self def __divmod__(self, oc): - return JaxArray(self._value % (oc._value if isinstance(oc, JaxArray) else oc)) + return JaxArray(self._value.__divmod__(oc._value if isinstance(oc, JaxArray) else oc)) def __rdivmod__(self, oc): - return JaxArray(self._value % (oc._value if isinstance(oc, JaxArray) else oc)) + return JaxArray(self._value.__rdivmod__(oc._value if isinstance(oc, JaxArray) else oc)) def __mod__(self, oc): return JaxArray(self._value % (oc._value if isinstance(oc, JaxArray) else oc)) def __rmod__(self, oc): - return JaxArray(self._value % (oc._value if isinstance(oc, JaxArray) else oc)) + return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) % self._value) def __imod__(self, oc): # a %= b @@ -313,7 +313,7 @@ def __pow__(self, oc): return JaxArray(self._value ** (oc._value if isinstance(oc, JaxArray) else oc)) def __rpow__(self, oc): - return JaxArray(self._value ** (oc._value if isinstance(oc, JaxArray) else oc)) + return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) ** self._value) def __ipow__(self, oc): # a **= b @@ -326,7 +326,7 @@ def __matmul__(self, oc): return JaxArray(self._value @ (oc._value if isinstance(oc, JaxArray) else oc)) def __rmatmul__(self, oc): - return JaxArray(self._value @ (oc._value if isinstance(oc, JaxArray) else oc)) + return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) @ self._value) def __imatmul__(self, oc): # a @= b @@ -339,7 +339,7 @@ def __and__(self, oc): return JaxArray(self._value & (oc._value if isinstance(oc, JaxArray) else oc)) def __rand__(self, oc): - return JaxArray(self._value & (oc._value if isinstance(oc, JaxArray) else oc)) + return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) & self._value) def __iand__(self, oc): # a &= b @@ -352,7 +352,7 @@ def __or__(self, oc): return JaxArray(self._value | (oc._value if isinstance(oc, JaxArray) else oc)) def __ror__(self, oc): - return JaxArray(self._value | (oc._value if isinstance(oc, JaxArray) else oc)) + return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) | self._value) def __ior__(self, oc): # a |= b @@ -365,7 +365,7 @@ def __xor__(self, oc): return JaxArray(self._value ^ (oc._value if isinstance(oc, JaxArray) else oc)) def __rxor__(self, oc): - return JaxArray(self._value ^ (oc._value if isinstance(oc, JaxArray) else oc)) + return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) ^ self._value) def __ixor__(self, oc): # a ^= b @@ -378,7 +378,7 @@ def __lshift__(self, oc): return JaxArray(self._value << (oc._value if isinstance(oc, JaxArray) else oc)) def __rlshift__(self, oc): - return JaxArray(self._value << (oc._value if isinstance(oc, JaxArray) else oc)) + return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) << self._value) def __ilshift__(self, oc): # a <<= b @@ -391,7 +391,7 @@ def __rshift__(self, oc): return JaxArray(self._value >> (oc._value if isinstance(oc, JaxArray) else oc)) def __rrshift__(self, oc): - return JaxArray(self._value >> (oc._value if isinstance(oc, JaxArray) else oc)) + return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) >> self._value) def __irshift__(self, oc): # a >>= b