diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 714b1279037b..e2d2e2f0ab8d 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -210,6 +210,20 @@ def cholesky_update(r_matrix: ArrayLike, w_vector: ArrayLike) -> Array: return cholesky_update_p.bind(r_matrix, w_vector) +def symmetric_product( + a_matrix: ArrayLike, c_matrix: ArrayLike, + alpha: float = 1., beta: float = 0., + symmetrize_output=False): + """Computes C = alpha * A @ A.T + beta * C (where C is symmetric).""" + result = symmetric_product_p.bind(a_matrix, c_matrix, alpha=alpha, beta=beta) + if symmetrize_output: + upper_half = lax.transpose( + _tril(result, k=-1), + (*range(result.ndim - 2), result.ndim - 1, result.ndim - 2)) + result = _tril(result, k=0) + upper_half + return result + + def lu_pivots_to_permutation(pivots: ArrayLike, permutation_size: int) -> Array: """Converts the pivots (row swaps) returned by LU to a permutation. @@ -592,6 +606,7 @@ def _drot( R = R.at[k, :].set(row_k) return R + cholesky_update_p = Primitive('cholesky_update') cholesky_update_p.multiple_results = False cholesky_update_p.def_abstract_eval(_cholesky_update_abstract_eval) @@ -604,6 +619,67 @@ def _drot( cholesky_update_p, mlir.lower_fun(_cholesky_update_jax_fn, multiple_results=False)) +# symmetric_update + +def _symmetric_product_abstract_eval(a, c, *, alpha, beta): + a_dtype = dtypes.canonicalize_dtype(a.dtype) + c_dtype = dtypes.canonicalize_dtype(c.dtype) + if not (a_dtype == c_dtype and a_dtype in (np.float32, np.float64)): + raise NotImplementedError( + "Symmetric update is only implemented for float32 and float64.") + if not (a.ndim >= 2 and c.ndim >= 2 + and a.shape[-2] == c.shape[-1] + and c.shape[-1] == c.shape[-2]): + raise ValueError( + "Symmetric update takes (maybe batched) matrices of matching shapes. " + "Got shapes {}, {} instead".format(a.shape, c.shape)) + return ShapedArray(c.shape, c.dtype) + + +def _symmetric_product_batching_rule(batched_args, batch_dims, *, alpha, beta): + a_tensor, c_tensor = batched_args + a_bd, c_bd = batch_dims + a_tensor = batching.moveaxis(a_tensor, a_bd, 0) + c_tensor = batching.moveaxis(c_tensor, c_bd, 0) + return ( + symmetric_product_p.bind(a_tensor, c_tensor, alpha=alpha, beta=beta), 0) + +symmetric_product_p = Primitive('symmetric_update') +symmetric_product_p.multiple_results = False +symmetric_product_p.def_abstract_eval(_symmetric_product_abstract_eval) +symmetric_product_p.def_impl( + partial(dispatch.apply_primitive, symmetric_product_p)) +batching.primitive_batchers[ + symmetric_product_p] = _symmetric_product_batching_rule + + +def _symmetric_product_gpu_lowering( + platform, ctx, a_tensor, c_tensor, alpha, beta): + a_aval, c_aval = ctx.avals_in[:2] + dtype = a_aval.dtype + alpha_aval = beta_aval = ShapedArray((), dtype) + + alpha_array = mlir.full_like_aval(ctx, alpha, alpha_aval) + beta_array = mlir.full_like_aval(ctx, beta, beta_aval) + + rule = ffi.ffi_lowering(f"{platform}_syrk_ffi", operand_output_aliases={1: 0}) + ctx = ctx.replace(avals_in=[a_aval, c_aval, alpha_aval, beta_aval]) + return rule(ctx, a_tensor, c_tensor, alpha_array, beta_array, transpose=False) + + +def _symmetric_product_jax_fn(a, c, *, alpha, beta): + a_T = lax.transpose(a, (*range(a.ndim - 2), a.ndim - 1, a.ndim - 2)) + return alpha * lax.batch_matmul( + a, a_T, precision=lax.Precision.HIGHEST) + beta * c + + +mlir.register_lowering( + symmetric_product_p, + partial(_symmetric_product_gpu_lowering, 'cu'), platform='cuda') +mlir.register_lowering( + symmetric_product_p, + mlir.lower_fun(_symmetric_product_jax_fn, multiple_results=False)) + # Asymmetric eigendecomposition def eig_impl(operand, *, compute_left_eigenvectors, compute_right_eigenvectors): diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 741887abf24b..24dee390f398 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1548,6 +1548,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs): "consume", "ragged_dot", "cholesky_update", + "symmetric_update", # Pallas TPU primitives "bitcast", "repeat", diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 72db9868e427..5cf85f3697c7 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -351,6 +351,7 @@ cc_library( hdrs = ["//jaxlib/gpu:linalg_kernels.h"], features = ["-use_header_modules"], deps = [ + ":cuda_blas_handle_pool", ":cuda_gpu_kernel_helpers", ":cuda_linalg_kernels_impl", ":cuda_vendor", @@ -363,6 +364,7 @@ cc_library( "@xla//xla/ffi/api:c_api", "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_status", + "@xla//xla/tsl/cuda:cublas", ], ) @@ -373,6 +375,8 @@ cuda_library( deps = [ ":cuda_gpu_kernel_helpers", ":cuda_vendor", + "//jaxlib:ffi_helpers", + "@local_config_cuda//cuda:cuda_headers", "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_status", ], diff --git a/jaxlib/gpu/gpu_kernels.cc b/jaxlib/gpu/gpu_kernels.cc index 3841393654a8..c17d0ac9fd5a 100644 --- a/jaxlib/gpu/gpu_kernels.cc +++ b/jaxlib/gpu/gpu_kernels.cc @@ -46,6 +46,8 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cu_threefry2x32", ThreeFry2x32, XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_getrf", Getrf, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_getrf_ffi", "CUDA", GetrfFfi); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cu_syrk_ffi", "CUDA", + SyrkFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_geqrf", Geqrf, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_geqrf_ffi", "CUDA", GeqrfFfi); diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index fee1c1014c75..2a006d033afe 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -473,12 +473,16 @@ nb::dict Registrations() { #ifdef JAX_GPU_CUDA dict["cusolver_csrlsvqr"] = EncapsulateFunction(Csrlsvqr); dict["cusolver_gesvdj"] = EncapsulateFunction(Gesvdj); + #endif // JAX_GPU_CUDA dict[JAX_GPU_PREFIX "solver_getrf_ffi"] = EncapsulateFfiHandler(GetrfFfi); dict[JAX_GPU_PREFIX "solver_geqrf_ffi"] = EncapsulateFfiHandler(GeqrfFfi); dict[JAX_GPU_PREFIX "solver_orgqr_ffi"] = EncapsulateFfiHandler(OrgqrFfi); + dict[JAX_GPU_PREFIX "_syrk_ffi"] = EncapsulateFfiHandler(SyrkFfi); + + return dict; } diff --git a/jaxlib/gpu/solver_kernels_ffi.cc b/jaxlib/gpu/solver_kernels_ffi.cc index 6e988a6ca5e6..b757d303510f 100644 --- a/jaxlib/gpu/solver_kernels_ffi.cc +++ b/jaxlib/gpu/solver_kernels_ffi.cc @@ -483,5 +483,112 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(OrgqrFfi, OrgqrDispatch, #undef SOLVER_DISPATCH_IMPL + +#define SYRK_KERNEL_IMPL(type, fn) \ + template <> \ + struct SyrkKernel { \ + static absl::Status Run(gpublasHandle_t handle, std::int64_t n, \ + std::int64_t k, bool transpose, \ + const type* alpha, const type* beta, \ + const type* a_matrix, type* c_matrix) { \ + gpublasOperation_t op = transpose ? GPUBLAS_OP_N : GPUBLAS_OP_T; \ + gpublasFillMode_t uplo = GPUSOLVER_FILL_MODE_UPPER; \ + int lda = transpose ? n : k; \ + return JAX_AS_STATUS(fn(handle, uplo, op, n, k, \ + alpha, a_matrix, lda, beta, \ + c_matrix, n)); \ + } \ + } + +template +struct SyrkKernel; + +SYRK_KERNEL_IMPL(float, gpublasSsyrk); +SYRK_KERNEL_IMPL(double, gpublasDsyrk); +SYRK_KERNEL_IMPL(gpublasComplex, gpublasCsyrk); +SYRK_KERNEL_IMPL(gpublasDoubleComplex, gpublasZsyrk); +#undef SYRK_KERNEL_IMPL + +template +ffi::Error SyrkImpl(gpuStream_t stream, + ffi::AnyBuffer a_matrix, + ffi::AnyBuffer c_matrix, + bool transpose, + ffi::AnyBuffer alpha, + ffi::AnyBuffer beta, + ffi::Result c_matrix_out) { + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(a_matrix.dimensions())); + FFI_ASSIGN_OR_RETURN((auto [batch_c, rows_c, cols_c]), + SplitBatch2D(c_matrix.dimensions())); + FFI_ASSIGN_OR_RETURN((auto [batch_out, rows_out, cols_out]), + SplitBatch2D(c_matrix_out->dimensions())); + if (batch != batch_c || batch != batch_out) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "a_matrix, c_matrix and c_matrix_out must have the same " + "batch size."); + } + int n = transpose ? cols : rows; + int k = transpose ? rows : cols; + + FFI_RETURN_IF_ERROR( + CheckShape(c_matrix_out->dimensions().last(2), {n, n}, "out", "Syrk")); + FFI_RETURN_IF_ERROR( + CheckShape(c_matrix.dimensions().last(2), {n, n}, "C", "Syrk")); + + const T* a_data = static_cast(a_matrix.untyped_data()); + T* c_data = static_cast(c_matrix.untyped_data()); + T* c_out_data = static_cast(c_matrix_out->untyped_data()); + + // with alpha or beta provided as device_pointers, cublassyrk will SIGSEGV + T host_alpha; + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( + &host_alpha, alpha.untyped_data(), sizeof(T), gpuMemcpyDeviceToHost, + stream))); + + T host_beta; + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( + &host_beta, beta.untyped_data(), sizeof(T), gpuMemcpyDeviceToHost, + stream))); + + if (c_data != c_out_data) { + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( + c_out_data, c_data, c_matrix.size_bytes(), gpuMemcpyDeviceToDevice, + stream))); + } + FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream)); + for (int i = 0; i < batch; ++i) { + FFI_RETURN_IF_ERROR_STATUS(SyrkKernel::Run( + handle.get(), n, k, transpose, &host_alpha, &host_beta, + a_data + i * k * n, c_out_data + i * n * n)); + } + return ffi::Error::Success(); +} + +ffi::Error SyrkDispatch( + gpuStream_t stream, + ffi::AnyBuffer a_matrix, + ffi::AnyBuffer c_matrix, + bool transpose, + ffi::AnyBuffer alpha, + ffi::AnyBuffer beta, + ffi::Result c_matrix_out) { + auto dataType = a_matrix.element_type(); + SOLVER_BLAS_DISPATCH_IMPL(SyrkImpl, stream, a_matrix, c_matrix, transpose, + alpha, beta, c_matrix_out); + return ffi::Error::InvalidArgument("Unsupported element type for Syrk"); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(SyrkFfi, SyrkDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Arg() // a_matrix + .Arg() // c_matrix + .Attr("transpose") // transpose + .Arg() // alpha + .Arg() // beta + .Ret()); // c_matrix_out + + } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/solver_kernels_ffi.h b/jaxlib/gpu/solver_kernels_ffi.h index 7dbc7454c2e6..4d9b6d1371fa 100644 --- a/jaxlib/gpu/solver_kernels_ffi.h +++ b/jaxlib/gpu/solver_kernels_ffi.h @@ -25,6 +25,8 @@ namespace JAX_GPU_NAMESPACE { XLA_FFI_DECLARE_HANDLER_SYMBOL(GetrfFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(GeqrfFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(OrgqrFfi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(SyrkFfi); + } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index 077d3bb54185..bc61d58181ab 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -57,6 +57,9 @@ typedef cuDoubleComplex gpublasDoubleComplex; typedef cublasFillMode_t gpusolverFillMode_t; typedef cublasStatus_t gpublasStatus_t; typedef cublasHandle_t gpublasHandle_t; +typedef cublasOperation_t gpublasOperation_t; +typedef cublasFillMode_t gpublasFillMode_t; + typedef CUcontext gpuContext_t; typedef CUstreamCaptureMode gpustreamCaptureMode_t; typedef CUstreamCaptureStatus gpustreamCaptureStatus_t; @@ -101,6 +104,11 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; #define gpublasCgetrfBatched cublasCgetrfBatched #define gpublasZgetrfBatched cublasZgetrfBatched +#define gpublasSsyrk cublasSsyrk +#define gpublasDsyrk cublasDsyrk +#define gpublasCsyrk cublasCsyrk +#define gpublasZsyrk cublasZsyrk + #define GPUBLAS_STATUS_SUCCESS CUBLAS_STATUS_SUCCESS #define gpudnnCreate cudnnCreate @@ -190,6 +198,10 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; #define GPUSOLVER_EIG_MODE_VECTOR CUSOLVER_EIG_MODE_VECTOR #define GPUSOLVER_STATUS_SUCCESS CUSOLVER_STATUS_SUCCESS +#define GPUBLAS_OP_N CUBLAS_OP_N +#define GPUBLAS_OP_T CUBLAS_OP_T +#define GPUBLAS_OP_C CUBLAS_OP_C + #define gpusparseCooSetStridedBatch cusparseCooSetStridedBatch #define gpusparseCreate cusparseCreate #define gpusparseCreateCoo cusparseCreateCoo @@ -330,6 +342,7 @@ typedef hipsolverHandle_t gpusolverDnHandle_t; typedef hipblasFillMode_t gpublasFillMode_t; typedef hipsolverFillMode_t gpusolverFillMode_t; typedef hipblasHandle_t gpublasHandle_t; +typedef hipblasOperation_t gpublasOperation_t; typedef hipblasStatus_t gpublasStatus_t; typedef hipCtx_t gpuContext_t; typedef hipStreamCaptureMode gpustreamCaptureMode_t; @@ -372,6 +385,11 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t; #define gpublasCgetrfBatched hipblasCgetrfBatched #define gpublasZgetrfBatched hipblasZgetrfBatched +#define gpublasSsyrk hipblasSsyrk +#define gpublasDsyrk hipblasDsyrk +#define gpublasCsyrk hipblasCsyrk +#define gpublasZsyrk hipblasZsyrk + #define GPUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS #define gpusolverDnCreate hipsolverCreate @@ -456,6 +474,10 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t; #define GPUSOLVER_EIG_MODE_VECTOR HIPSOLVER_EIG_MODE_VECTOR #define GPUSOLVER_STATUS_SUCCESS HIPSOLVER_STATUS_SUCCESS +#define GPUBLAS_OP_N HIPBLAS_OP_N +#define GPUBLAS_OP_T HIPBLAS_OP_T +#define GPUBLAS_OP_C HIPBLAS_OP_C + #define gpusparseCooSetStridedBatch hipsparseCooSetStridedBatch #define gpusparseCreate hipsparseCreate #define gpusparseSetStream hipsparseSetStream diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 944880066437..0eb4f800309b 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -2187,6 +2187,44 @@ def testHilbert(self, n): self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker) self._CompileAndCheck(jsp_fun, args_maker) + @jtu.sample_product( + shape=[ + (128, 12), + (128, 64), + (2048, 128), + ], + dtype=[jnp.float32, jnp.float64], + symmetrize_output=[True, False], + ) + @jtu.skip_on_devices("tpu") + def testSymmetricProduct(self, shape, dtype, symmetrize_output): + if dtype is jnp.float64 and not config.enable_x64.value: + self.skipTest("Test disabled for x32 mode") + + rng = jtu.rand_default(self.rng()) + batch_size = 10 + atol = 1e-6 if dtype == jnp.float64 else 1e-3 + + a_matrix = rng((batch_size,) + shape, dtype) + c_shape = a_matrix.shape[:-1] + (a_matrix.shape[-2],) + c_matrix = jnp.zeros(c_shape, dtype) + + old_product = jnp.einsum("...ij,...kj->...ik", a_matrix, a_matrix) + new_product = lax_linalg.symmetric_product( + a_matrix, c_matrix, symmetrize_output=symmetrize_output) + new_product_with_batching = jax.vmap( + lambda a, c: lax_linalg.symmetric_product( + a, c, symmetrize_output=symmetrize_output), + in_axes=(0, 0))(a_matrix, c_matrix) + + if not symmetrize_output: + old_product = jnp.tril(old_product) + new_product = jnp.tril(new_product) + new_product_with_batching = jnp.tril(new_product_with_batching) + self.assertAllClose(new_product, old_product, atol=atol) + self.assertAllClose( + new_product_with_batching, old_product, atol=atol) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())