Releases: jeremiecoullon/jax-tqdm
Bar Update Bug Fix
Fix bug where progress bar initialisation and updates were not correctly called in sequence either side of a scan/loop step function. This resulted in all updates to the progress bar being called after any functionality inside the step. So for longer running functions this would look like:
- The progress bar being initialised/displayed after the first iteration (i.e. delaying the display of the bar)
- Some updates being run directly after each other, causing some updates to look instant/jumpy
Thanks to @andrewlesak for spotting this!
Progress bars will now be updated in the correct sequence (with no API change):
- Bars will be displayed before the first iteration is computed
- Steps between updates be spaced appropriately
Multiple Progress Bars
Adds the ability to update multiple progress bars, to show the progress of individual loops and scans inside a vmapped function.
Select Tqdm Submodule
Adds the ability to select the the tqdm submodule, for example manually selecting 'std' or 'notebook' for progress bar creation (thanks @mdmould)
Bug Fixes
JAX Update
- Use
jax.debug.callback
instead of deprecatedjax.experimental.host_callback
(thanks @BirkhoffG ) - Update min JAX version to
0.4.12
Progress Bar Options
Adds the ability to pass tqdm keyword options as arguments (thanks to @mdmould )
v0.1.1
Add the ability to manually set print-rate
Initial Release
Implements tqdm progress bars for JAX scans and for loops.