Skip to content

Commit

Permalink
Merge pull request #22972 from mgoldfarb-nvidia:mgoldfarb-nvidia/pgo_…
Browse files Browse the repository at this point in the history
…nsys_converter_update

PiperOrigin-RevId: 668147156
  • Loading branch information
jax authors committed Aug 27, 2024
2 parents db9e44f + d2b1ebd commit 88a2008
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions jax/tools/pgo_nsys_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,14 @@
profile_folder = os.path.join(os.path.split(args.profile_path)[0], '')

assert isinstance(nsys_path, str)
stats_command = [nsys_path, "stats", "--force-overwrite", "true", "--force-export", "true", "--report", "nvtxkernsum", f"{args.profile_path}", "-o", f"{args.pgle_output_path}"]

# Older versions of nsys use `nvtxsum` for the report name so determine which is available.
query_reports_command = [nsys_path, "stats", "--help-reports"]
reports_list = subprocess.run(query_reports_command, capture_output=True, text=True).stdout
report_name = "nvtx_sum" if "nvtx_sum" in reports_list else "nvtxsum"

assert isinstance(nsys_path, str)
stats_command = [nsys_path, "stats", "--force-overwrite", "true", "--force-export", "true", "--report", report_name, f"{args.profile_path}", "-o", f"{args.pgle_output_path}"]

print(f"""
******Starting stats command******
Expand All @@ -49,10 +56,10 @@

thunk_re = re.compile("hlo_op=(.*)#")
with open(f"{args.pgle_output_path}", 'w', newline='') as protofile:
with open(f"{pgle_folder}{pgle_filename}.pbtxt_nvtxkernsum.csv", newline='') as csvfile:
with open(f"{pgle_folder}{pgle_filename}.pbtxt_{report_name}.csv", newline='') as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
name = row['NVTX Range']
name = row['Range']
time_ns = float(row['Avg (ns)'])
m = thunk_re.search(name)
if m is not None:
Expand Down

0 comments on commit 88a2008

Please sign in to comment.