From 2d05017ca887ce5289afe7cec18f1bc34a9aa493 Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Sun, 14 Jan 2024 13:37:18 -0800 Subject: [PATCH] removed race conditions --- .gitignore | 2 +- CMakeLists.txt | 2 +- include/LockPool.h | 3 +- libmc/__init__.py | 24 ++++++++++++-- libmc/_client.pyx | 15 ++++++--- misc/memcached_server | 5 +++ src/ClientPool.cpp | 1 - tests/test_client_pool.cpp | 57 ++++++++++++-------------------- tests/test_client_pool.py | 66 ++++++++++++++++++++++++++++++++++++-- 9 files changed, 125 insertions(+), 50 deletions(-) diff --git a/.gitignore b/.gitignore index 6d48b7e3..0949ba73 100644 --- a/.gitignore +++ b/.gitignore @@ -26,4 +26,4 @@ tests/resources/keys_*.txt /go.mod /go.sum -/cython_debug \ No newline at end of file +/cython_debug diff --git a/CMakeLists.txt b/CMakeLists.txt index 04a4799a..bddda399 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,7 +16,7 @@ if (NOT CMAKE_BUILD_TYPE) endif (NOT CMAKE_BUILD_TYPE) set(CMAKE_CXX_FLAGS_COMMON "-Wall -fno-rtti -fno-exceptions") -set(CMAKE_CXX_FLAGS_DEBUG "-DDEBUG -g2 ${CMAKE_CXX_FLAGS_COMMON}" CACHE STRING "CXX DEBUG FLAGS" FORCE) +set(CMAKE_CXX_FLAGS_DEBUG "-DDEBUG -g2 ${CMAKE_CXX_FLAGS_COMMON} -fsanitize=thread" CACHE STRING "CXX DEBUG FLAGS" FORCE) set(CMAKE_CXX_FLAGS_RELEASE "-DNDEBUG -O3 ${CMAKE_CXX_FLAGS_COMMON}" CACHE STRING "CXX RELEASE FLAGS" FORCE) set(CMAKE_INSTALL_INCLUDE include CACHE PATH "Output directory for header files") set(CMAKE_INSTALL_LIBDIR lib CACHE PATH "Output directory for libraries") diff --git a/include/LockPool.h b/include/LockPool.h index f63a97ab..7d82b943 100644 --- a/include/LockPool.h +++ b/include/LockPool.h @@ -5,6 +5,7 @@ #include #include #include +#include namespace douban { namespace mc { @@ -14,7 +15,7 @@ class OrderedLock { std::queue m_fifo_locks; protected: std::mutex m_fifo_access; - bool m_locked; + std::atomic m_locked; protected: OrderedLock() : m_locked(true) {}; diff --git a/libmc/__init__.py b/libmc/__init__.py index 2e4a1122..7154b0dc 100644 --- a/libmc/__init__.py +++ b/libmc/__init__.py @@ -1,4 +1,4 @@ -import os +import os, functools from ._client import ( PyClient, PyClientPool, ThreadUnsafe, encode_value, @@ -44,12 +44,32 @@ class Client(PyClient): class ClientPool(PyClientPool): pass +class ThreadedClient(): + @functools.wraps(ClientPool.__init__) + def __init__(self, *args, **kwargs): + self._client_pool = ClientPool(*args, **kwargs) + + def update_servers(self, servers): + return self._client_pool.update_servers(servers) + + def __getattr__(self, key): + if not hasattr(Client, key): + raise AttributeError + result = getattr(Client, key) + if callable(result): + @functools.wraps(result) + def wrapper(*args, **kwargs): + with self._client_pool.client() as mc: + return getattr(mc, key)(*args, **kwargs) + return wrapper + return result + DYNAMIC_LIBRARIES = [os.path.abspath(_libmc_so_file)] __all__ = [ - 'Client', 'ClientPool', 'ThreadUnsafe', '__VERSION__', + 'Client', 'ClientPool', 'ThreadedClient', 'ThreadUnsafe', '__VERSION__', 'encode_value', 'decode_value', 'MC_DEFAULT_EXPTIME', 'MC_POLL_TIMEOUT', 'MC_CONNECT_TIMEOUT', diff --git a/libmc/_client.pyx b/libmc/_client.pyx index b7162b78..c181098d 100644 --- a/libmc/_client.pyx +++ b/libmc/_client.pyx @@ -449,9 +449,6 @@ cdef class PyClientShell(PyClientSettings): self._thread_ident = None self._created_stack = traceback.extract_stack() - def __dealloc__(self): - del self._imp - def config(self, int opt, int val): self._imp.config(opt, val) @@ -1109,7 +1106,7 @@ cdef class PyClientShell(PyClientSettings): self._get_current_thread_ident())) def _get_current_thread_ident(self): - return (os.getpid(), threading.current_thread().name) + return (os.getpid(), threading.current_thread().native_id) def get_last_error(self): return self.last_error @@ -1137,6 +1134,9 @@ cdef class PyClient(PyClientShell): return True return False + def __dealloc__(self): + del self._imp + cdef class PyPoolClient(PyClientShell): cdef IndexedClient* _indexed @@ -1172,6 +1172,8 @@ cdef class PyClientPool(PyClientSettings): self.connect() self._initialized = True worker = self._imp._acquire() + return self.setup(worker) + # prone to race conditions pending mux and possibly fails to update if worker.index >= len(self.clients): self.clients += [None] * (worker.index - len(self.clients)) self.clients.append(self.setup(worker)) @@ -1184,8 +1186,8 @@ cdef class PyClientPool(PyClientSettings): @contextmanager def client(self): + worker = self.acquire() try: - worker = self.acquire() yield worker finally: self.release(worker) @@ -1200,3 +1202,6 @@ cdef class PyClientPool(PyClientSettings): self.servers = servers return True return False + + def __dealloc__(self): + del self._imp diff --git a/misc/memcached_server b/misc/memcached_server index 47fa7b82..14496462 100755 --- a/misc/memcached_server +++ b/misc/memcached_server @@ -153,6 +153,11 @@ case "$1" in cygdb . --skip-interpreter -- --args "$dbg" setup.py test fi ;; + run-test) + cd "$( dirname "${BASH_SOURCE[0]}" )/.." &> /dev/null + shift + python setup.py test -a "-k $*" + ;; *) printf 'Usage: %s {start|stop|restart} \n' "$prog" exit 1 diff --git a/src/ClientPool.cpp b/src/ClientPool.cpp index abe40971..87a08f3e 100644 --- a/src/ClientPool.cpp +++ b/src/ClientPool.cpp @@ -1,6 +1,5 @@ //#include #include -#include #include "ClientPool.h" namespace douban { diff --git a/tests/test_client_pool.cpp b/tests/test_client_pool.cpp index 6ae7ed73..d77b563c 100644 --- a/tests/test_client_pool.cpp +++ b/tests/test_client_pool.cpp @@ -14,18 +14,7 @@ const unsigned int start_port = 21211; const char host[] = "127.0.0.1"; unsigned int n_threads = 8; -TEST(test_client_pool, simple_set_get) { - uint32_t ports[n_servers]; - const char* hosts[n_servers]; - for (unsigned int i = 0; i < n_servers; i++) { - ports[i] = start_port + i; - hosts[i] = host; - } - - ClientPool* pool = new ClientPool(); - pool->config(CFG_HASH_FUNCTION, OPT_HASH_FNV1A_32); - pool->init(hosts, ports, n_servers); - +void inner_test_loop(ClientPool* pool) { retrieval_result_t **r_results = NULL; message_result_t **m_results = NULL; size_t nResults = 0; @@ -37,7 +26,7 @@ TEST(test_client_pool, simple_set_get) { const char* keys = &key[0]; const char* values = &value[0]; - for (unsigned int j = 0; j < n_ops * n_threads; j++) { + for (unsigned int j = 0; j < n_ops; j++) { gen_random(key, data_size); gen_random(value, data_size); auto c = pool->acquire(); @@ -48,6 +37,23 @@ TEST(test_client_pool, simple_set_get) { c->destroyRetrievalResult(); pool->release(c); } +} + +TEST(test_client_pool, simple_set_get) { + uint32_t ports[n_servers]; + const char* hosts[n_servers]; + for (unsigned int i = 0; i < n_servers; i++) { + ports[i] = start_port + i; + hosts[i] = host; + } + + ClientPool* pool = new ClientPool(); + pool->config(CFG_HASH_FUNCTION, OPT_HASH_FNV1A_32); + pool->init(hosts, ports, n_servers); + + for (unsigned int j = 0; j < n_threads; j++) { + inner_test_loop(pool); + } delete pool; } @@ -67,30 +73,7 @@ TEST(test_client_pool, threaded_set_get) { pool->init(hosts, ports, n_servers); for (unsigned int i = 0; i < n_threads; i++) { - threads[i] = std::thread([&pool]() { - retrieval_result_t **r_results = NULL; - message_result_t **m_results = NULL; - size_t nResults = 0; - flags_t flags[] = {}; - size_t data_lens[] = {data_size}; - exptime_t exptime = 0; - char key[data_size + 1]; - char value[data_size + 1]; - const char* keys = &key[0]; - const char* values = &value[0]; - - for (unsigned int j = 0; j < n_ops; j++) { - gen_random(key, data_size); - gen_random(value, data_size); - auto c = pool->acquire(); - c->set(&keys, data_lens, flags, exptime, NULL, 0, &values, data_lens, 1, &m_results, &nResults); - c->destroyMessageResult(); - c->get(&keys, data_lens, 1, &r_results, &nResults); - ASSERT_N_STREQ(r_results[0]->data_block, values, data_size); - c->destroyRetrievalResult(); - pool->release(c); - } - }); + threads[i] = std::thread([&pool] { inner_test_loop(pool); }); } for (unsigned int i = 0; i < n_threads; i++) { threads[i].join(); diff --git a/tests/test_client_pool.py b/tests/test_client_pool.py index 6344199f..81dab0b1 100644 --- a/tests/test_client_pool.py +++ b/tests/test_client_pool.py @@ -1,8 +1,7 @@ # coding: utf-8 import unittest from threading import Thread -from libmc import ClientPool - +from libmc import ClientPool#, ThreadedClient class ThreadedSingleServerCase(unittest.TestCase): def setUp(self): @@ -44,3 +43,66 @@ def test_pool_client_threaded(self): for t in ts: t.join() + +''' +class ClientOps: + def client_misc(self, mc, i=0): + tid = mc._get_current_thread_ident() + (i,) + tid = "_".join(map(str, tid)) + f, t = 'foo_' + tid, 'tuiche_' + tid + mc.get_multi([f, t]) + mc.delete(f) + mc.delete(t) + assert mc.get(f) is None + assert mc.get(t) is None + + mc.set(f, 'biu') + mc.set(t, 'bb') + assert mc.get(f) == 'biu' + assert mc.get(t) == 'bb' + assert (mc.get_multi([f, t]) == + {f: 'biu', t: 'bb'}) + mc.set_multi({f: 1024, t: '8964'}) + assert (mc.get_multi([f, t]) == + {f: 1024, t: '8964'}) + + def client_threads(self, target): + ts = [Thread(target=target) for i in range(8)] + for t in ts: + t.start() + + for t in ts: + t.join() + +class ThreadedSingleServerCase(unittest.TestCase, ClientOps): + def setUp(self): + self.pool = ClientPool(["127.0.0.1:21211"]) + + def misc(self): + for i in range(5): + self.test_pool_client_misc(i) + + def test_pool_client_misc(self, i=0): + with self.pool.client() as mc: + self.client_misc(mc, i) + + def test_acquire(self): + with self.pool.client() as mc: + pass + + def test_pool_client_threaded(self): + with open('debug.log', 'a') as f: + f.write("stdout working\n") + self.client_threads(self.misc) + +class ThreadedClientWrapperCheck(unittest.TestCase, ClientOps): + def setUp(self): + self.imp = ThreadedClient(["127.0.0.1:21211"]) + + def misc(self): + for i in range(5): + self.client_misc(self.imp, i) + + def test_many_threads(self): + self.client_threads(self.misc) +'''