diff --git a/jaxlib/cpu_feature_guard.c b/jaxlib/cpu_feature_guard.c index 7c8ff2951a79..d18478eb57d5 100644 --- a/jaxlib/cpu_feature_guard.c +++ b/jaxlib/cpu_feature_guard.c @@ -172,5 +172,12 @@ static struct PyModuleDef cpu_feature_guard_module = { #endif EXPORT_SYMBOL PyMODINIT_FUNC PyInit_cpu_feature_guard(void) { - return PyModule_Create(&cpu_feature_guard_module); + PyObject *module = PyModule_Create(&cpu_feature_guard_module); + if (module == NULL) { + return NULL; + } +#ifdef Py_GIL_DISABLED + PyUnstable_Module_SetGIL(module, Py_MOD_GIL_NOT_USED); +#endif + return module; }