Skip to content

Commit

Permalink
Activate QR Factorization to XLA's FFI
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 666722604
  • Loading branch information
pparuzel authored and jax authors committed Aug 23, 2024
1 parent 07767e8 commit c430b0c
Show file tree
Hide file tree
Showing 5 changed files with 556 additions and 38 deletions.
19 changes: 16 additions & 3 deletions jax/_src/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,9 +923,22 @@ def _check_lowering(lowering) -> None:
"\n".join(not_implemented_msgs))

_CPU_FFI_KERNELS = [
"lapack_spotrf_ffi", "lapack_dpotrf_ffi", "lapack_cpotrf_ffi", "lapack_zpotrf_ffi",
"lapack_sgesdd_ffi", "lapack_dgesdd_ffi", "lapack_cgesdd_ffi", "lapack_zgesdd_ffi",
"lapack_sgetrf_ffi", "lapack_dgetrf_ffi", "lapack_cgetrf_ffi", "lapack_zgetrf_ffi",
"lapack_spotrf_ffi",
"lapack_dpotrf_ffi",
"lapack_cpotrf_ffi",
"lapack_zpotrf_ffi",
"lapack_sgeqrf_ffi",
"lapack_dgeqrf_ffi",
"lapack_cgeqrf_ffi",
"lapack_zgeqrf_ffi",
"lapack_sgesdd_ffi",
"lapack_dgesdd_ffi",
"lapack_cgesdd_ffi",
"lapack_zgesdd_ffi",
"lapack_sgetrf_ffi",
"lapack_dgetrf_ffi",
"lapack_cgetrf_ffi",
"lapack_zgetrf_ffi",
]
# These are the JAX custom call target names that are guaranteed to be stable.
# Their backwards compatibility is tested by back_compat_test.py.
Expand Down
Loading

0 comments on commit c430b0c

Please sign in to comment.