diff --git a/jax/tools/pgo_nsys_converter.py b/jax/tools/pgo_nsys_converter.py index 5e87220be606..e4a3be953172 100644 --- a/jax/tools/pgo_nsys_converter.py +++ b/jax/tools/pgo_nsys_converter.py @@ -38,7 +38,14 @@ profile_folder = os.path.join(os.path.split(args.profile_path)[0], '') assert isinstance(nsys_path, str) - stats_command = [nsys_path, "stats", "--force-overwrite", "true", "--force-export", "true", "--report", "nvtxkernsum", f"{args.profile_path}", "-o", f"{args.pgle_output_path}"] + + # Older versions of nsys use `nvtxkernsum` for the report name so determine which is available. + query_reports_command = [nsys_path, "stats", "--help-reports"] + reports_list = subprocess.run(query_reports_command, capture_output=True, text=True).stdout + report_name = "nvtx_kern_sum" if "nvtx_kern_sum" in reports_list else "nvtxkernsum" + + assert isinstance(nsys_path, str) + stats_command = [nsys_path, "stats", "--force-overwrite", "true", "--force-export", "true", "--report", report_name, f"{args.profile_path}", "-o", f"{args.pgle_output_path}"] print(f""" ******Starting stats command****** @@ -49,7 +56,7 @@ thunk_re = re.compile("hlo_op=(.*)#") with open(f"{args.pgle_output_path}", 'w', newline='') as protofile: - with open(f"{pgle_folder}{pgle_filename}.pbtxt_nvtxkernsum.csv", newline='') as csvfile: + with open(f"{pgle_folder}{pgle_filename}.pbtxt_{report_name}.csv", newline='') as csvfile: reader = csv.DictReader(csvfile) for row in reader: name = row['NVTX Range']