Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Gamma, Negative Binomial, and Generalized Poisson Distribution #145

Merged
merged 2 commits into from
Mar 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions appletree/randgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from appletree.utils import exporter


export, __all__ = exporter(export_self=False)

if jax.config.x64_enabled:
Expand Down Expand Up @@ -85,6 +86,32 @@ def poisson(key, lam, shape=()):
return key, rvs.astype(INT)


@export
@partial(jit, static_argnums=(3,))
def gamma(key, alpha, beta, shape=()):
"""Gamma distribution random sampler.

Args:
key: seed for random generator.
alpha: <jnp.array>-like shape in gamma distribution.
beta: <jnp.array>-like rate in normal distribution.
shape: output shape.
If not given, output has shape jnp.broadcast_shapes(jnp.shape(alpha), jnp.shape(beta)).

Returns:
an updated seed, random variables.

"""
key, seed = random.split(key)

shape = shape or jnp.broadcast_shapes(jnp.shape(alpha), jnp.shape(beta))
alpha = jnp.broadcast_to(alpha, shape).astype(FLOAT)
beta = jnp.broadcast_to(beta, shape).astype(FLOAT)

rvs = random.gamma(seed, alpha, shape=shape) / beta
return key, rvs.astype(FLOAT)


@export
@partial(jit, static_argnums=(3,))
def normal(key, mean, std, shape=()):
Expand Down Expand Up @@ -249,6 +276,78 @@ def _binomial_dispatch(seed, p, n):
return key, jnp.reshape(ret, shape)


@export
@partial(jit, static_argnums=(3, 4))
def negative_binomial(key, p, n, shape=()):
"""Negative binomial distribution random sampler. Using Gamma–Poisson mixture.

Args:
key: seed for random generator.
p: <jnp.array>-like probability of a single success in negative binomial distribution.
n: <jnp.array>-like number of successes in negative binomial distribution.
shape: output shape.
If not given, output has shape jnp.broadcast_shapes(jnp.shape(p), jnp.shape(n)).

Returns:
an updated seed, random variables.

References:
1. https://en.wikipedia.org/wiki/Negative_binomial_distribution#Gamma%E2%80%93Poisson_mixture # noqa
2. https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.nbinom.html

"""

key, lam = gamma(key, n, p / (1 - p), shape)

key, rvs = poisson(key, lam)
return key, rvs


@export
@partial(jit, static_argnums=(3,))
def generalized_poisson(key, lam, eta, shape=()):
"""Generalized Poisson Distribution(GPD) random sampler.

Args:
key: seed for random generator.
lam: <jnp.array>-like expectation(location parameter) in GPD.
eta: <jnp.array>-like scale parameter in GPD, within [0, 1).
shape: output shape. If not given, output has shape jnp.shape(lam).

Returns:
an updated seed, random variables.

References:
1. https://gist.github.com/danmackinlay/00e957b11c488539bd3e2a3804922b9d
2. https://search.r-project.org/CRAN/refmans/LaplacesDemon/html/dist.Generalized.Poisson.html # noqa

"""

shape = shape or jnp.broadcast_shapes(jnp.shape(lam), jnp.shape(eta))
lam = jnp.broadcast_to(lam, shape).astype(FLOAT)
eta = jnp.broadcast_to(eta, shape).astype(FLOAT)

key, population = poisson(key, lam * (1 - eta), shape)

offspring = jnp.copy(population)

def cond_fun(args):
return jnp.any(args[1] > 0)

def body_fun(args):
key, offspring = poisson(args[0], eta * args[1])
population = args[2] + offspring
return key, offspring, population

key, offspring, population = jax.lax.while_loop(
cond_fun,
body_fun,
(key, offspring, population),
)

return key, population.astype(INT)


@export
@jit
def uniform_key_vectorized(key):
Expand Down
Loading