Skip to content

Commit

Permalink
use ClientPool to prevent race conditions when using pylibmc as memca…
Browse files Browse the repository at this point in the history
…ched package
  • Loading branch information
drcpu-github committed Jul 22, 2023
1 parent eafb40b commit 2322e34
Showing 1 changed file with 61 additions and 14 deletions.
75 changes: 61 additions & 14 deletions src/cachelib/memcached.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,25 @@ def __init__(
servers: _t.Any = None,
default_timeout: int = 300,
key_prefix: _t.Optional[str] = None,
threads: int = 1,
blocking: bool = False,
):
BaseCache.__init__(self, default_timeout)

self.pylibmc_used = False

if servers is None or isinstance(servers, (list, tuple)):
if servers is None:
servers = ["127.0.0.1:11211"]
self._client = self.import_preferred_memcache_lib(servers)
self._client = self.import_preferred_memcache_lib(servers, threads)
if self._client is None:
raise RuntimeError("no memcache module found")
else:
# NOTE: servers is actually an already initialized memcache
# client.
self._client = servers

self.blocking = blocking
self.key_prefix = key_prefix

def _normalize_key(self, key: str) -> str:
Expand All @@ -81,7 +87,11 @@ def get(self, key: str) -> _t.Any:
# checks for so long keys can occur because it's tested from user
# submitted data etc we fail silently for getting.
if _test_memcached_key(key):
return self._client.get(key)
if self.pylibmc_used:
with self._client.reserve(block=self.blocking) as mc:
return mc.get(self._normalize_key(key))
else:
return self._client.get(key)

def get_dict(self, *keys: str) -> _t.Dict[str, _t.Any]:
key_mapping = {}
Expand All @@ -90,7 +100,11 @@ def get_dict(self, *keys: str) -> _t.Dict[str, _t.Any]:
if _test_memcached_key(key):
key_mapping[encoded_key] = key
_keys = list(key_mapping)
d = rv = self._client.get_multi(_keys) # type: _t.Dict[str, _t.Any]
if self.pylibmc_used:
with self._client.reserve(block=self.blocking) as mc:
d = rv = mc.get_multi(_keys) # type: _t.Dict[str, _t.Any]
else:
d = rv = self._client.get_multi(_keys) # type: _t.Dict[str, _t.Any]
if self.key_prefix:
rv = {}
for key, value in d.items():
Expand All @@ -104,14 +118,22 @@ def get_dict(self, *keys: str) -> _t.Dict[str, _t.Any]:
def add(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> bool:
key = self._normalize_key(key)
timeout = self._normalize_timeout(timeout)
return bool(self._client.add(key, value, timeout))
if self.pylibmc_used:
with self._client.reserve(block=self.blocking) as mc:
return bool(mc.add(key, value, timeout))
else:
return bool(self._client.add(key, value, timeout))

def set(
self, key: str, value: _t.Any, timeout: _t.Optional[int] = None
) -> _t.Optional[bool]:
key = self._normalize_key(key)
timeout = self._normalize_timeout(timeout)
return bool(self._client.set(key, value, timeout))
if self.pylibmc_used:
with self._client.reserve(block=self.blocking) as mc:
return bool(mc.set(key, value, timeout))
else:
return bool(self._client.set(key, value, timeout))

def get_many(self, *keys: str) -> _t.List[_t.Any]:
d = self.get_dict(*keys)
Expand All @@ -126,16 +148,26 @@ def set_many(
new_mapping[key] = value

timeout = self._normalize_timeout(timeout)
failed_keys = self._client.set_multi(
new_mapping, timeout
) # type: _t.List[_t.Any]
if self.pylibmc_used:
with self._client.reserve(block=self.blocking) as mc:
failed_keys = mc.set_multi(
new_mapping, timeout
) # type: _t.List[_t.Any]
else:
failed_keys = self._client.set_multi(
new_mapping, timeout
) # type: _t.List[_t.Any]
k_normkey = zip(mapping.keys(), new_mapping.keys()) # noqa: B905
return [k for k, nkey in k_normkey if nkey not in failed_keys]

def delete(self, key: str) -> bool:
key = self._normalize_key(key)
if _test_memcached_key(key):
return bool(self._client.delete(key))
if self.pylibmc_used:
with self._client.reserve(block=self.blocking) as mc:
return bool(mc.delete(key))
else:
return bool(self._client.delete(key))
return False

def delete_many(self, *keys: str) -> _t.List[_t.Any]:
Expand All @@ -144,17 +176,29 @@ def delete_many(self, *keys: str) -> _t.List[_t.Any]:
key = self._normalize_key(key)
if _test_memcached_key(key):
new_keys.append(key)
self._client.delete_multi(new_keys)
if self.pylibmc_used:
with self._client.reserve(block=self.blocking) as mc:
mc.delete_multi(new_keys)
else:
self._client.delete_multi(new_keys)
return [k for k in new_keys if not self.has(k)]

def has(self, key: str) -> bool:
key = self._normalize_key(key)
if _test_memcached_key(key):
return bool(self._client.append(key, ""))
if self.pylibmc_used:
with self._client.reserve(block=self.blocking) as mc:
return bool(mc.append(key, ""))
else:
return bool(self._client.append(key, ""))
return False

def clear(self) -> bool:
return bool(self._client.flush_all())
if self.pylibmc_used:
with self._client.reserve(block=self.blocking) as mc:
return bool(mc.flush_all())
else:
return bool(self._client.flush_all())

def inc(self, key: str, delta: int = 1) -> _t.Optional[int]:
key = self._normalize_key(key)
Expand All @@ -166,14 +210,17 @@ def dec(self, key: str, delta: int = 1) -> _t.Optional[int]:
value = (self._client.get(key) or 0) - delta
return value if self.set(key, value) else None

def import_preferred_memcache_lib(self, servers: _t.Any) -> _t.Any:
def import_preferred_memcache_lib(self, servers: _t.Any, threads: int) -> _t.Any:
"""Returns an initialized memcache client. Used by the constructor."""
try:
import pylibmc # type: ignore
except ImportError:
pass
else:
return pylibmc.Client(servers)
self.pylibmc_used = True
_client_pool = pylibmc.ClientPool()
_client_pool.fill(pylibmc.Client(servers), threads)
return _client_pool

try:
from google.appengine.api import memcache # type: ignore
Expand Down

0 comments on commit 2322e34

Please sign in to comment.