Skip to content

Commit

Permalink
Thread-safety changes
Browse files Browse the repository at this point in the history
  • Loading branch information
jmao-denver committed Sep 18, 2024
1 parent e36c55b commit d4d57ea
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 9 deletions.
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
os.path.join(src_test_py_dir, 'jpy_rt_test.py'),
os.path.join(src_test_py_dir, 'jpy_mt_test.py'),
os.path.join(src_test_py_dir, 'jpy_diag_test.py'),
# os.path.join(src_test_py_dir, 'jpy_perf_test.py'),
os.path.join(src_test_py_dir, 'jpy_perf_test.py'),
]

# Python unit tests that require target/test-classes or target/classes
Expand All @@ -97,6 +97,7 @@
os.path.join(src_test_py_dir, 'jpy_java_embeddable_test.py'),
os.path.join(src_test_py_dir, 'jpy_obj_test.py'),
os.path.join(src_test_py_dir, 'jpy_eval_exec_test.py'),
os.path.join(src_test_py_dir, 'jpy_mt_eval_exec_test.py'),
]

# e.g. jdk_home_dir = '/home/marta/jdk1.7.0_15'
Expand Down
24 changes: 16 additions & 8 deletions src/main/c/jni/org_jpy_PyLib.c
Original file line number Diff line number Diff line change
Expand Up @@ -1124,7 +1124,11 @@ JNIEXPORT void JNICALL Java_org_jpy_PyLib_incRef
if (Py_IsInitialized()) {
JPy_BEGIN_GIL_STATE

refCount = pyObject->ob_refcnt;
#if PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION <= 12
refCount = pyObject->ob_refcnt;
#else
refCount = Py_REFCNT(pyObject);
#endif
JPy_DIAG_PRINT(JPy_DIAG_F_MEM, "Java_org_jpy_PyLib_incRef: pyObject=%p, refCount=%d, type='%s'\n", pyObject, refCount, Py_TYPE(pyObject)->tp_name);
JPy_INCREF(pyObject);

Expand All @@ -1150,7 +1154,11 @@ JNIEXPORT void JNICALL Java_org_jpy_PyLib_decRef
if (Py_IsInitialized()) {
JPy_BEGIN_GIL_STATE

#if PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION <= 12
refCount = pyObject->ob_refcnt;
#else
refCount = Py_REFCNT(pyObject);
#endif
if (refCount <= 0) {
JPy_DIAG_PRINT(JPy_DIAG_F_ALL, "Java_org_jpy_PyLib_decRef: error: refCount <= 0: pyObject=%p, refCount=%d\n", pyObject, refCount);
} else {
Expand Down Expand Up @@ -1183,13 +1191,13 @@ JNIEXPORT void JNICALL Java_org_jpy_PyLib_decRefs
buf = (*jenv)->GetLongArrayElements(jenv, objIds, &isCopy);
for (i = 0; i < len; i++) {
pyObject = (PyObject*) buf[i];
refCount = pyObject->ob_refcnt;
if (refCount <= 0) {
JPy_DIAG_PRINT(JPy_DIAG_F_ALL, "Java_org_jpy_PyLib_decRefs: error: refCount <= 0: pyObject=%p, refCount=%d\n", pyObject, refCount);
} else {
JPy_DIAG_PRINT(JPy_DIAG_F_MEM, "Java_org_jpy_PyLib_decRefs: pyObject=%p, refCount=%d, type='%s'\n", pyObject, refCount, Py_TYPE(pyObject)->tp_name);
JPy_DECREF(pyObject);
}
// refCount = pyObject->ob_refcnt;
// if (refCount <= 0) {
// JPy_DIAG_PRINT(JPy_DIAG_F_ALL, "Java_org_jpy_PyLib_decRefs: error: refCount <= 0: pyObject=%p, refCount=%d\n", pyObject, refCount);
// } else {
// JPy_DIAG_PRINT(JPy_DIAG_F_MEM, "Java_org_jpy_PyLib_decRefs: pyObject=%p, refCount=%d, type='%s'\n", pyObject, refCount, Py_TYPE(pyObject)->tp_name);
// JPy_DECREF(pyObject);
// }
}
(*jenv)->ReleaseLongArrayElements(jenv, objIds, buf, JNI_ABORT);
JPy_END_GIL_STATE
Expand Down
66 changes: 66 additions & 0 deletions src/main/c/jpy_jtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,54 @@
#include "jpy_conv.h"
#include "jpy_compat.h"

#ifdef Py_GIL_DISABLED
typedef struct {
PyMutex lock;
PyThreadState* owner;
int recursion_level;
} ReentrantLock;

static void acquire_lock(ReentrantLock* self) {
PyThreadState* current_thread = PyThreadState_Get();

if (self->owner == current_thread) {
self->recursion_level++;
return;
}

PyMutex_Lock(&(self->lock));

self->owner = current_thread;
self->recursion_level = 1;
}

static void release_lock(ReentrantLock* self) {
if (self->owner != PyThreadState_Get()) {
PyErr_SetString(PyExc_RuntimeError, "Lock not owned by current thread");
return;
}

self->recursion_level--;
if (self->recursion_level == 0) {
self->owner = NULL;
PyMutex_Unlock(&(self->lock));
}
}

static ReentrantLock get_type_rlock = {{0}, NULL, 0};
static ReentrantLock resolve_type_rlock = {{0}, NULL, 0};

#define ACQUIRE_GET_TYPE_LOCK() acquire_lock(&get_type_rlock)
#define RELEASE_GET_TYPE_LOCK() release_lock(&get_type_rlock)
#define ACQUIRE_RESOLVE_TYPE_LOCK() acquire_lock(&resolve_type_rlock)
#define RELEASE_RESOLVE_TYPE_LOCK() release_lock(&resolve_type_rlock)

#else
#define ACQUIRE_GET_TYPE_LOCK()
#define RELEASE_GET_TYPE_LOCK()
#define ACQUIRE_RESOLVE_TYPE_LOCK()
#define RELEASE_RESOLVE_TYPE_LOCK()
#endif

JPy_JType* JType_New(JNIEnv* jenv, jclass classRef, jboolean resolve);
int JType_ResolveType(JNIEnv* jenv, JPy_JType* type);
Expand All @@ -52,6 +100,8 @@ static int JType_MatchVarArgPyArgAsFPType(const JPy_ParamDescriptor *paramDescri
static int JType_MatchVarArgPyArgIntType(const JPy_ParamDescriptor *paramDescriptor, PyObject *pyArg, int idx,
struct JPy_JType *expectedComponentType);



JPy_JType* JType_GetTypeForObject(JNIEnv* jenv, jobject objectRef, jboolean resolve)
{
JPy_JType* type;
Expand Down Expand Up @@ -151,6 +201,7 @@ JPy_JType* JType_GetType(JNIEnv* jenv, jclass classRef, jboolean resolve)
return NULL;
}

ACQUIRE_GET_TYPE_LOCK();
typeValue = PyDict_GetItem(JPy_Types, typeKey);
if (typeValue == NULL) {

Expand All @@ -160,6 +211,7 @@ JPy_JType* JType_GetType(JNIEnv* jenv, jclass classRef, jboolean resolve)
type = JType_New(jenv, classRef, resolve);
if (type == NULL) {
JPy_DECREF(typeKey);
RELEASE_GET_TYPE_LOCK();
return NULL;
}

Expand All @@ -184,6 +236,7 @@ JPy_JType* JType_GetType(JNIEnv* jenv, jclass classRef, jboolean resolve)
PyDict_DelItem(JPy_Types, typeKey);
JPy_DECREF(typeKey);
JPy_DECREF(type);
RELEASE_GET_TYPE_LOCK();
return NULL;
}

Expand All @@ -195,6 +248,7 @@ JPy_JType* JType_GetType(JNIEnv* jenv, jclass classRef, jboolean resolve)
PyDict_DelItem(JPy_Types, typeKey);
JPy_DECREF(typeKey);
JPy_DECREF(type);
RELEASE_GET_TYPE_LOCK();
return NULL;
}

Expand All @@ -206,6 +260,7 @@ JPy_JType* JType_GetType(JNIEnv* jenv, jclass classRef, jboolean resolve)
PyDict_DelItem(JPy_Types, typeKey);
JPy_DECREF(typeKey);
JPy_DECREF(type);
RELEASE_GET_TYPE_LOCK();
return NULL;
}

Expand All @@ -231,6 +286,7 @@ JPy_JType* JType_GetType(JNIEnv* jenv, jclass classRef, jboolean resolve)
"jpy internal error: attributes in 'jpy.%s' must be of type '%s', but found a '%s'",
JPy_MODULE_ATTR_NAME_TYPES, JType_Type.tp_name, Py_TYPE(typeValue)->tp_name);
JPy_DECREF(typeKey);
RELEASE_GET_TYPE_LOCK();
return NULL;
}

Expand All @@ -240,6 +296,7 @@ JPy_JType* JType_GetType(JNIEnv* jenv, jclass classRef, jboolean resolve)
}

JPy_DIAG_PRINT(JPy_DIAG_F_TYPE, "JType_GetType: javaName=\"%s\", found=%d, resolve=%d, resolved=%d, type=%p\n", type->javaName, found, resolve, type->isResolved, type);
RELEASE_GET_TYPE_LOCK();

if (!type->isResolved && resolve) {
if (JType_ResolveType(jenv, type) < 0) {
Expand Down Expand Up @@ -968,7 +1025,10 @@ int JType_ResolveType(JNIEnv* jenv, JPy_JType* type)
{
PyTypeObject* typeObj;

ACQUIRE_RESOLVE_TYPE_LOCK();

if (type->isResolved || type->isResolving) {
RELEASE_RESOLVE_TYPE_LOCK();
return 0;
}

Expand All @@ -980,6 +1040,7 @@ int JType_ResolveType(JNIEnv* jenv, JPy_JType* type)
if (!baseType->isResolved) {
if (JType_ResolveType(jenv, baseType) < 0) {
type->isResolving = JNI_FALSE;
RELEASE_RESOLVE_TYPE_LOCK();
return -1;
}
}
Expand All @@ -988,24 +1049,29 @@ int JType_ResolveType(JNIEnv* jenv, JPy_JType* type)
//printf("JType_ResolveType 1\n");
if (JType_ProcessClassConstructors(jenv, type) < 0) {
type->isResolving = JNI_FALSE;
RELEASE_RESOLVE_TYPE_LOCK();
return -1;
}

//printf("JType_ResolveType 2\n");
if (JType_ProcessClassMethods(jenv, type) < 0) {
type->isResolving = JNI_FALSE;
RELEASE_RESOLVE_TYPE_LOCK();
return -1;
}

//printf("JType_ResolveType 3\n");
if (JType_ProcessClassFields(jenv, type) < 0) {
type->isResolving = JNI_FALSE;
RELEASE_RESOLVE_TYPE_LOCK();
return -1;
}

//printf("JType_ResolveType 4\n");
type->isResolving = JNI_FALSE;
type->isResolved = JNI_TRUE;

RELEASE_RESOLVE_TYPE_LOCK();
return 0;
}

Expand Down
3 changes: 3 additions & 0 deletions src/main/c/jpy_module.c
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,9 @@ PyMODINIT_FUNC JPY_MODULE_INIT_FUNC(void)
if (JPy_Module == NULL) {
JPY_RETURN(NULL);
}
#ifdef Py_GIL_DISABLED
PyUnstable_Module_SetGIL(JPy_Module, Py_MOD_GIL_NOT_USED);
#endif
#elif defined(JPY_COMPAT_27)
JPy_Module = Py_InitModule3(JPY_MODULE_NAME, JPy_Functions, JPY_MODULE_DOC);
if (JPy_Module == NULL) {
Expand Down
54 changes: 54 additions & 0 deletions src/test/java/org/jpy/fixtures/MultiThreadedEvalTestFixture.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package org.jpy.fixtures;

import org.jpy.PyInputMode;
import org.jpy.PyLib;
import org.jpy.PyObject;

import java.util.List;

public class MultiThreadedEvalTestFixture {

public static void expression(String expression, int numThreads) {
PyObject globals = PyLib.getCurrentGlobals();
PyObject locals = PyLib.getCurrentLocals();

List<Thread> threads = new java.util.ArrayList<>();
for (int i = 0; i < numThreads; i++) {
threads.add(new Thread(() -> {
PyObject.executeCode(expression, PyInputMode.EXPRESSION, globals, locals);
}));
}
for (Thread thread : threads) {
thread.start();
}
for (Thread thread : threads) {
try {
thread.join();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
}

public static void script(String expression, int numThreads) {
List<Thread> threads = new java.util.ArrayList<>();
PyObject globals = PyLib.getCurrentGlobals();
PyObject locals = PyLib.getCurrentLocals();
for (int i = 0; i < numThreads; i++) {
threads.add(new Thread(() -> {
PyObject.executeCode(expression, PyInputMode.SCRIPT, globals, locals);
}));
}
for (Thread thread : threads) {
thread.start();
}
for (Thread thread : threads) {
try {
thread.join();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
}

}
1 change: 1 addition & 0 deletions src/test/python/jpy_eval_exec_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
jpyutil.init_jvm(jvm_maxmem='512M', jvm_classpath=['target/classes', 'target/test-classes'])
import jpy


class TestEvalExec(unittest.TestCase):
def setUp(self):
self.fixture = jpy.get_type("org.jpy.fixtures.EvalTestFixture")
Expand Down
69 changes: 69 additions & 0 deletions src/test/python/jpy_mt_eval_exec_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import unittest

import jpyutil

jpyutil.init_jvm(jvm_maxmem='512M', jvm_classpath=['target/classes', 'target/test-classes'])
import jpy
# jpy.diag.flags = jpy.diag.F_TYPE

NUM_THREADS = 20


class MultiThreadedTestEvalExec(unittest.TestCase):
def setUp(self):
self.fixture = jpy.get_type("org.jpy.fixtures.MultiThreadedEvalTestFixture")
self.assertIsNotNone(self.fixture)

def test_inc_baz(self):
baz = 15
self.fixture.script("baz = baz + 1; self.assertGreater(baz, 15)", NUM_THREADS)
# note: this *is* correct wrt python semantics w/ exec(code, globals(), locals())
# https://bugs.python.org/issue4831 (Note: it's *not* a bug, is working as intended)
self.assertEqual(baz, 15)

def test_exec_import(self):
import sys
self.assertTrue("json" not in sys.modules)
self.fixture.script("import json", NUM_THREADS)
self.assertTrue("json" in sys.modules)

def test_java_threading_jpy_get_type(self):
self.fixture.script("j_child1_class = jpy.get_type(\"org.jpy.fixtures.CyclicReferenceChild1\");j_child2_class ="
"jpy.get_type(\"org.jpy.fixtures.CyclicReferenceChild2\")", NUM_THREADS)

def test_py_threading_jpy_get_type(self):
import threading

test_self = self

class MyThread(threading.Thread):
def __init__(self):
threading.Thread.__init__(self)

def run(self):
barrier.wait()
j_child1_class = jpy.get_type("org.jpy.fixtures.CyclicReferenceChild1")
j_child2_class = jpy.get_type("org.jpy.fixtures.CyclicReferenceChild2")
j_child2 = j_child2_class()
j_child1 = j_child1_class.of(8)
test_self.assertEqual(88, j_child1.parentMethod())
test_self.assertEqual(888, j_child1.grandParentMethod())
test_self.assertIsNone(j_child1.refChild2(j_child2))
test_self.assertEqual(8, j_child1.get_x())
test_self.assertEqual(10, j_child1.y)
test_self.assertEqual(100, j_child1.z)

barrier = threading.Barrier(NUM_THREADS)
threads = []
for i in range(NUM_THREADS):
t = MyThread()
t.start()
threads.append(t)

for t in threads:
t.join()


if __name__ == '__main__':
print('\nRunning ' + __file__)
unittest.main()

0 comments on commit d4d57ea

Please sign in to comment.