Skip to content

Commit

Permalink
Store JointGenericAxis instead of the JAX numpy array in JointModel
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Jun 3, 2024
1 parent d534de5 commit 6420495
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 20 deletions.
2 changes: 1 addition & 1 deletion src/jaxsim/api/kin_dyn_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def joint_transforms_and_motion_subspaces(
pre_H_suc_J, S_J = jax.vmap(supported_joint_motion)(
jnp.array(self.joint_model.joint_types[1:]).astype(int),
jnp.array(joint_positions),
jnp.array(self.joint_model.joint_axis),
jnp.array([j.axis for j in self.joint_model.joint_axis]),
)

# Extract the transforms and motion subspaces of the joints.
Expand Down
43 changes: 25 additions & 18 deletions src/jaxsim/math/joint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class JointModel:

joint_dofs: Static[tuple[int, ...]]
joint_names: Static[tuple[str, ...]]
joint_types: Static[tuple[JointType, ...]]
joint_types: Static[tuple[int, ...]]
joint_axis: Static[tuple[JointGenericAxis, ...]]

@staticmethod
Expand Down Expand Up @@ -109,7 +109,7 @@ def build(description: ModelDescription) -> JointModel:
joint_dofs=tuple([base_dofs] + [int(1) for _ in ordered_joints]),
joint_names=tuple(["world_to_base"] + [j.name for j in ordered_joints]),
joint_types=tuple([JointType.Fixed] + [j.jtype for j in ordered_joints]),
joint_axis=tuple([j.axis for j in ordered_joints]),
joint_axis=tuple(JointGenericAxis(axis=j.axis) for j in ordered_joints),
)

def parent_H_child(
Expand Down Expand Up @@ -201,7 +201,7 @@ def predecessor_H_successor(
pre_H_suc, S = supported_joint_motion(
self.joint_types[joint_index],
joint_position,
self.joint_axis[joint_index],
self.joint_axis[joint_index].axis,
)

return pre_H_suc, S
Expand All @@ -224,18 +224,18 @@ def successor_H_child(self, joint_index: jtp.IntLike) -> jtp.Matrix:

@jax.jit
def supported_joint_motion(
joint_type: JointType,
joint_type: jtp.IntLike,
joint_position: jtp.VectorLike,
joint_axis: JointGenericAxis,
joint_axis: jtp.VectorLike | None = None,
/,
) -> tuple[jtp.Matrix, jtp.Array]:
"""
Compute the homogeneous transformation and motion subspace of a joint.
Args:
joint_type: The type of the joint.
joint_axis: The axis of rotation or translation of the joint.
joint_position: The position of the joint.
joint_axis: The optional 3D axis of rotation or translation of the joint.
Returns:
A tuple containing the homogeneous transformation and the motion subspace.
Expand All @@ -244,26 +244,33 @@ def supported_joint_motion(
# Prepare the joint position
s = jnp.array(joint_position).astype(float)

def compute_F():
def compute_F() -> tuple[jtp.Matrix, jtp.Array]:
return jaxlie.SE3.identity(), jnp.zeros(shape=(6, 1))

def compute_R():
def compute_R() -> tuple[jtp.Matrix, jtp.Array]:

# Get the additional argument specifying the joint axis.
# This is a metadata required by only some joint types.
axis = jnp.array(joint_axis).astype(float).squeeze()

pre_H_suc = jaxlie.SE3.from_rotation(
rotation=jaxlie.SO3.from_matrix(
Rotation.from_axis_angle(vector=s * joint_axis)
)
rotation=jaxlie.SO3.from_matrix(Rotation.from_axis_angle(vector=s * axis))
)

S = jnp.vstack(jnp.hstack([jnp.zeros(3), joint_axis.squeeze()]))
S = jnp.vstack(jnp.hstack([jnp.zeros(3), axis]))

return pre_H_suc, S

def compute_P():
pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
rotation=jaxlie.SO3.identity(),
translation=jnp.array(s * joint_axis),
)
def compute_P() -> tuple[jtp.Matrix, jtp.Array]:

# Get the additional argument specifying the joint axis.
# This is a metadata required by only some joint types.
axis = jnp.array(joint_axis).astype(float).squeeze()

pre_H_suc = jaxlie.SE3.from_translation(translation=jnp.array(s * axis))

S = jnp.vstack(jnp.hstack([axis, jnp.zeros(3)]))

S = jnp.vstack(jnp.hstack([joint_axis.squeeze(), jnp.zeros(3)]))
return pre_H_suc, S

pre_H_suc, S = jax.lax.switch(
Expand Down
4 changes: 3 additions & 1 deletion src/jaxsim/parsers/descriptions/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ class JointGenericAxis:
axis: jtp.Vector

def __hash__(self) -> int:
return hash((tuple(np.array(self.axis).tolist())))

return hash(tuple(self.axis.tolist()))

def __eq__(self, other: JointGenericAxis) -> bool:

if not isinstance(other, JointGenericAxis):
return False

Expand Down

0 comments on commit 6420495

Please sign in to comment.