-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve compilation time of RBDAs for models with many DoFs #153
Conversation
BenchmarksTL;DR:
Current VersionIn [2]: %time _ = js.model.free_floating_mass_matrix(model, data0)
CPU times: user 8.09 s, sys: 24 ms, total: 8.12 s
Wall time: 8.07 s
In [3]: %timeit _ = js.model.free_floating_mass_matrix(model, data0)
293 µs ± 9.05 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
In [4]: %time _ = js.model.free_floating_bias_forces(model, data0)
CPU times: user 9.67 s, sys: 109 ms, total: 9.78 s
Wall time: 9.7 s
In [5]: %timeit _ = js.model.free_floating_bias_forces(model, data0)
210 µs ± 4.97 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
In [6]: %time _ = js.model.link_bias_accelerations(model, data0)
CPU times: user 23.3 s, sys: 28.7 ms, total: 23.3 s
Wall time: 23.2 s
In [7]: %timeit _ = js.model.link_bias_accelerations(model, data0)
250 µs ± 2.01 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) ImprovedIn [2]: %time _ = js.model.free_floating_mass_matrix(model, data0)
CPU times: user 1.04 s, sys: 20.8 ms, total: 1.06 s
Wall time: 1.01 s
In [3]: %timeit _ = js.model.free_floating_mass_matrix(model, data0)
265 µs ± 7.79 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
In [4]: %time _ = js.model.free_floating_bias_forces(model, data0)
CPU times: user 1.87 s, sys: 87.3 ms, total: 1.95 s
Wall time: 1.88 s
In [5]: %timeit _ = js.model.free_floating_bias_forces(model, data0)
144 µs ± 1.65 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
In [6]: %time _ = js.model.link_bias_accelerations(model, data0)
CPU times: user 1.66 s, sys: 23.7 ms, total: 1.68 s
Wall time: 1.62 s
In [7]: %timeit _ = js.model.link_bias_accelerations(model, data0)
140 µs ± 708 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each) Test Scriptimport os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
import jax.numpy as jnp
import jaxsim.api as js
import numpy as np
import resolve_robotics_uri_py
import rod
from jaxsim import VelRepr, integrators
import jaxsim.api as js
urdf_path = resolve_robotics_uri_py.resolve_robotics_uri(
uri="model://ergoCubSN001/model.urdf"
)
rod_sdf = rod.Sdf.load(sdf=urdf_path)
model = js.model.JaxSimModel.build_from_model_description(
model_description=rod_sdf.model,
)
data0 = js.data.JaxSimModelData.build(
model=model,
base_position=jnp.array([0, 0, 0.85]),
velocity_representation=VelRepr.Inertial,
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, nicely done! The benchmarks look great, this was the last remarkable speed-up that was still pending (mainly after #121) after the development of the new functional APIs (#88 and #108) 🚀 It solves a longstanding TODO comment left in the code.
Here my first comments, let's start with these ones before moving on with the remaining.
f774b3d
to
4d14208
Compare
Deprecate `JointDescriptor`, introduce new `JointType`, separate `JointType` from `JointGenericAxis`
Co-authored-by: Diego Ferigo <diego.ferigo@iit.it>
- Use explicit joint type names - Use `jnp.newaxis` instead of `None` - Use dataclass for `JointType` instead of metaclass Co-authored-by: Diego Ferigo <diego.ferigo@iit.it>
4d14208
to
4f142bc
Compare
Thanks @diegoferigo! I've noticed that we saved 12 minutes on the CI 🌱 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Last comment to enable future extensions.
Co-authored-by: Diego Ferigo <diego.ferigo@iit.it>
e3d0f16
to
16edc05
Compare
Great! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! Good to go 🚀
I believe we suffer again the problems of #103 (static attributes that are jax.Array
). However, it's not entirely dependent on this PR and we need to properly address all of them.
This PR drastically reduces the compilation time of RBDAs and of functions that involve joint-related calculations. In particular:
enum.IntEnum
for theJointType
has been reprecated in favor of a leaner class that simply instanciates subclasses for which the only value is an integer representing the joint typeJointDescriptor
class has been deprecated, separating theJointGenericAxis
from theJointType
completely. In this way, not only we avoid a circular interdependency between attributes ofJointDescription
, but we are able to pass the two objects as separate arguments to functions, making the usevmap
more straightforward and less error-pronejoint_transforms_and_motion_subspaces
, which is used in every RBDA, gets compiled fasterCloses #151
📚 Documentation preview 📚: https://jaxsim--153.org.readthedocs.build//153/