From f4517ba6ddafd6e636537a2b38321bc78a1841cb Mon Sep 17 00:00:00 2001 From: Michael Goldfarb Date: Thu, 22 Aug 2024 14:22:01 +0000 Subject: [PATCH] query for report --- jax/tools/pgo_nsys_converter.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/jax/tools/pgo_nsys_converter.py b/jax/tools/pgo_nsys_converter.py index 3f09e569e575..e4a3be953172 100644 --- a/jax/tools/pgo_nsys_converter.py +++ b/jax/tools/pgo_nsys_converter.py @@ -38,7 +38,14 @@ profile_folder = os.path.join(os.path.split(args.profile_path)[0], '') assert isinstance(nsys_path, str) - stats_command = [nsys_path, "stats", "--force-overwrite", "true", "--force-export", "true", "--report", "nvtx_kern_sum", f"{args.profile_path}", "-o", f"{args.pgle_output_path}"] + + # Older versions of nsys use `nvtxkernsum` for the report name so determine which is available. + query_reports_command = [nsys_path, "stats", "--help-reports"] + reports_list = subprocess.run(query_reports_command, capture_output=True, text=True).stdout + report_name = "nvtx_kern_sum" if "nvtx_kern_sum" in reports_list else "nvtxkernsum" + + assert isinstance(nsys_path, str) + stats_command = [nsys_path, "stats", "--force-overwrite", "true", "--force-export", "true", "--report", report_name, f"{args.profile_path}", "-o", f"{args.pgle_output_path}"] print(f""" ******Starting stats command****** @@ -49,7 +56,7 @@ thunk_re = re.compile("hlo_op=(.*)#") with open(f"{args.pgle_output_path}", 'w', newline='') as protofile: - with open(f"{pgle_folder}{pgle_filename}.pbtxt_nvtx_kern_sum.csv", newline='') as csvfile: + with open(f"{pgle_folder}{pgle_filename}.pbtxt_{report_name}.csv", newline='') as csvfile: reader = csv.DictReader(csvfile) for row in reader: name = row['NVTX Range']