Skip to content

Commit

Permalink
Improve pthread based lock
Browse files Browse the repository at this point in the history
  • Loading branch information
ronny-rentner committed Sep 26, 2022
1 parent 1ebfa8f commit 9f88a2f
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 116 deletions.
22 changes: 17 additions & 5 deletions UltraDict.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def fix_unregister(name, rtype):
#More details at: https://bugs.python.org/issue38119
remove_shm_from_resource_tracker()

running_on_linux = sys.platform.startswith("linux")

class UltraDict(collections.UserDict, dict):

Exceptions = Exceptions
Expand Down Expand Up @@ -121,13 +123,16 @@ def __init__(self, file_path):

self.lock = SharedMutex(file_path, lambda: True)

def acquire(self, *args, **kwargs):
def acquire(self, block=True, timeout=None, *args, **kwargs):
if self.has_lock:
self.has_lock += 1
return True

self.lock.lock()
self.has_lock = 1
if self.lock.lock(block=block, timeout=timeout):
self.has_lock = 1
return True
else:
return False

def release(self, *args, **kwargs):
if self.has_lock > 0:
Expand Down Expand Up @@ -512,7 +517,7 @@ def finalize(weak_self, name):

# Local lock for all processes and threads created by the same interpreter
if shared_lock:
if shared_lock == 'pymutex':
if shared_lock == 'pymutex' or running_on_linux:
self.lock = self.SharedMutexLock(f'{self.name}_mutex')
else:
self.lock = self.SharedLock(self, 'lock_remote', 'lock_pid_remote')
Expand Down Expand Up @@ -602,7 +607,14 @@ def get_memory(*, create=True, name=None, size=0):
"""
Attach an existing SharedMemory object with `name`.
If `create` is True, create the object if it does not exist.
If `create is True`, create the object if it does not exist. Throw an
exception if the memory already exists.
If `create is False`, attach to an existing memoery. Throw an exception
if the memory does not exist.
If `create is None`, either silently attach to an existing memory or
create it if it does not exist.
"""
assert size > 0 or not create
if name:
Expand Down
8 changes: 6 additions & 2 deletions examples/parallel.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
#
# Two processes are incrementing a counter in parallel
# UltraDict Example
# -----------------
#
# Four processes are incrementing a counter in parallel.
#
# In this example we use the shared_lock=True parameter.
# This way of shared locking is safe accross independent
# processes on all operting systems but it is slower than
# using the built-in default locking method using `multiprocessing.RLock()`.
#
# UltraDict uses the atomics package internally for shared locking.
# UltraDict uses the pthread library on Linux and the atomics package on other OS for shared locking.


# Make the example find UltraDict
import sys, os
Expand Down
66 changes: 0 additions & 66 deletions examples/parallel_linux_faster.py

This file was deleted.

56 changes: 27 additions & 29 deletions pymutex/mutex.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,17 @@
import mmap
import weakref
import logging
import ctypes as c
import ctypes.util
import ctypes, ctypes.util
from . import utils

_pt = c.CDLL(ctypes.util.find_library('pthread'))
_pt = ctypes.CDLL(ctypes.util.find_library('pthread'))

# pthread structs' size from pthreadtypes-arch.h
_PTHREAD_MUTEX_ATTRS_SIZE = 4
import platform as _platform
_arch = _platform.architecture()[0]
if _arch == '64bit':
if c.sizeof(c.POINTER(c.c_int)) == 8:
if ctypes.sizeof(ctypes.POINTER(ctypes.c_int)) == 8:
_PTHREAD_MUTEX_SIZE = 40
else:
_PTHREAD_MUTEX_SIZE = 32
Expand All @@ -65,11 +64,11 @@ def configure_default_logging():
logger.addHandler(handler)


# posix timespec in types/struct_timespec.h
class _timespec(c.Structure):
# posix timespec in types/struct_timespectypes.h
class _timespec(ctypes.Structure):
_fields_ = [
('tv_sec', c.c_long),
('tv_nsec', c.c_long)
('tv_sec', ctypes.c_long),
('tv_nsec', ctypes.c_long)
]

class InvalidSharedState(Exception):
Expand All @@ -91,9 +90,8 @@ class _MutexState:
When SharedMutex is collected by the GC, _MutexState is destroyed.
"""

def __init__(
self, pathname: str, mutex_ptr, mutex_attrs_ptr, mutex_fd: int,
mutex_mmap: mmap.mmap, recover_shared_state_cb):
def __init__(self, pathname: str, mutex_ptr, mutex_attrs_ptr, mutex_fd: int,
mutex_mmap: mmap.mmap, recover_shared_state_cb):
self.mutex_ptr = mutex_ptr
self.mutex_attrs_ptr = mutex_attrs_ptr
self.locked = False
Expand Down Expand Up @@ -163,67 +161,67 @@ def __init__(self, mutex_file: str, recover_shared_state_cb):
self._finalizer = weakref.finalize(self, _mutex_finalizer, self._state)
return
try:
mutex_attrs = c.create_string_buffer(_PTHREAD_MUTEX_ATTRS_SIZE)
mutex_attrs_ptr = c.byref(mutex_attrs)
mutex_attrs = ctypes.create_string_buffer(_PTHREAD_MUTEX_ATTRS_SIZE)
mutex_attrs_ptr = ctypes.byref(mutex_attrs)
e = _pt.pthread_mutexattr_init(mutex_attrs_ptr)
if e != 0:
raise OSError(e, os.strerror(e))
# set type to PTHREAD_MUTEX_ERRORCHECK
e = _pt.pthread_mutexattr_settype(mutex_attrs_ptr, c.c_int(2))
e = _pt.pthread_mutexattr_settype(mutex_attrs_ptr, ctypes.c_int(2))
if e != 0:
raise OSError(e, os.strerror(e))
# set robustness to PTHREAD_MUTEX_ROBUST
try:
e = _pt.pthread_mutexattr_setrobust(mutex_attrs_ptr, c.c_int(1))
e = _pt.pthread_mutexattr_setrobust(mutex_attrs_ptr, ctypes.c_int(1))
except AttributeError:
pass
if e != 0:
raise OSError(e, os.strerror(e))
# set sharing mode to PTHREAD_PROCESS_SHARED
e = _pt.pthread_mutexattr_setpshared(mutex_attrs_ptr, c.c_int(1))
e = _pt.pthread_mutexattr_setpshared(mutex_attrs_ptr, ctypes.c_int(1))
if e != 0:
raise OSError(e, os.strerror(e))
mutex = c.create_string_buffer(_PTHREAD_MUTEX_SIZE)
e = _pt.pthread_mutex_init(c.byref(mutex), mutex_attrs_ptr)
mutex = ctypes.create_string_buffer(_PTHREAD_MUTEX_SIZE)
e = _pt.pthread_mutex_init(ctypes.byref(mutex), mutex_attrs_ptr)
if e != 0:
_pt.pthread_mutexattr_destroy(mutex_attrs_ptr)
raise OSError(e, os.strerror(e))
try:
assert os.write(mutex_fd, mutex) == _PTHREAD_MUTEX_SIZE, 'Failed to store the mutex'
# Share the mutex by creating a memory mapped file
mutex_mmap = mmap.mmap(mutex_fd, 0, mmap.MAP_SHARED, mmap.PROT_WRITE | mmap.PROT_READ)
mutex = c.c_char.from_buffer(mutex_mmap)
mutex = ctypes.c_char.from_buffer(mutex_mmap)
# Process's user will have write and read permissions on the mutex file from now
# FIXME: Should it restrict all processes to be running as same user?
os.fchmod(mutex_fd, stat.S_IRUSR | stat.S_IWUSR)
except:
_pt.pthread_mutex_destroy(c.byref(mutex))
_pt.pthread_mutex_destroy(ctypes.byref(mutex))
_pt.pthread_mutexattr_destroy(mutex_attrs_ptr)
raise
except:
os.remove(mutex_file)
os.close(mutex_fd)
raise
self._state = _MutexState(mutex_file, c.byref(mutex), mutex_attrs_ptr, mutex_fd, mutex_mmap, recover_shared_state_cb)
self._state = _MutexState(mutex_file, ctypes.byref(mutex), mutex_attrs_ptr, mutex_fd, mutex_mmap, recover_shared_state_cb)
self._finalizer = weakref.finalize(self, _mutex_finalizer, self._state)

@property
def owns_lock(self):
"""Returns True if this mutex instance owns the lock, False otherwise."""
return self._state.locked

def lock(self, blocking: bool = True, timeout: float = 0):
"""Lock the mutex. If "blocking" is True, the current thread
def lock(self, block: bool = True, timeout: float = 0):
"""Lock the mutex. If "block" is True, the current thread
blocks until the mutex becomes available, if "timeout" > 0 the
thread blocks until timeout (in seconds) expires. Returns
True if the mutex was locked, False otherwise."""
if self._state.mutex_ptr is None: raise RuntimeError('Invalid state')
if blocking:
#if self._state.mutex_ptr is None: raise RuntimeError('Invalid state')
if block:
# The robustness setting only works for lock attempts that
# comes *after* the thread holding the lock terminates
# without releasing it. The blocking call will be made of
# several calls to pthread_mutex_timedlock.
if timeout > 0:
if timeout and timeout > 0:
locked = self._mutex_timedlock(timeout, False)
else:
locked = self._mutex_timedlock(_MUTEX_LOCK_HEARTBEAT)
Expand All @@ -242,7 +240,7 @@ def lock(self, blocking: bool = True, timeout: float = 0):

def unlock(self):
"""Unlock the mutex. Raises PermissionError if the current thread does not owns the lock."""
if self._state.mutex_ptr is None: raise RuntimeError('Invalid mutex state')
#if self._state.mutex_ptr is None: raise RuntimeError('Invalid mutex state')
e = _pt.pthread_mutex_unlock(self._state.mutex_ptr)
if e == 0:
self._state.locked = False
Expand All @@ -256,7 +254,7 @@ def _mutex_timedlock(self, timeout: float, until_lock = True):
current_timeout = time.clock_gettime(time.CLOCK_REALTIME) + timeout
e = _pt.pthread_mutex_timedlock(
self._state.mutex_ptr,
c.byref(_timespec(
ctypes.byref(_timespec(
int(current_timeout),
int((current_timeout * 1000 % 1000) * 1_000_000)
))
Expand Down Expand Up @@ -318,7 +316,7 @@ def _mutex_load(self, mutex_file: str, recover_shared_state_cb):
mutex_mmap = mmap.mmap(mutex_fd, 0, mmap.MAP_SHARED, mmap.PROT_WRITE | mmap.PROT_READ)
self._state = _MutexState(
mutex_file,
c.byref(c.c_char.from_buffer(mutex_mmap)),
ctypes.byref(ctypes.c_char.from_buffer(mutex_mmap)),
None,
mutex_fd,
mutex_mmap,
Expand Down
8 changes: 3 additions & 5 deletions tests/performance/lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,10 @@ def __init__(self):
self.lock = self.ultra.lock

def acquire(self):
#self.lock.acquire()
self.lock.test_and_inc()
self.lock.acquire(block=False)

def release(self):
#self.lock.release()
self.lock.test_and_dec()
self.lock.release()

class TestSharedMutexLock:
name = 'SharedMutexLock'
Expand All @@ -61,7 +59,7 @@ def __init__(self):
self.lock = self.ultra.lock

def acquire(self):
self.lock.acquire()
self.lock.acquire(block=False)

def release(self):
self.lock.release()
Expand Down
15 changes: 6 additions & 9 deletions tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,12 @@ def test_parameter_passing(self):
# Connect `other` dict to `ultra` dict via `name`
other = UltraDict(name=ultra.name)

self.assertIsInstance(ultra.lock, ultra.SharedLock)
self.assertIsInstance(other.lock, other.SharedLock)
if sys.platform.startswith("linux"):
self.assertIsInstance(ultra.lock, ultra.SharedMutexLock)
self.assertIsInstance(other.lock, other.SharedMutexLock)
else:
self.assertIsInstance(ultra.lock, ultra.SharedLock)
self.assertIsInstance(other.lock, other.SharedLock)

self.assertEqual(ultra.buffer_size, other.buffer_size)

Expand Down Expand Up @@ -162,13 +166,6 @@ def test_example_parallel(self):
self.assertReturnCode(ret)
self.assertEqual(ret.stdout.splitlines()[-1], b'Counter: 100000 == 100000', self.exec_show_output(ret))

@unittest.skipIf(sys.platform.startswith("win"), "not for Windows, requires libpthread")
def test_example_parallel_linux_faster(self):
filename = "examples/parallel_linux_faster.py"
ret = self.exec(filename)
self.assertReturnCode(ret)
self.assertEqual(ret.stdout.splitlines()[-1], b'Counter: 100000 == 100000', self.exec_show_output(ret))

def test_example_nested(self):
filename = "examples/nested.py"
ret = self.exec(filename)
Expand Down

0 comments on commit 9f88a2f

Please sign in to comment.