diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index 5782eb726..7dc3de827 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -30,7 +30,7 @@ @jax_dataclasses.pytree_dataclass class JaxSimModelData(common.ModelDataWithVelocityRepresentation): """ - Class containing the state of a `JaxSimModel` object. + Class containing the data of a `JaxSimModel` object. """ state: ODEState @@ -43,6 +43,24 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation): default_factory=lambda: jnp.array(0, dtype=jnp.uint64) ) + def __hash__(self) -> int: + + return hash( + ( + hash(self.state), + hash(tuple(self.gravity.flatten().tolist())), + hash(self.soft_contacts_params), + hash(jnp.atleast_1d(self.time_ns).flatten().tolist()), + ) + ) + + def __eq__(self, other: JaxSimModelData) -> bool: + + if not isinstance(other, JaxSimModelData): + return False + + return hash(self) == hash(other) + def valid(self, model: js.model.JaxSimModel | None = None) -> bool: """ Check if the current state is valid for the given model. diff --git a/src/jaxsim/api/frame.py b/src/jaxsim/api/frame.py index e8c0a23a0..0db020bf5 100644 --- a/src/jaxsim/api/frame.py +++ b/src/jaxsim/api/frame.py @@ -30,7 +30,7 @@ def idx_of_parent_link(model: js.model.JaxSimModel, *, frame_idx: jtp.IntLike) - """ # Get the intermediate representation parsed from the model description. - ir = model.description.get() + ir = model.description # Extract the indices of the frame and the link it is attached to. F = ir.frames[frame_idx - model.number_of_links()] @@ -51,7 +51,7 @@ def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> int: The index of the frame. """ - frame_names = np.array([frame.name for frame in model.description.get().frames]) + frame_names = np.array([frame.name for frame in model.description.frames]) if frame_name in frame_names: idx_in_list = np.argwhere(frame_names == frame_name) @@ -72,7 +72,7 @@ def idx_to_name(model: js.model.JaxSimModel, *, frame_index: jtp.IntLike) -> str The name of the frame. """ - return model.description.get().frames[frame_index - model.number_of_links()].name + return model.description.frames[frame_index - model.number_of_links()].name @functools.partial(jax.jit, static_argnames=["frame_names"]) @@ -144,7 +144,7 @@ def transform( W_H_L = js.link.transform(model=model, data=data, link_index=L) # Get the static frame pose wrt the parent link. - frame = model.description.get().frames[frame_index - model.number_of_links()] + frame = model.description.frames[frame_index - model.number_of_links()] L_H_F = frame.pose # Combine the transforms computing the frame pose. diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index 622af2498..01c2d4eca 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -11,7 +11,7 @@ import jaxsim.typing as jtp from jaxsim.math import Inertia, JointModel, supported_joint_motion from jaxsim.parsers.descriptions import JointDescription, ModelDescription -from jaxsim.utils import JaxsimDataclass +from jaxsim.utils import HashedNumpyArray, JaxsimDataclass @jax_dataclasses.pytree_dataclass @@ -32,8 +32,8 @@ class KynDynParameters(JaxsimDataclass): # Static link_names: Static[tuple[str]] - parent_array: Static[jtp.Vector] - support_body_array_bool: Static[jtp.Matrix] + _parent_array: Static[HashedNumpyArray] + _support_body_array_bool: Static[HashedNumpyArray] # Links link_parameters: LinkParameters @@ -45,6 +45,14 @@ class KynDynParameters(JaxsimDataclass): joint_model: JointModel joint_parameters: JointParameters | None + @property + def parent_array(self) -> jtp.Vector: + return self._parent_array.get() + + @property + def support_body_array_bool(self) -> jtp.Matrix: + return self._support_body_array_bool.get() + @staticmethod def build(model_description: ModelDescription) -> KynDynParameters: """ @@ -191,8 +199,8 @@ def scan_body(carry: tuple, i: jtp.Int) -> tuple[tuple, None]: return KynDynParameters( link_names=tuple(l.name for l in ordered_links), - parent_array=parent_array, - support_body_array_bool=support_body_array_bool, + _parent_array=HashedNumpyArray(array=parent_array), + _support_body_array_bool=HashedNumpyArray(array=support_body_array_bool), link_parameters=link_parameters, joint_model=joint_model, joint_parameters=joint_parameters, @@ -204,23 +212,18 @@ def __eq__(self, other: KynDynParameters) -> bool: if not isinstance(other, KynDynParameters): return False - equal = True - equal = equal and self.number_of_links() == other.number_of_links() - equal = equal and self.number_of_joints() == other.number_of_joints() - equal = equal and jnp.allclose(self.parent_array, other.parent_array) - - return equal + return hash(self) == hash(other) def __hash__(self) -> int: - h = ( - hash(self.number_of_links()), - hash(self.number_of_joints()), - hash(tuple(self.parent_array.tolist())), + return hash( + ( + hash(self.number_of_links()), + hash(self.number_of_joints()), + hash(tuple(jnp.atleast_1d(self.parent_array).flatten().tolist())), + ) ) - return hash(h) - # ============================= # Helpers to extract parameters # ============================= @@ -388,7 +391,7 @@ def joint_transforms_and_motion_subspaces( pre_H_suc_J, S_J = jax.vmap(supported_joint_motion)( jnp.array(self.joint_model.joint_types[1:]).astype(int), jnp.array(joint_positions), - jnp.array(self.joint_model.joint_axis), + jnp.array([j.axis for j in self.joint_model.joint_axis]), ) # Extract the transforms and motion subspaces of the joints. diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 272c328b6..a3a154268 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -32,18 +32,37 @@ class JaxSimModel(JaxsimDataclass): terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field( default=jaxsim.terrain.FlatTerrain(), repr=False, compare=False, hash=False ) + kin_dyn_parameters: js.kin_dyn_parameters.KynDynParameters | None = ( + dataclasses.field(default=None, repr=False, compare=False, hash=False) + ) built_from: Static[str | pathlib.Path | rod.Model | None] = dataclasses.field( default=None, repr=False, compare=False, hash=False ) - description: Static[ + _description: Static[ HashlessObject[jaxsim.parsers.descriptions.ModelDescription | None] ] = dataclasses.field(default=None, repr=False, compare=False, hash=False) - kin_dyn_parameters: js.kin_dyn_parameters.KynDynParameters | None = ( - dataclasses.field(default=None, repr=False, compare=False, hash=False) - ) + @property + def description(self) -> jaxsim.parsers.descriptions.ModelDescription: + return self._description.get() + + def __eq__(self, other: JaxSimModel) -> bool: + + if not isinstance(other, JaxSimModel): + return False + + return hash(self) == hash(other) + + def __hash__(self) -> int: + + return hash( + ( + hash(self.model_name), + hash(self.kin_dyn_parameters), + ) + ) # ======================== # Initialization and state @@ -137,7 +156,7 @@ def build( # Build the model model = JaxSimModel( model_name=model_name, - description=HashlessObject(obj=model_description), + _description=HashlessObject(obj=model_description), kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build( model_description=model_description ), @@ -260,7 +279,7 @@ def frame_names(self) -> tuple[str, ...]: The names of the links in the model. """ - return tuple([frame.name for frame in self.description.get().frames]) + return tuple(frame.name for frame in self.description.frames) # ===================== @@ -297,7 +316,7 @@ def reduce( # Copy the model description with a deep copy of the joints. intermediate_description = dataclasses.replace( - model.description.get(), joints=copy.deepcopy(model.description.get().joints) + model.description, joints=copy.deepcopy(model.description.joints) ) # Update the initial position of the joints. diff --git a/src/jaxsim/api/ode_data.py b/src/jaxsim/api/ode_data.py index db8edf4a9..aec0d7277 100644 --- a/src/jaxsim/api/ode_data.py +++ b/src/jaxsim/api/ode_data.py @@ -281,6 +281,24 @@ class PhysicsModelState(JaxsimDataclass): default_factory=lambda: jnp.zeros(3) ) + def __hash__(self) -> int: + + return hash( + ( + hash(tuple(jnp.atleast_1d(self.joint_positions.flatten().tolist()))), + hash(tuple(jnp.atleast_1d(self.joint_velocities.flatten().tolist()))), + hash(tuple(self.base_position.flatten().tolist())), + hash(tuple(self.base_quaternion.flatten().tolist())), + ) + ) + + def __eq__(self, other: PhysicsModelState) -> bool: + + if not isinstance(other, PhysicsModelState): + return False + + return hash(self) == hash(other) + @staticmethod def build_from_jaxsim_model( model: js.model.JaxSimModel | None = None, @@ -593,6 +611,19 @@ class SoftContactsState(JaxsimDataclass): tangential_deformation: jtp.Matrix + def __hash__(self) -> int: + + return hash( + tuple(jnp.atleast_1d(self.tangential_deformation.flatten()).tolist()) + ) + + def __eq__(self, other: SoftContactsState) -> bool: + + if not isinstance(other, SoftContactsState): + return False + + return hash(self) == hash(other) + @staticmethod def build_from_jaxsim_model( model: js.model.JaxSimModel | None = None, diff --git a/src/jaxsim/math/joint_model.py b/src/jaxsim/math/joint_model.py index 2e29553e9..106409b88 100644 --- a/src/jaxsim/math/joint_model.py +++ b/src/jaxsim/math/joint_model.py @@ -39,7 +39,7 @@ class JointModel: joint_dofs: Static[tuple[int, ...]] joint_names: Static[tuple[str, ...]] - joint_types: Static[tuple[JointType, ...]] + joint_types: Static[tuple[int, ...]] joint_axis: Static[tuple[JointGenericAxis, ...]] @staticmethod @@ -109,7 +109,7 @@ def build(description: ModelDescription) -> JointModel: joint_dofs=tuple([base_dofs] + [int(1) for _ in ordered_joints]), joint_names=tuple(["world_to_base"] + [j.name for j in ordered_joints]), joint_types=tuple([JointType.Fixed] + [j.jtype for j in ordered_joints]), - joint_axis=tuple([j.axis for j in ordered_joints]), + joint_axis=tuple(JointGenericAxis(axis=j.axis) for j in ordered_joints), ) def parent_H_child( @@ -201,7 +201,7 @@ def predecessor_H_successor( pre_H_suc, S = supported_joint_motion( self.joint_types[joint_index], joint_position, - self.joint_axis[joint_index], + self.joint_axis[joint_index].axis, ) return pre_H_suc, S @@ -224,9 +224,9 @@ def successor_H_child(self, joint_index: jtp.IntLike) -> jtp.Matrix: @jax.jit def supported_joint_motion( - joint_type: JointType, + joint_type: jtp.IntLike, joint_position: jtp.VectorLike, - joint_axis: JointGenericAxis, + joint_axis: jtp.VectorLike | None = None, /, ) -> tuple[jtp.Matrix, jtp.Array]: """ @@ -234,8 +234,8 @@ def supported_joint_motion( Args: joint_type: The type of the joint. - joint_axis: The axis of rotation or translation of the joint. joint_position: The position of the joint. + joint_axis: The optional 3D axis of rotation or translation of the joint. Returns: A tuple containing the homogeneous transformation and the motion subspace. @@ -244,26 +244,33 @@ def supported_joint_motion( # Prepare the joint position s = jnp.array(joint_position).astype(float) - def compute_F(): + def compute_F() -> tuple[jtp.Matrix, jtp.Array]: return jaxlie.SE3.identity(), jnp.zeros(shape=(6, 1)) - def compute_R(): + def compute_R() -> tuple[jtp.Matrix, jtp.Array]: + + # Get the additional argument specifying the joint axis. + # This is a metadata required by only some joint types. + axis = jnp.array(joint_axis).astype(float).squeeze() + pre_H_suc = jaxlie.SE3.from_rotation( - rotation=jaxlie.SO3.from_matrix( - Rotation.from_axis_angle(vector=s * joint_axis) - ) + rotation=jaxlie.SO3.from_matrix(Rotation.from_axis_angle(vector=s * axis)) ) - S = jnp.vstack(jnp.hstack([jnp.zeros(3), joint_axis.squeeze()])) + S = jnp.vstack(jnp.hstack([jnp.zeros(3), axis])) + return pre_H_suc, S - def compute_P(): - pre_H_suc = jaxlie.SE3.from_rotation_and_translation( - rotation=jaxlie.SO3.identity(), - translation=jnp.array(s * joint_axis), - ) + def compute_P() -> tuple[jtp.Matrix, jtp.Array]: + + # Get the additional argument specifying the joint axis. + # This is a metadata required by only some joint types. + axis = jnp.array(joint_axis).astype(float).squeeze() + + pre_H_suc = jaxlie.SE3.from_translation(translation=jnp.array(s * axis)) + + S = jnp.vstack(jnp.hstack([axis, jnp.zeros(3)])) - S = jnp.vstack(jnp.hstack([joint_axis.squeeze(), jnp.zeros(3)])) return pre_H_suc, S pre_H_suc, S = jax.lax.switch( diff --git a/src/jaxsim/parsers/descriptions/joint.py b/src/jaxsim/parsers/descriptions/joint.py index 9abf7d257..a2139e0e6 100644 --- a/src/jaxsim/parsers/descriptions/joint.py +++ b/src/jaxsim/parsers/descriptions/joint.py @@ -30,9 +30,11 @@ class JointGenericAxis: axis: jtp.Vector def __hash__(self) -> int: - return hash((tuple(np.array(self.axis).tolist()))) + + return hash(tuple(self.axis.tolist())) def __eq__(self, other: JointGenericAxis) -> bool: + if not isinstance(other, JointGenericAxis): return False diff --git a/src/jaxsim/rbda/soft_contacts.py b/src/jaxsim/rbda/soft_contacts.py index f54c37057..e20722178 100644 --- a/src/jaxsim/rbda/soft_contacts.py +++ b/src/jaxsim/rbda/soft_contacts.py @@ -29,6 +29,23 @@ class SoftContactsParams(JaxsimDataclass): default_factory=lambda: jnp.array(0.5, dtype=float) ) + def __hash__(self) -> int: + + return hash( + ( + hash(tuple(jnp.atleast_1d(self.K).flatten().tolist())), + hash(tuple(jnp.atleast_1d(self.D).flatten().tolist())), + hash(tuple(jnp.atleast_1d(self.mu).flatten().tolist())), + ) + ) + + def __eq__(self, other: SoftContactsParams) -> bool: + + if not isinstance(other, SoftContactsParams): + return NotImplemented + + return hash(self) == hash(other) + @staticmethod def build( K: jtp.FloatLike = 1e6, D: jtp.FloatLike = 2_000, mu: jtp.FloatLike = 0.5 diff --git a/src/jaxsim/utils/__init__.py b/src/jaxsim/utils/__init__.py index 8d55d4ecc..d0b881ceb 100644 --- a/src/jaxsim/utils/__init__.py +++ b/src/jaxsim/utils/__init__.py @@ -1,5 +1,5 @@ from jax_dataclasses._copy_and_mutate import _Mutability as Mutability -from .hashless import HashlessObject from .jaxsim_dataclass import JaxsimDataclass from .tracing import not_tracing, tracing +from .wrappers import HashedNumpyArray, HashlessObject diff --git a/src/jaxsim/utils/hashless.py b/src/jaxsim/utils/hashless.py deleted file mode 100644 index 9a48fb437..000000000 --- a/src/jaxsim/utils/hashless.py +++ /dev/null @@ -1,18 +0,0 @@ -from __future__ import annotations - -import dataclasses -from typing import Generic, TypeVar - -T = TypeVar("T") - - -@dataclasses.dataclass -class HashlessObject(Generic[T]): - - obj: T - - def get(self: HashlessObject[T]) -> T: - return self.obj - - def __hash__(self) -> int: - return 0 diff --git a/src/jaxsim/utils/wrappers.py b/src/jaxsim/utils/wrappers.py new file mode 100644 index 000000000..fd0a29dd8 --- /dev/null +++ b/src/jaxsim/utils/wrappers.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import dataclasses +from typing import Generic, TypeVar + +import jax +import jax_dataclasses +import numpy as np +import numpy.typing as npt + +T = TypeVar("T") + + +@dataclasses.dataclass +class HashlessObject(Generic[T]): + """ + A class that wraps an object and makes it hashless. + + This is useful for creating particular JAX pytrees. + For example, to create a pytree with a static leaf that is ignored + by JAX when it compares two instances to trigger a JIT recompilation. + """ + + obj: T + + def get(self: HashlessObject[T]) -> T: + return self.obj + + def __hash__(self) -> int: + + return 0 + + def __eq__(self, other: HashlessObject[T]) -> bool: + + if not isinstance(other, HashlessObject) and isinstance( + other.get(), type(self.get()) + ): + return False + + return hash(self) == hash(other) + + +@jax_dataclasses.pytree_dataclass +class HashedNumpyArray: + """ + A class that wraps a numpy array and makes it hashable. + + This is useful for creating particular JAX pytrees. + For example, to create a pytree with a plain NumPy or JAX NumPy array as static leaf. + + Note: + Calculating with the wrapper class the hash of a very large array can be + very expensive. If the array is large and only the equality operator is needed, + set `large_array=True` to use a faster comparison method. + """ + + array: jax.Array | npt.NDArray + + large_array: jax_dataclasses.Static[bool] = dataclasses.field( + default=False, repr=False, compare=False, hash=False + ) + + def get(self) -> jax.Array | npt.NDArray: + return self.array + + def __hash__(self) -> int: + + return hash(tuple(np.atleast_1d(self.array).flatten().tolist())) + + def __eq__(self, other: HashedNumpyArray) -> bool: + + if not isinstance(other, HashedNumpyArray): + return False + + if self.large_array: + return np.array_equal(self.array, other.array) + + return hash(self) == hash(other) diff --git a/tests/conftest.py b/tests/conftest.py index 843b2c300..b16d08208 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -177,18 +177,19 @@ def jaxsim_model_sphere() -> js.model.JaxSimModel: @pytest.fixture(scope="session") -def jaxsim_model_ergocub() -> js.model.JaxSimModel: +def ergocub_model_description_path() -> pathlib.Path: """ - Fixture providing the JaxSim model of the ErgoCub robot. + Fixture providing the path to the URDF model description of the ErgoCub robot. Returns: - The JaxSim model of the ErgoCub robot. + The path to the URDF model description of the ErgoCub robot. """ try: os.environ["ROBOT_DESCRIPTION_COMMIT"] = "v0.7.1" import robot_descriptions.ergocub_description + finally: _ = os.environ.pop("ROBOT_DESCRIPTION_COMMIT", None) @@ -198,7 +199,21 @@ def jaxsim_model_ergocub() -> js.model.JaxSimModel: ) ) - return build_jaxsim_model(model_description=model_urdf_path) + return model_urdf_path + + +@pytest.fixture(scope="session") +def jaxsim_model_ergocub( + ergocub_model_description_path: pathlib.Path, +) -> js.model.JaxSimModel: + """ + Fixture providing the JaxSim model of the ErgoCub robot. + + Returns: + The JaxSim model of the ErgoCub robot. + """ + + return build_jaxsim_model(model_description=ergocub_model_description_path) @pytest.fixture(scope="session") diff --git a/tests/test_api_frame.py b/tests/test_api_frame.py index 4e610a9e6..49d58f3db 100644 --- a/tests/test_api_frame.py +++ b/tests/test_api_frame.py @@ -17,12 +17,10 @@ def test_frame_index(jaxsim_models_types: js.model.JaxSimModel): # ===== frame_indices = tuple( - frame.index - for frame in model.description.get().frames - if frame.index is not None + frame.index for frame in model.description.frames if frame.index is not None ) - frame_names = np.array([frame.name for frame in model.description.get().frames]) + frame_names = np.array([frame.name for frame in model.description.frames]) for frame_idx, frame_name in zip(frame_indices, frame_names): assert js.frame.name_to_idx(model=model, frame_name=frame_name) == frame_idx @@ -60,7 +58,7 @@ def test_frame_transforms( # Get all names of frames in the iDynTree model. frame_names = [ frame.name - for frame in model.description.get().frames + for frame in model.description.frames if frame.name in kin_dyn.frame_names() ] @@ -74,7 +72,7 @@ def test_frame_transforms( # Get indices of frames. frame_indices = tuple( frame.index - for frame in model.description.get().frames + for frame in model.description.frames if frame.index is not None and frame.name in frame_names ) @@ -115,7 +113,7 @@ def test_frame_jacobians( # Get all names of frames in the iDynTree model. frame_names = [ frame.name - for frame in model.description.get().frames + for frame in model.description.frames if frame.name in kin_dyn.frame_names() ] @@ -127,7 +125,7 @@ def test_frame_jacobians( # Get indices of frames. frame_indices = tuple( frame.index - for frame in model.description.get().frames + for frame in model.description.frames if frame.index is not None and frame.name in frame_names ) diff --git a/tests/test_pytree.py b/tests/test_pytree.py index 83febf38d..712b61441 100644 --- a/tests/test_pytree.py +++ b/tests/test_pytree.py @@ -1,44 +1,30 @@ import io +import pathlib from contextlib import redirect_stdout import jax -import pytest -import rod.builder.primitives -import rod.urdf.exporter import jaxsim.api as js -# https://github.com/ami-iit/jaxsim/issues/103 -@pytest.mark.xfail(strict=True) -def test_call_jit_compiled_function_passing_different_objects(): - - # Create on-the-fly a ROD model of a box. - rod_model = ( - rod.builder.primitives.BoxBuilder(x=0.3, y=0.2, z=0.1, mass=1.0, name="box") - .build_model() - .add_link() - .add_inertial() - .add_visual() - .add_collision() - .build() - ) - - # Export the URDF string. - urdf_string = rod.urdf.exporter.UrdfExporter(pretty=True).to_urdf_string( - sdf=rod_model - ) +def test_call_jit_compiled_function_passing_different_objects( + ergocub_model_description_path: pathlib.Path, +): + # Create a first model from the URDF. model1 = js.model.JaxSimModel.build_from_model_description( - model_description=urdf_string, + model_description=ergocub_model_description_path, is_urdf=True, ) + # Create a second model from the URDF. model2 = js.model.JaxSimModel.build_from_model_description( - model_description=urdf_string, + model_description=ergocub_model_description_path, is_urdf=True, ) + # The objects should be different, but the comparison should return True. + assert id(model1) != id(model2) assert model1 == model2 assert hash(model1) == hash(model2) diff --git a/tests/utils_idyntree.py b/tests/utils_idyntree.py index 34a5425ba..b1ccb71bc 100644 --- a/tests/utils_idyntree.py +++ b/tests/utils_idyntree.py @@ -52,7 +52,7 @@ def build_kindyncomputations_from_jaxsim_model( # Get the default positions already stored in the model description. removed_joint_positions_default = { str(j.name): float(j.initial_position) - for j in model.description.get()._joints_removed + for j in model.description._joints_removed if j.name not in considered_joints }