From 688664b91e810ff93902b0f27d2a306e3f1897ac Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 12 Mar 2024 09:11:36 +0100 Subject: [PATCH 01/35] Remove existing tests of OOP APIs --- tests/test_ad_physics.py | 190 --------------- tests/test_eom.py | 130 ---------- tests/test_forward_dynamics.py | 71 ------ tests/test_jax_oop.py | 422 --------------------------------- tests/utils_models.py | 56 ----- tests/utils_rng.py | 96 -------- 6 files changed, 965 deletions(-) delete mode 100644 tests/test_ad_physics.py delete mode 100644 tests/test_eom.py delete mode 100644 tests/test_forward_dynamics.py delete mode 100644 tests/test_jax_oop.py delete mode 100644 tests/utils_models.py delete mode 100644 tests/utils_rng.py diff --git a/tests/test_ad_physics.py b/tests/test_ad_physics.py deleted file mode 100644 index 9c7db2c0e..000000000 --- a/tests/test_ad_physics.py +++ /dev/null @@ -1,190 +0,0 @@ -import jax.numpy as jnp -import numpy as np -import pytest -from jax.test_util import check_grads -from pytest import param as p - -from jaxsim.high_level.common import VelRepr -from jaxsim.high_level.model import Model - -from . import utils_models, utils_rng -from .utils_models import Robot - - -@pytest.mark.parametrize( - "robot, vel_repr", - [ - p(*[Robot.Ur10, VelRepr.Inertial], id="Ur10-Inertial"), - p(*[Robot.AnymalC, VelRepr.Inertial], id="AnymalC-Inertial"), - p(*[Robot.Cassie, VelRepr.Inertial], id="Cassie-Inertial"), - ], -) -def test_ad_physics(robot: utils_models.Robot, vel_repr: VelRepr) -> None: - """Unit test of the application of Automatic Differentiation on RBD algorithms.""" - - robot = Robot.Ur10 - vel_repr = VelRepr.Inertial - - # Initialize the gravity - gravity = np.array([0, 0, -10.0]) - - # Get the URDF of the robot - urdf_file_path = utils_models.ModelFactory.get_model_description(robot=robot) - - # Build the high-level model - model = Model.build_from_model_description( - model_description=urdf_file_path, - vel_repr=vel_repr, - gravity=gravity, - is_urdf=True, - ).mutable(mutable=True, validate=True) - - # Initialize the model with a random state - model.data.model_state = utils_rng.random_physics_model_state( - physics_model=model.physics_model - ) - - # Initialize the model with a random input - model.data.model_input = utils_rng.random_physics_model_input( - physics_model=model.physics_model - ) - - # ======================== - # Extract state and inputs - # ======================== - - # Extract the physics model used in the low-level physics algorithms - physics_model = model.physics_model - - # State - s = model.joint_positions() - ṡ = model.joint_velocities() - xfb = model.data.model_state.xfb() - - # Inputs - f_ext = model.external_forces() - tau = model.joint_generalized_forces_targets() - - # Perturbation used for computing finite differences - ε = jnp.finfo(jnp.array(0.0)).resolution ** (1 / 3) - - # ===================================================== - # Check first-order and second-order derivatives of ABA - # ===================================================== - - import jaxsim.physics.algos.aba - - aba = lambda xfb, s, ṡ, tau, f_ext: jaxsim.physics.algos.aba.aba( - model=physics_model, xfb=xfb, q=s, qd=ṡ, tau=tau, f_ext=f_ext - ) - - check_grads( - f=aba, - args=(xfb, s, ṡ, tau, f_ext), - order=2, - modes=["rev", "fwd"], - eps=ε, - ) - - # ====================================================== - # Check first-order and second-order derivatives of RNEA - # ====================================================== - - import jaxsim.physics.algos.rnea - - W_v̇_WB = utils_rng.get_rng().uniform(size=6, low=-1) - s̈ = utils_rng.get_rng().uniform(size=physics_model.dofs(), low=-1) - - rnea = lambda xfb, s, ṡ, s̈, W_v̇_WB, f_ext: jaxsim.physics.algos.rnea.rnea( - model=physics_model, xfb=xfb, q=s, qd=ṡ, qdd=s̈, a0fb=W_v̇_WB, f_ext=f_ext - ) - - check_grads( - f=rnea, - args=(xfb, s, ṡ, s̈, W_v̇_WB, f_ext), - order=2, - modes=["rev", "fwd"], - eps=ε, - ) - - # ====================================================== - # Check first-order and second-order derivatives of CRBA - # ====================================================== - - import jaxsim.physics.algos.crba - - crba = lambda s: jaxsim.physics.algos.crba.crba(model=physics_model, q=s) - - check_grads( - f=crba, - args=(s,), - order=2, - modes=["rev", "fwd"], - eps=ε, - ) - - # ==================================================== - # Check first-order and second-order derivatives of FK - # ==================================================== - - import jaxsim.physics.algos.forward_kinematics - - fk = ( - lambda xfb, s: jaxsim.physics.algos.forward_kinematics.forward_kinematics_model( - model=physics_model, xfb=xfb, q=s - ) - ) - - check_grads( - f=fk, - args=(xfb, s), - order=2, - modes=["rev", "fwd"], - eps=ε, - ) - - # ========================================================== - # Check first-order and second-order derivatives of Jacobian - # ========================================================== - - import jaxsim.physics.algos.jacobian - - link_indices = [l.index() for l in model.links()] - - jacobian = lambda s: jaxsim.physics.algos.jacobian.jacobian( - model=physics_model, q=s, body_index=link_indices[-1] - ) - - check_grads( - f=jacobian, - args=(s,), - order=2, - modes=["rev", "fwd"], - eps=ε, - ) - - # ===================================================================== - # Check first-order and second-order derivatives of soft contacts model - # ===================================================================== - - import jaxsim.physics.algos.soft_contacts - - p = utils_rng.get_rng().uniform(size=3, low=-1) - v = utils_rng.get_rng().uniform(size=3, low=-1) - m = utils_rng.get_rng().uniform(size=3, low=-1) - - parameters = jaxsim.physics.algos.soft_contacts.SoftContactsParams.build( - K=10_000, D=20.0, mu=0.5 - ) - - soft_contacts = lambda p, v, m: jaxsim.physics.algos.soft_contacts.SoftContacts( - parameters=parameters - ).contact_model(position=p, velocity=v, tangential_deformation=m) - - check_grads( - f=soft_contacts, - args=(p, v, m), - order=2, - modes=["rev", "fwd"], - eps=ε, - ) diff --git a/tests/test_eom.py b/tests/test_eom.py deleted file mode 100644 index 4c0926382..000000000 --- a/tests/test_eom.py +++ /dev/null @@ -1,130 +0,0 @@ -import pathlib - -import jax.numpy as jnp -import numpy as np -import pytest -from pytest import param as p - -from jaxsim.high_level.common import VelRepr -from jaxsim.high_level.model import Model - -from . import utils_idyntree, utils_models, utils_rng -from .utils_models import Robot - - -@pytest.mark.parametrize( - "robot, vel_repr", - [ - p(*[Robot.DoublePendulum, VelRepr.Inertial], id="DoublePendulum-Inertial"), - p(*[Robot.DoublePendulum, VelRepr.Body], id="DoublePendulum-Body"), - p(*[Robot.DoublePendulum, VelRepr.Mixed], id="DoublePendulum-Mixed"), - p(*[Robot.Ur10, VelRepr.Inertial], id="Ur10-Inertial"), - p(*[Robot.Ur10, VelRepr.Body], id="Ur10-Body"), - p(*[Robot.Ur10, VelRepr.Mixed], id="Ur10-Mixed"), - p(*[Robot.AnymalC, VelRepr.Inertial], id="AnymalC-Inertial"), - p(*[Robot.AnymalC, VelRepr.Body], id="AnymalC-Body"), - p(*[Robot.AnymalC, VelRepr.Mixed], id="AnymalC-Mixed"), - p(*[Robot.Cassie, VelRepr.Inertial], id="Cassie-Inertial"), - p(*[Robot.Cassie, VelRepr.Body], id="Cassie-Body"), - p(*[Robot.Cassie, VelRepr.Mixed], id="Cassie-Mixed"), - ], -) -def test_eom(robot: utils_models.Robot, vel_repr: VelRepr) -> None: - """Unit test of all the terms of the floating-base Equations of Motion.""" - - # Initialize the gravity - gravity = np.array([0, 0, -10.0]) - - # Get the URDF of the robot - urdf_file_path = utils_models.ModelFactory.get_model_description(robot=robot) - - # Build the high-level model - model_jaxsim = Model.build_from_model_description( - model_description=urdf_file_path, - vel_repr=vel_repr, - gravity=gravity, - is_urdf=True, - ).mutable(mutable=True, validate=True) - - # Initialize the model with a random state - model_jaxsim.data.model_state = utils_rng.random_physics_model_state( - physics_model=model_jaxsim.physics_model - ) - - # Initialize the model with a random input - model_jaxsim.data.model_input = utils_rng.random_physics_model_input( - physics_model=model_jaxsim.physics_model - ) - - # Get the joint torques - tau = model_jaxsim.joint_generalized_forces_targets() - - # ========================== - # Ground truth with iDynTree - # ========================== - - kin_dyn = utils_idyntree.KinDynComputations.build( - urdf=pathlib.Path(urdf_file_path), - considered_joints=list(model_jaxsim.joint_names()), - vel_repr=vel_repr, - gravity=gravity, - ) - - kin_dyn.set_robot_state( - joint_positions=np.array(model_jaxsim.joint_positions()), - joint_velocities=np.array(model_jaxsim.joint_velocities()), - base_transform=np.array(model_jaxsim.base_transform()), - base_velocity=np.array(model_jaxsim.base_velocity()), - ) - - assert kin_dyn.joint_names() == list(model_jaxsim.joint_names()) - assert kin_dyn.gravity == pytest.approx(model_jaxsim.physics_model.gravity[0:3]) - assert kin_dyn.joint_positions() == pytest.approx(model_jaxsim.joint_positions()) - assert kin_dyn.joint_velocities() == pytest.approx(model_jaxsim.joint_velocities()) - assert kin_dyn.base_velocity() == pytest.approx(model_jaxsim.base_velocity()) - assert kin_dyn.frame_transform(model_jaxsim.base_frame()) == pytest.approx( - model_jaxsim.base_transform() - ) - - M_idt = kin_dyn.mass_matrix() - h_idt = kin_dyn.bias_forces() - g_idt = kin_dyn.gravity_forces() - - J_idt = np.vstack( - [ - kin_dyn.jacobian_frame(frame_name=link_name) - for link_name in model_jaxsim.link_names() - ] - ) - - # ================================ - # Test individual terms of the EoM - # ================================ - - M_jaxsim = model_jaxsim.free_floating_mass_matrix() - g_jaxsim = model_jaxsim.free_floating_gravity_forces() - J_jaxsim = jnp.vstack([link.jacobian() for link in model_jaxsim.links()]) - h_jaxsim = model_jaxsim.free_floating_bias_forces() - - # Support both fixed-base and floating-base models by slicing the first six rows - sl = np.s_[0:] if model_jaxsim.floating_base() else np.s_[6:] - - assert M_jaxsim[sl, sl] == pytest.approx(M_idt[sl, sl], abs=1e-3) - assert g_jaxsim[sl] == pytest.approx(g_idt[sl], abs=1e-3) - assert h_jaxsim[sl] == pytest.approx(h_idt[sl], abs=1e-3) - assert J_jaxsim == pytest.approx(J_idt, abs=1e-3) - - # =========================================== - # Test the forward dynamics computed with CRB - # =========================================== - - J_ff = model_jaxsim.generalized_free_floating_jacobian() - f_ext = model_jaxsim.external_forces().flatten() - ν̇ = np.hstack(model_jaxsim.forward_dynamics_crb(tau=tau)) - S = np.block( - [np.zeros(shape=(model_jaxsim.dofs(), 6)), np.eye(model_jaxsim.dofs())] - ).T - - assert h_jaxsim[sl] == pytest.approx( - (S @ tau + J_ff.T @ f_ext - M_jaxsim @ ν̇)[sl], abs=1e-3 - ) diff --git a/tests/test_forward_dynamics.py b/tests/test_forward_dynamics.py deleted file mode 100644 index b4bb72a37..000000000 --- a/tests/test_forward_dynamics.py +++ /dev/null @@ -1,71 +0,0 @@ -import numpy as np -import pytest -from pytest import param as p - -from jaxsim.high_level.common import VelRepr -from jaxsim.high_level.model import Model - -from . import utils_models, utils_rng -from .utils_models import Robot - - -@pytest.mark.parametrize( - "robot, vel_repr", - [ - p(*[Robot.DoublePendulum, VelRepr.Inertial], id="DoublePendulum-Inertial"), - p(*[Robot.DoublePendulum, VelRepr.Body], id="DoublePendulum-Body"), - p(*[Robot.DoublePendulum, VelRepr.Mixed], id="DoublePendulum-Mixed"), - p(*[Robot.Ur10, VelRepr.Inertial], id="Ur10-Inertial"), - p(*[Robot.Ur10, VelRepr.Body], id="Ur10-Body"), - p(*[Robot.Ur10, VelRepr.Mixed], id="Ur10-Mixed"), - p(*[Robot.AnymalC, VelRepr.Inertial], id="AnymalC-Inertial"), - p(*[Robot.AnymalC, VelRepr.Body], id="AnymalC-Body"), - p(*[Robot.AnymalC, VelRepr.Mixed], id="AnymalC-Mixed"), - p(*[Robot.Cassie, VelRepr.Inertial], id="Cassie-Inertial"), - p(*[Robot.Cassie, VelRepr.Body], id="Cassie-Body"), - p(*[Robot.Cassie, VelRepr.Mixed], id="Cassie-Mixed"), - ], -) -def test_aba(robot: utils_models.Robot, vel_repr: VelRepr) -> None: - """ - Unit test of the ABA algorithm against forward dynamics computed from the EoM. - """ - - # Initialize the gravity - gravity = np.array([0, 0, -10.0]) - - # Get the URDF of the robot - urdf_file_path = utils_models.ModelFactory.get_model_description(robot=robot) - - # Build the high-level model - model = Model.build_from_model_description( - model_description=urdf_file_path, - vel_repr=vel_repr, - gravity=gravity, - is_urdf=True, - ).mutable(mutable=True, validate=True) - - # Initialize the model with a random state - model.data.model_state = utils_rng.random_physics_model_state( - physics_model=model.physics_model - ) - - # Initialize the model with a random input - model.data.model_input = utils_rng.random_physics_model_input( - physics_model=model.physics_model - ) - - # Get the joint torques - tau = model.joint_generalized_forces_targets() - - # Compute model acceleration with ABA - v̇_WB_aba, s̈_aba = model.forward_dynamics_aba(tau=tau) - - # ============================================== - # Compute forward dynamics with dedicated method - # ============================================== - - v̇_WB, s̈ = model.forward_dynamics_crb(tau=tau) - - assert s̈.squeeze() == pytest.approx(s̈_aba.squeeze(), abs=0.5) - assert v̇_WB.squeeze() == pytest.approx(v̇_WB_aba.squeeze(), abs=0.2) diff --git a/tests/test_jax_oop.py b/tests/test_jax_oop.py deleted file mode 100644 index f887ecfae..000000000 --- a/tests/test_jax_oop.py +++ /dev/null @@ -1,422 +0,0 @@ -import dataclasses -import io -from contextlib import redirect_stdout -from typing import Any, Type - -import jax -import jax.numpy as jnp -import jax_dataclasses -import numpy as np -import pytest - -from jaxsim.utils import Mutability, Vmappable, oop - -try: - from typing import Self -except ImportError: - from typing_extensions import Self - - -@jax_dataclasses.pytree_dataclass -class AlgoData(Vmappable): - """Class storing vmappable data of a given algorithm.""" - - counter: jax.Array = dataclasses.field( - default_factory=lambda: jnp.array(0, dtype=jnp.uint64) - ) - - @classmethod - def build(cls: Type[Self], counter: jax.typing.ArrayLike) -> Self: - """Builder method. Helpful for enforcing type and shape of fields.""" - - # Counter can be int / scalar numpy array / scalar jax array / etc. - if jnp.array(counter).squeeze().size != 1: - raise ValueError("The counter must be a scalar") - - # Create the object enforcing `counter` to be a scalar jax array - data = AlgoData( - counter=jnp.array(counter, dtype=jnp.uint64).squeeze(), - ) - - return data - - -def test_data(): - """Test AlgoData class.""" - - data1 = AlgoData.build(counter=0) - data2 = AlgoData.build(counter=np.array(10)) - data3 = AlgoData.build(counter=jnp.array(50)) - - assert isinstance(data1.counter, jax.Array) and data1.counter.dtype == jnp.uint64 - assert isinstance(data2.counter, jax.Array) and data2.counter.dtype == jnp.uint64 - assert isinstance(data3.counter, jax.Array) and data3.counter.dtype == jnp.uint64 - - assert data1.batch_size == 0 - assert data2.batch_size == 0 - assert data3.batch_size == 0 - - # ================== - # Vectorizing PyTree - # ================== - - for batch_size in (0, 10, 100): - data_vec = data1.vectorize(batch_size=batch_size) - - assert data_vec.batch_size == batch_size - - if batch_size > 0: - assert data_vec.counter.shape[0] == batch_size - - # ========================================= - # Extracting element from vectorized PyTree - # ========================================= - - data_vec = AlgoData.build_from_list(list_of_obj=[data1, data2, data3]) - assert data_vec.batch_size == 3 - assert data_vec.extract_element(index=0) == data1 - assert data_vec.extract_element(index=1) == data2 - assert data_vec.extract_element(index=2) == data3 - - with pytest.raises(ValueError): - _ = data_vec.extract_element(index=3) - - out = data1.extract_element(index=0) - assert out == data1 - assert id(out) != id(data1) - - with pytest.raises(RuntimeError): - _ = data1.extract_element(index=1) - - with pytest.raises(ValueError): - _ = AlgoData.build_from_list(list_of_obj=[data1, data2, data3, 42]) - - -@jax_dataclasses.pytree_dataclass -class MyClassWithAlgorithms(Vmappable): - """ - Class to demonstrate how to use `Vmappable`. - """ - - # Dynamic data of the algorithm - data: AlgoData = dataclasses.field(default=None) - - # Static attribute of the pytree (triggers recompilation if changed) - double_input: jax_dataclasses.Static[bool] = dataclasses.field(default=None) - - # Non-static attribute of the pytree that is not transparently vmap-able. - done: jax.typing.ArrayLike = dataclasses.field( - default_factory=lambda: jnp.array(False, dtype=bool) - ) - - # Additional leaves to test the behaviour of mutable and immutable python objects - my_tuple: tuple[int] = dataclasses.field(default=tuple(jnp.array([1, 2, 3]))) - my_list: list[int] = dataclasses.field( - default_factory=lambda: [4, 5, 6], init=False - ) - my_array: jax.Array = dataclasses.field( - default_factory=lambda: jnp.array([10, 20, 30]) - ) - - @classmethod - def build(cls: Type[Self], double_input: bool = False) -> Self: - """""" - - obj = MyClassWithAlgorithms() - - with obj.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): - obj.data = AlgoData.build(counter=0) - obj.double_input = jnp.array(double_input) - - return obj - - @oop.jax_tf.method_ro - def algo_ro(self, advance: int | jax.typing.ArrayLike) -> Any: - """This is a read-only algorithm. It does not alter any pytree leaf.""" - - # This should be printed only the first execution since it is disabled - # in the execution of the JIT-compiled function. - print("__algo_ro__") - - # Use the dynamic condition that doubles the input value - mul = jax.lax.select(self.double_input, 2, 1) - - # Increase the counter - counter_old = jnp.atleast_1d(self.data.counter)[0] - counter_new = counter_old + mul * advance - - # Return the updated counter - return counter_new - - @oop.jax_tf.method_rw - def algo_rw(self, advance: int | jax.typing.ArrayLike) -> Any: - """ - This is a read-write algorithm. It may alter pytree leaves either belonging - to the vmappable data or generic non-static dataclass attributes. - """ - - print(self) - - # This should be printed only the first execution since it is disabled - # in the execution of the JIT-compiled function. - print("__algo_rw__") - - # Use the dynamic condition that doubles the input value - mul = jax.lax.select(self.double_input, 2, 1) - - # Increase the internal counter - counter_old = jnp.atleast_1d(self.data.counter)[0] - self.data.counter = jnp.array(counter_old + mul * advance, dtype=jnp.uint64) - - # Update the non-static and non-vmap-able attribute - self.done = jax.lax.cond( - pred=self.data.counter > 100, - true_fun=lambda _: jnp.array(True), - false_fun=lambda _: jnp.array(False), - operand=None, - ) - - print(self) - - # Return the updated counter - return self.data.counter - - -def test_mutability(): - """Test MyClassWithAlgorithms class.""" - - # Build the object - obj_ro = MyClassWithAlgorithms.build(double_input=True) - - # By default, pytrees built with jax_dataclasses are frozen (read-only) - assert obj_ro._mutability() == Mutability.FROZEN - with pytest.raises(dataclasses.FrozenInstanceError): - obj_ro.data.counter = 42 - - # Data can be changed through a context manager, in this case operating on a copy... - with obj_ro.editable(validate=True) as obj_ro_copy: - obj_ro_copy.data.counter = jnp.array(42, dtype=obj_ro.data.counter.dtype) - assert obj_ro_copy.data.counter == pytest.approx(42) - assert obj_ro.data.counter != pytest.approx(42) - - # ... or a context manager that does not copy the pytree... - with obj_ro.mutable_context(mutability=Mutability.MUTABLE): - obj_ro.data.counter = jnp.array(42, dtype=obj_ro.data.counter.dtype) - assert obj_ro.data.counter == pytest.approx(42) - - # ... that raises if the leafs change type - with pytest.raises(AssertionError): - with obj_ro.mutable_context(mutability=Mutability.MUTABLE): - obj_ro.data.counter = 42 - - # Pytrees can be copied... - obj_ro_copy = obj_ro.copy() - assert id(obj_ro) != id(obj_ro_copy) - # ... operation that does not copy the leaves - # TODO describe - assert id(obj_ro.done) == id(obj_ro_copy.done) - assert id(obj_ro.data.counter) == id(obj_ro_copy.data.counter) - assert id(obj_ro.my_array) == id(obj_ro_copy.my_array) - assert id(obj_ro.my_tuple) != id(obj_ro_copy.my_tuple) - assert id(obj_ro.my_list) != id(obj_ro_copy.my_list) - - # They can be converted as mutable pytrees to update their values without - # using context managers (maybe useful for debugging or quick prototyping) - obj_rw = obj_ro.copy().mutable(validate=True) - assert obj_rw._mutability() == Mutability.MUTABLE - obj_rw.data.counter = jnp.array(42, dtype=obj_rw.data.counter.dtype) - - # However, with validation enabled, this works only if the leaf does not - # change its type (shape, dtype, weakness, ...) - with pytest.raises(AssertionError): - obj_rw.data.counter = 100 - with pytest.raises(AssertionError): - obj_rw.data.counter = jnp.array(100, dtype=float) - with pytest.raises(AssertionError): - obj_rw.data.counter = jnp.array([100, 200], dtype=obj_rw.data.counter.dtype) - - # Instead, with validation disabled, the pytree structure can be altered - # (and this might cause JIT recompilations, so use it at your own risk) - obj_rw_noval = obj_ro.copy().mutable(validate=False) - assert obj_rw_noval._mutability() == Mutability.MUTABLE_NO_VALIDATION - obj_rw_noval.data.counter = jnp.array(42, dtype=obj_rw.data.counter.dtype) - - # Now this should work without exceptions - obj_rw_noval.data.counter = 100 - obj_rw_noval.data.counter = jnp.array(100, dtype=float) - obj_rw_noval.data.counter = jnp.array([100, 200], dtype=obj_rw.data.counter.dtype) - - # Build another object and check mutability changes - obj_ro = MyClassWithAlgorithms.build(double_input=True) - assert obj_ro.is_mutable(validate=True) is False - assert obj_ro.is_mutable(validate=False) is False - - obj_rw_val = obj_ro.mutable(validate=True) - assert id(obj_ro) == id(obj_rw_val) - assert obj_rw_val.is_mutable(validate=True) is True - assert obj_rw_val.is_mutable(validate=False) is False - - obj_rw_noval = obj_rw_val.mutable(validate=False) - assert id(obj_rw_noval) == id(obj_rw_val) - assert obj_rw_noval.is_mutable(validate=True) is False - assert obj_rw_noval.is_mutable(validate=False) is True - - # Checking mutable leaves behavior - obj_rw = MyClassWithAlgorithms.build(double_input=True).mutable(validate=True) - obj_rw_copy = obj_rw.copy() - - # Memory of JAX arrays cannot be altered in place so this is safe - obj_rw.my_array = obj_rw.my_array.at[1].set(-20) - assert obj_rw_copy.my_array[1] != -20 - - # Tuples are immutable so this should be safe too - obj_rw.my_tuple = tuple(jnp.array([1, -2, 3])) - assert obj_rw_copy.my_array[1] != -2 - - # Lists are treated as tuples (they are not leaves) but since they are mutable, - # their id changes - obj_rw.my_list[1] = -5 - assert obj_rw_copy.my_list[1] != -5 - - # Check that exceptions in mutable context do not alter the object - obj_ro = MyClassWithAlgorithms.build(double_input=True) - assert obj_ro.data.counter == 0 - assert obj_ro.double_input == jnp.array(True) - - with pytest.raises(RuntimeError): - with obj_ro.mutable_context(mutability=Mutability.MUTABLE): - obj_ro.double_input = jnp.array(False, dtype=obj_ro.double_input.dtype) - obj_ro.data.counter = jnp.array(33, dtype=obj_ro.data.counter.dtype) - raise RuntimeError - assert obj_ro.data.counter == 0 - assert obj_ro.double_input == jnp.array(True) - - -def test_decorators_jit_compilation(): - """Test JIT features of MyClassWithAlgorithms class.""" - - obj = MyClassWithAlgorithms.build(double_input=False) - assert obj.data.counter == 0 - assert obj.is_mutable(validate=True) is False - assert obj.is_mutable(validate=False) is False - - # JIT compilation should happen only the first function call. - # We test this by checking that the first execution prints some output. - with io.StringIO() as buf, redirect_stdout(buf): - _ = obj.algo_ro(advance=1) - printed = buf.getvalue() - assert "__algo_ro__" in printed - with io.StringIO() as buf, redirect_stdout(buf): - _ = obj.algo_ro(advance=1) - printed = buf.getvalue() - assert "__algo_ro__" not in printed - - # JIT compilation should happen only the first function call. - # We test this by checking that the first execution prints some output. - with io.StringIO() as buf, redirect_stdout(buf): - _ = obj.algo_rw(advance=1) - printed = buf.getvalue() - assert "__algo_rw__" in printed - with io.StringIO() as buf, redirect_stdout(buf): - _ = obj.algo_rw(advance=1) - printed = buf.getvalue() - assert "__algo_rw__" not in printed - - # Create a new object - obj = MyClassWithAlgorithms.build(double_input=False) - - # New objects should be able to re-use the JIT-compiled functions from other objects - with io.StringIO() as buf, redirect_stdout(buf): - _ = obj.algo_ro(advance=1) - _ = obj.algo_rw(advance=1) - printed = buf.getvalue() - assert "__algo_ro__" not in printed - assert "__algo_rw__" not in printed - - # Create a new object - obj = MyClassWithAlgorithms.build(double_input=False) - - # Read-only methods can be called on r/o objects - out = obj.algo_ro(advance=1) - assert out == obj.data.counter + 1 - out = obj.algo_ro(advance=1) - assert out == obj.data.counter + 1 - - # Read-write methods can be called too on r/o objects since they are marked as r/w - out = obj.algo_rw(advance=1) - assert out == 1 - out = obj.algo_rw(advance=1) - assert out == 2 - out = obj.algo_rw(advance=2) - assert out == 4 - - # Create a new object with a different dynamic attribute - obj_dyn = MyClassWithAlgorithms.build(double_input=False).mutable(validate=True) - obj_dyn.done = jnp.array(not obj_dyn.done, dtype=bool) - - # New objects with different dynamic attributes should be able to re-use the - # JIT-compiled functions from other objects - with io.StringIO() as buf, redirect_stdout(buf): - _ = obj.algo_ro(advance=1) - _ = obj.algo_rw(advance=1) - printed = buf.getvalue() - assert "__algo_ro__" not in printed - assert "__algo_rw__" not in printed - - # Create a new object with a different static attribute - obj_stat = MyClassWithAlgorithms.build(double_input=True) - - # New objects with different static attributes trigger the recompilation of the - # JIT-compiled functions... - with io.StringIO() as buf, redirect_stdout(buf): - _ = obj_stat.algo_ro(advance=1) - _ = obj_stat.algo_rw(advance=1) - printed = buf.getvalue() - assert "__algo_ro__" in printed - assert "__algo_rw__" in printed - - # ... that are cached as well by jax - with io.StringIO() as buf, redirect_stdout(buf): - _ = obj_stat.algo_ro(advance=1) - _ = obj_stat.algo_rw(advance=1) - printed = buf.getvalue() - assert "__algo_ro__" not in printed - assert "__algo_rw__" not in printed - - -def test_decorators_vmap(): - """Test automatic vectorization features of MyClassWithAlgorithms class.""" - - # Create a new object with scalar data - obj = MyClassWithAlgorithms.build(double_input=False) - - # Vectorize the entire object - obj_vec = obj.vectorize(batch_size=10) - assert obj_vec.vectorized is True - assert obj_vec.batch_size == 10 - assert id(obj_vec) != id(obj) - - # Calling methods of vectorized objects with scalar arguments should raise an error - with pytest.raises(ValueError): - _ = obj_vec.algo_ro(advance=1) - with pytest.raises(ValueError): - _ = obj_vec.algo_rw(advance=1) - - # Check that the r/o method provides automatically vectorized output and accepts - # vectorized input - out_vec = obj_vec.algo_ro(advance=jnp.array([1] * obj_vec.batch_size)) - assert out_vec.shape[0] == 10 - assert set(out_vec.tolist()) == {1} - - # Check that the r/w method provides automatically vectorized output and accepts - # vectorized input - out_vec = obj_vec.algo_rw(advance=jnp.array([1] * obj_vec.batch_size)) - assert out_vec.shape[0] == 10 - assert set(out_vec.tolist()) == {1} - out_vec = obj_vec.algo_rw(advance=jnp.array([1] * obj_vec.batch_size)) - assert set(out_vec.tolist()) == {2} - - # Extract a single object from the vectorized object - obj = obj_vec.extract_element(index=5) - assert obj.vectorized is False - assert obj.data.counter == obj_vec.data.counter[5] diff --git a/tests/utils_models.py b/tests/utils_models.py deleted file mode 100644 index 4d7363d52..000000000 --- a/tests/utils_models.py +++ /dev/null @@ -1,56 +0,0 @@ -import enum -import pathlib - -import robot_descriptions.anymal_c_description -import robot_descriptions.cassie_description -import robot_descriptions.double_pendulum_description -import robot_descriptions.icub_description -import robot_descriptions.laikago_description -import robot_descriptions.panda_description -import robot_descriptions.ur10_description - - -class Robot(enum.IntEnum): - iCub = enum.auto() - Ur10 = enum.auto() - Panda = enum.auto() - Cassie = enum.auto() - AnymalC = enum.auto() - Laikago = enum.auto() - DoublePendulum = enum.auto() - - -class ModelFactory: - """Factory class providing URDF files used by the tests.""" - - @staticmethod - def get_model_description(robot: Robot) -> pathlib.Path: - """ - Get the URDF file of different robots. - - Args: - robot: Robot name of the desired URDF file. - - Returns: - Path to the URDF file of the robot. - """ - - match robot: - case Robot.iCub: - return pathlib.Path(robot_descriptions.icub_description.URDF_PATH) - case Robot.Ur10: - return pathlib.Path(robot_descriptions.ur10_description.URDF_PATH) - case Robot.Panda: - return pathlib.Path(robot_descriptions.panda_description.URDF_PATH) - case Robot.Cassie: - return pathlib.Path(robot_descriptions.cassie_description.URDF_PATH) - case Robot.AnymalC: - return pathlib.Path(robot_descriptions.anymal_c_description.URDF_PATH) - case Robot.Laikago: - return pathlib.Path(robot_descriptions.laikago_description.URDF_PATH) - case Robot.DoublePendulum: - return pathlib.Path( - robot_descriptions.double_pendulum_description.URDF_PATH - ) - case _: - raise ValueError(f"Unknown robot '{robot}'") diff --git a/tests/utils_rng.py b/tests/utils_rng.py deleted file mode 100644 index 17dd7acd0..000000000 --- a/tests/utils_rng.py +++ /dev/null @@ -1,96 +0,0 @@ -import numpy as np - -from jaxsim import logging, sixd -from jaxsim.physics.model.physics_model import PhysicsModel -from jaxsim.physics.model.physics_model_state import ( - PhysicsModelInput, - PhysicsModelState, -) -from jaxsim.utils import Mutability - -# Initialize a global RNG used by all tests -test_rng = None - - -def get_rng(seed: int = None) -> np.random.Generator: - """ - Get a random number generator that can be used to produce reproducibile sequences. - - Args: - seed: The optional seed of the RNG (ignored if the RNG - - Returns: - A random number generator. - """ - - global test_rng - - if test_rng is not None and seed is not None: - msg = "Seed was already configured globally, ignoring the new one" - logging.warning(msg=msg) - - seed = seed if seed is not None else 42 - test_rng = test_rng if test_rng is not None else np.random.default_rng(seed=seed) - - return test_rng - - -def random_physics_model_state(physics_model: PhysicsModel) -> PhysicsModelState: - """ - Generate a random `PhysicsModelState` object. - - Args: - physics_model: the physics model to which the random state refers to. - - Returns: - The random `PhysicsModelState` object. - """ - - rng = get_rng() - - with PhysicsModelState.zero(physics_model=physics_model).mutable_context( - mutability=Mutability.MUTABLE - ) as state: - # Generate random joint quantities - state.joint_positions = rng.uniform(size=physics_model.dofs(), low=-1) - state.joint_velocities = rng.uniform(size=physics_model.dofs(), low=-1) - - # Generate random base quantities - state.base_position = rng.uniform(size=3, low=-1) - state.base_quaternion = sixd.so3.SO3.from_rpy_radians( - *rng.uniform(low=0, high=2 * np.pi, size=3) - ).as_quaternion_xyzw()[np.array([3, 0, 1, 2])] - - # If floating-base, generate random base velocities - if physics_model.is_floating_base: - state.base_linear_velocity = rng.uniform(size=3, low=-1) - state.base_angular_velocity = rng.uniform(size=3, low=-1) - - return state - - -def random_physics_model_input(physics_model: PhysicsModel) -> PhysicsModelInput: - """ - Generate a random `PhysicsModelInput` object. - - Args: - physics_model: the physics model to which the random state refers to. - - Returns: - The random `PhysicsModelInput` object. - """ - - rng = get_rng() - - with PhysicsModelInput.zero(physics_model=physics_model).mutable_context( - mutability=Mutability.MUTABLE - ) as model_input: - # Generate random joint torques and external forces - model_input.tau = 10 * rng.uniform(size=physics_model.dofs(), low=-1) - model_input.f_ext = 10 * rng.uniform(size=[physics_model.NB, 6], low=-1) - - # Zero the base force if the robot is fixed base - if not physics_model.is_floating_base and physics_model.NB > 0: - model_input.f_ext[0] = np.zeros(6) - - return model_input From 1b506de19e532e3c25fb14bfa85470d03e83983d Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 8 Mar 2024 13:05:32 +0100 Subject: [PATCH 02/35] Add a pytest fixture to generate a PRNG key --- tests/conftest.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 tests/conftest.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..ffa0c13d7 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,35 @@ +import os + +import jax +import pytest + + +def pytest_configure() -> None: + """Pytest configuration hook.""" + + # This is a global variable that is updated by the `prng_key` fixture. + pytest.prng_key = jax.random.PRNGKey( + seed=int(os.environ.get("JAXSIM_TEST_SEED", 0)) + ) + + +# ================ +# Generic fixtures +# ================ + + +@pytest.fixture(scope="function") +def prng_key() -> jax.Array: + """ + Fixture to generate a new PRNG key for each test function. + + Returns: + The new PRNG key passed to the test. + + Note: + This fixture operates on a global variable initialized in the + `pytest_configure` hook. + """ + + pytest.prng_key, subkey = jax.random.split(pytest.prng_key, num=2) + return subkey From ff02a33164d64a6a0881a438067be0688a3ac8aa Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 8 Mar 2024 11:02:02 +0100 Subject: [PATCH 03/35] Add fixture to parameterize tests over all velocity representations --- tests/conftest.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index ffa0c13d7..f72e6fe75 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import jax import pytest +import jaxsim def pytest_configure() -> None: @@ -33,3 +34,22 @@ def prng_key() -> jax.Array: pytest.prng_key, subkey = jax.random.split(pytest.prng_key, num=2) return subkey + + +@pytest.fixture( + scope="function", + params=[ + pytest.param(jaxsim.VelRepr.Inertial, id="inertial"), + pytest.param(jaxsim.VelRepr.Body, id="body"), + pytest.param(jaxsim.VelRepr.Mixed, id="mixed"), + ], +) +def velocity_representation(request) -> jaxsim.VelRepr: + """ + Parametrized fixture providing all supported velocity representations. + + Returns: + A velocity representation. + """ + + return request.param From d70568ca44abdfb364e0a9827e5fa4cee938b0c5 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 8 Mar 2024 16:06:42 +0100 Subject: [PATCH 04/35] Add session-wide fixtures to provide tested models --- tests/conftest.py | 187 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 187 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index f72e6fe75..7d91d3a4c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,13 @@ import os +import pathlib import jax +import jax.numpy as jnp import pytest +import rod + import jaxsim +import jaxsim.api as js def pytest_configure() -> None: @@ -53,3 +58,185 @@ def velocity_representation(request) -> jaxsim.VelRepr: """ return request.param + + +# ================================ +# Fixtures providing JaxSim models +# ================================ + +# All the fixtures in this section must have "session" scope. +# In this way, the models are generated only once and shared among all the tests. + + +# This is not a fixture. +def build_jaxsim_model( + model_description: str | pathlib.Path | rod.Model, +) -> js.model.JaxSimModel: + """ + Helper to build a JaxSim model from a model description. + + Args: + model_description: A model description provided by any fixture provider. + + Returns: + A JaxSim model built from the provided description. + """ + + is_urdf = None + + # If the provided description is a string, automatically detect if it + # contains the content of a URDF or SDF file. + if isinstance(model_description, str): + if "" in model_description: + is_urdf = False + + else: + is_urdf = None + + # Build the JaxSim model. + model = js.model.JaxSimModel.build_from_model_description( + model_description=model_description, + gravity=jnp.array([0, 0, -10]), + is_urdf=is_urdf, + ) + + return model + + +@pytest.fixture(scope="session") +def jaxsim_model_box() -> js.model.JaxSimModel: + """ + Fixture providing the JaxSim model of a box. + + Returns: + The JaxSim model of a box. + """ + + import rod.builder.primitives + import rod.urdf.exporter + + # Create on-the-fly a ROD model of a box. + rod_model = ( + rod.builder.primitives.BoxBuilder(x=0.3, y=0.2, z=0.1, mass=1.0, name="box") + .build_model() + .add_link() + .add_inertial() + .add_visual() + .add_collision() + .build() + ) + + # Export the URDF string. + urdf_string = rod.urdf.exporter.UrdfExporter.sdf_to_urdf_string( + sdf=rod_model, pretty=True + ) + + return build_jaxsim_model(model_description=urdf_string) + + +@pytest.fixture(scope="session") +def jaxsim_model_sphere() -> js.model.JaxSimModel: + """ + Fixture providing the JaxSim model of a sphere. + + Returns: + The JaxSim model of a sphere. + """ + + import rod.builder.primitives + import rod.urdf.exporter + + # Create on-the-fly a ROD model of a sphere. + rod_model = ( + rod.builder.primitives.SphereBuilder(radius=0.1, mass=1.0, name="sphere") + .build_model() + .add_link() + .add_inertial() + .add_visual() + .add_collision() + .build() + ) + + # Export the URDF string. + urdf_string = rod.urdf.exporter.UrdfExporter.sdf_to_urdf_string( + sdf=rod_model, pretty=True + ) + + return build_jaxsim_model(model_description=urdf_string) + + +@pytest.fixture(scope="session") +def jaxsim_model_ergocub() -> js.model.JaxSimModel: + """ + Fixture providing the JaxSim model of the ErgoCub robot. + + Returns: + The JaxSim model of the ErgoCub robot. + """ + + os_environ_original = os.environ.copy() + + try: + os.environ["ROBOT_DESCRIPTION_COMMIT"] = "v0.7.1" + + import robot_descriptions.ergocub_description + finally: + os.environ = os_environ_original + + model_urdf_path = pathlib.Path( + robot_descriptions.ergocub_description.URDF_PATH.replace( + "ergoCubSN000", "ergoCubSN001" + ) + ) + + return build_jaxsim_model(model_description=model_urdf_path) + + +@pytest.fixture(scope="session") +def jaxsim_model_ergocub_reduced(jaxsim_model_ergocub) -> js.model.JaxSimModel: + """ + Fixture providing the JaxSim model of the ErgoCub robot with only locomotion joints. + + Returns: + The JaxSim model of the ErgoCub robot with only locomotion joints. + """ + + model_full = jaxsim_model_ergocub + + # Get the names of the joints to keep + reduced_joints = tuple( + j + for j in model_full.joint_names() + if "camera" not in j + # Remove head and hands + and "neck" not in j + and "wrist" not in j + and "thumb" not in j + and "index" not in j + and "middle" not in j + and "ring" not in j + and "pinkie" not in j + # Remove upper body + and "torso" not in j and "elbow" not in j and "shoulder" not in j + ) + + return js.model.reduce(model=model_full, considered_joints=reduced_joints) + + +@pytest.fixture(scope="session") +def jaxsim_model_ur10() -> js.model.JaxSimModel: + """ + Fixture providing the JaxSim model of the UR10 robot. + + Returns: + The JaxSim model of the UR10 robot. + """ + + import robot_descriptions.ur10_description + + model_urdf_path = pathlib.Path(robot_descriptions.ur10_description.URDF_PATH) + + return build_jaxsim_model(model_description=model_urdf_path) From 020ff88ca404421eaabbd4161e3b3380beb4fee5 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 8 Mar 2024 15:45:41 +0100 Subject: [PATCH 05/35] Add collections of tested models --- tests/conftest.py | 123 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 123 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 7d91d3a4c..762cacd99 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -240,3 +240,126 @@ def jaxsim_model_ur10() -> js.model.JaxSimModel: model_urdf_path = pathlib.Path(robot_descriptions.ur10_description.URDF_PATH) return build_jaxsim_model(model_description=model_urdf_path) + + +# ============================ +# Collections of JaxSim models +# ============================ + + +def get_jaxsim_model_fixture( + model_name: str, request: pytest.FixtureRequest +) -> str | pathlib.Path: + """ + Factory to get the fixture providing the JaxSim model of a robot. + + Args: + model_name: The name of the model. + request: The request object. + + Returns: + The JaxSim model of the robot. + """ + + match model_name: + case "box": + return request.getfixturevalue(jaxsim_model_box.__name__) + case "sphere": + return request.getfixturevalue(jaxsim_model_sphere.__name__) + case "ergocub": + return request.getfixturevalue(jaxsim_model_ergocub.__name__) + case "ergocub_reduced": + return request.getfixturevalue(jaxsim_model_ergocub_reduced.__name__) + case "ur10": + return request.getfixturevalue(jaxsim_model_ur10.__name__) + case _: + raise ValueError(model_name) + + +@pytest.fixture( + scope="session", + params=[ + "box", + "sphere", + "ur10", + "ergocub", + "ergocub_reduced", + ], +) +def jaxsim_models_all(request) -> pathlib.Path | str: + """ + Fixture providing the JaxSim models of all supported robots. + """ + + model_name: str = request.param + return get_jaxsim_model_fixture(model_name=model_name, request=request) + + +@pytest.fixture( + scope="session", + params=[ + "box", + "ur10", + "ergocub_reduced", + ], +) +def jaxsim_models_types(request) -> pathlib.Path | str: + """ + Fixture providing JaxSim models of all types of supported robots. + + Note: + At the moment, most of our tests use this fixture. It provides: + - A robot with no joints. + - A fixed-base robot. + - A floating-base robot. + """ + + model_name: str = request.param + return get_jaxsim_model_fixture(model_name=model_name, request=request) + + +@pytest.fixture( + scope="session", + params=[ + "box", + "sphere", + ], +) +def jaxsim_models_no_joints(request) -> pathlib.Path | str: + """ + Fixture providing JaxSim models of robots with no joints. + """ + + model_name: str = request.param + return get_jaxsim_model_fixture(model_name=model_name, request=request) + + +@pytest.fixture( + scope="session", + params=[ + "ergocub", + "ergocub_reduced", + ], +) +def jaxsim_models_floating_base(request) -> pathlib.Path | str: + """ + Fixture providing JaxSim models of floating-base robots. + """ + + model_name: str = request.param + return get_jaxsim_model_fixture(model_name=model_name, request=request) + + +@pytest.fixture( + scope="session", + params=[ + "ur10", + ], +) +def jaxsim_models_fixed_base(request) -> pathlib.Path | str: + """ + Fixture providing JaxSim models of fixed-base robots. + """ + + model_name: str = request.param + return get_jaxsim_model_fixture(model_name=model_name, request=request) From 43fa2f9f2c6e3f83f0c5cc5baab86807de61eb3c Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 8 Mar 2024 11:12:48 +0100 Subject: [PATCH 06/35] Add tests of jaxsim.api.data module --- tests/test_api_data.py | 145 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 145 insertions(+) create mode 100644 tests/test_api_data.py diff --git a/tests/test_api_data.py b/tests/test_api_data.py new file mode 100644 index 000000000..70084fd03 --- /dev/null +++ b/tests/test_api_data.py @@ -0,0 +1,145 @@ +import jax +import jax.numpy as jnp +import pytest + +import jaxsim.api as js +from jaxsim import VelRepr +from jaxsim.utils import Mutability + +from . import utils_idyntree + + +def test_data_valid( + jaxsim_models_all: js.model.JaxSimModel, +): + + model = jaxsim_models_all + data = js.data.JaxSimModelData.build(model=model) + + assert data.valid(model=model) + + +def test_data_joint_indexing( + jaxsim_models_types: js.model.JaxSimModel, + velocity_representation: VelRepr, + prng_key: jax.Array, +): + + model = jaxsim_models_types + + key, subkey = jax.random.split(prng_key, num=2) + data = js.data.random_model_data( + model=model, key=subkey, velocity_representation=velocity_representation + ) + + assert data.joint_positions( + model=model, joint_names=model.joint_names() + ) == pytest.approx(data.joint_positions()) + + assert data.joint_positions() == pytest.approx( + data.state.physics_model.joint_positions + ) + + assert data.joint_velocities( + model=model, joint_names=model.joint_names() + ) == pytest.approx(data.joint_velocities()) + + assert data.joint_velocities() == pytest.approx( + data.state.physics_model.joint_velocities + ) + + +def test_data_switch_velocity_representation( + jaxsim_models_types: js.model.JaxSimModel, + prng_key: jax.Array, +): + + model = jaxsim_models_types + + key, subkey = jax.random.split(prng_key, num=2) + data = js.data.random_model_data( + model=model, key=subkey, velocity_representation=VelRepr.Inertial + ) + + # ===== + # Tests + # ===== + + new_base_linear_velocity = jnp.array([1.0, -2.0, 3.0]) + old_base_linear_velocity = data.state.physics_model.base_linear_velocity + + # The following should not change the original `data` object since it raises + with pytest.raises(RuntimeError): + with data.switch_velocity_representation( + velocity_representation=VelRepr.Inertial + ): + with data.mutable_context(mutability=Mutability.MUTABLE): + data.state.physics_model.base_linear_velocity = new_base_linear_velocity + raise RuntimeError("This is raised on purpose inside this context") + + assert data.state.physics_model.base_linear_velocity == pytest.approx( + old_base_linear_velocity + ) + + # The following instead should result to an updated `data` object + with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial): + with data.mutable_context(mutability=Mutability.MUTABLE): + data.state.physics_model.base_linear_velocity = new_base_linear_velocity + + assert data.state.physics_model.base_linear_velocity == pytest.approx( + new_base_linear_velocity + ) + + +def test_data_change_velocity_representation( + jaxsim_models_types: js.model.JaxSimModel, + prng_key: jax.Array, +): + + model = jaxsim_models_types + + key, subkey = jax.random.split(prng_key, num=2) + data = js.data.random_model_data( + model=model, key=subkey, velocity_representation=VelRepr.Inertial + ) + + # ===== + # Tests + # ===== + + kin_dyn_inertial = utils_idyntree.build_kindyncomputations_from_jaxsim_model( + model=model, data=data + ) + + with data.switch_velocity_representation(VelRepr.Mixed): + kin_dyn_mixed = utils_idyntree.build_kindyncomputations_from_jaxsim_model( + model=model, data=data + ) + + with data.switch_velocity_representation(VelRepr.Body): + kin_dyn_body = utils_idyntree.build_kindyncomputations_from_jaxsim_model( + model=model, data=data + ) + + assert data.base_velocity() == pytest.approx(kin_dyn_inertial.base_velocity()) + + if not model.floating_base(): + return + + with data.switch_velocity_representation(VelRepr.Mixed): + assert data.base_velocity() == pytest.approx(kin_dyn_mixed.base_velocity()) + assert data.base_velocity()[0:3] != pytest.approx( + data.state.physics_model.base_linear_velocity + ) + assert data.base_velocity()[3:6] == pytest.approx( + data.state.physics_model.base_angular_velocity + ) + + with data.switch_velocity_representation(VelRepr.Body): + assert data.base_velocity() == pytest.approx(kin_dyn_body.base_velocity()) + assert data.base_velocity()[0:3] != pytest.approx( + data.state.physics_model.base_linear_velocity + ) + assert data.base_velocity()[3:6] != pytest.approx( + data.state.physics_model.base_angular_velocity + ) From 5cdd9142021e594c551dc3a596c8be545b8e8e54 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 8 Mar 2024 11:13:06 +0100 Subject: [PATCH 07/35] Add test of jaxsim.api.joint module --- tests/test_api_joint.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 tests/test_api_joint.py diff --git a/tests/test_api_joint.py b/tests/test_api_joint.py new file mode 100644 index 000000000..e64fc1210 --- /dev/null +++ b/tests/test_api_joint.py @@ -0,0 +1,34 @@ +import jax.numpy as jnp +import pytest + +import jaxsim.api as js + + +def test_joint_index( + jaxsim_models_types: js.model.JaxSimModel, +): + + model = jaxsim_models_types + + # ===== + # Tests + # ===== + + for idx, joint_name in enumerate(model.joint_names()): + assert js.joint.name_to_idx(model=model, joint_name=joint_name) == idx + + assert js.joint.names_to_idxs( + model=model, joint_names=model.joint_names() + ) == pytest.approx(jnp.arange(model.number_of_joints())) + + assert ( + js.joint.idxs_to_names( + model=model, + joint_indices=tuple( + js.joint.names_to_idxs( + model=model, joint_names=model.joint_names() + ).tolist() + ), + ) + == model.joint_names() + ) From 446369a1f88d2129072e99d00d77fbf8552d8aea Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 8 Mar 2024 11:13:22 +0100 Subject: [PATCH 08/35] Add test of jaxsim.api.link module --- tests/test_api_link.py | 151 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 151 insertions(+) create mode 100644 tests/test_api_link.py diff --git a/tests/test_api_link.py b/tests/test_api_link.py new file mode 100644 index 000000000..51dd5d243 --- /dev/null +++ b/tests/test_api_link.py @@ -0,0 +1,151 @@ +import jax +import jax.numpy as jnp +import pytest + +import jaxsim.api as js +from jaxsim import VelRepr + +from . import utils_idyntree + + +def test_link_index( + jaxsim_models_types: js.model.JaxSimModel, +): + + model = jaxsim_models_types + + # ===== + # Tests + # ===== + + for idx, link_name in enumerate(model.link_names()): + assert js.link.name_to_idx(model=model, link_name=link_name) == idx + + assert js.link.names_to_idxs( + model=model, link_names=model.link_names() + ) == pytest.approx(jnp.arange(model.number_of_links())) + + assert ( + js.link.idxs_to_names( + model=model, + link_indices=tuple( + js.link.names_to_idxs( + model=model, link_names=model.link_names() + ).tolist() + ), + ) + == model.link_names() + ) + + +def test_link_inertial_properties( + jaxsim_models_types: js.model.JaxSimModel, + prng_key: jax.Array, +): + + model = jaxsim_models_types + + key, subkey = jax.random.split(prng_key, num=2) + data = js.data.random_model_data( + model=model, + key=subkey, + velocity_representation=VelRepr.Inertial, + ) + + kin_dyn = utils_idyntree.build_kindyncomputations_from_jaxsim_model( + model=model, data=data + ) + + # ===== + # Tests + # ===== + + for link_name, link_idx in zip( + model.link_names(), + js.link.names_to_idxs(model=model, link_names=model.link_names()), + ): + if link_name == model.base_link(): + continue + + assert js.link.mass(model=model, link_index=link_idx) == pytest.approx( + kin_dyn.link_mass(link_name=link_name) + ), link_name + + assert js.link.spatial_inertia( + model=model, link_index=link_idx + ) == pytest.approx(kin_dyn.link_spatial_inertia(link_name=link_name)), link_name + + +def test_link_transforms( + jaxsim_models_types: js.model.JaxSimModel, + prng_key: jax.Array, +): + + model = jaxsim_models_types + + key, subkey = jax.random.split(prng_key, num=2) + data = js.data.random_model_data( + model=model, + key=subkey, + velocity_representation=VelRepr.Inertial, + ) + + kin_dyn = utils_idyntree.build_kindyncomputations_from_jaxsim_model( + model=model, data=data + ) + + # ===== + # Tests + # ===== + + W_H_LL_model = js.model.forward_kinematics(model=model, data=data) + + W_H_LL_links = jax.vmap( + lambda idx: js.link.transform(model=model, data=data, link_index=idx) + )(jnp.arange(model.number_of_links())) + + assert W_H_LL_model == pytest.approx(W_H_LL_links) + + for W_H_L, link_name in zip(W_H_LL_links, model.link_names()): + + assert W_H_L == pytest.approx( + kin_dyn.frame_transform(frame_name=link_name) + ), link_name + + +def test_link_jacobians( + jaxsim_models_types: js.model.JaxSimModel, + velocity_representation: VelRepr, + prng_key: jax.Array, +): + + model = jaxsim_models_types + + key, subkey = jax.random.split(prng_key, num=2) + data = js.data.random_model_data( + model=model, + key=subkey, + velocity_representation=velocity_representation, + ) + + kin_dyn = utils_idyntree.build_kindyncomputations_from_jaxsim_model( + model=model, data=data + ) + + # ===== + # Tests + # ===== + + J_WL_links = jax.vmap( + lambda idx: js.link.jacobian(model=model, data=data, link_index=idx) + )(jnp.arange(model.number_of_links())) + + for J_WL, link_name in zip(J_WL_links, model.link_names()): + assert J_WL == pytest.approx( + kin_dyn.jacobian_frame(frame_name=link_name), abs=1e-9 + ), link_name + + # The following is true only in inertial-fixed representation. + if data.velocity_representation is VelRepr.Inertial: + J_WL_model = js.model.generalized_free_floating_jacobian(model=model, data=data) + assert J_WL_model == pytest.approx(J_WL_links) From d95b589b121140dba2fcf7dc9290e13a96898063 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 8 Mar 2024 13:04:41 +0100 Subject: [PATCH 09/35] Add test of jaxsim.api.model module --- tests/test_api_model.py | 311 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 311 insertions(+) create mode 100644 tests/test_api_model.py diff --git a/tests/test_api_model.py b/tests/test_api_model.py new file mode 100644 index 000000000..d53e3595b --- /dev/null +++ b/tests/test_api_model.py @@ -0,0 +1,311 @@ +import pathlib + +import jax +import jax.numpy as jnp +import numpy as np +import pytest +import rod + +import jaxsim.api as js +from jaxsim import VelRepr + +from . import utils_idyntree + + +def test_model_creation_and_reduction( + jaxsim_model_ergocub: js.model.JaxSimModel, + jaxsim_model_ergocub_reduced: js.model.JaxSimModel, +): + + model_full = jaxsim_model_ergocub + model_reduced = jaxsim_model_ergocub_reduced + + # Build the data of the full model. + data = js.data.JaxSimModelData.build( + model=model_full, + base_position=jnp.array([0, 0, 0.8]), + velocity_representation=VelRepr.Inertial, + ) + + # ===== + # Tests + # ===== + + # Check that the data of the full model is valid. + assert data.valid(model=model_full) + + # Build the ROD model from the original description. + assert isinstance(model_full.built_from, (str, pathlib.Path)) + rod_sdf = rod.Sdf.load(sdf=model_full.built_from) + assert len(rod_sdf.models()) == 1 + + # Get all non-fixed joint names from the description. + joint_names_in_description = [ + j.name for j in rod_sdf.models()[0].joints() if j.type != "fixed" + ] + + # Check that all non-fixed joints are in the full model. + assert set(joint_names_in_description) == set(model_full.joint_names()) + + # Build the data of the reduced model. + data_reduced = js.data.JaxSimModelData.build( + model=model_reduced, + base_position=jnp.array([0, 0, 0.8]), + velocity_representation=VelRepr.Inertial, + ) + + # Check that the reduced model data is valid. + assert not data_reduced.valid(model=model_full) + assert data_reduced.valid(model=model_reduced) + + # Check that the total mass is preserved. + assert js.model.total_mass(model=model_full) == pytest.approx( + js.model.total_mass(model=model_reduced) + ) + + # Check that the CoM position is preserved. + assert js.model.com_position(model=model_full, data=data) == pytest.approx( + js.model.com_position(model=model_reduced, data=data_reduced) + ) + + +def test_model_properties( + jaxsim_models_types: js.model.JaxSimModel, + velocity_representation: VelRepr, + prng_key: jax.Array, +): + + model = jaxsim_models_types + + key, subkey = jax.random.split(prng_key, num=2) + data = js.data.random_model_data( + model=model, key=subkey, velocity_representation=velocity_representation + ) + + kin_dyn = utils_idyntree.build_kindyncomputations_from_jaxsim_model( + model=model, data=data + ) + + # ===== + # Tests + # ===== + + m_idt = kin_dyn.total_mass() + m_js = js.model.total_mass(model=model) + assert pytest.approx(m_idt) == m_js + + p_com_idt = kin_dyn.com_position() + p_com_js = js.model.com_position(model=model, data=data) + assert pytest.approx(p_com_idt) == p_com_js + + h_tot_idt = kin_dyn.total_momentum() + h_tot_js = js.model.total_momentum(model=model, data=data) + assert pytest.approx(h_tot_idt) == h_tot_js + + Jh_idt = kin_dyn.total_momentum_jacobian() + Jh_js = js.model.free_floating_mass_matrix(model=model, data=data)[0:6] + assert pytest.approx(Jh_idt) == Jh_js + + +def test_model_rbda( + jaxsim_models_types: js.model.JaxSimModel, + prng_key: jax.Array, + velocity_representation: VelRepr, +): + + model = jaxsim_models_types + + key, subkey = jax.random.split(prng_key, num=2) + data = js.data.random_model_data( + model=model, key=subkey, velocity_representation=velocity_representation + ) + + kin_dyn = utils_idyntree.build_kindyncomputations_from_jaxsim_model( + model=model, data=data + ) + + # ===== + # Tests + # ===== + + # Support both fixed-base and floating-base models by slicing the first six rows + sl = np.s_[0:] if model.floating_base() else np.s_[6:] + + # Mass matrix + M_idt = kin_dyn.mass_matrix() + M_js = js.model.free_floating_mass_matrix(model=model, data=data) + assert pytest.approx(M_idt[sl, sl]) == M_js[sl, sl] + + # Gravity forces + g_idt = kin_dyn.gravity_forces() + g_js = js.model.free_floating_gravity_forces(model=model, data=data) + assert pytest.approx(g_idt[sl]) == g_js[sl] + + # Bias forces + h_idt = kin_dyn.bias_forces() + h_js = js.model.free_floating_bias_forces(model=model, data=data) + assert pytest.approx(h_idt[sl]) == h_js[sl] + + # Forward kinematics + HH_js = js.model.forward_kinematics(model=model, data=data) + HH_idt = jnp.stack( + [kin_dyn.frame_transform(frame_name=name) for name in model.link_names()] + ) + assert pytest.approx(HH_idt) == HH_js + + +def test_model_jacobian( + jaxsim_models_types: js.model.JaxSimModel, + prng_key: jax.Array, +): + + model = jaxsim_models_types + + key, subkey = jax.random.split(prng_key, num=2) + data = js.data.random_model_data( + model=model, key=subkey, velocity_representation=VelRepr.Inertial + ) + + # ===== + # Tests + # ===== + + # Create random references (joint torques and link forces) + key, subkey1, subkey2 = jax.random.split(key, num=3) + references = js.references.JaxSimModelReferences.build( + model=model, + joint_force_references=10 * jax.random.uniform(subkey1, shape=(model.dofs(),)), + link_forces=jax.random.uniform(subkey2, shape=(model.number_of_links(), 6)), + data=data, + velocity_representation=data.velocity_representation, + ) + + # Remove the force applied to the base link if the model is fixed-base + if not model.floating_base(): + references = references.apply_link_forces( + forces=jnp.atleast_2d(jnp.zeros(6)), + model=model, + data=data, + link_names=(model.base_link(),), + additive=False, + ) + + # Get the J.T @ f product in inertial-fixed input/output representation. + # We use doubly right-trivialized jacobian with inertial-fixed 6D forces. + with references.switch_velocity_representation(VelRepr.Inertial): + with data.switch_velocity_representation(VelRepr.Inertial): + + f = references.link_forces(model=model, data=data) + assert f == pytest.approx(references.input.physics_model.f_ext) + + J = js.model.generalized_free_floating_jacobian(model=model, data=data) + JTf_inertial = jnp.einsum("l6g,l6->g", J, f) + + for vel_repr in [VelRepr.Body, VelRepr.Mixed]: + with references.switch_velocity_representation(vel_repr): + + # Get the jacobian having an inertial-fixed input representation (so that + # it computes the same quantity computed above) and an output representation + # compatible with the frame in which the external forces are expressed. + with data.switch_velocity_representation(VelRepr.Inertial): + + J = js.model.generalized_free_floating_jacobian( + model=model, data=data, output_vel_repr=vel_repr + ) + + # Get the forces in the tested representation and compute the product + # O_J_WL_W.T @ O_f, producing a generalized acceleration in W. + # The resulting acceleration can be tested again the one computed before. + with data.switch_velocity_representation(vel_repr): + + f = references.link_forces(model=model, data=data) + JTf_other = jnp.einsum("l6g,l6->g", J, f) + assert pytest.approx(JTf_inertial) == JTf_other, vel_repr.name + + +def test_model_fd_id_consistency( + jaxsim_models_types: js.model.JaxSimModel, + velocity_representation: VelRepr, + prng_key: jax.Array, +): + + model = jaxsim_models_types + + key, subkey = jax.random.split(prng_key, num=2) + data = js.data.random_model_data( + model=model, key=subkey, velocity_representation=velocity_representation + ) + + # ===== + # Tests + # ===== + + # Create random references (joint torques and link forces) + key, subkey1, subkey2 = jax.random.split(key, num=3) + references = js.references.JaxSimModelReferences.build( + model=model, + joint_force_references=10 * jax.random.uniform(subkey1, shape=(model.dofs(),)), + link_forces=jax.random.uniform(subkey2, shape=(model.number_of_links(), 6)), + data=data, + velocity_representation=data.velocity_representation, + ) + + # Remove the force applied to the base link if the model is fixed-base + if not model.floating_base(): + references = references.apply_link_forces( + forces=jnp.atleast_2d(jnp.zeros(6)), + model=model, + data=data, + link_names=(model.base_link(),), + additive=False, + ) + + # Compute forward dynamics with ABA + v̇_WB_aba, s̈_aba = js.model.forward_dynamics_aba( + model=model, + data=data, + joint_forces=references.joint_force_references(), + external_forces=references.link_forces(model=model, data=data), + ) + + # Compute forward dynamics with CRB + v̇_WB_crb, s̈_crb = js.model.forward_dynamics_crb( + model=model, + data=data, + joint_forces=references.joint_force_references(), + external_forces=references.link_forces(model=model, data=data), + ) + + assert pytest.approx(s̈_aba) == s̈_crb + assert pytest.approx(v̇_WB_aba) == v̇_WB_crb + + # Compute inverse dynamics with the quantities computed by forward dynamics + fB_id, τ_id = js.model.inverse_dynamics( + model=model, + data=data, + joint_accelerations=s̈_aba, + base_acceleration=v̇_WB_aba, + external_forces=references.link_forces(model=model, data=data), + ) + + # Check consistency between FD and ID + assert pytest.approx(τ_id) == references.joint_force_references(model=model) + assert pytest.approx(fB_id, abs=1e-9) == jnp.zeros(6) + + if model.floating_base(): + # If we remove the base 6D force from the inputs, we should find it as output. + fB_id, τ_id = js.model.inverse_dynamics( + model=model, + data=data, + joint_accelerations=s̈_aba, + base_acceleration=v̇_WB_aba, + external_forces=references.link_forces(model=model, data=data) + .at[0] + .set(jnp.zeros(6)), + ) + + assert pytest.approx(τ_id) == references.joint_force_references(model=model) + assert ( + pytest.approx(fB_id, abs=1e-9) + == references.link_forces(model=model, data=data)[0] + ) From 29ec05817c83dd67ad8698868f4f2adc302e1e21 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 12 Mar 2024 12:00:32 +0100 Subject: [PATCH 10/35] Add methods to wrapper utils_idyntree.KinDynComputations --- tests/utils_idyntree.py | 62 +++++++++++++++++++++++++++++++---------- 1 file changed, 48 insertions(+), 14 deletions(-) diff --git a/tests/utils_idyntree.py b/tests/utils_idyntree.py index a58391067..491987092 100644 --- a/tests/utils_idyntree.py +++ b/tests/utils_idyntree.py @@ -1,6 +1,5 @@ import dataclasses import pathlib -from typing import List, Union import idyntree.bindings as idt import numpy as np @@ -19,14 +18,13 @@ class KinDynComputations: @staticmethod def build( - urdf: Union[pathlib.Path, str], - considered_joints: List[str] = None, + urdf: pathlib.Path | str, + considered_joints: list[str] = None, vel_repr: VelRepr = VelRepr.Inertial, gravity: npt.NDArray = dataclasses.field( default_factory=lambda: np.array([0, 0, -10.0]) ), ) -> "KinDynComputations": - """""" # Read the URDF description urdf_string = urdf.read_text() if isinstance(urdf, pathlib.Path) else urdf @@ -40,7 +38,7 @@ def build( if considered_joints is None else mdl_loader.loadReducedModelFromString(urdf_string, considered_joints) ): - raise RuntimeError(f"Failed to load URDF description") + raise RuntimeError("Failed to load URDF description") # Create KinDynComputations and insert the model kindyn = idt.KinDynComputations() @@ -72,6 +70,7 @@ def set_robot_state( base_velocity: npt.NDArray = np.zeros(6), world_gravity: npt.NDArray | None = None, ) -> None: + joint_positions = ( joint_positions if joint_positions is not None else np.zeros(self.dofs()) ) @@ -114,21 +113,25 @@ def set_robot_state( raise RuntimeError("Failed to set the robot state") # Update stored gravity - self.world_gravity = gravity + self.gravity = gravity def dofs(self) -> int: + return self.kin_dyn.getNrOfDegreesOfFreedom() - def joint_names(self) -> List[str]: + def joint_names(self) -> list[str]: + model: idt.Model = self.kin_dyn.model() return [model.getJointName(i) for i in range(model.getNrOfJoints())] - def link_names(self) -> List[str]: + def link_names(self) -> list[str]: + return [ self.kin_dyn.getFrameName(i) for i in range(self.kin_dyn.getNrOfLinks()) ] def joint_positions(self) -> npt.NDArray: + vector = idt.VectorDynSize() if not self.kin_dyn.getJointPos(vector): @@ -137,6 +140,7 @@ def joint_positions(self) -> npt.NDArray: return vector.toNumPy() def joint_velocities(self) -> npt.NDArray: + vector = idt.VectorDynSize() if not self.kin_dyn.getJointVel(vector): @@ -145,6 +149,7 @@ def joint_velocities(self) -> npt.NDArray: return vector.toNumPy() def jacobian_frame(self, frame_name: str) -> npt.NDArray: + if self.kin_dyn.getFrameIndex(frame_name) < 0: raise ValueError(f"Frame '{frame_name}' does not exist") @@ -156,23 +161,36 @@ def jacobian_frame(self, frame_name: str) -> npt.NDArray: return J.toNumPy() def total_mass(self) -> float: + model: idt.Model = self.kin_dyn.model() return model.getTotalMass() - def spatial_inertia(self, link_name: str) -> npt.NDArray: + def link_spatial_inertia(self, link_name: str) -> npt.NDArray: + if link_name not in self.link_names(): raise ValueError(link_name) model = self.kin_dyn.model() + link: idt.Link = model.getLink(model.getLinkIndex(link_name)) - return ( - model.getLink(model.getLinkIndex(link_name)).inertia().asMatrix().toNumPy() - ) + return link.inertia().asMatrix().toNumPy() + + def link_mass(self, link_name: str) -> float: + + if link_name not in self.link_names(): + raise ValueError(link_name) + + model = self.kin_dyn.model() + link: idt.Link = model.getLink(model.getLinkIndex(link_name)) + + return link.getInertia().asVector().toNumPy()[0] def floating_base_frame(self) -> str: + return self.kin_dyn.getFloatingBase() def frame_transform(self, frame_name: str) -> npt.NDArray: + if self.kin_dyn.getFrameIndex(frame_name) < 0: raise ValueError(f"Frame '{frame_name}' does not exist") @@ -181,8 +199,6 @@ def frame_transform(self, frame_name: str) -> npt.NDArray: else: H_idt = self.kin_dyn.getWorldTransform(frame_name) - # return H_idt.asHomogeneousTransform().toNumPy() - H = np.eye(4) H[0:3, 3] = H_idt.getPosition().toNumPy() H[0:3, 0:3] = H_idt.getRotation().toNumPy() @@ -190,6 +206,7 @@ def frame_transform(self, frame_name: str) -> npt.NDArray: return H def base_velocity(self) -> npt.NDArray: + nu = idt.VectorDynSize() if not self.kin_dyn.getModelVel(nu): @@ -198,10 +215,12 @@ def base_velocity(self) -> npt.NDArray: return nu.toNumPy()[0:6] def com_position(self) -> npt.NDArray: + W_p_G = self.kin_dyn.getCenterOfMassPosition() return W_p_G.toNumPy() def mass_matrix(self) -> npt.NDArray: + M = idt.MatrixDynSize() if not self.kin_dyn.getFreeFloatingMassMatrix(M): @@ -210,6 +229,7 @@ def mass_matrix(self) -> npt.NDArray: return M.toNumPy() def bias_forces(self) -> npt.NDArray: + h = idt.FreeFloatingGeneralizedTorques(self.kin_dyn.model()) if not self.kin_dyn.generalizedBiasForces(h): @@ -223,6 +243,7 @@ def bias_forces(self) -> npt.NDArray: ) def gravity_forces(self) -> npt.NDArray: + g = idt.FreeFloatingGeneralizedTorques(self.kin_dyn.model()) if not self.kin_dyn.generalizedGravityForces(g): @@ -234,3 +255,16 @@ def gravity_forces(self) -> npt.NDArray: return np.hstack( [base_wrench.toNumPy().flatten(), joint_torques.toNumPy().flatten()] ) + + def total_momentum(self) -> npt.NDArray: + + return self.kin_dyn.getLinearAngularMomentum().toNumPy().flatten() + + def total_momentum_jacobian(self) -> npt.NDArray: + + Jh = idt.MatrixDynSize() + + if not self.kin_dyn.getLinearAngularMomentumJacobian(Jh): + raise RuntimeError("Failed to get the total momentum jacobian") + + return Jh.toNumPy() From 9b49549682e9ab771e4fa561fc3a92d4eebb7c7e Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 8 Mar 2024 11:19:14 +0100 Subject: [PATCH 11/35] Add other iDynTree testing helpers --- tests/utils_idyntree.py | 67 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/tests/utils_idyntree.py b/tests/utils_idyntree.py index 491987092..c022cf5f7 100644 --- a/tests/utils_idyntree.py +++ b/tests/utils_idyntree.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dataclasses import pathlib @@ -5,9 +7,74 @@ import numpy as np import numpy.typing as npt +import jaxsim.api as js from jaxsim.high_level.common import VelRepr +def build_kindyncomputations_from_jaxsim_model( + model: js.model.JaxSimModel, data: js.data.JaxSimModelData +) -> KinDynComputations: + """ + Build a `KinDynComputations` from `JaxSimModel` and `JaxSimModelData`. + + Args: + model: The `JaxSimModel` from which to build the `KinDynComputations`. + data: The `JaxSimModelData` from which to build the `KinDynComputations`. + + Returns: + The `KinDynComputations` built from the `JaxSimModel` and `JaxSimModelData`. + + Note: + Only `JaxSimModel` built from URDF files are supported. + """ + + if ( + isinstance(model.built_from, pathlib.Path) + and model.built_from.suffix != ".urdf" + ) or (isinstance(model.built_from, str) and " KinDynComputations: + """ + Store the state of a `JaxSimModelData` in `KinDynComputations`. + + Args: + data: + The `JaxSimModelData` providing the desired state to copy. + kin_dyn: + The `KinDynComputations` in which to store the state of `JaxSimModelData`. + + Returns: + The updated `KinDynComputations` with the state of `JaxSimModelData`. + """ + + with data.switch_velocity_representation(kin_dyn.vel_repr): + kin_dyn.set_robot_state( + joint_positions=np.array(data.joint_positions()), + joint_velocities=np.array(data.joint_velocities()), + base_transform=np.array(data.base_transform()), + base_velocity=np.array(data.base_velocity()), + ) + + return kin_dyn + + @dataclasses.dataclass class KinDynComputations: """High-level wrapper of the iDynTree KinDynComputations class.""" From 97e54d64b58c14ef7ccf774454c70cf0d3a45f62 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 8 Mar 2024 13:01:20 +0100 Subject: [PATCH 12/35] Add test to check automatic differentiation of algorithms --- tests/test_automatic_differentiation.py | 406 ++++++++++++++++++++++++ 1 file changed, 406 insertions(+) create mode 100644 tests/test_automatic_differentiation.py diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py new file mode 100644 index 000000000..2a1859d7e --- /dev/null +++ b/tests/test_automatic_differentiation.py @@ -0,0 +1,406 @@ +import os + +import jax +import jax.numpy as jnp +from jax.test_util import check_grads + +import jaxsim.api as js +from jaxsim import VelRepr + +# All JaxSim algorithms, excluding the variable-step integrators, should support +# being automatically differentiated until second order, both in FWD and REV modes. +# However, checking the second-order derivatives is particularly slow and makes +# CI tests take too long. Therefore, we only check first-order derivatives. +AD_ORDER = os.environ.get("JAXSIM_TEST_AD_ORDER", 1) + + +def get_random_data_and_references( + model: js.model.JaxSimModel, + velocity_representation: VelRepr, + key: jax.Array, +) -> tuple[js.data.JaxSimModelData, js.references.JaxSimModelReferences]: + + key, subkey = jax.random.split(key, num=2) + + data = js.data.random_model_data( + model=model, key=subkey, velocity_representation=velocity_representation + ) + + key, subkey1, subkey2 = jax.random.split(key, num=3) + + references = js.references.JaxSimModelReferences.build( + model=model, + joint_force_references=10 * jax.random.uniform(subkey1, shape=(model.dofs(),)), + link_forces=jax.random.uniform(subkey2, shape=(model.number_of_links(), 6)), + data=data, + velocity_representation=velocity_representation, + ) + + # Remove the force applied to the base link if the model is fixed-base. + if not model.floating_base(): + references = references.apply_link_forces( + forces=jnp.atleast_2d(jnp.zeros(6)), + model=model, + data=data, + link_names=(model.base_link(),), + additive=False, + ) + + return data, references + + +def test_ad_aba( + jaxsim_models_types: js.model.JaxSimModel, + prng_key: jax.Array, +): + + model = jaxsim_models_types + + key, subkey = jax.random.split(prng_key, num=2) + data, references = get_random_data_and_references( + model=model, velocity_representation=VelRepr.Inertial, key=key + ) + + # Perturbation used for computing finite differences. + ε = jnp.finfo(jnp.array(0.0)).resolution ** (1 / 3) + + # State in VelRepr.Inertial representation. + s = data.joint_positions() + ṡ = data.joint_velocities(model=model) + xfb = data.state.physics_model.xfb() + + # Inputs. + f = references.link_forces(model=model) + τ = references.joint_force_references(model=model) + + # ==== + # Test + # ==== + + import jaxsim.physics.algos.aba + + # Get a closure exposing only the parameters to be differentiated. + aba = lambda xfb, s, ṡ, tau, f_ext: jaxsim.physics.algos.aba.aba( + model=model.physics_model, xfb=xfb, q=s, qd=ṡ, tau=tau, f_ext=f_ext + ) + + # Check derivatives against finite differences. + check_grads( + f=aba, + args=(xfb, s, ṡ, τ, f), + order=AD_ORDER, + modes=["rev", "fwd"], + eps=ε, + ) + + +def test_ad_rnea( + jaxsim_models_types: js.model.JaxSimModel, + prng_key: jax.Array, +): + + model = jaxsim_models_types + + key, subkey = jax.random.split(prng_key, num=2) + data, references = get_random_data_and_references( + model=model, velocity_representation=VelRepr.Inertial, key=key + ) + + # Perturbation used for computing finite differences. + ε = jnp.finfo(jnp.array(0.0)).resolution ** (1 / 3) + + # State in VelRepr.Inertial representation. + s = data.joint_positions() + ṡ = data.joint_velocities(model=model) + xfb = data.state.physics_model.xfb() + + # Inputs. + f = references.link_forces(model=model) + + # ==== + # Test + # ==== + + import jaxsim.physics.algos.rnea + + key, subkey1, subkey2 = jax.random.split(key, num=3) + W_v̇_WB = jax.random.uniform(subkey1, shape=(6,), minval=-1) + s̈ = jax.random.uniform(subkey2, shape=(model.dofs(),), minval=-1) + + # Get a closure exposing only the parameters to be differentiated. + rnea = lambda xfb, s, ṡ, s̈, W_v̇_WB, f_ext: jaxsim.physics.algos.rnea.rnea( + model=model.physics_model, xfb=xfb, q=s, qd=ṡ, qdd=s̈, a0fb=W_v̇_WB, f_ext=f_ext + ) + + # Check derivatives against finite differences. + check_grads( + f=rnea, + args=(xfb, s, ṡ, s̈, W_v̇_WB, f), + order=AD_ORDER, + modes=["rev", "fwd"], + eps=ε, + ) + + +def test_ad_crba( + jaxsim_models_types: js.model.JaxSimModel, + prng_key: jax.Array, +): + + model = jaxsim_models_types + + key, subkey = jax.random.split(prng_key, num=2) + data, references = get_random_data_and_references( + model=model, velocity_representation=VelRepr.Inertial, key=key + ) + + # Perturbation used for computing finite differences. + ε = jnp.finfo(jnp.array(0.0)).resolution ** (1 / 3) + + # State in VelRepr.Inertial representation. + s = data.joint_positions() + + # ==== + # Test + # ==== + + import jaxsim.physics.algos.crba + + # Get a closure exposing only the parameters to be differentiated. + crba = lambda s: jaxsim.physics.algos.crba.crba(model=model.physics_model, q=s) + + # Check derivatives against finite differences. + check_grads( + f=crba, + args=(s,), + order=AD_ORDER, + modes=["rev", "fwd"], + eps=ε, + ) + + +def test_ad_fk( + jaxsim_models_types: js.model.JaxSimModel, + prng_key: jax.Array, +): + + model = jaxsim_models_types + + key, subkey = jax.random.split(prng_key, num=2) + data, references = get_random_data_and_references( + model=model, velocity_representation=VelRepr.Inertial, key=key + ) + + # Perturbation used for computing finite differences. + ε = jnp.finfo(jnp.array(0.0)).resolution ** (1 / 3) + + # State in VelRepr.Inertial representation. + s = data.joint_positions() + xfb = data.state.physics_model.xfb() + + # ==== + # Test + # ==== + + import jaxsim.physics.algos.forward_kinematics + + # Get a closure exposing only the parameters to be differentiated. + fk = ( + lambda xfb, s: jaxsim.physics.algos.forward_kinematics.forward_kinematics_model( + model=model.physics_model, xfb=xfb, q=s + ) + ) + + # Check derivatives against finite differences. + check_grads( + f=fk, + args=(xfb, s), + order=AD_ORDER, + modes=["rev", "fwd"], + eps=ε, + ) + + +def test_ad_jacobian( + jaxsim_models_types: js.model.JaxSimModel, + prng_key: jax.Array, +): + + model = jaxsim_models_types + + key, subkey = jax.random.split(prng_key, num=2) + data, references = get_random_data_and_references( + model=model, velocity_representation=VelRepr.Inertial, key=key + ) + + # Perturbation used for computing finite differences. + ε = jnp.finfo(jnp.array(0.0)).resolution ** (1 / 3) + + # State in VelRepr.Inertial representation. + s = data.joint_positions() + + # ==== + # Test + # ==== + + import jaxsim.physics.algos.jacobian + + # Get the link indices. + link_indices = js.link.names_to_idxs(model=model, link_names=model.link_names()) + + # Get a closure exposing only the parameters to be differentiated. + # We differentiate the jacobian of the last link, likely among those + # farther from the base. + jacobian = lambda s: jaxsim.physics.algos.jacobian.jacobian( + model=model.physics_model, q=s, body_index=link_indices[-1] + ) + + # Check derivatives against finite differences. + check_grads( + f=jacobian, + args=(s,), + order=AD_ORDER, + modes=["rev", "fwd"], + eps=ε, + ) + + +def test_ad_soft_contacts( + jaxsim_models_types: js.model.JaxSimModel, + prng_key: jax.Array, +): + + model = jaxsim_models_types + + # Perturbation used for computing finite differences. + ε = jnp.finfo(jnp.array(0.0)).resolution ** (1 / 3) + + key, subkey1, subkey2, subkey3 = jax.random.split(prng_key, num=4) + p = jax.random.uniform(subkey1, shape=(3,), minval=-1) + v = jax.random.uniform(subkey2, shape=(3,), minval=-1) + m = jax.random.uniform(subkey3, shape=(3,), minval=-1) + + # Get the soft contacts parameters. + parameters = js.contact.estimate_good_soft_contacts_parameters(model=model) + + # ==== + # Test + # ==== + + import jaxsim.physics.algos.soft_contacts + + # Get a closure exposing only the parameters to be differentiated. + soft_contacts = lambda p, v, m: jaxsim.physics.algos.soft_contacts.SoftContacts( + parameters=parameters + ).contact_model(position=p, velocity=v, tangential_deformation=m) + + # Check derivatives against finite differences. + check_grads( + f=soft_contacts, + args=(p, v, m), + order=AD_ORDER, + modes=["rev", "fwd"], + eps=ε, + ) + + +def test_ad_integration( + jaxsim_models_types: js.model.JaxSimModel, + prng_key: jax.Array, +): + + model = jaxsim_models_types + + key, subkey = jax.random.split(prng_key, num=2) + data, references = get_random_data_and_references( + model=model, velocity_representation=VelRepr.Inertial, key=key + ) + + # Perturbation used for computing finite differences. + ε = jnp.finfo(jnp.array(0.0)).resolution ** (1 / 3) + + # State in VelRepr.Inertial representation. + s = data.joint_positions() + ṡ = data.joint_velocities(model=model) + xfb = data.state.physics_model.xfb() + m = data.state.soft_contacts.tangential_deformation + + # Inputs. + f = references.link_forces(model=model) + τ = references.joint_force_references(model=model) + + # ==== + # Test + # ==== + + import jaxsim.integrators + + # Note that only fixes-step integrators support both FWD and RWD gradients. + # Select a second-order Heun scheme with quaternion integrated on SO(3). + # Note that it's always preferable using the SO(3) versions on AD applications so + # that the gradient of the integrated dynamics always considers unary quaternions. + integrator = jaxsim.integrators.fixed_step.Heun2SO3.build( + dynamics=js.ode.wrap_system_dynamics_for_integration( + model=model, + data=data, + system_dynamics=js.ode.system_dynamics, + ), + ) + + # Initialize the integrator. + t0, dt = 0.0, 0.001 + integrator_state = integrator.init(x0=data.state, t0=t0, dt=dt) + + # Function exposing only the parameters to be differentiated. + def step( + xfb: jax.typing.ArrayLike, + s: jax.typing.ArrayLike, + ṡ: jax.typing.ArrayLike, + m: jax.typing.ArrayLike, + tau: jax.typing.ArrayLike, + f_ext: jax.typing.ArrayLike, + ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]: + + data_x0 = data.replace( + state=js.ode.ODEState.build( + physics_model_state=js.ode.PhysicsModelState.build( + joint_positions=s, + joint_velocities=ṡ, + base_position=xfb[4:7], + base_quaternion=xfb[0:4], + base_linear_velocity=xfb[7:10], + base_angular_velocity=xfb[10:13], + ), + soft_contacts_state=js.ode.SoftContactsState.build( + tangential_deformation=m + ), + ), + ) + + data_xf, _ = js.model.step( + dt=dt, + model=model, + data=data_x0, + integrator=integrator, + integrator_state=integrator_state, + joint_forces=tau, + external_forces=f_ext, + ) + + s_xf = data_xf.joint_positions() + ṡ_xf = data_xf.joint_velocities() + xfb_xf = data_xf.state.physics_model.xfb() + m_xf = data_xf.state.soft_contacts.tangential_deformation + + return xfb_xf, s_xf, ṡ_xf, m_xf + + # Check derivatives against finite differences. + check_grads( + f=step, + args=(xfb, s, ṡ, m, τ, f), + order=AD_ORDER, + modes=["rev", "fwd"], + eps=ε, + # This check (at least on ErgoCub) needs larger tolerances + rtol=0.0001, + ) From b977b92fb6cc5bcccaa85d97b97fb8b45f2e541b Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 8 Mar 2024 16:40:59 +0100 Subject: [PATCH 13/35] Fix random generation of JaxSimModelData for fixed-base models --- src/jaxsim/api/data.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index f2abc4bd6..4115a525e 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -765,16 +765,17 @@ def random_model_data( model=model, key=k3 ) - physics_model_state.base_linear_velocity = jax.random.uniform( - key=k4, shape=(3,), minval=v_min, maxval=v_max - ) - - physics_model_state.base_angular_velocity = jax.random.uniform( - key=k5, shape=(3,), minval=ω_min, maxval=ω_max - ) - physics_model_state.joint_velocities = jax.random.uniform( key=k6, shape=(model.dofs(),), minval=ṡ_min, maxval=ṡ_max ) + if model.floating_base(): + physics_model_state.base_linear_velocity = jax.random.uniform( + key=k4, shape=(3,), minval=v_min, maxval=v_max + ) + + physics_model_state.base_angular_velocity = jax.random.uniform( + key=k5, shape=(3,), minval=ω_min, maxval=ω_max + ) + return random_data From 83147887c2efd8077aee34c592fad6b4415664cd Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Mon, 11 Mar 2024 15:09:56 +0100 Subject: [PATCH 14/35] Fix random generation of JaxSimModelData for jointless models --- src/jaxsim/api/data.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index 4115a525e..59294e6c2 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -761,13 +761,14 @@ def random_model_data( *jax.random.uniform(key=k2, shape=(3,), minval=0, maxval=2 * jnp.pi) ).as_quaternion_xyzw()[np.array([3, 0, 1, 2])] - physics_model_state.joint_positions = jaxsim.api.joint.random_joint_positions( - model=model, key=k3 - ) + if model.number_of_joints() > 0: + physics_model_state.joint_positions = ( + jaxsim.api.joint.random_joint_positions(model=model, key=k3) + ) - physics_model_state.joint_velocities = jax.random.uniform( - key=k6, shape=(model.dofs(),), minval=ṡ_min, maxval=ṡ_max - ) + physics_model_state.joint_velocities = jax.random.uniform( + key=k6, shape=(model.dofs(),), minval=ṡ_min, maxval=ṡ_max + ) if model.floating_base(): physics_model_state.base_linear_velocity = jax.random.uniform( From ccc34736561f89126a989c6cf9a4f17f20aa9309 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 8 Mar 2024 16:41:29 +0100 Subject: [PATCH 15/35] Fix api.joint.name_to_idx --- src/jaxsim/api/joint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jaxsim/api/joint.py b/src/jaxsim/api/joint.py index 1a0c2a564..1757925da 100644 --- a/src/jaxsim/api/joint.py +++ b/src/jaxsim/api/joint.py @@ -26,7 +26,7 @@ def name_to_idx(model: Model.JaxSimModel, *, joint_name: str) -> jtp.Int: """ return jnp.array( - model.physics_model.description.joints_dict[joint_name].index, dtype=int + model.physics_model.description.joints_dict[joint_name].index - 1, dtype=int ) From 0932028dc9beddb80441aae99f39df0d68f8cb86 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 8 Mar 2024 16:43:01 +0100 Subject: [PATCH 16/35] Fix ABA in Mixed representation for fixed-base models --- src/jaxsim/api/model.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 7a5c0ecc2..f4877436d 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -602,6 +602,12 @@ def to_active(W_vd_WB, W_H_C, W_v_WB, W_vl_WC): W_vl_WC=W_vl_WC, ) + # The ABA algorithm already returns a zero base 6D acceleration for + # fixed-based models. However, the to_active function introduces an + # additional acceleration component in Mixed representation. + # Here below we make sure that the base acceleration is zero. + C_v̇_WB = C_v̇_WB if model.floating_base() else jnp.zeros(6).astype(float) + # Adjust shape s̈ = jnp.atleast_1d(s̈.squeeze()) From c234320116125051c68b489de06d491bb6791428 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Mon, 11 Mar 2024 15:24:19 +0100 Subject: [PATCH 17/35] Fix ABA when inputs are passed The low-level RBDA need the inputs to be in inertial-fixed representation. We use JaxSimModelReferences to automatically convert them. --- src/jaxsim/api/model.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index f4877436d..5de4cc67a 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -549,14 +549,22 @@ def forward_dynamics_aba( else jnp.zeros((model.number_of_links(), 6)) ) + references = js.references.JaxSimModelReferences.build( + model=model, + joint_force_references=τ, + link_forces=f_ext, + data=data, + velocity_representation=data.velocity_representation, + ) + # Compute ABA W_v̇_WB, s̈ = jaxsim.physics.algos.aba.aba( model=model.physics_model, xfb=data.state.physics_model.xfb(), q=data.state.physics_model.joint_positions, qd=data.state.physics_model.joint_velocities, - tau=τ, - f_ext=f_ext, + tau=references.input.physics_model.tau, + f_ext=references.input.physics_model.f_ext, ) def to_active(W_vd_WB, W_H_C, W_v_WB, W_vl_WC): From 196426cfc230c392f68a5b107064750a1c2a2a30 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 8 Mar 2024 16:43:59 +0100 Subject: [PATCH 18/35] Fix computation of bias forces for fixed-base models --- src/jaxsim/api/model.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 5de4cc67a..0cbc8dd0a 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -962,18 +962,20 @@ def free_floating_bias_forces( data.state.physics_model.joint_positions ) - data_rnea.state.physics_model.base_linear_velocity = ( - data.state.physics_model.base_linear_velocity - ) - - data_rnea.state.physics_model.base_angular_velocity = ( - data.state.physics_model.base_angular_velocity - ) - data_rnea.state.physics_model.joint_velocities = ( data.state.physics_model.joint_velocities ) + # Make sure that base velocity is zero for fixed-base model. + if model.floating_base(): + data_rnea.state.physics_model.base_linear_velocity = ( + data.state.physics_model.base_linear_velocity + ) + + data_rnea.state.physics_model.base_angular_velocity = ( + data.state.physics_model.base_angular_velocity + ) + return jnp.hstack( inverse_dynamics( model=model, From 5eeda351106d1e5439dc7b05ac928a930efe4750 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 8 Mar 2024 16:45:01 +0100 Subject: [PATCH 19/35] Rename Heun integrator to Heun2 --- src/jaxsim/integrators/fixed_step.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/jaxsim/integrators/fixed_step.py b/src/jaxsim/integrators/fixed_step.py index 7cc32b14f..fa6352450 100644 --- a/src/jaxsim/integrators/fixed_step.py +++ b/src/jaxsim/integrators/fixed_step.py @@ -45,7 +45,7 @@ class ForwardEuler(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]): @jax_dataclasses.pytree_dataclass -class Heun(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]): +class Heun2(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]): A: ClassVar[jax.typing.ArrayLike] = jnp.array( [ @@ -149,7 +149,7 @@ class ForwardEulerSO3(ExplicitRungeKuttaSO3Mixin, Heun[ODEState]): @jax_dataclasses.pytree_dataclass -class HeunSO3(ExplicitRungeKuttaSO3Mixin, Heun[ODEState]): +class Heun2SO3(ExplicitRungeKuttaSO3Mixin, Heun2[ODEState]): pass From f4563bfd335a19d8583bf64c64123e83e0974aa6 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 8 Mar 2024 16:51:10 +0100 Subject: [PATCH 20/35] Fix ForwardEuler integrator --- src/jaxsim/integrators/fixed_step.py | 22 ++++------------------ 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/src/jaxsim/integrators/fixed_step.py b/src/jaxsim/integrators/fixed_step.py index fa6352450..53a0975c2 100644 --- a/src/jaxsim/integrators/fixed_step.py +++ b/src/jaxsim/integrators/fixed_step.py @@ -20,25 +20,11 @@ @jax_dataclasses.pytree_dataclass class ForwardEuler(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]): - A: ClassVar[jax.typing.ArrayLike] = jnp.array( - [ - [0], - ] - ).astype(float) + A: ClassVar[jax.typing.ArrayLike] = jnp.atleast_2d(0).astype(float) - b: ClassVar[jax.typing.ArrayLike] = ( - jnp.array( - [ - [1], - ] - ) - .astype(float) - .transpose() - ) + b: ClassVar[jax.typing.ArrayLike] = jnp.atleast_2d(1).astype(float).transpose() - c: ClassVar[jax.typing.ArrayLike] = jnp.array( - [0], - ).astype(float) + c: ClassVar[jax.typing.ArrayLike] = jnp.atleast_1d(0).astype(float) row_index_of_solution: ClassVar[int] = 0 order_of_bT_rows: ClassVar[tuple[int, ...]] = (1,) @@ -144,7 +130,7 @@ def post_process_state( @jax_dataclasses.pytree_dataclass -class ForwardEulerSO3(ExplicitRungeKuttaSO3Mixin, Heun[ODEState]): +class ForwardEulerSO3(ExplicitRungeKuttaSO3Mixin, ForwardEuler[ODEState]): pass From b81bc1d9554539dcf21df1b58cca9d78eb848ac3 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Mon, 11 Mar 2024 12:46:22 +0100 Subject: [PATCH 21/35] Fix shape correction of the Butcher tableau b parameter --- src/jaxsim/integrators/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jaxsim/integrators/common.py b/src/jaxsim/integrators/common.py index 89f6872f9..53f5605a8 100644 --- a/src/jaxsim/integrators/common.py +++ b/src/jaxsim/integrators/common.py @@ -211,7 +211,7 @@ def build( # Adjust the shape of the tableau coefficients. c = jnp.atleast_1d(cls.c.squeeze()) - b = jnp.atleast_2d(jnp.vstack(cls.b.squeeze())) + b = jnp.atleast_2d(cls.b.T.squeeze()).transpose() A = jnp.atleast_2d(cls.A.squeeze()) # Check validity of the Butcher tableau. From 945f04b683c3519772ad4ec7bb916bacd4400a3f Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 8 Mar 2024 17:01:10 +0100 Subject: [PATCH 22/35] Fix non-jit execution of RBDAs for models with no joints --- src/jaxsim/physics/algos/aba.py | 42 +++++++++++++++-------- src/jaxsim/physics/algos/rnea.py | 24 ++++++++----- src/jaxsim/physics/algos/soft_contacts.py | 24 ++++++++----- 3 files changed, 59 insertions(+), 31 deletions(-) diff --git a/src/jaxsim/physics/algos/aba.py b/src/jaxsim/physics/algos/aba.py index 378864d36..cb07a9c80 100644 --- a/src/jaxsim/physics/algos/aba.py +++ b/src/jaxsim/physics/algos/aba.py @@ -112,7 +112,7 @@ def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> Tuple[Pass1Carry, None]: i_X_λi = i_X_λi.at[i].set(i_X_λi_i) # Propagate link velocity - vJ = S[i] * qd[ii] if qd.size != 0 else S[i] * 0 + vJ = S[i] * qd[ii] v_i = i_X_λi[i] @ v[λ[i]] + vJ v = v.at[i].set(v_i) @@ -134,10 +134,14 @@ def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> Tuple[Pass1Carry, None]: return (i_X_λi, v, c, MA, pA, i_X_0), None - (i_X_λi, v, c, MA, pA, i_X_0), _ = jax.lax.scan( - f=loop_body_pass1, - init=pass_1_carry, - xs=np.arange(start=1, stop=model.NB), + (i_X_λi, v, c, MA, pA, i_X_0), _ = ( + jax.lax.scan( + f=loop_body_pass1, + init=pass_1_carry, + xs=np.arange(start=1, stop=model.NB), + ) + if model.NB > 1 + else [(i_X_λi, v, c, MA, pA, i_X_0), None] ) U = jnp.zeros_like(S) @@ -166,7 +170,7 @@ def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> Tuple[Pass2Carry, None]: d_i = S[i].T @ U[i] d = d.at[i].set(d_i.squeeze()) - u_i = tau[ii] - S[i].T @ pA[i] if tau.size != 0 else -S[i].T @ pA[i] + u_i = tau[ii] - S[i].T @ pA[i] u = u.at[i].set(u_i.squeeze()) # Compute the articulated-body inertia and bias forces of this link @@ -196,10 +200,14 @@ def propagate( return (U, d, u, MA, pA), None - (U, d, u, MA, pA), _ = jax.lax.scan( - f=loop_body_pass2, - init=pass_2_carry, - xs=np.flip(np.arange(start=1, stop=model.NB)), + (U, d, u, MA, pA), _ = ( + jax.lax.scan( + f=loop_body_pass2, + init=pass_2_carry, + xs=np.flip(np.arange(start=1, stop=model.NB)), + ) + if model.NB > 1 + else [(U, d, u, MA, pA), None] ) if model.is_floating_base: @@ -226,15 +234,19 @@ def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> Tuple[Pass3Carry, None]: qdd_ii = (u[i] - U[i].T @ a_i) / d[i] qdd = qdd.at[i - 1].set(qdd_ii.squeeze()) if qdd.size != 0 else qdd - a_i = a_i + S[i] * qdd[ii] if qdd.size != 0 else a_i + a_i = a_i + S[i] * qdd[ii] a = a.at[i].set(a_i) return (a, qdd), None - (a, qdd), _ = jax.lax.scan( - f=loop_body_pass3, - init=pass_3_carry, - xs=np.arange(1, model.NB), + (a, qdd), _ = ( + jax.lax.scan( + f=loop_body_pass3, + init=pass_3_carry, + xs=np.arange(1, model.NB), + ) + if model.NB > 1 + else [(a, qdd), None] ) # Handle 1 DoF models diff --git a/src/jaxsim/physics/algos/rnea.py b/src/jaxsim/physics/algos/rnea.py index d51f5d5a4..cd50039d5 100644 --- a/src/jaxsim/physics/algos/rnea.py +++ b/src/jaxsim/physics/algos/rnea.py @@ -130,10 +130,14 @@ def forward_pass( return (i_X_λi, v, a, i_X_0, f), None - (i_X_λi, v, a, i_X_0, f), _ = jax.lax.scan( - f=forward_pass, - init=forward_pass_carry, - xs=np.arange(start=1, stop=model.NB), + (i_X_λi, v, a, i_X_0, f), _ = ( + jax.lax.scan( + f=forward_pass, + init=forward_pass_carry, + xs=np.arange(start=1, stop=model.NB), + ) + if model.NB > 1 + else [(i_X_λi, v, a, i_X_0, f), None] ) tau = jnp.zeros_like(q) @@ -164,10 +168,14 @@ def update_f(f: jtp.MatrixJax) -> jtp.MatrixJax: return (tau, f), None - (tau, f), _ = jax.lax.scan( - f=backward_pass, - init=backward_pass_carry, - xs=np.flip(np.arange(start=1, stop=model.NB)), + (tau, f), _ = ( + jax.lax.scan( + f=backward_pass, + init=backward_pass_carry, + xs=np.flip(np.arange(start=1, stop=model.NB)), + ) + if model.NB > 1 + else [(tau, f), None] ) # Handle 1 DoF models diff --git a/src/jaxsim/physics/algos/soft_contacts.py b/src/jaxsim/physics/algos/soft_contacts.py index 02ebda317..d6d2b824d 100644 --- a/src/jaxsim/physics/algos/soft_contacts.py +++ b/src/jaxsim/physics/algos/soft_contacts.py @@ -176,10 +176,14 @@ def propagate_transforms( # Pack and return the carry return (W_X_i,), None - (W_X_i,), _ = jax.lax.scan( - f=propagate_transforms, - init=propagate_transforms_carry, - xs=np.arange(start=1, stop=model.NB), + (W_X_i,), _ = ( + jax.lax.scan( + f=propagate_transforms, + init=propagate_transforms_carry, + xs=np.arange(start=1, stop=model.NB), + ) + if model.NB > 1 + else [(W_X_i,), None] ) # ==================== @@ -209,10 +213,14 @@ def propagate_velocities( # Pack and return the carry return (W_v_Wi,), None - (W_v_Wi,), _ = jax.lax.scan( - f=propagate_velocities, - init=propagate_velocities_carry, - xs=jnp.vstack([qd, jnp.arange(start=0, stop=qd.size)]).T, + (W_v_Wi,), _ = ( + jax.lax.scan( + f=propagate_velocities, + init=propagate_velocities_carry, + xs=jnp.vstack([qd, jnp.arange(start=0, stop=qd.size)]).T, + ) + if model.NB > 1 + else [(W_v_Wi,), None] ) # ================================================== From 1b35ea23c982a8473a68b76278f66030fe8e1299 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Mon, 11 Mar 2024 12:32:28 +0100 Subject: [PATCH 23/35] Fix collidable_points_pos_vel for models with no collidable points --- src/jaxsim/physics/algos/soft_contacts.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/jaxsim/physics/algos/soft_contacts.py b/src/jaxsim/physics/algos/soft_contacts.py index d6d2b824d..ab86b3554 100644 --- a/src/jaxsim/physics/algos/soft_contacts.py +++ b/src/jaxsim/physics/algos/soft_contacts.py @@ -119,6 +119,9 @@ def collidable_points_pos_vel( Tuple[jtp.Matrix, jtp.Matrix]: A tuple containing the position and velocity of collidable points. """ + if len(model.gc.body) == 0: + return jnp.empty(0), jnp.empty(0) + # Make sure that shape and size are correct xfb, q, qd, _, _, _ = utils.process_inputs(physics_model=model, xfb=xfb, q=q, qd=qd) From d7f13176fe5ded6bdf59a29125f6db8d71a0d3f9 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 8 Mar 2024 17:07:31 +0100 Subject: [PATCH 24/35] Remove pytest-forked --- environment.yml | 1 - pyproject.toml | 2 +- setup.cfg | 3 +-- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/environment.yml b/environment.yml index 77510645d..654529e54 100644 --- a/environment.yml +++ b/environment.yml @@ -19,7 +19,6 @@ dependencies: # [testing] - idyntree - pytest - - pytest-forked - pytest-icdiff - robot_descriptions # [viz] diff --git a/pyproject.toml b/pyproject.toml index 3a74e3a49..f44fd14b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ multi_line_output = 3 [tool.pytest.ini_options] minversion = "6.0" -addopts = "-rsxX -v --strict-markers --forked" +addopts = "-rsxX -v --strict-markers" testpaths = [ "tests", ] diff --git a/setup.cfg b/setup.cfg index 88b0b69f8..099431d63 100644 --- a/setup.cfg +++ b/setup.cfg @@ -71,8 +71,7 @@ style = pre-commit testing = idyntree - pytest >= 6.0 - pytest-forked + pytest >=6.0 pytest-icdiff robot-descriptions viz = From 9612045f1f9c2df92168ea8727eefebc1f8ad661 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Mon, 11 Mar 2024 12:35:02 +0100 Subject: [PATCH 25/35] Require updated version of the rod dependency --- environment.yml | 2 +- setup.cfg | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/environment.yml b/environment.yml index 654529e54..5f8ce23ef 100644 --- a/environment.yml +++ b/environment.yml @@ -10,7 +10,7 @@ dependencies: - jaxlie >= 1.3.0 - jax-dataclasses >= 1.4.0 - pptree - - rod + - rod >= 0.2.0 - typing_extensions # python<3.12 # Optional dependencies from setup.cfg # [style] diff --git a/setup.cfg b/setup.cfg index 099431d63..7754644ce 100644 --- a/setup.cfg +++ b/setup.cfg @@ -58,7 +58,7 @@ install_requires = jaxlie >= 1.3.0 jax_dataclasses >= 1.4.0 pptree - rod + rod >= 0.2.0 typing_extensions ; python_version < '3.12' [options.packages.find] From e96595ba031edb841fae7d2512b645ee68cb3bf3 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Mon, 11 Mar 2024 12:51:00 +0100 Subject: [PATCH 26/35] Fix joint.position_limit for jointless models --- src/jaxsim/api/joint.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/jaxsim/api/joint.py b/src/jaxsim/api/joint.py index 1757925da..7c5668e8c 100644 --- a/src/jaxsim/api/joint.py +++ b/src/jaxsim/api/joint.py @@ -103,10 +103,13 @@ def position_limit( ) -> tuple[jtp.Float, jtp.Float]: """""" - min = model.physics_model._joint_position_limits_min[joint_index] - max = model.physics_model._joint_position_limits_max[joint_index] + if model.physics_model.NB <= 1: + return jnp.array([]), jnp.array([]) - return min.astype(float), max.astype(float) + s_min = model.physics_model._joint_position_limits_min[joint_index] + s_max = model.physics_model._joint_position_limits_max[joint_index] + + return s_min.astype(float), s_max.astype(float) @functools.partial(jax.jit, static_argnames=["joint_names"]) From 7a4e1ce33019562f55f8ebcaacd9e9729fbed00f Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Mon, 11 Mar 2024 15:36:27 +0100 Subject: [PATCH 27/35] Fix cast of gc.body in api.ode --- src/jaxsim/api/ode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index d3cada9c9..8432a6806 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -142,7 +142,7 @@ def system_velocity_dynamics( lambda nc: ( jnp.vstack( jnp.equal( - np.array(model.physics_model.gc.body, dtype=int), nc + jnp.array(model.physics_model.gc.body, dtype=int), nc ).astype(int) ) * W_f_Ci From 75cd229a5c6443a72cb76ff9fdac9787d7d21b8e Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Mon, 11 Mar 2024 16:05:24 +0100 Subject: [PATCH 28/35] Do not raise in simulation.utils.check_valid_shape --- src/jaxsim/simulation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jaxsim/simulation/utils.py b/src/jaxsim/simulation/utils.py index 0571ee820..d03d8c39d 100644 --- a/src/jaxsim/simulation/utils.py +++ b/src/jaxsim/simulation/utils.py @@ -10,6 +10,6 @@ def check_valid_shape( if not valid_shape: logging.debug(f"Shape of {what} differs: {shape}, {expected_shape}") - raise + return False return valid From c7c86442cbe6f6d62b443c20f7b9e43ae223aa7d Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Mon, 11 Mar 2024 16:08:34 +0100 Subject: [PATCH 29/35] Increase tolerance in test of reduced CoM position --- tests/test_api_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_api_model.py b/tests/test_api_model.py index d53e3595b..29192a115 100644 --- a/tests/test_api_model.py +++ b/tests/test_api_model.py @@ -65,7 +65,7 @@ def test_model_creation_and_reduction( # Check that the CoM position is preserved. assert js.model.com_position(model=model_full, data=data) == pytest.approx( - js.model.com_position(model=model_reduced, data=data_reduced) + js.model.com_position(model=model_reduced, data=data_reduced), abs=1e-6 ) From 866a6e39ba9590e84825104c67f94f1de2949542 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Mon, 11 Mar 2024 14:53:16 +0100 Subject: [PATCH 30/35] Update api.__init__.py --- src/jaxsim/api/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/jaxsim/api/__init__.py b/src/jaxsim/api/__init__.py index 44436c5f6..a85176dd4 100644 --- a/src/jaxsim/api/__init__.py +++ b/src/jaxsim/api/__init__.py @@ -1 +1,2 @@ -from . import contact, data, joint, link, model, ode +from . import model, data # isort:skip +from . import common, contact, joint, link, ode, references From 2b4c2c5801666b608ebdbcd2f80ac0cbb30cece7 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Mon, 11 Mar 2024 16:11:19 +0100 Subject: [PATCH 31/35] Add again support for latest jax and jaxlib --- setup.cfg | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 7754644ce..ddc76c923 100644 --- a/setup.cfg +++ b/setup.cfg @@ -53,8 +53,8 @@ package_dir = python_requires = >=3.11 install_requires = coloredlogs - jax >= 0.4.13,< 0.4.25 - jaxlib >= 0.4.13,< 0.4.25 + jax >= 0.4.13 + jaxlib >= 0.4.13 jaxlie >= 1.3.0 jax_dataclasses >= 1.4.0 pptree From 6397446a316b7cc10f3244659fffa5bf4c1324ec Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Mon, 11 Mar 2024 23:47:02 +0100 Subject: [PATCH 32/35] Install Gazebo Sim instead of Gazebo Classic in CI --- .github/workflows/ci_cd.yml | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci_cd.yml b/.github/workflows/ci_cd.yml index 200fce627..eeb3c1154 100644 --- a/.github/workflows/ci_cd.yml +++ b/.github/workflows/ci_cd.yml @@ -100,11 +100,22 @@ jobs: with: fetch-depth: 0 - - name: Install Gazebo Classic +# - name: Install Gazebo Classic +# if: contains(matrix.os, 'ubuntu') +# run: | +# sudo apt-get update +# sudo apt-get install gazebo + + # https://gazebosim.org/docs/harmonic/install_ubuntu + - name: Install Gazebo Sim if: contains(matrix.os, 'ubuntu') run: | sudo apt-get update - sudo apt-get install gazebo + sudo apt-get install lsb-release wget gnupg + sudo wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg + echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null + sudo apt-get update + sudo apt-get install gz-harmonic - name: Run the Python tests if: contains(matrix.os, 'ubuntu') From 67486e089037d1ee9b9c3ab2d12cc5dff6d84778 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 12 Mar 2024 09:06:59 +0100 Subject: [PATCH 33/35] Add test of jit compiling functions taking JaxSimModel as input --- tests/test_pytree.py | 62 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 tests/test_pytree.py diff --git a/tests/test_pytree.py b/tests/test_pytree.py new file mode 100644 index 000000000..9c91671f8 --- /dev/null +++ b/tests/test_pytree.py @@ -0,0 +1,62 @@ +import io +from contextlib import redirect_stdout + +import jax +import jax.numpy as jnp +import rod.builder.primitives +import rod.urdf.exporter + +import jaxsim.api as js + + +# https://github.com/ami-iit/jaxsim/issues/103 +def test_call_jit_compiled_function_passing_different_objects(): + + # Create on-the-fly a ROD model of a box. + rod_model = ( + rod.builder.primitives.BoxBuilder(x=0.3, y=0.2, z=0.1, mass=1.0, name="box") + .build_model() + .add_link() + .add_inertial() + .add_visual() + .add_collision() + .build() + ) + + # Export the URDF string. + urdf_string = rod.urdf.exporter.UrdfExporter.sdf_to_urdf_string( + sdf=rod_model, pretty=True + ) + + model1 = js.model.JaxSimModel.build_from_model_description( + model_description=urdf_string, + gravity=jnp.array([0, 0, -10]), + is_urdf=True, + ) + + model2 = js.model.JaxSimModel.build_from_model_description( + model_description=urdf_string, + gravity=jnp.array([0, 0, -10]), + is_urdf=True, + ) + + assert model1 == model2 + assert hash(model1) == hash(model2) + + # If this function has never been compiled by any other test, JAX will + # jit-compile it here. + _ = js.contact.estimate_good_soft_contacts_parameters(model=model1) + + # Now JAX should not compile it again. + with jax.log_compiles(): + with io.StringIO() as buf, redirect_stdout(buf): + # Beyond running without any JIT recompilations, the following function + # should work on different JaxSimModel objects without raising any errors + # related to the comparison of Static fields. + _ = js.contact.estimate_good_soft_contacts_parameters(model=model2) + stdout = buf.getvalue() + + assert ( + f"Compiling {js.contact.estimate_good_soft_contacts_parameters.__name__}" + not in stdout + ) From 5dcfeb9b9cc4546c9c1a6dd03a9fca02ed264891 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 12 Mar 2024 10:33:22 +0100 Subject: [PATCH 34/35] Use ordered split keys in random data generation --- src/jaxsim/api/data.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index 59294e6c2..5a018840d 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -767,16 +767,16 @@ def random_model_data( ) physics_model_state.joint_velocities = jax.random.uniform( - key=k6, shape=(model.dofs(),), minval=ṡ_min, maxval=ṡ_max + key=k4, shape=(model.dofs(),), minval=ṡ_min, maxval=ṡ_max ) if model.floating_base(): physics_model_state.base_linear_velocity = jax.random.uniform( - key=k4, shape=(3,), minval=v_min, maxval=v_max + key=k5, shape=(3,), minval=v_min, maxval=v_max ) physics_model_state.base_angular_velocity = jax.random.uniform( - key=k5, shape=(3,), minval=ω_min, maxval=ω_max + key=k6, shape=(3,), minval=ω_min, maxval=ω_max ) return random_data From a54bfaeb5bf008fb52a20dbe0000079caebfa315 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 12 Mar 2024 12:09:10 +0100 Subject: [PATCH 35/35] Do not alter during runtime Butcher tableau coefficients --- src/jaxsim/integrators/common.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/src/jaxsim/integrators/common.py b/src/jaxsim/integrators/common.py index 53f5605a8..1eb231f0b 100644 --- a/src/jaxsim/integrators/common.py +++ b/src/jaxsim/integrators/common.py @@ -209,20 +209,10 @@ def build( The integrator object. """ - # Adjust the shape of the tableau coefficients. - c = jnp.atleast_1d(cls.c.squeeze()) - b = jnp.atleast_2d(cls.b.T.squeeze()).transpose() - A = jnp.atleast_2d(cls.A.squeeze()) - # Check validity of the Butcher tableau. - if not ExplicitRungeKutta.butcher_tableau_is_valid(A=A, b=b, c=c): + if not ExplicitRungeKutta.butcher_tableau_is_valid(A=cls.A, b=cls.b, c=cls.c): raise ValueError("The Butcher tableau of this class is not valid.") - # Store the adjusted shapes of the tableau coefficients. - cls.c = c - cls.b = b - cls.A = A - # Check that b.T has enough rows based on the configured index of the solution. if cls.row_index_of_solution >= cls.b.T.shape[0]: msg = "The index of the solution ({}-th row of `b.T`) is out of range ({})."