summary(results) takes a long time to execute #23
-
Hello, I attach the python script I'm using and the data. I'm doing some basic tests for exoplanet detection with radial velocities. Thanks in advance |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 4 replies
-
Small update: I've noticed that the time it gets stuck is proportional to the number of live points I use. With only a few live points used in the nested sampling, |
Beta Was this translation helpful? Give feedback.
-
Thanks @nicochunger, neat model! So this problem here is due to the way that JAX works. When you run a JAX op it immediately passes control back to python and lets python continue until you actually need the result of the computation. So what's happening is that you're calling Then To make sure a computation is done, you could block on any of the arrays in the results = jit(ns)(random.PRNGKey(35111651), termination_frac=0.001)
results.num_likelihood_evaluations.block_until_ready() I personally accomplish this with plotting (which blocks until results are ready): results = jit(ns)(random.PRNGKey(35111651), termination_frac=0.001)
plot_diagnostics(results)
plot_cornerplot(results)
summary(results) For me when I do this complie+run takes 4 minutes, and summary takes 1 second. Note, you can rewrite your def true_anomaly(ma, ecc):
E = jnp.array(ma, copy=True) # ma.copy()
E0 = jnp.zeros_like(E)
# deltaE = jnp.array([1])
# This is usually a dynamic algorithm with a while loop unitl a stopping
# criteria is met. Not sure yet if I can do that wiht jax.
# For now I just fix the number of loops
def body(state):
(i, E) = state
E0 = E
diff = E - ecc * jnp.sin(E) - ma
deriv = 1 - ecc * jnp.cos(E)
deltaE = diff / deriv
E = E0 - deltaE
return (i + 1, E)
def cond(state):
(i, E) = state
return i < 5
(_, E) = while_loop(
cond,
body,
init_val=(jnp.asarray(0), E))
v = 2 * jnp.arctan(jnp.sqrt((1 + ecc) / (1 - ecc)) * jnp.tan(E / 2))
return v |
Beta Was this translation helpful? Give feedback.
Thanks @nicochunger, neat model! So this problem here is due to the way that JAX works. When you run a JAX op it immediately passes control back to python and lets python continue until you actually need the result of the computation. So what's happening is that you're calling
results = jit(ns)(random.PRNGKey(35111651), termination_frac=0.001)
and then it passes back control to python once it's done compiling, but only when it's started running. For me compilation takes 40 seconds.Then
summary
blocks until theresults
is actually done, which gives you the apparent wait time.To make sure a computation is done, you could block on any of the arrays in the
results
like this: