Skip to content

Commit

Permalink
Port the GPU Cholesky update custom call to the FFI.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 665319689
  • Loading branch information
dfm authored and jax authors committed Aug 20, 2024
1 parent 71a93d0 commit bd90968
Show file tree
Hide file tree
Showing 10 changed files with 150 additions and 78 deletions.
38 changes: 22 additions & 16 deletions jax/_src/lax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,20 +505,26 @@ def _cholesky_update_abstract_eval(r_matrix, w_vector):
r_matrix.shape, w_vector.shape))
return ShapedArray(r_matrix.shape, r_matrix.dtype)

def _cholesky_update_cuda_lowering_rule(ctx, r_matrix, w_vector):
r_matrix_aval, _ = ctx.avals_in
try:
[platform] = ctx.module_context.platforms
except ValueError:
raise ValueError(
"Can only lower cholesky_update on a single platform."
) from None
if platform != "cuda":
raise NotImplementedError(
"Can only lower fast cholesky_update on CUDA."
)
return gpu_linalg.cuda_cholesky_update(
r_matrix, w_vector, r_matrix_aval.dtype)
def _cholesky_update_gpu_lowering_rule(target_name_prefix, ctx, r_matrix, w_vector):
# TODO(b/360781533): Remove guard after 3 week forward compatibility period.
if ctx.is_forward_compat() or jaxlib_version < (0, 4, 32):
r_matrix_aval, _ = ctx.avals_in
try:
[platform] = ctx.module_context.platforms
except ValueError:
raise ValueError(
"Can only lower cholesky_update on a single platform."
) from None
if platform != "cuda":
raise NotImplementedError(
"Can only lower fast cholesky_update on CUDA."
)
return gpu_linalg.cuda_cholesky_update(
r_matrix, w_vector, r_matrix_aval.dtype)
rule = ffi.ffi_lowering(f"{target_name_prefix}_cholesky_update_ffi",
operand_output_aliases={0: 0, 1: 1})
sub_ctx = ctx.replace(avals_out=ctx.avals_in)
return rule(sub_ctx, r_matrix, w_vector)[:1]


def _cholesky_update_jax_fn(R, z):
Expand Down Expand Up @@ -557,8 +563,8 @@ def _drot(
cholesky_update_p.def_impl(partial(dispatch.apply_primitive, cholesky_update_p))

mlir.register_lowering(
cholesky_update_p, _cholesky_update_cuda_lowering_rule, platform='cuda')

cholesky_update_p, partial(_cholesky_update_gpu_lowering_rule, "cu"),
platform='cuda')
mlir.register_lowering(
cholesky_update_p,
mlir.lower_fun(_cholesky_update_jax_fn, multiple_results=False))
Expand Down
10 changes: 2 additions & 8 deletions jaxlib/cuda/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,6 @@ cc_library(
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
"@xla//xla/service:custom_call_status",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
Expand All @@ -367,18 +366,13 @@ cc_library(

cuda_library(
name = "cuda_linalg_kernels_impl",
srcs = [
"//jaxlib/gpu:linalg_kernels.cu.cc",
],
hdrs = [
"//jaxlib/gpu:linalg_kernels.h",
],
srcs = ["//jaxlib/gpu:linalg_kernels.cu.cc"],
hdrs = ["//jaxlib/gpu:linalg_kernels.h"],
deps = [
":cuda_gpu_kernel_helpers",
":cuda_vendor",
"@xla//xla/ffi/api:ffi",
"@xla//xla/service:custom_call_status",
"@local_config_cuda//cuda:cuda_headers",
],
)

Expand Down
2 changes: 2 additions & 0 deletions jaxlib/gpu/gpu_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_sytrd", Sytrd, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvd", Gesvd, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvdj", Gesvdj, "CUDA");

XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cu_cholesky_update_ffi", "CUDA",
CholeskyUpdateFfi);
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cu_lu_pivots_to_permutation",
"CUDA", LuPivotsToPermutation);
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cu_threefry2x32_ffi", "CUDA",
Expand Down
2 changes: 2 additions & 0 deletions jaxlib/gpu/linalg.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ NB_MODULE(_linalg, m) {
EncapsulateFfiHandler(LuPivotsToPermutation);
dict[JAX_GPU_PREFIX "_cholesky_update"] =
EncapsulateFunction(CholeskyUpdate);
dict[JAX_GPU_PREFIX "_cholesky_update_ffi"] =
EncapsulateFunction(CholeskyUpdateFfi);
return dict;
});
m.def("build_cholesky_update_descriptor", &BuildCholeskyUpdateDescriptor);
Expand Down
81 changes: 59 additions & 22 deletions jaxlib/gpu/linalg_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,9 @@ limitations under the License.
#include "jaxlib/gpu/linalg_kernels.h"

#include <cstddef>
#include <cstdint>
#include <functional>
#include <string>
#include <string_view>
#include <utility>

#include "absl/algorithm/container.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
Expand Down Expand Up @@ -60,32 +56,73 @@ void CholeskyUpdate(gpuStream_t stream, void** buffers, const char* opaque,
}

namespace {
absl::StatusOr<std::pair<std::int64_t, std::int32_t>> GetDimensions(
ffi::Span<const std::int64_t> dims, const std::string& arg_name) {
if (dims.size() < 1) {
return absl::InvalidArgumentError(
absl::StrFormat("%s must have at least one dimension", arg_name));
ffi::Error CholeskyUpdateFfiImpl(gpuStream_t stream, ffi::AnyBuffer matrix_in,
ffi::AnyBuffer vector_in,
ffi::Result<ffi::AnyBuffer> matrix_out,
ffi::Result<ffi::AnyBuffer> vector_out) {
FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]),
SplitBatch2D(matrix_in.dimensions()));
if (rows != cols) {
return ffi::Error::InvalidArgument(
"The matrix input to Cholesky update must be square.");
}
std::int64_t batch_size = 1;
if (dims.size() >= 2) {
batch_size =
absl::c_accumulate(dims.first(dims.size() - 1), 1, std::multiplies<>());
FFI_RETURN_IF_ERROR(CheckShape(vector_in.dimensions(), {batch, cols},
"vector", "cholesky_update"));
FFI_RETURN_IF_ERROR(CheckShape(matrix_out->dimensions(), {batch, rows, cols},
"matrix_out", "cholesky_update"));
FFI_RETURN_IF_ERROR(CheckShape(vector_out->dimensions(), {batch, cols},
"vector_out", "cholesky_update"));
FFI_ASSIGN_OR_RETURN(auto size, MaybeCastNoOverflow<int>(cols));
auto dtype = matrix_in.element_type();
if (dtype != ffi::F32 && dtype != ffi::F64) {
return ffi::Error::InvalidArgument(
"Invalid input type for Cholesky update; must be float32 or float64.");
}
JAX_ASSIGN_OR_RETURN(auto size,
MaybeCastNoOverflow<std::int32_t>(dims.back()));
return std::make_pair(batch_size, size);
if (vector_in.element_type() != dtype ||
matrix_out->element_type() != dtype ||
vector_out->element_type() != dtype) {
return ffi::Error::InvalidArgument(
"All input and output types for Cholesky update must match.");
}
bool is_single_precision = dtype == ffi::F32;
auto matrix = matrix_out->untyped_data();
if (matrix_in.untyped_data() != matrix) {
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(
gpuMemcpyAsync(matrix, matrix_in.untyped_data(), matrix_in.size_bytes(),
gpuMemcpyDeviceToDevice, stream)));
}
auto vector = vector_out->untyped_data();
if (vector_in.untyped_data() != vector) {
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(
gpuMemcpyAsync(vector, vector_in.untyped_data(), vector_in.size_bytes(),
gpuMemcpyDeviceToDevice, stream)));
}
for (auto n = 0; n < batch; ++n) {
LaunchCholeskyUpdateFfiKernel(stream, matrix, vector, size,
is_single_precision);
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError()));
}
return ffi::Error::Success();
}
} // namespace

XLA_FFI_DEFINE_HANDLER_SYMBOL(CholeskyUpdateFfi, CholeskyUpdateFfiImpl,
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<gpuStream_t>>()
.Arg<ffi::AnyBuffer>()
.Arg<ffi::AnyBuffer>()
.Ret<ffi::AnyBuffer>()
.Ret<ffi::AnyBuffer>());

namespace {
ffi::Error LuPivotsToPermutationImpl(
gpuStream_t stream, ffi::Dictionary /* unused */,
ffi::Buffer<ffi::DataType::S32> pivots,
ffi::Result<ffi::Buffer<ffi::DataType::S32>> permutation) {
FFI_ASSIGN_OR_RETURN(auto pivots_dims,
GetDimensions(pivots.dimensions(), "pivots"));
FFI_ASSIGN_OR_RETURN(auto permutation_dims,
GetDimensions(permutation->dimensions(), "permutation"));
auto [batch_size, pivot_size] = pivots_dims;
auto [permutation_batch, permutation_size] = permutation_dims;
FFI_ASSIGN_OR_RETURN((auto [batch_size, pivot_size]),
SplitBatch1D(pivots.dimensions()));
FFI_ASSIGN_OR_RETURN((auto [permutation_batch, permutation_size]),
SplitBatch1D(permutation->dimensions()));
if (permutation_batch != batch_size) {
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"pivots and permutation must have the same batch size.");
Expand Down
62 changes: 39 additions & 23 deletions jaxlib/gpu/linalg_kernels.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,11 @@ limitations under the License.

#include "jaxlib/gpu/linalg_kernels.h"

#include <array>
#include <algorithm>
#include <cstdint>
#include <iostream>

#include "jaxlib/gpu/vendor.h"

#ifdef JAX_GPU_HIP
#include "rocm/include/hip/amd_detail/amd_hip_cooperative_groups.h"
#else // JAX_GPU_CUDA
#include "third_party/gpus/cuda/include/cooperative_groups.h"
#endif

namespace cg = cooperative_groups;

namespace jax {
Expand All @@ -47,7 +40,6 @@ __device__ void drotg(T* da, T* db, T* c, T* s) {
T rh = rhypot(a, b);
*c = a * rh;
*s = -(b * rh);
return;
}

template <typename T>
Expand Down Expand Up @@ -85,15 +77,9 @@ void LaunchCholeskyUpdateKernelBody(gpuStream_t stream, void** buffers,
reinterpret_cast<void*>(&uVector),
reinterpret_cast<void*>(&nSize),
};
#ifdef JAX_GPU_HIP
hipLaunchCooperativeKernel((void*)CholeskyUpdateKernel<T>, grid_dim,
gpuLaunchCooperativeKernel((void*)CholeskyUpdateKernel<T>, grid_dim,
block_dim, arg_ptrs,
/*dynamic_shared_mem_bytes=*/0, stream);
#else // JAX_GPU_CUDA
cudaLaunchCooperativeKernel((void*)CholeskyUpdateKernel<T>, grid_dim,
block_dim, arg_ptrs,
/*dynamic_shared_mem_bytes=*/0, stream);
#endif
}

void LaunchCholeskyUpdateKernel(gpuStream_t stream, void** buffers,
Expand All @@ -102,13 +88,8 @@ void LaunchCholeskyUpdateKernel(gpuStream_t stream, void** buffers,
LinalgType type = descriptor.linalg_type;

int dev = 0;
#ifdef JAX_GPU_HIP
hipDeviceProp_t deviceProp;
hipGetDeviceProperties(&deviceProp, dev);
#else // JAX_GPU_CUDA
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, dev);
#endif
gpuDeviceProp deviceProp;
gpuGetDeviceProperties(&deviceProp, dev);

int block_dim = deviceProp.maxThreadsPerBlock;
int grid_dim = deviceProp.multiProcessorCount;
Expand All @@ -125,6 +106,41 @@ void LaunchCholeskyUpdateKernel(gpuStream_t stream, void** buffers,
}
}

template <typename T>
void LaunchCholeskyUpdateFfiKernelBody(gpuStream_t stream, void* matrix,
void* vector, int grid_dim,
int block_dim, int nSize) {
T* rMatrix = reinterpret_cast<T*>(matrix);
T* uVector = reinterpret_cast<T*>(vector);

void* arg_ptrs[3] = {
reinterpret_cast<void*>(&rMatrix),
reinterpret_cast<void*>(&uVector),
reinterpret_cast<void*>(&nSize),
};
gpuLaunchCooperativeKernel((void*)CholeskyUpdateKernel<T>, grid_dim,
block_dim, arg_ptrs,
/*dynamic_shared_mem_bytes=*/0, stream);
}

void LaunchCholeskyUpdateFfiKernel(gpuStream_t stream, void* matrix,
void* vector, int size,
bool is_single_precision) {
int dev = 0;
gpuDeviceProp deviceProp;
gpuGetDeviceProperties(&deviceProp, dev);
int block_dim = deviceProp.maxThreadsPerBlock;
int grid_dim = deviceProp.multiProcessorCount;

if (is_single_precision) {
LaunchCholeskyUpdateFfiKernelBody<float>(stream, matrix, vector, grid_dim,
block_dim, size);
} else {
LaunchCholeskyUpdateFfiKernelBody<double>(stream, matrix, vector, grid_dim,
block_dim, size);
}
}

namespace {

__device__ void ComputePermutation(const std::int32_t* pivots,
Expand Down
8 changes: 5 additions & 3 deletions jaxlib/gpu/linalg_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ limitations under the License.
namespace jax {
namespace JAX_GPU_NAMESPACE {

namespace ffi = xla::ffi;

enum LinalgType {
F32 = 0,
F64 = 1,
Expand All @@ -44,13 +42,17 @@ void LaunchCholeskyUpdateKernel(gpuStream_t stream, void** buffers,
void CholeskyUpdate(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);

void LaunchCholeskyUpdateFfiKernel(gpuStream_t stream, void* matrix,
void* vector, int size,
bool is_single_precision);
XLA_FFI_DECLARE_HANDLER_SYMBOL(CholeskyUpdateFfi);

void LaunchLuPivotsToPermutationKernel(gpuStream_t stream,
std::int64_t batch_size,
std::int32_t pivot_size,
std::int32_t permutation_size,
const std::int32_t* pivots,
std::int32_t* permutation);

XLA_FFI_DECLARE_HANDLER_SYMBOL(LuPivotsToPermutation);

} // namespace JAX_GPU_NAMESPACE
Expand Down
14 changes: 12 additions & 2 deletions jaxlib/gpu/vendor.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#if defined(JAX_GPU_CUDA)

#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cooperative_groups.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cuComplex.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cublas_v2.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cuda.h" // IWYU pragma: export
Expand All @@ -31,8 +32,8 @@ limitations under the License.
#include "third_party/gpus/cuda/include/cufft.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cusolverDn.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cusolver_common.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cusparse.h" // IWYU pragma: export
#include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cusparse.h" // IWYU pragma: export
#include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: export

#if CUDA_VERSION < 11080
#error "JAX requires CUDA 11.8 or newer."
Expand Down Expand Up @@ -292,6 +293,10 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
#define gpuStreamWaitEvent cudaStreamWaitEvent
#define gpuSuccess cudaSuccess

#define gpuDeviceProp cudaDeviceProp
#define gpuGetDeviceProperties cudaGetDeviceProperties
#define gpuLaunchCooperativeKernel cudaLaunchCooperativeKernel

namespace jax::JAX_GPU_NAMESPACE {
namespace {
constexpr uint32_t kNumThreadsPerWarp = 32;
Expand All @@ -300,6 +305,7 @@ constexpr uint32_t kNumThreadsPerWarp = 32;

#elif defined(JAX_GPU_HIP)

#include "rocm/include/hip/amd_detail/amd_hip_cooperative_groups.h"
#include "rocm/include/hip/hip_runtime_api.h"
#include "rocm/include/hipblas/hipblas.h"
#include "rocm/include/hipsolver/hipsolver.h"
Expand Down Expand Up @@ -541,6 +547,10 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t;
HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES
#define GPU_EVENT_DEFAULT hipEventDefault

#define gpuDeviceProp hipDeviceProp_t
#define gpuGetDeviceProperties hipGetDeviceProperties
#define gpuLaunchCooperativeKernel hipLaunchCooperativeKernel

namespace jax::JAX_GPU_NAMESPACE {
namespace {
constexpr uint32_t kNumThreadsPerWarp = 64;
Expand Down
Loading

0 comments on commit bd90968

Please sign in to comment.