Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new test suite of functional APIs #106

Merged
merged 35 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
688664b
Remove existing tests of OOP APIs
diegoferigo Mar 12, 2024
1b506de
Add a pytest fixture to generate a PRNG key
diegoferigo Mar 8, 2024
ff02a33
Add fixture to parameterize tests over all velocity representations
diegoferigo Mar 8, 2024
d70568c
Add session-wide fixtures to provide tested models
diegoferigo Mar 8, 2024
020ff88
Add collections of tested models
diegoferigo Mar 8, 2024
43fa2f9
Add tests of jaxsim.api.data module
diegoferigo Mar 8, 2024
5cdd914
Add test of jaxsim.api.joint module
diegoferigo Mar 8, 2024
446369a
Add test of jaxsim.api.link module
diegoferigo Mar 8, 2024
d95b589
Add test of jaxsim.api.model module
diegoferigo Mar 8, 2024
29ec058
Add methods to wrapper utils_idyntree.KinDynComputations
diegoferigo Mar 12, 2024
9b49549
Add other iDynTree testing helpers
diegoferigo Mar 8, 2024
97e54d6
Add test to check automatic differentiation of algorithms
diegoferigo Mar 8, 2024
b977b92
Fix random generation of JaxSimModelData for fixed-base models
diegoferigo Mar 8, 2024
8314788
Fix random generation of JaxSimModelData for jointless models
diegoferigo Mar 11, 2024
ccc3473
Fix api.joint.name_to_idx
diegoferigo Mar 8, 2024
0932028
Fix ABA in Mixed representation for fixed-base models
diegoferigo Mar 8, 2024
c234320
Fix ABA when inputs are passed
diegoferigo Mar 11, 2024
196426c
Fix computation of bias forces for fixed-base models
diegoferigo Mar 8, 2024
5eeda35
Rename Heun integrator to Heun2
diegoferigo Mar 8, 2024
f4563bf
Fix ForwardEuler integrator
diegoferigo Mar 8, 2024
b81bc1d
Fix shape correction of the Butcher tableau b parameter
diegoferigo Mar 11, 2024
945f04b
Fix non-jit execution of RBDAs for models with no joints
diegoferigo Mar 8, 2024
1b35ea2
Fix collidable_points_pos_vel for models with no collidable points
diegoferigo Mar 11, 2024
d7f1317
Remove pytest-forked
diegoferigo Mar 8, 2024
9612045
Require updated version of the rod dependency
diegoferigo Mar 11, 2024
e96595b
Fix joint.position_limit for jointless models
diegoferigo Mar 11, 2024
7a4e1ce
Fix cast of gc.body in api.ode
diegoferigo Mar 11, 2024
75cd229
Do not raise in simulation.utils.check_valid_shape
diegoferigo Mar 11, 2024
c7c8644
Increase tolerance in test of reduced CoM position
diegoferigo Mar 11, 2024
866a6e3
Update api.__init__.py
diegoferigo Mar 11, 2024
2b4c2c5
Add again support for latest jax and jaxlib
diegoferigo Mar 11, 2024
6397446
Install Gazebo Sim instead of Gazebo Classic in CI
diegoferigo Mar 11, 2024
67486e0
Add test of jit compiling functions taking JaxSimModel as input
diegoferigo Mar 12, 2024
5dcfeb9
Use ordered split keys in random data generation
diegoferigo Mar 12, 2024
a54bfae
Do not alter during runtime Butcher tableau coefficients
diegoferigo Mar 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions .github/workflows/ci_cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,22 @@ jobs:
with:
fetch-depth: 0

- name: Install Gazebo Classic
# - name: Install Gazebo Classic
# if: contains(matrix.os, 'ubuntu')
# run: |
# sudo apt-get update
# sudo apt-get install gazebo

# https://gazebosim.org/docs/harmonic/install_ubuntu
- name: Install Gazebo Sim
if: contains(matrix.os, 'ubuntu')
run: |
sudo apt-get update
sudo apt-get install gazebo
sudo apt-get install lsb-release wget gnupg
sudo wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg
echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null
sudo apt-get update
sudo apt-get install gz-harmonic

- name: Run the Python tests
if: contains(matrix.os, 'ubuntu')
Expand Down
3 changes: 1 addition & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dependencies:
- jaxlie >= 1.3.0
- jax-dataclasses >= 1.4.0
- pptree
- rod
- rod >= 0.2.0
- typing_extensions # python<3.12
# Optional dependencies from setup.cfg
# [style]
Expand All @@ -19,7 +19,6 @@ dependencies:
# [testing]
- idyntree
- pytest
- pytest-forked
- pytest-icdiff
- robot_descriptions
# [viz]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ multi_line_output = 3

[tool.pytest.ini_options]
minversion = "6.0"
addopts = "-rsxX -v --strict-markers --forked"
addopts = "-rsxX -v --strict-markers"
testpaths = [
"tests",
]
9 changes: 4 additions & 5 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ package_dir =
python_requires = >=3.11
install_requires =
coloredlogs
jax >= 0.4.13,< 0.4.25
jaxlib >= 0.4.13,< 0.4.25
jax >= 0.4.13
jaxlib >= 0.4.13
jaxlie >= 1.3.0
jax_dataclasses >= 1.4.0
pptree
rod
rod >= 0.2.0
typing_extensions ; python_version < '3.12'

[options.packages.find]
Expand All @@ -71,8 +71,7 @@ style =
pre-commit
testing =
idyntree
pytest >= 6.0
pytest-forked
pytest >=6.0
pytest-icdiff
robot-descriptions
viz =
Expand Down
3 changes: 2 additions & 1 deletion src/jaxsim/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from . import contact, data, joint, link, model, ode
from . import model, data # isort:skip
from . import common, contact, joint, link, ode, references
26 changes: 14 additions & 12 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,20 +761,22 @@ def random_model_data(
*jax.random.uniform(key=k2, shape=(3,), minval=0, maxval=2 * jnp.pi)
).as_quaternion_xyzw()[np.array([3, 0, 1, 2])]

physics_model_state.joint_positions = jaxsim.api.joint.random_joint_positions(
model=model, key=k3
)
if model.number_of_joints() > 0:
physics_model_state.joint_positions = (
jaxsim.api.joint.random_joint_positions(model=model, key=k3)
)

physics_model_state.base_linear_velocity = jax.random.uniform(
key=k4, shape=(3,), minval=v_min, maxval=v_max
)
physics_model_state.joint_velocities = jax.random.uniform(
key=k4, shape=(model.dofs(),), minval=ṡ_min, maxval=ṡ_max
)

physics_model_state.base_angular_velocity = jax.random.uniform(
key=k5, shape=(3,), minval=ω_min, maxval=ω_max
)
if model.floating_base():
physics_model_state.base_linear_velocity = jax.random.uniform(
key=k5, shape=(3,), minval=v_min, maxval=v_max
)

physics_model_state.joint_velocities = jax.random.uniform(
key=k6, shape=(model.dofs(),), minval=ṡ_min, maxval=ṡ_max
)
physics_model_state.base_angular_velocity = jax.random.uniform(
key=k6, shape=(3,), minval=ω_min, maxval=ω_max
)

return random_data
11 changes: 7 additions & 4 deletions src/jaxsim/api/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def name_to_idx(model: Model.JaxSimModel, *, joint_name: str) -> jtp.Int:
"""

return jnp.array(
model.physics_model.description.joints_dict[joint_name].index, dtype=int
model.physics_model.description.joints_dict[joint_name].index - 1, dtype=int
)


Expand Down Expand Up @@ -103,10 +103,13 @@ def position_limit(
) -> tuple[jtp.Float, jtp.Float]:
""""""

min = model.physics_model._joint_position_limits_min[joint_index]
max = model.physics_model._joint_position_limits_max[joint_index]
if model.physics_model.NB <= 1:
return jnp.array([]), jnp.array([])
diegoferigo marked this conversation as resolved.
Show resolved Hide resolved

return min.astype(float), max.astype(float)
s_min = model.physics_model._joint_position_limits_min[joint_index]
s_max = model.physics_model._joint_position_limits_max[joint_index]

return s_min.astype(float), s_max.astype(float)


@functools.partial(jax.jit, static_argnames=["joint_names"])
Expand Down
36 changes: 26 additions & 10 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,14 +549,22 @@ def forward_dynamics_aba(
else jnp.zeros((model.number_of_links(), 6))
)

references = js.references.JaxSimModelReferences.build(
model=model,
joint_force_references=τ,
link_forces=f_ext,
data=data,
velocity_representation=data.velocity_representation,
)

# Compute ABA
W_v̇_WB, s̈ = jaxsim.physics.algos.aba.aba(
model=model.physics_model,
xfb=data.state.physics_model.xfb(),
q=data.state.physics_model.joint_positions,
qd=data.state.physics_model.joint_velocities,
tau=τ,
f_ext=f_ext,
tau=references.input.physics_model.tau,
f_ext=references.input.physics_model.f_ext,
)

def to_active(W_vd_WB, W_H_C, W_v_WB, W_vl_WC):
Expand Down Expand Up @@ -602,6 +610,12 @@ def to_active(W_vd_WB, W_H_C, W_v_WB, W_vl_WC):
W_vl_WC=W_vl_WC,
)

# The ABA algorithm already returns a zero base 6D acceleration for
# fixed-based models. However, the to_active function introduces an
# additional acceleration component in Mixed representation.
# Here below we make sure that the base acceleration is zero.
C_v̇_WB = C_v̇_WB if model.floating_base() else jnp.zeros(6).astype(float)

# Adjust shape
s̈ = jnp.atleast_1d(s̈.squeeze())

Expand Down Expand Up @@ -948,18 +962,20 @@ def free_floating_bias_forces(
data.state.physics_model.joint_positions
)

data_rnea.state.physics_model.base_linear_velocity = (
data.state.physics_model.base_linear_velocity
)

data_rnea.state.physics_model.base_angular_velocity = (
data.state.physics_model.base_angular_velocity
)

data_rnea.state.physics_model.joint_velocities = (
data.state.physics_model.joint_velocities
)

# Make sure that base velocity is zero for fixed-base model.
if model.floating_base():
data_rnea.state.physics_model.base_linear_velocity = (
data.state.physics_model.base_linear_velocity
)

data_rnea.state.physics_model.base_angular_velocity = (
data.state.physics_model.base_angular_velocity
)

return jnp.hstack(
inverse_dynamics(
model=model,
Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def system_velocity_dynamics(
lambda nc: (
jnp.vstack(
jnp.equal(
np.array(model.physics_model.gc.body, dtype=int), nc
jnp.array(model.physics_model.gc.body, dtype=int), nc
).astype(int)
)
* W_f_Ci
Expand Down
12 changes: 1 addition & 11 deletions src/jaxsim/integrators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,20 +209,10 @@ def build(
The integrator object.
"""

# Adjust the shape of the tableau coefficients.
c = jnp.atleast_1d(cls.c.squeeze())
b = jnp.atleast_2d(jnp.vstack(cls.b.squeeze()))
A = jnp.atleast_2d(cls.A.squeeze())

# Check validity of the Butcher tableau.
if not ExplicitRungeKutta.butcher_tableau_is_valid(A=A, b=b, c=c):
if not ExplicitRungeKutta.butcher_tableau_is_valid(A=cls.A, b=cls.b, c=cls.c):
raise ValueError("The Butcher tableau of this class is not valid.")

# Store the adjusted shapes of the tableau coefficients.
cls.c = c
cls.b = b
cls.A = A

# Check that b.T has enough rows based on the configured index of the solution.
if cls.row_index_of_solution >= cls.b.T.shape[0]:
msg = "The index of the solution ({}-th row of `b.T`) is out of range ({})."
Expand Down
26 changes: 6 additions & 20 deletions src/jaxsim/integrators/fixed_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,32 +20,18 @@
@jax_dataclasses.pytree_dataclass
class ForwardEuler(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):

A: ClassVar[jax.typing.ArrayLike] = jnp.array(
[
[0],
]
).astype(float)
A: ClassVar[jax.typing.ArrayLike] = jnp.atleast_2d(0).astype(float)

b: ClassVar[jax.typing.ArrayLike] = (
jnp.array(
[
[1],
]
)
.astype(float)
.transpose()
)
b: ClassVar[jax.typing.ArrayLike] = jnp.atleast_2d(1).astype(float).transpose()

c: ClassVar[jax.typing.ArrayLike] = jnp.array(
[0],
).astype(float)
c: ClassVar[jax.typing.ArrayLike] = jnp.atleast_1d(0).astype(float)

row_index_of_solution: ClassVar[int] = 0
order_of_bT_rows: ClassVar[tuple[int, ...]] = (1,)


@jax_dataclasses.pytree_dataclass
class Heun(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
class Heun2(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):

A: ClassVar[jax.typing.ArrayLike] = jnp.array(
[
Expand Down Expand Up @@ -144,12 +130,12 @@ def post_process_state(


@jax_dataclasses.pytree_dataclass
class ForwardEulerSO3(ExplicitRungeKuttaSO3Mixin, Heun[ODEState]):
class ForwardEulerSO3(ExplicitRungeKuttaSO3Mixin, ForwardEuler[ODEState]):
pass


@jax_dataclasses.pytree_dataclass
class HeunSO3(ExplicitRungeKuttaSO3Mixin, Heun[ODEState]):
class Heun2SO3(ExplicitRungeKuttaSO3Mixin, Heun2[ODEState]):
pass


Expand Down
42 changes: 27 additions & 15 deletions src/jaxsim/physics/algos/aba.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> Tuple[Pass1Carry, None]:
i_X_λi = i_X_λi.at[i].set(i_X_λi_i)

# Propagate link velocity
vJ = S[i] * qd[ii] if qd.size != 0 else S[i] * 0
vJ = S[i] * qd[ii]

v_i = i_X_λi[i] @ v[λ[i]] + vJ
v = v.at[i].set(v_i)
Expand All @@ -134,10 +134,14 @@ def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> Tuple[Pass1Carry, None]:

return (i_X_λi, v, c, MA, pA, i_X_0), None

(i_X_λi, v, c, MA, pA, i_X_0), _ = jax.lax.scan(
f=loop_body_pass1,
init=pass_1_carry,
xs=np.arange(start=1, stop=model.NB),
(i_X_λi, v, c, MA, pA, i_X_0), _ = (
jax.lax.scan(
f=loop_body_pass1,
init=pass_1_carry,
xs=np.arange(start=1, stop=model.NB),
)
if model.NB > 1
else [(i_X_λi, v, c, MA, pA, i_X_0), None]
)

U = jnp.zeros_like(S)
Expand Down Expand Up @@ -166,7 +170,7 @@ def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> Tuple[Pass2Carry, None]:
d_i = S[i].T @ U[i]
d = d.at[i].set(d_i.squeeze())

u_i = tau[ii] - S[i].T @ pA[i] if tau.size != 0 else -S[i].T @ pA[i]
u_i = tau[ii] - S[i].T @ pA[i]
u = u.at[i].set(u_i.squeeze())

# Compute the articulated-body inertia and bias forces of this link
Expand Down Expand Up @@ -196,10 +200,14 @@ def propagate(

return (U, d, u, MA, pA), None

(U, d, u, MA, pA), _ = jax.lax.scan(
f=loop_body_pass2,
init=pass_2_carry,
xs=np.flip(np.arange(start=1, stop=model.NB)),
(U, d, u, MA, pA), _ = (
jax.lax.scan(
f=loop_body_pass2,
init=pass_2_carry,
xs=np.flip(np.arange(start=1, stop=model.NB)),
)
if model.NB > 1
else [(U, d, u, MA, pA), None]
)

if model.is_floating_base:
Expand All @@ -226,15 +234,19 @@ def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> Tuple[Pass3Carry, None]:
qdd_ii = (u[i] - U[i].T @ a_i) / d[i]
qdd = qdd.at[i - 1].set(qdd_ii.squeeze()) if qdd.size != 0 else qdd

a_i = a_i + S[i] * qdd[ii] if qdd.size != 0 else a_i
a_i = a_i + S[i] * qdd[ii]
a = a.at[i].set(a_i)

return (a, qdd), None

(a, qdd), _ = jax.lax.scan(
f=loop_body_pass3,
init=pass_3_carry,
xs=np.arange(1, model.NB),
(a, qdd), _ = (
jax.lax.scan(
f=loop_body_pass3,
init=pass_3_carry,
xs=np.arange(1, model.NB),
)
if model.NB > 1
else [(a, qdd), None]
)

# Handle 1 DoF models
Expand Down
Loading
Loading