Skip to content

Commit

Permalink
Port Schur Decomposition to XLA's FFI
Browse files Browse the repository at this point in the history
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 639830807
  • Loading branch information
pparuzel authored and Google-ML-Automation committed Oct 14, 2024
1 parent ec68d42 commit 738ab4e
Show file tree
Hide file tree
Showing 7 changed files with 360 additions and 5 deletions.
1 change: 1 addition & 0 deletions jaxlib/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ pybind_extension(
module_name = "_lapack",
pytype_srcs = [
"_lapack/__init__.pyi",
"_lapack/schur.pyi",
"_lapack/svd.pyi",
"_lapack/eig.pyi",
],
Expand Down
26 changes: 26 additions & 0 deletions jaxlib/cpu/_lapack/schur.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
from typing import ClassVar


class ComputationMode(enum.Enum):
kComputeSchurVectors: ClassVar[ComputationMode]
kNoComputeSchurVectors: ClassVar[ComputationMode]


class Sort(enum.Enum):
kNoSortEigenvalues: ClassVar[Sort]
kSortEigenvalues: ClassVar[Sort]
4 changes: 4 additions & 0 deletions jaxlib/cpu/cpu_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ JAX_CPU_REGISTER_HANDLER(lapack_ssytrd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_dsytrd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_chetrd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_zhetrd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_sgees_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_dgees_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_cgees_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_zgees_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_sgehrd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_dgehrd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_cgehrd_ffi);
Expand Down
18 changes: 18 additions & 0 deletions jaxlib/cpu/lapack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ void GetLapackKernelsFromScipy() {
AssignKernelFn<RealGees<double>>(lapack_ptr("dgees"));
AssignKernelFn<ComplexGees<std::complex<float>>>(lapack_ptr("cgees"));
AssignKernelFn<ComplexGees<std::complex<double>>>(lapack_ptr("zgees"));
AssignKernelFn<SchurDecomposition<DataType::F32>>(lapack_ptr("sgees"));
AssignKernelFn<SchurDecomposition<DataType::F64>>(lapack_ptr("dgees"));
AssignKernelFn<SchurDecompositionComplex<DataType::C64>>(lapack_ptr("cgees"));
AssignKernelFn<SchurDecompositionComplex<DataType::C128>>(
lapack_ptr("zgees"));

AssignKernelFn<Gehrd<float>>(lapack_ptr("sgehrd"));
AssignKernelFn<Gehrd<double>>(lapack_ptr("dgehrd"));
Expand Down Expand Up @@ -265,6 +270,10 @@ nb::dict Registrations() {
dict["lapack_dsytrd_ffi"] = EncapsulateFunction(lapack_dsytrd_ffi);
dict["lapack_chetrd_ffi"] = EncapsulateFunction(lapack_chetrd_ffi);
dict["lapack_zhetrd_ffi"] = EncapsulateFunction(lapack_zhetrd_ffi);
dict["lapack_sgees_ffi"] = EncapsulateFunction(lapack_sgees_ffi);
dict["lapack_dgees_ffi"] = EncapsulateFunction(lapack_dgees_ffi);
dict["lapack_cgees_ffi"] = EncapsulateFunction(lapack_cgees_ffi);
dict["lapack_zgees_ffi"] = EncapsulateFunction(lapack_zgees_ffi);
dict["lapack_sgehrd_ffi"] = EncapsulateFunction(lapack_sgehrd_ffi);
dict["lapack_dgehrd_ffi"] = EncapsulateFunction(lapack_dgehrd_ffi);
dict["lapack_cgehrd_ffi"] = EncapsulateFunction(lapack_cgehrd_ffi);
Expand All @@ -280,6 +289,7 @@ NB_MODULE(_lapack, m) {
// Submodules
auto svd = m.def_submodule("svd");
auto eig = m.def_submodule("eig");
auto schur = m.def_submodule("schur");
// Enums
nb::enum_<svd::ComputationMode>(svd, "ComputationMode")
// kComputeVtOverwriteXPartialU is not implemented
Expand All @@ -289,6 +299,14 @@ NB_MODULE(_lapack, m) {
nb::enum_<eig::ComputationMode>(eig, "ComputationMode")
.value("kComputeEigenvectors", eig::ComputationMode::kComputeEigenvectors)
.value("kNoEigenvectors", eig::ComputationMode::kNoEigenvectors);
nb::enum_<schur::ComputationMode>(schur, "ComputationMode")
.value("kNoComputeSchurVectors",
schur::ComputationMode::kNoComputeSchurVectors)
.value("kComputeSchurVectors",
schur::ComputationMode::kComputeSchurVectors);
nb::enum_<schur::Sort>(schur, "Sort")
.value("kNoSortEigenvalues", schur::Sort::kNoSortEigenvalues)
.value("kSortEigenvalues", schur::Sort::kSortEigenvalues);

// Old-style LAPACK Workspace Size Queries
m.def("lapack_sgeqrf_workspace", &Geqrf<float>::Workspace, nb::arg("m"),
Expand Down
210 changes: 210 additions & 0 deletions jaxlib/cpu/lapack_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Diag);
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::UpLo);
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::svd::ComputationMode);
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::eig::ComputationMode);
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::schur::ComputationMode);
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::schur::Sort);

#undef REGISTER_CHAR_ENUM_ATTR_DECODING

Expand Down Expand Up @@ -1573,6 +1575,180 @@ template struct RealGees<double>;
template struct ComplexGees<std::complex<float>>;
template struct ComplexGees<std::complex<double>>;

// FFI Kernel

template <ffi::DataType dtype>
ffi::Error SchurDecomposition<dtype>::Kernel(
ffi::Buffer<dtype> x, schur::ComputationMode mode, schur::Sort sort,
ffi::ResultBuffer<dtype> x_out, ffi::ResultBuffer<dtype> schur_vectors,
ffi::ResultBuffer<dtype> eigvals_real,
ffi::ResultBuffer<dtype> eigvals_imag,
// TODO(paruzelp): Sort is not implemented because select function is not
// supplied. For that reason, this parameter will always be zero!
ffi::ResultBuffer<LapackIntDtype> selected_eigvals,
ffi::ResultBuffer<LapackIntDtype> info) {
FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]),
SplitBatch2D(x.dimensions()));
if (sort != schur::Sort::kNoSortEigenvalues) {
return ffi::Error(
ffi::ErrorCode::kUnimplemented,
"Ordering eigenvalues on the diagonal is not implemented");
}

CopyIfDiffBuffer(x, x_out);

// TODO(paruzelp): `select` should be passed as an execution context
bool (*select)(ValueType, ValueType) = nullptr;
ValueType* x_out_data = x_out->typed_data();
ValueType* eigvals_real_data = eigvals_real->typed_data();
ValueType* eigvals_imag_data = eigvals_imag->typed_data();
ValueType* schur_vectors_data = schur_vectors->typed_data();
lapack_int* selected_data = selected_eigvals->typed_data();
lapack_int* info_data = info->typed_data();

auto mode_v = static_cast<char>(mode);
auto sort_v = static_cast<char>(sort);
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));

// Prepare LAPACK workspaces.
std::unique_ptr<bool[]> bwork =
sort != schur::Sort::kNoSortEigenvalues
? AllocateScratchMemory<ffi::DataType::PRED>(x_cols)
: nullptr;
auto work_size = GetWorkspaceSize(x_cols, mode, sort);
FFI_ASSIGN_OR_RETURN(auto work_size_v,
MaybeCastNoOverflow<lapack_int>(work_size));
auto work_data = AllocateScratchMemory<dtype>(work_size);

const int64_t x_size{x_cols * x_cols};
[[maybe_unused]] const auto x_size_bytes =
static_cast<unsigned long>(x_size) * sizeof(ValueType);
[[maybe_unused]] const auto x_cols_bytes =
static_cast<unsigned long>(x_cols) * sizeof(ValueType);
for (int64_t i = 0; i < batch_count; ++i) {
fn(&mode_v, &sort_v, select, &x_cols_v, x_out_data, &x_cols_v,
selected_data, eigvals_real_data, eigvals_imag_data, schur_vectors_data,
&x_cols_v, work_data.get(), &work_size_v, bwork.get(), info_data);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_out_data, x_size_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(selected_data, sizeof(lapack_int));
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_real_data, x_cols_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_imag_data, x_cols_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(schur_vectors_data, x_size_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_data, sizeof(lapack_int));

x_out_data += x_size;
eigvals_real_data += x_cols;
eigvals_imag_data += x_cols;
schur_vectors_data += x_size;
++selected_data;
++info_data;
}

return ffi::Error::Success();
}

template <ffi::DataType dtype>
ffi::Error SchurDecompositionComplex<dtype>::Kernel(
ffi::Buffer<dtype> x, schur::ComputationMode mode, schur::Sort sort,
ffi::ResultBuffer<dtype> x_out, ffi::ResultBuffer<dtype> schur_vectors,
ffi::ResultBuffer<dtype> eigvals,
// TODO(paruzelp): Sort is not implemented because select function is not
// supplied. For that reason, this parameter will always be zero!
ffi::ResultBuffer<LapackIntDtype> selected_eigvals,
ffi::ResultBuffer<LapackIntDtype> info) {
FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]),
SplitBatch2D(x.dimensions()));
if (sort != schur::Sort::kNoSortEigenvalues) {
return ffi::Error(
ffi::ErrorCode::kUnimplemented,
"Ordering eigenvalues on the diagonal is not implemented");
}

CopyIfDiffBuffer(x, x_out);

// TODO(paruzelp): `select` should be passed as an execution context
bool (*select)(ValueType) = nullptr;
ValueType* x_out_data = x_out->typed_data();
ValueType* eigvals_data = eigvals->typed_data();
ValueType* schur_vectors_data = schur_vectors->typed_data();
lapack_int* selected_data = selected_eigvals->typed_data();
lapack_int* info_data = info->typed_data();

auto mode_v = static_cast<char>(mode);
auto sort_v = static_cast<char>(sort);
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));

// Prepare LAPACK workspaces.
std::unique_ptr<bool[]> bwork =
sort != schur::Sort::kNoSortEigenvalues
? AllocateScratchMemory<ffi::DataType::PRED>(x_cols)
: nullptr;
auto work_size = GetWorkspaceSize(x_cols, mode, sort);
FFI_ASSIGN_OR_RETURN(auto work_size_v,
MaybeCastNoOverflow<lapack_int>(work_size));
auto work_data = AllocateScratchMemory<dtype>(work_size);
auto rwork_data = AllocateScratchMemory<ffi::ToReal(dtype)>(x_cols);

const int64_t x_size{x_cols * x_cols};
[[maybe_unused]] const auto x_size_bytes =
static_cast<unsigned long>(x_size) * sizeof(ValueType);
[[maybe_unused]] const auto x_cols_bytes =
static_cast<unsigned long>(x_cols) * sizeof(ValueType);
for (int64_t i = 0; i < batch_count; ++i) {
fn(&mode_v, &sort_v, select, &x_cols_v, x_out_data, &x_cols_v,
selected_data, eigvals_data, schur_vectors_data, &x_cols_v,
work_data.get(), &work_size_v, rwork_data.get(), bwork.get(), info_data);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_data, x_cols_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(schur_vectors_data, x_size_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_data, sizeof(lapack_int));
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(selected_data, sizeof(lapack_int));

x_out_data += x_size;
eigvals_data += x_cols;
schur_vectors_data += x_size;
++selected_data;
++info_data;
}

return ffi::Error::Success();
}

template <ffi::DataType dtype>
int64_t SchurDecomposition<dtype>::GetWorkspaceSize(lapack_int x_cols,
schur::ComputationMode mode,
schur::Sort sort) {
ValueType optimal_size = {};
lapack_int workspace_query = -1;
lapack_int info = 0;

auto mode_v = static_cast<char>(mode);
auto sort_v = static_cast<char>(sort);
fn(&mode_v, &sort_v, nullptr, &x_cols, nullptr, &x_cols, nullptr, nullptr,
nullptr, nullptr, &x_cols, &optimal_size, &workspace_query, nullptr,
&info);
return info == 0 ? static_cast<int64_t>(std::real(optimal_size)) : -1;
};

template <ffi::DataType dtype>
int64_t SchurDecompositionComplex<dtype>::GetWorkspaceSize(
lapack_int x_cols, schur::ComputationMode mode, schur::Sort sort) {
ValueType optimal_size = {};
lapack_int workspace_query = -1;
lapack_int info = 0;

auto mode_v = static_cast<char>(mode);
auto sort_v = static_cast<char>(sort);
fn(&mode_v, &sort_v, nullptr, &x_cols, nullptr, &x_cols, nullptr, nullptr,
nullptr, &x_cols, &optimal_size, &workspace_query, nullptr, nullptr,
&info);
return info == 0 ? static_cast<int64_t>(std::real(optimal_size)) : -1;
};

template struct SchurDecomposition<ffi::DataType::F32>;
template struct SchurDecomposition<ffi::DataType::F64>;
template struct SchurDecompositionComplex<ffi::DataType::C64>;
template struct SchurDecompositionComplex<ffi::DataType::C128>;

//== Hessenberg Decomposition ==//

// lapack gehrd
Expand Down Expand Up @@ -1926,6 +2102,33 @@ template struct TridiagonalReduction<ffi::DataType::C128>;
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvecs_right*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/))

#define JAX_CPU_DEFINE_GEES(name, data_type) \
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
name, SchurDecomposition<data_type>::Kernel, \
::xla::ffi::Ffi::Bind() \
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
.Attr<schur::ComputationMode>("mode") \
.Attr<schur::Sort>("sort") \
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*schur_vectors*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvals_real*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvals_imag*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*selected_eigvals*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/))

#define JAX_CPU_DEFINE_GEES_COMPLEX(name, data_type) \
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
name, SchurDecompositionComplex<data_type>::Kernel, \
::xla::ffi::Ffi::Bind() \
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
.Attr<schur::ComputationMode>("mode") \
.Attr<schur::Sort>("sort") \
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*schur_vectors*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvals*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*selected_eigvals*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/))

#define JAX_CPU_DEFINE_SYTRD_HETRD(name, data_type) \
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
name, TridiagonalReduction<data_type>::Kernel, \
Expand Down Expand Up @@ -1998,6 +2201,11 @@ JAX_CPU_DEFINE_SYTRD_HETRD(lapack_dsytrd_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_SYTRD_HETRD(lapack_chetrd_ffi, ::xla::ffi::DataType::C64);
JAX_CPU_DEFINE_SYTRD_HETRD(lapack_zhetrd_ffi, ::xla::ffi::DataType::C128);

JAX_CPU_DEFINE_GEES(lapack_sgees_ffi, ::xla::ffi::DataType::F32);
JAX_CPU_DEFINE_GEES(lapack_dgees_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_GEES_COMPLEX(lapack_cgees_ffi, ::xla::ffi::DataType::C64);
JAX_CPU_DEFINE_GEES_COMPLEX(lapack_zgees_ffi, ::xla::ffi::DataType::C128);

JAX_CPU_DEFINE_GEHRD(lapack_sgehrd_ffi, ::xla::ffi::DataType::F32);
JAX_CPU_DEFINE_GEHRD(lapack_dgehrd_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_GEHRD(lapack_cgehrd_ffi, ::xla::ffi::DataType::C64);
Expand All @@ -2015,6 +2223,8 @@ JAX_CPU_DEFINE_GEHRD(lapack_zgehrd_ffi, ::xla::ffi::DataType::C128);
#undef JAX_CPU_DEFINE_GEEV
#undef JAX_CPU_DEFINE_GEEV_COMPLEX
#undef JAX_CPU_DEFINE_SYTRD_HETRD
#undef JAX_CPU_DEFINE_GEES
#undef JAX_CPU_DEFINE_GEES_COMPLEX
#undef JAX_CPU_DEFINE_GEHRD

} // namespace jax
Loading

0 comments on commit 738ab4e

Please sign in to comment.