Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve compilation time of RBDAs for models with many DoFs #153

Merged
merged 6 commits into from
May 15, 2024

Conversation

flferretti
Copy link
Collaborator

@flferretti flferretti commented May 15, 2024

This PR drastically reduces the compilation time of RBDAs and of functions that involve joint-related calculations. In particular:

  • the use of enum.IntEnum for the JointType has been reprecated in favor of a leaner class that simply instanciates subclasses for which the only value is an integer representing the joint type
  • the JointDescriptor class has been deprecated, separating the JointGenericAxis from the JointType completely. In this way, not only we avoid a circular interdependency between attributes of JointDescription, but we are able to pass the two objects as separate arguments to functions, making the use vmap more straightforward and less error-prone
  • joint_transforms_and_motion_subspaces, which is used in every RBDA, gets compiled faster

Closes #151


📚 Documentation preview 📚: https://jaxsim--153.org.readthedocs.build//153/

@flferretti flferretti self-assigned this May 15, 2024
@flferretti flferretti requested a review from diegoferigo as a code owner May 15, 2024 13:42
@flferretti
Copy link
Collaborator Author

Benchmarks

TL;DR:

Method Compilation Execution
js.model.free_floating_mass_matrix -87.5 % -9.55 %
js.model.free_floating_bias_forces -80.7 % -31.4 %
js.model.link_bias_accelerations -92.8 % -44.0 %

Current Version

In [2]: %time _ = js.model.free_floating_mass_matrix(model, data0)
CPU times: user 8.09 s, sys: 24 ms, total: 8.12 s
Wall time: 8.07 s

In [3]: %timeit _ = js.model.free_floating_mass_matrix(model, data0)
293 µs ± 9.05 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [4]: %time _ = js.model.free_floating_bias_forces(model, data0)
CPU times: user 9.67 s, sys: 109 ms, total: 9.78 s
Wall time: 9.7 s

In [5]: %timeit _ = js.model.free_floating_bias_forces(model, data0)
210 µs ± 4.97 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [6]: %time _ = js.model.link_bias_accelerations(model, data0)
CPU times: user 23.3 s, sys: 28.7 ms, total: 23.3 s
Wall time: 23.2 s

In [7]: %timeit _ = js.model.link_bias_accelerations(model, data0)
250 µs ± 2.01 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Improved

In [2]: %time _ = js.model.free_floating_mass_matrix(model, data0)
CPU times: user 1.04 s, sys: 20.8 ms, total: 1.06 s
Wall time: 1.01 s

In [3]: %timeit _ = js.model.free_floating_mass_matrix(model, data0)
265 µs ± 7.79 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [4]: %time _ = js.model.free_floating_bias_forces(model, data0)
CPU times: user 1.87 s, sys: 87.3 ms, total: 1.95 s
Wall time: 1.88 s

In [5]: %timeit _ = js.model.free_floating_bias_forces(model, data0)
144 µs ± 1.65 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [6]: %time _ = js.model.link_bias_accelerations(model, data0)
CPU times: user 1.66 s, sys: 23.7 ms, total: 1.68 s
Wall time: 1.62 s

In [7]: %timeit _ = js.model.link_bias_accelerations(model, data0)
140 µs ± 708 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Test Script
import os

os.environ["CUDA_VISIBLE_DEVICES"] = ""

import jax.numpy as jnp
import jaxsim.api as js
import numpy as np
import resolve_robotics_uri_py
import rod
from jaxsim import VelRepr, integrators
import jaxsim.api as js

urdf_path = resolve_robotics_uri_py.resolve_robotics_uri(
    uri="model://ergoCubSN001/model.urdf"
)

rod_sdf = rod.Sdf.load(sdf=urdf_path)


model = js.model.JaxSimModel.build_from_model_description(
    model_description=rod_sdf.model,
)

data0 = js.data.JaxSimModelData.build(
    model=model,
    base_position=jnp.array([0, 0, 0.85]),
    velocity_representation=VelRepr.Inertial,
)

Copy link
Member

@diegoferigo diegoferigo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, nicely done! The benchmarks look great, this was the last remarkable speed-up that was still pending (mainly after #121) after the development of the new functional APIs (#88 and #108) 🚀 It solves a longstanding TODO comment left in the code.

Here my first comments, let's start with these ones before moving on with the remaining.

src/jaxsim/api/kin_dyn_parameters.py Outdated Show resolved Hide resolved
src/jaxsim/parsers/descriptions/joint.py Outdated Show resolved Hide resolved
src/jaxsim/parsers/descriptions/joint.py Outdated Show resolved Hide resolved
@flferretti flferretti force-pushed the feature/vmap_joint_computations branch 2 times, most recently from f774b3d to 4d14208 Compare May 15, 2024 15:08
@flferretti flferretti requested a review from diegoferigo May 15, 2024 15:08
flferretti and others added 5 commits May 15, 2024 17:10
Deprecate `JointDescriptor`, introduce new `JointType`, separate
`JointType` from `JointGenericAxis`
Co-authored-by: Diego Ferigo <diego.ferigo@iit.it>
- Use explicit joint type names
- Use `jnp.newaxis` instead of `None`
- Use dataclass for `JointType` instead of metaclass

Co-authored-by: Diego Ferigo <diego.ferigo@iit.it>
@flferretti flferretti force-pushed the feature/vmap_joint_computations branch from 4d14208 to 4f142bc Compare May 15, 2024 15:10
@flferretti
Copy link
Collaborator Author

Thanks @diegoferigo! I've noticed that we saved 12 minutes on the CI 🌱

Copy link
Member

@diegoferigo diegoferigo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last comment to enable future extensions.

src/jaxsim/math/joint_model.py Outdated Show resolved Hide resolved
Co-authored-by: Diego Ferigo <diego.ferigo@iit.it>
@flferretti flferretti force-pushed the feature/vmap_joint_computations branch from e3d0f16 to 16edc05 Compare May 15, 2024 16:35
@traversaro
Copy link
Contributor

Great!

@flferretti flferretti requested a review from diegoferigo May 15, 2024 16:42
Copy link
Member

@diegoferigo diegoferigo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Good to go 🚀

I believe we suffer again the problems of #103 (static attributes that are jax.Array). However, it's not entirely dependent on this PR and we need to properly address all of them.

@flferretti flferretti merged commit c14205a into main May 15, 2024
27 checks passed
@flferretti flferretti deleted the feature/vmap_joint_computations branch May 15, 2024 16:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Address long compilation time in RBDAs for models with many DoFs
3 participants