Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix passing different PyTrees to JIT-compiled functions #165

Merged
merged 10 commits into from
Jun 4, 2024
20 changes: 19 additions & 1 deletion src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
@jax_dataclasses.pytree_dataclass
class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
"""
Class containing the state of a `JaxSimModel` object.
Class containing the data of a `JaxSimModel` object.
"""

state: ODEState
Expand All @@ -43,6 +43,24 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
default_factory=lambda: jnp.array(0, dtype=jnp.uint64)
)

def __hash__(self) -> int:

return hash(
(
hash(self.state),
hash(tuple(self.gravity.flatten().tolist())),
diegoferigo marked this conversation as resolved.
Show resolved Hide resolved
hash(self.soft_contacts_params),
hash(jnp.atleast_1d(self.time_ns).flatten().tolist()),
diegoferigo marked this conversation as resolved.
Show resolved Hide resolved
)
)

def __eq__(self, other: JaxSimModelData) -> bool:

if not isinstance(other, JaxSimModelData):
return False

return hash(self) == hash(other)

def valid(self, model: js.model.JaxSimModel | None = None) -> bool:
"""
Check if the current state is valid for the given model.
Expand Down
8 changes: 4 additions & 4 deletions src/jaxsim/api/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def idx_of_parent_link(model: js.model.JaxSimModel, *, frame_idx: jtp.IntLike) -
"""

# Get the intermediate representation parsed from the model description.
ir = model.description.get()
ir = model.description

# Extract the indices of the frame and the link it is attached to.
F = ir.frames[frame_idx - model.number_of_links()]
Expand All @@ -51,7 +51,7 @@ def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> int:
The index of the frame.
"""

frame_names = np.array([frame.name for frame in model.description.get().frames])
frame_names = np.array([frame.name for frame in model.description.frames])

if frame_name in frame_names:
idx_in_list = np.argwhere(frame_names == frame_name)
Expand All @@ -72,7 +72,7 @@ def idx_to_name(model: js.model.JaxSimModel, *, frame_index: jtp.IntLike) -> str
The name of the frame.
"""

return model.description.get().frames[frame_index - model.number_of_links()].name
return model.description.frames[frame_index - model.number_of_links()].name


@functools.partial(jax.jit, static_argnames=["frame_names"])
Expand Down Expand Up @@ -144,7 +144,7 @@ def transform(
W_H_L = js.link.transform(model=model, data=data, link_index=L)

# Get the static frame pose wrt the parent link.
frame = model.description.get().frames[frame_index - model.number_of_links()]
frame = model.description.frames[frame_index - model.number_of_links()]
L_H_F = frame.pose

# Combine the transforms computing the frame pose.
Expand Down
39 changes: 21 additions & 18 deletions src/jaxsim/api/kin_dyn_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import jaxsim.typing as jtp
from jaxsim.math import Inertia, JointModel, supported_joint_motion
from jaxsim.parsers.descriptions import JointDescription, ModelDescription
from jaxsim.utils import JaxsimDataclass
from jaxsim.utils import HashedNumpyArray, JaxsimDataclass


@jax_dataclasses.pytree_dataclass
Expand All @@ -32,8 +32,8 @@ class KynDynParameters(JaxsimDataclass):

# Static
link_names: Static[tuple[str]]
parent_array: Static[jtp.Vector]
support_body_array_bool: Static[jtp.Matrix]
_parent_array: Static[HashedNumpyArray]
_support_body_array_bool: Static[HashedNumpyArray]

# Links
link_parameters: LinkParameters
Expand All @@ -45,6 +45,14 @@ class KynDynParameters(JaxsimDataclass):
joint_model: JointModel
joint_parameters: JointParameters | None

@property
def parent_array(self) -> jtp.Vector:
return self._parent_array.get()

@property
def support_body_array_bool(self) -> jtp.Matrix:
return self._support_body_array_bool.get()

@staticmethod
def build(model_description: ModelDescription) -> KynDynParameters:
"""
Expand Down Expand Up @@ -191,8 +199,8 @@ def scan_body(carry: tuple, i: jtp.Int) -> tuple[tuple, None]:

return KynDynParameters(
link_names=tuple(l.name for l in ordered_links),
parent_array=parent_array,
support_body_array_bool=support_body_array_bool,
_parent_array=HashedNumpyArray(array=parent_array),
_support_body_array_bool=HashedNumpyArray(array=support_body_array_bool),
link_parameters=link_parameters,
joint_model=joint_model,
joint_parameters=joint_parameters,
Expand All @@ -204,23 +212,18 @@ def __eq__(self, other: KynDynParameters) -> bool:
if not isinstance(other, KynDynParameters):
return False

equal = True
equal = equal and self.number_of_links() == other.number_of_links()
equal = equal and self.number_of_joints() == other.number_of_joints()
equal = equal and jnp.allclose(self.parent_array, other.parent_array)

return equal
return hash(self) == hash(other)

def __hash__(self) -> int:

h = (
hash(self.number_of_links()),
hash(self.number_of_joints()),
hash(tuple(self.parent_array.tolist())),
return hash(
(
hash(self.number_of_links()),
hash(self.number_of_joints()),
hash(tuple(jnp.atleast_1d(self.parent_array).flatten().tolist())),
)
)

return hash(h)

# =============================
# Helpers to extract parameters
# =============================
Expand Down Expand Up @@ -388,7 +391,7 @@ def joint_transforms_and_motion_subspaces(
pre_H_suc_J, S_J = jax.vmap(supported_joint_motion)(
jnp.array(self.joint_model.joint_types[1:]).astype(int),
jnp.array(joint_positions),
jnp.array(self.joint_model.joint_axis),
jnp.array([j.axis for j in self.joint_model.joint_axis]),
)

# Extract the transforms and motion subspaces of the joints.
Expand Down
33 changes: 26 additions & 7 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,37 @@ class JaxSimModel(JaxsimDataclass):
terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field(
default=jaxsim.terrain.FlatTerrain(), repr=False, compare=False, hash=False
)
kin_dyn_parameters: js.kin_dyn_parameters.KynDynParameters | None = (
dataclasses.field(default=None, repr=False, compare=False, hash=False)
)

built_from: Static[str | pathlib.Path | rod.Model | None] = dataclasses.field(
default=None, repr=False, compare=False, hash=False
)

description: Static[
_description: Static[
HashlessObject[jaxsim.parsers.descriptions.ModelDescription | None]
] = dataclasses.field(default=None, repr=False, compare=False, hash=False)

kin_dyn_parameters: js.kin_dyn_parameters.KynDynParameters | None = (
dataclasses.field(default=None, repr=False, compare=False, hash=False)
)
@property
def description(self) -> jaxsim.parsers.descriptions.ModelDescription:
return self._description.get()

def __eq__(self, other: JaxSimModel) -> bool:

if not isinstance(other, JaxSimModel):
return False

return hash(self) == hash(other)

def __hash__(self) -> int:

return hash(
(
hash(self.model_name),
hash(self.kin_dyn_parameters),
)
)

# ========================
# Initialization and state
Expand Down Expand Up @@ -137,7 +156,7 @@ def build(
# Build the model
model = JaxSimModel(
model_name=model_name,
description=HashlessObject(obj=model_description),
_description=HashlessObject(obj=model_description),
kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build(
model_description=model_description
),
Expand Down Expand Up @@ -260,7 +279,7 @@ def frame_names(self) -> tuple[str, ...]:
The names of the links in the model.
"""

return tuple([frame.name for frame in self.description.get().frames])
return tuple([frame.name for frame in self.description.frames])
diegoferigo marked this conversation as resolved.
Show resolved Hide resolved


# =====================
Expand Down Expand Up @@ -297,7 +316,7 @@ def reduce(

# Copy the model description with a deep copy of the joints.
intermediate_description = dataclasses.replace(
model.description.get(), joints=copy.deepcopy(model.description.get().joints)
model.description, joints=copy.deepcopy(model.description.joints)
)

# Update the initial position of the joints.
Expand Down
31 changes: 31 additions & 0 deletions src/jaxsim/api/ode_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,24 @@ class PhysicsModelState(JaxsimDataclass):
default_factory=lambda: jnp.zeros(3)
)

def __hash__(self) -> int:

return hash(
(
hash(tuple(jnp.atleast_1d(self.joint_positions.flatten().tolist()))),
hash(tuple(jnp.atleast_1d(self.joint_velocities.flatten().tolist()))),
hash(tuple(self.base_position.flatten().tolist())),
hash(tuple(self.base_quaternion.flatten().tolist())),
diegoferigo marked this conversation as resolved.
Show resolved Hide resolved
)
)

def __eq__(self, other: PhysicsModelState) -> bool:

if not isinstance(other, PhysicsModelState):
return False

return hash(self) == hash(other)

@staticmethod
def build_from_jaxsim_model(
model: js.model.JaxSimModel | None = None,
Expand Down Expand Up @@ -593,6 +611,19 @@ class SoftContactsState(JaxsimDataclass):

tangential_deformation: jtp.Matrix

def __hash__(self) -> int:

return hash(
tuple(jnp.atleast_1d(self.tangential_deformation.flatten()).tolist())
)

def __eq__(self, other: SoftContactsState) -> bool:

if not isinstance(other, SoftContactsState):
return False

return hash(self) == hash(other)

@staticmethod
def build_from_jaxsim_model(
model: js.model.JaxSimModel | None = None,
Expand Down
43 changes: 25 additions & 18 deletions src/jaxsim/math/joint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class JointModel:

joint_dofs: Static[tuple[int, ...]]
joint_names: Static[tuple[str, ...]]
joint_types: Static[tuple[JointType, ...]]
joint_types: Static[tuple[int, ...]]
joint_axis: Static[tuple[JointGenericAxis, ...]]

@staticmethod
Expand Down Expand Up @@ -109,7 +109,7 @@ def build(description: ModelDescription) -> JointModel:
joint_dofs=tuple([base_dofs] + [int(1) for _ in ordered_joints]),
joint_names=tuple(["world_to_base"] + [j.name for j in ordered_joints]),
joint_types=tuple([JointType.Fixed] + [j.jtype for j in ordered_joints]),
joint_axis=tuple([j.axis for j in ordered_joints]),
joint_axis=tuple(JointGenericAxis(axis=j.axis) for j in ordered_joints),
)

def parent_H_child(
Expand Down Expand Up @@ -201,7 +201,7 @@ def predecessor_H_successor(
pre_H_suc, S = supported_joint_motion(
self.joint_types[joint_index],
joint_position,
self.joint_axis[joint_index],
self.joint_axis[joint_index].axis,
)

return pre_H_suc, S
Expand All @@ -224,18 +224,18 @@ def successor_H_child(self, joint_index: jtp.IntLike) -> jtp.Matrix:

@jax.jit
def supported_joint_motion(
joint_type: JointType,
joint_type: jtp.IntLike,
joint_position: jtp.VectorLike,
joint_axis: JointGenericAxis,
joint_axis: jtp.VectorLike | None = None,
/,
) -> tuple[jtp.Matrix, jtp.Array]:
"""
Compute the homogeneous transformation and motion subspace of a joint.

Args:
joint_type: The type of the joint.
joint_axis: The axis of rotation or translation of the joint.
joint_position: The position of the joint.
joint_axis: The optional 3D axis of rotation or translation of the joint.

Returns:
A tuple containing the homogeneous transformation and the motion subspace.
Expand All @@ -244,26 +244,33 @@ def supported_joint_motion(
# Prepare the joint position
s = jnp.array(joint_position).astype(float)

def compute_F():
def compute_F() -> tuple[jtp.Matrix, jtp.Array]:
return jaxlie.SE3.identity(), jnp.zeros(shape=(6, 1))

def compute_R():
def compute_R() -> tuple[jtp.Matrix, jtp.Array]:

# Get the additional argument specifying the joint axis.
# This is a metadata required by only some joint types.
axis = jnp.array(joint_axis).astype(float).squeeze()

pre_H_suc = jaxlie.SE3.from_rotation(
rotation=jaxlie.SO3.from_matrix(
Rotation.from_axis_angle(vector=s * joint_axis)
)
rotation=jaxlie.SO3.from_matrix(Rotation.from_axis_angle(vector=s * axis))
)

S = jnp.vstack(jnp.hstack([jnp.zeros(3), joint_axis.squeeze()]))
S = jnp.vstack(jnp.hstack([jnp.zeros(3), axis]))

return pre_H_suc, S

def compute_P():
pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
rotation=jaxlie.SO3.identity(),
translation=jnp.array(s * joint_axis),
)
def compute_P() -> tuple[jtp.Matrix, jtp.Array]:

# Get the additional argument specifying the joint axis.
# This is a metadata required by only some joint types.
axis = jnp.array(joint_axis).astype(float).squeeze()

pre_H_suc = jaxlie.SE3.from_translation(translation=jnp.array(s * axis))

S = jnp.vstack(jnp.hstack([axis, jnp.zeros(3)]))

S = jnp.vstack(jnp.hstack([joint_axis.squeeze(), jnp.zeros(3)]))
return pre_H_suc, S

pre_H_suc, S = jax.lax.switch(
Expand Down
4 changes: 3 additions & 1 deletion src/jaxsim/parsers/descriptions/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ class JointGenericAxis:
axis: jtp.Vector

def __hash__(self) -> int:
return hash((tuple(np.array(self.axis).tolist())))

return hash(tuple(self.axis.tolist()))

def __eq__(self, other: JointGenericAxis) -> bool:

if not isinstance(other, JointGenericAxis):
return False

Expand Down
Loading