From 6e1d8504cd9611615bff092092368d265e216126 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 27 Sep 2024 08:02:28 -0700 Subject: [PATCH] Include the sdy MLIR dialect in jaxlib. We're seeing test failures from tests assuming that this dialect exists. But given we plan to enable it at some point, we may as well just include it in the build. The size impact is small (around 400K uncompressed). PiperOrigin-RevId: 679592303 --- jaxlib/tools/build_wheel.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 52a17c451aea..3c40c2d11fb5 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -351,6 +351,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels): f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}", + f"__main__/jaxlib/mlir/_mlir_libs/_sdy.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/register_jax_dialects.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsGPU.{pyext}",