Skip to content

Commit

Permalink
Update pgo_nsys_converter.py to use the NVTX kern sum report when ava…
Browse files Browse the repository at this point in the history
…ilable.
  • Loading branch information
mgoldfarb-nvidia committed Aug 23, 2024
1 parent f54e220 commit f9dba24
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions jax/tools/pgo_nsys_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,14 @@
profile_folder = os.path.join(os.path.split(args.profile_path)[0], '')

assert isinstance(nsys_path, str)
stats_command = [nsys_path, "stats", "--force-overwrite", "true", "--force-export", "true", "--report", "nvtxkernsum", f"{args.profile_path}", "-o", f"{args.pgle_output_path}"]

# Older versions of nsys use `nvtxkernsum` for the report name so determine which is available.
query_reports_command = [nsys_path, "stats", "--help-reports"]
reports_list = subprocess.run(query_reports_command, capture_output=True, text=True).stdout
report_name = "nvtx_kern_sum" if "nvtx_kern_sum" in reports_list else "nvtxkernsum"

assert isinstance(nsys_path, str)
stats_command = [nsys_path, "stats", "--force-overwrite", "true", "--force-export", "true", "--report", report_name, f"{args.profile_path}", "-o", f"{args.pgle_output_path}"]

print(f"""
******Starting stats command******
Expand All @@ -49,7 +56,7 @@

thunk_re = re.compile("hlo_op=(.*)#")
with open(f"{args.pgle_output_path}", 'w', newline='') as protofile:
with open(f"{pgle_folder}{pgle_filename}.pbtxt_nvtxkernsum.csv", newline='') as csvfile:
with open(f"{pgle_folder}{pgle_filename}.pbtxt_{report_name}.csv", newline='') as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
name = row['NVTX Range']
Expand Down

0 comments on commit f9dba24

Please sign in to comment.