Skip to content

Commit

Permalink
Fix Callback Sequence Bug (#28)
Browse files Browse the repository at this point in the history
* Fix initial update and print-rate=1 bugs

* Cleanup

* Enforce dependency between callbacks in  scan/loop updates

* Fix bug when closing multiple progress bars

* Pop progress-bars from storage

* Formatting and docstring tweaks

* Pass all scan/loop arguments through update callback

* Increment patch version
  • Loading branch information
zombie-einstein authored Oct 14, 2024
1 parent f216dcb commit f824beb
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 51 deletions.
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# JAX-tqdm
# JAX-Tqdm

Add a [tqdm](https://github.com/tqdm/tqdm) progress bar to your JAX scans and loops.

Expand All @@ -10,9 +10,9 @@ Install with pip:
pip install jax-tqdm
```

## Example usage
## Example Usage

### in `jax.lax.scan`
### In `jax.lax.scan`

```python
from jax_tqdm import scan_tqdm
Expand All @@ -28,7 +28,7 @@ def step(carry, x):
last_number, all_numbers = lax.scan(step, 0, jnp.arange(n))
```

### in `jax.lax.fori_loop`
### In `jax.lax.fori_loop`

```python
from jax_tqdm import loop_tqdm
Expand All @@ -43,7 +43,7 @@ def step(i, val):
last_number = lax.fori_loop(0, n, step, 0)
```

### Scans & Loops Inside VMAP
### Scans & Loops Inside Vmap

For scans and loops inside a map, jax-tqdm can print stacked progress bars
showing the individual progress of each process. To do this you can wrap
Expand Down Expand Up @@ -101,7 +101,7 @@ last_number, all_numbers = lax.scan(step, 0, jnp.arange(n))

will update every other step.

### Progress bar type
### Progress Bar Type

You can select the [tqdm](https://github.com/tqdm/tqdm) [submodule](https://github.com/tqdm/tqdm/tree/master?tab=readme-ov-file#submodules) manually with the `tqdm_type` option. The options are `'std'`, `'notebook'`, or `'auto'`.
```python
Expand All @@ -118,7 +118,7 @@ def step(carry, x):
last_number, all_numbers = lax.scan(step, 0, jnp.arange(n))
```

### Progress bar options
### Progress Bar Options

Any additional keyword arguments are passed to the [tqdm](https://github.com/tqdm/tqdm)
progress bar constructor. For example:
Expand All @@ -137,7 +137,7 @@ def step(carry, x):
last_number, all_numbers = lax.scan(step, 0, jnp.arange(n))
```

## Why JAX-tqdm?
## Why JAX-Tqdm?

JAX functions are [pure](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions),
so side effects such as printing progress when running scans and loops are not allowed.
Expand Down
103 changes: 61 additions & 42 deletions jax_tqdm/pbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def scan_tqdm(
Progress bar wrapping function.
"""

_update_progress_bar, close_tqdm = build_tqdm(n, print_rate, tqdm_type, **kwargs)
update_progress_bar, close_tqdm = build_tqdm(n, print_rate, tqdm_type, **kwargs)

def _scan_tqdm(func):
"""Decorator that adds a tqdm progress bar to `body_fun` used in `jax.lax.scan`.
Expand All @@ -59,12 +59,12 @@ def wrapper_progress_bar(carry, x):
if isinstance(carry, PBar):
bar_id = carry.id
carry = carry.carry
_update_progress_bar(iter_num, bar_id=bar_id)
carry, x = update_progress_bar((carry, x), iter_num, bar_id=bar_id)
result = func(carry, x)
result = (PBar(id=bar_id, carry=result[0]), result[1])
return close_tqdm(result, iter_num, bar_id=bar_id)
else:
_update_progress_bar(iter_num)
carry, x = update_progress_bar((carry, x), iter_num)
result = func(carry, x)
return close_tqdm(result, iter_num)

Expand All @@ -84,7 +84,7 @@ def loop_tqdm(
Parameters
----------
n : int
n: int
Number of iterations.
print_rate: int
Optional integer rate at which the progress bar will be updated,
Expand All @@ -100,7 +100,7 @@ def loop_tqdm(
Progress bar wrapping function.
"""

_update_progress_bar, close_tqdm = build_tqdm(n, print_rate, tqdm_type, **kwargs)
update_progress_bar, close_tqdm = build_tqdm(n, print_rate, tqdm_type, **kwargs)

def _loop_tqdm(func):
"""
Expand All @@ -112,12 +112,12 @@ def wrapper_progress_bar(i, val):
if isinstance(val, PBar):
bar_id = val.id
val = val.carry
_update_progress_bar(i, bar_id=bar_id)
i, val = update_progress_bar((i, val), i, bar_id=bar_id)
result = func(i, val)
result = PBar(id=bar_id, carry=result)
return close_tqdm(result, i, bar_id=bar_id)
else:
_update_progress_bar(i)
i, val = update_progress_bar((i, val), i)
result = func(i, val)
return close_tqdm(result, i)

Expand All @@ -134,6 +134,18 @@ def build_tqdm(
) -> typing.Tuple[typing.Callable, typing.Callable]:
"""
Build the tqdm progress bar on the host
Parameters
----------
n: int
Number of updates
print_rate: int
Optional integer rate at which the progress bar will be updated,
If ``None`` the print rate will 1/20th of the total number of steps.
tqdm_type: str
Type of progress-bar, should be one of "auto", "std", or "notebook".
**kwargs
Extra keyword arguments to pass to tqdm.
"""

if tqdm_type not in ("auto", "std", "notebook"):
Expand Down Expand Up @@ -167,50 +179,57 @@ def build_tqdm(
)

remainder = n % print_rate
remainder = remainder if remainder > 0 else print_rate

def _define_tqdm(_arg, bar_id: int):
def _define_tqdm(bar_id: int):
bar_id = int(bar_id)
tqdm_bars[bar_id] = pbar(range(n), position=bar_id + position_offset, **kwargs)
tqdm_bars[bar_id].set_description(message, refresh=False)
tqdm_bars[bar_id] = pbar(
total=n,
position=bar_id + position_offset,
desc=message,
**kwargs,
)

def _update_tqdm(bar_id: int):
tqdm_bars[int(bar_id)].update(print_rate)

def _update_tqdm(arg, bar_id: int):
tqdm_bars[int(bar_id)].update(int(arg))
def _close_tqdm(bar_id: int):
_pbar = tqdm_bars.pop(int(bar_id))
_pbar.update(remainder)
_pbar.clear()
_pbar.close()

def _update_progress_bar(iter_num, bar_id: int = 0):
def update_progress_bar(carry: typing.Any, iter_num: int, bar_id: int = 0):
"""Updates tqdm from a JAX scan or loop"""
_ = jax.lax.cond(
iter_num == 0,
lambda _: callback(_define_tqdm, None, bar_id, ordered=True),
lambda _: None,
operand=None,
)

_ = jax.lax.cond(
# update tqdm every multiple of `print_rate` except at the end
(iter_num % print_rate == 0) & (iter_num != n - remainder),
lambda _: callback(_update_tqdm, print_rate, bar_id, ordered=True),
lambda _: None,
operand=None,
)
def _inner_init(_i, _carry):
callback(_define_tqdm, bar_id, ordered=True)
return _carry

_ = jax.lax.cond(
# update tqdm by `remainder`
iter_num == n - remainder,
lambda _: callback(_update_tqdm, remainder, bar_id, ordered=True),
lambda _: None,
operand=None,
def _inner_update(i, _carry):
_ = jax.lax.cond(
i % print_rate == 0,
lambda: callback(_update_tqdm, bar_id, ordered=True),
lambda: None,
)
return _carry

carry = jax.lax.cond(
iter_num == 0,
_inner_init,
_inner_update,
iter_num,
carry,
)

def _close_tqdm(_arg, bar_id: int):
tqdm_bars[int(bar_id)].close()
return carry

def close_tqdm(result, iter_num, bar_id: int = 0):
_ = jax.lax.cond(
iter_num == n - 1,
lambda _: callback(_close_tqdm, None, bar_id, ordered=True),
lambda _: None,
operand=None,
)
def close_tqdm(result: typing.Any, iter_num: int, bar_id: int = 0):
def _inner_close(_result):
callback(_close_tqdm, bar_id, ordered=True)
return _result

result = jax.lax.cond(iter_num + 1 == n, _inner_close, lambda r: r, result)
return result

return _update_progress_bar, close_tqdm
return update_progress_bar, close_tqdm
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "jax-tqdm"
version = "0.3.0"
version = "0.3.1"
description = "Tqdm progress bar for JAX scans and loops"
authors = [
"Jeremie Coullon <jeremie.coullon@gmail.com>",
Expand Down

0 comments on commit f824beb

Please sign in to comment.