From 4b7c198a1c99049e2b41c18a9e06a5173e4e54f2 Mon Sep 17 00:00:00 2001 From: Rahul Batra Date: Thu, 25 Jul 2024 19:02:33 +0000 Subject: [PATCH] [ROCm]: Add get_arch_details for triton kernel call --- jaxlib/gpu/triton.cc | 12 ++++++++++++ jaxlib/gpu_triton.py | 2 ++ 2 files changed, 14 insertions(+) diff --git a/jaxlib/gpu/triton.cc b/jaxlib/gpu/triton.cc index 1274eeba466b..500034af3ebb 100644 --- a/jaxlib/gpu/triton.cc +++ b/jaxlib/gpu/triton.cc @@ -132,6 +132,18 @@ NB_MODULE(_triton, m) { return major * 10 + minor; })); + m.def( + "get_arch_details", + ValueOrThrowWrapper([](int device) -> absl::StatusOr { +#ifdef JAX_GPU_HIP + hipDeviceProp_t prop; + hipGetDeviceProperties(&prop, 0); + return prop.gcnArchName; +#else + return absl::UnimplementedError("Not a HIP GPU"); +#endif + })); + m.def("get_serialized_metadata", ValueOrThrowWrapper( [](nb::bytes opaque) -> absl::StatusOr { diff --git a/jaxlib/gpu_triton.py b/jaxlib/gpu_triton.py index f2d37bfec03d..77f315e5b4b1 100644 --- a/jaxlib/gpu_triton.py +++ b/jaxlib/gpu_triton.py @@ -35,6 +35,7 @@ create_array_parameter = _cuda_triton.create_array_parameter create_scalar_parameter = _cuda_triton.create_scalar_parameter get_compute_capability = _cuda_triton.get_compute_capability + get_arch_details = _cuda_triton.get_arch_details get_custom_call = _cuda_triton.get_custom_call get_serialized_metadata = _cuda_triton.get_serialized_metadata @@ -58,5 +59,6 @@ create_array_parameter = _hip_triton.create_array_parameter create_scalar_parameter = _hip_triton.create_scalar_parameter get_compute_capability = _hip_triton.get_compute_capability + get_arch_details = _hip_triton.get_arch_details get_custom_call = _hip_triton.get_custom_call get_serialized_metadata = _hip_triton.get_serialized_metadata