From 5a1d0a6c2637432d49df3bf36f4c3255294be86e Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 27 Sep 2024 08:52:42 -0700 Subject: [PATCH] Include the sdy MLIR dialect in jaxlib. We're seeing test failures from tests assuming that this dialect exists. But given we plan to enable it at some point, we may as well just include it in the build. The size impact is small (around 400K uncompressed). PiperOrigin-RevId: 679608092 --- jaxlib/tools/build_wheel.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 52a17c451aea..3c40c2d11fb5 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -351,6 +351,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels): f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}", + f"__main__/jaxlib/mlir/_mlir_libs/_sdy.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/register_jax_dialects.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsGPU.{pyext}",