diff --git a/jaxlib/cpu/BUILD b/jaxlib/cpu/BUILD index d3d15c4fc939..bcc36aa8c37c 100644 --- a/jaxlib/cpu/BUILD +++ b/jaxlib/cpu/BUILD @@ -65,6 +65,7 @@ pybind_extension( module_name = "_lapack", pytype_srcs = [ "_lapack/__init__.pyi", + "_lapack/schur.pyi", "_lapack/svd.pyi", "_lapack/eig.pyi", ], diff --git a/jaxlib/cpu/_lapack/schur.pyi b/jaxlib/cpu/_lapack/schur.pyi new file mode 100644 index 000000000000..add01b049390 --- /dev/null +++ b/jaxlib/cpu/_lapack/schur.pyi @@ -0,0 +1,26 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +from typing import ClassVar + + +class ComputationMode(enum.Enum): + kComputeSchurVectors: ClassVar[ComputationMode] + kNoComputeSchurVectors: ClassVar[ComputationMode] + + +class Sort(enum.Enum): + kNoSortEigenvalues: ClassVar[Sort] + kSortEigenvalues: ClassVar[Sort] diff --git a/jaxlib/cpu/cpu_kernels.cc b/jaxlib/cpu/cpu_kernels.cc index eb3029c628f6..b10fabcff255 100644 --- a/jaxlib/cpu/cpu_kernels.cc +++ b/jaxlib/cpu/cpu_kernels.cc @@ -153,6 +153,10 @@ JAX_CPU_REGISTER_HANDLER(lapack_ssytrd_ffi); JAX_CPU_REGISTER_HANDLER(lapack_dsytrd_ffi); JAX_CPU_REGISTER_HANDLER(lapack_chetrd_ffi); JAX_CPU_REGISTER_HANDLER(lapack_zhetrd_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_sgees_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_dgees_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_cgees_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_zgees_ffi); JAX_CPU_REGISTER_HANDLER(lapack_sgehrd_ffi); JAX_CPU_REGISTER_HANDLER(lapack_dgehrd_ffi); JAX_CPU_REGISTER_HANDLER(lapack_cgehrd_ffi); diff --git a/jaxlib/cpu/lapack.cc b/jaxlib/cpu/lapack.cc index fc04963a0c81..21fc79eba92c 100644 --- a/jaxlib/cpu/lapack.cc +++ b/jaxlib/cpu/lapack.cc @@ -137,6 +137,11 @@ void GetLapackKernelsFromScipy() { AssignKernelFn>(lapack_ptr("dgees")); AssignKernelFn>>(lapack_ptr("cgees")); AssignKernelFn>>(lapack_ptr("zgees")); + AssignKernelFn>(lapack_ptr("sgees")); + AssignKernelFn>(lapack_ptr("dgees")); + AssignKernelFn>(lapack_ptr("cgees")); + AssignKernelFn>( + lapack_ptr("zgees")); AssignKernelFn>(lapack_ptr("sgehrd")); AssignKernelFn>(lapack_ptr("dgehrd")); @@ -265,6 +270,10 @@ nb::dict Registrations() { dict["lapack_dsytrd_ffi"] = EncapsulateFunction(lapack_dsytrd_ffi); dict["lapack_chetrd_ffi"] = EncapsulateFunction(lapack_chetrd_ffi); dict["lapack_zhetrd_ffi"] = EncapsulateFunction(lapack_zhetrd_ffi); + dict["lapack_sgees_ffi"] = EncapsulateFunction(lapack_sgees_ffi); + dict["lapack_dgees_ffi"] = EncapsulateFunction(lapack_dgees_ffi); + dict["lapack_cgees_ffi"] = EncapsulateFunction(lapack_cgees_ffi); + dict["lapack_zgees_ffi"] = EncapsulateFunction(lapack_zgees_ffi); dict["lapack_sgehrd_ffi"] = EncapsulateFunction(lapack_sgehrd_ffi); dict["lapack_dgehrd_ffi"] = EncapsulateFunction(lapack_dgehrd_ffi); dict["lapack_cgehrd_ffi"] = EncapsulateFunction(lapack_cgehrd_ffi); @@ -280,6 +289,7 @@ NB_MODULE(_lapack, m) { // Submodules auto svd = m.def_submodule("svd"); auto eig = m.def_submodule("eig"); + auto schur = m.def_submodule("schur"); // Enums nb::enum_(svd, "ComputationMode") // kComputeVtOverwriteXPartialU is not implemented @@ -289,6 +299,14 @@ NB_MODULE(_lapack, m) { nb::enum_(eig, "ComputationMode") .value("kComputeEigenvectors", eig::ComputationMode::kComputeEigenvectors) .value("kNoEigenvectors", eig::ComputationMode::kNoEigenvectors); + nb::enum_(schur, "ComputationMode") + .value("kNoComputeSchurVectors", + schur::ComputationMode::kNoComputeSchurVectors) + .value("kComputeSchurVectors", + schur::ComputationMode::kComputeSchurVectors); + nb::enum_(schur, "Sort") + .value("kNoSortEigenvalues", schur::Sort::kNoSortEigenvalues) + .value("kSortEigenvalues", schur::Sort::kSortEigenvalues); // Old-style LAPACK Workspace Size Queries m.def("lapack_sgeqrf_workspace", &Geqrf::Workspace, nb::arg("m"), diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index 4b25540332fa..19b82a5ce149 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -64,6 +64,8 @@ REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Diag); REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::UpLo); REGISTER_CHAR_ENUM_ATTR_DECODING(jax::svd::ComputationMode); REGISTER_CHAR_ENUM_ATTR_DECODING(jax::eig::ComputationMode); +REGISTER_CHAR_ENUM_ATTR_DECODING(jax::schur::ComputationMode); +REGISTER_CHAR_ENUM_ATTR_DECODING(jax::schur::Sort); #undef REGISTER_CHAR_ENUM_ATTR_DECODING @@ -1573,6 +1575,180 @@ template struct RealGees; template struct ComplexGees>; template struct ComplexGees>; +// FFI Kernel + +template +ffi::Error SchurDecomposition::Kernel( + ffi::Buffer x, schur::ComputationMode mode, schur::Sort sort, + ffi::ResultBuffer x_out, ffi::ResultBuffer schur_vectors, + ffi::ResultBuffer eigvals_real, + ffi::ResultBuffer eigvals_imag, + // TODO(paruzelp): Sort is not implemented because select function is not + // supplied. For that reason, this parameter will always be zero! + ffi::ResultBuffer selected_eigvals, + ffi::ResultBuffer info) { + FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), + SplitBatch2D(x.dimensions())); + if (sort != schur::Sort::kNoSortEigenvalues) { + return ffi::Error( + ffi::ErrorCode::kUnimplemented, + "Ordering eigenvalues on the diagonal is not implemented"); + } + + CopyIfDiffBuffer(x, x_out); + + // TODO(paruzelp): `select` should be passed as an execution context + bool (*select)(ValueType, ValueType) = nullptr; + ValueType* x_out_data = x_out->typed_data(); + ValueType* eigvals_real_data = eigvals_real->typed_data(); + ValueType* eigvals_imag_data = eigvals_imag->typed_data(); + ValueType* schur_vectors_data = schur_vectors->typed_data(); + lapack_int* selected_data = selected_eigvals->typed_data(); + lapack_int* info_data = info->typed_data(); + + auto mode_v = static_cast(mode); + auto sort_v = static_cast(sort); + FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow(x_cols)); + + // Prepare LAPACK workspaces. + std::unique_ptr bwork = + sort != schur::Sort::kNoSortEigenvalues + ? AllocateScratchMemory(x_cols) + : nullptr; + auto work_size = GetWorkspaceSize(x_cols, mode, sort); + FFI_ASSIGN_OR_RETURN(auto work_size_v, + MaybeCastNoOverflow(work_size)); + auto work_data = AllocateScratchMemory(work_size); + + const int64_t x_size{x_cols * x_cols}; + [[maybe_unused]] const auto x_size_bytes = + static_cast(x_size) * sizeof(ValueType); + [[maybe_unused]] const auto x_cols_bytes = + static_cast(x_cols) * sizeof(ValueType); + for (int64_t i = 0; i < batch_count; ++i) { + fn(&mode_v, &sort_v, select, &x_cols_v, x_out_data, &x_cols_v, + selected_data, eigvals_real_data, eigvals_imag_data, schur_vectors_data, + &x_cols_v, work_data.get(), &work_size_v, bwork.get(), info_data); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_out_data, x_size_bytes); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(selected_data, sizeof(lapack_int)); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_real_data, x_cols_bytes); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_imag_data, x_cols_bytes); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(schur_vectors_data, x_size_bytes); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_data, sizeof(lapack_int)); + + x_out_data += x_size; + eigvals_real_data += x_cols; + eigvals_imag_data += x_cols; + schur_vectors_data += x_size; + ++selected_data; + ++info_data; + } + + return ffi::Error::Success(); +} + +template +ffi::Error SchurDecompositionComplex::Kernel( + ffi::Buffer x, schur::ComputationMode mode, schur::Sort sort, + ffi::ResultBuffer x_out, ffi::ResultBuffer schur_vectors, + ffi::ResultBuffer eigvals, + // TODO(paruzelp): Sort is not implemented because select function is not + // supplied. For that reason, this parameter will always be zero! + ffi::ResultBuffer selected_eigvals, + ffi::ResultBuffer info) { + FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), + SplitBatch2D(x.dimensions())); + if (sort != schur::Sort::kNoSortEigenvalues) { + return ffi::Error( + ffi::ErrorCode::kUnimplemented, + "Ordering eigenvalues on the diagonal is not implemented"); + } + + CopyIfDiffBuffer(x, x_out); + + // TODO(paruzelp): `select` should be passed as an execution context + bool (*select)(ValueType) = nullptr; + ValueType* x_out_data = x_out->typed_data(); + ValueType* eigvals_data = eigvals->typed_data(); + ValueType* schur_vectors_data = schur_vectors->typed_data(); + lapack_int* selected_data = selected_eigvals->typed_data(); + lapack_int* info_data = info->typed_data(); + + auto mode_v = static_cast(mode); + auto sort_v = static_cast(sort); + FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow(x_cols)); + + // Prepare LAPACK workspaces. + std::unique_ptr bwork = + sort != schur::Sort::kNoSortEigenvalues + ? AllocateScratchMemory(x_cols) + : nullptr; + auto work_size = GetWorkspaceSize(x_cols, mode, sort); + FFI_ASSIGN_OR_RETURN(auto work_size_v, + MaybeCastNoOverflow(work_size)); + auto work_data = AllocateScratchMemory(work_size); + auto rwork_data = AllocateScratchMemory(x_cols); + + const int64_t x_size{x_cols * x_cols}; + [[maybe_unused]] const auto x_size_bytes = + static_cast(x_size) * sizeof(ValueType); + [[maybe_unused]] const auto x_cols_bytes = + static_cast(x_cols) * sizeof(ValueType); + for (int64_t i = 0; i < batch_count; ++i) { + fn(&mode_v, &sort_v, select, &x_cols_v, x_out_data, &x_cols_v, + selected_data, eigvals_data, schur_vectors_data, &x_cols_v, + work_data.get(), &work_size_v, rwork_data.get(), bwork.get(), info_data); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_data, x_cols_bytes); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(schur_vectors_data, x_size_bytes); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_data, sizeof(lapack_int)); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(selected_data, sizeof(lapack_int)); + + x_out_data += x_size; + eigvals_data += x_cols; + schur_vectors_data += x_size; + ++selected_data; + ++info_data; + } + + return ffi::Error::Success(); +} + +template +int64_t SchurDecomposition::GetWorkspaceSize(lapack_int x_cols, + schur::ComputationMode mode, + schur::Sort sort) { + ValueType optimal_size = {}; + lapack_int workspace_query = -1; + lapack_int info = 0; + + auto mode_v = static_cast(mode); + auto sort_v = static_cast(sort); + fn(&mode_v, &sort_v, nullptr, &x_cols, nullptr, &x_cols, nullptr, nullptr, + nullptr, nullptr, &x_cols, &optimal_size, &workspace_query, nullptr, + &info); + return info == 0 ? static_cast(std::real(optimal_size)) : -1; +}; + +template +int64_t SchurDecompositionComplex::GetWorkspaceSize( + lapack_int x_cols, schur::ComputationMode mode, schur::Sort sort) { + ValueType optimal_size = {}; + lapack_int workspace_query = -1; + lapack_int info = 0; + + auto mode_v = static_cast(mode); + auto sort_v = static_cast(sort); + fn(&mode_v, &sort_v, nullptr, &x_cols, nullptr, &x_cols, nullptr, nullptr, + nullptr, &x_cols, &optimal_size, &workspace_query, nullptr, nullptr, + &info); + return info == 0 ? static_cast(std::real(optimal_size)) : -1; +}; + +template struct SchurDecomposition; +template struct SchurDecomposition; +template struct SchurDecompositionComplex; +template struct SchurDecompositionComplex; + //== Hessenberg Decomposition ==// // lapack gehrd @@ -1926,6 +2102,33 @@ template struct TridiagonalReduction; .Ret<::xla::ffi::Buffer>(/*eigvecs_right*/) \ .Ret<::xla::ffi::Buffer>(/*info*/)) +#define JAX_CPU_DEFINE_GEES(name, data_type) \ + XLA_FFI_DEFINE_HANDLER_SYMBOL( \ + name, SchurDecomposition::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Attr("mode") \ + .Attr("sort") \ + .Ret<::xla::ffi::Buffer>(/*x_out*/) \ + .Ret<::xla::ffi::Buffer>(/*schur_vectors*/) \ + .Ret<::xla::ffi::Buffer>(/*eigvals_real*/) \ + .Ret<::xla::ffi::Buffer>(/*eigvals_imag*/) \ + .Ret<::xla::ffi::Buffer>(/*selected_eigvals*/) \ + .Ret<::xla::ffi::Buffer>(/*info*/)) + +#define JAX_CPU_DEFINE_GEES_COMPLEX(name, data_type) \ + XLA_FFI_DEFINE_HANDLER_SYMBOL( \ + name, SchurDecompositionComplex::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Attr("mode") \ + .Attr("sort") \ + .Ret<::xla::ffi::Buffer>(/*x_out*/) \ + .Ret<::xla::ffi::Buffer>(/*schur_vectors*/) \ + .Ret<::xla::ffi::Buffer>(/*eigvals*/) \ + .Ret<::xla::ffi::Buffer>(/*selected_eigvals*/) \ + .Ret<::xla::ffi::Buffer>(/*info*/)) + #define JAX_CPU_DEFINE_SYTRD_HETRD(name, data_type) \ XLA_FFI_DEFINE_HANDLER_SYMBOL( \ name, TridiagonalReduction::Kernel, \ @@ -1998,6 +2201,11 @@ JAX_CPU_DEFINE_SYTRD_HETRD(lapack_dsytrd_ffi, ::xla::ffi::DataType::F64); JAX_CPU_DEFINE_SYTRD_HETRD(lapack_chetrd_ffi, ::xla::ffi::DataType::C64); JAX_CPU_DEFINE_SYTRD_HETRD(lapack_zhetrd_ffi, ::xla::ffi::DataType::C128); +JAX_CPU_DEFINE_GEES(lapack_sgees_ffi, ::xla::ffi::DataType::F32); +JAX_CPU_DEFINE_GEES(lapack_dgees_ffi, ::xla::ffi::DataType::F64); +JAX_CPU_DEFINE_GEES_COMPLEX(lapack_cgees_ffi, ::xla::ffi::DataType::C64); +JAX_CPU_DEFINE_GEES_COMPLEX(lapack_zgees_ffi, ::xla::ffi::DataType::C128); + JAX_CPU_DEFINE_GEHRD(lapack_sgehrd_ffi, ::xla::ffi::DataType::F32); JAX_CPU_DEFINE_GEHRD(lapack_dgehrd_ffi, ::xla::ffi::DataType::F64); JAX_CPU_DEFINE_GEHRD(lapack_cgehrd_ffi, ::xla::ffi::DataType::C64); @@ -2015,6 +2223,8 @@ JAX_CPU_DEFINE_GEHRD(lapack_zgehrd_ffi, ::xla::ffi::DataType::C128); #undef JAX_CPU_DEFINE_GEEV #undef JAX_CPU_DEFINE_GEEV_COMPLEX #undef JAX_CPU_DEFINE_SYTRD_HETRD +#undef JAX_CPU_DEFINE_GEES +#undef JAX_CPU_DEFINE_GEES_COMPLEX #undef JAX_CPU_DEFINE_GEHRD } // namespace jax diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index e5fa9d354f6d..7d15e494fffc 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -67,7 +67,18 @@ enum class ComputationMode : char { kComputeEigenvectors = 'V', }; -} +} // namespace eig + +namespace schur { + +enum class ComputationMode : char { + kNoComputeSchurVectors = 'N', + kComputeSchurVectors = 'V', +}; + +enum class Sort : char { kNoSortEigenvalues = 'N', kSortEigenvalues = 'S' }; + +} // namespace schur template void AssignKernelFn(void* func) { @@ -96,6 +107,8 @@ DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Transpose); DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Diag); DEFINE_CHAR_ENUM_ATTR_DECODING(jax::svd::ComputationMode); DEFINE_CHAR_ENUM_ATTR_DECODING(jax::eig::ComputationMode); +DEFINE_CHAR_ENUM_ATTR_DECODING(jax::schur::ComputationMode); +DEFINE_CHAR_ENUM_ATTR_DECODING(jax::schur::Sort); #undef DEFINE_CHAR_ENUM_ATTR_DECODING @@ -551,6 +564,64 @@ struct ComplexGees { static void Kernel(void* out, void** data, XlaCustomCallStatus*); }; +// FFI Kernel + +template <::xla::ffi::DataType dtype> +struct SchurDecomposition { + static_assert(!::xla::ffi::IsComplexType(), + "There exists a separate implementation for Complex types"); + + using ValueType = ::xla::ffi::NativeType; + using FnType = void(char* jobvs, char* sort, + bool (*select)(ValueType, ValueType), lapack_int* n, + ValueType* a, lapack_int* lda, lapack_int* sdim, + ValueType* wr, ValueType* wi, ValueType* vs, + lapack_int* ldvs, ValueType* work, lapack_int* lwork, + bool* bwork, lapack_int* info); + + inline static FnType* fn = nullptr; + + static ::xla::ffi::Error Kernel( + ::xla::ffi::Buffer x, schur::ComputationMode mode, + schur::Sort sort, ::xla::ffi::ResultBuffer x_out, + ::xla::ffi::ResultBuffer schur_vectors, + ::xla::ffi::ResultBuffer eigvals_real, + ::xla::ffi::ResultBuffer eigvals_imag, + ::xla::ffi::ResultBuffer selected_eigvals, + ::xla::ffi::ResultBuffer info); + + static int64_t GetWorkspaceSize(lapack_int x_cols, + schur::ComputationMode mode, + schur::Sort sort); +}; + +template <::xla::ffi::DataType dtype> +struct SchurDecompositionComplex { + static_assert(::xla::ffi::IsComplexType()); + + using ValueType = ::xla::ffi::NativeType; + using RealType = ::xla::ffi::NativeType<::xla::ffi::ToReal(dtype)>; + using FnType = void(char* jobvs, char* sort, bool (*select)(ValueType), + lapack_int* n, ValueType* a, lapack_int* lda, + lapack_int* sdim, ValueType* w, ValueType* vs, + lapack_int* ldvs, ValueType* work, lapack_int* lwork, + RealType* rwork, bool* bwork, lapack_int* info); + + inline static FnType* fn = nullptr; + + static ::xla::ffi::Error Kernel( + ::xla::ffi::Buffer x, schur::ComputationMode mode, + schur::Sort sort, ::xla::ffi::ResultBuffer x_out, + ::xla::ffi::ResultBuffer schur_vectors, + ::xla::ffi::ResultBuffer eigvals, + ::xla::ffi::ResultBuffer selected_eigvals, + ::xla::ffi::ResultBuffer info); + + static int64_t GetWorkspaceSize(lapack_int x_cols, + schur::ComputationMode mode, + schur::Sort sort); +}; + //== Hessenberg Decomposition ==// //== Reduces a non-symmetric square matrix to upper Hessenberg form ==// @@ -677,6 +748,10 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_ssytrd_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dsytrd_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_chetrd_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zhetrd_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgees_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgees_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgees_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zgees_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgehrd_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgehrd_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgehrd_ffi); diff --git a/jaxlib/cpu/lapack_kernels_using_lapack.cc b/jaxlib/cpu/lapack_kernels_using_lapack.cc index 48efcedeba9b..ad64069a2499 100644 --- a/jaxlib/cpu/lapack_kernels_using_lapack.cc +++ b/jaxlib/cpu/lapack_kernels_using_lapack.cc @@ -66,10 +66,10 @@ jax::EigenvalueDecomposition::FnType dgeev_; jax::EigenvalueDecompositionComplex::FnType cgeev_; jax::EigenvalueDecompositionComplex::FnType zgeev_; -jax::RealGees::FnType sgees_; -jax::RealGees::FnType dgees_; -jax::ComplexGees>::FnType cgees_; -jax::ComplexGees>::FnType zgees_; +jax::SchurDecomposition::FnType sgees_; +jax::SchurDecomposition::FnType dgees_; +jax::SchurDecompositionComplex::FnType cgees_; +jax::SchurDecompositionComplex::FnType zgees_; jax::HessenbergDecomposition::FnType sgehrd_; jax::HessenbergDecomposition::FnType dgehrd_; @@ -227,6 +227,22 @@ static_assert( std::is_same_v::FnType, jax::Sytrd>::FnType>, JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v::FnType, + jax::RealGees::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v::FnType, + jax::RealGees::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v::FnType, + jax::ComplexGees>::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v::FnType, + jax::ComplexGees>::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); static_assert( std::is_same_v::FnType, jax::Gehrd::FnType>, @@ -352,6 +368,11 @@ static auto init = []() -> int { AssignKernelFn>(chetrd_); AssignKernelFn>(zhetrd_); + AssignKernelFn>(sgees_); + AssignKernelFn>(dgees_); + AssignKernelFn>(cgees_); + AssignKernelFn>(zgees_); + AssignKernelFn>(sgehrd_); AssignKernelFn>(dgehrd_); AssignKernelFn>(cgehrd_);