-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Error when passing two different pytrees with the same structure to a JIT-compiled function #103
Comments
This is basically the updated version of the problem occurring in the previous OOP APIs in #84 cc @flferretti |
The example is failing with the following error for me:
Installed packages: Click
Perhaps a specific rod version is required? |
I guess it needs ami-iit/rod@0ce60b2 included in ami-iit/rod#27. |
Also #83 actually needed ami-iit/rod#27 for proper color support, we need to bump the dependencies before releasing jaxsim |
Playing a bit around, I was able to reduce the MWE, to remove jaxsim at all: import dataclasses
import jax
import jax.numpy as jnp
import jax_dataclasses
from jax_dataclasses import Static
import pathlib
import numpy.typing as npt
import dataclasses
from typing import Dict, Union
import jax.lax
import jax.numpy as jnp
import jax_dataclasses
import numpy as np
from jax_dataclasses import Static
@jax_dataclasses.pytree_dataclass
class MWE103GroundContact:
body: Static[npt.NDArray] = dataclasses.field(
default_factory=lambda: np.array([])
)
@staticmethod
def build_from() -> "MWE103GroundContact":
# Build the object
gc = MWE103GroundContact(body=np.zeros(8))
return gc
def local_total_mass(model: MWE103GroundContact) -> float:
return 1.0
dummy_contact1 = MWE103GroundContact.build_from()
dummy_contact2 = MWE103GroundContact.build_from()
local_total_mass_jit1 = jax.jit(local_total_mass)
local_total_mass_jit2 = jax.jit(local_total_mass)
_ = local_total_mass_jit1(model=dummy_contact1)
_ = local_total_mass_jit2(model=dummy_contact2) Somehow the problem is related to the
|
This is aligned with what was found in #84 (comment) . |
I think the answer to this question clarifies everything: jax-ml/jax#19547 (comment) . |
I think I described the root issue behind all the |
Doing some additional tests, I found out that a similar problem is raised. This time, it seems to be related to the integrators: Traceback (most recent call last):
File "/home/flferretti/git/comodo/src/comodo/jaxsimSimulator/test.py", line 55, in <module>
integrator_state2 = integrator2.init(x0=data2.state, t0=0, dt=1e-3)
File "/home/flferretti/jaxsim/src/jaxsim/integrators/common.py", line 166, in init
_ = integrator(x0, t0, dt, **kwargs)
File "/home/flferretti/jaxsim/src/jaxsim/integrators/common.py", line 283, in __call__
z = self._compute_next_state(x0=x0, t0=t0, dt=dt, **kwargs)
File "/home/flferretti/jaxsim/src/jaxsim/integrators/common.py", line 417, in _compute_next_state
K, _ = jax.lax.scan(
File "/home/flferretti/jaxsim/src/jaxsim/integrators/common.py", line 403, in scan_body
ki = jax.lax.cond(
File "/home/flferretti/jaxsim/src/jaxsim/integrators/common.py", line 374, in <lambda>
get_ẋ0 = lambda: self.params.get("dxdt0", f(x0, t0)[0])
File "/home/flferretti/jaxsim/src/jaxsim/integrators/common.py", line 365, in <lambda>
f = lambda x, t: self.dynamics(x=x, t=t, **kwargs)
File "/home/flferretti/jaxsim/src/jaxsim/api/ode.py", line 69, in f
return system_dynamics(
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[1]..
The error occurred while tracing the function <lambda> at /home/flferretti/jaxsim/src/jaxsim/integrators/common.py:374 for cond. This value became a tracer due to JAX operations on these lines:
operation a:bool[1] = eq b c
from line /home/flferretti/jaxsim/src/jaxsim/api/ode.py:69 (f)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError MWEimport jax.numpy as jnp
import jaxsim.api as js
import rod.builder.primitives
import rod.urdf.exporter
from jaxsim import integrators
# Create on-the-fly a ROD model of a box.
rod_model = (
rod.builder.primitives.BoxBuilder(x=0.3, y=0.2, z=0.1, mass=1.0, name="box")
.build_model()
.add_link()
.add_inertial()
.add_visual()
.add_collision()
.build()
)
# Export the URDF string.
urdf_string = rod.urdf.exporter.UrdfExporter.sdf_to_urdf_string(
sdf=rod_model, pretty=True
)
model1 = js.model.JaxSimModel.build_from_model_description(
model_description=urdf_string,
is_urdf=True,
)
model2 = js.model.JaxSimModel.build_from_model_description(
model_description=urdf_string,
is_urdf=True,
)
# Build the data
data1 = js.data.JaxSimModelData.build(model=model1)
data2 = js.data.JaxSimModelData.build(model=model2)
# Create the integrators
integrator1 = integrators.fixed_step.Heun2SO3.build(
dynamics=js.ode.wrap_system_dynamics_for_integration(
model=model1,
data=data1,
system_dynamics=js.ode.system_dynamics,
),
)
integrator2 = integrators.fixed_step.Heun2SO3.build(
dynamics=js.ode.wrap_system_dynamics_for_integration(
model=model2,
data=data2,
system_dynamics=js.ode.system_dynamics,
),
)
# ! Try to initialize the integrator
integrator_state1 = integrator1.init(x0=data1.state, t0=0, dt=1e-3)
integrator_state2 = integrator2.init(x0=data2.state, t0=0, dt=1e-3) |
After some testing, I found out that the objects Setting |
Furthermore, I believe that we should pay attention to the |
Thanks for the investigation @flferretti. Since |
Also with the new functional APIs, the following problem still affects our code:
A MWE to reproduce the problem is the following:
I've already tried to play a bit with
JaxSimModel.__eq__
andJaxSimModel.__hash__
with no luck.Marking this issue as help wanted in case this problem becomes annoying to someone.
Interestingly, running it in a context in which
jax.jit
is disabled, works fine:I suspect that there is something strange in the jax cache that we are either overlooking or not handling properly.
The text was updated successfully, but these errors were encountered: