Why jax choose jaxpr instead of python ast ? #7013
-
Hi, I am learning jax and some other python jit project, e.g. llvm-npcomp/numba, by now. And I noticed that jax convert python source code to jaxpr as an IR by tracing the code execution hisotry, and at same time, llvm-npcomp/numba do the same thing by heavily using the For I am just a newbee to these jit techs, I couldn't see obivious benifits from "tracing the code", both "python ast" and "tracing" looks fine to me. Was there some problem that only "tracing" solution could handle, while "python ast" not? Sincerely. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
This is an interesting question: I think the main reason is that the Python AST is far too general: JAX does not implement everything that Python implements: for example, transformable JAX programs are linear and functional: JAX variables cannot be modified in-place, and JAX does not have general control flow like |
Beta Was this translation helpful? Give feedback.
This is an interesting question: I think the main reason is that the Python AST is far too general: JAX does not implement everything that Python implements: for example, transformable JAX programs are linear and functional: JAX variables cannot be modified in-place, and JAX does not have general control flow like
if/else
statements or loops! For that reason, the Python AST is not a great fit for representing the set of programs that JAX is able to transform and compile, and a more limited JAX-specific IR is a better choice. Does that answer your question?