From 3c6103f2dfd8c8a56c5f9104eef52c4320a3d1ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Paruzel?= Date: Wed, 28 Aug 2024 03:53:07 -0700 Subject: [PATCH] Activate Eigenvalue Decompositions to XLA's FFI Two eigenvalue decomposition methods. One is intended for non-symmetric matrices - GEEV (General Eigenvalue Solver) - and the other for Symmetric or Hermitian matrices - SYEVD/HEEVD. PiperOrigin-RevId: 668381949 --- jax/_src/export/_export.py | 2 + .../cpu_eig_lapack_geev.py | 265 +++++++++++ .../cpu_eigh_lapack_syev.py | 443 ++++++++++++++++++ jax/_src/lax/linalg.py | 22 +- jaxlib/lapack.py | 294 +++++++----- tests/export_back_compat_test.py | 18 + 6 files changed, 932 insertions(+), 112 deletions(-) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 54a4f6d5498e..54defa0e9c54 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -925,6 +925,8 @@ def _check_lowering(lowering) -> None: _CPU_FFI_KERNELS = [ "lapack_spotrf_ffi", "lapack_dpotrf_ffi", "lapack_cpotrf_ffi", "lapack_zpotrf_ffi", "lapack_sgeqrf_ffi", "lapack_dgeqrf_ffi", "lapack_cgeqrf_ffi", "lapack_zgeqrf_ffi", + "lapack_ssyevd_ffi", "lapack_dsyevd_ffi", "lapack_cheevd_ffi", "lapack_zheevd_ffi", + "lapack_sgeev_ffi", "lapack_dgeev_ffi", "lapack_cgeev_ffi", "lapack_zgeev_ffi", "lapack_sgesdd_ffi", "lapack_dgesdd_ffi", "lapack_cgesdd_ffi", "lapack_zgesdd_ffi", "lapack_sgetrf_ffi", "lapack_dgetrf_ffi", "lapack_cgetrf_ffi", "lapack_zgetrf_ffi", ] diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eig_lapack_geev.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eig_lapack_geev.py index 1e4b6428556b..bc28857fa325 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eig_lapack_geev.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eig_lapack_geev.py @@ -283,3 +283,268 @@ mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01#\x05\x01\x03\x01\x03\x05\x03\x13\x07\t\x0b\r\x0f\x11\x13\x15\x17\x03\xe5\x9b7\x01[\x0f\x13\x0b\x07\x0b\x13\x0f\x0b\x17\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x17\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03A\x0fO\x0b\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x0f\x1f\x1f\x13\x0b\x0b\x0b\x0b\x1f#\x1f\x0f\x0b\x0bO/O\x01\x03\x0f\x035\x17\x0f\x07\x13\x0b\x07\x0f\x0f\x07\x07\x17\x17\x1b\x13\x07\x07\x13\x13\x13\x13\x13\x0f\x13\x13\x13\x13\x02z\x06\x1d9;\x03\x03\t\x8f\x05\x19\x1f\x05\x1b\x03\x03\x05\x95\x11\x01\x05\x05\x1d\x17\x13\xc2\x07\x01\x05\x1f\x03\x03\x05\x7f\x03\x03\t\x99\x03\x07\x1b\r\x1d\r\x0f\x1f\x05!\x05#\x05%\x03\x0b#_%e'g\x0fu)w\x05'\x05)\x05+\x05-\x03\x03-y\x05/\x1d1\x11\x051\x1d5\x11\x053\x03\x03\x05{\x055\x17\x13\xc6\x07\x01\x03\x03\x05}\x03\x11A\x81C\x83E\x85G_I\x87K\x89M_O\x8b\x057\x059\x05;\x05=\x05?\x05A\x05C\x05E\x03\x03\x05\x8d\x03\x05U\x91W\x93\x05G\x05I\x03\x03\t\x97\x1f%\x01\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1dK\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1b\x03\x07imq\r\x03ak\x1dM\r\x03ao\x1dO\r\x03as\x1dQ\x1dS\x1dU\x13\r\x01\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x11\x03V\x0b\x05\x1dW\x1dY\x05\x01\x03\x0b[[[[]\x03\r]cc]][\x1f\x05\t\x00\x00\x00\x00\x1f+\x01\t\x07\x07\x01\x1f\x0f!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f5!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x0b)\x01\x1f\x01)\x03\x11\x0b\x03\x15\x1d)\x01\x0b)\x01!\x13\x0b)\x05\x05\x05\x07)\x05\x11\x11\x07\x11\x01\x07\t\x03\x03)\x03A\x0b\x1b!)\x03!\x15)\x03\x01\x13)\x03\t\x13)\x03\x05\x13)\x03\x01\r)\x01\x07)\x03\x05\x07)\x03\x11\x07)\x03\x05\r)\x03\t\r\x04j\x03\x05\x01\x11\x07\x19\x07\x03\x01\x05\t\x11\x07!\x05\x03=i\x0b\x03/+\x03\x1d\r\x063\x03\x03\x03\x01\x05\x03\x017\x03\x05\x05\x03\x01=\x03\x05\x05\x03\x01\x15\x03\x11\x05\x03\x01\x15\x03\x11\x0f\x07\x01?\r\x03#\t\x03\x03\x05\x0b\x05\x07\t\x0b\x03\x05\x03\x01Q\x03\x05\x03\x07\x01\x03\x03\x05\x03\x19\x11\x07\x01S\x03-\x05\x17\x1b\x03\x07\x01\x03\x03/\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\t\x03!\x03\x07\x01Y\x031\x03\x1f\x07\x06\x01\x03\t\x07%\x11#\x03\x07\x01\x03\x03\x17\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\x03\x03+\x03\x07\x01\x17\x03\x19\x03)\x07\x06\x01\x03\x03\x07/\x13-\x03\x07\x01\x03\x03\x17\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\x03\x035\x03\x07\x01\x17\x03\x19\x033\x07\x06\x01\x03\x03\x079\x157\x13\x04\x07\x07'1;\x06\x03\x01\x05\x01\x00\x02\r[\x1b\x03\x0f\x0b\t\t\t!+\x1b\x1f/!!)#\x1f\x19\xb1}\x87\x1f\x1f\x15\x1d\x15\x13%)\x97\x13+\r\x15\x17\x1f\x17\x11\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=complex128 shape=(16,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]\x00jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_zgeev\x00", xla_call_module_version=6, ) # End paste + + +data_2024_08_19 = {} + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_19["c128"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_zgeev_ffi'], + serialized_date=datetime.date(2024, 8, 19), + inputs=(), + expected_outputs=(array([ 3.2464249196572972e+01+0.j, -2.4642491965729794e+00+0.j, + -1.4596915295025735e-15+0.j, 4.7403016698320490e-16+0.j]), array([[ 0.40377749076862324+0.j, 0.8288327563197503 +0.j, + -0.5409014947846461 +0.j, 0.10917005482608667-0.j], + [ 0.4648073711584899 +0.j, 0.43714638836388775-0.j, + 0.7854306338527134 +0.j, -0.5456169434539783 +0.j], + [ 0.5258372515483575 +0.j, 0.04546002040802463-0.j, + 0.05184321664851461-0.j, 0.7637237224296971 +0.j], + [ 0.5868671319382249 +0.j, -0.34622634754783843+0.j, + -0.296372355716581 +0.j, -0.32727683380180517+0.j]]), array([[ 0.11417645138733866+0.j, 0.7327780959803557 +0.j, + -0.5367326141844461 +0.j, -0.08617176416747369+0.j], + [ 0.33000459866554754+0.j, 0.28974835239692603-0.j, + 0.6342729310130916 +0.j, -0.28826848493327445+0.j], + [ 0.5458327459437569 +0.j, -0.15328139118650222+0.j, + 0.34165198052715445-0.j, 0.83505226236897 +0.j], + [ 0.7616608932219664 +0.j, -0.5963111347699301 +0.j, + -0.4391922973557999 +0.j, -0.460612013268222 +0.j]])), + mlir_module_text=r""" +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<4x4xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<4x4xcomplex> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<16xcomplex> loc(#loc3) + %1 = stablehlo.reshape %0 : (tensor<16xcomplex>) -> tensor<4x4xcomplex> loc(#loc4) + %c = stablehlo.constant dense<4> : tensor loc(#loc5) + %c_0 = stablehlo.constant dense<4> : tensor loc(#loc5) + %2:4 = stablehlo.custom_call @lapack_zgeev_ffi(%1) {mhlo.backend_config = {compute_left = 86 : ui8, compute_right = 86 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], result_layouts = [dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xcomplex>) -> (tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex>, tensor) loc(#loc5) + %c_1 = stablehlo.constant dense<0> : tensor loc(#loc5) + %3 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor) -> tensor loc(#loc5) + %4 = stablehlo.compare EQ, %2#3, %3, SIGNED : (tensor, tensor) -> tensor loc(#loc5) + %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) + %cst = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) + %6 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<4xcomplex> loc(#loc5) + %7 = stablehlo.broadcast_in_dim %5, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc5) + %8 = stablehlo.select %7, %2#0, %6 : tensor<4xi1>, tensor<4xcomplex> loc(#loc5) + %9 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) + %cst_2 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) + %10 = stablehlo.broadcast_in_dim %cst_2, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) + %11 = stablehlo.broadcast_in_dim %9, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) + %12 = stablehlo.select %11, %2#1, %10 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) + %13 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) + %cst_3 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) + %14 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) + %15 = stablehlo.broadcast_in_dim %13, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) + %16 = stablehlo.select %15, %2#2, %14 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) + return %8, %12, %16 : tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":210:14) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":211:13) +#loc3 = loc("jit(func)/jit(main)/iota[dtype=complex128 shape=(16,) dimension=0]"(#loc1)) +#loc4 = loc("jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]"(#loc1)) +#loc5 = loc("jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]"(#loc2)) +""", + mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01#\x05\x01\x03\x01\x03\x05\x03\x13\x07\t\x0b\r\x0f\x11\x13\x15\x17\x03\xef\xa57\x01]\x0f\x13\x07\x0b\x0b\x13\x0f\x0b\x17\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0f\x0b\x0b\x17S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03I\x0b\x0b\x0b\x0bO\x0f\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0f/\x0b\x0b\x0b\x0b\x1b\x0b\x0b\x0f\x1b/\x0f\x1f\x0f\x0b\x0bO/O\x01\x05\x0b\x0f\x033\x17\x07\x07\x13\x0b\x0f\x0f\x0f\x07\x17\x17\x1b\x07\x13\x07\x07\x13\x13\x13\x13\x0f\x13\x13\x13\x13\x02\xa6\x06\x1d;=\x03\x03\t\x99\x1f\x05\x19\x05\x1b\x03\x03\x07\x9f\x11\x03\x05\x05\x1d\x17\x13J\x03\x1d\x05\x1f\x03\x03\x07\x7f\x03\x03\t\xa3\x03\t\x1b\x1d\x1f\r!\r\x0f#\x05!\x11\x01\x00\x05#\x05%\x05\'\x03\x0b\'])i+k\x0fy-{\x05)\x05+\x05-\x05/\x03\x031}\x051\x1d5\x11\x053\x1d9\x11\x055\x057\x17\x13N\x03\x1b\x03\x13A\x81C\x83E\x85G]I\x87K\x89M\x8fO]Q\x91\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x05I\x03\x03\x07\x97\x03\x05W\x9bY\x9d\x05K\x05M\x03\x03\t\xa1\x03\x01\x1dO\x1dQ\x1dS\x1f%!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x13#V#\x1b\x03\x07mqu\r\x05_oac\x1dU\r\x05_sac\x1dW\r\x05_wac\x1dY\x1d[\x1d]\x13\x07\x01\x1f\x13\x11\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d_\x1da\x05\x01\r\x05\x8bg\x8dg\x1dc\x1de\x03\x03e\x03\t\x93ee\x95\x1f\'\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f)\x01\x1f\x0f\t\x00\x00\x00\x00\x1f+\x01\t\x07\x07\x01\x1f\x11!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f5!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x05\x11\x11\r\x1d\x01)\x03\x11\r\x03\x1d)\x01!)\x01\r)\x01\x07\x13)\x05\x05\x05\t)\x05\x11\x11\t\x11\x01\x07\x0b\x05\x05\x0b)\x03A\r\x1b!)\x03\t\x15)\x03\x05\x15)\x03\x01\x15)\x03\x01\x07)\x01\t)\x03\x05\t)\x03\x11\t)\x03\x05\x07)\x03\t\x07\x04"\x03\x05\x01\x11\x05\x19\x07\x03\x01\x05\t\x11\x05%\x07\x035a\x0b\x033/\x03\x1f\r\x067\x03\x05\x03\x01\x05\x03\x01\x15\x03\x13\x05\x03\x01\x15\x03\x13\x0f\x07\x01?\t\x0b\x05\x05\x0f\x03\x03\x05\x03\x01S\x03\x0f\x03\x07\x01\x03\x03\x0f\x03\x11\x11\x07\x01U\x03-\x05\x0f\x13\x03\x07\x01\x03\x03/\x03\x15\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x0b\x03\x19\x03\x07\x01[\x031\x03\x17\x07\x06\x01\x03\x0b\x07\x1d\t\x1b\x03\x07\x01\x03\x03\x17\x03\x15\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x05\x03#\x03\x07\x01\x17\x03\x19\x03!\x07\x06\x01\x03\x05\x07\'\x0b%\x03\x07\x01\x03\x03\x17\x03\x15\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x05\x03-\x03\x07\x01\x17\x03\x19\x03+\x07\x06\x01\x03\x05\x071\r/\x13\x04\x05\x07\x1f)3\x06\x03\x01\x05\x01\x00^\x0eg\x1d\x1b#\x03\x0f\x0b\t\t\t\x11#!+\x1b\x1f/!)!)#\x1f\x19\xb1}\x87\x1f\x1f\x15\x1d\x15\x13%)9i\x13+\r\x15\x17\x1f\x17\x11\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=complex128 shape=(16,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]\x00jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_zgeev_ffi\x00compute_left\x00compute_right\x00', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_19["c64"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_cgeev_ffi'], + serialized_date=datetime.date(2024, 8, 19), + inputs=(), + expected_outputs=(array([ 3.2464249e+01+0.j, -2.4642491e+00+0.j, -8.1492220e-07+0.j, + 3.0721142e-07+0.j], dtype=complex64), array([[ 0.40377736 +0.j, 0.8288328 +0.j, -0.53676015 +0.j, + 0.07707452 -0.j], + [ 0.4648074 +0.j, 0.43714643 -0.j, 0.79694915 +0.j, + -0.5069523 +0.j], + [ 0.52583736 +0.j, 0.04545992 -0.j, 0.016383484+0.j, + 0.7826807 +0.j], + [ 0.5868672 +0.j, -0.34622622 +0.j, -0.2765721 +0.j, + -0.35280296 +0.j]], dtype=complex64), array([[ 0.114176415+0.j, 0.73277825 +0.j, -0.54227245 +0.j, + -0.109032825+0.j], + [ 0.3300045 +0.j, 0.2897482 -0.j, 0.6655821 +0.j, + -0.25470036 +0.j], + [ 0.5458329 +0.j, -0.15328139 +0.j, 0.29565343 +0.j, + 0.83649963 +0.j], + [ 0.7616609 +0.j, -0.59631103 +0.j, -0.4189632 +0.j, + -0.47276634 +0.j]], dtype=complex64)), + mlir_module_text=r""" +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<4x4xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<4x4xcomplex> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<16xcomplex> loc(#loc3) + %1 = stablehlo.reshape %0 : (tensor<16xcomplex>) -> tensor<4x4xcomplex> loc(#loc4) + %c = stablehlo.constant dense<4> : tensor loc(#loc5) + %c_0 = stablehlo.constant dense<4> : tensor loc(#loc5) + %2:4 = stablehlo.custom_call @lapack_cgeev_ffi(%1) {mhlo.backend_config = {compute_left = 86 : ui8, compute_right = 86 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], result_layouts = [dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xcomplex>) -> (tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex>, tensor) loc(#loc5) + %c_1 = stablehlo.constant dense<0> : tensor loc(#loc5) + %3 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor) -> tensor loc(#loc5) + %4 = stablehlo.compare EQ, %2#3, %3, SIGNED : (tensor, tensor) -> tensor loc(#loc5) + %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) + %cst = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) + %6 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<4xcomplex> loc(#loc5) + %7 = stablehlo.broadcast_in_dim %5, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc5) + %8 = stablehlo.select %7, %2#0, %6 : tensor<4xi1>, tensor<4xcomplex> loc(#loc5) + %9 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) + %cst_2 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) + %10 = stablehlo.broadcast_in_dim %cst_2, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) + %11 = stablehlo.broadcast_in_dim %9, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) + %12 = stablehlo.select %11, %2#1, %10 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) + %13 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) + %cst_3 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) + %14 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) + %15 = stablehlo.broadcast_in_dim %13, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) + %16 = stablehlo.select %15, %2#2, %14 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) + return %8, %12, %16 : tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":210:14) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":211:13) +#loc3 = loc("jit(func)/jit(main)/iota[dtype=complex64 shape=(16,) dimension=0]"(#loc1)) +#loc4 = loc("jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]"(#loc1)) +#loc5 = loc("jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]"(#loc2)) +""", + mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01#\x05\x01\x03\x01\x03\x05\x03\x13\x07\t\x0b\r\x0f\x11\x13\x15\x17\x03\xef\xa57\x01]\x0f\x13\x07\x0b\x0b\x13\x0f\x0b\x17\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0f\x0b\x0b\x17S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03I\x0b\x0b\x0b\x0bO\x0f\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0f/\x0b\x0b\x0b\x0b\x1b\x0b\x0b\x0f\x1b/\x0f\x1f\x0f\x0b\x0b//O\x01\x05\x0b\x0f\x033\x17\x07\x07\x13\x0b\x0f\x0f\x0f\x07\x17\x17\x1b\x07\x13\x07\x07\x13\x13\x13\x13\x0f\x13\x13\x13\x13\x02\x86\x06\x1d;=\x03\x03\t\x99\x1f\x05\x19\x05\x1b\x03\x03\x07\x9f\x11\x03\x05\x05\x1d\x17\x13J\x03\x1d\x05\x1f\x03\x03\x07\x7f\x03\x03\t\xa3\x03\t\x1b\x1d\x1f\r!\r\x0f#\x05!\x11\x01\x00\x05#\x05%\x05\'\x03\x0b\'])i+k\x0fy-{\x05)\x05+\x05-\x05/\x03\x031}\x051\x1d5\x11\x053\x1d9\x11\x055\x057\x17\x13N\x03\x1b\x03\x13A\x81C\x83E\x85G]I\x87K\x89M\x8fO]Q\x91\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x05I\x03\x03\x07\x97\x03\x05W\x9bY\x9d\x05K\x05M\x03\x03\t\xa1\x03\x01\x1dO\x1dQ\x1dS\x1f%!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x13#V#\x1b\x03\x07mqu\r\x05_oac\x1dU\r\x05_sac\x1dW\r\x05_wac\x1dY\x1d[\x1d]\x13\x07\x01\x1f\x13\x11\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d_\x1da\x05\x01\r\x05\x8bg\x8dg\x1dc\x1de\x03\x03e\x03\t\x93ee\x95\x1f\'\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f)\x01\x1f\x0f\t\x00\x00\x00\x00\x1f+\x01\t\x07\x07\x01\x1f\x11\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f5!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x05\x11\x11\r\x1d\x01)\x03\x11\r\x03\x1d)\x01!)\x01\r)\x01\x07\x13)\x05\x05\x05\t)\x05\x11\x11\t\x11\x01\x07\x0b\x05\x05\t)\x03A\r\x1b!)\x03\t\x15)\x03\x05\x15)\x03\x01\x15)\x03\x01\x07)\x01\t)\x03\x05\t)\x03\x11\t)\x03\x05\x07)\x03\t\x07\x04"\x03\x05\x01\x11\x05\x19\x07\x03\x01\x05\t\x11\x05%\x07\x035a\x0b\x033/\x03\x1f\r\x067\x03\x05\x03\x01\x05\x03\x01\x15\x03\x13\x05\x03\x01\x15\x03\x13\x0f\x07\x01?\t\x0b\x05\x05\x0f\x03\x03\x05\x03\x01S\x03\x0f\x03\x07\x01\x03\x03\x0f\x03\x11\x11\x07\x01U\x03-\x05\x0f\x13\x03\x07\x01\x03\x03/\x03\x15\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x0b\x03\x19\x03\x07\x01[\x031\x03\x17\x07\x06\x01\x03\x0b\x07\x1d\t\x1b\x03\x07\x01\x03\x03\x17\x03\x15\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x05\x03#\x03\x07\x01\x17\x03\x19\x03!\x07\x06\x01\x03\x05\x07\'\x0b%\x03\x07\x01\x03\x03\x17\x03\x15\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x05\x03-\x03\x07\x01\x17\x03\x19\x03+\x07\x06\x01\x03\x05\x071\r/\x13\x04\x05\x07\x1f)3\x06\x03\x01\x05\x01\x00Z\x0eg\x1d\x1b#\x03\x0f\x0b\t\t\t\x11#!+\x1b\x1f/!)!)#\x1f\x19\xb1}\x85\x1f\x1f\x15\x1d\x15\x13%)9i\x13+\r\x15\x17\x1f\x17\x11\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=complex64 shape=(16,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]\x00jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_cgeev_ffi\x00compute_left\x00compute_right\x00', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_19["f32"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_sgeev_ffi'], + serialized_date=datetime.date(2024, 8, 19), + inputs=(), + expected_outputs=(array([ 3.2464241e+01+0.j, -2.4642482e+00+0.j, -4.5555478e-07+0.j, + 2.9215252e-07+0.j], dtype=complex64), array([[-0.40377742+0.j, 0.8288328 +0.j, -0.5253654 +0.j, + -0.11065983+0.j], + [-0.46480736+0.j, 0.43714654+0.j, 0.8159359 +0.j, + 0.547376 +0.j], + [-0.52583736+0.j, 0.04545998+0.j, -0.0557748 +0.j, + -0.7627722 +0.j], + [-0.5868672 +0.j, -0.34622627+0.j, -0.23479532+0.j, + 0.32605612+0.j]], dtype=complex64), array([[-0.114176415+0.j, 0.7327782 +0.j, -0.5364275 +0.j, + 0.15489015 +0.j], + [-0.33000445 +0.j, 0.28974816 +0.j, 0.6327556 +0.j, + 0.18506403 +0.j], + [-0.54583275 +0.j, -0.15328142 +0.j, 0.34377125 +0.j, + -0.83479893 +0.j], + [-0.761661 +0.j, -0.5963111 +0.j, -0.44009918 +0.j, + 0.49484456 +0.j]], dtype=complex64)), + mlir_module_text=r""" +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<4x4xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<4x4xcomplex> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<16xf32> loc(#loc3) + %1 = stablehlo.reshape %0 : (tensor<16xf32>) -> tensor<4x4xf32> loc(#loc4) + %c = stablehlo.constant dense<4> : tensor loc(#loc5) + %c_0 = stablehlo.constant dense<4> : tensor loc(#loc5) + %2:5 = stablehlo.custom_call @lapack_sgeev_ffi(%1) {mhlo.backend_config = {compute_left = 86 : ui8, compute_right = 86 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], result_layouts = [dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xf32>) -> (tensor<4xf32>, tensor<4xf32>, tensor<4x4xcomplex>, tensor<4x4xcomplex>, tensor) loc(#loc5) + %3 = stablehlo.complex %2#0, %2#1 : tensor<4xcomplex> loc(#loc5) + %c_1 = stablehlo.constant dense<0> : tensor loc(#loc5) + %4 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor) -> tensor loc(#loc5) + %5 = stablehlo.compare EQ, %2#4, %4, SIGNED : (tensor, tensor) -> tensor loc(#loc5) + %6 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) + %cst = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) + %7 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<4xcomplex> loc(#loc5) + %8 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc5) + %9 = stablehlo.select %8, %3, %7 : tensor<4xi1>, tensor<4xcomplex> loc(#loc5) + %10 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) + %cst_2 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) + %11 = stablehlo.broadcast_in_dim %cst_2, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) + %12 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) + %13 = stablehlo.select %12, %2#2, %11 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) + %14 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) + %cst_3 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) + %15 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) + %16 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) + %17 = stablehlo.select %16, %2#3, %15 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) + return %9, %13, %17 : tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":210:14) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":211:13) +#loc3 = loc("jit(func)/jit(main)/iota[dtype=float32 shape=(16,) dimension=0]"(#loc1)) +#loc4 = loc("jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]"(#loc1)) +#loc5 = loc("jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]"(#loc2)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01%\x05\x01\x03\x01\x03\x05\x03\x15\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x03\xf3\xa5;\x01]\x0f\x13\x07\x0b\x0b\x13\x0f\x0b\x17\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0f\x0b\x0b\x17S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03I\x0b\x0b\x0b\x0bO\x0f/\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0f/\x0b\x0b\x0b\x0b\x1b\x0b\x0b\x0f\x1f\x0f\x1f\x0f\x0b\x0b//O\x01\x05\x0b\x0f\x037\x17\x07\x07\x13\x07\x0f\x0f\x0b\x0f\x07\x13\x17\x17\x1b\x13\x17\x07\x07\x13\x13\x13\x13\x0f\x13\x13\x13\x13\x02\xae\x06\x1d;=\x03\x03\t\x99\x1f\x05\x1b\x05\x1d\x03\x03\x07\x9f\x11\x03\x05\x05\x1f\x17\x13J\x03\x1d\x05!\x03\x03\x07\x81\x03\x03\t\xa3\x03\t\x1b\x1d\x1f\r!\r\x0f#\x05#\x11\x01\x00\x05%\x05'\x05)\x03\x0b'])k+m\x0f{-}\x05+\x05-\x05/\x051\x03\x031\x7f\x053\x1d5\x11\x055\x1d9\x11\x057\x059\x17\x13N\x03\x1b\x03\x13A\x83C\x85E\x87G]I\x89K\x8bM\x91O]Q\x93\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x05I\x05K\x03\x03\x07\x97\x03\x05W\x9bY\x9d\x05M\x05O\x03\x03\t\xa1\x03\x01\x1dQ\x1dS\x1dU\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x13'V\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x07osw\r\x05_qac\x1dW\r\x05_uac\x1dY\r\x05_yac\x1d[\x1d]\x1d_\x13\x07\x01\x1f\x15\x11\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1da\x1dc\x05\x01\r\x05\x8dg\x8fg\x1de\x1dg\x03\x03e\x03\x0biiee\x95\x1f-\x01\x1f\x0f\t\x00\x00\x00\x00\x1f/\x01\t\x07\x07\x01\x1f\x11\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f7\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f9!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x05\x11\x11\x13\x1d\x01)\x03\x11\x13\t)\x01%)\x01\x13\x03\r)\x01\x07\x13)\x03\x11\r)\x05\x05\x05\t)\x05\x11\x11\t\x11\x01\x07\x0b\x05\x05)\x03A\r)\x05\x11\x11\r\x1b!)\x03\t\x17)\x03\x05\x17)\x03\x01\x17)\x03\x01\x07)\x01\t)\x03\x05\t)\x03\x11\t)\x03\x05\x07)\x03\t\x07\x04F\x03\x05\x01\x11\x05\x19\x07\x03\x01\x05\t\x11\x05%\x07\x039e\x0b\x033/\x03!\r\x067\x03#\x03\x01\x05\x03\x01\x15\x03\x15\x05\x03\x01\x15\x03\x15\x0f\x07\x01?\x0b\x19\x19\x05\x05\x0f\x03\x03\x11\x06\x01\x03\x0b\x05\t\x0b\x05\x03\x01S\x03\x0f\x03\x07\x01\x03\x03\x0f\x03\x15\x13\x07\x01U\x031\x05\x11\x17\x03\x07\x01\x03\x033\x03\x19\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x0b\x03\x1d\x03\x07\x01[\x035\x03\x1b\x07\x06\x01\x03\x0b\x07!\x13\x1f\x03\x07\x01\x03\x03\x1b\x03\x19\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x05\x03'\x03\x07\x01\x17\x03\x1d\x03%\x07\x06\x01\x03\x05\x07+\r)\x03\x07\x01\x03\x03\x1b\x03\x19\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x05\x031\x03\x07\x01\x17\x03\x1d\x03/\x07\x06\x01\x03\x05\x075\x0f3\x15\x04\x05\x07#-7\x06\x03\x01\x05\x01\x00\x82\x0ei\x1d\x1b#\x03\x0f\x0b\t\t\t\x11#!+\x1b\x1f/!)!)#\x1f\x19\xb1}\x81\x1f\x1f\x15\x1d\x15\x13%)9i\x13+\r\x15\x17\x17\x1f\x17\x11\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00complex_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=float32 shape=(16,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]\x00jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_sgeev_ffi\x00compute_left\x00compute_right\x00", + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_19["f64"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_dgeev_ffi'], + serialized_date=datetime.date(2024, 8, 19), + inputs=(), + expected_outputs=(array([ 3.2464249196572972e+01+0.j, -2.4642491965729789e+00+0.j, + -1.4885423746029788e-15+0.j, 4.7495173217146935e-16+0.j]), array([[-0.40377749076862246 +0.j, -0.8288327563197503 +0.j, + -0.541090767303977 +0.j, 0.10767692008040902 +0.j], + [-0.4648073711584901 +0.j, -0.43714638836388775 +0.j, + 0.7847911174458492 +0.j, -0.5438508504687168 +0.j], + [-0.5258372515483576 +0.j, -0.045460020408024666+0.j, + 0.05369006702023438 +0.j, 0.7646709406962073 +0.j], + [-0.5868671319382248 +0.j, 0.34622634754783854 +0.j, + -0.2973904171621061 +0.j, -0.32849701030789913 +0.j]]), array([[-0.11417645138733848+0.j, -0.7327780959803556 +0.j, + -0.5370341524353898 +0.j, -0.0849751818967924 +0.j], + [-0.33000459866554754+0.j, -0.2897483523969262 +0.j, + 0.6357878989446506 +0.j, -0.29000500336734825+0.j], + [-0.545832745943757 +0.j, 0.15328139118650214+0.j, + 0.33952665941686755+0.j, 0.8349355524250736 +0.j], + [-0.7616608932219664 +0.j, 0.5963111347699303 +0.j, + -0.43828040592612855+0.j, -0.45995536716093305+0.j]])), + mlir_module_text=r""" +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<4x4xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<4x4xcomplex> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<16xf64> loc(#loc3) + %1 = stablehlo.reshape %0 : (tensor<16xf64>) -> tensor<4x4xf64> loc(#loc4) + %c = stablehlo.constant dense<4> : tensor loc(#loc5) + %c_0 = stablehlo.constant dense<4> : tensor loc(#loc5) + %2:5 = stablehlo.custom_call @lapack_dgeev_ffi(%1) {mhlo.backend_config = {compute_left = 86 : ui8, compute_right = 86 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], result_layouts = [dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xf64>) -> (tensor<4xf64>, tensor<4xf64>, tensor<4x4xcomplex>, tensor<4x4xcomplex>, tensor) loc(#loc5) + %3 = stablehlo.complex %2#0, %2#1 : tensor<4xcomplex> loc(#loc5) + %c_1 = stablehlo.constant dense<0> : tensor loc(#loc5) + %4 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor) -> tensor loc(#loc5) + %5 = stablehlo.compare EQ, %2#4, %4, SIGNED : (tensor, tensor) -> tensor loc(#loc5) + %6 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) + %cst = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) + %7 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<4xcomplex> loc(#loc5) + %8 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc5) + %9 = stablehlo.select %8, %3, %7 : tensor<4xi1>, tensor<4xcomplex> loc(#loc5) + %10 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) + %cst_2 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) + %11 = stablehlo.broadcast_in_dim %cst_2, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) + %12 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) + %13 = stablehlo.select %12, %2#2, %11 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) + %14 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) + %cst_3 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) + %15 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) + %16 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) + %17 = stablehlo.select %16, %2#3, %15 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) + return %9, %13, %17 : tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":210:14) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":211:13) +#loc3 = loc("jit(func)/jit(main)/iota[dtype=float64 shape=(16,) dimension=0]"(#loc1)) +#loc4 = loc("jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]"(#loc1)) +#loc5 = loc("jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]"(#loc2)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01%\x05\x01\x03\x01\x03\x05\x03\x15\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x03\xf3\xa5;\x01]\x0f\x13\x07\x0b\x0b\x13\x0f\x0b\x17\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0f\x0b\x0b\x17S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03I\x0b\x0b\x0b\x0bO\x0f/\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0f/\x0b\x0b\x0b\x0b\x1b\x0b\x0b\x0f\x1f\x0f\x1f\x0f\x0b\x0bO/O\x01\x05\x0b\x0f\x037\x17\x07\x07\x13\x07\x0f\x0f\x0b\x0f\x07\x13\x17\x17\x1b\x13\x17\x07\x07\x13\x13\x13\x13\x0f\x13\x13\x13\x13\x02\xce\x06\x1d;=\x03\x03\t\x99\x1f\x05\x1b\x05\x1d\x03\x03\x07\x9f\x11\x03\x05\x05\x1f\x17\x13J\x03\x1d\x05!\x03\x03\x07\x81\x03\x03\t\xa3\x03\t\x1b\x1d\x1f\r!\r\x0f#\x05#\x11\x01\x00\x05%\x05'\x05)\x03\x0b'])k+m\x0f{-}\x05+\x05-\x05/\x051\x03\x031\x7f\x053\x1d5\x11\x055\x1d9\x11\x057\x059\x17\x13N\x03\x1b\x03\x13A\x83C\x85E\x87G]I\x89K\x8bM\x91O]Q\x93\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x05I\x05K\x03\x03\x07\x97\x03\x05W\x9bY\x9d\x05M\x05O\x03\x03\t\xa1\x03\x01\x1dQ\x1dS\x1dU\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x13'V\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x07osw\r\x05_qac\x1dW\r\x05_uac\x1dY\r\x05_yac\x1d[\x1d]\x1d_\x13\x07\x01\x1f\x15\x11\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1da\x1dc\x05\x01\r\x05\x8dg\x8fg\x1de\x1dg\x03\x03e\x03\x0biiee\x95\x1f-\x01\x1f\x0f\t\x00\x00\x00\x00\x1f/\x01\t\x07\x07\x01\x1f\x11!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f7\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f9!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x05\x11\x11\x13\x1d\x01)\x03\x11\x13\x0b)\x01%)\x01\x13\x03\r)\x01\x07\x13)\x03\x11\r)\x05\x05\x05\t)\x05\x11\x11\t\x11\x01\x07\x0b\x05\x05)\x03A\r)\x05\x11\x11\r\x1b!)\x03\t\x17)\x03\x05\x17)\x03\x01\x17)\x03\x01\x07)\x01\t)\x03\x05\t)\x03\x11\t)\x03\x05\x07)\x03\t\x07\x04F\x03\x05\x01\x11\x05\x19\x07\x03\x01\x05\t\x11\x05%\x07\x039e\x0b\x033/\x03!\r\x067\x03#\x03\x01\x05\x03\x01\x15\x03\x15\x05\x03\x01\x15\x03\x15\x0f\x07\x01?\x0b\x19\x19\x05\x05\x0f\x03\x03\x11\x06\x01\x03\x0b\x05\t\x0b\x05\x03\x01S\x03\x0f\x03\x07\x01\x03\x03\x0f\x03\x15\x13\x07\x01U\x031\x05\x11\x17\x03\x07\x01\x03\x033\x03\x19\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x0b\x03\x1d\x03\x07\x01[\x035\x03\x1b\x07\x06\x01\x03\x0b\x07!\x13\x1f\x03\x07\x01\x03\x03\x1b\x03\x19\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x05\x03'\x03\x07\x01\x17\x03\x1d\x03%\x07\x06\x01\x03\x05\x07+\r)\x03\x07\x01\x03\x03\x1b\x03\x19\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x05\x031\x03\x07\x01\x17\x03\x1d\x03/\x07\x06\x01\x03\x05\x075\x0f3\x15\x04\x05\x07#-7\x06\x03\x01\x05\x01\x00\x82\x0ei\x1d\x1b#\x03\x0f\x0b\t\t\t\x11#!+\x1b\x1f/!)!)#\x1f\x19\xb1}\x81\x1f\x1f\x15\x1d\x15\x13%)9i\x13+\r\x15\x17\x17\x1f\x17\x11\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00complex_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=float64 shape=(16,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]\x00jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_dgeev_ffi\x00compute_left\x00compute_right\x00", + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eigh_lapack_syev.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eigh_lapack_syev.py index fcc32058bbee..f0696db1aeda 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eigh_lapack_syev.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eigh_lapack_syev.py @@ -383,3 +383,446 @@ xla_call_module_version=4, ), # End paste ) + +data_2024_08_19 = {} + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_19["c128"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_zheevd_ffi'], + serialized_date=datetime.date(2024, 8, 19), + inputs=(), + expected_outputs=(array([[-6.1857700048412056e-01+0.j, 2.4081403770912022e-01+0.j, + 3.5662489253627483e-01+0.j, -6.3034019033669797e-01+0.j, + 1.0043483479985752e-16+0.j, -2.8842036081919542e-02+0.j, + 7.7164692943283169e-25+0.j, -1.8446994643771725e-01+0.j], + [-4.7070881487314609e-01+0.j, 4.7473787464450828e-01+0.j, + -4.8036836210243361e-01+0.j, 4.3802686872516400e-01+0.j, + 1.7961797619639255e-01+0.j, 8.3080980076741355e-03+0.j, + 2.1415294457221759e-01+0.j, -2.2856669794666584e-01+0.j], + [-3.2284062926217072e-01+0.j, -5.4336490915553370e-01+0.j, + 2.2181041859724987e-01+0.j, 2.9947877954402286e-01+0.j, + -3.6491813600134637e-01+0.j, 3.2867679819727436e-01+0.j, + 3.8223299448843473e-01+0.j, -2.7266344945561438e-01+0.j], + [-1.7497244365119527e-01+0.j, -8.9251550609769331e-02+0.j, + -6.3518515114898352e-02+0.j, 1.9162997359209963e-01+0.j, + -2.2087281326110142e-01+0.j, 5.9957027043505008e-02+0.j, + -8.7632498908241274e-01+0.j, -3.1676020096456303e-01+0.j], + [-2.7104258040220017e-02+0.j, -3.3772873786627688e-01+0.j, + 2.5901386593721754e-01+0.j, 1.7032650752287815e-01+0.j, + 6.7521217612940321e-01+0.j, -4.5036136532965476e-01+0.j, + -1.2279030059078447e-02+0.j, -3.6085695247351163e-01+0.j], + [ 1.2076392757075533e-01+0.j, -3.3834734096469249e-01+0.j, + -6.5506827461665529e-01+0.j, -5.0472498521116760e-01+0.j, + 6.9987430903492132e-02+0.j, 1.0595648906599270e-01+0.j, + 8.3443844143082035e-02+0.j, -4.0495370398246017e-01+0.j], + [ 2.6863211318173102e-01+0.j, 2.2958613191407312e-01+0.j, + 6.3952843755683969e-02+0.j, 1.8776775771084192e-02+0.j, + -5.3523731432241317e-01+0.j, -5.9199531677602002e-01+0.j, + 1.7916671834524250e-01+0.j, -4.4905045549140887e-01+0.j], + [ 4.1650029879270667e-01+0.j, 3.6355449432857068e-01+0.j, + 2.9755313100756148e-01+0.j, 1.6826270392616000e-02+0.j, + 1.9621068035557282e-01+0.j, 5.6830030587314817e-01+0.j, + 2.9607517592514260e-02+0.j, -4.9314720700035747e-01+0.j]]), array([-2.4598804776133626e+01, -4.6567755957874661e-14, + -1.9932120610662194e-14, -5.7323356091157378e-15, + -4.5459724251334835e-16, 4.0479851042511616e-14, + 9.2325194924982089e-14, 2.7659880477613365e+02])), + mlir_module_text=r""" +#loc7 = loc("third_party/py/jax/tests/export_back_compat_test.py":277:27) +#loc18 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc7)) +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<8x8xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<8xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<64xcomplex> loc(#loc9) + %1 = stablehlo.reshape %0 : (tensor<64xcomplex>) -> tensor<8x8xcomplex> loc(#loc10) + %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xcomplex>) -> tensor<8x8xcomplex> loc(#loc11) + %3 = stablehlo.real %2 : (tensor<8x8xcomplex>) -> tensor<8x8xf64> loc(#loc12) + %4 = stablehlo.imag %2 : (tensor<8x8xcomplex>) -> tensor<8x8xf64> loc(#loc13) + %5 = stablehlo.negate %4 : tensor<8x8xf64> loc(#loc14) + %6 = stablehlo.complex %3, %5 : tensor<8x8xcomplex> loc(#loc15) + %7 = stablehlo.add %1, %6 : tensor<8x8xcomplex> loc(#loc16) + %cst = stablehlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor> loc(#loc) + %8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<8x8xcomplex> loc(#loc17) + %9 = stablehlo.divide %7, %8 : tensor<8x8xcomplex> loc(#loc17) + %10 = call @tril(%9) : (tensor<8x8xcomplex>) -> tensor<8x8xcomplex> loc(#loc18) + %c = stablehlo.constant dense<8> : tensor loc(#loc19) + %c_0 = stablehlo.constant dense<8> : tensor loc(#loc19) + %11:3 = stablehlo.custom_call @lapack_zheevd_ffi(%10) {mhlo.backend_config = {mode = 86 : ui8, uplo = 76 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor<8x8xcomplex>) -> (tensor<8x8xcomplex>, tensor<8xf64>, tensor) loc(#loc19) + %c_1 = stablehlo.constant dense<0> : tensor loc(#loc19) + %12 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor) -> tensor loc(#loc19) + %13 = stablehlo.compare EQ, %11#2, %12, SIGNED : (tensor, tensor) -> tensor loc(#loc19) + %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc19) + %cst_2 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc19) + %15 = stablehlo.broadcast_in_dim %cst_2, dims = [] : (tensor>) -> tensor<8x8xcomplex> loc(#loc19) + %16 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> loc(#loc19) + %17 = stablehlo.select %16, %11#0, %15 : tensor<8x8xi1>, tensor<8x8xcomplex> loc(#loc19) + %18 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi1> loc(#loc19) + %cst_3 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc19) + %19 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor) -> tensor<8xf64> loc(#loc19) + %20 = stablehlo.broadcast_in_dim %18, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> loc(#loc19) + %21 = stablehlo.select %20, %11#1, %19 : tensor<8xi1>, tensor<8xf64> loc(#loc19) + return %17, %21 : tensor<8x8xcomplex>, tensor<8xf64> loc(#loc) + } loc(#loc) + func.func private @tril(%arg0: tensor<8x8xcomplex> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc7))) -> (tensor<8x8xcomplex> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> loc(#loc20) + %c = stablehlo.constant dense<0> : tensor loc(#loc18) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<8x8xi32> loc(#loc21) + %2 = stablehlo.add %0, %1 : tensor<8x8xi32> loc(#loc21) + %3 = stablehlo.iota dim = 1 : tensor<8x8xi32> loc(#loc22) + %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> loc(#loc23) + %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc18) + %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<8x8xcomplex> loc(#loc24) + %6 = stablehlo.select %4, %arg0, %5 : tensor<8x8xi1>, tensor<8x8xcomplex> loc(#loc25) + return %6 : tensor<8x8xcomplex> loc(#loc18) + } loc(#loc18) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":269:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":269:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:34) +#loc4 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:25) +#loc5 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:15) +#loc6 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:14) +#loc8 = loc("third_party/py/jax/tests/export_back_compat_test.py":277:11) +#loc9 = loc("jit()/jit(main)/iota[dtype=complex128 shape=(64,) dimension=0]"(#loc1)) +#loc10 = loc("jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]"(#loc2)) +#loc11 = loc("jit()/jit(main)/transpose[permutation=(1, 0)]"(#loc3)) +#loc12 = loc("jit()/jit(main)/real"(#loc4)) +#loc13 = loc("jit()/jit(main)/imag"(#loc4)) +#loc14 = loc("jit()/jit(main)/neg"(#loc4)) +#loc15 = loc("jit()/jit(main)/complex"(#loc4)) +#loc16 = loc("jit()/jit(main)/add"(#loc5)) +#loc17 = loc("jit()/jit(main)/div"(#loc6)) +#loc19 = loc("jit()/jit(main)/eigh[lower=True sort_eigenvalues=True subset_by_index=None]"(#loc8)) +#loc20 = loc("jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]"(#loc7)) +#loc21 = loc("jit()/jit(main)/jit(tril)/add"(#loc7)) +#loc22 = loc("jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]"(#loc7)) +#loc23 = loc("jit()/jit(main)/jit(tril)/ge"(#loc7)) +#loc24 = loc("jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]"(#loc7)) +#loc25 = loc("jit()/jit(main)/jit(tril)/select_n"(#loc7)) +""", + mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x013\x05\x01\x03\x01\x03\x05\x03#\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%\'\x03\xda\x02*\x02?\x01\xab\x0f\x0b\x13\x17\x0f\x0b\x07\x17\x0b\x0b\x0f\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x0f\x13+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x13\x0b\x0f\x0b\x17\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x13\x0b\x17\x13\x0b\x0b\x17S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03a\x0b\x0b\x0b\x0b\x0f\x0b\x0bO\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0bOOO/\x0b\x0b\x0b\x0b\x1b\x0b\x0f\x0b\x0f\x0f\x0f\x17\x17/\x0f\x0bOO//\x01\x0b\x1f\x17\x17\x17\x17\x01\x05\x0b\x0f\x03;\x17\x07\x0f\x0f\x07\x07\x13\x17\x0b\x17\x0f\x07\x07\x17\x13\x07\x0f\x17\x17\x13\x17\x13\x13\x13\x0f\x17\x13\x13\x13\x02\xa6\n\x1d\x93\x95\x05)\x03\x03\x13\xd5\x17\x03V\x047\x1d?\x07\x05+\x1f\x17\x03>\x043\x05-\x05/\x11\x03\x05\x051\x053\x055\x057\x03\x03!\xd1\x059\x03\x03\x0b\xd3\x1dE\x07\x05;\x05=\x1d\x8b\x8d\x03\x03\x0b\xe1\x03\t135\x157\x15\x119\x05?\x11\x01\x00\x05A\x05C\x05E\x03\x0b\x17\xaf\x19\xbb\x1b\xbd\x11\xc7\x1d\xc9\x03\x0b\x17\xb3\x19\xcd\x1b\xb3\x11\xb5\x1d\xcf\x05G\x1dC\x07\x05I\x05K\x03\x03!\xd7\x1dK\x07\x05M\x03\x05\'\xb7)\xd9\x1dQ\x07\x05O\x03\x03\x0b\xdb\x1dW\x07\x05Q\x1d[\x07\x05S\x1d_a\x05U\x17\x036\x045\x1deg\x05W\x17\x036\x04\x1d\x03\x03k\xdd\x05Y\x1doq\x05[\x17\x03>\x04E\x1du\x0f\x05]\x1dy\x0f\x05_\x1d}\x0f\x05a\x1d\x81\x0f\x05c\x1d\x85\x87\x05e\x17\x03>\x04\x1f\x03\x03\x0b\xdf\x05g\x17\x03>\x04\x1d\x03\x03\x91\xb5\x05i\x05k\x17\x03V\x04\x17\x03\x13\x99\xe3\x9b\xe5\x9d\xe7\x9f\xaf\xa1\xe9\xa3\xeb\xa5\xf5\xa7\xf7\xa9\xfb\x05m\x05o\x05q\x05s\x05u\x05w\x05y\x05{\x05}\x1d\x7f\x1d\x81\x03\x01\x1d\x83\x03\x03\xcb\x1d\x85\t\x07\x1f/!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#\'\x03\x05\xbf\xc3\r\x05\xb1\xc1\xab\xad\x1d\x87\r\x05\xb1\xc5\xab\xad\x1d\x89\x1d\x8b\x1d\x8d\r\x03\xab\xad#)\x1d\x8f\x13\x07\x01\x1f\x0b\t\x00\x00\x00\x00\x1f+\x01\x13\x07\x05\x07\x05\x1f\t!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f!!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t!\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x19\x11\x08\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x91\x1d\x93\x05\x01\r\x05\xed\xef\xf1\xf3\x1d\x95\x13#V\x1d\x97\x13#L\x03\x03\xb9\x03\x03\xf9\x15\x03\x01\x01\x01\x03\x07\xb9\xfd\xff\x1f1\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f3\x01\x07\x01\x1f\t!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f!!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f%\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f=\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x05\'\xb7)\x02\x02\x03\x03\x0b\x06\x02\x03\x03\x13\n\x02\x03\x03\x0b\x0e\x02\x03\x03\x13\x12\x02\x01\t\x01\x02\x02)\x05!!\x15\x1d)\x01\x15)\x01\x1d\x01\x0b)\x03!\x0f)\x05!!\x1d\x03\x0f)\x05!!\x0f)\x01\x07\x13\x1b)\x05!!\r)\x03\t\x07!)\x01\x0f\x11\x01\x05\x05\x11\x11\x03\x05\x03\x05)\x03\x01\x07)\x03\x02\x02\x15)\x03\t\x1b)\x03\x05\x1b)\x03\x01\x1b)\x01\r)\x05\x05\x05\r)\x03\x05\r)\x03!\r)\x03\x05\x07\x04\x06\x05\x05\x01\x11\r/\x07\x03\x01\t\x0b\x11\r;\x07\x03=u\x07\x03]\x1f\x03-\x13\x06c\x03\x05\x03\x01\x15\x07mi\x03\x05\x03\x03\x17\x06s\x03\x17\x03\x05\x19\x06w\x03\x17\x03\x05\x1b\x06{\x03\x17\x03\t\x1d\x06\x7f\x03\x05\x05\x07\x0b\r\x06\x83\x03\x05\x05\x03\r\x05\x03\r\x89\x03\t\x03\x07+\x05\x03\x05\x03\x11\x1f\x06+\x03\x05\x05\x0f\x13!\x07\t\x8f\x03\x05\x03\x15\x05\x03\x01-\x03\x19\x05\x03\x01-\x03\x19#\x07\x01\x97\x07\x05\x11\x0b\x03\x17\x05\x03\x01#\x03\x0b\x03\x07\x01\x05\x03\x0b\x03#\x0f\x07\x01\x16\x02\x035\x05!%\x03\x07\x01\x05\x037\x03\'\x05\x03\x01\x1a\x02\x03\t\x03\x07\x01\x05\x03\x05\x03+\x03\x07\x01\x1e\x02\x03\x1f\x03)\t\x06\x01\x03\x05\x07/\x1d-\x03\x07\x01\x05\x039\x03\'\x05\x03\x01"\x02\x03%\x03\x07\x01\x05\x03\x11\x035\x03\x07\x01&\x02\x03;\x033\t\x06\x01\x03\x11\x079\x1f7\x11\x04\r\x051;\x0b\x11\t=\x07\x03\x15+\x03\x05\t\x07\x03A\x1f\x03\x13\x05\x03\t#\x03\x0b\x03\x07%\x05\x03\x13\x03\x05\r\x06%\x03\x13\x05\x03\x07\x07\x03IG\x03\x13\x0f\x07OM\x03\x1f\x05\t\x0b\x05\x03\tS\x03\t\x03\x07U\x05\x03\x05\x03\x0f\t\x06Y\x03\x05\x07\r\x01\x11\x11\x04\t\x03\x13\x06\x03\x01\x05\x01\x00\xe2\x1c\x99\x0b\x0b%\x03\x11\x0f\x0b\t\t\x0b!\x11#\x1f/!)!)#\x1f\x19\xa9\x0f99A9;;m\x19\x85\x8fW\xb3K\x9bM\x9bn\x03\x1b%)9+\x1b\x1f\x1f\x15\x1d\x15+\x13\ri\x1f\x11\x15\x17\x15\x11\x11\x1b\x17\x15\x17\x0f\x11\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00value\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=complex128 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/real\x00jit()/jit(main)/imag\x00jit()/jit(main)/neg\x00jit()/jit(main)/complex\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True subset_by_index=None]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00mhlo.layout_mode\x00default\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00lapack_zheevd_ffi\x00mode\x00uplo\x00', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_19["c64"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_cheevd_ffi'], + serialized_date=datetime.date(2024, 8, 19), + inputs=(), + expected_outputs=(array([[-0.6185769 +0.j, -0.20142993 +0.j, -0.09725195 +0.j, + 0.62983674 +0.j, -0.07926044 +0.j, 0.3605001 -0.j, + -0.019093221 +0.j, -0.18446997 +0.j], + [-0.47070873 +0.j, 0.29325768 +0.j, -0.19454116 +0.j, + -0.6394365 +0.j, 0.06229549 +0.j, 0.33249345 +0.j, + 0.28112718 +0.j, -0.22856665 +0.j], + [-0.32284075 +0.j, -0.12361939 +0.j, 0.20547704 +0.j, + -0.18307868 +0.j, 0.47294614 +0.j, -0.3170349 +0.j, + -0.6373532 +0.j, -0.27266347 +0.j], + [-0.17497246 +0.j, -0.079641335 +0.j, 0.15042792 +0.j, + -0.15416273 +0.j, -0.815209 +0.j, -0.38054234 +0.j, + -0.083263926 +0.j, -0.31676024 +0.j], + [-0.027104257 +0.j, -0.26490977 +0.j, 0.32271704 +0.j, + 0.08653544 +0.j, 0.30305928 +0.j, -0.33998996 +0.j, + 0.6926741 +0.j, -0.360857 +0.j], + [ 0.120763965 +0.j, 0.43288827 +0.j, -0.64385164 +0.j, + 0.2652551 +0.j, 0.094823755 +0.j, -0.37435007 +0.j, + 0.00091664493+0.j, -0.40495378 +0.j], + [ 0.26863196 +0.j, 0.51607686 +0.j, 0.53846526 +0.j, + 0.16969058 +0.j, -0.0216703 +0.j, 0.35755336 +0.j, + -0.113144726 +0.j, -0.4490505 +0.j], + [ 0.4165004 +0.j, -0.57262254 +0.j, -0.28144246 +0.j, + -0.17463988 +0.j, -0.016984984 +0.j, 0.3613705 +0.j, + -0.12186296 +0.j, -0.49314725 +0.j]], dtype=complex64), array([-2.4598808e+01, -3.3105560e-05, -3.1002426e-05, -1.0103593e-05, + -1.0022322e-05, 4.0141886e-06, 9.5510331e-06, 2.7659882e+02], + dtype=float32)), + mlir_module_text=r""" +#loc7 = loc("third_party/py/jax/tests/export_back_compat_test.py":277:27) +#loc18 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc7)) +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<8x8xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<8xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<64xcomplex> loc(#loc9) + %1 = stablehlo.reshape %0 : (tensor<64xcomplex>) -> tensor<8x8xcomplex> loc(#loc10) + %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xcomplex>) -> tensor<8x8xcomplex> loc(#loc11) + %3 = stablehlo.real %2 : (tensor<8x8xcomplex>) -> tensor<8x8xf32> loc(#loc12) + %4 = stablehlo.imag %2 : (tensor<8x8xcomplex>) -> tensor<8x8xf32> loc(#loc13) + %5 = stablehlo.negate %4 : tensor<8x8xf32> loc(#loc14) + %6 = stablehlo.complex %3, %5 : tensor<8x8xcomplex> loc(#loc15) + %7 = stablehlo.add %1, %6 : tensor<8x8xcomplex> loc(#loc16) + %cst = stablehlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor> loc(#loc) + %8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<8x8xcomplex> loc(#loc17) + %9 = stablehlo.divide %7, %8 : tensor<8x8xcomplex> loc(#loc17) + %10 = call @tril(%9) : (tensor<8x8xcomplex>) -> tensor<8x8xcomplex> loc(#loc18) + %c = stablehlo.constant dense<8> : tensor loc(#loc19) + %c_0 = stablehlo.constant dense<8> : tensor loc(#loc19) + %11:3 = stablehlo.custom_call @lapack_cheevd_ffi(%10) {mhlo.backend_config = {mode = 86 : ui8, uplo = 76 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor<8x8xcomplex>) -> (tensor<8x8xcomplex>, tensor<8xf32>, tensor) loc(#loc19) + %c_1 = stablehlo.constant dense<0> : tensor loc(#loc19) + %12 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor) -> tensor loc(#loc19) + %13 = stablehlo.compare EQ, %11#2, %12, SIGNED : (tensor, tensor) -> tensor loc(#loc19) + %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc19) + %cst_2 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc19) + %15 = stablehlo.broadcast_in_dim %cst_2, dims = [] : (tensor>) -> tensor<8x8xcomplex> loc(#loc19) + %16 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> loc(#loc19) + %17 = stablehlo.select %16, %11#0, %15 : tensor<8x8xi1>, tensor<8x8xcomplex> loc(#loc19) + %18 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi1> loc(#loc19) + %cst_3 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc19) + %19 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor) -> tensor<8xf32> loc(#loc19) + %20 = stablehlo.broadcast_in_dim %18, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> loc(#loc19) + %21 = stablehlo.select %20, %11#1, %19 : tensor<8xi1>, tensor<8xf32> loc(#loc19) + return %17, %21 : tensor<8x8xcomplex>, tensor<8xf32> loc(#loc) + } loc(#loc) + func.func private @tril(%arg0: tensor<8x8xcomplex> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc7))) -> (tensor<8x8xcomplex> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> loc(#loc20) + %c = stablehlo.constant dense<0> : tensor loc(#loc18) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<8x8xi32> loc(#loc21) + %2 = stablehlo.add %0, %1 : tensor<8x8xi32> loc(#loc21) + %3 = stablehlo.iota dim = 1 : tensor<8x8xi32> loc(#loc22) + %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> loc(#loc23) + %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc18) + %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<8x8xcomplex> loc(#loc24) + %6 = stablehlo.select %4, %arg0, %5 : tensor<8x8xi1>, tensor<8x8xcomplex> loc(#loc25) + return %6 : tensor<8x8xcomplex> loc(#loc18) + } loc(#loc18) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":269:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":269:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:34) +#loc4 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:25) +#loc5 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:15) +#loc6 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:14) +#loc8 = loc("third_party/py/jax/tests/export_back_compat_test.py":277:11) +#loc9 = loc("jit()/jit(main)/iota[dtype=complex64 shape=(64,) dimension=0]"(#loc1)) +#loc10 = loc("jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]"(#loc2)) +#loc11 = loc("jit()/jit(main)/transpose[permutation=(1, 0)]"(#loc3)) +#loc12 = loc("jit()/jit(main)/real"(#loc4)) +#loc13 = loc("jit()/jit(main)/imag"(#loc4)) +#loc14 = loc("jit()/jit(main)/neg"(#loc4)) +#loc15 = loc("jit()/jit(main)/complex"(#loc4)) +#loc16 = loc("jit()/jit(main)/add"(#loc5)) +#loc17 = loc("jit()/jit(main)/div"(#loc6)) +#loc19 = loc("jit()/jit(main)/eigh[lower=True sort_eigenvalues=True subset_by_index=None]"(#loc8)) +#loc20 = loc("jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]"(#loc7)) +#loc21 = loc("jit()/jit(main)/jit(tril)/add"(#loc7)) +#loc22 = loc("jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]"(#loc7)) +#loc23 = loc("jit()/jit(main)/jit(tril)/ge"(#loc7)) +#loc24 = loc("jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]"(#loc7)) +#loc25 = loc("jit()/jit(main)/jit(tril)/select_n"(#loc7)) +""", + mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x013\x05\x01\x03\x01\x03\x05\x03#\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%\'\x03\xda\x02*\x02?\x01\xab\x0f\x0b\x13\x17\x0f\x0b\x07\x17\x0b\x0b\x0f\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x0f\x13+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x13\x0b\x0f\x0b\x17\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x13\x0b\x17\x13\x0b\x0b\x17S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03a\x0b\x0b\x0b\x0b\x0f\x0b\x0bO\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0b/O//\x0b\x0b\x0b\x0b\x1b\x0b\x0f\x0b\x0f\x0f\x0f\x17\x17/\x0f\x0b/O\x1f/\x01\x0b\x1f\x17\x17\x17\x17\x01\x05\x0b\x0f\x03;\x17\x07\x0f\x0f\x07\x07\x13\x17\x0b\x17\x0f\x07\x07\x17\x13\x07\x0f\x17\x17\x13\x17\x13\x13\x13\x0f\x17\x13\x13\x13\x026\n\x1d\x93\x95\x05)\x03\x03\x13\xd5\x17\x03V\x047\x1d?\x07\x05+\x1f\x17\x03>\x043\x05-\x05/\x11\x03\x05\x051\x053\x055\x057\x03\x03!\xd1\x059\x03\x03\x0b\xd3\x1dE\x07\x05;\x05=\x1d\x8b\x8d\x03\x03\x0b\xe1\x03\t135\x157\x15\x119\x05?\x11\x01\x00\x05A\x05C\x05E\x03\x0b\x17\xaf\x19\xbb\x1b\xbd\x11\xc7\x1d\xc9\x03\x0b\x17\xb3\x19\xcd\x1b\xb3\x11\xb5\x1d\xcf\x05G\x1dC\x07\x05I\x05K\x03\x03!\xd7\x1dK\x07\x05M\x03\x05\'\xb7)\xd9\x1dQ\x07\x05O\x03\x03\x0b\xdb\x1dW\x07\x05Q\x1d[\x07\x05S\x1d_a\x05U\x17\x036\x045\x1deg\x05W\x17\x036\x04\x1d\x03\x03k\xdd\x05Y\x1doq\x05[\x17\x03>\x04E\x1du\x0f\x05]\x1dy\x0f\x05_\x1d}\x0f\x05a\x1d\x81\x0f\x05c\x1d\x85\x87\x05e\x17\x03>\x04\x1f\x03\x03\x0b\xdf\x05g\x17\x03>\x04\x1d\x03\x03\x91\xb5\x05i\x05k\x17\x03V\x04\x17\x03\x13\x99\xe3\x9b\xe5\x9d\xe7\x9f\xaf\xa1\xe9\xa3\xeb\xa5\xf5\xa7\xf7\xa9\xfb\x05m\x05o\x05q\x05s\x05u\x05w\x05y\x05{\x05}\x1d\x7f\x1d\x81\x03\x01\x1d\x83\x03\x03\xcb\x1d\x85\t\x07\x1f/!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#\'\x03\x05\xbf\xc3\r\x05\xb1\xc1\xab\xad\x1d\x87\r\x05\xb1\xc5\xab\xad\x1d\x89\x1d\x8b\x1d\x8d\r\x03\xab\xad#)\x1d\x8f\x13\x07\x01\x1f\x0b\t\x00\x00\x00\x00\x1f+\x01\x13\x07\x05\x07\x05\x1f\t\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f!!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\x11\x00\x00\x00@\x00\x00\x00\x00\x1f\x19\x11\x08\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x91\x1d\x93\x05\x01\r\x05\xed\xef\xf1\xf3\x1d\x95\x13#V\x1d\x97\x13#L\x03\x03\xb9\x03\x03\xf9\x15\x03\x01\x01\x01\x03\x07\xb9\xfd\xff\x1f1\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f3\x01\x07\x01\x1f\t\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f!!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f%\t\x00\x00\xc0\x7f\x1f=\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x05\'\xb7)\x02\x02\x03\x03\x0b\x06\x02\x03\x03\x13\n\x02\x03\x03\x0b\x0e\x02\x03\x03\x13\x12\x02\x01\t\x01\x02\x02)\x05!!\x15\x1d)\x01\x15)\x01\x1d\x01\t)\x03!\x0f)\x05!!\x1d\x03\x0f)\x05!!\x0f)\x01\x07\x13\x1b)\x05!!\r)\x03\t\x07!)\x01\x0f\x11\x01\x05\x05\x11\x11\x03\x05\x03\x05)\x03\x01\x07)\x03\x02\x02\x15)\x03\t\x1b)\x03\x05\x1b)\x03\x01\x1b)\x01\r)\x05\x05\x05\r)\x03\x05\r)\x03!\r)\x03\x05\x07\x04\x06\x05\x05\x01\x11\r/\x07\x03\x01\t\x0b\x11\r;\x07\x03=u\x07\x03]\x1f\x03-\x13\x06c\x03\x05\x03\x01\x15\x07mi\x03\x05\x03\x03\x17\x06s\x03\x17\x03\x05\x19\x06w\x03\x17\x03\x05\x1b\x06{\x03\x17\x03\t\x1d\x06\x7f\x03\x05\x05\x07\x0b\r\x06\x83\x03\x05\x05\x03\r\x05\x03\r\x89\x03\t\x03\x07+\x05\x03\x05\x03\x11\x1f\x06+\x03\x05\x05\x0f\x13!\x07\t\x8f\x03\x05\x03\x15\x05\x03\x01-\x03\x19\x05\x03\x01-\x03\x19#\x07\x01\x97\x07\x05\x11\x0b\x03\x17\x05\x03\x01#\x03\x0b\x03\x07\x01\x05\x03\x0b\x03#\x0f\x07\x01\x16\x02\x035\x05!%\x03\x07\x01\x05\x037\x03\'\x05\x03\x01\x1a\x02\x03\t\x03\x07\x01\x05\x03\x05\x03+\x03\x07\x01\x1e\x02\x03\x1f\x03)\t\x06\x01\x03\x05\x07/\x1d-\x03\x07\x01\x05\x039\x03\'\x05\x03\x01"\x02\x03%\x03\x07\x01\x05\x03\x11\x035\x03\x07\x01&\x02\x03;\x033\t\x06\x01\x03\x11\x079\x1f7\x11\x04\r\x051;\x0b\x11\t=\x07\x03\x15+\x03\x05\t\x07\x03A\x1f\x03\x13\x05\x03\t#\x03\x0b\x03\x07%\x05\x03\x13\x03\x05\r\x06%\x03\x13\x05\x03\x07\x07\x03IG\x03\x13\x0f\x07OM\x03\x1f\x05\t\x0b\x05\x03\tS\x03\t\x03\x07U\x05\x03\x05\x03\x0f\t\x06Y\x03\x05\x07\r\x01\x11\x11\x04\t\x03\x13\x06\x03\x01\x05\x01\x00\xde\x1c\x99\x0b\x0b%\x03\x11\x0f\x0b\t\t\x0b!\x11#\x1f/!)!)#\x1f\x19\xa9\x0f99A9;;m\x19\x85\x8dW\xb3K\x9bM\x9bn\x03\x1b%)9+\x1b\x1f\x1f\x15\x1d\x15+\x13\ri\x1f\x11\x15\x17\x15\x11\x11\x1b\x17\x15\x17\x0f\x11\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00value\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=complex64 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/real\x00jit()/jit(main)/imag\x00jit()/jit(main)/neg\x00jit()/jit(main)/complex\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True subset_by_index=None]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00mhlo.layout_mode\x00default\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00lapack_cheevd_ffi\x00mode\x00uplo\x00', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_19["f32"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_ssyevd_ffi'], + serialized_date=datetime.date(2024, 8, 19), + inputs=(), + expected_outputs=(array([[-0.6185769 , -0.20142993 , -0.09725195 , 0.62983674 , + -0.07926044 , 0.3605001 , -0.019093221 , -0.18446997 ], + [-0.47070873 , 0.29325768 , -0.19454119 , -0.6394365 , + 0.0622955 , 0.33249345 , 0.28112718 , -0.22856665 ], + [-0.32284075 , -0.12361939 , 0.20547704 , -0.18307868 , + 0.47294614 , -0.3170349 , -0.6373532 , -0.27266347 ], + [-0.17497246 , -0.079641335 , 0.15042791 , -0.15416273 , + -0.815209 , -0.38054234 , -0.083263926 , -0.31676024 ], + [-0.027104253 , -0.26490977 , 0.32271704 , 0.08653544 , + 0.30305928 , -0.33998996 , 0.6926741 , -0.360857 ], + [ 0.12076397 , 0.43288827 , -0.64385164 , 0.2652551 , + 0.09482376 , -0.37435007 , 0.00091664493, -0.40495378 ], + [ 0.26863196 , 0.51607686 , 0.53846526 , 0.16969058 , + -0.021670295 , 0.35755336 , -0.113144726 , -0.4490505 ], + [ 0.4165004 , -0.57262254 , -0.2814425 , -0.17463988 , + -0.01698498 , 0.3613705 , -0.12186296 , -0.49314725 ]], + dtype=float32), array([-2.4598808e+01, -3.3105560e-05, -3.1002426e-05, -1.0103593e-05, + -1.0022322e-05, 4.0141886e-06, 9.5510331e-06, 2.7659882e+02], + dtype=float32)), + mlir_module_text=r""" +#loc6 = loc("third_party/py/jax/tests/export_back_compat_test.py":277:27) +#loc13 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc6)) +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<8x8xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<8xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<64xf32> loc(#loc8) + %1 = stablehlo.reshape %0 : (tensor<64xf32>) -> tensor<8x8xf32> loc(#loc9) + %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xf32>) -> tensor<8x8xf32> loc(#loc10) + %3 = stablehlo.add %1, %2 : tensor<8x8xf32> loc(#loc11) + %cst = stablehlo.constant dense<2.000000e+00> : tensor loc(#loc) + %4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<8x8xf32> loc(#loc12) + %5 = stablehlo.divide %3, %4 : tensor<8x8xf32> loc(#loc12) + %6 = call @tril(%5) : (tensor<8x8xf32>) -> tensor<8x8xf32> loc(#loc13) + %c = stablehlo.constant dense<8> : tensor loc(#loc14) + %c_0 = stablehlo.constant dense<8> : tensor loc(#loc14) + %7:3 = stablehlo.custom_call @lapack_ssyevd_ffi(%6) {mhlo.backend_config = {mode = 86 : ui8, uplo = 76 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor<8x8xf32>) -> (tensor<8x8xf32>, tensor<8xf32>, tensor) loc(#loc14) + %c_1 = stablehlo.constant dense<0> : tensor loc(#loc14) + %8 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor) -> tensor loc(#loc14) + %9 = stablehlo.compare EQ, %7#2, %8, SIGNED : (tensor, tensor) -> tensor loc(#loc14) + %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc14) + %cst_2 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc14) + %11 = stablehlo.broadcast_in_dim %cst_2, dims = [] : (tensor) -> tensor<8x8xf32> loc(#loc14) + %12 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> loc(#loc14) + %13 = stablehlo.select %12, %7#0, %11 : tensor<8x8xi1>, tensor<8x8xf32> loc(#loc14) + %14 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1xi1> loc(#loc14) + %cst_3 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc14) + %15 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor) -> tensor<8xf32> loc(#loc14) + %16 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> loc(#loc14) + %17 = stablehlo.select %16, %7#1, %15 : tensor<8xi1>, tensor<8xf32> loc(#loc14) + return %13, %17 : tensor<8x8xf32>, tensor<8xf32> loc(#loc) + } loc(#loc) + func.func private @tril(%arg0: tensor<8x8xf32> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc6))) -> (tensor<8x8xf32> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> loc(#loc15) + %c = stablehlo.constant dense<0> : tensor loc(#loc13) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<8x8xi32> loc(#loc16) + %2 = stablehlo.add %0, %1 : tensor<8x8xi32> loc(#loc16) + %3 = stablehlo.iota dim = 1 : tensor<8x8xi32> loc(#loc17) + %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> loc(#loc18) + %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc13) + %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<8x8xf32> loc(#loc19) + %6 = stablehlo.select %4, %arg0, %5 : tensor<8x8xi1>, tensor<8x8xf32> loc(#loc20) + return %6 : tensor<8x8xf32> loc(#loc13) + } loc(#loc13) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":269:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":269:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:34) +#loc4 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:15) +#loc5 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:14) +#loc7 = loc("third_party/py/jax/tests/export_back_compat_test.py":277:11) +#loc8 = loc("jit()/jit(main)/iota[dtype=float32 shape=(64,) dimension=0]"(#loc1)) +#loc9 = loc("jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]"(#loc2)) +#loc10 = loc("jit()/jit(main)/transpose[permutation=(1, 0)]"(#loc3)) +#loc11 = loc("jit()/jit(main)/add"(#loc4)) +#loc12 = loc("jit()/jit(main)/div"(#loc5)) +#loc14 = loc("jit()/jit(main)/eigh[lower=True sort_eigenvalues=True subset_by_index=None]"(#loc7)) +#loc15 = loc("jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]"(#loc6)) +#loc16 = loc("jit()/jit(main)/jit(tril)/add"(#loc6)) +#loc17 = loc("jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]"(#loc6)) +#loc18 = loc("jit()/jit(main)/jit(tril)/ge"(#loc6)) +#loc19 = loc("jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]"(#loc6)) +#loc20 = loc("jit()/jit(main)/jit(tril)/select_n"(#loc6)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01+\x05\x01\x03\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03\x96\x02\xff9\x01\xa1\x0f\x13\x17\x0b\x0f\x0b\x07\x0b\x0b\x0f\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x0f\x13\x13+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x13\x0b\x0f\x0b\x17\x0f\x0b\x17\x13\x0b\x17\x13\x0b\x0b\x17S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x1b\x13\x13\x03_\x0b\x0b\x0b\x0b\x0f\x0b\x0bO\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1fO\x1f/\x0b\x0b\x0b\x0b\x1b\x0b\x0f\x0b\x0f\x0f\x0f\x17\x17/\x0f\x0b\x1fO/\x01\x05\x0b\x0f\x035\x17\x0f\x07\x0f\x07\x07\x13\x17\x0f\x07\x07\x17\x13\x07\x17\x17\x13\x17\x13\x13\x13\x0f\x17\x13\x13\x13\x02:\t\x1d\x83\x85\x03\x03\x11\xcb\x17\x07V\x047\x05!\x1d?\x05\x05#\x1f\x05%\x05'\x11\x03\x05\x05)\x05+\x05-\x05/\x03\x03\x1f\xc7\x051\x03\x03\x0b\xc9\x1dE\x05\x053\x055\x1d{}\x03\x03\x0b\xd7\x03\x03\x0b\xf9\x03\t135\x137\x13\x0f9\x057\x11\x01\x00\x059\x05;\x05=\x03\x0b\x15\xa5\x17\xb1\x19\xb3\x0f\xbd\x1b\xbf\x03\x0b\x15\xa9\x17\xc3\x19\xa9\x0f\xab\x1b\xc5\x05?\x1dC\x05\x05A\x05C\x03\x03\x1f\xcd\x1dK\x05\x05E\x03\x05%\xad'\xcf\x1dQ\x05\x05G\x03\x03\x0b\xd1\x1dW\x05\x05I\x1d[\x05\x05K\x1d_a\x05M\x17\x076\x045\x1deg\x05O\x17\x076\x04\x1d\x03\x03k\xd3\x05Q\x1doq\x05S\x17\x07>\x04E\x1duw\x05U\x17\x07>\x04\x1f\x03\x03\x0b\xd5\x05W\x17\x07>\x04\x1d\x03\x03\x81\xab\x05Y\x05[\x17\x07V\x04\x17\x03\x13\x89\xd9\x8b\xdb\x8d\xdd\x8f\xa5\x91\xdf\x93\xe1\x95\xeb\x97\xed\x99\xf1\x05]\x05_\x05a\x05c\x05e\x05g\x05i\x05k\x05m\x03\x05%\xad'\xf7\x03\x03\x11\xfb\x03\x03\x11\xfd\x1do\x1dq\x03\x01\x1ds\x03\x03\xc1\x1du\t\x07\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#!\x03\x05\xb5\xb9\r\x05\xa7\xb7\xa1\xa3\x1dw\r\x05\xa7\xbb\xa1\xa3\x1dy\x1d{\x1d}\r\x03\xa1\xa3##\x1d\x7f\x13\t\x01\x1f\x0b\t\x00\x00\x00\x00\x1f%\x01\x13\t\x05\x07\x05\x1f\x07\t\x00\x00\x00\x00\x1f\x1d!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\t\x00\x00\x00@\x1f\x15\x11\x08\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x81\x1d\x83\x05\x01\r\x05\xe3\xe5\xe7\xe9\x1d\x85\x13\x1fV\x1d\x87\x13\x1fL\x03\x03\xaf\x03\x03\xef\x15\x03\x01\x01\x01\x03\x07\xaf\xf3\xf5\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f-\x01\x07\x01\x1f\x07\t\x00\x00\xc0\x7f\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f7\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x05!!\x0f)\x01\x0f\x1d)\x01\x19\x01\t)\x03!\x0f)\x05!!\x19)\x01\t\x13\x1b)\x05!!\r)\x03\t\t!\x11\x01\x05\x05\x11\x11\x03\x05\x03\x05)\x03\x01\t)\x03\x02\x02\x0f)\x03\t\x17)\x03\x05\x17)\x03\x01\x17)\x01\r)\x05\x05\x05\r)\x03\x05\r)\x03!\r)\x03\x05\t\x04~\x04\x05\x01\x11\r/\x07\x03\x01\t\x0b\x11\r;\x07\x035e\x07\x03]\x1d\x03'\x13\x06c\x03\x05\x03\x01\x15\x07mi\x03\x05\x03\x03\r\x06s\x03\x05\x05\x03\x05\x05\x03\ry\x03\x07\x03\x07)\x03\x03\x05\x03\t\x17\x06)\x03\x05\x05\x07\x0b\x19\x07\t\x7f\x03\x05\x03\r\x05\x03\x01+\x03\x15\x05\x03\x01+\x03\x15\x1b\x07\x01\x87\x07\x05\x11\x0b\x03\x0f\x05\x03\x01!\x03\x0b\x03\x07\x01\x03\x03\x0b\x03\x1b\x0f\x07\x01\x9b\x03/\x05\x19\x1d\x03\x07\x01\x03\x031\x03\x1f\x05\x03\x01-\x03\x07\x03\x07\x01\x03\x03\x05\x03#\x03\x07\x01\x9d\x03\x1b\x03!\t\x06\x01\x03\x05\x07'\x15%\x03\x07\x01\x03\x033\x03\x1f\x05\x03\x01-\x03\x07\x03\x07\x01\x03\x03\x11\x03-\x03\x07\x01\x9f\x035\x03+\t\x06\x01\x03\x11\x071\x17/\x11\x04\r\x05)3\x0b\x11\t=\x07\x03\x15+\x03\x05\t\x07\x03A\x1d\x03\x13\x05\x03\t!\x03\x0b\x03\x07#\x03\x03\x13\x03\x05\r\x06#\x03\x13\x05\x03\x07\x07\x03IG\x03\x13\x0f\x07OM\x03\x1b\x05\t\x0b\x05\x03\tS\x03\x07\x03\x07U\x03\x03\x05\x03\x0f\t\x06Y\x03\x05\x07\r\x01\x11\x11\x04\t\x03\x13\x06\x03\x01\x05\x01\x00J\x1a\x89\x0b\x0b%\x03\x11\x0f\x0b\t\t\x0b!\x11#\x1f/!)!)#\x1f\x19\xa9\x0f99m\x19\x85\x89W\xb3K\x9bM\x9bn\x03\x1b%)9+\x1b\x1f\x1f\x15\x1d\x15+\x13\ri\x1f\x11\x15\x1b\x17\x15\x17\x0f\x11\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00value\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True subset_by_index=None]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00mhlo.layout_mode\x00default\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00lapack_ssyevd_ffi\x00mode\x00uplo\x00", + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_19["f64"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_dsyevd_ffi'], + serialized_date=datetime.date(2024, 8, 19), + inputs=(), + expected_outputs=(array([[-6.1857700048412056e-01, 2.4081403770912022e-01, + 3.5662489253627483e-01, -6.3034019033669797e-01, + 1.0043483479985752e-16, -2.8842036081919542e-02, + 7.7164692943283169e-25, -1.8446994643771725e-01], + [-4.7070881487314614e-01, 4.7473787464450845e-01, + -4.8036836210243367e-01, 4.3802686872516400e-01, + 1.7961797619639258e-01, 8.3080980076741355e-03, + 2.1415294457221756e-01, -2.2856669794666584e-01], + [-3.2284062926217072e-01, -5.4336490915553370e-01, + 2.2181041859724990e-01, 2.9947877954402297e-01, + -3.6491813600134632e-01, 3.2867679819727436e-01, + 3.8223299448843473e-01, -2.7266344945561438e-01], + [-1.7497244365119530e-01, -8.9251550609769414e-02, + -6.3518515114898394e-02, 1.9162997359209971e-01, + -2.2087281326110139e-01, 5.9957027043505064e-02, + -8.7632498908241274e-01, -3.1676020096456303e-01], + [-2.7104258040220038e-02, -3.3772873786627672e-01, + 2.5901386593721748e-01, 1.7032650752287815e-01, + 6.7521217612940332e-01, -4.5036136532965476e-01, + -1.2279030059078447e-02, -3.6085695247351163e-01], + [ 1.2076392757075530e-01, -3.3834734096469254e-01, + -6.5506827461665540e-01, -5.0472498521116749e-01, + 6.9987430903492118e-02, 1.0595648906599275e-01, + 8.3443844143082022e-02, -4.0495370398246017e-01], + [ 2.6863211318173097e-01, 2.2958613191407318e-01, + 6.3952843755683941e-02, 1.8776775771084137e-02, + -5.3523731432241317e-01, -5.9199531677602002e-01, + 1.7916671834524248e-01, -4.4905045549140887e-01], + [ 4.1650029879270661e-01, 3.6355449432857079e-01, + 2.9755313100756142e-01, 1.6826270392615944e-02, + 1.9621068035557282e-01, 5.6830030587314817e-01, + 2.9607517592514246e-02, -4.9314720700035747e-01]]), array([-2.4598804776133626e+01, -4.6567755957874661e-14, + -1.9932120610662194e-14, -5.7323356091157378e-15, + -4.5459724251334835e-16, 4.0479851042511616e-14, + 9.2325194924982089e-14, 2.7659880477613365e+02])), + mlir_module_text=r""" +#loc6 = loc("third_party/py/jax/tests/export_back_compat_test.py":277:27) +#loc13 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc6)) +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<8x8xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<8xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<64xf64> loc(#loc8) + %1 = stablehlo.reshape %0 : (tensor<64xf64>) -> tensor<8x8xf64> loc(#loc9) + %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xf64>) -> tensor<8x8xf64> loc(#loc10) + %3 = stablehlo.add %1, %2 : tensor<8x8xf64> loc(#loc11) + %cst = stablehlo.constant dense<2.000000e+00> : tensor loc(#loc) + %4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<8x8xf64> loc(#loc12) + %5 = stablehlo.divide %3, %4 : tensor<8x8xf64> loc(#loc12) + %6 = call @tril(%5) : (tensor<8x8xf64>) -> tensor<8x8xf64> loc(#loc13) + %c = stablehlo.constant dense<8> : tensor loc(#loc14) + %c_0 = stablehlo.constant dense<8> : tensor loc(#loc14) + %7:3 = stablehlo.custom_call @lapack_dsyevd_ffi(%6) {mhlo.backend_config = {mode = 86 : ui8, uplo = 76 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor<8x8xf64>) -> (tensor<8x8xf64>, tensor<8xf64>, tensor) loc(#loc14) + %c_1 = stablehlo.constant dense<0> : tensor loc(#loc14) + %8 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor) -> tensor loc(#loc14) + %9 = stablehlo.compare EQ, %7#2, %8, SIGNED : (tensor, tensor) -> tensor loc(#loc14) + %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc14) + %cst_2 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc14) + %11 = stablehlo.broadcast_in_dim %cst_2, dims = [] : (tensor) -> tensor<8x8xf64> loc(#loc14) + %12 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> loc(#loc14) + %13 = stablehlo.select %12, %7#0, %11 : tensor<8x8xi1>, tensor<8x8xf64> loc(#loc14) + %14 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1xi1> loc(#loc14) + %cst_3 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc14) + %15 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor) -> tensor<8xf64> loc(#loc14) + %16 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> loc(#loc14) + %17 = stablehlo.select %16, %7#1, %15 : tensor<8xi1>, tensor<8xf64> loc(#loc14) + return %13, %17 : tensor<8x8xf64>, tensor<8xf64> loc(#loc) + } loc(#loc) + func.func private @tril(%arg0: tensor<8x8xf64> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc6))) -> (tensor<8x8xf64> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> loc(#loc15) + %c = stablehlo.constant dense<0> : tensor loc(#loc13) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<8x8xi32> loc(#loc16) + %2 = stablehlo.add %0, %1 : tensor<8x8xi32> loc(#loc16) + %3 = stablehlo.iota dim = 1 : tensor<8x8xi32> loc(#loc17) + %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> loc(#loc18) + %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc13) + %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<8x8xf64> loc(#loc19) + %6 = stablehlo.select %4, %arg0, %5 : tensor<8x8xi1>, tensor<8x8xf64> loc(#loc20) + return %6 : tensor<8x8xf64> loc(#loc13) + } loc(#loc13) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":269:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":269:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:34) +#loc4 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:15) +#loc5 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:14) +#loc7 = loc("third_party/py/jax/tests/export_back_compat_test.py":277:11) +#loc8 = loc("jit()/jit(main)/iota[dtype=float64 shape=(64,) dimension=0]"(#loc1)) +#loc9 = loc("jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]"(#loc2)) +#loc10 = loc("jit()/jit(main)/transpose[permutation=(1, 0)]"(#loc3)) +#loc11 = loc("jit()/jit(main)/add"(#loc4)) +#loc12 = loc("jit()/jit(main)/div"(#loc5)) +#loc14 = loc("jit()/jit(main)/eigh[lower=True sort_eigenvalues=True subset_by_index=None]"(#loc7)) +#loc15 = loc("jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]"(#loc6)) +#loc16 = loc("jit()/jit(main)/jit(tril)/add"(#loc6)) +#loc17 = loc("jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]"(#loc6)) +#loc18 = loc("jit()/jit(main)/jit(tril)/ge"(#loc6)) +#loc19 = loc("jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]"(#loc6)) +#loc20 = loc("jit()/jit(main)/jit(tril)/select_n"(#loc6)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01+\x05\x01\x03\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03\x96\x02\xff9\x01\xa1\x0f\x13\x17\x0b\x0f\x0b\x07\x0b\x0b\x0f\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x0f\x13\x13+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x13\x0b\x0f\x0b\x17\x0f\x0b\x17\x13\x0b\x17\x13\x0b\x0b\x17S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x1b\x13\x13\x03_\x0b\x0b\x0b\x0b\x0f\x0b\x0bO\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0b/O//\x0b\x0b\x0b\x0b\x1b\x0b\x0f\x0b\x0f\x0f\x0f\x17\x17/\x0f\x0b/O/\x01\x05\x0b\x0f\x035\x17\x0f\x07\x0f\x07\x07\x13\x17\x0f\x07\x07\x17\x13\x07\x17\x17\x13\x17\x13\x13\x13\x0f\x17\x13\x13\x13\x02j\t\x1d\x83\x85\x03\x03\x11\xcb\x17\x07V\x047\x05!\x1d?\x05\x05#\x1f\x05%\x05'\x11\x03\x05\x05)\x05+\x05-\x05/\x03\x03\x1f\xc7\x051\x03\x03\x0b\xc9\x1dE\x05\x053\x055\x1d{}\x03\x03\x0b\xd7\x03\x03\x0b\xf9\x03\t135\x137\x13\x0f9\x057\x11\x01\x00\x059\x05;\x05=\x03\x0b\x15\xa5\x17\xb1\x19\xb3\x0f\xbd\x1b\xbf\x03\x0b\x15\xa9\x17\xc3\x19\xa9\x0f\xab\x1b\xc5\x05?\x1dC\x05\x05A\x05C\x03\x03\x1f\xcd\x1dK\x05\x05E\x03\x05%\xad'\xcf\x1dQ\x05\x05G\x03\x03\x0b\xd1\x1dW\x05\x05I\x1d[\x05\x05K\x1d_a\x05M\x17\x076\x045\x1deg\x05O\x17\x076\x04\x1d\x03\x03k\xd3\x05Q\x1doq\x05S\x17\x07>\x04E\x1duw\x05U\x17\x07>\x04\x1f\x03\x03\x0b\xd5\x05W\x17\x07>\x04\x1d\x03\x03\x81\xab\x05Y\x05[\x17\x07V\x04\x17\x03\x13\x89\xd9\x8b\xdb\x8d\xdd\x8f\xa5\x91\xdf\x93\xe1\x95\xeb\x97\xed\x99\xf1\x05]\x05_\x05a\x05c\x05e\x05g\x05i\x05k\x05m\x03\x05%\xad'\xf7\x03\x03\x11\xfb\x03\x03\x11\xfd\x1do\x1dq\x03\x01\x1ds\x03\x03\xc1\x1du\t\x07\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#!\x03\x05\xb5\xb9\r\x05\xa7\xb7\xa1\xa3\x1dw\r\x05\xa7\xbb\xa1\xa3\x1dy\x1d{\x1d}\r\x03\xa1\xa3##\x1d\x7f\x13\t\x01\x1f\x0b\t\x00\x00\x00\x00\x1f%\x01\x13\t\x05\x07\x05\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1d!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00@\x1f\x15\x11\x08\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x81\x1d\x83\x05\x01\r\x05\xe3\xe5\xe7\xe9\x1d\x85\x13\x1fV\x1d\x87\x13\x1fL\x03\x03\xaf\x03\x03\xef\x15\x03\x01\x01\x01\x03\x07\xaf\xf3\xf5\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f-\x01\x07\x01\x1f\x07\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f7\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x05!!\x0f)\x01\x0f\x1d)\x01\x19\x01\x0b)\x03!\x0f)\x05!!\x19)\x01\t\x13\x1b)\x05!!\r)\x03\t\t!\x11\x01\x05\x05\x11\x11\x03\x05\x03\x05)\x03\x01\t)\x03\x02\x02\x0f)\x03\t\x17)\x03\x05\x17)\x03\x01\x17)\x01\r)\x05\x05\x05\r)\x03\x05\r)\x03!\r)\x03\x05\t\x04~\x04\x05\x01\x11\r/\x07\x03\x01\t\x0b\x11\r;\x07\x035e\x07\x03]\x1d\x03'\x13\x06c\x03\x05\x03\x01\x15\x07mi\x03\x05\x03\x03\r\x06s\x03\x05\x05\x03\x05\x05\x03\ry\x03\x07\x03\x07)\x03\x03\x05\x03\t\x17\x06)\x03\x05\x05\x07\x0b\x19\x07\t\x7f\x03\x05\x03\r\x05\x03\x01+\x03\x15\x05\x03\x01+\x03\x15\x1b\x07\x01\x87\x07\x05\x11\x0b\x03\x0f\x05\x03\x01!\x03\x0b\x03\x07\x01\x03\x03\x0b\x03\x1b\x0f\x07\x01\x9b\x03/\x05\x19\x1d\x03\x07\x01\x03\x031\x03\x1f\x05\x03\x01-\x03\x07\x03\x07\x01\x03\x03\x05\x03#\x03\x07\x01\x9d\x03\x1b\x03!\t\x06\x01\x03\x05\x07'\x15%\x03\x07\x01\x03\x033\x03\x1f\x05\x03\x01-\x03\x07\x03\x07\x01\x03\x03\x11\x03-\x03\x07\x01\x9f\x035\x03+\t\x06\x01\x03\x11\x071\x17/\x11\x04\r\x05)3\x0b\x11\t=\x07\x03\x15+\x03\x05\t\x07\x03A\x1d\x03\x13\x05\x03\t!\x03\x0b\x03\x07#\x03\x03\x13\x03\x05\r\x06#\x03\x13\x05\x03\x07\x07\x03IG\x03\x13\x0f\x07OM\x03\x1b\x05\t\x0b\x05\x03\tS\x03\x07\x03\x07U\x03\x03\x05\x03\x0f\t\x06Y\x03\x05\x07\r\x01\x11\x11\x04\t\x03\x13\x06\x03\x01\x05\x01\x00J\x1a\x89\x0b\x0b%\x03\x11\x0f\x0b\t\t\x0b!\x11#\x1f/!)!)#\x1f\x19\xa9\x0f99m\x19\x85\x89W\xb3K\x9bM\x9bn\x03\x1b%)9+\x1b\x1f\x1f\x15\x1d\x15+\x13\ri\x1f\x11\x15\x1b\x17\x15\x17\x0f\x11\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00value\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=float64 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True subset_by_index=None]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00mhlo.layout_mode\x00default\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00lapack_dsyevd_ffi\x00mode\x00uplo\x00", + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 45eed43e0b4f..1a792e3adc0c 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -615,7 +615,9 @@ def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors, out_aval = ctx.avals_out[0] batch_dims = operand_aval.shape[:-2] op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) - w, vl, vr, info = lapack.geev_hlo(operand_aval.dtype, operand, + # TODO(b/344892332): Remove the conditional after the compatibility period. + ctx_args = (ctx,) if jaxlib_version >= (0, 4, 32) else () + w, vl, vr, info = lapack.geev_hlo(*ctx_args, operand_aval.dtype, operand, input_shape_vals=op_shape_vals, jobvl=compute_left_eigenvectors, jobvr=compute_right_eigenvectors) @@ -801,7 +803,8 @@ def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues, subset_by_index): def _eigh_cpu_gpu_lowering( - syevd_impl, ctx, operand, *, lower, sort_eigenvalues, subset_by_index + syevd_impl, ctx, operand, *, lower, sort_eigenvalues, subset_by_index, + platform=None ): del sort_eigenvalues # The CPU/GPU implementations always sort. operand_aval, = ctx.avals_in @@ -821,7 +824,12 @@ def _eigh_cpu_gpu_lowering( raise NotImplementedError("subset_by_index not implemented for CPU and GPU") op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) - v, w, info = syevd_impl(operand_aval.dtype, operand, + cpu_args = [] + if platform == "cpu": + # TODO(b/344892332): Remove the conditional after the compatibility period. + ctx_args = (ctx,) if jaxlib_version >= (0, 4, 32) else () + cpu_args.extend(ctx_args) + v, w, info = syevd_impl(*cpu_args, operand_aval.dtype, operand, a_shape_vals=op_shape_vals, lower=lower) zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))) @@ -955,15 +963,17 @@ def _eigh_batching_rule( batching.primitive_batchers[eigh_p] = _eigh_batching_rule mlir.register_lowering( - eigh_p, partial(_eigh_cpu_gpu_lowering, lapack.syevd_hlo), + eigh_p, partial(_eigh_cpu_gpu_lowering, lapack.syevd_hlo, platform='cpu'), platform='cpu') if gpu_solver is not None: mlir.register_lowering( - eigh_p, partial(_eigh_cpu_gpu_lowering, gpu_solver.cuda_syevd), + eigh_p, partial(_eigh_cpu_gpu_lowering, gpu_solver.cuda_syevd, + platform='cuda'), platform='cuda') mlir.register_lowering( - eigh_p, partial(_eigh_cpu_gpu_lowering, gpu_solver.rocm_syevd), + eigh_p, partial(_eigh_cpu_gpu_lowering, gpu_solver.rocm_syevd, + platform='rocm'), platform='rocm') mlir.register_lowering( diff --git a/jaxlib/lapack.py b/jaxlib/lapack.py index a71f219acd1d..a389380a61ec 100644 --- a/jaxlib/lapack.py +++ b/jaxlib/lapack.py @@ -27,6 +27,7 @@ from jaxlib import xla_client from .cpu import _lapack +from .cpu._lapack import eig from .hlo_helpers import ( custom_call, hlo_u8, hlo_s32, ensure_hlo_s32, hlo_add, hlo_min, @@ -523,10 +524,9 @@ def gesdd_hlo(ctx, dtype, a: ir.Value, *, full_matrices=True, compute_uv=True, # # syevd: Symmetric eigendecomposition -def syevd_hlo(dtype, a: ir.Value, +def syevd_hlo(ctx, dtype, a: ir.Value, a_shape_vals: tuple[DimensionSize, ...], lower=False): - _lapack.initialize() a_type = ir.RankedTensorType(a.type) assert len(a_shape_vals) >= 2 m, n = a_shape_vals[-2:] @@ -535,76 +535,110 @@ def syevd_hlo(dtype, a: ir.Value, batch_dims_vals = a_shape_vals[:-2] num_bd = len(a_shape_vals) - 2 + mode = _enum_to_char_attr(eig.ComputationMode.kComputeEigenvectors) i32_type = ir.IntegerType.get_signless(32) workspace: list[ShapeTypePair] - if dtype == np.float32: - fn = "lapack_ssyevd" - eigvals_type = ir.F32Type.get() - workspace = [ - ([_lapack.syevd_work_size(n)], a_type.element_type), - ([_lapack.syevd_iwork_size(n)], i32_type), - ] - elif dtype == np.float64: - fn = "lapack_dsyevd" - eigvals_type = ir.F64Type.get() - workspace = [ - ([_lapack.syevd_work_size(n)], a_type.element_type), - ([_lapack.syevd_iwork_size(n)], i32_type), - ] - elif dtype == np.complex64: - fn = "lapack_cheevd" + layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) + # Hermitian is for complex square matrices, symmetric otherwise. + fn_base = "he" if dtype == np.complex64 or dtype == np.complex128 else "sy" + fn_base = prepare_lapack_call(fn_base=fn_base + "evd", dtype=dtype) + if ctx.is_forward_compat(): + fn = fn_base + if dtype == np.float32: + eigvals_type = ir.F32Type.get() + workspace = [ + ([_lapack.syevd_work_size(n)], a_type.element_type), + ([_lapack.syevd_iwork_size(n)], i32_type), + ] + elif dtype == np.float64: + eigvals_type = ir.F64Type.get() + workspace = [ + ([_lapack.syevd_work_size(n)], a_type.element_type), + ([_lapack.syevd_iwork_size(n)], i32_type), + ] + elif dtype == np.complex64: + eigvals_type = ir.F32Type.get() + workspace = [ + ([_lapack.heevd_work_size(n)], a_type.element_type), + ([_lapack.heevd_rwork_size(n)], eigvals_type), + ([_lapack.syevd_iwork_size(n)], i32_type), + ] + elif dtype == np.complex128: + eigvals_type = ir.F64Type.get() + workspace = [ + ([_lapack.heevd_work_size(n)], a_type.element_type), + ([_lapack.heevd_rwork_size(n)], eigvals_type), + ([_lapack.syevd_iwork_size(n)], i32_type), + ] + else: + raise NotImplementedError(f"Unsupported dtype {dtype}") + + batch_size_val = hlo_s32(1) + for b_v in batch_dims_vals: + batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) + + scalar_layout = [] + shape_layout = [0] + workspace_layouts = [shape_layout] * len(workspace) + layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) + + result_types, result_shapes = mk_result_types_and_shapes( + [(a_shape_vals, a_type.element_type), + (batch_dims_vals + (n,), eigvals_type), + (batch_dims_vals, i32_type)] + workspace + ) + + return custom_call( + fn, + result_types=result_types, + operands=[hlo_s32(1 if lower else 0), batch_size_val, ensure_hlo_s32(n), a], + operand_layouts=[scalar_layout] * 3 + [layout], + result_layouts=[ + layout, + tuple(range(num_bd, -1, -1)), + tuple(range(num_bd - 1, -1, -1)), + ] + workspace_layouts, + operand_output_aliases={3: 0}, + result_shapes=result_shapes, + ).results[:3] + fn = fn_base + "_ffi" + if dtype == np.float32 or dtype == np.complex64: eigvals_type = ir.F32Type.get() - workspace = [ - ([_lapack.heevd_work_size(n)], a_type.element_type), - ([_lapack.heevd_rwork_size(n)], eigvals_type), - ([_lapack.syevd_iwork_size(n)], i32_type), - ] - elif dtype == np.complex128: - fn = "lapack_zheevd" + elif dtype == np.float64 or dtype == np.complex128: eigvals_type = ir.F64Type.get() - workspace = [ - ([_lapack.heevd_work_size(n)], a_type.element_type), - ([_lapack.heevd_rwork_size(n)], eigvals_type), - ([_lapack.syevd_iwork_size(n)], i32_type), - ] else: raise NotImplementedError(f"Unsupported dtype {dtype}") - batch_size_val = hlo_s32(1) - for b_v in batch_dims_vals: - batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) - - scalar_layout = [] - shape_layout = [0] - workspace_layouts = [shape_layout] * len(workspace) - layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - - result_types, result_shapes = mk_result_types_and_shapes( - [(a_shape_vals, a_type.element_type), - (batch_dims_vals + (n,), eigvals_type), - (batch_dims_vals, i32_type)] + workspace - ) + result_types, result_shapes = mk_result_types_and_shapes([ + (a_shape_vals, a_type.element_type), + (batch_dims_vals + (n,), eigvals_type), + (batch_dims_vals, i32_type), + ]) - out = custom_call( + return custom_call( fn, result_types=result_types, - operands=[hlo_s32(1 if lower else 0), batch_size_val, ensure_hlo_s32(n), a], - operand_layouts=[scalar_layout] * 3 + [layout], + operands=[a], + operand_layouts=[layout], result_layouts=[ layout, tuple(range(num_bd, -1, -1)), tuple(range(num_bd - 1, -1, -1)), - ] + workspace_layouts, - operand_output_aliases={3: 0}, + ], + operand_output_aliases={0: 0}, result_shapes=result_shapes, + backend_config={ + "uplo": _matrix_uplo_attr(lower=lower), + "mode": mode, + }, + api_version=4, ).results - return out[:3] # # geev: Nonsymmetric eigendecomposition (eig) -def geev_hlo(dtype, input, *, +def geev_hlo(ctx, dtype, input, *, input_shape_vals: tuple[DimensionSize, ...], # input.shape as ir.Values jobvl=True, jobvr=True): # input_shape_vals are used for when input has dynamic shapes. @@ -617,80 +651,128 @@ def geev_hlo(dtype, input, *, layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - jobvl_c = ord('V' if jobvl else 'N') - jobvr_c = ord('V' if jobvr else 'N') + compute_left = ( + eig.ComputationMode.kComputeEigenvectors + if jobvl + else eig.ComputationMode.kNoEigenvectors + ) + + compute_right = ( + eig.ComputationMode.kComputeEigenvectors + if jobvr + else eig.ComputationMode.kNoEigenvectors + ) + fn_base = build_lapack_fn_target(fn_base="geev", dtype=dtype) i32_type = ir.IntegerType.get_signless(32) f32_type = ir.F32Type.get() f64_type = ir.F64Type.get() c64_type = ir.ComplexType.get(ir.F32Type.get()) c128_type = ir.ComplexType.get(ir.F64Type.get()) + if ctx.is_forward_compat(): + fn = fn_base + workspaces: list[ShapeTypePair] + eigvals: list[ShapeTypePair] + if dtype == np.float32: + real = True + eigvecs_type = c64_type + workspaces = [([n, n], f32_type)] * 3 + workspace_layouts = [[0, 1]] * 3 + eigvals = [(batch_dims_vals + (n,), f32_type)] * 2 + eigvals_layouts = [tuple(range(num_bd, -1, -1))] * 2 + elif dtype == np.float64: + real = True + eigvecs_type = c128_type + workspaces = [([n, n], f64_type)] * 3 + workspace_layouts = [[0, 1]] * 3 + eigvals = [(batch_dims_vals + (n,), f64_type)] * 2 + eigvals_layouts = [tuple(range(num_bd, -1, -1))] * 2 + elif dtype == np.complex64: + real = False + eigvecs_type = c64_type + workspaces = [([n, n], c64_type), ([hlo_add(n, n)], f32_type)] + workspace_layouts = [[0, 1], [0]] + eigvals = [(batch_dims_vals + (n,), c64_type)] + eigvals_layouts = [tuple(range(num_bd, -1, -1))] + elif dtype == np.complex128: + real = False + eigvecs_type = c128_type + workspaces = [([n, n], c128_type), ([hlo_add(n, n)], f64_type)] + workspace_layouts = [[0, 1], [0]] + eigvals = [(batch_dims_vals + (n,), c128_type)] + eigvals_layouts = [tuple(range(num_bd, -1, -1))] + else: + raise NotImplementedError(f"Unsupported dtype {dtype}") - workspaces: list[ShapeTypePair] - eigvals: list[ShapeTypePair] - if dtype == np.float32: - fn = "lapack_sgeev" - real = True - eigvecs_type = c64_type - workspaces = [([n, n], f32_type)] * 3 - workspace_layouts = [[0, 1]] * 3 - eigvals = [(batch_dims_vals + (n,), f32_type)] * 2 - eigvals_layouts = [tuple(range(num_bd, -1, -1))] * 2 - elif dtype == np.float64: - fn = "lapack_dgeev" - real = True - eigvecs_type = c128_type - workspaces = [([n, n], f64_type)] * 3 - workspace_layouts = [[0, 1]] * 3 - eigvals = [(batch_dims_vals + (n,), f64_type)] * 2 - eigvals_layouts = [tuple(range(num_bd, -1, -1))] * 2 - elif dtype == np.complex64: - fn = "lapack_cgeev" - real = False - eigvecs_type = c64_type - workspaces = [([n, n], c64_type), ([hlo_add(n, n)], f32_type)] - workspace_layouts = [[0, 1], [0]] - eigvals = [(batch_dims_vals + (n,), c64_type)] - eigvals_layouts = [tuple(range(num_bd, -1, -1))] - elif dtype == np.complex128: - fn = "lapack_zgeev" - real = False - eigvecs_type = c128_type - workspaces = [([n, n], c128_type), ([hlo_add(n, n)], f64_type)] - workspace_layouts = [[0, 1], [0]] - eigvals = [(batch_dims_vals + (n,), c128_type)] - eigvals_layouts = [tuple(range(num_bd, -1, -1))] - else: - raise NotImplementedError(f"Unsupported dtype {dtype}") - - scalar_layout = [] - info_layout = tuple(range(num_bd - 1, -1, -1)) + scalar_layout = [] + info_layout = tuple(range(num_bd - 1, -1, -1)) - batch_size_val = hlo_s32(1) - for b_v in batch_dims_vals: - batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) + batch_size_val = hlo_s32(1) + for b_v in batch_dims_vals: + batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) - shape_type_pairs: Sequence[ShapeTypePair] = workspaces + eigvals + [ + shape_type_pairs: Sequence[ShapeTypePair] = workspaces + eigvals + [ + (input_shape_vals, eigvecs_type), + (input_shape_vals, eigvecs_type), + (batch_dims_vals, i32_type)] + result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) + out = custom_call( + fn, + result_types=result_types, + operands=[batch_size_val, ensure_hlo_s32(n), + hlo_u8(compute_left.value), + hlo_u8(compute_right.value), + input], + operand_layouts=[scalar_layout] * 4 + [layout], + result_layouts=(workspace_layouts + eigvals_layouts + [layout] * 2 + + [info_layout]), + result_shapes=result_shapes, + ).results + if real: + return (hlo.complex(out[3], out[4]), out[5], out[6], out[7]) + else: + return out[2:6] + fn = fn_base + "_ffi" + real = dtype == np.float32 or dtype == np.float64 + eigvecs_type = ( + c64_type if dtype == np.float32 or dtype == np.complex64 else c128_type + ) + input_type = ir.RankedTensorType(input.type) + eigvals = [(batch_dims_vals + (n,), input_type.element_type)] + eigvals_layouts = [tuple(range(num_bd, -1, -1))] + if real: + eigvals = eigvals * 2 + eigvals_layouts = eigvals_layouts * 2 + info_layout = tuple(range(num_bd - 1, -1, -1)) + shape_type_pairs: Sequence[ShapeTypePair] = [ + *eigvals, (input_shape_vals, eigvecs_type), (input_shape_vals, eigvecs_type), - (batch_dims_vals, i32_type)] + (batch_dims_vals, i32_type), + ] result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) out = custom_call( fn, result_types=result_types, - operands=[batch_size_val, ensure_hlo_s32(n), - hlo_u8(jobvl_c), - hlo_u8(jobvr_c), - input], - operand_layouts=[scalar_layout] * 4 + [layout], - result_layouts=(workspace_layouts + eigvals_layouts + [layout] * 2 + - [info_layout]), + operands=[input], + operand_layouts=[layout], + result_layouts=( + *eigvals_layouts, + layout, + layout, + info_layout, + ), result_shapes=result_shapes, + backend_config={ + "compute_left": _enum_to_char_attr(compute_left), + "compute_right": _enum_to_char_attr(compute_right), + }, + api_version=4, ).results if real: - return (hlo.complex(out[3], out[4]), out[5], out[6], out[7]) + return (hlo.complex(out[0], out[1]), out[2], out[3], out[4]) else: - return out[2:6] + return out[:4] # # gees : Schur factorization diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 26eb34088460..4e7898d57fe0 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -115,6 +115,8 @@ def test_custom_call_coverage(self): cpu_ffi_testdatas = [ cpu_cholesky_lapack_potrf.data_2024_05_31, cpu_qr_lapack_geqrf.data_2024_08_22, + cpu_eig_lapack_geev.data_2024_08_19, + cpu_eigh_lapack_syev.data_2024_08_19, cpu_lu_lapack_getrf.data_2024_05_31, cpu_svd_lapack_gesdd.data_2024_08_13, ] @@ -256,6 +258,14 @@ def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array): self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=check_eig_results) + # TODO(b/344892332): Remove the check after the compatibility period. + has_xla_ffi_support = jaxlib_version >= (0, 4, 32) + if has_xla_ffi_support: + with config.export_ignore_forward_compatibility(True): + # FFI Kernel test + data = self.load_testdata(cpu_eig_lapack_geev.data_2024_08_19[dtype_name]) + self.run_one_test(func, data, rtol=rtol, atol=atol, + check_results=check_eig_results) @staticmethod def eigh_input(shape, dtype): @@ -306,6 +316,14 @@ def test_cpu_eigh_lapack_syevd(self, dtype_name="f32"): atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name] self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=partial(self.check_eigh_results, operand)) + # TODO(b/344892332): Remove the check after the compatibility period. + has_xla_ffi_support = jaxlib_version >= (0, 4, 32) + if has_xla_ffi_support: + # FFI Kernel test + with config.export_ignore_forward_compatibility(True): + data = self.load_testdata(cpu_eigh_lapack_syev.data_2024_08_19[dtype_name]) + self.run_one_test(func, data, rtol=rtol, atol=atol, + check_results=partial(self.check_eigh_results, operand)) @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}_{variant}",