Skip to content

Is it possible to pass Type[class] as parameter for a jit'ed function? #5770

Answered by jakevdp
quattro asked this question in Q&A
Discussion options

You must be logged in to vote

Yes, this is possible as long as the type argument is declared as static.

Additionally, if you want a jitted function to return a custom object, you need to tell jax how to convert it to a pytree (basically a structure of arrays representing the object).

Here's a quick example of creating a custom type, registering it with JAX, and then passing the type to a jitted function and returning the constructed object:

from jax import jit, partial
from jax.tree_util import register_pytree_node_class

@register_pytree_node_class
class Model:
  def tree_flatten(self):
    # specify how to serialize model into a JAX pytree
    return ((self.x,), None)

  @classmethod
  def tree_unflatten(cls, aux_data

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@quattro
Comment options

Answer selected by quattro
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants