diff --git a/src/jaxsim/api/link.py b/src/jaxsim/api/link.py index 4f479928c..6b108281c 100644 --- a/src/jaxsim/api/link.py +++ b/src/jaxsim/api/link.py @@ -11,6 +11,8 @@ import jaxsim.rbda import jaxsim.typing as jtp +from .common import VelRepr + # ======================= # Index-related functions # ======================= diff --git a/src/jaxsim/api/references.py b/src/jaxsim/api/references.py index 34a24c896..ddc350eba 100644 --- a/src/jaxsim/api/references.py +++ b/src/jaxsim/api/references.py @@ -18,6 +18,8 @@ except ImportError: from typing_extensions import Self +from .data import JaxSimModelData + @jax_dataclasses.pytree_dataclass class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation): @@ -179,7 +181,7 @@ def link_forces( # serialization. if model is None: - def inertial(): + def inertial() -> jtp.Array: if link_names is not None: raise ValueError("Link names cannot be provided without a model") @@ -202,7 +204,7 @@ def inertial(): link_names = link_names if link_names is not None else model.link_names() link_idxs = js.link.names_to_idxs(link_names=link_names, model=model) - def not_inertial(): + def check_not_inertial() -> None: if data is None: raise ValueError( "Missing model data to use a representation different from `VelRepr.Inertial`" @@ -213,6 +215,17 @@ def not_inertial(): ): raise ValueError("The provided data is not valid for the model") + # If not inertial-fixed representation, we need the model data. + jax.lax.cond( + pred=(self.velocity_representation != VelRepr.Inertial), + true_fun=lambda: jax.pure_callback( + callback=check_not_inertial, + result_shape_dtypes=None, + ), + false_fun=lambda: None, + ) + + def not_inertial(velocity_representation: int) -> jtp.Matrix: # Helper function to convert a single 6D force to the active representation # considering as body the link (i.e. L_f_L and LW_f_L). def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix: @@ -220,14 +233,16 @@ def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix: return jax.vmap( lambda W_f_L, W_H_L: JaxSimModelReferences.inertial_to_other_representation( array=W_f_L, - other_representation=self.velocity_representation, + other_representation=velocity_representation, transform=W_H_L, is_force=True, ) )(W_f_L, W_H_L) # The f_L output is either L_f_L or LW_f_L, depending on the representation. - W_H_L = js.model.forward_kinematics(model=model, data=data) + W_H_L = js.model.forward_kinematics( + model=model, data=data or JaxSimModelData.zero(model=model) + ) f_L = convert(W_f_L=W_f_L[link_idxs, :], W_H_L=W_H_L[link_idxs, :, :]) return f_L @@ -235,11 +250,9 @@ def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix: # In inertial-fixed representation, we already have the link forces. return jax.lax.cond( pred=(self.velocity_representation == VelRepr.Inertial), - true_fun=lambda: W_f_L[link_idxs, :], - false_fun=lambda: jax.pure_callback( - callback=not_inertial, - result_shape_dtypes=W_f_L[link_idxs, :], - ), + true_fun=lambda _: W_f_L[link_idxs, :], + false_fun=not_inertial, + operand=self.velocity_representation, ) def joint_force_references( @@ -388,7 +401,7 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences: # using the implicit link serialization. if model is None: - def inertial(): + def inertial() -> JaxSimModelReferences: if link_names is not None: raise ValueError("Link names cannot be provided without a model") @@ -435,16 +448,7 @@ def inertial(): else self.input.physics_model.f_ext[link_idxs, :] ) - # If inertial-fixed representation, we can directly store the link forces. - def inertial(): - W_f_L = f_L - return replace( - forces=self.input.physics_model.f_ext.at[link_idxs, :].set( - W_f0_L + W_f_L - ) - ) - - def not_inertial(data): + def check_not_inertial() -> None: if data is None: raise ValueError( "Missing model data to use a representation different from `VelRepr.Inertial`" @@ -453,6 +457,26 @@ def not_inertial(data): if not_tracing(forces) and not data.valid(model=model): raise ValueError("The provided data is not valid for the model") + # If not inertial-fixed representation, we need the model data. + jax.lax.cond( + pred=(self.velocity_representation != VelRepr.Inertial), + true_fun=lambda: jax.pure_callback( + callback=check_not_inertial, + result_shape_dtypes=None, + ), + false_fun=lambda: None, + ) + + # If inertial-fixed representation, we can directly store the link forces. + def inertial(velocity_representation: int) -> JaxSimModelReferences: + W_f_L = f_L + return replace( + forces=self.input.physics_model.f_ext.at[link_idxs, :].set( + W_f0_L + W_f_L + ) + ) + + def not_inertial(velocity_representation: int) -> JaxSimModelReferences: # Helper function to convert a single 6D force to the inertial representation # considering as body the link (i.e. L_f_L and LW_f_L). def convert_using_link_frame( @@ -462,14 +486,16 @@ def convert_using_link_frame( return jax.vmap( lambda f_L, W_H_L: JaxSimModelReferences.other_representation_to_inertial( array=f_L, - other_representation=self.velocity_representation, + other_representation=velocity_representation, transform=W_H_L, is_force=True, ) )(f_L, W_H_L) # The f_L input is either L_f_L or LW_f_L, depending on the representation. - W_H_L = js.model.forward_kinematics(model=model, data=data) + W_H_L = js.model.forward_kinematics( + model=model, data=data or JaxSimModelData.zero(model=model) + ) W_f_L = convert_using_link_frame(f_L=f_L, W_H_L=W_H_L[link_idxs, :, :]) return replace( @@ -481,8 +507,6 @@ def convert_using_link_frame( return jax.lax.cond( pred=(self.velocity_representation == VelRepr.Inertial), true_fun=inertial, - false_fun=lambda: jax.experimental.io_callback( - callback=not_inertial, - result_shape_dtypes=self, - ), + false_fun=not_inertial, + operand=self.velocity_representation, )