From 2322e341e2bb13bf8e231461257a1b1687e36e4e Mon Sep 17 00:00:00 2001 From: drcpu Date: Sat, 22 Jul 2023 21:07:05 +0200 Subject: [PATCH] use ClientPool to prevent race conditions when using pylibmc as memcached package --- src/cachelib/memcached.py | 75 +++++++++++++++++++++++++++++++-------- 1 file changed, 61 insertions(+), 14 deletions(-) diff --git a/src/cachelib/memcached.py b/src/cachelib/memcached.py index 4fbcce68..462d7352 100644 --- a/src/cachelib/memcached.py +++ b/src/cachelib/memcached.py @@ -49,12 +49,17 @@ def __init__( servers: _t.Any = None, default_timeout: int = 300, key_prefix: _t.Optional[str] = None, + threads: int = 1, + blocking: bool = False, ): BaseCache.__init__(self, default_timeout) + + self.pylibmc_used = False + if servers is None or isinstance(servers, (list, tuple)): if servers is None: servers = ["127.0.0.1:11211"] - self._client = self.import_preferred_memcache_lib(servers) + self._client = self.import_preferred_memcache_lib(servers, threads) if self._client is None: raise RuntimeError("no memcache module found") else: @@ -62,6 +67,7 @@ def __init__( # client. self._client = servers + self.blocking = blocking self.key_prefix = key_prefix def _normalize_key(self, key: str) -> str: @@ -81,7 +87,11 @@ def get(self, key: str) -> _t.Any: # checks for so long keys can occur because it's tested from user # submitted data etc we fail silently for getting. if _test_memcached_key(key): - return self._client.get(key) + if self.pylibmc_used: + with self._client.reserve(block=self.blocking) as mc: + return mc.get(self._normalize_key(key)) + else: + return self._client.get(key) def get_dict(self, *keys: str) -> _t.Dict[str, _t.Any]: key_mapping = {} @@ -90,7 +100,11 @@ def get_dict(self, *keys: str) -> _t.Dict[str, _t.Any]: if _test_memcached_key(key): key_mapping[encoded_key] = key _keys = list(key_mapping) - d = rv = self._client.get_multi(_keys) # type: _t.Dict[str, _t.Any] + if self.pylibmc_used: + with self._client.reserve(block=self.blocking) as mc: + d = rv = mc.get_multi(_keys) # type: _t.Dict[str, _t.Any] + else: + d = rv = self._client.get_multi(_keys) # type: _t.Dict[str, _t.Any] if self.key_prefix: rv = {} for key, value in d.items(): @@ -104,14 +118,22 @@ def get_dict(self, *keys: str) -> _t.Dict[str, _t.Any]: def add(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> bool: key = self._normalize_key(key) timeout = self._normalize_timeout(timeout) - return bool(self._client.add(key, value, timeout)) + if self.pylibmc_used: + with self._client.reserve(block=self.blocking) as mc: + return bool(mc.add(key, value, timeout)) + else: + return bool(self._client.add(key, value, timeout)) def set( self, key: str, value: _t.Any, timeout: _t.Optional[int] = None ) -> _t.Optional[bool]: key = self._normalize_key(key) timeout = self._normalize_timeout(timeout) - return bool(self._client.set(key, value, timeout)) + if self.pylibmc_used: + with self._client.reserve(block=self.blocking) as mc: + return bool(mc.set(key, value, timeout)) + else: + return bool(self._client.set(key, value, timeout)) def get_many(self, *keys: str) -> _t.List[_t.Any]: d = self.get_dict(*keys) @@ -126,16 +148,26 @@ def set_many( new_mapping[key] = value timeout = self._normalize_timeout(timeout) - failed_keys = self._client.set_multi( - new_mapping, timeout - ) # type: _t.List[_t.Any] + if self.pylibmc_used: + with self._client.reserve(block=self.blocking) as mc: + failed_keys = mc.set_multi( + new_mapping, timeout + ) # type: _t.List[_t.Any] + else: + failed_keys = self._client.set_multi( + new_mapping, timeout + ) # type: _t.List[_t.Any] k_normkey = zip(mapping.keys(), new_mapping.keys()) # noqa: B905 return [k for k, nkey in k_normkey if nkey not in failed_keys] def delete(self, key: str) -> bool: key = self._normalize_key(key) if _test_memcached_key(key): - return bool(self._client.delete(key)) + if self.pylibmc_used: + with self._client.reserve(block=self.blocking) as mc: + return bool(mc.delete(key)) + else: + return bool(self._client.delete(key)) return False def delete_many(self, *keys: str) -> _t.List[_t.Any]: @@ -144,17 +176,29 @@ def delete_many(self, *keys: str) -> _t.List[_t.Any]: key = self._normalize_key(key) if _test_memcached_key(key): new_keys.append(key) - self._client.delete_multi(new_keys) + if self.pylibmc_used: + with self._client.reserve(block=self.blocking) as mc: + mc.delete_multi(new_keys) + else: + self._client.delete_multi(new_keys) return [k for k in new_keys if not self.has(k)] def has(self, key: str) -> bool: key = self._normalize_key(key) if _test_memcached_key(key): - return bool(self._client.append(key, "")) + if self.pylibmc_used: + with self._client.reserve(block=self.blocking) as mc: + return bool(mc.append(key, "")) + else: + return bool(self._client.append(key, "")) return False def clear(self) -> bool: - return bool(self._client.flush_all()) + if self.pylibmc_used: + with self._client.reserve(block=self.blocking) as mc: + return bool(mc.flush_all()) + else: + return bool(self._client.flush_all()) def inc(self, key: str, delta: int = 1) -> _t.Optional[int]: key = self._normalize_key(key) @@ -166,14 +210,17 @@ def dec(self, key: str, delta: int = 1) -> _t.Optional[int]: value = (self._client.get(key) or 0) - delta return value if self.set(key, value) else None - def import_preferred_memcache_lib(self, servers: _t.Any) -> _t.Any: + def import_preferred_memcache_lib(self, servers: _t.Any, threads: int) -> _t.Any: """Returns an initialized memcache client. Used by the constructor.""" try: import pylibmc # type: ignore except ImportError: pass else: - return pylibmc.Client(servers) + self.pylibmc_used = True + _client_pool = pylibmc.ClientPool() + _client_pool.fill(pylibmc.Client(servers), threads) + return _client_pool try: from google.appengine.api import memcache # type: ignore