diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 7f7773acbd39..e869953fa1a1 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -961,11 +961,6 @@ def _check_lowering(lowering) -> None: "lapack_sorgqr", "lapack_dorgqr", "lapack_cungqr", "lapack_zungqr", # svd on CPU "lapack_sgesdd", "lapack_dgesdd", "lapack_cgesdd", "lapack_zgesdd", - # qr on GPU - "cusolver_geqrf", "cublas_geqrf_batched", - "cusolver_orgqr", - "hipsolver_geqrf", "hipblas_geqrf_batched", - "hipsolver_orgqr", # qr and svd on TPU "Qr", "ProductOfElementaryHouseholderReflectors", # triangular_solve on CPU @@ -977,9 +972,10 @@ def _check_lowering(lowering) -> None: "lapack_sgees", "lapack_dgees", "lapack_cgees", "lapack_zgees", # lu on GPU "cu_lu_pivots_to_permutation", - # "cublas_getrf_batched", "cusolver_getrf", - # "hipblas_getrf_batched", "hipsolver_getrf", "cusolver_getrf_ffi", + # qr on GPU + "cusolver_geqrf_ffi", "cusolver_orgqr_ffi", + # svd on GPU # lu on TPU "LuDecomposition", # ApproxTopK on TPU diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_qr_cusolver_geqrf.py b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_qr_cusolver_geqrf.py index 82e5d2bd479a..be5c6e01f8d8 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_qr_cusolver_geqrf.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_qr_cusolver_geqrf.py @@ -15,7 +15,7 @@ # ruff: noqa import datetime -from numpy import array, float32 +from numpy import array, float32, float64, complex64, complex128 data_2023_03_18 = {} @@ -155,3 +155,271 @@ mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01+\x05\x01\x05\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03\x96\x02\xff=\x01\x9f\x17\x0f\x0f\x0f\x07\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0bK\x0b\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0bK\x0b\x13\x1b\x13\x13\x13\x13\x0b\x03ao/\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0fO\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0bO\x1f\x0b\x0b\x0f\x17\x1b\x0f\x0f\x0f\x0f\x0b\x0b\x13\x17\x1f\x0b/\x1fo\x03=\x1b\x07\x07\x07\x0f\x17\x0f\x07\x13\x07\x13\x1b\x17\x13\x1b\x17\x17\x13\x17\x13\x13\x1b\x07\x13\x13\x13\x17\x13\x1b\x13\x02\x1a\n\x17\x1d\n\x06\x01\x1d\x8f\x01\x1dK\x01\x1dy\x01\x1f\x05!\x03\x03\x0f\xd1\x05#\x05%\x05'\x05)\x05+\x05-\x05/\x051\x03\x03!\xcd\x053\x1dS\x01\x055\x057\x03\x03\x0b\xd9\x17\x1d\x06\x06\x01\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x11\xe5\x03\x03\x11\xe7\x03\x03\x11\xe9\x03\x03\x13E\x05I\x03\x0b\x15\xa3\x17\xb7\x19\xb9\x13\xc3\x1b\xc5\x03\x0b\x15\xa9\x17\xc9\x19\xa9\x13\xab\x1b\xcb\x05K\x1dO\x01\x05M\x03\x03\x0b\xcf\x05O\x03\x03!\xd3\x1dY\x01\x05Q\x03\x05%\xad'\xd5\x1d_\x01\x05S\x03\x03\x0f\xd7\x1de\x01\x05U\x1di\x01\x05W\x1dm\x01\x05Y\x1dq+\x05[\x1du+\x05]\x03\x11-\xaf/\xdb1\xdd3\xa35\xb17\xdf9\xb3;\xe3\x05_\x03\x03\x11\xeb\x1d\x7f\x01\x05a\x03\x07\x83\xa5\x85\xa5\x87\xa5\x05c\x05e\x05g\x1d\x8b\x01\x05i\x03\x11-\xaf/\xed1\xef3\xa35\xb17\xf19\xb3;\xf3\x05k\x03\x03\x0b\xf5\x03\x05%\xad'\xf7\x03\x03\x0f\xf9\x03\x03\x0b\xfb\x03\x03\x0f\xfd\x03\x03\x9d\xab\x05m\x1f/1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1do\x03\x03\xc7\x1dq\t\x07\x0b\x05\x05\x01\x03\x03\xe1\x1f1!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x05\xbb\xbf\r\x03\xa7\xbd\x1ds\r\x03\xa7\xc1\x1du\x1dw\x1dy\r\x01#!\x1d{\x13\x05\x01\x1f\r\t\xff\xff\xff\xff\x1f#\x01\x13\x05\x05\x07\x05\x1f'!\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\t\t\x00\x00\x00\x00\x1d}\x1d\x7f\x03\x03\x9f\x15\x03\x01\x01\x01\x03\t\x9f\xb5\xa1\xa1\x13\x03\x01\x13\x03\x05\x13\x03\t\x13\x03\r\x1d\x81\x1d\x83\x03\x05\x9f\xb5\x03\x07\x9f\xa1\xa1\x1f\r\t\x00\x00\x00\x00\x07\x01\x1f;\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\t\x00\x00\xc0\x7f\x1f\x1b1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00)\x07\t\r\r\x07\x1b\x1d\t)\x01\x07)\x05\r\r\x03)\x01\x03\x01)\x03A-\x13)\x03\t\x03)\x07\t\r\r\x0f)\x05\t\r\x07)\x03\r\x05)\x03\x04\x12\x08\x07\x11\x01\x05\x01\x01\x11\x03\x01\x03\x01)\x03\x01\x05)\x05\r\r\x0f)\x03\t\x05)\x03I\x07/\t\x01\x19\x11\x11\x17)\x03\r\x13)\x03\t\x13)\x03\x05\x13/\x07\x01\x15\x1d)\x03\t\x0f)\x07\t\x05\x05\x0f)\x03\x05\x05\x04r\x04\x05\x01\x11\tC\x07\x03\x01\t\x0b\x11\tG\x05\x03-]\t\x03o\x1f\x03)\x17\x06s\x03\x01\x03\x01\x13\x07\x07w\x03+\x03\x03\x05\x07\x07=\x03\x01\x03\x05\x05\x07\x07?\x03\x19\x03\x05\x05\x07\x07A\x03\x11\x03\x05\x05\x07\x07{\x03\x11\x03\x05\x07\x03})\x03\t\x19\x07\x89\x81\x03\x01\x05\x07\x0f\x13\x07\x03\x8d\x035\x05\x11\t\x05\x07\x03=\x03\x01\x03\x13\x05\x07\x03?\x03\x15\x03\x13\x05\x07\x03A\x03\x1d\x03\x13\x07\x03\x03\x91\x03\r\x03\x07\x03\r\x03\x15\x03\x1b\r\x07\x03\x93\x037\x05\x17\x1d\x03\x07\x03\x95\x039\x03\x1f\x07\x03\x03\x97\x03\t\x03\x07\x03\r\x03\x01\x03#\x03\x07\x03\x99\x03\x17\x03!\x0f\x06\x03\x03\x01\x07'\x15%\x1b\x07\x05\x9b\x03\x01\x03\x07\x11\x04\t\x05)+\x0b\x11\x05I\x05\x03\x17/\x03\x01\t\t\x03M\x1f\x03\x0b\x07\x03\x05Q\x03\r\x03\x07#\r\x03\x0b\x03\x05\x15\x06#\x03\x0b\x05\x03\x07\t\x03WU\x03\x0b\r\x07][\x03%\x05\t\x0b\x03\x07ca\x03\x17\x03\r\x07\x03\x05)\x03\t\x03\x07g\r\x03\x01\x03\x11\x0f\x06k\x03\x01\x07\x0f\x13\x01\x11\x04\x05\x03\x15\x06\x03\x01\x05\x01\x00Z\x1b\x85\x1f3+#\x11\x0f\x0b\t\t\x0b!\x0fY\x9d##%_=\x8b\x89W\xb9\xc1K\x9bM\x9b\xd2\x02\x1b\x1f/!!)#\x1f\x19+\x1b\x1f\x83\x1f\x15\x1d\x15\x13\r+\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11\x19+)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00get_tuple_element_v1\x00constant_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00broadcast_dimensions\x00index\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=(1, 2)]\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(18,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(2, 3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00callee\x00jax.result_info\x00triu\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x00cublas_geqrf_batched\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00 \x81\x00\x00\x00cusolver_orgqr\x00", xla_call_module_version=4, ) # End paste + +data_2024_09_26 = {} + +data_2024_09_26["f32"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cusolver_geqrf_ffi', 'cusolver_orgqr_ffi'], + serialized_date=datetime.date(2024, 9, 26), + inputs=(), + expected_outputs=(array([[[ 0. , 0.91287094, 0.40824842], + [-0.4472136 , 0.36514843, -0.8164965 ], + [-0.8944272 , -0.18257415, 0.40824836]], + + [[-0.42426407, 0.80828977, 0.4082495 ], + [-0.5656854 , 0.11547142, -0.8164964 ], + [-0.7071068 , -0.57735085, 0.4082474 ]]], dtype=float32), array([[[-6.7082038e+00, -8.0498447e+00, -9.3914852e+00], + [ 0.0000000e+00, 1.0954450e+00, 2.1908898e+00], + [ 0.0000000e+00, 0.0000000e+00, 4.8374091e-08]], + + [[-2.1213203e+01, -2.2910259e+01, -2.4607319e+01], + [ 0.0000000e+00, 3.4641042e-01, 6.9282258e-01], + [ 0.0000000e+00, 0.0000000e+00, 1.4548683e-06]]], dtype=float32)), + mlir_module_text=r""" +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x3x3xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3x3xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %c = stablehlo.constant dense<-1> : tensor loc(#loc) + %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc) + %0 = stablehlo.iota dim = 0 : tensor<18xf32> loc(#loc4) + %1 = stablehlo.reshape %0 : (tensor<18xf32>) -> tensor<2x3x3xf32> loc(#loc5) + %2:2 = stablehlo.custom_call @cusolver_geqrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>]} : (tensor<2x3x3xf32>) -> (tensor<2x3x3xf32>, tensor<2x3xf32>) loc(#loc6) + %3 = stablehlo.pad %2#0, %cst, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<2x3x3xf32>, tensor) -> tensor<2x3x3xf32> loc(#loc7) + %4 = stablehlo.custom_call @cusolver_orgqr_ffi(%3, %2#1) {mhlo.backend_config = {}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>]} : (tensor<2x3x3xf32>, tensor<2x3xf32>) -> tensor<2x3x3xf32> loc(#loc8) + %5 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc9) + %6 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc10) + %7 = stablehlo.add %5, %6 : tensor<3x3xi32> loc(#loc10) + %8 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc9) + %9 = stablehlo.compare GE, %7, %8, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc11) + %10 = stablehlo.broadcast_in_dim %9, dims = [1, 2] : (tensor<3x3xi1>) -> tensor<2x3x3xi1> loc(#loc12) + %11 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x3x3xf32> loc(#loc12) + %12 = stablehlo.select %10, %11, %2#0 : tensor<2x3x3xi1>, tensor<2x3x3xf32> loc(#loc13) + return %4, %12 : tensor<2x3x3xf32>, tensor<2x3x3xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":390:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":390:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":391:11) +#loc4 = loc("jit()/jit(main)/iota"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape"(#loc2)) +#loc6 = loc("jit()/jit(main)/geqrf"(#loc3)) +#loc7 = loc("jit()/jit(main)/pad"(#loc3)) +#loc8 = loc("jit()/jit(main)/householder_product"(#loc3)) +#loc9 = loc("jit()/jit(main)/iota"(#loc3)) +#loc10 = loc("jit()/jit(main)/add"(#loc3)) +#loc11 = loc("jit()/jit(main)/ge"(#loc3)) +#loc12 = loc("jit()/jit(main)/broadcast_in_dim"(#loc3)) +#loc13 = loc("jit()/jit(main)/select_n"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.5.0\x00\x01)\x05\x01\x05\x19\x01\x03\x0b\x03\x17\x0f\x13\x17\x1b\x1f#'+/37\x03\xc7\x89+\x01C\x17\x07\x0b\x0f\x0b\x13\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x17\x0f\x0b\x17\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x0b\x0f\x0b\x03G\x0b/\x0b\x0b\x0b\x0f\x0b\x0b\x0b\x0fo\x13\x0f\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x1f\x1f\x0b\x0b\x0f\x17O\x0b\x0f\x13\x0f\x0b\x0bO\x01\x05\x0b\x0f\x03'\x1b\x07\x07\x17\x0f\x07\x0f\x07\x07\x17\x13\x17\x13\x13\x13\x13\x17\x1b\x13\x02F\x05\x17\x05\x1e\x06\x17\x1f\x05\x1d\x11\x03\x05\x05\x1f\x03\x03)q\x1d\t\x01\x1d7\x01\x1d=\x01\x03\x07\x15\x17\x19\x07\x1b\x07\x05!\x11\x01\x00\x05#\x05%\x05'\x1d\t!\x17\x05\x1a\x065\x1d%'\x05)\x17\x05\x1a\x06\x1d\x05+\x1d-\x01\x05-\x1d1\x01\x05/\x1d5\x01\x051\x053\x1d;\x01\x055\x057\x1dA\x01\x059\x03\x01\x1f!\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d;\x1d=\x1d?\x13\x07\x01\x0b\x03\x1dA\x05\x01\x03\x03W\x1f\x1d1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x05Wy\x1f#\x01#\x17\x03\x05ae\r\x05GcIK\x1dC\r\x05GgIK\x1dE\x1dG\x1dI\x1f\r\t\xff\xff\xff\xff\x1f\x11\t\x00\x00\x00\x00\r\x01\x1dK\x03\x03w\x15\x03\x01\x01\x01\x1f\x1f!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1dM\x03\x03\x7f\x15\x01\x01\x01\x13\x07\x05\t\x07\x07\x05\x1f)!\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\r\r\t\x1d\t)\x05\r\r\x0f)\x01\x0f\x1b)\x01\t\x13\x01\x11\x01\x05\x05\x05)\x03I\t)\x05\t\r\t)\x03\r\x13)\x03\t\x13)\x03\r\x07)\x03\x01\x07)\x05\r\r\x15)\x07\t\r\r\x15)\x03\t\x07\x04F\x02\x05\x01Q\x03\x13\x01\x07\x04\x1e\x02\x03\x01\x05\x0bP\x03\x03\x07\x04\xfb\x03!A\x07B\x03\x05\x03\r\x07B\x03\x07\x03\x11\x03B\x1f\t\x03\x19\r\x06#\x03\x05\x03\x05\tG+\x0b\x0b\x05\x05\x1b\x03\x07\x0fF/\r\x03\x05\x05\t\x03\tG3\x0b\x0f\x03\x05\x05\r\x0b\x03B\r\t\x03\x0b\x05F\x0f\x11\x03\x0b\x03\x01\x11\x06\x0f\x03\x0b\x05\x11\x13\x03B\r\x13\x03\x0b\x13F9\x15\x03%\x05\x15\x17\x05F\x11\x17\x03'\x03\x19\x05F\x11\x11\x03\x05\x03\x03\x15\x06?\x03\x05\x07\x1b\x1d\t\x17\x04\x03\x05\x0f\x1f\x06\x03\x01\x05\x01\x00J\x0bO''\x0f\x0b\t\t\x03\x11#!CS79Y9=)A\x1b%)9;i\x15\x15\x17\x0f\x0f\x17\x11\x1f\x19)\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00iota_v1\x00broadcast_in_dim_v1\x00constant_v1\x00custom_call_v1\x00func_v1\x00reshape_v1\x00pad_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit()/jit(main)/iota\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/reshape\x00mhlo.backend_config\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/pad\x00jit()/jit(main)/householder_product\x00jit()/jit(main)/add\x00jit()/jit(main)/ge\x00jit()/jit(main)/broadcast_in_dim\x00jit()/jit(main)/select_n\x00jax.result_info\x00mhlo.layout_mode\x00default\x00\x00[0]\x00[1]\x00main\x00public\x00cusolver_geqrf_ffi\x00cusolver_orgqr_ffi\x00\x08_\x19\x05;\x01\x0bC]_ik\x03m\x03o\x03M\x11OQsCSUuY\x07EEE\x11OQ{CSY}U\x03[\x03\x81\x05\x83\x85\x03\x87", + xla_call_module_version=9, + nr_devices=1, +) # End paste + +data_2024_09_26["f64"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cusolver_geqrf_ffi', 'cusolver_orgqr_ffi'], + serialized_date=datetime.date(2024, 9, 26), + inputs=(), + expected_outputs=(array([[[ 0. , 0.9128709291752773 , + 0.408248290463862 ], + [-0.447213595499958 , 0.36514837167011005, + -0.8164965809277264 ], + [-0.894427190999916 , -0.1825741858350547 , + 0.40824829046386335]], + + [[-0.42426406871192857, 0.8082903768654768 , + 0.40824829046386124], + [-0.565685424949238 , 0.11547005383792364, + -0.8164965809277263 ], + [-0.7071067811865476 , -0.577350269189625 , + 0.4082482904638641 ]]]), array([[[-6.7082039324993694e+00, -8.0498447189992444e+00, + -9.3914855054991175e+00], + [ 0.0000000000000000e+00, 1.0954451150103344e+00, + 2.1908902300206661e+00], + [ 0.0000000000000000e+00, 0.0000000000000000e+00, + -1.7577018578317312e-15]], + + [[-2.1213203435596427e+01, -2.2910259710444144e+01, + -2.4607315985291855e+01], + [ 0.0000000000000000e+00, 3.4641016151377924e-01, + 6.9282032302755281e-01], + [ 0.0000000000000000e+00, 0.0000000000000000e+00, + -1.8103038069914667e-15]]])), + mlir_module_text=r""" +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x3x3xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3x3xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %c = stablehlo.constant dense<-1> : tensor loc(#loc) + %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc) + %0 = stablehlo.iota dim = 0 : tensor<18xf64> loc(#loc4) + %1 = stablehlo.reshape %0 : (tensor<18xf64>) -> tensor<2x3x3xf64> loc(#loc5) + %2:2 = stablehlo.custom_call @cusolver_geqrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>]} : (tensor<2x3x3xf64>) -> (tensor<2x3x3xf64>, tensor<2x3xf64>) loc(#loc6) + %3 = stablehlo.pad %2#0, %cst, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<2x3x3xf64>, tensor) -> tensor<2x3x3xf64> loc(#loc7) + %4 = stablehlo.custom_call @cusolver_orgqr_ffi(%3, %2#1) {mhlo.backend_config = {}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>]} : (tensor<2x3x3xf64>, tensor<2x3xf64>) -> tensor<2x3x3xf64> loc(#loc8) + %5 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc9) + %6 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc10) + %7 = stablehlo.add %5, %6 : tensor<3x3xi32> loc(#loc10) + %8 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc9) + %9 = stablehlo.compare GE, %7, %8, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc11) + %10 = stablehlo.broadcast_in_dim %9, dims = [1, 2] : (tensor<3x3xi1>) -> tensor<2x3x3xi1> loc(#loc12) + %11 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x3x3xf64> loc(#loc12) + %12 = stablehlo.select %10, %11, %2#0 : tensor<2x3x3xi1>, tensor<2x3x3xf64> loc(#loc13) + return %4, %12 : tensor<2x3x3xf64>, tensor<2x3x3xf64> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":390:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":390:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":391:11) +#loc4 = loc("jit()/jit(main)/iota"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape"(#loc2)) +#loc6 = loc("jit()/jit(main)/geqrf"(#loc3)) +#loc7 = loc("jit()/jit(main)/pad"(#loc3)) +#loc8 = loc("jit()/jit(main)/householder_product"(#loc3)) +#loc9 = loc("jit()/jit(main)/iota"(#loc3)) +#loc10 = loc("jit()/jit(main)/add"(#loc3)) +#loc11 = loc("jit()/jit(main)/ge"(#loc3)) +#loc12 = loc("jit()/jit(main)/broadcast_in_dim"(#loc3)) +#loc13 = loc("jit()/jit(main)/select_n"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.5.0\x00\x01)\x05\x01\x05\x19\x01\x03\x0b\x03\x17\x0f\x13\x17\x1b\x1f#'+/37\x03\xc7\x89+\x01C\x17\x07\x0b\x0f\x0b\x13\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x17\x0f\x0b\x17\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x0b\x0f\x0b\x03G\x0b/\x0b\x0b\x0b\x0f\x0b\x0b\x0b\x0fo\x13\x0f\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x1f/\x0b\x0b\x0f\x17O\x0b\x0f\x13\x0f\x0b\x0bO\x01\x05\x0b\x0f\x03'\x1b\x07\x07\x17\x0f\x07\x0f\x07\x07\x17\x13\x17\x13\x13\x13\x13\x17\x1b\x13\x02V\x05\x17\x05\x1e\x06\x17\x1f\x05\x1d\x11\x03\x05\x05\x1f\x03\x03)q\x1d\t\x01\x1d7\x01\x1d=\x01\x03\x07\x15\x17\x19\x07\x1b\x07\x05!\x11\x01\x00\x05#\x05%\x05'\x1d\t!\x17\x05\x1a\x065\x1d%'\x05)\x17\x05\x1a\x06\x1d\x05+\x1d-\x01\x05-\x1d1\x01\x05/\x1d5\x01\x051\x053\x1d;\x01\x055\x057\x1dA\x01\x059\x03\x01\x1f!\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d;\x1d=\x1d?\x13\x07\x01\x0b\x03\x1dA\x05\x01\x03\x03W\x1f\x1d1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x05Wy\x1f#\x01#\x17\x03\x05ae\r\x05GcIK\x1dC\r\x05GgIK\x1dE\x1dG\x1dI\x1f\r\t\xff\xff\xff\xff\x1f\x11\x11\x00\x00\x00\x00\x00\x00\x00\x00\r\x01\x1dK\x03\x03w\x15\x03\x01\x01\x01\x1f\x1f!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1dM\x03\x03\x7f\x15\x01\x01\x01\x13\x07\x05\t\x07\x07\x05\x1f)!\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\r\r\t\x1d\x0b)\x05\r\r\x0f)\x01\x0f\x1b)\x01\t\x13\x01\x11\x01\x05\x05\x05)\x03I\t)\x05\t\r\t)\x03\r\x13)\x03\t\x13)\x03\r\x07)\x03\x01\x07)\x05\r\r\x15)\x07\t\r\r\x15)\x03\t\x07\x04F\x02\x05\x01Q\x03\x13\x01\x07\x04\x1e\x02\x03\x01\x05\x0bP\x03\x03\x07\x04\xfb\x03!A\x07B\x03\x05\x03\r\x07B\x03\x07\x03\x11\x03B\x1f\t\x03\x19\r\x06#\x03\x05\x03\x05\tG+\x0b\x0b\x05\x05\x1b\x03\x07\x0fF/\r\x03\x05\x05\t\x03\tG3\x0b\x0f\x03\x05\x05\r\x0b\x03B\r\t\x03\x0b\x05F\x0f\x11\x03\x0b\x03\x01\x11\x06\x0f\x03\x0b\x05\x11\x13\x03B\r\x13\x03\x0b\x13F9\x15\x03%\x05\x15\x17\x05F\x11\x17\x03'\x03\x19\x05F\x11\x11\x03\x05\x03\x03\x15\x06?\x03\x05\x07\x1b\x1d\t\x17\x04\x03\x05\x0f\x1f\x06\x03\x01\x05\x01\x00J\x0bO''\x0f\x0b\t\t\x03\x11#!CS79Y9=)A\x1b%)9;i\x15\x15\x17\x0f\x0f\x17\x11\x1f\x19)\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00iota_v1\x00broadcast_in_dim_v1\x00constant_v1\x00custom_call_v1\x00func_v1\x00reshape_v1\x00pad_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit()/jit(main)/iota\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/reshape\x00mhlo.backend_config\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/pad\x00jit()/jit(main)/householder_product\x00jit()/jit(main)/add\x00jit()/jit(main)/ge\x00jit()/jit(main)/broadcast_in_dim\x00jit()/jit(main)/select_n\x00jax.result_info\x00mhlo.layout_mode\x00default\x00\x00[0]\x00[1]\x00main\x00public\x00cusolver_geqrf_ffi\x00cusolver_orgqr_ffi\x00\x08_\x19\x05;\x01\x0bC]_ik\x03m\x03o\x03M\x11OQsCSUuY\x07EEE\x11OQ{CSY}U\x03[\x03\x81\x05\x83\x85\x03\x87", + xla_call_module_version=9, + nr_devices=1, +) # End paste + +data_2024_09_26["c64"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cusolver_geqrf_ffi', 'cusolver_orgqr_ffi'], + serialized_date=datetime.date(2024, 9, 26), + inputs=(), + expected_outputs=(array([[[ 0. +0.j, 0.91287094+0.j, 0.40824836+0.j], + [-0.4472136 +0.j, 0.36514843+0.j, -0.81649655+0.j], + [-0.8944272 +0.j, -0.18257417+0.j, 0.4082483 +0.j]], + + [[-0.42426407+0.j, 0.8082899 +0.j, 0.40824962+0.j], + [-0.5656854 +0.j, 0.11547136+0.j, -0.8164964 +0.j], + [-0.7071068 +0.j, -0.57735085+0.j, 0.4082474 +0.j]]], + dtype=complex64), array([[[-6.7082038e+00+0.j, -8.0498447e+00+0.j, -9.3914852e+00+0.j], + [ 0.0000000e+00+0.j, 1.0954450e+00+0.j, 2.1908898e+00+0.j], + [ 0.0000000e+00+0.j, 0.0000000e+00+0.j, 4.8374091e-08+0.j]], + + [[-2.1213203e+01+0.j, -2.2910259e+01+0.j, -2.4607319e+01+0.j], + [ 0.0000000e+00+0.j, 3.4641042e-01+0.j, 6.9282258e-01+0.j], + [ 0.0000000e+00+0.j, 0.0000000e+00+0.j, 1.5032538e-06+0.j]]], + dtype=complex64)), + mlir_module_text=r""" +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x3x3xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3x3xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %c = stablehlo.constant dense<-1> : tensor loc(#loc) + %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc) + %0 = stablehlo.iota dim = 0 : tensor<18xcomplex> loc(#loc4) + %1 = stablehlo.reshape %0 : (tensor<18xcomplex>) -> tensor<2x3x3xcomplex> loc(#loc5) + %2:2 = stablehlo.custom_call @cusolver_geqrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>]} : (tensor<2x3x3xcomplex>) -> (tensor<2x3x3xcomplex>, tensor<2x3xcomplex>) loc(#loc6) + %3 = stablehlo.pad %2#0, %cst, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<2x3x3xcomplex>, tensor>) -> tensor<2x3x3xcomplex> loc(#loc7) + %4 = stablehlo.custom_call @cusolver_orgqr_ffi(%3, %2#1) {mhlo.backend_config = {}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>]} : (tensor<2x3x3xcomplex>, tensor<2x3xcomplex>) -> tensor<2x3x3xcomplex> loc(#loc8) + %5 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc9) + %6 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc10) + %7 = stablehlo.add %5, %6 : tensor<3x3xi32> loc(#loc10) + %8 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc9) + %9 = stablehlo.compare GE, %7, %8, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc11) + %10 = stablehlo.broadcast_in_dim %9, dims = [1, 2] : (tensor<3x3xi1>) -> tensor<2x3x3xi1> loc(#loc12) + %11 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<2x3x3xcomplex> loc(#loc12) + %12 = stablehlo.select %10, %11, %2#0 : tensor<2x3x3xi1>, tensor<2x3x3xcomplex> loc(#loc13) + return %4, %12 : tensor<2x3x3xcomplex>, tensor<2x3x3xcomplex> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":390:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":390:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":391:11) +#loc4 = loc("jit()/jit(main)/iota"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape"(#loc2)) +#loc6 = loc("jit()/jit(main)/geqrf"(#loc3)) +#loc7 = loc("jit()/jit(main)/pad"(#loc3)) +#loc8 = loc("jit()/jit(main)/householder_product"(#loc3)) +#loc9 = loc("jit()/jit(main)/iota"(#loc3)) +#loc10 = loc("jit()/jit(main)/add"(#loc3)) +#loc11 = loc("jit()/jit(main)/ge"(#loc3)) +#loc12 = loc("jit()/jit(main)/broadcast_in_dim"(#loc3)) +#loc13 = loc("jit()/jit(main)/select_n"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.5.0\x00\x01)\x05\x01\x05\x19\x01\x03\x0b\x03\x17\x0f\x13\x17\x1b\x1f#'+/37\x03\xc9\x89-\x01C\x17\x07\x0b\x0f\x0b\x13\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x17\x0f\x0b\x17\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x0b\x0f\x0b\x03G\x0b/\x0b\x0b\x0b\x0f\x0b\x0b\x0b\x0fo\x13\x0f\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x1f/\x0b\x0b\x0f\x17O\x0b\x0f\x13\x0f\x0b\x0bO\x01\x05\x0b\x0f\x03)\x1b\x07\x0b\x17\x0f\x07\x0f\x07\x07\x17\x07\x13\x17\x13\x13\x13\x13\x17\x1b\x13\x02^\x05\x17\x05\x1e\x06\x17\x1f\x05\x1d\x11\x03\x05\x05\x1f\x03\x03)q\x1d\t\x01\x1d7\x01\x1d=\x01\x03\x07\x15\x17\x19\x07\x1b\x07\x05!\x11\x01\x00\x05#\x05%\x05'\x1d\t!\x17\x05\x1a\x065\x1d%'\x05)\x17\x05\x1a\x06\x1d\x05+\x1d-\x01\x05-\x1d1\x01\x05/\x1d5\x01\x051\x053\x1d;\x01\x055\x057\x1dA\x01\x059\x03\x01\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d;\x1d=\x1d?\x13\x07\x01\x0b\x03\x1dA\x05\x01\x03\x03W\x1f\x1f1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x05Wy\x1f%\x01#\x17\x03\x05ae\r\x05GcIK\x1dC\r\x05GgIK\x1dE\x1dG\x1dI\x1f\r\t\xff\xff\xff\xff\x1f\x11\x11\x00\x00\x00\x00\x00\x00\x00\x00\r\x01\x1dK\x03\x03w\x15\x03\x01\x01\x01\x1f!!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1dM\x03\x03\x7f\x15\x01\x01\x01\x13\x07\x05\t\x07\x07\x05\x1f+!\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\r\r\t\x1d\x03\x19)\x05\r\r\x0f)\x01\x0f\x1b)\x01\t\x13\x01\x11\x01\x05\x05\x05\t)\x03I\t)\x05\t\r\t)\x03\r\x13)\x03\t\x13)\x03\r\x07)\x03\x01\x07)\x05\r\r\x15)\x07\t\r\r\x15)\x03\t\x07\x04F\x02\x05\x01Q\x03\x13\x01\x07\x04\x1e\x02\x03\x01\x05\x0bP\x03\x03\x07\x04\xfb\x03!A\x07B\x03\x05\x03\r\x07B\x03\x07\x03\x11\x03B\x1f\t\x03\x1b\r\x06#\x03\x05\x03\x05\tG+\x0b\x0b\x05\x05\x1d\x03\x07\x0fF/\r\x03\x05\x05\t\x03\tG3\x0b\x0f\x03\x05\x05\r\x0b\x03B\r\t\x03\x0b\x05F\x0f\x11\x03\x0b\x03\x01\x11\x06\x0f\x03\x0b\x05\x11\x13\x03B\r\x13\x03\x0b\x13F9\x15\x03'\x05\x15\x17\x05F\x11\x17\x03)\x03\x19\x05F\x11\x11\x03\x05\x03\x03\x15\x06?\x03\x05\x07\x1b\x1d\t\x17\x04\x03\x05\x0f\x1f\x06\x03\x01\x05\x01\x00J\x0bO''\x0f\x0b\t\t\x03\x11#!CS79Y9=)A\x1b%)9;i\x15\x15\x17\x0f\x0f\x17\x11\x1f\x19)\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00iota_v1\x00broadcast_in_dim_v1\x00constant_v1\x00custom_call_v1\x00func_v1\x00reshape_v1\x00pad_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit()/jit(main)/iota\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/reshape\x00mhlo.backend_config\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/pad\x00jit()/jit(main)/householder_product\x00jit()/jit(main)/add\x00jit()/jit(main)/ge\x00jit()/jit(main)/broadcast_in_dim\x00jit()/jit(main)/select_n\x00jax.result_info\x00mhlo.layout_mode\x00default\x00\x00[0]\x00[1]\x00main\x00public\x00cusolver_geqrf_ffi\x00cusolver_orgqr_ffi\x00\x08_\x19\x05;\x01\x0bC]_ik\x03m\x03o\x03M\x11OQsCSUuY\x07EEE\x11OQ{CSY}U\x03[\x03\x81\x05\x83\x85\x03\x87", + xla_call_module_version=9, + nr_devices=1, +) # End paste + +data_2024_09_26["c128"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cusolver_geqrf_ffi', 'cusolver_orgqr_ffi'], + serialized_date=datetime.date(2024, 9, 26), + inputs=(), + expected_outputs=(array([[[ 0. +0.j, 0.9128709291752773 +0.j, + 0.4082482904638621 +0.j], + [-0.447213595499958 +0.j, 0.36514837167011005+0.j, + -0.8164965809277263 +0.j], + [-0.894427190999916 +0.j, -0.18257418583505472+0.j, + 0.40824829046386335+0.j]], + + [[-0.42426406871192857+0.j, 0.808290376865477 +0.j, + 0.4082482904638615 +0.j], + [-0.565685424949238 +0.j, 0.11547005383792353+0.j, + -0.8164965809277263 +0.j], + [-0.7071067811865476 +0.j, -0.577350269189625 +0.j, + 0.4082482904638641 +0.j]]]), array([[[-6.7082039324993694e+00+0.j, -8.0498447189992444e+00+0.j, + -9.3914855054991175e+00+0.j], + [ 0.0000000000000000e+00+0.j, 1.0954451150103344e+00+0.j, + 2.1908902300206661e+00+0.j], + [ 0.0000000000000000e+00+0.j, 0.0000000000000000e+00+0.j, + -1.7577018578317312e-15+0.j]], + + [[-2.1213203435596427e+01+0.j, -2.2910259710444144e+01+0.j, + -2.4607315985291855e+01+0.j], + [ 0.0000000000000000e+00+0.j, 3.4641016151377924e-01+0.j, + 6.9282032302755292e-01+0.j], + [ 0.0000000000000000e+00+0.j, 0.0000000000000000e+00+0.j, + -1.7201790115224914e-15+0.j]]])), + mlir_module_text=r""" +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x3x3xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3x3xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %c = stablehlo.constant dense<-1> : tensor loc(#loc) + %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc) + %0 = stablehlo.iota dim = 0 : tensor<18xcomplex> loc(#loc4) + %1 = stablehlo.reshape %0 : (tensor<18xcomplex>) -> tensor<2x3x3xcomplex> loc(#loc5) + %2:2 = stablehlo.custom_call @cusolver_geqrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>]} : (tensor<2x3x3xcomplex>) -> (tensor<2x3x3xcomplex>, tensor<2x3xcomplex>) loc(#loc6) + %3 = stablehlo.pad %2#0, %cst, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<2x3x3xcomplex>, tensor>) -> tensor<2x3x3xcomplex> loc(#loc7) + %4 = stablehlo.custom_call @cusolver_orgqr_ffi(%3, %2#1) {mhlo.backend_config = {}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>]} : (tensor<2x3x3xcomplex>, tensor<2x3xcomplex>) -> tensor<2x3x3xcomplex> loc(#loc8) + %5 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc9) + %6 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc10) + %7 = stablehlo.add %5, %6 : tensor<3x3xi32> loc(#loc10) + %8 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc9) + %9 = stablehlo.compare GE, %7, %8, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc11) + %10 = stablehlo.broadcast_in_dim %9, dims = [1, 2] : (tensor<3x3xi1>) -> tensor<2x3x3xi1> loc(#loc12) + %11 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<2x3x3xcomplex> loc(#loc12) + %12 = stablehlo.select %10, %11, %2#0 : tensor<2x3x3xi1>, tensor<2x3x3xcomplex> loc(#loc13) + return %4, %12 : tensor<2x3x3xcomplex>, tensor<2x3x3xcomplex> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":390:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":390:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":391:11) +#loc4 = loc("jit()/jit(main)/iota"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape"(#loc2)) +#loc6 = loc("jit()/jit(main)/geqrf"(#loc3)) +#loc7 = loc("jit()/jit(main)/pad"(#loc3)) +#loc8 = loc("jit()/jit(main)/householder_product"(#loc3)) +#loc9 = loc("jit()/jit(main)/iota"(#loc3)) +#loc10 = loc("jit()/jit(main)/add"(#loc3)) +#loc11 = loc("jit()/jit(main)/ge"(#loc3)) +#loc12 = loc("jit()/jit(main)/broadcast_in_dim"(#loc3)) +#loc13 = loc("jit()/jit(main)/select_n"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.5.0\x00\x01)\x05\x01\x05\x19\x01\x03\x0b\x03\x17\x0f\x13\x17\x1b\x1f#'+/37\x03\xc9\x89-\x01C\x17\x07\x0b\x0f\x0b\x13\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x17\x0f\x0b\x17\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x0b\x0f\x0b\x03G\x0b/\x0b\x0b\x0b\x0f\x0b\x0b\x0b\x0fo\x13\x0f\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x1fO\x0b\x0b\x0f\x17O\x0b\x0f\x13\x0f\x0b\x0bO\x01\x05\x0b\x0f\x03)\x1b\x07\x0b\x17\x0f\x07\x0f\x07\x07\x17\x07\x13\x17\x13\x13\x13\x13\x17\x1b\x13\x02~\x05\x17\x05\x1e\x06\x17\x1f\x05\x1d\x11\x03\x05\x05\x1f\x03\x03)q\x1d\t\x01\x1d7\x01\x1d=\x01\x03\x07\x15\x17\x19\x07\x1b\x07\x05!\x11\x01\x00\x05#\x05%\x05'\x1d\t!\x17\x05\x1a\x065\x1d%'\x05)\x17\x05\x1a\x06\x1d\x05+\x1d-\x01\x05-\x1d1\x01\x05/\x1d5\x01\x051\x053\x1d;\x01\x055\x057\x1dA\x01\x059\x03\x01\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d;\x1d=\x1d?\x13\x07\x01\x0b\x03\x1dA\x05\x01\x03\x03W\x1f\x1f1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x05Wy\x1f%\x01#\x17\x03\x05ae\r\x05GcIK\x1dC\r\x05GgIK\x1dE\x1dG\x1dI\x1f\r\t\xff\xff\xff\xff\x1f\x11!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\r\x01\x1dK\x03\x03w\x15\x03\x01\x01\x01\x1f!!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1dM\x03\x03\x7f\x15\x01\x01\x01\x13\x07\x05\t\x07\x07\x05\x1f+!\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\r\r\t\x1d\x03\x19)\x05\r\r\x0f)\x01\x0f\x1b)\x01\t\x13\x01\x11\x01\x05\x05\x05\x0b)\x03I\t)\x05\t\r\t)\x03\r\x13)\x03\t\x13)\x03\r\x07)\x03\x01\x07)\x05\r\r\x15)\x07\t\r\r\x15)\x03\t\x07\x04F\x02\x05\x01Q\x03\x13\x01\x07\x04\x1e\x02\x03\x01\x05\x0bP\x03\x03\x07\x04\xfb\x03!A\x07B\x03\x05\x03\r\x07B\x03\x07\x03\x11\x03B\x1f\t\x03\x1b\r\x06#\x03\x05\x03\x05\tG+\x0b\x0b\x05\x05\x1d\x03\x07\x0fF/\r\x03\x05\x05\t\x03\tG3\x0b\x0f\x03\x05\x05\r\x0b\x03B\r\t\x03\x0b\x05F\x0f\x11\x03\x0b\x03\x01\x11\x06\x0f\x03\x0b\x05\x11\x13\x03B\r\x13\x03\x0b\x13F9\x15\x03'\x05\x15\x17\x05F\x11\x17\x03)\x03\x19\x05F\x11\x11\x03\x05\x03\x03\x15\x06?\x03\x05\x07\x1b\x1d\t\x17\x04\x03\x05\x0f\x1f\x06\x03\x01\x05\x01\x00J\x0bO''\x0f\x0b\t\t\x03\x11#!CS79Y9=)A\x1b%)9;i\x15\x15\x17\x0f\x0f\x17\x11\x1f\x19)\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00iota_v1\x00broadcast_in_dim_v1\x00constant_v1\x00custom_call_v1\x00func_v1\x00reshape_v1\x00pad_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit()/jit(main)/iota\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/reshape\x00mhlo.backend_config\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/pad\x00jit()/jit(main)/householder_product\x00jit()/jit(main)/add\x00jit()/jit(main)/ge\x00jit()/jit(main)/broadcast_in_dim\x00jit()/jit(main)/select_n\x00jax.result_info\x00mhlo.layout_mode\x00default\x00\x00[0]\x00[1]\x00main\x00public\x00cusolver_geqrf_ffi\x00cusolver_orgqr_ffi\x00\x08_\x19\x05;\x01\x0bC]_ik\x03m\x03o\x03M\x11OQsCSUuY\x07EEE\x11OQ{CSY}U\x03[\x03\x81\x05\x83\x85\x03\x87", + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_util.py b/jax/_src/internal_test_util/export_back_compat_test_util.py index 5a975e3c5a61..5201e8edef87 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_util.py +++ b/jax/_src/internal_test_util/export_back_compat_test_util.py @@ -245,6 +245,7 @@ def run_one_test(self, func: Callable[..., jax.Array], logging.info("Writing the updated testdata at %s", output_file) with open(output_file, "w") as f: f.write(updated_testdata) + print(updated_testdata) if rtol is None: rtol = 1.e-7 diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index ef6a5a11a56e..e6546e94abd4 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -1678,7 +1678,7 @@ def _geqrf_abstract_eval(operand): if operand.ndim < 2: raise ValueError("Argument to QR decomposition must have ndims >= 2") *batch_dims, m, n = operand.shape - taus = operand.update(shape=(*batch_dims, min(m, n))) + taus = operand.update(shape=(*batch_dims, core.min_dim(m, n))) return operand, taus def _geqrf_batching_rule(batched_args, batch_dims): @@ -1707,60 +1707,20 @@ def _geqrf_lowering_rule(ctx, operand): ) return op.results -def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a, *, - platform: str): - a_aval, taus_aval = ctx.avals_out - *batch_dims, m, n = a_aval.shape - # It should be possible to support fully-dynamic shapes, but since - # the last two dimensions (m, n) are used in more involved ways, we only - # support dynamic dimensions for the batch size for now. - if not is_constant_shape([m, n]): - raise NotImplementedError( - "Shape polymorphism for native serialization for qr on CPU and GPU is " - f"implemented only for the batch dimensions: {a_aval.shape}") - batch = math.prod(batch_dims) - - if batch == 0 or m == 0 or n == 0: - return mlir.full_like_aval(ctx, 0, a_aval), mlir.full_like_aval(ctx, 0, taus_aval) - - if not is_constant_shape(a_aval.shape): - if platform in ["cuda", "rocm"]: - # TODO(necula): remove the platform kwarg when we implement GPU support. - raise NotImplementedError( - "Shape polymorphism for native serialization for QR is not " - f"implemented, try to upgrade jaxlib; b/261671778; {a_aval.shape}") - - if (batched_geqrf_impl is not None and batch > 1 and m // batch <= 128 and - n // batch <= 128): - a_out, taus = batched_geqrf_impl(a_aval.dtype, a) +def _geqrf_cpu_gpu_lowering(ctx, a, *, target_name_prefix: str): + operand_aval, = ctx.avals_in + batch_dims = operand_aval.shape[:-2] + nb = len(batch_dims) + layout = (nb, nb + 1) + tuple(range(nb - 1, -1, -1)) + result_layouts = [layout, tuple(range(nb, -1, -1))] + if target_name_prefix == "cpu": + target_name = lapack.build_lapack_fn_target("geqrf_ffi", operand_aval.dtype) else: - if platform in ["cuda", "rocm"]: - a_out, taus, info_geqrf = geqrf_impl(a_aval.dtype, a) - else: - a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, a_aval.shape) - ctx_args = ( - (ctx,) if platform == "cpu" else () - ) - a_out, taus, *maybe_info_geqrf = geqrf_impl( - *ctx_args, a_aval.dtype, a, a_shape_vals=a_shape_vals - ) - if not ctx.is_forward_compat(): - # Skip the info parameter verification for the FFI kernel. - return a_out, taus - # TODO(b/344892332): This parameter will no longer be needed after - # the forward compatibility period - info_geqrf = maybe_info_geqrf[0] - zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))) - ok = mlir.compare_hlo(info_geqrf, zeros, "EQ", "SIGNED") - select_ok_a_aval = ShapedArray(batch_dims + [1, 1], np.dtype(np.bool_)) - ok_a = mlir.broadcast_in_dim(ctx, ok, select_ok_a_aval, - broadcast_dimensions=range(len(batch_dims))) - a_out = _broadcasting_select_hlo(ctx, ok_a, select_ok_a_aval, a_out, a_aval, _nan_like_hlo(ctx, a_aval), a_aval) - select_ok_taus_aval = ShapedArray(batch_dims + [1], np.dtype(np.bool_)) - ok_taus = mlir.broadcast_in_dim(ctx, ok, select_ok_taus_aval, - broadcast_dimensions=range(len(batch_dims))) - taus = _broadcasting_select_hlo(ctx, ok_taus, select_ok_taus_aval, taus, taus_aval, _nan_like_hlo(ctx, taus_aval), taus_aval) - return a_out, taus + target_name = f"{target_name_prefix}solver_geqrf_ffi" + rule = ffi.ffi_lowering(target_name, operand_layouts=[layout], + result_layouts=result_layouts, + operand_output_aliases={0: 0}) + return rule(ctx, a) geqrf_p = Primitive('geqrf') geqrf_p.multiple_results = True @@ -1770,20 +1730,15 @@ def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a, *, mlir.register_lowering(geqrf_p, _geqrf_lowering_rule) mlir.register_lowering( - geqrf_p, partial(_geqrf_cpu_gpu_lowering, lapack.geqrf_hlo, None, - platform='cpu'), + geqrf_p, partial(_geqrf_cpu_gpu_lowering, target_name_prefix='cpu'), platform='cpu') mlir.register_lowering( geqrf_p, - partial(_geqrf_cpu_gpu_lowering, gpu_solver.cuda_geqrf, - gpu_solver.cuda_geqrf_batched, - platform='cuda'), + partial(_geqrf_cpu_gpu_lowering, target_name_prefix='cu'), platform='cuda') mlir.register_lowering( geqrf_p, - partial(_geqrf_cpu_gpu_lowering, gpu_solver.rocm_geqrf, - gpu_solver.rocm_geqrf_batched, - platform='rocm'), + partial(_geqrf_cpu_gpu_lowering, target_name_prefix='hip'), platform='rocm') @@ -1813,7 +1768,7 @@ def _householder_product_abstract_eval(a, taus): raise ValueError("Argument to Householder product must have ndims >= 2") *batch_dims, m, n = a.shape *taus_batch_dims, k = taus.shape - if a.dtype != taus.dtype or batch_dims != taus_batch_dims or k > min(m, n): + if a.dtype != taus.dtype or batch_dims != taus_batch_dims or k > core.min_dim(m, n): raise ValueError(f"Type mismatch for Householder product: {a=} {taus=}") if m < n: raise ValueError("Householder product inputs must have at least as many " @@ -1841,48 +1796,23 @@ def _householder_product_lowering_rule(ctx, a, taus): result_shapes=result_shapes) return [op.result] -def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus, *, - platform: str): - a_aval, taus_aval = ctx.avals_in - *batch_dims, m, n = a_aval.shape - if not is_constant_shape([m, n]): - raise NotImplementedError( - "Shape polymorphism for native serialization for householder_product on " - f"CPU and GPU is implemented only for the batch dimensions: {a_aval.shape}") - - if m == 0 or n == 0: - return [mlir.full_like_aval(ctx, 0, a_aval)] - - if platform in ["rocm", "cuda"]: - # TODO(necula): remove the platform kwarg when we implement GPU support. - if not is_constant_shape(a_aval.shape): - raise NotImplementedError( - "Shape polymorphism for native serialization for householder_product " - f"on GPU is not implemented; b/261671778; {a_aval.shape}") - a, info_orgqr = orgqr_impl(a_aval.dtype, a, taus) +def _householder_product_cpu_gpu_lowering(ctx, a, taus, *, + target_name_prefix: str): + a_aval, _ = ctx.avals_in + batch_dims = a_aval.shape[:-2] + nb = len(batch_dims) + layout = (nb, nb + 1) + tuple(range(nb - 1, -1, -1)) + tau_layout = tuple(range(nb, -1, -1)) + if target_name_prefix == "cpu": + dtype = a_aval.dtype + prefix = "un" if dtypes.issubdtype(dtype, np.complexfloating) else "or" + target_name = lapack.build_lapack_fn_target(f"{prefix}gqr_ffi", dtype) else: - ctx_args = ( - (ctx,) if platform == "cpu" else () - ) - a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, a_aval.shape) - tau_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, taus_aval.shape) - a, *maybe_info_orgqr = orgqr_impl(*ctx_args, a_aval.dtype, a, taus, - a_shape_vals=a_shape_vals, - tau_shape_vals=tau_shape_vals) - if not ctx.is_forward_compat(): - # Skip the info parameter verification for the FFI kernel. - return [a] - # TODO(b/344892332): This parameter will no longer be needed after - # the forward compatibility period - info_orgqr = maybe_info_orgqr[0] - zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))) - ok = mlir.compare_hlo(info_orgqr, zeros, "EQ", "SIGNED") - select_a_aval = ShapedArray(batch_dims + [1, 1], np.dtype(np.bool_)) - ok = mlir.broadcast_in_dim(ctx, ok, select_a_aval, - broadcast_dimensions=range(len(batch_dims))) - a = _broadcasting_select_hlo(ctx, ok, select_a_aval, a, a_aval, _nan_like_hlo(ctx, a_aval), a_aval) - return [a] - + target_name = f"{target_name_prefix}solver_orgqr_ffi" + rule = ffi.ffi_lowering(target_name, operand_layouts=[layout, tau_layout], + result_layouts=[layout], + operand_output_aliases={0: 0}) + return rule(ctx, a, taus) householder_product_p = Primitive('householder_product') householder_product_p.def_impl(partial(dispatch.apply_primitive, householder_product_p)) @@ -1892,18 +1822,15 @@ def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus, *, mlir.register_lowering( householder_product_p, - partial(_householder_product_cpu_gpu_lowering, lapack.orgqr_hlo, - platform='cpu'), + partial(_householder_product_cpu_gpu_lowering, target_name_prefix='cpu'), platform='cpu') mlir.register_lowering( householder_product_p, - partial(_householder_product_cpu_gpu_lowering, gpu_solver.cuda_orgqr, - platform='cuda'), + partial(_householder_product_cpu_gpu_lowering, target_name_prefix='cu'), platform='cuda') mlir.register_lowering( householder_product_p, - partial(_householder_product_cpu_gpu_lowering, gpu_solver.rocm_orgqr, - platform='rocm'), + partial(_householder_product_cpu_gpu_lowering, target_name_prefix='hip'), platform='rocm') @@ -1916,7 +1843,7 @@ def _qr_abstract_eval(operand, *, full_matrices): if operand.ndim < 2: raise ValueError("Argument to QR decomposition must have ndims >= 2") *batch_dims, m, n = operand.shape - k = m if full_matrices else min(m, n) + k = m if full_matrices else core.min_dim(m, n) q = operand.update(shape=(*batch_dims, m, k)) r = operand.update(shape=(*batch_dims, k, n)) else: @@ -1953,7 +1880,7 @@ def _qr_batching_rule(batched_args, batch_dims, *, full_matrices): def _qr_lowering(a, *, full_matrices): *batch_dims, m, n = a.shape if m == 0 or n == 0: - k = m if full_matrices else min(m, n) + k = m if full_matrices else core.min_dim(m, n) q = lax.broadcast_in_dim(lax_internal._eye(a.dtype, (m, k)), (*batch_dims, m, k), (len(batch_dims), len(batch_dims) + 1)) diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index ff1e5570bb04..3dd933a616b7 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -151,86 +151,6 @@ def _getrf_hlo(platform, gpu_blas, gpu_solver, dtype, a): rocm_getrf = partial(_getrf_hlo, "hip", _hipblas, _hipsolver) -def _geqrf_hlo(platform, gpu_solver, dtype, a): - """QR decomposition.""" - a_type = ir.RankedTensorType(a.type) - dims = a_type.shape - assert len(dims) >= 2 - m, n = dims[-2:] - batch_dims = tuple(dims[:-2]) - num_bd = len(batch_dims) - batch = math.prod(batch_dims) - - lwork, opaque = gpu_solver.build_geqrf_descriptor( - np.dtype(dtype), batch, m, n) - - layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - i32_type = ir.IntegerType.get_signless(32) - out = custom_call( - f"{platform}solver_geqrf", - result_types=[ - a.type, - ir.RankedTensorType.get(batch_dims + (min(m, n),), a_type.element_type), - ir.RankedTensorType.get(batch_dims, i32_type), - ir.RankedTensorType.get([lwork], a_type.element_type), - ], - operands=[a], - backend_config=opaque, - operand_layouts=[layout], - result_layouts=[ - layout, - tuple(range(num_bd, -1, -1)), - tuple(range(num_bd - 1, -1, -1)), - [0], - ], - operand_output_aliases={0: 0}).results - return out[:3] - -cuda_geqrf = partial(_geqrf_hlo, "cu", _cusolver) -rocm_geqrf = partial(_geqrf_hlo, "hip", _hipsolver) - -def _geqrf_batched_hlo(platform, gpu_blas, dtype, a): - """Batched QR decomposition.""" - a_type = ir.RankedTensorType(a.type) - dims = a_type.shape - assert len(dims) >= 2 - m, n = dims[-2:] - batch_dims = tuple(dims[:-2]) - num_bd = len(batch_dims) - batch = math.prod(batch_dims) - - if not gpu_blas: - raise GpuLibNotLinkedError() - - lwork, opaque = gpu_blas.build_geqrf_batched_descriptor( - np.dtype(dtype), batch, m, n) - - layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - out = custom_call( - f"{platform}blas_geqrf_batched", - result_types=[ - a.type, - ir.RankedTensorType.get(batch_dims + (min(m, n),), a_type.element_type), - ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8)), - ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8)), - ], - operands=[a], - backend_config=opaque, - operand_layouts=[layout], - result_layouts=[ - layout, - tuple(range(num_bd, -1, -1)), - [0], - [0], - ], - operand_output_aliases={0: 0} - ).results - return out[:2] - -cuda_geqrf_batched = partial(_geqrf_batched_hlo, "cu", _cublas) -rocm_geqrf_batched = partial(_geqrf_batched_hlo, "hip", _hipblas) - - def _csrlsvqr_hlo(platform, gpu_solver, dtype, data, indices, indptr, b, tol, reorder): """Sparse solver via QR decomposition. CUDA only.""" @@ -256,50 +176,6 @@ def _csrlsvqr_hlo(platform, gpu_solver, dtype, data, cuda_csrlsvqr = partial(_csrlsvqr_hlo, "cu", _cusolver) -def _orgqr_hlo(platform, gpu_solver, dtype, a, tau): - """Product of elementary Householder reflections.""" - a_type = ir.RankedTensorType(a.type) - dims = a_type.shape - assert len(dims) >= 2 - m, n = dims[-2:] - batch_dims = tuple(dims[:-2]) - num_bd = len(batch_dims) - batch = math.prod(batch_dims) - - tau_dims = ir.RankedTensorType(tau.type).shape - assert tau_dims[:-1] == dims[:-2] - k = tau_dims[-1] - - lwork, opaque = gpu_solver.build_orgqr_descriptor( - np.dtype(dtype), batch, m, n, k) - - layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - i32_type = ir.IntegerType.get_signless(32) - out = custom_call( - f"{platform}solver_orgqr", - result_types=[ - a.type, - ir.RankedTensorType.get(batch_dims, i32_type), - ir.RankedTensorType.get([lwork], a_type.element_type), - ], - operands=[a, tau], - backend_config=opaque, - operand_layouts=[ - layout, - tuple(range(num_bd, -1, -1)), - ], - result_layouts=[ - layout, - tuple(range(num_bd - 1, -1, -1)), - [0], - ], - operand_output_aliases={0: 0}).results - return out[:2] - -cuda_orgqr = partial(_orgqr_hlo, "cu", _cusolver) -rocm_orgqr = partial(_orgqr_hlo, "hip", _hipsolver) - - def _syevd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a, *, a_shape_vals: tuple[DimensionSize, ...], lower=False): """Symmetric (Hermitian) eigendecomposition.""" diff --git a/jaxlib/lapack.py b/jaxlib/lapack.py index e23cc0075139..a977d4be620b 100644 --- a/jaxlib/lapack.py +++ b/jaxlib/lapack.py @@ -201,170 +201,6 @@ def getrf_hlo(dtype, a: ir.Value, *, a_shape_vals: tuple[DimensionSize, ...]): result_shapes=result_shapes, ).results -# # ?geqrf: QR decomposition - - -def geqrf_hlo( - ctx, dtype, a: ir.Value, *, a_shape_vals: tuple[DimensionSize, ...] -): - a_type = ir.RankedTensorType(a.type) - assert len(a_shape_vals) >= 2 - m, n = a_shape_vals[-2:] - assert type(m) is int - assert type(n) is int - - batch_dims_vals = a_shape_vals[:-2] - num_bd = len(batch_dims_vals) - fn_base = prepare_lapack_call(fn_base="geqrf", dtype=dtype) - - layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - i32_type = ir.IntegerType.get_signless(32) - - if ctx.is_forward_compat(): - fn = fn_base - if dtype == np.float32: - lwork = _lapack.lapack_sgeqrf_workspace(m, n) - elif dtype == np.float64: - lwork = _lapack.lapack_dgeqrf_workspace(m, n) - elif dtype == np.complex64: - lwork = _lapack.lapack_cgeqrf_workspace(m, n) - elif dtype == np.complex128: - lwork = _lapack.lapack_zgeqrf_workspace(m, n) - else: - raise NotImplementedError(f"Unsupported dtype {dtype}") - - scalar_layout = [] - batch_size_val = hlo_s32(1) - for b_v in batch_dims_vals: - batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) - shape_type_pairs: Sequence[ShapeTypePair] = [ - (a_shape_vals, a_type.element_type), - (batch_dims_vals + (min(m, n),), a_type.element_type), - (batch_dims_vals, i32_type), - ([lwork], a_type.element_type), - ] - result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) - return custom_call( - fn, - result_types=result_types, - operands=[batch_size_val, hlo_s32(m), hlo_s32(n), hlo_s32(lwork), a], - operand_layouts=[scalar_layout] * 4 + [layout], - result_layouts=[ - layout, - tuple(range(num_bd, -1, -1)), - tuple(range(num_bd - 1, -1, -1)), - [0], - ], - operand_output_aliases={4: 0}, - result_shapes=result_shapes, - ).results[:3] - fn = fn_base + "_ffi" - shape_type_pairs: Sequence[ShapeTypePair] = [ - (a_shape_vals, a_type.element_type), - (batch_dims_vals + (min(m, n),), a_type.element_type), - ] - result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) - return custom_call( - fn, - result_types=result_types, - operands=[a], - operand_layouts=[layout], - result_layouts=[ - layout, - tuple(range(num_bd, -1, -1)), - ], - operand_output_aliases={0: 0}, - result_shapes=result_shapes, - backend_config={}, - api_version=4, - ).results - - -# # ?orgqr: product of elementary Householder reflectors: -def orgqr_hlo(ctx, dtype, a: ir.Value, tau, *, - a_shape_vals: tuple[DimensionSize, ...], - tau_shape_vals: tuple[DimensionSize, ...]): - fn_base = "un" if dtype == np.complex64 or dtype == np.complex128 else "or" - fn_base = prepare_lapack_call(fn_base=fn_base + "gqr", dtype=dtype) - a_type = ir.RankedTensorType(a.type) - dims = a_type.shape - dims_vals = a_shape_vals - assert len(dims) >= 2 - m, n = dims[-2:] - assert m != ir.ShapedType.get_dynamic_size() - assert n != ir.ShapedType.get_dynamic_size() - batch_dims_vals = dims_vals[:-2] - num_bd = len(batch_dims_vals) - k = tau_shape_vals[-1] - assert type(k) is int - layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - i32_type = ir.IntegerType.get_signless(32) - - if ctx.is_forward_compat(): - fn = fn_base - batch_size_val = hlo_s32(1) - for b_v in batch_dims_vals: - batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) - - if dtype == np.float32: - lwork = _lapack.lapack_sorgqr_workspace(m, n, k) - elif dtype == np.float64: - lwork = _lapack.lapack_dorgqr_workspace(m, n, k) - elif dtype == np.complex64: - lwork = _lapack.lapack_cungqr_workspace(m, n, k) - elif dtype == np.complex128: - lwork = _lapack.lapack_zungqr_workspace(m, n, k) - else: - raise NotImplementedError(f"Unsupported dtype {dtype}") - - scalar_layout = [] - shape_type_pairs: Sequence[ShapeTypePair] = [ - (a_shape_vals, a_type.element_type), - (batch_dims_vals, i32_type), - ([lwork], a_type.element_type), - ] - result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) - return custom_call( - fn, - result_types=result_types, - operands=[batch_size_val, hlo_s32(m), hlo_s32(n), hlo_s32(k), - hlo_s32(lwork), a, tau], - operand_layouts=[scalar_layout] * 5 + [ - layout, - tuple(range(num_bd, -1, -1)), - ], - result_layouts=[ - layout, - tuple(range(num_bd - 1, -1, -1)), - [0], - ], - operand_output_aliases={5: 0}, - result_shapes=result_shapes, - ).results[:2] - fn = fn_base + "_ffi" - shape_type_pairs: Sequence[ShapeTypePair] = [ - (a_shape_vals, a_type.element_type), - ] - result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) - return custom_call( - fn, - result_types=result_types, - operands=[ - a, tau - ], - operand_layouts=[ - layout, - tuple(range(num_bd, -1, -1)), - ], - result_layouts=[ - layout, - ], - operand_output_aliases={0: 0}, - result_shapes=result_shapes, - backend_config={}, - api_version=4, - ).results - # ?potrf: Cholesky decomposition diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 103357ac18ac..3ff625deea82 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -131,7 +131,7 @@ def test_custom_call_coverage(self): cpu_lu_lapack_getrf.data_2023_06_14, cuda_lu_pivots_to_permutation.data_2024_08_08, cuda_lu_cusolver_getrf.data_2024_08_19, - cuda_qr_cusolver_geqrf.data_2023_03_18, + cuda_qr_cusolver_geqrf.data_2024_09_26, cuda_eigh_cusolver_syev.data_2023_03_17, rocm_qr_hipsolver_geqrf.data_2024_08_05, rocm_eigh_hipsolver_syev.data_2024_08_05, @@ -394,41 +394,55 @@ def qr_harness(shape, dtype): dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) for dtype_name in ("f32", "f64", "c64", "c128")) def test_cpu_qr_lapack_geqrf(self, dtype_name="f32"): - # For lax.linalg.qr - if not config.enable_x64.value and dtype_name in ["f64", "c128"]: - self.skipTest("Test disabled for x32 mode") - + rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name] dtype = dict(f32=np.float32, f64=np.float64, c64=np.complex64, c128=np.complex128)[dtype_name] func = lambda: CompatTest.qr_harness((3, 3), dtype) - data = self.load_testdata(cpu_qr_lapack_geqrf.data_2023_03_17[dtype_name]) - rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name] + + info = cpu_qr_lapack_geqrf.data_2024_08_22[dtype_name] + data = self.load_testdata(info) self.run_one_test(func, data, rtol=rtol) - with config.export_ignore_forward_compatibility(True): - # FFI Kernel test - data = self.load_testdata( - cpu_qr_lapack_geqrf.data_2024_08_22[dtype_name] - ) - self.run_one_test(func, data, rtol=rtol) + # TODO(b/369826500): Remove legacy custom call test after mid March 2025. + data = self.load_testdata(cpu_qr_lapack_geqrf.data_2023_03_17[dtype_name]) + self.run_one_test(func, data, rtol=rtol, + expect_current_custom_calls=info["custom_call_targets"]) + + # TODO(b/369826500): Remove legacy custom call test after mid March 2025. @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}_{batched}", dtype_name=dtype_name, batched=batched) for dtype_name in ("f32",) # For batched qr we use cublas_geqrf_batched/hipblas_geqrf_batched. for batched in ("batched", "unbatched")) - def test_gpu_qr_solver_geqrf(self, dtype_name="f32", batched="unbatched"): - if jtu.test_device_matches(["cuda"]): - data = self.load_testdata(cuda_qr_cusolver_geqrf.data_2023_03_18[batched]) - elif jtu.test_device_matches(["rocm"]): + def test_gpu_qr_solver_geqrf_legacy(self, dtype_name, batched): + if jtu.test_device_matches(["rocm"]): data = self.load_testdata(rocm_qr_hipsolver_geqrf.data_2024_08_05[batched]) + prefix = "hip" + elif jtu.test_device_matches(["cuda", "gpu"]): + data = self.load_testdata(cuda_qr_cusolver_geqrf.data_2023_03_18[batched]) + prefix = "cu" else: self.skipTest("Unsupported platform") - # For lax.linalg.qr - dtype = dict(f32=np.float32, f64=np.float64)[dtype_name] - rtol = dict(f32=1e-3, f64=1e-5)[dtype_name] + dtype = dict(f32=np.float32)[dtype_name] + rtol = dict(f32=1e-3)[dtype_name] shape = dict(batched=(2, 3, 3), unbatched=(3, 3))[batched] func = lambda: CompatTest.qr_harness(shape, dtype) + self.run_one_test(func, data, rtol=rtol, expect_current_custom_calls=[ + f"{prefix}solver_geqrf_ffi", f"{prefix}solver_orgqr_ffi"]) + + @parameterized.named_parameters( + dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) + for dtype_name in ("f32", "f64", "c64", "c128")) + def test_gpu_qr_solver_geqrf(self, dtype_name="f32"): + if jtu.test_device_matches(["rocm"]) or not jtu.test_device_matches(["gpu"]): + self.skipTest("Unsupported platform") + dtype = dict(f32=np.float32, f64=np.float64, + c64=np.complex64, c128=np.complex128)[dtype_name] + rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name] + shape = (2, 3, 3) + func = lambda: CompatTest.qr_harness(shape, dtype) + data = self.load_testdata(cuda_qr_cusolver_geqrf.data_2024_09_26[dtype_name]) self.run_one_test(func, data, rtol=rtol) def test_tpu_Qr(self): diff --git a/tests/linalg_test.py b/tests/linalg_test.py index e52582eb7526..5ace4b5ecf18 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -1007,7 +1007,7 @@ def testQrInvalidDtypeCPU(self, shape=(5, 6), dtype=np.float16): if jtu.test_device_matches(['cpu']): err, msg = NotImplementedError, "Unsupported dtype float16" else: - err, msg = ValueError, r"Unsupported dtype dtype\('float16'\)" + err, msg = Exception, "Unsupported dtype" with self.assertRaisesRegex(err, msg): jnp.linalg.qr(arr) diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index 77e5273d172a..4b0630e44a20 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -2770,22 +2770,30 @@ def test_vmap_error(self): arg_descriptors=[RandArg((3, 5), _f32)], polymorphic_shapes=["b, ..."]), [ - PolyHarness( + PolyHarness( # pylint: disable=g-complex-comprehension "qr", f"shape={jtu.format_shape_dtype_string(shape, dtype)}_poly={poly}_{full_matrices=}", lambda x, full_matrices: lax.linalg.qr(x, full_matrices=full_matrices), arg_descriptors=[RandArg(shape, dtype), StaticArg(full_matrices)], - polymorphic_shapes=[poly]) + polymorphic_shapes=[poly], + symbolic_constraints=constraints) for dtype in {np.float32, np.float64, np.complex64, np.complex128} & jtu.supported_dtypes() - # m and n must be static for now - for shape, poly, full_matrices in [ - ((2, 0, 4), "b, ...", False), # m = 0 - ((2, 4, 0), "b, ...", False), # n = 0 - ((2, 3, 4, 4), "b1, b2, ...", False), # m == n - ((2, 3, 4, 4), "b1, b2, ...", True), - ((2, 3, 4, 5), "b1, b2, ...", False), # m < n - ((2, 3, 4, 5), "b1, b2, ...", True), - ((2, 3, 8, 4), "b1, b2, ...", False), # m > n - ((2, 3, 8, 4), "b1, b2, ...", True), + for shape, poly, full_matrices, constraints in [ + ((2, 0, 4), "b, ...", False, ()), # m = 0 + ((2, 4, 0), "b, ...", False, ()), # n = 0 + ((2, 3, 4, 4), "b1, b2, ...", False, ()), # m == n + ((2, 3, 4, 4), "b1, b2, ...", True, ()), + ((2, 3, 4, 5), "b1, b2, ...", False, ()), # m < n + ((2, 3, 4, 5), "b1, b2, ...", True, ()), + ((2, 3, 8, 4), "b1, b2, ...", False, ()), # m > n + ((2, 3, 8, 4), "b1, b2, ...", True, ()), + # Dynamic shapes are also supported for non-batch dimensions with + # some constraints. + ((2, 3, 4, 4), "b1, b2, m, m", False, ()), # m == n + ((2, 3, 4, 4), "b1, b2, m, m", True, ()), + ((2, 3, 4, 5), "b1, b2, m, n", False, ["m + 1 <= n"]), # m < n + ((2, 3, 4, 5), "b1, b2, m, n", True, ["m + 1 <= n"]), + ((2, 3, 8, 4), "b1, b2, m, n", False, ["n <= m"]), # m > n + ((2, 3, 8, 4), "b1, b2, m, n", True, ["n <= m"]), ] ], [ @@ -3465,9 +3473,6 @@ def test_harness(self, harness: PolyHarness): # Exclude some harnesses that are known to fail for native serialization # Set of harness.group_name:platform that are implemented with custom call custom_call_harnesses = { - "householder_product:gpu", - "vmap_geqrf:gpu", # used for linalg.qr - "vmap_qr:gpu", "qr:gpu", "vmap_svd:gpu", } name_device_key = f"{harness.group_name}:{jtu.device_under_test()}"