Skip to content

Commit

Permalink
Add additional type checks
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Jun 14, 2024
1 parent 8a8219b commit 2e733fe
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 26 deletions.
2 changes: 2 additions & 0 deletions src/jaxsim/api/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import jaxsim.rbda
import jaxsim.typing as jtp

from .common import VelRepr

# =======================
# Index-related functions
# =======================
Expand Down
76 changes: 50 additions & 26 deletions src/jaxsim/api/references.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
except ImportError:
from typing_extensions import Self

from .data import JaxSimModelData


@jax_dataclasses.pytree_dataclass
class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
Expand Down Expand Up @@ -179,7 +181,7 @@ def link_forces(
# serialization.
if model is None:

def inertial():
def inertial() -> jtp.Array:
if link_names is not None:
raise ValueError("Link names cannot be provided without a model")

Expand All @@ -202,7 +204,7 @@ def inertial():
link_names = link_names if link_names is not None else model.link_names()
link_idxs = js.link.names_to_idxs(link_names=link_names, model=model)

def not_inertial():
def check_not_inertial() -> None:
if data is None:
raise ValueError(
"Missing model data to use a representation different from `VelRepr.Inertial`"
Expand All @@ -213,33 +215,44 @@ def not_inertial():
):
raise ValueError("The provided data is not valid for the model")

# If not inertial-fixed representation, we need the model data.
jax.lax.cond(
pred=(self.velocity_representation != VelRepr.Inertial),
true_fun=lambda: jax.pure_callback(
callback=check_not_inertial,
result_shape_dtypes=None,
),
false_fun=lambda: None,
)

def not_inertial(velocity_representation: int) -> jtp.Matrix:
# Helper function to convert a single 6D force to the active representation
# considering as body the link (i.e. L_f_L and LW_f_L).
def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix:

return jax.vmap(
lambda W_f_L, W_H_L: JaxSimModelReferences.inertial_to_other_representation(
array=W_f_L,
other_representation=self.velocity_representation,
other_representation=velocity_representation,
transform=W_H_L,
is_force=True,
)
)(W_f_L, W_H_L)

# The f_L output is either L_f_L or LW_f_L, depending on the representation.
W_H_L = js.model.forward_kinematics(model=model, data=data)
W_H_L = js.model.forward_kinematics(
model=model, data=data or JaxSimModelData.zero(model=model)
)
f_L = convert(W_f_L=W_f_L[link_idxs, :], W_H_L=W_H_L[link_idxs, :, :])

return f_L

# In inertial-fixed representation, we already have the link forces.
return jax.lax.cond(
pred=(self.velocity_representation == VelRepr.Inertial),
true_fun=lambda: W_f_L[link_idxs, :],
false_fun=lambda: jax.pure_callback(
callback=not_inertial,
result_shape_dtypes=W_f_L[link_idxs, :],
),
true_fun=lambda _: W_f_L[link_idxs, :],
false_fun=not_inertial,
operand=self.velocity_representation,
)

def joint_force_references(
Expand Down Expand Up @@ -388,7 +401,7 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences:
# using the implicit link serialization.
if model is None:

def inertial():
def inertial() -> JaxSimModelReferences:
if link_names is not None:
raise ValueError("Link names cannot be provided without a model")

Expand Down Expand Up @@ -435,16 +448,7 @@ def inertial():
else self.input.physics_model.f_ext[link_idxs, :]
)

# If inertial-fixed representation, we can directly store the link forces.
def inertial():
W_f_L = f_L
return replace(
forces=self.input.physics_model.f_ext.at[link_idxs, :].set(
W_f0_L + W_f_L
)
)

def not_inertial(data):
def check_not_inertial() -> None:
if data is None:
raise ValueError(
"Missing model data to use a representation different from `VelRepr.Inertial`"
Expand All @@ -453,6 +457,26 @@ def not_inertial(data):
if not_tracing(forces) and not data.valid(model=model):
raise ValueError("The provided data is not valid for the model")

# If not inertial-fixed representation, we need the model data.
jax.lax.cond(
pred=(self.velocity_representation != VelRepr.Inertial),
true_fun=lambda: jax.pure_callback(
callback=check_not_inertial,
result_shape_dtypes=None,
),
false_fun=lambda: None,
)

# If inertial-fixed representation, we can directly store the link forces.
def inertial(velocity_representation: int) -> JaxSimModelReferences:
W_f_L = f_L
return replace(
forces=self.input.physics_model.f_ext.at[link_idxs, :].set(
W_f0_L + W_f_L
)
)

def not_inertial(velocity_representation: int) -> JaxSimModelReferences:
# Helper function to convert a single 6D force to the inertial representation
# considering as body the link (i.e. L_f_L and LW_f_L).
def convert_using_link_frame(
Expand All @@ -462,14 +486,16 @@ def convert_using_link_frame(
return jax.vmap(
lambda f_L, W_H_L: JaxSimModelReferences.other_representation_to_inertial(
array=f_L,
other_representation=self.velocity_representation,
other_representation=velocity_representation,
transform=W_H_L,
is_force=True,
)
)(f_L, W_H_L)

# The f_L input is either L_f_L or LW_f_L, depending on the representation.
W_H_L = js.model.forward_kinematics(model=model, data=data)
W_H_L = js.model.forward_kinematics(
model=model, data=data or JaxSimModelData.zero(model=model)
)
W_f_L = convert_using_link_frame(f_L=f_L, W_H_L=W_H_L[link_idxs, :, :])

return replace(
Expand All @@ -481,8 +507,6 @@ def convert_using_link_frame(
return jax.lax.cond(
pred=(self.velocity_representation == VelRepr.Inertial),
true_fun=inertial,
false_fun=lambda: jax.experimental.io_callback(
callback=not_inertial,
result_shape_dtypes=self,
),
false_fun=not_inertial,
operand=self.velocity_representation,
)

0 comments on commit 2e733fe

Please sign in to comment.