diff --git a/py_rql/parser.py b/py_rql/parser.py index 0d993e1..9ed972d 100644 --- a/py_rql/parser.py +++ b/py_rql/parser.py @@ -2,6 +2,8 @@ # Copyright © 2022 Ingram Micro Inc. All rights reserved. # +from threading import Lock + from cachetools import LFUCache from lark import Lark from lark.exceptions import LarkError @@ -15,17 +17,23 @@ def __init__(self, *args, **kwargs): super(RQLLarkParser, self).__init__(*args, **kwargs) self._cache = LFUCache(maxsize=1000) + self._lock = Lock() def parse_query(self, query): cache_key = hash(query) - if cache_key in self._cache: - return self._cache[cache_key] + try: - rql_ast = self.parse(query) - self._cache[cache_key] = rql_ast - return rql_ast - except LarkError: - raise RQLFilterParsingError() + return self._cache[cache_key] + except KeyError: + + try: + rql_ast = self.parse(query) + with self._lock: + self._cache[cache_key] = rql_ast + + return rql_ast + except LarkError: + raise RQLFilterParsingError() RQLParser = RQLLarkParser(RQL_GRAMMAR, parser='lalr', start='start') diff --git a/tests/test_init.py b/tests/test_init.py index e2956b7..0c06b0c 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -1,11 +1,17 @@ # # Copyright © 2022 Ingram Micro Inc. All rights reserved. # +import time +from threading import Thread + import pytest +from cachetools import LFUCache from lark import Tree from py_rql import parse from py_rql.exceptions import RQLFilterParsingError +from py_rql.grammar import RQL_GRAMMAR +from py_rql.parser import RQLLarkParser def test_parse_ok(): @@ -15,3 +21,40 @@ def test_parse_ok(): def test_parse_fail(): with pytest.raises(RQLFilterParsingError): parse('a=') + + +def test_parse_locks(): + class Cache(LFUCache): + def pop(self, key): + time.sleep(0.5) + + return super().pop(key) + + cache = Cache(maxsize=1) + parser = RQLLarkParser(RQL_GRAMMAR, parser='lalr', start='start') + parser._cache = cache + parser.parse_query('a=b') + + def func1(): + parser.parse_query('b=c') + + has_exception = False + + def func2(): + nonlocal has_exception + + try: + parser.parse_query('c=d') + except KeyError: + has_exception = True + + t1 = Thread(target=func1) + t2 = Thread(target=func2) + + t1.start() + t2.start() + t1.join() + t2.join() + + assert not has_exception + assert hash('c=d') in cache