Skip to content

Commit

Permalink
removed race conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
kentslaney committed Jan 14, 2024
1 parent 4079333 commit 2d05017
Show file tree
Hide file tree
Showing 9 changed files with 125 additions and 50 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ tests/resources/keys_*.txt
/go.mod
/go.sum

/cython_debug
/cython_debug
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ if (NOT CMAKE_BUILD_TYPE)
endif (NOT CMAKE_BUILD_TYPE)

set(CMAKE_CXX_FLAGS_COMMON "-Wall -fno-rtti -fno-exceptions")
set(CMAKE_CXX_FLAGS_DEBUG "-DDEBUG -g2 ${CMAKE_CXX_FLAGS_COMMON}" CACHE STRING "CXX DEBUG FLAGS" FORCE)
set(CMAKE_CXX_FLAGS_DEBUG "-DDEBUG -g2 ${CMAKE_CXX_FLAGS_COMMON} -fsanitize=thread" CACHE STRING "CXX DEBUG FLAGS" FORCE)
set(CMAKE_CXX_FLAGS_RELEASE "-DNDEBUG -O3 ${CMAKE_CXX_FLAGS_COMMON}" CACHE STRING "CXX RELEASE FLAGS" FORCE)
set(CMAKE_INSTALL_INCLUDE include CACHE PATH "Output directory for header files")
set(CMAKE_INSTALL_LIBDIR lib CACHE PATH "Output directory for libraries")
Expand Down
3 changes: 2 additions & 1 deletion include/LockPool.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <queue>
#include <deque>
#include <vector>
#include <atomic>

namespace douban {
namespace mc {
Expand All @@ -14,7 +15,7 @@ class OrderedLock {
std::queue<std::condition_variable*> m_fifo_locks;
protected:
std::mutex m_fifo_access;
bool m_locked;
std::atomic<bool> m_locked;

protected:
OrderedLock() : m_locked(true) {};
Expand Down
24 changes: 22 additions & 2 deletions libmc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os
import os, functools
from ._client import (
PyClient, PyClientPool, ThreadUnsafe,
encode_value,
Expand Down Expand Up @@ -44,12 +44,32 @@ class Client(PyClient):
class ClientPool(PyClientPool):
pass

class ThreadedClient():
@functools.wraps(ClientPool.__init__)
def __init__(self, *args, **kwargs):
self._client_pool = ClientPool(*args, **kwargs)

def update_servers(self, servers):
return self._client_pool.update_servers(servers)

def __getattr__(self, key):
if not hasattr(Client, key):
raise AttributeError
result = getattr(Client, key)
if callable(result):
@functools.wraps(result)
def wrapper(*args, **kwargs):
with self._client_pool.client() as mc:
return getattr(mc, key)(*args, **kwargs)
return wrapper
return result


DYNAMIC_LIBRARIES = [os.path.abspath(_libmc_so_file)]


__all__ = [
'Client', 'ClientPool', 'ThreadUnsafe', '__VERSION__',
'Client', 'ClientPool', 'ThreadedClient', 'ThreadUnsafe', '__VERSION__',
'encode_value', 'decode_value',

'MC_DEFAULT_EXPTIME', 'MC_POLL_TIMEOUT', 'MC_CONNECT_TIMEOUT',
Expand Down
15 changes: 10 additions & 5 deletions libmc/_client.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -449,9 +449,6 @@ cdef class PyClientShell(PyClientSettings):
self._thread_ident = None
self._created_stack = traceback.extract_stack()

def __dealloc__(self):
del self._imp

def config(self, int opt, int val):
self._imp.config(<config_options_t>opt, val)

Expand Down Expand Up @@ -1109,7 +1106,7 @@ cdef class PyClientShell(PyClientSettings):
self._get_current_thread_ident()))

def _get_current_thread_ident(self):
return (os.getpid(), threading.current_thread().name)
return (os.getpid(), threading.current_thread().native_id)

def get_last_error(self):
return self.last_error
Expand Down Expand Up @@ -1137,6 +1134,9 @@ cdef class PyClient(PyClientShell):
return True
return False

def __dealloc__(self):
del self._imp

cdef class PyPoolClient(PyClientShell):
cdef IndexedClient* _indexed

Expand Down Expand Up @@ -1172,6 +1172,8 @@ cdef class PyClientPool(PyClientSettings):
self.connect()
self._initialized = True
worker = self._imp._acquire()
return self.setup(worker)
# prone to race conditions pending mux and possibly fails to update
if worker.index >= len(self.clients):
self.clients += [None] * (worker.index - len(self.clients))
self.clients.append(self.setup(worker))
Expand All @@ -1184,8 +1186,8 @@ cdef class PyClientPool(PyClientSettings):

@contextmanager
def client(self):
worker = self.acquire()
try:
worker = self.acquire()
yield worker
finally:
self.release(worker)
Expand All @@ -1200,3 +1202,6 @@ cdef class PyClientPool(PyClientSettings):
self.servers = servers
return True
return False

def __dealloc__(self):
del self._imp
5 changes: 5 additions & 0 deletions misc/memcached_server
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ case "$1" in
cygdb . --skip-interpreter -- --args "$dbg" setup.py test
fi
;;
run-test)
cd "$( dirname "${BASH_SOURCE[0]}" )/.." &> /dev/null
shift
python setup.py test -a "-k $*"
;;
*)
printf 'Usage: %s {start|stop|restart} <port>\n' "$prog"
exit 1
Expand Down
1 change: 0 additions & 1 deletion src/ClientPool.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
//#include <execution>
#include <thread>
#include <atomic>
#include "ClientPool.h"

namespace douban {
Expand Down
57 changes: 20 additions & 37 deletions tests/test_client_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,7 @@ const unsigned int start_port = 21211;
const char host[] = "127.0.0.1";
unsigned int n_threads = 8;

TEST(test_client_pool, simple_set_get) {
uint32_t ports[n_servers];
const char* hosts[n_servers];
for (unsigned int i = 0; i < n_servers; i++) {
ports[i] = start_port + i;
hosts[i] = host;
}

ClientPool* pool = new ClientPool();
pool->config(CFG_HASH_FUNCTION, OPT_HASH_FNV1A_32);
pool->init(hosts, ports, n_servers);

void inner_test_loop(ClientPool* pool) {
retrieval_result_t **r_results = NULL;
message_result_t **m_results = NULL;
size_t nResults = 0;
Expand All @@ -37,7 +26,7 @@ TEST(test_client_pool, simple_set_get) {
const char* keys = &key[0];
const char* values = &value[0];

for (unsigned int j = 0; j < n_ops * n_threads; j++) {
for (unsigned int j = 0; j < n_ops; j++) {
gen_random(key, data_size);
gen_random(value, data_size);
auto c = pool->acquire();
Expand All @@ -48,6 +37,23 @@ TEST(test_client_pool, simple_set_get) {
c->destroyRetrievalResult();
pool->release(c);
}
}

TEST(test_client_pool, simple_set_get) {
uint32_t ports[n_servers];
const char* hosts[n_servers];
for (unsigned int i = 0; i < n_servers; i++) {
ports[i] = start_port + i;
hosts[i] = host;
}

ClientPool* pool = new ClientPool();
pool->config(CFG_HASH_FUNCTION, OPT_HASH_FNV1A_32);
pool->init(hosts, ports, n_servers);

for (unsigned int j = 0; j < n_threads; j++) {
inner_test_loop(pool);
}

delete pool;
}
Expand All @@ -67,30 +73,7 @@ TEST(test_client_pool, threaded_set_get) {
pool->init(hosts, ports, n_servers);

for (unsigned int i = 0; i < n_threads; i++) {
threads[i] = std::thread([&pool]() {
retrieval_result_t **r_results = NULL;
message_result_t **m_results = NULL;
size_t nResults = 0;
flags_t flags[] = {};
size_t data_lens[] = {data_size};
exptime_t exptime = 0;
char key[data_size + 1];
char value[data_size + 1];
const char* keys = &key[0];
const char* values = &value[0];

for (unsigned int j = 0; j < n_ops; j++) {
gen_random(key, data_size);
gen_random(value, data_size);
auto c = pool->acquire();
c->set(&keys, data_lens, flags, exptime, NULL, 0, &values, data_lens, 1, &m_results, &nResults);
c->destroyMessageResult();
c->get(&keys, data_lens, 1, &r_results, &nResults);
ASSERT_N_STREQ(r_results[0]->data_block, values, data_size);
c->destroyRetrievalResult();
pool->release(c);
}
});
threads[i] = std::thread([&pool] { inner_test_loop(pool); });
}
for (unsigned int i = 0; i < n_threads; i++) {
threads[i].join();
Expand Down
66 changes: 64 additions & 2 deletions tests/test_client_pool.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# coding: utf-8
import unittest
from threading import Thread
from libmc import ClientPool

from libmc import ClientPool#, ThreadedClient

class ThreadedSingleServerCase(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -44,3 +43,66 @@ def test_pool_client_threaded(self):

for t in ts:
t.join()

'''
class ClientOps:
def client_misc(self, mc, i=0):
tid = mc._get_current_thread_ident() + (i,)
tid = "_".join(map(str, tid))
f, t = 'foo_' + tid, 'tuiche_' + tid
mc.get_multi([f, t])
mc.delete(f)
mc.delete(t)
assert mc.get(f) is None
assert mc.get(t) is None
mc.set(f, 'biu')
mc.set(t, 'bb')
assert mc.get(f) == 'biu'
assert mc.get(t) == 'bb'
assert (mc.get_multi([f, t]) ==
{f: 'biu', t: 'bb'})
mc.set_multi({f: 1024, t: '8964'})
assert (mc.get_multi([f, t]) ==
{f: 1024, t: '8964'})
def client_threads(self, target):
ts = [Thread(target=target) for i in range(8)]
for t in ts:
t.start()
for t in ts:
t.join()
class ThreadedSingleServerCase(unittest.TestCase, ClientOps):
def setUp(self):
self.pool = ClientPool(["127.0.0.1:21211"])
def misc(self):
for i in range(5):
self.test_pool_client_misc(i)
def test_pool_client_misc(self, i=0):
with self.pool.client() as mc:
self.client_misc(mc, i)
def test_acquire(self):
with self.pool.client() as mc:
pass
def test_pool_client_threaded(self):
with open('debug.log', 'a') as f:
f.write("stdout working\n")
self.client_threads(self.misc)
class ThreadedClientWrapperCheck(unittest.TestCase, ClientOps):
def setUp(self):
self.imp = ThreadedClient(["127.0.0.1:21211"])
def misc(self):
for i in range(5):
self.client_misc(self.imp, i)
def test_many_threads(self):
self.client_threads(self.misc)
'''

0 comments on commit 2d05017

Please sign in to comment.