From 8398f0b630bebc436379f902524dd4e58812c45b Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 14 Jun 2024 15:54:06 +0200 Subject: [PATCH] Define `jaxsim.typing.VelRepr` as `Int` Co-authored-by: Diego Ferigo --- src/jaxsim/api/common.py | 8 ++++---- src/jaxsim/api/data.py | 12 ++++++------ src/jaxsim/api/references.py | 10 +++++----- src/jaxsim/typing.py | 2 ++ tests/test_api_com.py | 3 ++- tests/test_api_data.py | 3 ++- tests/test_api_frame.py | 3 ++- tests/test_api_link.py | 7 ++++--- tests/test_api_model.py | 7 ++++--- tests/test_automatic_differentiation.py | 2 +- tests/test_contact.py | 3 ++- tests/test_simulations.py | 5 +++-- 12 files changed, 37 insertions(+), 28 deletions(-) diff --git a/src/jaxsim/api/common.py b/src/jaxsim/api/common.py index 465690947..e0459fd8c 100644 --- a/src/jaxsim/api/common.py +++ b/src/jaxsim/api/common.py @@ -35,13 +35,13 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC): Base class for model data structures with velocity representation. """ - velocity_representation: int = dataclasses.field( + velocity_representation: jtp.VelRepr = dataclasses.field( default=VelRepr.Inertial, kw_only=True ) @contextlib.contextmanager def switch_velocity_representation( - self, velocity_representation: int + self, velocity_representation: jtp.VelRepr ) -> ContextManager[Self]: """ Context manager to temporarily switch the velocity representation. @@ -83,7 +83,7 @@ def switch_velocity_representation( @functools.partial(jax.jit, static_argnames=["is_force"]) def inertial_to_other_representation( array: jtp.Array, - other_representation: int, + other_representation: jtp.VelRepr, transform: jtp.Matrix, *, is_force: bool, @@ -148,7 +148,7 @@ def to_mixed(): @functools.partial(jax.jit, static_argnames=["is_force"]) def other_representation_to_inertial( array: jtp.Array, - other_representation: int, + other_representation: jtp.VelRepr, transform: jtp.Matrix, *, is_force: bool, diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index c3dfc2a48..0ad559e95 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -83,7 +83,7 @@ def valid(self, model: js.model.JaxSimModel | None = None) -> bool: @staticmethod def zero( model: js.model.JaxSimModel, - velocity_representation: int = VelRepr.Inertial, + velocity_representation: jtp.VelRepr = VelRepr.Inertial, ) -> JaxSimModelData: """ Create a `JaxSimModelData` object with zero state. @@ -112,7 +112,7 @@ def build( standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity, soft_contacts_state: js.ode_data.SoftContactsState | None = None, soft_contacts_params: jaxsim.rbda.SoftContactsParams | None = None, - velocity_representation: int = VelRepr.Inertial, + velocity_representation: jtp.VelRepr = VelRepr.Inertial, time: jtp.FloatLike | None = None, ) -> JaxSimModelData: """ @@ -631,7 +631,7 @@ def reset_base_pose(self, base_pose: jtp.MatrixLike) -> Self: def reset_base_linear_velocity( self, linear_velocity: jtp.VectorLike, - velocity_representation: int | None = None, + velocity_representation: jtp.VelRepr | None = None, ) -> Self: """ Reset the base linear velocity. @@ -659,7 +659,7 @@ def reset_base_linear_velocity( def reset_base_angular_velocity( self, angular_velocity: jtp.VectorLike, - velocity_representation: int | None = None, + velocity_representation: jtp.VelRepr | None = None, ) -> Self: """ Reset the base angular velocity. @@ -687,7 +687,7 @@ def reset_base_angular_velocity( def reset_base_velocity( self, base_velocity: jtp.VectorLike, - velocity_representation: int | None = None, + velocity_representation: jtp.VelRepr | None = None, ) -> Self: """ Reset the base 6D velocity. @@ -732,7 +732,7 @@ def random_model_data( model: js.model.JaxSimModel, *, key: jax.Array | None = None, - velocity_representation: int | None = None, + velocity_representation: jtp.VelRepr | None = None, base_pos_bounds: tuple[ jtp.FloatLike | Sequence[jtp.FloatLike], jtp.FloatLike | Sequence[jtp.FloatLike], diff --git a/src/jaxsim/api/references.py b/src/jaxsim/api/references.py index ddc350eba..ecbf79720 100644 --- a/src/jaxsim/api/references.py +++ b/src/jaxsim/api/references.py @@ -32,7 +32,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation): @staticmethod def zero( model: js.model.JaxSimModel, - velocity_representation: int = VelRepr.Inertial, + velocity_representation: jtp.VelRepr = VelRepr.Inertial, ) -> JaxSimModelReferences: """ Create a `JaxSimModelReferences` object with zero references. @@ -55,7 +55,7 @@ def build( joint_force_references: jtp.Vector | None = None, link_forces: jtp.Matrix | None = None, data: js.data.JaxSimModelData | None = None, - velocity_representation: int | None = None, + velocity_representation: jtp.VelRepr | None = None, ) -> JaxSimModelReferences: """ Create a `JaxSimModelReferences` object with the given references. @@ -225,7 +225,7 @@ def check_not_inertial() -> None: false_fun=lambda: None, ) - def not_inertial(velocity_representation: int) -> jtp.Matrix: + def not_inertial(velocity_representation: jtp.VelRepr) -> jtp.Matrix: # Helper function to convert a single 6D force to the active representation # considering as body the link (i.e. L_f_L and LW_f_L). def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix: @@ -468,7 +468,7 @@ def check_not_inertial() -> None: ) # If inertial-fixed representation, we can directly store the link forces. - def inertial(velocity_representation: int) -> JaxSimModelReferences: + def inertial(velocity_representation: jtp.VelRepr) -> JaxSimModelReferences: W_f_L = f_L return replace( forces=self.input.physics_model.f_ext.at[link_idxs, :].set( @@ -476,7 +476,7 @@ def inertial(velocity_representation: int) -> JaxSimModelReferences: ) ) - def not_inertial(velocity_representation: int) -> JaxSimModelReferences: + def not_inertial(velocity_representation: jtp.VelRepr) -> JaxSimModelReferences: # Helper function to convert a single 6D force to the inertial representation # considering as body the link (i.e. L_f_L and LW_f_L). def convert_using_link_frame( diff --git a/src/jaxsim/typing.py b/src/jaxsim/typing.py index 5d56467c0..048dd8169 100644 --- a/src/jaxsim/typing.py +++ b/src/jaxsim/typing.py @@ -40,3 +40,5 @@ IntLike = Int BoolLike = Bool FloatLike = Float + +VelRepr = Int diff --git a/tests/test_api_com.py b/tests/test_api_com.py index 83a17bf75..e019a53fd 100644 --- a/tests/test_api_com.py +++ b/tests/test_api_com.py @@ -2,6 +2,7 @@ import pytest import jaxsim.api as js +import jaxsim.typing as jtp from jaxsim import VelRepr from . import utils_idyntree @@ -9,7 +10,7 @@ def test_com_properties( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: int, + velocity_representation: jtp.VelRepr, prng_key: jax.Array, ): diff --git a/tests/test_api_data.py b/tests/test_api_data.py index c22ba7687..f159026a8 100644 --- a/tests/test_api_data.py +++ b/tests/test_api_data.py @@ -3,6 +3,7 @@ import pytest import jaxsim.api as js +import jaxsim.typing as jtp from jaxsim import VelRepr from jaxsim.utils import Mutability @@ -21,7 +22,7 @@ def test_data_valid( def test_data_joint_indexing( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: int, + velocity_representation: jtp.VelRepr, prng_key: jax.Array, ): diff --git a/tests/test_api_frame.py b/tests/test_api_frame.py index ccd750381..10d5f299b 100644 --- a/tests/test_api_frame.py +++ b/tests/test_api_frame.py @@ -3,6 +3,7 @@ import pytest import jaxsim.api as js +import jaxsim.typing as jtp from jaxsim import VelRepr from . import utils_idyntree @@ -95,7 +96,7 @@ def test_frame_transforms( def test_frame_jacobians( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: int, + velocity_representation: jtp.VelRepr, prng_key: jax.Array, ): diff --git a/tests/test_api_link.py b/tests/test_api_link.py index be31acbd1..027865d4f 100644 --- a/tests/test_api_link.py +++ b/tests/test_api_link.py @@ -4,6 +4,7 @@ import jaxsim.api as js import jaxsim.math +import jaxsim.typing as jtp from jaxsim import VelRepr from . import utils_idyntree @@ -117,7 +118,7 @@ def test_link_transforms( def test_link_jacobians( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: int, + velocity_representation: jtp.VelRepr, prng_key: jax.Array, ): @@ -184,7 +185,7 @@ def test_link_jacobians( def test_link_bias_acceleration( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: int, + velocity_representation: jtp.VelRepr, prng_key: jax.Array, ): @@ -216,7 +217,7 @@ def test_link_bias_acceleration( def test_link_jacobian_derivative( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: VelRepr, + velocity_representation: jtp.VelRepr, prng_key: jax.Array, ): diff --git a/tests/test_api_model.py b/tests/test_api_model.py index 087c9d721..be6bfc97f 100644 --- a/tests/test_api_model.py +++ b/tests/test_api_model.py @@ -8,6 +8,7 @@ import jaxsim.api as js import jaxsim.math +import jaxsim.typing as jtp from jaxsim import VelRepr from . import utils_idyntree @@ -221,7 +222,7 @@ def test_model_creation_and_reduction( def test_model_properties( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: int, + velocity_representation: jtp.VelRepr, prng_key: jax.Array, ): @@ -268,7 +269,7 @@ def test_model_properties( def test_model_rbda( jaxsim_models_types: js.model.JaxSimModel, prng_key: jax.Array, - velocity_representation: int, + velocity_representation: jtp.VelRepr, ): model = jaxsim_models_types @@ -479,7 +480,7 @@ def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array: def test_model_fd_id_consistency( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: int, + velocity_representation: jtp.VelRepr, prng_key: jax.Array, ): diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py index fdf6a454e..c42ab5360 100644 --- a/tests/test_automatic_differentiation.py +++ b/tests/test_automatic_differentiation.py @@ -25,7 +25,7 @@ def get_random_data_and_references( model: js.model.JaxSimModel, - velocity_representation: int, + velocity_representation: jtp.VelRepr, key: jax.Array, ) -> tuple[js.data.JaxSimModelData, js.references.JaxSimModelReferences]: diff --git a/tests/test_contact.py b/tests/test_contact.py index 20af3dab2..5bcdc2cd6 100644 --- a/tests/test_contact.py +++ b/tests/test_contact.py @@ -2,12 +2,13 @@ import pytest import jaxsim.api as js +import jaxsim.typing as jtp from jaxsim import VelRepr def test_collidable_point_jacobians( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: int, + velocity_representation: jtp.VelRepr, prng_key: jax.Array, ): diff --git a/tests/test_simulations.py b/tests/test_simulations.py index 668fcc6dc..c85ac282d 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -5,12 +5,13 @@ import jaxsim.api as js import jaxsim.integrators import jaxsim.rbda +import jaxsim.typing as jtp from jaxsim import VelRepr def test_box_with_external_forces( jaxsim_model_box: js.model.JaxSimModel, - velocity_representation: int, + velocity_representation: jtp.VelRepr, ): """ This test simulates a box falling due to gravity. @@ -95,7 +96,7 @@ def test_box_with_external_forces( def test_box_with_zero_gravity( jaxsim_model_box: js.model.JaxSimModel, - velocity_representation: VelRepr, + velocity_representation: jtp.VelRepr, prng_key: jnp.ndarray, ):