-
Hi all, Thanks again for providing such a powerful framework for computing/optimization. I have some complex code that takes as input a class type, to enable flexibility in performing some inference. Specifically, I've ported over some of the Trust Region optimization from scipy for my project and have a function that looks like, def infer_trustregion(model: BaseModel,
cls: Type[BaseQuadraticProblem] = CGSteihaugSubproblem,
options: TrustRegionOptions = None) -> AlphaResults:
some_params =model.init()
sub_p = cls(some_params)
[...] I've already extended my One work-around would be to enable Is it possible in |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Yes, this is possible as long as the type argument is declared as static. Additionally, if you want a jitted function to return a custom object, you need to tell jax how to convert it to a pytree (basically a structure of arrays representing the object). Here's a quick example of creating a custom type, registering it with JAX, and then passing the type to a jitted function and returning the constructed object: from jax import jit, partial
from jax.tree_util import register_pytree_node_class
@register_pytree_node_class
class Model:
def tree_flatten(self):
# specify how to serialize model into a JAX pytree
return ((self.x,), None)
@classmethod
def tree_unflatten(cls, aux_data, children):
# specify how to build a model from a JAX pytree
return cls(*children)
def __init__(self, x):
self.x = x
def __repr__(self):
return f"Model({self.x})"
# Model class must be static.
@partial(jit, static_argnums=0)
def make_model(cls, y):
x = y ** 2
return cls(x)
print(make_model(Model, 2))
# Model(4) Hopefully you can modify this approach to your own application! |
Beta Was this translation helpful? Give feedback.
Yes, this is possible as long as the type argument is declared as static.
Additionally, if you want a jitted function to return a custom object, you need to tell jax how to convert it to a pytree (basically a structure of arrays representing the object).
Here's a quick example of creating a custom type, registering it with JAX, and then passing the type to a jitted function and returning the constructed object: