Skip to content

Commit

Permalink
[jax] Add computation name to cache hit logging.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 414697336
  • Loading branch information
trevorcai authored and jax authors committed Dec 7, 2021
1 parent 803d3f2 commit 56f029f
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion jax/_src/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,8 @@ def compile_or_get_cached(backend, computation, compile_options):
if cc.is_initialized() and backend.platform == 'tpu':
cached_executable = cc.get_executable(computation, compile_options, backend)
if cached_executable is not None:
logging.info('Persistent compilation cache hit')
logging.info('Persistent compilation cache hit for %s.',
computation.name())
return cached_executable
else:
compiled = backend_compile(backend, computation, compile_options)
Expand Down

0 comments on commit 56f029f

Please sign in to comment.