Skip to content

Commit

Permalink
Adding JAX_LOGGING_LEVEL configuration option
Browse files Browse the repository at this point in the history
  • Loading branch information
rdyro committed Sep 6, 2024
1 parent 3a1567f commit bc55b7e
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 3 deletions.
51 changes: 51 additions & 0 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1449,6 +1449,57 @@ def _update_disable_jit_thread_local(val):
update_thread_local_hook=lambda val: \
update_thread_local_jit_state(default_matmul_precision=val))

_logging_level_handler_set: dict[str, tuple[logging.Handler, int]] = {}

def _update_logging_level_global(logging_level: str | None):
# remove previous handlers
for logger_name, (handler, level) in _logging_level_handler_set.items():
logger = logging.getLogger(logger_name)
logger.removeHandler(handler)
logger.setLevel(level)
_logging_level_handler_set.clear()

if logging_level is None:
return

# attempt to convert the logging level to integer
try:
# logging level is a string representation of an integer
logging_level_num = int(logging_level)
except ValueError:
# logging level is a name string
logging_level_num = logging.getLevelName(logging_level)

# configure the CPP logging level 0 - debug, 1 - info, 2 - warning, 3 - error
cpp_log_level = min(max(0, (logging_level_num // 10) - 1), 3)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = str(cpp_log_level)

handler = logging.StreamHandler()
handler.setLevel(logging_level_num)
handler.setFormatter(logging_config.logging_formatter)

# update jax and jaxlib root loggers for propagation
root_loggers = [logging.getLogger("jax"), logging.getLogger("jaxlib")]
for logger in root_loggers:
logger.setLevel(logging_level_num)
logger.addHandler(handler)
_logging_level_handler_set[logger.name] = (handler, logger.level)


# Don't define a context manager since this isn't threadsafe.
optional_enum_state(
name='jax_logging_level',
enum_values=['NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL',
'0', '10', '20', '30', '40', '50'],
default=None,
help=('Set the correspoding logging level on all jax loggers. Only string'
' values from ["NOTSET", "DEBUG", "INFO", "WARNING", "ERROR",'
' "CRITICAL", "0", "10", "20", "30", "40", "50"] are accepted. If'
' None, the logging level will not be set.'),
update_global_hook=lambda logging_level: \
_update_logging_level_global(logging_level=logging_level)
)

traceback_filtering = enum_state(
name = 'jax_traceback_filtering',
enum_values=["off", "tracebackhide", "remove_frames", "quiet_remove_frames",
Expand Down
9 changes: 6 additions & 3 deletions jax/_src/logging_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
import logging
import sys

logging_formatter = logging.Formatter(
"{levelname}:{asctime}:{name}:{lineno}: {message}", style='{')

_debug_handler = logging.StreamHandler(sys.stderr)
_debug_handler.setLevel(logging.DEBUG)
# Example log message:
# DEBUG:2023-06-07 00:14:40,280:jax._src.xla_bridge:590: Initializing backend 'cpu'
_debug_handler.setFormatter(logging.Formatter(
"{levelname}:{asctime}:{name}:{lineno}: {message}", style='{'))
_debug_handler.setFormatter(logging_formatter)

_debug_enabled_loggers = []

Expand All @@ -34,9 +36,9 @@ def enable_debug_logging(logger_name):
logger_name: the name of the logger, e.g. "jax._src.xla_bridge".
"""
logger = logging.getLogger(logger_name)
_debug_enabled_loggers.append(logger)
logger.addHandler(_debug_handler)
logger.setLevel(logging.DEBUG)
_debug_enabled_loggers.append(logger)


def disable_all_debug_logging():
Expand All @@ -52,3 +54,4 @@ def disable_all_debug_logging():
# difference if not other handlers are attached, but set it back in case
# something else gets attached (e.g. absl logger) and for consistency.
logger.setLevel(logging.WARNING)
_debug_enabled_loggers.clear()
103 changes: 103 additions & 0 deletions tests/logging_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,24 @@ def jax_debug_log_modules(value):
finally:
jax.config.update("jax_debug_log_modules", original_value)

@contextlib.contextmanager
def jax_logging_level(value):
# jax_logging_level doesn't have a context manager, because it's
# not thread-safe. But since tests are always single-threaded, we
# can define one here.
original_value = jax.config.jax_logging_level
original_logging_value = logging.getLogger("jax").level
jax.config.update("jax_logging_level", value)
try:
yield
finally:
# in case original_value is None, which skips setting logging value
# we also set the logging value directly pulled from logger.level
jax.config.update("jax_logging_level", str(original_logging_value))
#print("Restoring original value of", original_value)
#import pdb; pdb.set_trace()
jax.config.update("jax_logging_level", original_value)


@contextlib.contextmanager
def capture_jax_logs():
Expand Down Expand Up @@ -142,6 +160,91 @@ def test_debug_logging(self):
jax.jit(lambda x: x + 1)(1)
self.assertEmpty(log_output.getvalue())

def test_simple_logging_level_changes(self):
logger = logging.getLogger("jax")
num_lvls = [logging.getLevelName(10 * i) for i in range(5, 0, -1)]
str_lvls = [logging.getLevelName(lvl) for lvl in num_lvls]
for i, level in enumerate(num_lvls):
with jax_logging_level(str(level)):
with self.assertLogs(logger=logger, level=level) as cm:
[logger.log(lvl, "%s log", lvl) for lvl in str_lvls]
self.assertLen(cm.output, i + 1)
for i, level in enumerate(str_lvls):
with jax_logging_level(str(level)):
with self.assertLogs(logger=logger, level=level) as cm:
[logger.log(lvl, "%s log", lvl) for lvl in str_lvls]
self.assertLen(cm.output, i + 1)

def test_double_logging_not_present(self):
logger = logging.getLogger("jax")

# test for simple double logging (is absent)
with jax_logging_level("DEBUG"):
with jax_debug_log_modules("jax._src.cache_key"):
f = jax.jit(lambda x: x)
with self.assertLogs(logger=logger, level="DEBUG") as cm:
_ = f(jax.numpy.ones(10))
self.assertTrue(any("jax._src.cache_key" in line for line in cm.output))
# assert logs are not repeatedly printed (perhaps without a prefix)
repeats = 0
for i in range(len(cm.output) - 1):
if cm.output[i] in cm.output[i+1] or cm.output[i+1] in cm.output[i]:
repeats += 1
self.assertLess(repeats / (2 * len(cm.output)), 0.2) # less than 20%

def test_debug_log_modules_overrides_logging_level(self):
logger = logging.getLogger("jax")

# tests that logs are present (debug_log_modules overrides logging_level)
with jax_logging_level("INFO"):
with jax_debug_log_modules("jax._src.cache_key"):
f = jax.jit(lambda x: x)
with self.assertLogs(logger=logger, level="DEBUG") as cm:
_ = f(jax.numpy.ones(10))
self.assertTrue(any("jax._src.cache_key" in line for line in cm.output))
# assert logs are not repeatedly printed (perhaps without a prefix)
repeats = 0
for i in range(len(cm.output) - 1):
if cm.output[i] in cm.output[i+1] or cm.output[i+1] in cm.output[i]:
repeats += 1
self.assertLess(repeats / (2 * len(cm.output)), 0.2) # less than 20%

def test_debug_log_modules_of_jax_does_not_silence_future_modules(self):
logger = logging.getLogger("jax")

# tests that logs are present (debug_log_modules overrides logging_level)
with jax_logging_level("DEBUG"):
f = jax.jit(lambda x: x)
with self.assertLogs(logger=logger, level="DEBUG") as cm:
_ = f(jax.numpy.ones(10))
self.assertTrue(any("jax._src.compiler" in line for line in cm.output))

with jax_debug_log_modules("jax._src.cache_key"):
f = jax.jit(lambda x: x)
with self.assertLogs(logger=logger, level="DEBUG") as cm:
_ = f(jax.numpy.ones(10))
self.assertTrue(any("jax._src.cache_key" in line for line in cm.output))
# assert logs are not repeatedly printed (perhaps without a prefix)
repeats = 0
for i in range(len(cm.output) - 1):
if cm.output[i] in cm.output[i+1] or cm.output[i+1] in cm.output[i]:
repeats += 1
self.assertLess(repeats / (2 * len(cm.output)), 0.2) # less than 20%

with jax_debug_log_modules("jax"):
a = 1 # noop

f = jax.jit(lambda x: x)
with self.assertLogs(logger=logger, level="DEBUG") as cm:
_ = f(jax.numpy.ones(10))
self.assertTrue(any("jax._src.compiler" in line for line in cm.output))

with self.assertLogs(logger=logger, level="DEBUG") as cm:
logger_ = logging.getLogger("jax.some_future_downstream_module")
logger_.debug("Test message")
self.assertLen(cm.output, 1)
self.assertIn("Test message", cm.output[0])


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit bc55b7e

Please sign in to comment.