From b1b3ea276b94c070db42dc88a3531fbed551189f Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 20 Aug 2024 00:03:56 +0200 Subject: [PATCH] Added py::mod_gil_not_used() to PYBIND11_MODULE register_jax_dialects --- jaxlib/mlir/_mlir_libs/register_jax_dialects.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc index e1958c211b33..2e10062945b5 100644 --- a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc +++ b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc @@ -1,5 +1,6 @@ // Registers MLIR dialects used by JAX. // This module is called by mlir/__init__.py during initialization. +#include #include "mlir-c/Dialect/Arith.h" #include "mlir-c/Dialect/Func.h" @@ -14,11 +15,13 @@ #include "mlir-c/Transforms.h" #include "mlir/Bindings/Python/PybindAdaptors.h" +namespace py = pybind11; + #define REGISTER_DIALECT(name) \ MlirDialectHandle name##_dialect = mlirGetDialectHandle__##name##__(); \ mlirDialectHandleInsertDialect(name##_dialect, registry) -PYBIND11_MODULE(register_jax_dialects, m) { +PYBIND11_MODULE(register_jax_dialects, m, py::mod_gil_not_used()) { m.doc() = "Registers upstream MLIR dialects used by JAX."; m.def("register_dialects", [](MlirDialectRegistry registry) {