Skip to content

Binomial coefficient in jax #7044

Answered by jakevdp
maurorigo asked this question in Q&A
Jun 21, 2021 · 1 comments · 1 reply
Discussion options

You must be logged in to vote

I suspect the best way to compute binomial coefficients non-statically in JAX will be with gammaln; for example:

import scipy.special
from jax.scipy.special import gammaln
import jax.numpy as jnp
from jax import jit

def comb(N, k):
  return jnp.exp(gammaln(N + 1) - gammaln(k + 1) - gammaln(N - k + 1))

print(jit(comb)(5, 3))
# 10.00001
print(scipy.special.comb(5, 3))
# 10

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@maurorigo
Comment options

Answer selected by maurorigo
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants