Skip to content

summary(results) takes a long time to execute #23

Answered by Joshuaalbert
nicochunger asked this question in Q&A
Discussion options

You must be logged in to vote

Thanks @nicochunger, neat model! So this problem here is due to the way that JAX works. When you run a JAX op it immediately passes control back to python and lets python continue until you actually need the result of the computation. So what's happening is that you're calling results = jit(ns)(random.PRNGKey(35111651), termination_frac=0.001) and then it passes back control to python once it's done compiling, but only when it's started running. For me compilation takes 40 seconds.

Then summary blocks until the results is actually done, which gives you the apparent wait time.

To make sure a computation is done, you could block on any of the arrays in the results like this:

results = jit(ns)(

Replies: 2 comments 4 replies

Comment options

You must be logged in to vote
2 replies
@Joshuaalbert
Comment options

@nicochunger
Comment options

Comment options

You must be logged in to vote
2 replies
@nicochunger
Comment options

@Joshuaalbert
Comment options

Answer selected by Joshuaalbert
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
v0 (Pre-release) Pre release
2 participants