Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Callback Sequence Bug #28

Merged
merged 8 commits into from
Oct 14, 2024
Merged

Fix Callback Sequence Bug #28

merged 8 commits into from
Oct 14, 2024

Conversation

zombie-einstein
Copy link
Collaborator

@zombie-einstein zombie-einstein commented Oct 12, 2024

Fix bug where tqdm callbacks were not being correctly ordered around either side of computation.

Also tweak how updates are performed:

  • Initialise progress bar before first step
  • Check for update every step before inner-function
  • On final step, after inner-function, complete and close progress bar

@andrewlesak
Copy link

andrewlesak commented Oct 12, 2024

I still think the ordering is not guaranteed. I am still seeing the progress bars initialize after the first iter and the final update is skipped. I swear that when I tested the code earlier it seemed to work. I have made sure that I am up to date with the fix_iter_bug branch and still observe incorrect progress bar behavior so I'm not entirely sure whats going wrong. I suspect the ordering of the callbacks are not fixed.

Testing with this code:

from jax_tqdm import PBar, scan_tqdm
import jax
import jax.numpy as jnp
import jax.random as jr

n = 10
print_rate = 3
arr_size = 9000
    
@scan_tqdm(n, print_rate, message="")
def step(carry, iter_num):
    val, *_ = carry
    rand_mat = jr.normal(jr.PRNGKey(iter_num), (arr_size, arr_size)) 
    mat_mul = rand_mat @ rand_mat.T
    val += 1e-8 * jnp.sum(mat_mul)
    # jax.debug.print("iter = {}", iter_num)
    # jax.debug.print("carry = {}",carry)
    return (val,), val

def map_func(i):
    # Wrap the initial value and pass the
    # progress bar index
    init_carry = PBar(id=i, carry=(0,))
    final_val, all_vals = jax.lax.scan(
        step, init_carry, jax.numpy.arange(n)
    )
    return (
        final_val.carry,
        all_vals,
    )

n_bars = 2
final_val, all_vals = jax.vmap(map_func)(jnp.arange(n_bars))

I observe the following:
test_new_branch_n10_r3-ezgif com-optimize

@andrewlesak
Copy link

andrewlesak commented Oct 13, 2024

Interestingly, if I print

jax.debug.print("carry = {}",carry, ordered=False)

in the step function of my example, the code works properly... I think this was why it was working for me earlier but not in the example code where I commented it out. It doesn't work if ordered=True.

It seems like the body function needs a callback to correctly order the rest? Could this be why testing with time.sleep() as mentioned in #27 seemed to produce the correct result?

The progress bar functions incorrectly if I only print the iter regardless of ordered, however. So the behavior seems to depend on callbacks to carry.

@andrewlesak
Copy link

Another observation:

If I instead modify my test example above to directly operate on carry (i.e. I no longer pass a tuple), then the placement of the callback affects how the progress bar functions.

For example, this results in the correct progress bar behavior:

@scan_tqdm(n, print_rate, message="")
def step(carry, iter_num):
    jax.debug.print("carry = {}",carry, ordered=False)  # works if false
    rand_mat = jr.normal(jr.PRNGKey(iter_num), (arr_size, arr_size)) 
    mat_mul = rand_mat @ rand_mat.T
    carry += 1e-8 * jnp.sum(mat_mul)
    return carry, carry

def map_func(i):
    # Wrap the initial value and pass the
    # progress bar index
    init_carry = PBar(id=i, carry=0)
    final_val, all_vals = jax.lax.scan(
        step, init_carry, jax.numpy.arange(n)
    )
    return (
        final_val.carry,
        all_vals,
    )

while this does not:

@scan_tqdm(n, print_rate, message="")
def step(carry, iter_num):
    rand_mat = jr.normal(jr.PRNGKey(iter_num), (arr_size, arr_size)) 
    mat_mul = rand_mat @ rand_mat.T
    carry += 1e-8 * jnp.sum(mat_mul)
    jax.debug.print("carry = {}",iter_num)  # doesnt work regardless of ordered
    return carry, carry

@zombie-einstein
Copy link
Collaborator Author

zombie-einstein commented Oct 13, 2024

This is a strange one. Thinking about the final iteration of the loop/scan, if the print-rate is 1, it should

  • Update the bar to the penultimate step
  • Wait (while doing computation)
  • Finalise the bar and close

and these all happen inside the same scan/loop iteration. The previous example

@scan_tqdm(n, print_rate)
    def step(carry, stuff):
        rand_mat = jr.normal(jr.PRNGKey(carry), (arr_size, arr_size))
        mat_mul = rand_mat @ rand_mat.T
        return carry + 1, stuff + 1e-8 * jnp.sum(mat_mul)

works correctly in this case, but this example

@scan_tqdm(n, print_rate)
def step(carry, iter_num):
    val, *_ = carry
    rand_mat = jr.normal(jr.PRNGKey(iter_num), (arr_size, arr_size)) 
    mat_mul = rand_mat @ rand_mat.T
    val += 1e-8 * jnp.sum(mat_mul)
    return (val,), val

does not (i.e. the final step jumps in one go, indicating the call-backs are not ordered correctly).

Explicitly passing both the carry and mapped values (rather than just the carry) seems to fix this. i.e. going from

carry = update_progress_bar(carry, iter_num, bar_id=bar_id)
result = func(carry, x)

to

carry, x = update_progress_bar((carry, x), iter_num, bar_id=bar_id)
result = func(carry, x)

Just pushed these changes.

I'm struggling to see the logic in this case, it's like the step in the second case has no dependence on the carry argument, but it clearly does. Will try looking at the JAX generated AST to see if there is something obvious.

@zombie-einstein
Copy link
Collaborator Author

Oh actually looking at it, in the first example, the long computation depends on the carry

rand_mat = jr.normal(jr.PRNGKey(carry), (arr_size, arr_size))

but in the second only the iteration

rand_mat = jr.normal(jr.PRNGKey(iter_num), (arr_size, arr_size))

So they must be getting re-ordered differently with respect to the callback.

@zombie-einstein
Copy link
Collaborator Author

Going to merge and release this, as it seems to fix what must have been a long running bug!

@zombie-einstein zombie-einstein changed the title Fix iter bug Fix Callback Sequence Bug Oct 14, 2024
@zombie-einstein zombie-einstein merged commit f824beb into main Oct 14, 2024
3 checks passed
@zombie-einstein zombie-einstein deleted the fix_iter_bug branch October 14, 2024 23:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants