Skip to content

Commit

Permalink
Add Gamma, Negative Binomial, and Generalized Poisson Distribution (#145
Browse files Browse the repository at this point in the history
)

* Add Gamma and Negative binomial distribution

* Implement Generalized Poisson Distribution(GPD)
  • Loading branch information
dachengx authored Mar 4, 2024
1 parent 138bb56 commit 6807966
Showing 1 changed file with 99 additions and 0 deletions.
99 changes: 99 additions & 0 deletions appletree/randgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from appletree.utils import exporter


export, __all__ = exporter(export_self=False)

if jax.config.x64_enabled:
Expand Down Expand Up @@ -85,6 +86,32 @@ def poisson(key, lam, shape=()):
return key, rvs.astype(INT)


@export
@partial(jit, static_argnums=(3,))
def gamma(key, alpha, beta, shape=()):
"""Gamma distribution random sampler.
Args:
key: seed for random generator.
alpha: <jnp.array>-like shape in gamma distribution.
beta: <jnp.array>-like rate in normal distribution.
shape: output shape.
If not given, output has shape jnp.broadcast_shapes(jnp.shape(alpha), jnp.shape(beta)).
Returns:
an updated seed, random variables.
"""
key, seed = random.split(key)

shape = shape or jnp.broadcast_shapes(jnp.shape(alpha), jnp.shape(beta))
alpha = jnp.broadcast_to(alpha, shape).astype(FLOAT)
beta = jnp.broadcast_to(beta, shape).astype(FLOAT)

rvs = random.gamma(seed, alpha, shape=shape) / beta
return key, rvs.astype(FLOAT)


@export
@partial(jit, static_argnums=(3,))
def normal(key, mean, std, shape=()):
Expand Down Expand Up @@ -249,6 +276,78 @@ def _binomial_dispatch(seed, p, n):
return key, jnp.reshape(ret, shape)


@export
@partial(jit, static_argnums=(3, 4))
def negative_binomial(key, p, n, shape=()):
"""Negative binomial distribution random sampler. Using Gamma–Poisson mixture.
Args:
key: seed for random generator.
p: <jnp.array>-like probability of a single success in negative binomial distribution.
n: <jnp.array>-like number of successes in negative binomial distribution.
shape: output shape.
If not given, output has shape jnp.broadcast_shapes(jnp.shape(p), jnp.shape(n)).
Returns:
an updated seed, random variables.
References:
1. https://en.wikipedia.org/wiki/Negative_binomial_distribution#Gamma%E2%80%93Poisson_mixture # noqa
2. https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.nbinom.html
"""

key, lam = gamma(key, n, p / (1 - p), shape)

key, rvs = poisson(key, lam)
return key, rvs


@export
@partial(jit, static_argnums=(3,))
def generalized_poisson(key, lam, eta, shape=()):
"""Generalized Poisson Distribution(GPD) random sampler.
Args:
key: seed for random generator.
lam: <jnp.array>-like expectation(location parameter) in GPD.
eta: <jnp.array>-like scale parameter in GPD, within [0, 1).
shape: output shape. If not given, output has shape jnp.shape(lam).
Returns:
an updated seed, random variables.
References:
1. https://gist.github.com/danmackinlay/00e957b11c488539bd3e2a3804922b9d
2. https://search.r-project.org/CRAN/refmans/LaplacesDemon/html/dist.Generalized.Poisson.html # noqa
"""

shape = shape or jnp.broadcast_shapes(jnp.shape(lam), jnp.shape(eta))
lam = jnp.broadcast_to(lam, shape).astype(FLOAT)
eta = jnp.broadcast_to(eta, shape).astype(FLOAT)

key, population = poisson(key, lam * (1 - eta), shape)

offspring = jnp.copy(population)

def cond_fun(args):
return jnp.any(args[1] > 0)

def body_fun(args):
key, offspring = poisson(args[0], eta * args[1])
population = args[2] + offspring
return key, offspring, population

key, offspring, population = jax.lax.while_loop(
cond_fun,
body_fun,
(key, offspring, population),
)

return key, population.astype(INT)


@export
@jit
def uniform_key_vectorized(key):
Expand Down

0 comments on commit 6807966

Please sign in to comment.