Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement symmetric_product() to produce a symmetric matrix: C = alpha * X @ X.T + beta * C #23062

Merged
merged 1 commit into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions jax/_src/lax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,20 @@ def cholesky_update(r_matrix: ArrayLike, w_vector: ArrayLike) -> Array:
return cholesky_update_p.bind(r_matrix, w_vector)


def symmetric_product(
a_matrix: ArrayLike, c_matrix: ArrayLike,
alpha: float = 1., beta: float = 0.,
symmetrize_output=False):
"""Computes C = alpha * A @ A.T + beta * C (where C is symmetric)."""
result = symmetric_product_p.bind(a_matrix, c_matrix, alpha=alpha, beta=beta)
if symmetrize_output:
upper_half = lax.transpose(
_tril(result, k=-1),
(*range(result.ndim - 2), result.ndim - 1, result.ndim - 2))
result = _tril(result, k=0) + upper_half
return result


def lu_pivots_to_permutation(pivots: ArrayLike, permutation_size: int) -> Array:
"""Converts the pivots (row swaps) returned by LU to a permutation.

Expand Down Expand Up @@ -592,6 +606,7 @@ def _drot(
R = R.at[k, :].set(row_k)
return R


cholesky_update_p = Primitive('cholesky_update')
cholesky_update_p.multiple_results = False
cholesky_update_p.def_abstract_eval(_cholesky_update_abstract_eval)
Expand All @@ -604,6 +619,67 @@ def _drot(
cholesky_update_p,
mlir.lower_fun(_cholesky_update_jax_fn, multiple_results=False))

# symmetric_update

def _symmetric_product_abstract_eval(a, c, *, alpha, beta):
a_dtype = dtypes.canonicalize_dtype(a.dtype)
c_dtype = dtypes.canonicalize_dtype(c.dtype)
if not (a_dtype == c_dtype and a_dtype in (np.float32, np.float64)):
raise NotImplementedError(
"Symmetric update is only implemented for float32 and float64.")
if not (a.ndim >= 2 and c.ndim >= 2
and a.shape[-2] == c.shape[-1]
and c.shape[-1] == c.shape[-2]):
raise ValueError(
"Symmetric update takes (maybe batched) matrices of matching shapes. "
"Got shapes {}, {} instead".format(a.shape, c.shape))
return ShapedArray(c.shape, c.dtype)


def _symmetric_product_batching_rule(batched_args, batch_dims, *, alpha, beta):
a_tensor, c_tensor = batched_args
a_bd, c_bd = batch_dims
a_tensor = batching.moveaxis(a_tensor, a_bd, 0)
c_tensor = batching.moveaxis(c_tensor, c_bd, 0)
return (
symmetric_product_p.bind(a_tensor, c_tensor, alpha=alpha, beta=beta), 0)

symmetric_product_p = Primitive('symmetric_update')
symmetric_product_p.multiple_results = False
symmetric_product_p.def_abstract_eval(_symmetric_product_abstract_eval)
symmetric_product_p.def_impl(
partial(dispatch.apply_primitive, symmetric_product_p))
batching.primitive_batchers[
symmetric_product_p] = _symmetric_product_batching_rule


def _symmetric_product_gpu_lowering(
platform, ctx, a_tensor, c_tensor, alpha, beta):
a_aval, c_aval = ctx.avals_in[:2]
dtype = a_aval.dtype
alpha_aval = beta_aval = ShapedArray((), dtype)

alpha_array = mlir.full_like_aval(ctx, alpha, alpha_aval)
beta_array = mlir.full_like_aval(ctx, beta, beta_aval)

rule = ffi.ffi_lowering(f"{platform}_syrk_ffi", operand_output_aliases={1: 0})
ctx = ctx.replace(avals_in=[a_aval, c_aval, alpha_aval, beta_aval])
return rule(ctx, a_tensor, c_tensor, alpha_array, beta_array, transpose=False)


def _symmetric_product_jax_fn(a, c, *, alpha, beta):
a_T = lax.transpose(a, (*range(a.ndim - 2), a.ndim - 1, a.ndim - 2))
return alpha * lax.batch_matmul(
a, a_T, precision=lax.Precision.HIGHEST) + beta * c


mlir.register_lowering(
symmetric_product_p,
partial(_symmetric_product_gpu_lowering, 'cu'), platform='cuda')
mlir.register_lowering(
symmetric_product_p,
mlir.lower_fun(_symmetric_product_jax_fn, multiple_results=False))

# Asymmetric eigendecomposition

def eig_impl(operand, *, compute_left_eigenvectors, compute_right_eigenvectors):
Expand Down
1 change: 1 addition & 0 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1548,6 +1548,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
"consume",
"ragged_dot",
"cholesky_update",
"symmetric_update",
# Pallas TPU primitives
"bitcast",
"repeat",
Expand Down
4 changes: 4 additions & 0 deletions jaxlib/cuda/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ cc_library(
hdrs = ["//jaxlib/gpu:linalg_kernels.h"],
features = ["-use_header_modules"],
deps = [
":cuda_blas_handle_pool",
":cuda_gpu_kernel_helpers",
":cuda_linalg_kernels_impl",
":cuda_vendor",
Expand All @@ -363,6 +364,7 @@ cc_library(
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
"@xla//xla/service:custom_call_status",
"@xla//xla/tsl/cuda:cublas",
],
)

Expand All @@ -373,6 +375,8 @@ cuda_library(
deps = [
":cuda_gpu_kernel_helpers",
":cuda_vendor",
"//jaxlib:ffi_helpers",
"@local_config_cuda//cuda:cuda_headers",
"@xla//xla/ffi/api:ffi",
"@xla//xla/service:custom_call_status",
],
Expand Down
2 changes: 2 additions & 0 deletions jaxlib/gpu/gpu_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cu_threefry2x32", ThreeFry2x32,
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_getrf", Getrf, "CUDA");
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_getrf_ffi", "CUDA",
GetrfFfi);
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cu_syrk_ffi", "CUDA",
SyrkFfi);
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_geqrf", Geqrf, "CUDA");
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_geqrf_ffi", "CUDA",
GeqrfFfi);
Expand Down
4 changes: 4 additions & 0 deletions jaxlib/gpu/solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -473,12 +473,16 @@ nb::dict Registrations() {
#ifdef JAX_GPU_CUDA
dict["cusolver_csrlsvqr"] = EncapsulateFunction(Csrlsvqr);
dict["cusolver_gesvdj"] = EncapsulateFunction(Gesvdj);

#endif // JAX_GPU_CUDA

dict[JAX_GPU_PREFIX "solver_getrf_ffi"] = EncapsulateFfiHandler(GetrfFfi);
dict[JAX_GPU_PREFIX "solver_geqrf_ffi"] = EncapsulateFfiHandler(GeqrfFfi);
dict[JAX_GPU_PREFIX "solver_orgqr_ffi"] = EncapsulateFfiHandler(OrgqrFfi);

dict[JAX_GPU_PREFIX "_syrk_ffi"] = EncapsulateFfiHandler(SyrkFfi);


return dict;
}

Expand Down
107 changes: 107 additions & 0 deletions jaxlib/gpu/solver_kernels_ffi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -483,5 +483,112 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(OrgqrFfi, OrgqrDispatch,

#undef SOLVER_DISPATCH_IMPL


#define SYRK_KERNEL_IMPL(type, fn) \
template <> \
struct SyrkKernel<type> { \
static absl::Status Run(gpublasHandle_t handle, std::int64_t n, \
std::int64_t k, bool transpose, \
const type* alpha, const type* beta, \
const type* a_matrix, type* c_matrix) { \
gpublasOperation_t op = transpose ? GPUBLAS_OP_N : GPUBLAS_OP_T; \
gpublasFillMode_t uplo = GPUSOLVER_FILL_MODE_UPPER; \
int lda = transpose ? n : k; \
return JAX_AS_STATUS(fn(handle, uplo, op, n, k, \
alpha, a_matrix, lda, beta, \
c_matrix, n)); \
} \
}

template <typename T>
struct SyrkKernel;

SYRK_KERNEL_IMPL(float, gpublasSsyrk);
SYRK_KERNEL_IMPL(double, gpublasDsyrk);
SYRK_KERNEL_IMPL(gpublasComplex, gpublasCsyrk);
SYRK_KERNEL_IMPL(gpublasDoubleComplex, gpublasZsyrk);
#undef SYRK_KERNEL_IMPL

template <typename T>
ffi::Error SyrkImpl(gpuStream_t stream,
ffi::AnyBuffer a_matrix,
ffi::AnyBuffer c_matrix,
bool transpose,
ffi::AnyBuffer alpha,
ffi::AnyBuffer beta,
ffi::Result<ffi::AnyBuffer> c_matrix_out) {
FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]),
SplitBatch2D(a_matrix.dimensions()));
FFI_ASSIGN_OR_RETURN((auto [batch_c, rows_c, cols_c]),
SplitBatch2D(c_matrix.dimensions()));
FFI_ASSIGN_OR_RETURN((auto [batch_out, rows_out, cols_out]),
SplitBatch2D(c_matrix_out->dimensions()));
if (batch != batch_c || batch != batch_out) {
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"a_matrix, c_matrix and c_matrix_out must have the same "
"batch size.");
}
int n = transpose ? cols : rows;
int k = transpose ? rows : cols;

FFI_RETURN_IF_ERROR(
CheckShape(c_matrix_out->dimensions().last(2), {n, n}, "out", "Syrk"));
FFI_RETURN_IF_ERROR(
CheckShape(c_matrix.dimensions().last(2), {n, n}, "C", "Syrk"));

const T* a_data = static_cast<const T*>(a_matrix.untyped_data());
T* c_data = static_cast<T*>(c_matrix.untyped_data());
T* c_out_data = static_cast<T*>(c_matrix_out->untyped_data());

// with alpha or beta provided as device_pointers, cublas<T>syrk will SIGSEGV
T host_alpha;
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
&host_alpha, alpha.untyped_data(), sizeof(T), gpuMemcpyDeviceToHost,
stream)));

T host_beta;
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
&host_beta, beta.untyped_data(), sizeof(T), gpuMemcpyDeviceToHost,
stream)));

if (c_data != c_out_data) {
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
c_out_data, c_data, c_matrix.size_bytes(), gpuMemcpyDeviceToDevice,
stream)));
}
FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream));
for (int i = 0; i < batch; ++i) {
FFI_RETURN_IF_ERROR_STATUS(SyrkKernel<T>::Run(
handle.get(), n, k, transpose, &host_alpha, &host_beta,
a_data + i * k * n, c_out_data + i * n * n));
}
return ffi::Error::Success();
}

ffi::Error SyrkDispatch(
gpuStream_t stream,
ffi::AnyBuffer a_matrix,
ffi::AnyBuffer c_matrix,
bool transpose,
ffi::AnyBuffer alpha,
ffi::AnyBuffer beta,
ffi::Result<ffi::AnyBuffer> c_matrix_out) {
auto dataType = a_matrix.element_type();
SOLVER_BLAS_DISPATCH_IMPL(SyrkImpl, stream, a_matrix, c_matrix, transpose,
alpha, beta, c_matrix_out);
return ffi::Error::InvalidArgument("Unsupported element type for Syrk");
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(SyrkFfi, SyrkDispatch,
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<gpuStream_t>>()
.Arg<ffi::AnyBuffer>() // a_matrix
.Arg<ffi::AnyBuffer>() // c_matrix
.Attr<bool>("transpose") // transpose
.Arg<ffi::AnyBuffer>() // alpha
.Arg<ffi::AnyBuffer>() // beta
.Ret<ffi::AnyBuffer>()); // c_matrix_out


} // namespace JAX_GPU_NAMESPACE
} // namespace jax
2 changes: 2 additions & 0 deletions jaxlib/gpu/solver_kernels_ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ namespace JAX_GPU_NAMESPACE {
XLA_FFI_DECLARE_HANDLER_SYMBOL(GetrfFfi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(GeqrfFfi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(OrgqrFfi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(SyrkFfi);


} // namespace JAX_GPU_NAMESPACE
} // namespace jax
Expand Down
22 changes: 22 additions & 0 deletions jaxlib/gpu/vendor.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ typedef cuDoubleComplex gpublasDoubleComplex;
typedef cublasFillMode_t gpusolverFillMode_t;
typedef cublasStatus_t gpublasStatus_t;
typedef cublasHandle_t gpublasHandle_t;
typedef cublasOperation_t gpublasOperation_t;
typedef cublasFillMode_t gpublasFillMode_t;

typedef CUcontext gpuContext_t;
typedef CUstreamCaptureMode gpustreamCaptureMode_t;
typedef CUstreamCaptureStatus gpustreamCaptureStatus_t;
Expand Down Expand Up @@ -101,6 +104,11 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
#define gpublasCgetrfBatched cublasCgetrfBatched
#define gpublasZgetrfBatched cublasZgetrfBatched

#define gpublasSsyrk cublasSsyrk
#define gpublasDsyrk cublasDsyrk
#define gpublasCsyrk cublasCsyrk
#define gpublasZsyrk cublasZsyrk

#define GPUBLAS_STATUS_SUCCESS CUBLAS_STATUS_SUCCESS

#define gpudnnCreate cudnnCreate
Expand Down Expand Up @@ -190,6 +198,10 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
#define GPUSOLVER_EIG_MODE_VECTOR CUSOLVER_EIG_MODE_VECTOR
#define GPUSOLVER_STATUS_SUCCESS CUSOLVER_STATUS_SUCCESS

#define GPUBLAS_OP_N CUBLAS_OP_N
#define GPUBLAS_OP_T CUBLAS_OP_T
#define GPUBLAS_OP_C CUBLAS_OP_C

#define gpusparseCooSetStridedBatch cusparseCooSetStridedBatch
#define gpusparseCreate cusparseCreate
#define gpusparseCreateCoo cusparseCreateCoo
Expand Down Expand Up @@ -330,6 +342,7 @@ typedef hipsolverHandle_t gpusolverDnHandle_t;
typedef hipblasFillMode_t gpublasFillMode_t;
typedef hipsolverFillMode_t gpusolverFillMode_t;
typedef hipblasHandle_t gpublasHandle_t;
typedef hipblasOperation_t gpublasOperation_t;
typedef hipblasStatus_t gpublasStatus_t;
typedef hipCtx_t gpuContext_t;
typedef hipStreamCaptureMode gpustreamCaptureMode_t;
Expand Down Expand Up @@ -372,6 +385,11 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t;
#define gpublasCgetrfBatched hipblasCgetrfBatched
#define gpublasZgetrfBatched hipblasZgetrfBatched

#define gpublasSsyrk hipblasSsyrk
#define gpublasDsyrk hipblasDsyrk
#define gpublasCsyrk hipblasCsyrk
#define gpublasZsyrk hipblasZsyrk

#define GPUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS

#define gpusolverDnCreate hipsolverCreate
Expand Down Expand Up @@ -456,6 +474,10 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t;
#define GPUSOLVER_EIG_MODE_VECTOR HIPSOLVER_EIG_MODE_VECTOR
#define GPUSOLVER_STATUS_SUCCESS HIPSOLVER_STATUS_SUCCESS

#define GPUBLAS_OP_N HIPBLAS_OP_N
#define GPUBLAS_OP_T HIPBLAS_OP_T
#define GPUBLAS_OP_C HIPBLAS_OP_C

#define gpusparseCooSetStridedBatch hipsparseCooSetStridedBatch
#define gpusparseCreate hipsparseCreate
#define gpusparseSetStream hipsparseSetStream
Expand Down
38 changes: 38 additions & 0 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2187,6 +2187,44 @@ def testHilbert(self, n):
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker)
self._CompileAndCheck(jsp_fun, args_maker)

@jtu.sample_product(
shape=[
(128, 12),
(128, 64),
(2048, 128),
],
dtype=[jnp.float32, jnp.float64],
symmetrize_output=[True, False],
)
@jtu.skip_on_devices("tpu")
def testSymmetricProduct(self, shape, dtype, symmetrize_output):
if dtype is jnp.float64 and not config.enable_x64.value:
self.skipTest("Test disabled for x32 mode")

rng = jtu.rand_default(self.rng())
batch_size = 10
atol = 1e-6 if dtype == jnp.float64 else 1e-3

a_matrix = rng((batch_size,) + shape, dtype)
c_shape = a_matrix.shape[:-1] + (a_matrix.shape[-2],)
c_matrix = jnp.zeros(c_shape, dtype)

old_product = jnp.einsum("...ij,...kj->...ik", a_matrix, a_matrix)
new_product = lax_linalg.symmetric_product(
a_matrix, c_matrix, symmetrize_output=symmetrize_output)
new_product_with_batching = jax.vmap(
lambda a, c: lax_linalg.symmetric_product(
a, c, symmetrize_output=symmetrize_output),
in_axes=(0, 0))(a_matrix, c_matrix)

if not symmetrize_output:
old_product = jnp.tril(old_product)
new_product = jnp.tril(new_product)
new_product_with_batching = jnp.tril(new_product_with_batching)
self.assertAllClose(new_product, old_product, atol=atol)
self.assertAllClose(
new_product_with_batching, old_product, atol=atol)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())