From 68079660c58c89a692a9374c52c1ff3dd4a3794f Mon Sep 17 00:00:00 2001 From: Dacheng Xu Date: Mon, 4 Mar 2024 22:56:42 +0800 Subject: [PATCH] Add Gamma, Negative Binomial, and Generalized Poisson Distribution (#145) * Add Gamma and Negative binomial distribution * Implement Generalized Poisson Distribution(GPD) --- appletree/randgen.py | 99 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) diff --git a/appletree/randgen.py b/appletree/randgen.py index 53be1224..f1235653 100644 --- a/appletree/randgen.py +++ b/appletree/randgen.py @@ -11,6 +11,7 @@ from appletree.utils import exporter + export, __all__ = exporter(export_self=False) if jax.config.x64_enabled: @@ -85,6 +86,32 @@ def poisson(key, lam, shape=()): return key, rvs.astype(INT) +@export +@partial(jit, static_argnums=(3,)) +def gamma(key, alpha, beta, shape=()): + """Gamma distribution random sampler. + + Args: + key: seed for random generator. + alpha: -like shape in gamma distribution. + beta: -like rate in normal distribution. + shape: output shape. + If not given, output has shape jnp.broadcast_shapes(jnp.shape(alpha), jnp.shape(beta)). + + Returns: + an updated seed, random variables. + + """ + key, seed = random.split(key) + + shape = shape or jnp.broadcast_shapes(jnp.shape(alpha), jnp.shape(beta)) + alpha = jnp.broadcast_to(alpha, shape).astype(FLOAT) + beta = jnp.broadcast_to(beta, shape).astype(FLOAT) + + rvs = random.gamma(seed, alpha, shape=shape) / beta + return key, rvs.astype(FLOAT) + + @export @partial(jit, static_argnums=(3,)) def normal(key, mean, std, shape=()): @@ -249,6 +276,78 @@ def _binomial_dispatch(seed, p, n): return key, jnp.reshape(ret, shape) +@export +@partial(jit, static_argnums=(3, 4)) +def negative_binomial(key, p, n, shape=()): + """Negative binomial distribution random sampler. Using Gamma–Poisson mixture. + + Args: + key: seed for random generator. + p: -like probability of a single success in negative binomial distribution. + n: -like number of successes in negative binomial distribution. + shape: output shape. + If not given, output has shape jnp.broadcast_shapes(jnp.shape(p), jnp.shape(n)). + + Returns: + an updated seed, random variables. + + References: + 1. https://en.wikipedia.org/wiki/Negative_binomial_distribution#Gamma%E2%80%93Poisson_mixture # noqa + 2. https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.nbinom.html + + """ + + key, lam = gamma(key, n, p / (1 - p), shape) + + key, rvs = poisson(key, lam) + return key, rvs + + +@export +@partial(jit, static_argnums=(3,)) +def generalized_poisson(key, lam, eta, shape=()): + """Generalized Poisson Distribution(GPD) random sampler. + + Args: + key: seed for random generator. + lam: -like expectation(location parameter) in GPD. + eta: -like scale parameter in GPD, within [0, 1). + shape: output shape. If not given, output has shape jnp.shape(lam). + + Returns: + an updated seed, random variables. + + References: + 1. https://gist.github.com/danmackinlay/00e957b11c488539bd3e2a3804922b9d + 2. https://search.r-project.org/CRAN/refmans/LaplacesDemon/html/dist.Generalized.Poisson.html # noqa + + """ + + shape = shape or jnp.broadcast_shapes(jnp.shape(lam), jnp.shape(eta)) + lam = jnp.broadcast_to(lam, shape).astype(FLOAT) + eta = jnp.broadcast_to(eta, shape).astype(FLOAT) + + key, population = poisson(key, lam * (1 - eta), shape) + + offspring = jnp.copy(population) + + def cond_fun(args): + return jnp.any(args[1] > 0) + + def body_fun(args): + key, offspring = poisson(args[0], eta * args[1]) + population = args[2] + offspring + return key, offspring, population + + key, offspring, population = jax.lax.while_loop( + cond_fun, + body_fun, + (key, offspring, population), + ) + + return key, population.astype(INT) + + @export @jit def uniform_key_vectorized(key):