-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Finalize the functional APIs and replace the OOP classes #108
Conversation
The low-level RBDA need the inputs to be in inertial-fixed representation. We use JaxSimModelReferences to automatically convert them.
Enhance parametric hardware models
Remove old OOP classes
@diegoferigo which CUDA version is used to build jaxlib in your case? |
Today I built a new Docker image with a fresh environment. You can find details below. env.yml
|
LGTM, apart from the fact that for some reason, some tests fail with the following environment: conda envname: testJAX
channels:
- nvidia
- conda-forge
dependencies:
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=2_gnu
- absl-py=2.1.0=pyhd8ed1ab_0
- ampl-mp=3.1.0=h2cc385e_1006
- assimp=5.3.1=h8343317_3
- attr=2.5.1=h166bdaf_1
- black=24.3.0=py312h7900ff3_0
- bzip2=1.0.8=hd590300_5
- c-ares=1.27.0=hd590300_0
- ca-certificates=2024.2.2=hbcca054_0
- click=8.1.7=unix_pyh707e725_0
- colorama=0.4.6=pyhd8ed1ab_0
- cuda-nvcc=12.4.99=0
- cuda-version=11.8=h70ddcb2_3
- cudatoolkit=11.8.0=h4ba93d1_13
- cudnn=8.9.7.29=hbc23b4c_3
- dbus=1.13.6=h5008d03_3
- eigen=3.4.0=h00ab1b0_0
- etils=1.8.0=pyhd8ed1ab_0
- exceptiongroup=1.2.0=pyhd8ed1ab_2
- expat=2.6.2=h59595ed_0
- fsspec=2024.3.1=pyhca7485f_0
- gettext=0.21.1=h27087fc_0
- glfw=3.4=hd590300_0
- icu=73.2=h59595ed_0
- idyntree=12.0.0=py312h7eadc23_0
- importlib-metadata=7.1.0=pyha770c72_0
- importlib_metadata=7.1.0=hd8ed1ab_0
- importlib_resources=6.4.0=pyhd8ed1ab_0
- iniconfig=2.0.0=pyhd8ed1ab_0
- ipopt=3.14.14=h04b96a2_1
- irrlicht=1.8.5=h2a6caf8_4
- isort=5.13.2=pyhd8ed1ab_0
- jax=0.4.25=pyhd8ed1ab_0
- jaxlib=0.4.23=cuda118py312h5cced90_200
- lame=3.100=h166bdaf_1003
- ld_impl_linux-64=2.40=h41732ed_0
- libabseil=20240116.1=cxx17_h59595ed_2
- libblas=3.9.0=21_linux64_openblas
- libboost=1.84.0=h8013b2b_2
- libcap=2.69=h0f662aa_0
- libcblas=3.9.0=21_linux64_openblas
- libccd-double=2.1=h59595ed_3
- libedit=3.1.20191231=he28a2e2_2
- libexpat=2.6.2=h59595ed_0
- libffi=3.4.2=h7f98852_5
- libflac=1.4.3=h59595ed_0
- libgcc-ng=13.2.0=h807b86a_5
- libgcrypt=1.10.3=hd590300_0
- libgfortran-ng=13.2.0=h69a702a_5
- libgfortran5=13.2.0=ha4646dd_5
- libglib=2.80.0=hf2295e7_1
- libglu=9.0.0=hac7e632_1003
- libgomp=13.2.0=h807b86a_5
- libgpg-error=1.48=h71f35ed_0
- libgrpc=1.62.1=h15f2491_0
- libhwloc=2.9.3=default_h554bfaf_1009
- libiconv=1.17=hd590300_2
- libjpeg-turbo=3.0.0=hd590300_1
- liblapack=3.9.0=21_linux64_openblas
- libmujoco=3.1.3=hfbbffa6_1
- libnsl=2.0.1=hd590300_0
- libogg=1.3.4=h7f98852_1
- libopenblas=0.3.26=pthreads_h413a1c8_0
- libopus=1.3.1=h7f98852_1
- libosqp=0.6.3=h59595ed_0
- libpng=1.6.43=h2797004_0
- libprotobuf=4.25.3=h08a7969_0
- libqdldl=0.1.5=h27087fc_1
- libre2-11=2023.09.01=h5a48ba9_2
- libscotch=7.0.4=h91e35bf_1
- libsndfile=1.2.2=hc60ed4a_1
- libspral=2023.09.07=h6aa6db2_2
- libsqlite=3.45.2=h2797004_0
- libstdcxx-ng=13.2.0=h7e041cc_5
- libsystemd0=255=h3516f8a_1
- libuuid=2.38.1=h0b41bf4_0
- libvorbis=1.3.7=h9c3ff4c_0
- libxcb=1.15=h0b41bf4_0
- libxcrypt=4.4.36=hd590300_1
- libxkbcommon=1.7.0=h662e7e4_0
- libxml2=2.12.6=h232c23b_1
- libxslt=1.1.39=h76b75d6_0
- libzlib=1.2.13=hd590300_5
- lodepng=20220109=h924138e_0
- lxml=5.1.0=py312h37b5203_0
- lz4-c=1.9.4=hcb278e6_0
- metis=5.1.0=h59595ed_1007
- ml_dtypes=0.3.2=py312hfb8ada1_0
- mpg123=1.32.4=h59595ed_0
- mujoco=3.1.3=ha770c72_1
- mujoco-python=3.1.3=py312h276ad9d_1
- mujoco-samples=3.1.3=h59595ed_1
- mujoco-simulate=3.1.3=h59595ed_1
- mumps-include=5.6.2=ha770c72_4
- mumps-seq=5.6.2=hfef103a_4
- mypy_extensions=1.0.0=pyha770c72_0
- nccl=2.20.5.1=h6103f9b_0
- ncurses=6.4.20240210=h59595ed_0
- numpy=1.26.4=py312heda63a1_0
- openssl=3.2.1=hd590300_1
- opt-einsum=3.3.0=hd8ed1ab_2
- opt_einsum=3.3.0=pyhc1e730c_2
- osqp-eigen=0.8.1=hdd734ac_1
- packaging=24.0=pyhd8ed1ab_0
- pathspec=0.12.1=pyhd8ed1ab_0
- pcre2=10.43=hcad00b1_0
- pip=24.0=pyhd8ed1ab_0
- platformdirs=4.2.0=pyhd8ed1ab_0
- pluggy=1.4.0=pyhd8ed1ab_0
- pptree=3.1=pyhd8ed1ab_0
- pthread-stubs=0.4=h36c2ea0_1001
- pulseaudio-client=17.0=hb77b528_0
- pybind11-abi=4=hd8ed1ab_3
- pyglfw=2.7.0=py312h7900ff3_0
- pyopengl=3.1.6=pyhd8ed1ab_1
- pytest=8.1.1=pyhd8ed1ab_0
- python=3.12.2=hab00c5b_0_cpython
- python_abi=3.12=4_cp312
- qhull=2020.2=h4bd325d_2
- re2=2023.09.01=h7f4b329_2
- readline=8.2=h8228510_1
- scipy=1.12.0=py312heda63a1_2
- scotch=7.0.4=h23d43cc_1
- sdl=1.2.68=h293081c_0
- sdl2=2.28.5=hdbcbe63_1
- setuptools=69.2.0=pyhd8ed1ab_0
- tinyxml2=10.0.0=h59595ed_0
- tk=8.6.13=noxft_h4845f30_101
- tomli=2.0.1=pyhd8ed1ab_0
- typing_extensions=4.10.0=pyha770c72_0
- tzdata=2024a=h0c530f3_0
- unixodbc=2.3.12=h661eb56_0
- wayland=1.22.0=h8c25dac_1
- wheel=0.43.0=pyhd8ed1ab_1
- xkeyboard-config=2.41=hd590300_0
- xorg-kbproto=1.0.7=h7f98852_1002
- xorg-libx11=1.8.7=h8ee46fc_0
- xorg-libxau=1.0.11=hd590300_0
- xorg-libxdmcp=1.1.3=h7f98852_0
- xorg-libxext=1.3.4=h0b41bf4_2
- xorg-libxinerama=1.1.5=h27087fc_0
- xorg-xextproto=7.3.0=h0b41bf4_1003
- xorg-xproto=7.0.31=h7f98852_1007
- xz=5.2.6=h166bdaf_0
- zipp=3.17.0=pyhd8ed1ab_0
- zlib=1.2.13=hd590300_5
- zstd=1.5.5=hfc55251_0
- pip:
- asttokens==2.4.1
- cfgv==3.4.0
- coloredlogs==15.0.1
- contourpy==1.2.0
- cycler==0.12.1
- decorator==5.1.1
- distlib==0.3.8
- docstring-parser==0.16
- executing==2.0.1
- filelock==3.13.3
- fonttools==4.50.0
- gitdb==4.0.11
- gitpython==3.1.42
- humanfriendly==10.0
- icdiff==2.0.7
- identify==2.5.35
- ipython==8.22.2
- jax-dataclasses==1.6.0
- jaxlie==1.3.4
- jaxsim==0.2.dev363
- jedi==0.19.1
- kiwisolver==1.4.5
- markdown-it-py==3.0.0
- mashumaro==3.12
- matplotlib==3.8.3
- matplotlib-inline==0.1.6
- mdurl==0.1.2
- mediapy==1.2.0
- nodeenv==1.8.0
- parso==0.8.3
- pexpect==4.9.0
- pillow==10.2.0
- pprintpp==0.4.0
- pre-commit==3.7.0
- prompt-toolkit==3.0.43
- ptyprocess==0.7.0
- pure-eval==0.2.2
- pygments==2.17.2
- pyparsing==3.1.2
- pytest-icdiff==0.9
- python-dateutil==2.9.0.post0
- pyyaml==6.0.1
- resolve-robotics-uri-py==0.2.0
- rich==13.7.1
- robot-descriptions==1.9.0
- rod==0.2.0
- shtab==1.7.1
- six==1.16.0
- smmap==5.0.1
- stack-data==0.6.3
- tokenize-rt==5.2.0
- tqdm==4.66.2
- traitlets==5.14.2
- tyro==0.7.3
- virtualenv==20.25.1
- wcwidth==0.2.13
- xmltodict==0.13.0 I can finalize #124 now if you want to include it into this PR |
I'm not really sure what's going on there, I see that there's a mixture of cuda11, cuda12, together with stuff from both the nvidia and conda-forge channel. As first thought, I blame that to be the culprit. My environment uses the latest jaxlib rebuild1 with the new official cuda12 all coming from conda-forge. Then, CI install everything from PyPI. I guess this is enough for merging, we can dig deeper on that particular environment later.
No problem, we can merge it separately. I'd like to give docstring a last round of review before the release. Footnotes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left some minor comments, which can be addressed in a following PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably a good idea for a future release might be to use pixi
for the CI, so we don't need to install the whole gazebo package from apt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While we nowadays use widely conda-forge, I still think that most people (especially those that want to give JaxSim a quick try) will use PyPI in a virtualenv. I'd agree to add a pixi
or conda
test in parallel.
from .simulation.ode_integration import IntegratorType | ||
from .simulation.simulator import JaxSim | ||
from . import terrain # isort:skip | ||
from . import api, integrators, logging, math, rbda |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reimporting logging
from line 1
from . import api, integrators, logging, math, rbda | |
from . import api, integrators, math, rbda |
params_next_accepted = params_next | dict( | ||
dt0=jnp.clip( | ||
jax.lax.select( | ||
pred=break_loop, | ||
on_true=params["dt0"], | ||
on_false=Δt_next, | ||
), | ||
self.dt_min, | ||
self.dt_max, | ||
) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Calls to dict
have been shown to be slower than just using the dict literal. We might prefer to use those
params_next_accepted = params_next | dict( | |
dt0=jnp.clip( | |
jax.lax.select( | |
pred=break_loop, | |
on_true=params["dt0"], | |
on_false=Δt_next, | |
), | |
self.dt_min, | |
self.dt_max, | |
) | |
) | |
params_next_accepted = params_next | { | |
"dt0": jnp.clip( | |
jax.lax.select( | |
pred=break_loop, | |
on_true=params["dt0"], | |
on_false=Δt_next, | |
), | |
self.dt_min, | |
self.dt_max, | |
) | |
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remember that this is true only in plain Python, not in jit-compiled code.
@@ -0,0 +1,93 @@ | |||
import jax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused import?
import jax |
**( | ||
dict(joint_forces=joint_forces, link_forces=link_forces) | integrator_kwargs | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same suggestion for using dict literal instead of calling dict
**( | |
dict(joint_forces=joint_forces, link_forces=link_forces) | integrator_kwargs | |
), | |
**( | |
{"joint_forces": joint_forces, "link_forces": link_forces} | |
| integrator_kwargs | |
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In jit-compiled code, this change has no effect.
from jaxsim.simulation.ode_data import ODEState | ||
from jaxsim.integrators import Time | ||
from jaxsim.math import Quaternion | ||
from jaxsim.utils import Mutability |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused import?
from jaxsim.utils import Mutability |
@@ -2,7 +2,9 @@ | |||
import contextlib | |||
import copy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused import
import copy |
This is a container PR merging the new structure of upcoming JaxSim release.
Extends the initial draft of the functional APIs introduced in #88 and #99 with the following PRs:
jaxsim.api
package away fromPhysicsModel
#112JaxsimDataclass
#116api.contact.in_contact
#119Note that tests running on CPU work fine, as in CI. For some reason, on my NVIDIA GeForce GTX 1650 Ti, testing the AD of a simulation step on a reduced ErgoCub model fails with the following error:
The test suite on GPU can be executed as follows:
pytest -k "not test_ad_integration[ergocub_reduced]"
📚 Documentation preview 📚: https://jaxsim--108.org.readthedocs.build//108/