-
Is there a direct way to compute a binomial coefficient which depends on some variable in a |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Jun 22, 2021
Replies: 1 comment 1 reply
-
I suspect the best way to compute binomial coefficients non-statically in JAX will be with import scipy.special
from jax.scipy.special import gammaln
import jax.numpy as jnp
from jax import jit
def comb(N, k):
return jnp.exp(gammaln(N + 1) - gammaln(k + 1) - gammaln(N - k + 1))
print(jit(comb)(5, 3))
# 10.00001
print(scipy.special.comb(5, 3))
# 10 |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
maurorigo
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I suspect the best way to compute binomial coefficients non-statically in JAX will be with
gammaln
; for example: