Skip to content

Commit

Permalink
Adding JAX_LOGGING_LEVEL configuration option
Browse files Browse the repository at this point in the history
  • Loading branch information
rdyro committed Oct 7, 2024
1 parent e8cea0d commit bddedd5
Show file tree
Hide file tree
Showing 3 changed files with 280 additions and 21 deletions.
25 changes: 15 additions & 10 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1492,6 +1492,7 @@ def _update_disable_jit_thread_local(val):
update_thread_local_hook=lambda val: \
update_thread_local_jit_state(default_matmul_precision=val))


traceback_filtering = enum_state(
name = 'jax_traceback_filtering',
enum_values=["off", "tracebackhide", "remove_frames", "quiet_remove_frames",
Expand Down Expand Up @@ -1717,23 +1718,27 @@ def transfer_guard(new_val: str) -> Iterator[None]:
stack.enter_context(_transfer_guard(new_val))
yield


def _update_debug_log_modules(module_names_str: str | None):
logging_config.disable_all_debug_logging()
if not module_names_str:
return
module_names = module_names_str.split(',')
for module_name in module_names:
logging_config.enable_debug_logging(module_name)

# Don't define a context manager since this isn't threadsafe.
string_state(
name='jax_debug_log_modules',
default='',
help=('Comma-separated list of module names (e.g. "jax" or '
'"jax._src.xla_bridge,jax._src.dispatch") to enable debug logging '
'for.'),
update_global_hook=_update_debug_log_modules)
update_global_hook=logging_config._update_debug_log_modules)

# Don't define a context manager since this isn't threadsafe.
optional_enum_state(
name='jax_logging_level',
enum_values=['NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
default=logging.getLevelName(logging.getLogger("jax").level),
help=('Set the correspoding logging level on all jax loggers. Only string'
' values from ["NOTSET", "DEBUG", "INFO", "WARNING", "ERROR",'
' "CRITICAL", "0", "10", "20", "30", "40", "50"] are accepted. If'
' None, the logging level will not be set.'),
update_global_hook=lambda logging_level: \
logging_config._update_logging_level_global(logging_level=logging_level)
)

pmap_no_rank_reduction = bool_state(
name='jax_pmap_no_rank_reduction',
Expand Down
85 changes: 74 additions & 11 deletions jax/_src/logging_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,76 @@
# limitations under the License.

import logging
import os
import sys

logging_formatter = logging.Formatter(
"{levelname}:{asctime}:{name}:{lineno}: {message}", style='{')

_logging_level_handler_set: dict[str, tuple[logging.Handler, int]] = {}


_nameToLevel = {
'CRITICAL': logging.CRITICAL,
'FATAL': logging.FATAL,
'ERROR': logging.ERROR,
'WARN': logging.WARNING,
'WARNING': logging.WARNING,
'INFO': logging.INFO,
'DEBUG': logging.DEBUG,
'NOTSET': logging.NOTSET,
}
def _getLevelNamesMapping():
return _nameToLevel.copy()


def _update_logging_level_global(logging_level: str | None) -> None:
# remove previous handlers
for logger_name, (handler, level) in _logging_level_handler_set.items():
logger = logging.getLogger(logger_name)
logger.removeHandler(handler)
logger.setLevel(level)
_logging_level_handler_set.clear()

if logging_level is None:
return

# attempt to convert the logging level to integer
try:
# logging level is a string representation of an integer
logging_level_num = int(logging_level)
except ValueError:
# logging level is a name string
logging_level_num = _getLevelNamesMapping()[logging_level]

# configure the CPP logging level 0 - debug, 1 - info, 2 - warning, 3 - error
cpp_log_level = min(max(0, (logging_level_num // 10) - 1), 3)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = str(cpp_log_level)

handler = logging.StreamHandler()
handler.setLevel(logging_level_num)
handler.setFormatter(logging_formatter)

# update jax and jaxlib root loggers for propagation
root_loggers = [logging.getLogger("jax"), logging.getLogger("jaxlib")]
for logger in root_loggers:
logger.setLevel(logging_level_num)
logger.addHandler(handler)
logger.propagate
logger.parent
_logging_level_handler_set[logger.name] = (handler, logger.level)


_debug_handler = logging.StreamHandler(sys.stderr)
_debug_handler.setLevel(logging.DEBUG)
# Example log message:
# DEBUG:2023-06-07 00:14:40,280:jax._src.xla_bridge:590: Initializing backend 'cpu'
_debug_handler.setFormatter(logging.Formatter(
"{levelname}:{asctime}:{name}:{lineno}: {message}", style='{'))
_debug_handler.setFormatter(logging_formatter)

_debug_enabled_loggers = []


def enable_debug_logging(logger_name):
def _enable_debug_logging(logger_name):
"""Makes the specified logger log everything to stderr.
Also adds more useful debug information to the log messages, e.g. the time.
Expand All @@ -34,21 +91,27 @@ def enable_debug_logging(logger_name):
logger_name: the name of the logger, e.g. "jax._src.xla_bridge".
"""
logger = logging.getLogger(logger_name)
_debug_enabled_loggers.append((logger, logger.level))
logger.addHandler(_debug_handler)
logger.setLevel(logging.DEBUG)
_debug_enabled_loggers.append(logger)


def disable_all_debug_logging():
def _disable_all_debug_logging():
"""Disables all debug logging enabled via `enable_debug_logging`.
The default logging behavior will still be in effect, i.e. WARNING and above
will be logged to stderr without extra message formatting.
"""
for logger in _debug_enabled_loggers:
for logger, prev_level in _debug_enabled_loggers:
logger: logging.Logger
logger.removeHandler(_debug_handler)
# Assume that the default non-debug log level is always WARNING. In theory
# we could keep track of what it was set to before. This shouldn't make a
# difference if not other handlers are attached, but set it back in case
# something else gets attached (e.g. absl logger) and for consistency.
logger.setLevel(logging.WARNING)
logger.setLevel(prev_level)
_debug_enabled_loggers.clear()

def _update_debug_log_modules(module_names_str: str | None):
_disable_all_debug_logging()
if not module_names_str:
return
module_names = module_names_str.split(',')
for module_name in module_names:
_enable_debug_logging(module_name)
191 changes: 191 additions & 0 deletions tests/logging_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
import logging
import os
import platform
import re
import shlex
import subprocess
import sys
import tempfile
import textwrap
import time
import unittest

import jax
Expand Down Expand Up @@ -49,6 +52,20 @@ def jax_debug_log_modules(value):
finally:
jax.config.update("jax_debug_log_modules", original_value)

@contextlib.contextmanager
def jax_logging_level(value):
# jax_logging_level doesn't have a context manager, because it's
# not thread-safe. But since tests are always single-threaded, we
# can define one here.
original_value = jax.config.jax_logging_level
jax.config.update("jax_logging_level", value)
try:
yield
finally:
# in case original_value is None, which skips setting logging value
# we also set the logging value directly pulled from logger.level
jax.config.update("jax_logging_level", original_value)


@contextlib.contextmanager
def capture_jax_logs():
Expand All @@ -62,6 +79,13 @@ def capture_jax_logs():
finally:
logger.removeHandler(handler)

def _get_repeated_log_fraction(logs: list[str]):
repeats = 0
for i in range(len(logs) - 1):
if logs[i] in logs[i+1] or logs[i+1] in logs[i]:
repeats += 1
return repeats / max(len(logs) - 1, 1)


class LoggingTest(jtu.JaxTestCase):

Expand Down Expand Up @@ -154,7 +178,174 @@ def test_debug_logging(self):
with capture_jax_logs() as log_output:
jax.jit(lambda x: x + 1)(1)
self.assertEmpty(log_output.getvalue())

def test_double_logging_not_present(self):
logger = logging.getLogger("jax")

# set both the debug level and the per-module debug
# test that messages are not repeated
with jax_logging_level("DEBUG"):
with jax_debug_log_modules("jax._src.cache_key"):
f = jax.jit(lambda x: x)
with self.assertLogs(logger=logger, level="DEBUG") as cm:
_ = f(jax.numpy.ones(10))
self.assertTrue(any("jax._src.cache_key" in line for line in cm.output))
# assert logs are not repeatedly printed (perhaps without a prefix)
log_repeat_fraction = _get_repeated_log_fraction(cm.output)
self.assertLess(log_repeat_fraction, 0.2) # less than 20%

def test_none_means_notset(self):
# setting the logging level to None should reset to no-logging
with jax_logging_level(None):
with capture_jax_logs() as log_output:
jax.jit(lambda x: x)(1.)
self.assertLen(log_output.getvalue(), 0)

def test_debug_log_modules_overrides_logging_level(self):
logger = logging.getLogger("jax")

# tests that logs are present (debug_log_modules overrides logging_level)
with jax_logging_level("INFO"):
with jax_debug_log_modules("jax._src.cache_key"):
with self.assertLogs(logger=logger, level="DEBUG") as cm:
_ = jax.jit(lambda x: x)(1.0)
self.assertTrue(any("jax._src.cache_key" in line for line in cm.output))

# now reverse the order
with jax_debug_log_modules("jax._src.cache_key"):
with jax_logging_level("INFO"):
with self.assertLogs(logger=logger, level="DEBUG") as cm:
_ = jax.jit(lambda x: x)(1.0)
self.assertTrue(any("jax._src.cache_key" in line for line in cm.output))

def test_debug_log_modules_of_jax_does_not_silence_future_modules(self):
logger = logging.getLogger("jax")

def _check_compiler_and_cache_key_logs(log_lines):
self.assertTrue(any(re.search(
r"jax._src.cache_key.*get_cache_key hash after serializing",
line) is not None for line in log_lines))
self.assertTrue(any(re.search(
r"jax._src.compiler.*PERSISTENT COMPILATION CACHE MISS", line)
is not None for line in log_lines))

# tests that logs are present (debug_log_modules overrides logging_level)
with jax_logging_level("DEBUG"):
with self.assertLogs(logger=logger, level="DEBUG") as cm:
_ = jax.jit(lambda x: x)(jax.numpy.ones(10))
_check_compiler_and_cache_key_logs(cm.output)

with jax_debug_log_modules("jax._src.cache_key"):
with self.assertLogs(logger=logger, level="DEBUG") as cm:
_ = jax.jit(lambda x: x)(jax.numpy.ones(10))
# assert logs are not repeatedly printed (perhaps without a prefix)
log_repeat_fraction = _get_repeated_log_fraction(cm.output)
self.assertLess(log_repeat_fraction, 0.2) # less than 20%
_check_compiler_and_cache_key_logs(cm.output)

with jax_debug_log_modules("jax"):
_ = 1 # noop

with self.assertLogs(logger=logger, level="DEBUG") as cm:
_ = jax.jit(lambda x: x)(jax.numpy.ones(10))
_check_compiler_and_cache_key_logs(cm.output)

with self.assertLogs(logger=logger, level="DEBUG") as cm:
logger_ = logging.getLogger("jax.some_future_downstream_module")
logger_.debug("Test message")
self.assertLen(cm.output, 1)
self.assertIn("Test message", cm.output[0])

@unittest.skipIf(platform.system() == "Windows",
"Subprocess test doesn't work on Windows")
def test_subprocess_stderr_logging(self):
if sys.executable is None:
raise self.skipTest("test requires access to python binary")
program = """
import jax # this prints INFO logging from backend imports
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
"""
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE) # strip indent

# test INFO
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=INFO {sys.executable} -c"
f" '{program}'")
p = subprocess.run(cmd, capture_output=True)
log_output = p.stderr.decode("utf-8")
info_num_lines = log_output.split("\n")
self.assertGreater(len(info_num_lines), 0)
self.assertIn("INFO", log_output)

# test DEBUG
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
log_output = p.stderr
debug_num_lines = log_output.split("\n")
self.assertGreater(len(info_num_lines), 0)
self.assertIn("INFO", log_output)
self.assertIn("DEBUG", log_output)
self.assertIn("Finished tracing + transforming <lambda> for pjit",
log_output)
self.assertGreater(len(debug_num_lines), len(info_num_lines))

@unittest.skipIf(platform.system() == "Windows",
"Subprocess test doesn't work on Windows")
def test_subprocess_toggling_logging_level(self):
if sys.executable is None:
raise self.skipTest("test requires access to python binary")
program = """
import jax # this prints INFO logging from backend imports
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
jax.config.update("jax_logging_level", None)
_ = input()
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
"""
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE) # strip indent

cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
f" '{program}'")
p = subprocess.Popen(cmd, stdin=subprocess.PIPE, stderr=subprocess.PIPE)

# check if the first part of the program prints DEBUG messages
time.sleep(1.0)
os.set_blocking(p.stderr.fileno(), False)
log_output_verbose = p.stderr.read().decode("utf-8")
os.set_blocking(p.stderr.fileno(), True)

# allow the program to continue to the second phase
p.stdin.write(b"a\n")
p.stdin.close()
p.wait()

# check if the second part of the program does NOT print DEBUG messages
log_output_silent = p.stderr.read()
p.stderr.close()

self.assertIn("Finished tracing + transforming <lambda> for pjit",
log_output_verbose)
self.assertEqual(log_output_silent, b"")

@unittest.skipIf(platform.system() == "Windows",
"Subprocess test doesn't work on Windows")
def test_subprocess_double_logging_absent(self):
if sys.executable is None:
raise self.skipTest("test requires access to python binary")
program = """
import jax # this prints INFO logging from backend imports
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
"""
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE) # strip indent

cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
log_output = p.stderr
self.assertNotEmpty(log_output)
log_lines = log_output.strip().split("\n")
self.assertLess(_get_repeated_log_fraction(log_lines), 0.2)

# extra subprocess tests for doubled logging in JAX_DEBUG_MODULES

if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit bddedd5

Please sign in to comment.