Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port Schur Decomposition to XLA's FFI #21609

Merged
merged 1 commit into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions jaxlib/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ pybind_extension(
module_name = "_lapack",
pytype_srcs = [
"_lapack/__init__.pyi",
"_lapack/schur.pyi",
"_lapack/svd.pyi",
"_lapack/eig.pyi",
],
Expand Down
26 changes: 26 additions & 0 deletions jaxlib/cpu/_lapack/schur.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
from typing import ClassVar


class ComputationMode(enum.Enum):
kComputeSchurVectors: ClassVar[ComputationMode]
kNoComputeSchurVectors: ClassVar[ComputationMode]


class Sort(enum.Enum):
kNoSortEigenvalues: ClassVar[Sort]
kSortEigenvalues: ClassVar[Sort]
4 changes: 4 additions & 0 deletions jaxlib/cpu/cpu_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ JAX_CPU_REGISTER_HANDLER(lapack_ssytrd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_dsytrd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_chetrd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_zhetrd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_sgees_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_dgees_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_cgees_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_zgees_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_sgehrd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_dgehrd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_cgehrd_ffi);
Expand Down
18 changes: 18 additions & 0 deletions jaxlib/cpu/lapack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ void GetLapackKernelsFromScipy() {
AssignKernelFn<RealGees<double>>(lapack_ptr("dgees"));
AssignKernelFn<ComplexGees<std::complex<float>>>(lapack_ptr("cgees"));
AssignKernelFn<ComplexGees<std::complex<double>>>(lapack_ptr("zgees"));
AssignKernelFn<SchurDecomposition<DataType::F32>>(lapack_ptr("sgees"));
AssignKernelFn<SchurDecomposition<DataType::F64>>(lapack_ptr("dgees"));
AssignKernelFn<SchurDecompositionComplex<DataType::C64>>(lapack_ptr("cgees"));
AssignKernelFn<SchurDecompositionComplex<DataType::C128>>(
lapack_ptr("zgees"));

AssignKernelFn<Gehrd<float>>(lapack_ptr("sgehrd"));
AssignKernelFn<Gehrd<double>>(lapack_ptr("dgehrd"));
Expand Down Expand Up @@ -265,6 +270,10 @@ nb::dict Registrations() {
dict["lapack_dsytrd_ffi"] = EncapsulateFunction(lapack_dsytrd_ffi);
dict["lapack_chetrd_ffi"] = EncapsulateFunction(lapack_chetrd_ffi);
dict["lapack_zhetrd_ffi"] = EncapsulateFunction(lapack_zhetrd_ffi);
dict["lapack_sgees_ffi"] = EncapsulateFunction(lapack_sgees_ffi);
dict["lapack_dgees_ffi"] = EncapsulateFunction(lapack_dgees_ffi);
dict["lapack_cgees_ffi"] = EncapsulateFunction(lapack_cgees_ffi);
dict["lapack_zgees_ffi"] = EncapsulateFunction(lapack_zgees_ffi);
dict["lapack_sgehrd_ffi"] = EncapsulateFunction(lapack_sgehrd_ffi);
dict["lapack_dgehrd_ffi"] = EncapsulateFunction(lapack_dgehrd_ffi);
dict["lapack_cgehrd_ffi"] = EncapsulateFunction(lapack_cgehrd_ffi);
Expand All @@ -280,6 +289,7 @@ NB_MODULE(_lapack, m) {
// Submodules
auto svd = m.def_submodule("svd");
auto eig = m.def_submodule("eig");
auto schur = m.def_submodule("schur");
// Enums
nb::enum_<svd::ComputationMode>(svd, "ComputationMode")
// kComputeVtOverwriteXPartialU is not implemented
Expand All @@ -289,6 +299,14 @@ NB_MODULE(_lapack, m) {
nb::enum_<eig::ComputationMode>(eig, "ComputationMode")
.value("kComputeEigenvectors", eig::ComputationMode::kComputeEigenvectors)
.value("kNoEigenvectors", eig::ComputationMode::kNoEigenvectors);
nb::enum_<schur::ComputationMode>(schur, "ComputationMode")
.value("kNoComputeSchurVectors",
schur::ComputationMode::kNoComputeSchurVectors)
.value("kComputeSchurVectors",
schur::ComputationMode::kComputeSchurVectors);
nb::enum_<schur::Sort>(schur, "Sort")
.value("kNoSortEigenvalues", schur::Sort::kNoSortEigenvalues)
.value("kSortEigenvalues", schur::Sort::kSortEigenvalues);

// Old-style LAPACK Workspace Size Queries
m.def("lapack_sgeqrf_workspace", &Geqrf<float>::Workspace, nb::arg("m"),
Expand Down
210 changes: 210 additions & 0 deletions jaxlib/cpu/lapack_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Diag);
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::UpLo);
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::svd::ComputationMode);
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::eig::ComputationMode);
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::schur::ComputationMode);
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::schur::Sort);

#undef REGISTER_CHAR_ENUM_ATTR_DECODING

Expand Down Expand Up @@ -1573,6 +1575,180 @@ template struct RealGees<double>;
template struct ComplexGees<std::complex<float>>;
template struct ComplexGees<std::complex<double>>;

// FFI Kernel

template <ffi::DataType dtype>
ffi::Error SchurDecomposition<dtype>::Kernel(
ffi::Buffer<dtype> x, schur::ComputationMode mode, schur::Sort sort,
ffi::ResultBuffer<dtype> x_out, ffi::ResultBuffer<dtype> schur_vectors,
ffi::ResultBuffer<dtype> eigvals_real,
ffi::ResultBuffer<dtype> eigvals_imag,
// TODO(paruzelp): Sort is not implemented because select function is not
// supplied. For that reason, this parameter will always be zero!
ffi::ResultBuffer<LapackIntDtype> selected_eigvals,
ffi::ResultBuffer<LapackIntDtype> info) {
FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]),
SplitBatch2D(x.dimensions()));
if (sort != schur::Sort::kNoSortEigenvalues) {
return ffi::Error(
ffi::ErrorCode::kUnimplemented,
"Ordering eigenvalues on the diagonal is not implemented");
}

CopyIfDiffBuffer(x, x_out);

// TODO(paruzelp): `select` should be passed as an execution context
bool (*select)(ValueType, ValueType) = nullptr;
ValueType* x_out_data = x_out->typed_data();
ValueType* eigvals_real_data = eigvals_real->typed_data();
ValueType* eigvals_imag_data = eigvals_imag->typed_data();
ValueType* schur_vectors_data = schur_vectors->typed_data();
lapack_int* selected_data = selected_eigvals->typed_data();
lapack_int* info_data = info->typed_data();

auto mode_v = static_cast<char>(mode);
auto sort_v = static_cast<char>(sort);
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));

// Prepare LAPACK workspaces.
std::unique_ptr<bool[]> bwork =
sort != schur::Sort::kNoSortEigenvalues
? AllocateScratchMemory<ffi::DataType::PRED>(x_cols)
: nullptr;
auto work_size = GetWorkspaceSize(x_cols, mode, sort);
FFI_ASSIGN_OR_RETURN(auto work_size_v,
MaybeCastNoOverflow<lapack_int>(work_size));
auto work_data = AllocateScratchMemory<dtype>(work_size);

const int64_t x_size{x_cols * x_cols};
[[maybe_unused]] const auto x_size_bytes =
static_cast<unsigned long>(x_size) * sizeof(ValueType);
[[maybe_unused]] const auto x_cols_bytes =
static_cast<unsigned long>(x_cols) * sizeof(ValueType);
for (int64_t i = 0; i < batch_count; ++i) {
fn(&mode_v, &sort_v, select, &x_cols_v, x_out_data, &x_cols_v,
selected_data, eigvals_real_data, eigvals_imag_data, schur_vectors_data,
&x_cols_v, work_data.get(), &work_size_v, bwork.get(), info_data);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_out_data, x_size_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(selected_data, sizeof(lapack_int));
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_real_data, x_cols_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_imag_data, x_cols_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(schur_vectors_data, x_size_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_data, sizeof(lapack_int));

x_out_data += x_size;
eigvals_real_data += x_cols;
eigvals_imag_data += x_cols;
schur_vectors_data += x_size;
++selected_data;
++info_data;
}

return ffi::Error::Success();
}

template <ffi::DataType dtype>
ffi::Error SchurDecompositionComplex<dtype>::Kernel(
ffi::Buffer<dtype> x, schur::ComputationMode mode, schur::Sort sort,
ffi::ResultBuffer<dtype> x_out, ffi::ResultBuffer<dtype> schur_vectors,
ffi::ResultBuffer<dtype> eigvals,
// TODO(paruzelp): Sort is not implemented because select function is not
// supplied. For that reason, this parameter will always be zero!
ffi::ResultBuffer<LapackIntDtype> selected_eigvals,
ffi::ResultBuffer<LapackIntDtype> info) {
FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]),
SplitBatch2D(x.dimensions()));
if (sort != schur::Sort::kNoSortEigenvalues) {
return ffi::Error(
ffi::ErrorCode::kUnimplemented,
"Ordering eigenvalues on the diagonal is not implemented");
}

CopyIfDiffBuffer(x, x_out);

// TODO(paruzelp): `select` should be passed as an execution context
bool (*select)(ValueType) = nullptr;
ValueType* x_out_data = x_out->typed_data();
ValueType* eigvals_data = eigvals->typed_data();
ValueType* schur_vectors_data = schur_vectors->typed_data();
lapack_int* selected_data = selected_eigvals->typed_data();
lapack_int* info_data = info->typed_data();

auto mode_v = static_cast<char>(mode);
auto sort_v = static_cast<char>(sort);
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));

// Prepare LAPACK workspaces.
std::unique_ptr<bool[]> bwork =
sort != schur::Sort::kNoSortEigenvalues
? AllocateScratchMemory<ffi::DataType::PRED>(x_cols)
: nullptr;
auto work_size = GetWorkspaceSize(x_cols, mode, sort);
FFI_ASSIGN_OR_RETURN(auto work_size_v,
MaybeCastNoOverflow<lapack_int>(work_size));
auto work_data = AllocateScratchMemory<dtype>(work_size);
auto rwork_data = AllocateScratchMemory<ffi::ToReal(dtype)>(x_cols);

const int64_t x_size{x_cols * x_cols};
[[maybe_unused]] const auto x_size_bytes =
static_cast<unsigned long>(x_size) * sizeof(ValueType);
[[maybe_unused]] const auto x_cols_bytes =
static_cast<unsigned long>(x_cols) * sizeof(ValueType);
for (int64_t i = 0; i < batch_count; ++i) {
fn(&mode_v, &sort_v, select, &x_cols_v, x_out_data, &x_cols_v,
selected_data, eigvals_data, schur_vectors_data, &x_cols_v,
work_data.get(), &work_size_v, rwork_data.get(), bwork.get(), info_data);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_data, x_cols_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(schur_vectors_data, x_size_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_data, sizeof(lapack_int));
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(selected_data, sizeof(lapack_int));

x_out_data += x_size;
eigvals_data += x_cols;
schur_vectors_data += x_size;
++selected_data;
++info_data;
}

return ffi::Error::Success();
}

template <ffi::DataType dtype>
int64_t SchurDecomposition<dtype>::GetWorkspaceSize(lapack_int x_cols,
schur::ComputationMode mode,
schur::Sort sort) {
ValueType optimal_size = {};
lapack_int workspace_query = -1;
lapack_int info = 0;

auto mode_v = static_cast<char>(mode);
auto sort_v = static_cast<char>(sort);
fn(&mode_v, &sort_v, nullptr, &x_cols, nullptr, &x_cols, nullptr, nullptr,
nullptr, nullptr, &x_cols, &optimal_size, &workspace_query, nullptr,
&info);
return info == 0 ? static_cast<int64_t>(std::real(optimal_size)) : -1;
};

template <ffi::DataType dtype>
int64_t SchurDecompositionComplex<dtype>::GetWorkspaceSize(
lapack_int x_cols, schur::ComputationMode mode, schur::Sort sort) {
ValueType optimal_size = {};
lapack_int workspace_query = -1;
lapack_int info = 0;

auto mode_v = static_cast<char>(mode);
auto sort_v = static_cast<char>(sort);
fn(&mode_v, &sort_v, nullptr, &x_cols, nullptr, &x_cols, nullptr, nullptr,
nullptr, &x_cols, &optimal_size, &workspace_query, nullptr, nullptr,
&info);
return info == 0 ? static_cast<int64_t>(std::real(optimal_size)) : -1;
};

template struct SchurDecomposition<ffi::DataType::F32>;
template struct SchurDecomposition<ffi::DataType::F64>;
template struct SchurDecompositionComplex<ffi::DataType::C64>;
template struct SchurDecompositionComplex<ffi::DataType::C128>;

//== Hessenberg Decomposition ==//

// lapack gehrd
Expand Down Expand Up @@ -1926,6 +2102,33 @@ template struct TridiagonalReduction<ffi::DataType::C128>;
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvecs_right*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/))

#define JAX_CPU_DEFINE_GEES(name, data_type) \
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
name, SchurDecomposition<data_type>::Kernel, \
::xla::ffi::Ffi::Bind() \
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
.Attr<schur::ComputationMode>("mode") \
.Attr<schur::Sort>("sort") \
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*schur_vectors*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvals_real*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvals_imag*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*selected_eigvals*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/))

#define JAX_CPU_DEFINE_GEES_COMPLEX(name, data_type) \
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
name, SchurDecompositionComplex<data_type>::Kernel, \
::xla::ffi::Ffi::Bind() \
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
.Attr<schur::ComputationMode>("mode") \
.Attr<schur::Sort>("sort") \
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*schur_vectors*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvals*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*selected_eigvals*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/))

#define JAX_CPU_DEFINE_SYTRD_HETRD(name, data_type) \
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
name, TridiagonalReduction<data_type>::Kernel, \
Expand Down Expand Up @@ -1998,6 +2201,11 @@ JAX_CPU_DEFINE_SYTRD_HETRD(lapack_dsytrd_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_SYTRD_HETRD(lapack_chetrd_ffi, ::xla::ffi::DataType::C64);
JAX_CPU_DEFINE_SYTRD_HETRD(lapack_zhetrd_ffi, ::xla::ffi::DataType::C128);

JAX_CPU_DEFINE_GEES(lapack_sgees_ffi, ::xla::ffi::DataType::F32);
JAX_CPU_DEFINE_GEES(lapack_dgees_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_GEES_COMPLEX(lapack_cgees_ffi, ::xla::ffi::DataType::C64);
JAX_CPU_DEFINE_GEES_COMPLEX(lapack_zgees_ffi, ::xla::ffi::DataType::C128);

JAX_CPU_DEFINE_GEHRD(lapack_sgehrd_ffi, ::xla::ffi::DataType::F32);
JAX_CPU_DEFINE_GEHRD(lapack_dgehrd_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_GEHRD(lapack_cgehrd_ffi, ::xla::ffi::DataType::C64);
Expand All @@ -2015,6 +2223,8 @@ JAX_CPU_DEFINE_GEHRD(lapack_zgehrd_ffi, ::xla::ffi::DataType::C128);
#undef JAX_CPU_DEFINE_GEEV
#undef JAX_CPU_DEFINE_GEEV_COMPLEX
#undef JAX_CPU_DEFINE_SYTRD_HETRD
#undef JAX_CPU_DEFINE_GEES
#undef JAX_CPU_DEFINE_GEES_COMPLEX
#undef JAX_CPU_DEFINE_GEHRD

} // namespace jax
Loading
Loading