Skip to content

Commit

Permalink
Use pathlib for profiler log_dir
Browse files Browse the repository at this point in the history
  • Loading branch information
cool-RR committed Jul 28, 2024
1 parent dab15d6 commit 538c64b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 21 deletions.
35 changes: 16 additions & 19 deletions jax/_src/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
from collections.abc import Callable
from contextlib import contextmanager
from functools import wraps
import glob
import gzip
import http.server
import json
import logging
import os
import pathlib
import socketserver
import threading
from typing import Any
Expand Down Expand Up @@ -88,7 +88,7 @@ def reset(self):
_profile_state = _ProfileState()


def start_trace(log_dir, create_perfetto_link: bool = False,
def start_trace(log_dir: os.PathLike | str, create_perfetto_link: bool = False,
create_perfetto_trace: bool = False) -> None:
"""Starts a profiler trace.
Expand Down Expand Up @@ -129,20 +129,18 @@ def start_trace(log_dir, create_perfetto_link: bool = False,
_profile_state.create_perfetto_link = create_perfetto_link
_profile_state.create_perfetto_trace = (
create_perfetto_trace or create_perfetto_link)
_profile_state.log_dir = str(log_dir)
_profile_state.log_dir = pathlib.Path(log_dir)


def _write_perfetto_trace_file(log_dir):
def _write_perfetto_trace_file(log_dir: os.PathLike | str):
# Navigate to folder with the latest trace dump to find `trace.json.jz`
curr_path = os.path.abspath(log_dir)
root_trace_folder = os.path.join(curr_path, "plugins", "profile")
trace_folders = [os.path.join(root_trace_folder, trace_folder) for
trace_folder in os.listdir(root_trace_folder)]
latest_folder = max(trace_folders, key=os.path.getmtime)
trace_jsons = glob.glob(os.path.join(latest_folder, "*.trace.json.gz"))
if len(trace_jsons) != 1:
raise ValueError(f"Invalid trace folder: {latest_folder}")
trace_json, = trace_jsons
trace_folders = (pathlib.Path(log_dir).absolute() / "plugins" / "profile").iterdir()
latest_trace_folder = max(trace_folders, key=os.path.getmtime)
trace_jsons = latest_trace_folder.glob("*.trace.json.gz")
try:
trace_json, = trace_jsons
except ValueError as value_error:
raise ValueError(f"Invalid trace folder: {latest_trace_folder}") from value_error

logger.info("Loading trace.json.gz and removing its metadata...")
# Perfetto doesn't like the `metadata` field in `trace.json` so we remove
Expand All @@ -152,8 +150,7 @@ def _write_perfetto_trace_file(log_dir):
with gzip.open(trace_json, "rb") as fp:
trace = json.load(fp)
del trace["metadata"]
filename = "perfetto_trace.json.gz"
perfetto_trace = os.path.join(latest_folder, filename)
perfetto_trace = latest_trace_folder / "perfetto_trace.json.gz"
logger.info("Writing perfetto_trace.json.gz...")
with gzip.open(perfetto_trace, "w") as fp:
fp.write(json.dumps(trace).encode("utf-8"))
Expand All @@ -173,11 +170,11 @@ def do_GET(self):
def do_POST(self):
self.send_error(404, "File not found")

def _host_perfetto_trace_file(path):
def _host_perfetto_trace_file(path: os.PathLike | str):
# ui.perfetto.dev looks for files hosted on `127.0.0.1:9001`. We set up a
# TCP server that is hosting the `perfetto_trace.json.gz` file.
port = 9001
orig_directory = os.path.abspath(os.getcwd())
orig_directory = pathlib.Path.cwd()
directory, filename = os.path.split(path)
try:
os.chdir(directory)
Expand All @@ -203,7 +200,7 @@ def stop_trace():
if _profile_state.profile_session is None:
raise RuntimeError("No profile started")
sess = _profile_state.profile_session
sess.export(sess.stop(), _profile_state.log_dir)
sess.export(sess.stop(), str(_profile_state.log_dir))
if _profile_state.create_perfetto_trace:
abs_filename = _write_perfetto_trace_file(_profile_state.log_dir)
if _profile_state.create_perfetto_link:
Expand All @@ -227,7 +224,7 @@ def stop_and_get_fdo_profile() -> bytes | str:


@contextmanager
def trace(log_dir, create_perfetto_link=False, create_perfetto_trace=False):
def trace(log_dir: os.PathLike | str, create_perfetto_link=False, create_perfetto_trace=False):
"""Context manager to take a profiler trace.
The trace will capture CPU, GPU, and/or TPU activity, including Python
Expand Down
4 changes: 2 additions & 2 deletions jax/collect_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
help="Profiler Python tracer level", type=int)

def collect_profile(port: int, duration_in_ms: int, host: str,
log_dir: str | None, host_tracer_level: int,
log_dir: os.PathLike | str | None, host_tracer_level: int,
device_tracer_level: int, python_tracer_level: int,
no_perfetto_link: bool):
options = profiler.ProfilerOptions(
Expand Down Expand Up @@ -97,7 +97,7 @@ def collect_profile(port: int, duration_in_ms: int, host: str,
fp.write(result.encode("utf-8"))

if not no_perfetto_link:
path = jax_profiler._write_perfetto_trace_file(str(log_dir_))
path = jax_profiler._write_perfetto_trace_file(log_dir_)
jax_profiler._host_perfetto_trace_file(path)

def main(args):
Expand Down

0 comments on commit 538c64b

Please sign in to comment.