Skip to content

Commit

Permalink
Added py::mod_gil_not_used() to PYBIND11_MODULE register_jax_dialects
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Aug 19, 2024
1 parent 292161a commit b1b3ea2
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion jaxlib/mlir/_mlir_libs/register_jax_dialects.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Registers MLIR dialects used by JAX.
// This module is called by mlir/__init__.py during initialization.
#include <pybind11/pybind11.h>

#include "mlir-c/Dialect/Arith.h"
#include "mlir-c/Dialect/Func.h"
Expand All @@ -14,11 +15,13 @@
#include "mlir-c/Transforms.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"

namespace py = pybind11;

#define REGISTER_DIALECT(name) \
MlirDialectHandle name##_dialect = mlirGetDialectHandle__##name##__(); \
mlirDialectHandleInsertDialect(name##_dialect, registry)

PYBIND11_MODULE(register_jax_dialects, m) {
PYBIND11_MODULE(register_jax_dialects, m, py::mod_gil_not_used()) {
m.doc() = "Registers upstream MLIR dialects used by JAX.";

m.def("register_dialects", [](MlirDialectRegistry registry) {
Expand Down

0 comments on commit b1b3ea2

Please sign in to comment.