-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Callback Sequence Bug #28
Conversation
I still think the ordering is not guaranteed. I am still seeing the progress bars initialize after the first iter and the final update is skipped. I swear that when I tested the code earlier it seemed to work. I have made sure that I am up to date with the Testing with this code: from jax_tqdm import PBar, scan_tqdm
import jax
import jax.numpy as jnp
import jax.random as jr
n = 10
print_rate = 3
arr_size = 9000
@scan_tqdm(n, print_rate, message="")
def step(carry, iter_num):
val, *_ = carry
rand_mat = jr.normal(jr.PRNGKey(iter_num), (arr_size, arr_size))
mat_mul = rand_mat @ rand_mat.T
val += 1e-8 * jnp.sum(mat_mul)
# jax.debug.print("iter = {}", iter_num)
# jax.debug.print("carry = {}",carry)
return (val,), val
def map_func(i):
# Wrap the initial value and pass the
# progress bar index
init_carry = PBar(id=i, carry=(0,))
final_val, all_vals = jax.lax.scan(
step, init_carry, jax.numpy.arange(n)
)
return (
final_val.carry,
all_vals,
)
n_bars = 2
final_val, all_vals = jax.vmap(map_func)(jnp.arange(n_bars)) |
Interestingly, if I print jax.debug.print("carry = {}",carry, ordered=False) in the step function of my example, the code works properly... I think this was why it was working for me earlier but not in the example code where I commented it out. It doesn't work if It seems like the body function needs a callback to correctly order the rest? Could this be why testing with The progress bar functions incorrectly if I only print the iter regardless of |
Another observation: If I instead modify my test example above to directly operate on For example, this results in the correct progress bar behavior: @scan_tqdm(n, print_rate, message="")
def step(carry, iter_num):
jax.debug.print("carry = {}",carry, ordered=False) # works if false
rand_mat = jr.normal(jr.PRNGKey(iter_num), (arr_size, arr_size))
mat_mul = rand_mat @ rand_mat.T
carry += 1e-8 * jnp.sum(mat_mul)
return carry, carry
def map_func(i):
# Wrap the initial value and pass the
# progress bar index
init_carry = PBar(id=i, carry=0)
final_val, all_vals = jax.lax.scan(
step, init_carry, jax.numpy.arange(n)
)
return (
final_val.carry,
all_vals,
) while this does not: @scan_tqdm(n, print_rate, message="")
def step(carry, iter_num):
rand_mat = jr.normal(jr.PRNGKey(iter_num), (arr_size, arr_size))
mat_mul = rand_mat @ rand_mat.T
carry += 1e-8 * jnp.sum(mat_mul)
jax.debug.print("carry = {}",iter_num) # doesnt work regardless of ordered
return carry, carry |
This is a strange one. Thinking about the final iteration of the loop/scan, if the print-rate is 1, it should
and these all happen inside the same scan/loop iteration. The previous example @scan_tqdm(n, print_rate)
def step(carry, stuff):
rand_mat = jr.normal(jr.PRNGKey(carry), (arr_size, arr_size))
mat_mul = rand_mat @ rand_mat.T
return carry + 1, stuff + 1e-8 * jnp.sum(mat_mul) works correctly in this case, but this example @scan_tqdm(n, print_rate)
def step(carry, iter_num):
val, *_ = carry
rand_mat = jr.normal(jr.PRNGKey(iter_num), (arr_size, arr_size))
mat_mul = rand_mat @ rand_mat.T
val += 1e-8 * jnp.sum(mat_mul)
return (val,), val does not (i.e. the final step jumps in one go, indicating the call-backs are not ordered correctly). Explicitly passing both the carry and mapped values (rather than just the carry) seems to fix this. i.e. going from carry = update_progress_bar(carry, iter_num, bar_id=bar_id)
result = func(carry, x) to carry, x = update_progress_bar((carry, x), iter_num, bar_id=bar_id)
result = func(carry, x) Just pushed these changes. I'm struggling to see the logic in this case, it's like the step in the second case has no dependence on the carry argument, but it clearly does. Will try looking at the JAX generated AST to see if there is something obvious. |
Oh actually looking at it, in the first example, the long computation depends on the carry rand_mat = jr.normal(jr.PRNGKey(carry), (arr_size, arr_size)) but in the second only the iteration rand_mat = jr.normal(jr.PRNGKey(iter_num), (arr_size, arr_size)) So they must be getting re-ordered differently with respect to the callback. |
Going to merge and release this, as it seems to fix what must have been a long running bug! |
Fix bug where tqdm callbacks were not being correctly ordered around either side of computation.
Also tweak how updates are performed: