Skip to content

Printing iteration number in jitted functions/loops using host_callback.id_tap #4763

Answered by shoyer
jeremiecoullon asked this question in Q&A
Discussion options

You must be logged in to vote

You definitely need to include a data dependency in order to ensure the callback runs at the correct time.

Here are two versions that work with jit, with and without scan:

from jax.experimental import host_callback
from jax import jit

def progress_bar(arg, transforms):
    i, n_iter, print_rate = arg
    if i % print_rate==0:
        print(f"Iteration {i}/{n_iter}")
    else:
        pass

@jit
def my_python_loop2(a):
    """
    Python loop that increments `a` 100 times
    """
    n_iter, print_rate = 100, 10
    for i in range(n_iter):
        a = host_callback.id_tap(progress_bar, (i, n_iter, print_rate), result=a)
        a += 1
    return a


@jit
def my_python_loop3(a):
    """

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
0 replies
Answer selected by shoyer
Comment options

You must be logged in to vote
1 reply
@shoyer
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants
Converted from issue

This discussion was converted from issue #4763 on November 04, 2020 19:08.