From 4d14208dedf4cc0e78811f84e55fade958b1d86b Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 15 May 2024 16:52:48 +0200 Subject: [PATCH] Address review - Use explicit joint type names - Use `jnp.newaxis` instead of `None` - Use dataclass for `JointType` instead of metaclass Co-authored-by: Diego Ferigo --- src/jaxsim/api/kin_dyn_parameters.py | 4 +-- src/jaxsim/math/joint_model.py | 8 +++--- src/jaxsim/parsers/descriptions/joint.py | 32 +++++------------------- src/jaxsim/parsers/rod/parser.py | 2 +- src/jaxsim/parsers/rod/utils.py | 6 ++--- 5 files changed, 16 insertions(+), 36 deletions(-) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index c5df87d8b..b9c536320 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -394,8 +394,8 @@ def joint_transforms_and_motion_subspaces( # Extract the transforms and motion subspaces of the joints. # We stack the base transform W_H_B at index 0, and a dummy motion subspace # for either the fixed or free-floating joint connecting the world to the base. - pre_H_suc = jnp.vstack([W_H_B[None, ...], pre_H_suc_J]) - S = jnp.vstack([jnp.zeros((6, 1))[None, ...], S_J]) + pre_H_suc = jnp.vstack([W_H_B[jnp.newaxis, ...], pre_H_suc_J]) + S = jnp.vstack([jnp.zeros((6, 1))[jnp.newaxis, ...], S_J]) # Extract the successor-to-child fixed transforms. # Note that here we include also the index 0 since suc_H_child[0] stores the diff --git a/src/jaxsim/math/joint_model.py b/src/jaxsim/math/joint_model.py index a6f3776ec..635fe2003 100644 --- a/src/jaxsim/math/joint_model.py +++ b/src/jaxsim/math/joint_model.py @@ -108,7 +108,7 @@ def build(description: ModelDescription) -> JointModel: # Static attributes joint_dofs=tuple([base_dofs] + [int(1) for _ in ordered_joints]), joint_names=tuple(["world_to_base"] + [j.name for j in ordered_joints]), - joint_types=tuple([JointType.F] + [j.jtype for j in ordered_joints]), + joint_types=tuple([JointType.Fixed] + [j.jtype for j in ordered_joints]), joint_axis=tuple([j.axis for j in ordered_joints]), ) @@ -265,9 +265,9 @@ def compute_P(): pre_H_suc, S = jax.lax.switch( index=joint_type, branches=( - compute_F, # JointType.F - compute_R, # JointType.R - compute_P, # JointType.P + compute_F, # JointType.Fixed + compute_R, # JointType.Revolute + compute_P, # JointType.Prismatic ), ) diff --git a/src/jaxsim/parsers/descriptions/joint.py b/src/jaxsim/parsers/descriptions/joint.py index 0a4d21d84..9abf7d257 100644 --- a/src/jaxsim/parsers/descriptions/joint.py +++ b/src/jaxsim/parsers/descriptions/joint.py @@ -1,7 +1,7 @@ from __future__ import annotations import dataclasses -from typing import Tuple, Union +from typing import ClassVar, Tuple, Union import jax_dataclasses import numpy as np @@ -13,31 +13,11 @@ from .link import LinkDescription -class _JointTypeMeta(type): - def __new__(cls, name, bases, dct): - cls_instance = super().__new__(cls, name, bases, dct) - - # Assign integer values to the descriptors - cls_instance.F = 0 - cls_instance.R = 1 - cls_instance.P = 2 - - return cls_instance - - -class JointType(metaclass=_JointTypeMeta): - """ - Type of supported joints. - """ - - class F: - pass - - class R: - pass - - class P: - pass +@dataclasses.dataclass(frozen=True) +class JointType: + Fixed: ClassVar[int] = 0 + Revolute: ClassVar[int] = 1 + Prismatic: ClassVar[int] = 2 @jax_dataclasses.pytree_dataclass diff --git a/src/jaxsim/parsers/rod/parser.py b/src/jaxsim/parsers/rod/parser.py index ed49e7702..c2e396d57 100644 --- a/src/jaxsim/parsers/rod/parser.py +++ b/src/jaxsim/parsers/rod/parser.py @@ -352,7 +352,7 @@ def build_model_description( considered_joints=[ j.name for j in sdf_data.joint_descriptions - if j.jtype is not descriptions.JointType.F + if j.jtype is not descriptions.JointType.Fixed ], ) diff --git a/src/jaxsim/parsers/rod/utils.py b/src/jaxsim/parsers/rod/utils.py index cfa0d508e..a001a1da7 100644 --- a/src/jaxsim/parsers/rod/utils.py +++ b/src/jaxsim/parsers/rod/utils.py @@ -76,7 +76,7 @@ def joint_to_joint_type( joint_type = joint.type if joint_type == "fixed": - return descriptions.JointType.F + return descriptions.JointType.Fixed if not (axis.xyz is not None and axis.xyz.xyz is not None): raise ValueError("Failed to read axis xyz data") @@ -86,10 +86,10 @@ def joint_to_joint_type( axis_xyz = axis_xyz / np.linalg.norm(axis_xyz) if joint_type in {"revolute", "continuous"}: - return descriptions.JointType.R + return descriptions.JointType.Revolute if joint_type == "prismatic": - return descriptions.JointType.P + return descriptions.JointType.Prismatic raise ValueError("Joint not supported", axis_xyz, joint_type)