You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In a desire for a simple way to implement state threading in a large preexisting JAX codebase with heavy use of grad and JIT, I wrote a new interpreter and set of primitives (linked here) which allows this process to be automated.
I'm aware that the JAX team have their own work on run_state to enable stateful computations, and there's a plethora of third party solutions to this problem - my approach was motivated by making the fewest possible changes to existing code.
The principle is to add calls to new get_state and set_state primitives in existing code, and then at the top level wrap everything in a stateful transformation. This traces to jaxpr, identifies all instances of these state primitives, creates a new jaxpr with these state values threaded through the computation, and then returns the new (still pure) callable jaxpr. This works with existing jit, grad etc transformations.
My question is - would there be any interest in integrating something like this as a new interpreter in JAX, similar to sparse or checkify? Does this approach fit with JAX's philosophy or is it too "magic"?
Adapting my current implementation would take some time (my first interpreter, built very quickly, doesn't currently follow the structure of existing internal JAX interpreters, and there are definitely still breaking edge cases), so I'd welcome feedback/ thoughts before going further.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
In a desire for a simple way to implement state threading in a large preexisting JAX codebase with heavy use of grad and JIT, I wrote a new interpreter and set of primitives (linked here) which allows this process to be automated.
I'm aware that the JAX team have their own work on
run_state
to enable stateful computations, and there's a plethora of third party solutions to this problem - my approach was motivated by making the fewest possible changes to existing code.The principle is to add calls to new
get_state
andset_state
primitives in existing code, and then at the top level wrap everything in astateful
transformation. This traces to jaxpr, identifies all instances of these state primitives, creates a new jaxpr with these state values threaded through the computation, and then returns the new (still pure) callable jaxpr. This works with existing jit, grad etc transformations.My question is - would there be any interest in integrating something like this as a new interpreter in JAX, similar to
sparse
orcheckify
? Does this approach fit with JAX's philosophy or is it too "magic"?Adapting my current implementation would take some time (my first interpreter, built very quickly, doesn't currently follow the structure of existing internal JAX interpreters, and there are definitely still breaking edge cases), so I'd welcome feedback/ thoughts before going further.
Beta Was this translation helpful? Give feedback.
All reactions