diff --git a/README.md b/README.md index 4b09777..7ead531 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# JAX-tqdm +# JAX-Tqdm Add a [tqdm](https://github.com/tqdm/tqdm) progress bar to your JAX scans and loops. @@ -10,9 +10,9 @@ Install with pip: pip install jax-tqdm ``` -## Example usage +## Example Usage -### in `jax.lax.scan` +### In `jax.lax.scan` ```python from jax_tqdm import scan_tqdm @@ -28,7 +28,7 @@ def step(carry, x): last_number, all_numbers = lax.scan(step, 0, jnp.arange(n)) ``` -### in `jax.lax.fori_loop` +### In `jax.lax.fori_loop` ```python from jax_tqdm import loop_tqdm @@ -43,7 +43,7 @@ def step(i, val): last_number = lax.fori_loop(0, n, step, 0) ``` -### Scans & Loops Inside VMAP +### Scans & Loops Inside Vmap For scans and loops inside a map, jax-tqdm can print stacked progress bars showing the individual progress of each process. To do this you can wrap @@ -101,7 +101,7 @@ last_number, all_numbers = lax.scan(step, 0, jnp.arange(n)) will update every other step. -### Progress bar type +### Progress Bar Type You can select the [tqdm](https://github.com/tqdm/tqdm) [submodule](https://github.com/tqdm/tqdm/tree/master?tab=readme-ov-file#submodules) manually with the `tqdm_type` option. The options are `'std'`, `'notebook'`, or `'auto'`. ```python @@ -118,7 +118,7 @@ def step(carry, x): last_number, all_numbers = lax.scan(step, 0, jnp.arange(n)) ``` -### Progress bar options +### Progress Bar Options Any additional keyword arguments are passed to the [tqdm](https://github.com/tqdm/tqdm) progress bar constructor. For example: @@ -137,7 +137,7 @@ def step(carry, x): last_number, all_numbers = lax.scan(step, 0, jnp.arange(n)) ``` -## Why JAX-tqdm? +## Why JAX-Tqdm? JAX functions are [pure](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions), so side effects such as printing progress when running scans and loops are not allowed. diff --git a/jax_tqdm/pbar.py b/jax_tqdm/pbar.py index d0a72db..dc4a363 100644 --- a/jax_tqdm/pbar.py +++ b/jax_tqdm/pbar.py @@ -41,7 +41,7 @@ def scan_tqdm( Progress bar wrapping function. """ - _update_progress_bar, close_tqdm = build_tqdm(n, print_rate, tqdm_type, **kwargs) + update_progress_bar, close_tqdm = build_tqdm(n, print_rate, tqdm_type, **kwargs) def _scan_tqdm(func): """Decorator that adds a tqdm progress bar to `body_fun` used in `jax.lax.scan`. @@ -59,12 +59,12 @@ def wrapper_progress_bar(carry, x): if isinstance(carry, PBar): bar_id = carry.id carry = carry.carry - _update_progress_bar(iter_num, bar_id=bar_id) + carry, x = update_progress_bar((carry, x), iter_num, bar_id=bar_id) result = func(carry, x) result = (PBar(id=bar_id, carry=result[0]), result[1]) return close_tqdm(result, iter_num, bar_id=bar_id) else: - _update_progress_bar(iter_num) + carry, x = update_progress_bar((carry, x), iter_num) result = func(carry, x) return close_tqdm(result, iter_num) @@ -84,7 +84,7 @@ def loop_tqdm( Parameters ---------- - n : int + n: int Number of iterations. print_rate: int Optional integer rate at which the progress bar will be updated, @@ -100,7 +100,7 @@ def loop_tqdm( Progress bar wrapping function. """ - _update_progress_bar, close_tqdm = build_tqdm(n, print_rate, tqdm_type, **kwargs) + update_progress_bar, close_tqdm = build_tqdm(n, print_rate, tqdm_type, **kwargs) def _loop_tqdm(func): """ @@ -112,12 +112,12 @@ def wrapper_progress_bar(i, val): if isinstance(val, PBar): bar_id = val.id val = val.carry - _update_progress_bar(i, bar_id=bar_id) + i, val = update_progress_bar((i, val), i, bar_id=bar_id) result = func(i, val) result = PBar(id=bar_id, carry=result) return close_tqdm(result, i, bar_id=bar_id) else: - _update_progress_bar(i) + i, val = update_progress_bar((i, val), i) result = func(i, val) return close_tqdm(result, i) @@ -134,6 +134,18 @@ def build_tqdm( ) -> typing.Tuple[typing.Callable, typing.Callable]: """ Build the tqdm progress bar on the host + + Parameters + ---------- + n: int + Number of updates + print_rate: int + Optional integer rate at which the progress bar will be updated, + If ``None`` the print rate will 1/20th of the total number of steps. + tqdm_type: str + Type of progress-bar, should be one of "auto", "std", or "notebook". + **kwargs + Extra keyword arguments to pass to tqdm. """ if tqdm_type not in ("auto", "std", "notebook"): @@ -167,50 +179,57 @@ def build_tqdm( ) remainder = n % print_rate + remainder = remainder if remainder > 0 else print_rate - def _define_tqdm(_arg, bar_id: int): + def _define_tqdm(bar_id: int): bar_id = int(bar_id) - tqdm_bars[bar_id] = pbar(range(n), position=bar_id + position_offset, **kwargs) - tqdm_bars[bar_id].set_description(message, refresh=False) + tqdm_bars[bar_id] = pbar( + total=n, + position=bar_id + position_offset, + desc=message, + **kwargs, + ) + + def _update_tqdm(bar_id: int): + tqdm_bars[int(bar_id)].update(print_rate) - def _update_tqdm(arg, bar_id: int): - tqdm_bars[int(bar_id)].update(int(arg)) + def _close_tqdm(bar_id: int): + _pbar = tqdm_bars.pop(int(bar_id)) + _pbar.update(remainder) + _pbar.clear() + _pbar.close() - def _update_progress_bar(iter_num, bar_id: int = 0): + def update_progress_bar(carry: typing.Any, iter_num: int, bar_id: int = 0): """Updates tqdm from a JAX scan or loop""" - _ = jax.lax.cond( - iter_num == 0, - lambda _: callback(_define_tqdm, None, bar_id, ordered=True), - lambda _: None, - operand=None, - ) - _ = jax.lax.cond( - # update tqdm every multiple of `print_rate` except at the end - (iter_num % print_rate == 0) & (iter_num != n - remainder), - lambda _: callback(_update_tqdm, print_rate, bar_id, ordered=True), - lambda _: None, - operand=None, - ) + def _inner_init(_i, _carry): + callback(_define_tqdm, bar_id, ordered=True) + return _carry - _ = jax.lax.cond( - # update tqdm by `remainder` - iter_num == n - remainder, - lambda _: callback(_update_tqdm, remainder, bar_id, ordered=True), - lambda _: None, - operand=None, + def _inner_update(i, _carry): + _ = jax.lax.cond( + i % print_rate == 0, + lambda: callback(_update_tqdm, bar_id, ordered=True), + lambda: None, + ) + return _carry + + carry = jax.lax.cond( + iter_num == 0, + _inner_init, + _inner_update, + iter_num, + carry, ) - def _close_tqdm(_arg, bar_id: int): - tqdm_bars[int(bar_id)].close() + return carry - def close_tqdm(result, iter_num, bar_id: int = 0): - _ = jax.lax.cond( - iter_num == n - 1, - lambda _: callback(_close_tqdm, None, bar_id, ordered=True), - lambda _: None, - operand=None, - ) + def close_tqdm(result: typing.Any, iter_num: int, bar_id: int = 0): + def _inner_close(_result): + callback(_close_tqdm, bar_id, ordered=True) + return _result + + result = jax.lax.cond(iter_num + 1 == n, _inner_close, lambda r: r, result) return result - return _update_progress_bar, close_tqdm + return update_progress_bar, close_tqdm diff --git a/pyproject.toml b/pyproject.toml index 73adebb..c7dea17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "jax-tqdm" -version = "0.3.0" +version = "0.3.1" description = "Tqdm progress bar for JAX scans and loops" authors = [ "Jeremie Coullon ",