From d2b1ebd0aa504d2d0dddc3e2ffdb7240be7bdff6 Mon Sep 17 00:00:00 2001 From: Michael Goldfarb Date: Fri, 9 Aug 2024 13:45:44 -0500 Subject: [PATCH] Update pgo_nsys_converter.py to use the NVTX kern sum report when available. --- jax/tools/pgo_nsys_converter.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/jax/tools/pgo_nsys_converter.py b/jax/tools/pgo_nsys_converter.py index 5e87220be606..5460edd960f5 100644 --- a/jax/tools/pgo_nsys_converter.py +++ b/jax/tools/pgo_nsys_converter.py @@ -38,7 +38,14 @@ profile_folder = os.path.join(os.path.split(args.profile_path)[0], '') assert isinstance(nsys_path, str) - stats_command = [nsys_path, "stats", "--force-overwrite", "true", "--force-export", "true", "--report", "nvtxkernsum", f"{args.profile_path}", "-o", f"{args.pgle_output_path}"] + + # Older versions of nsys use `nvtxsum` for the report name so determine which is available. + query_reports_command = [nsys_path, "stats", "--help-reports"] + reports_list = subprocess.run(query_reports_command, capture_output=True, text=True).stdout + report_name = "nvtx_sum" if "nvtx_sum" in reports_list else "nvtxsum" + + assert isinstance(nsys_path, str) + stats_command = [nsys_path, "stats", "--force-overwrite", "true", "--force-export", "true", "--report", report_name, f"{args.profile_path}", "-o", f"{args.pgle_output_path}"] print(f""" ******Starting stats command****** @@ -49,10 +56,10 @@ thunk_re = re.compile("hlo_op=(.*)#") with open(f"{args.pgle_output_path}", 'w', newline='') as protofile: - with open(f"{pgle_folder}{pgle_filename}.pbtxt_nvtxkernsum.csv", newline='') as csvfile: + with open(f"{pgle_folder}{pgle_filename}.pbtxt_{report_name}.csv", newline='') as csvfile: reader = csv.DictReader(csvfile) for row in reader: - name = row['NVTX Range'] + name = row['Range'] time_ns = float(row['Avg (ns)']) m = thunk_re.search(name) if m is not None: