diff --git a/gokart/run_with_lock.py b/gokart/conflict_prevention_lock/run_with_lock.py similarity index 100% rename from gokart/run_with_lock.py rename to gokart/conflict_prevention_lock/run_with_lock.py diff --git a/gokart/conflict_prevention_lock/task_lock.py b/gokart/conflict_prevention_lock/task_lock.py new file mode 100644 index 00000000..6dbb54b1 --- /dev/null +++ b/gokart/conflict_prevention_lock/task_lock.py @@ -0,0 +1,102 @@ +import functools +import os +from logging import getLogger +from typing import NamedTuple, Optional + +import redis +from apscheduler.schedulers.background import BackgroundScheduler + +logger = getLogger(__name__) + + +class TaskLockParams(NamedTuple): + redis_host: Optional[str] + redis_port: Optional[int] + redis_timeout: Optional[int] + redis_key: str + should_task_lock: bool + raise_task_lock_exception_on_collision: bool + lock_extend_seconds: int + + +class TaskLockException(Exception): + pass + + +class RedisClient: + _instances: dict = {} + + def __new__(cls, *args, **kwargs): + key = (args, tuple(sorted(kwargs.items()))) + if cls not in cls._instances: + cls._instances[cls] = {} + if key not in cls._instances[cls]: + cls._instances[cls][key] = super(RedisClient, cls).__new__(cls) + return cls._instances[cls][key] + + def __init__(self, host: Optional[str], port: Optional[int]) -> None: + if not hasattr(self, '_redis_client'): + host = host or 'localhost' + port = port or 6379 + self._redis_client = redis.Redis(host=host, port=port) + + def get_redis_client(self): + return self._redis_client + + +def _extend_lock(task_lock: redis.lock.Lock, redis_timeout: int): + task_lock.extend(additional_time=redis_timeout, replace_ttl=True) + + +def set_task_lock(task_lock_params: TaskLockParams) -> redis.lock.Lock: + redis_client = RedisClient(host=task_lock_params.redis_host, port=task_lock_params.redis_port).get_redis_client() + blocking = not task_lock_params.raise_task_lock_exception_on_collision + task_lock = redis.lock.Lock(redis=redis_client, name=task_lock_params.redis_key, timeout=task_lock_params.redis_timeout, thread_local=False) + if not task_lock.acquire(blocking=blocking): + raise TaskLockException('Lock already taken by other task.') + return task_lock + + +def set_lock_scheduler(task_lock: redis.lock.Lock, task_lock_params: TaskLockParams) -> BackgroundScheduler: + scheduler = BackgroundScheduler() + extend_lock = functools.partial(_extend_lock, task_lock=task_lock, redis_timeout=task_lock_params.redis_timeout) + scheduler.add_job( + extend_lock, + 'interval', + seconds=task_lock_params.lock_extend_seconds, + max_instances=999999999, + misfire_grace_time=task_lock_params.redis_timeout, + coalesce=False, + ) + scheduler.start() + return scheduler + + +def make_task_lock_key(file_path: str, unique_id: Optional[str]): + basename_without_ext = os.path.splitext(os.path.basename(file_path))[0] + return f'{basename_without_ext}_{unique_id}' + + +def make_task_lock_params( + file_path: str, + unique_id: Optional[str], + redis_host: Optional[str] = None, + redis_port: Optional[int] = None, + redis_timeout: Optional[int] = None, + raise_task_lock_exception_on_collision: bool = False, + lock_extend_seconds: int = 10, +) -> TaskLockParams: + redis_key = make_task_lock_key(file_path, unique_id) + should_task_lock = redis_host is not None and redis_port is not None + if redis_timeout is not None: + assert redis_timeout > lock_extend_seconds, f'`redis_timeout` must be set greater than lock_extend_seconds:{lock_extend_seconds}, not {redis_timeout}.' + task_lock_params = TaskLockParams( + redis_host=redis_host, + redis_port=redis_port, + redis_key=redis_key, + should_task_lock=should_task_lock, + redis_timeout=redis_timeout, + raise_task_lock_exception_on_collision=raise_task_lock_exception_on_collision, + lock_extend_seconds=lock_extend_seconds, + ) + return task_lock_params diff --git a/gokart/conflict_prevention_lock/task_lock_wrappers.py b/gokart/conflict_prevention_lock/task_lock_wrappers.py new file mode 100644 index 00000000..861cb3a3 --- /dev/null +++ b/gokart/conflict_prevention_lock/task_lock_wrappers.py @@ -0,0 +1,94 @@ +from logging import getLogger +from typing import Callable + +from gokart.conflict_prevention_lock.task_lock import TaskLockParams, set_lock_scheduler, set_task_lock + +logger = getLogger(__name__) + + +def _wrap_with_lock(func, task_lock_params: TaskLockParams): + if not task_lock_params.should_task_lock: + return func + + def wrapper(*args, **kwargs): + task_lock = set_task_lock(task_lock_params=task_lock_params) + scheduler = set_lock_scheduler(task_lock=task_lock, task_lock_params=task_lock_params) + + try: + logger.debug(f'Task lock of {task_lock_params.redis_key} locked.') + result = func(*args, **kwargs) + task_lock.release() + logger.debug(f'Task lock of {task_lock_params.redis_key} released.') + scheduler.shutdown() + return result + except BaseException as e: + logger.debug(f'Task lock of {task_lock_params.redis_key} released with BaseException.') + task_lock.release() + scheduler.shutdown() + raise e + + return wrapper + + +def wrap_with_run_lock(func: Callable, task_lock_params: TaskLockParams): + """Redis lock wrapper function for RunWithLock. + When a fucntion is wrapped by RunWithLock, the wrapped function will be simply wrapped with redis lock. + https://github.com/m3dev/gokart/issues/265 + """ + return _wrap_with_lock(func=func, task_lock_params=task_lock_params) + + +def wrap_with_dump_lock(func: Callable, task_lock_params: TaskLockParams, exist_check: Callable): + """Redis lock wrapper function for TargetOnKart.dump(). + When TargetOnKart.dump() is called, dump() will be wrapped with redis lock and cache existance check. + https://github.com/m3dev/gokart/issues/265 + """ + + if not task_lock_params.should_task_lock: + return func + + def wrapper(*args, **kwargs): + task_lock = set_task_lock(task_lock_params=task_lock_params) + scheduler = set_lock_scheduler(task_lock=task_lock, task_lock_params=task_lock_params) + + try: + logger.debug(f'Task lock of {task_lock_params.redis_key} locked.') + if not exist_check(): + func(*args, **kwargs) + finally: + logger.debug(f'Task lock of {task_lock_params.redis_key} released.') + task_lock.release() + scheduler.shutdown() + + return wrapper + + +def wrap_with_load_lock(func, task_lock_params: TaskLockParams): + """Redis lock wrapper function for TargetOnKart.load(). + When TargetOnKart.load() is called, redis lock will be locked and released before load(). + https://github.com/m3dev/gokart/issues/265 + """ + + if not task_lock_params.should_task_lock: + return func + + def wrapper(*args, **kwargs): + task_lock = set_task_lock(task_lock_params=task_lock_params) + scheduler = set_lock_scheduler(task_lock=task_lock, task_lock_params=task_lock_params) + + logger.debug(f'Task lock of {task_lock_params.redis_key} locked.') + task_lock.release() + logger.debug(f'Task lock of {task_lock_params.redis_key} released.') + scheduler.shutdown() + result = func(*args, **kwargs) + return result + + return wrapper + + +def wrap_with_remove_lock(func, task_lock_params: TaskLockParams): + """Redis lock wrapper function for TargetOnKart.remove(). + When TargetOnKart.remove() is called, remove() will be simply wrapped with redis lock. + https://github.com/m3dev/gokart/issues/265 + """ + return _wrap_with_lock(func=func, task_lock_params=task_lock_params) diff --git a/gokart/redis_lock.py b/gokart/redis_lock.py deleted file mode 100644 index 00fecc6d..00000000 --- a/gokart/redis_lock.py +++ /dev/null @@ -1,190 +0,0 @@ -import functools -import os -from logging import getLogger -from typing import Callable, NamedTuple, Optional - -import redis -from apscheduler.schedulers.background import BackgroundScheduler - -logger = getLogger(__name__) - - -class RedisParams(NamedTuple): - redis_host: Optional[str] - redis_port: Optional[int] - redis_timeout: Optional[int] - redis_key: str - should_redis_lock: bool - raise_task_lock_exception_on_collision: bool - lock_extend_seconds: int - - -class TaskLockException(Exception): - pass - - -class RedisClient: - _instances: dict = {} - - def __new__(cls, *args, **kwargs): - key = (args, tuple(sorted(kwargs.items()))) - if cls not in cls._instances: - cls._instances[cls] = {} - if key not in cls._instances[cls]: - cls._instances[cls][key] = super(RedisClient, cls).__new__(cls) - return cls._instances[cls][key] - - def __init__(self, host: Optional[str], port: Optional[int]) -> None: - if not hasattr(self, '_redis_client'): - host = host or 'localhost' - port = port or 6379 - self._redis_client = redis.Redis(host=host, port=port) - - def get_redis_client(self): - return self._redis_client - - -def _extend_lock(redis_lock: redis.lock.Lock, redis_timeout: int): - redis_lock.extend(additional_time=redis_timeout, replace_ttl=True) - - -def _set_redis_lock(redis_params: RedisParams) -> redis.lock.Lock: - redis_client = RedisClient(host=redis_params.redis_host, port=redis_params.redis_port).get_redis_client() - blocking = not redis_params.raise_task_lock_exception_on_collision - redis_lock = redis.lock.Lock(redis=redis_client, name=redis_params.redis_key, timeout=redis_params.redis_timeout, thread_local=False) - if not redis_lock.acquire(blocking=blocking): - raise TaskLockException('Lock already taken by other task.') - return redis_lock - - -def _set_lock_scheduler(redis_lock: redis.lock.Lock, redis_params: RedisParams) -> BackgroundScheduler: - scheduler = BackgroundScheduler() - extend_lock = functools.partial(_extend_lock, redis_lock=redis_lock, redis_timeout=redis_params.redis_timeout) - scheduler.add_job( - extend_lock, - 'interval', - seconds=redis_params.lock_extend_seconds, - max_instances=999999999, - misfire_grace_time=redis_params.redis_timeout, - coalesce=False, - ) - scheduler.start() - return scheduler - - -def _wrap_with_lock(func, redis_params: RedisParams): - if not redis_params.should_redis_lock: - return func - - def wrapper(*args, **kwargs): - redis_lock = _set_redis_lock(redis_params=redis_params) - scheduler = _set_lock_scheduler(redis_lock=redis_lock, redis_params=redis_params) - - try: - logger.debug(f'Task lock of {redis_params.redis_key} locked.') - result = func(*args, **kwargs) - redis_lock.release() - logger.debug(f'Task lock of {redis_params.redis_key} released.') - scheduler.shutdown() - return result - except BaseException as e: - logger.debug(f'Task lock of {redis_params.redis_key} released with BaseException.') - redis_lock.release() - scheduler.shutdown() - raise e - - return wrapper - - -def wrap_with_run_lock(func: Callable, redis_params: RedisParams): - """Redis lock wrapper function for RunWithLock. - When a fucntion is wrapped by RunWithLock, the wrapped function will be simply wrapped with redis lock. - https://github.com/m3dev/gokart/issues/265 - """ - return _wrap_with_lock(func=func, redis_params=redis_params) - - -def wrap_with_dump_lock(func: Callable, redis_params: RedisParams, exist_check: Callable): - """Redis lock wrapper function for TargetOnKart.dump(). - When TargetOnKart.dump() is called, dump() will be wrapped with redis lock and cache existance check. - https://github.com/m3dev/gokart/issues/265 - """ - - if not redis_params.should_redis_lock: - return func - - def wrapper(*args, **kwargs): - redis_lock = _set_redis_lock(redis_params=redis_params) - scheduler = _set_lock_scheduler(redis_lock=redis_lock, redis_params=redis_params) - - try: - logger.debug(f'Task lock of {redis_params.redis_key} locked.') - if not exist_check(): - func(*args, **kwargs) - finally: - logger.debug(f'Task lock of {redis_params.redis_key} released.') - redis_lock.release() - scheduler.shutdown() - - return wrapper - - -def wrap_with_load_lock(func, redis_params: RedisParams): - """Redis lock wrapper function for TargetOnKart.load(). - When TargetOnKart.load() is called, redis lock will be locked and released before load(). - https://github.com/m3dev/gokart/issues/265 - """ - - if not redis_params.should_redis_lock: - return func - - def wrapper(*args, **kwargs): - redis_lock = _set_redis_lock(redis_params=redis_params) - scheduler = _set_lock_scheduler(redis_lock=redis_lock, redis_params=redis_params) - - logger.debug(f'Task lock of {redis_params.redis_key} locked.') - redis_lock.release() - logger.debug(f'Task lock of {redis_params.redis_key} released.') - scheduler.shutdown() - result = func(*args, **kwargs) - return result - - return wrapper - - -def wrap_with_remove_lock(func, redis_params: RedisParams): - """Redis lock wrapper function for TargetOnKart.remove(). - When TargetOnKart.remove() is called, remove() will be simply wrapped with redis lock. - https://github.com/m3dev/gokart/issues/265 - """ - return _wrap_with_lock(func=func, redis_params=redis_params) - - -def make_redis_key(file_path: str, unique_id: Optional[str]): - basename_without_ext = os.path.splitext(os.path.basename(file_path))[0] - return f'{basename_without_ext}_{unique_id}' - - -def make_redis_params( - file_path: str, - unique_id: Optional[str], - redis_host: Optional[str] = None, - redis_port: Optional[int] = None, - redis_timeout: Optional[int] = None, - raise_task_lock_exception_on_collision: bool = False, - lock_extend_seconds: int = 10, -): - redis_key = make_redis_key(file_path, unique_id) - should_redis_lock = redis_host is not None and redis_port is not None - if redis_timeout is not None: - assert redis_timeout > lock_extend_seconds, f'`redis_timeout` must be set greater than lock_extend_seconds:{lock_extend_seconds}, not {redis_timeout}.' - redis_params = RedisParams( - redis_host=redis_host, - redis_port=redis_port, - redis_key=redis_key, - should_redis_lock=should_redis_lock, - redis_timeout=redis_timeout, - raise_task_lock_exception_on_collision=raise_task_lock_exception_on_collision, - lock_extend_seconds=lock_extend_seconds, - ) - return redis_params diff --git a/gokart/target.py b/gokart/target.py index 6ef96336..dd62891c 100644 --- a/gokart/target.py +++ b/gokart/target.py @@ -12,9 +12,10 @@ import pandas as pd from tqdm import tqdm +from gokart.conflict_prevention_lock.task_lock import TaskLockParams, make_task_lock_params +from gokart.conflict_prevention_lock.task_lock_wrappers import wrap_with_dump_lock, wrap_with_load_lock, wrap_with_remove_lock, wrap_with_run_lock from gokart.file_processor import FileProcessor, make_file_processor from gokart.object_storage import ObjectStorage -from gokart.redis_lock import RedisParams, make_redis_params, wrap_with_dump_lock, wrap_with_load_lock, wrap_with_remove_lock, wrap_with_run_lock from gokart.zip_client_util import make_zip_client logger = getLogger(__name__) @@ -25,17 +26,17 @@ def exists(self) -> bool: return self._exists() def load(self) -> Any: - return wrap_with_load_lock(func=self._load, redis_params=self._get_redis_params())() + return wrap_with_load_lock(func=self._load, task_lock_params=self._get_task_lock_params())() def dump(self, obj, lock_at_dump: bool = True) -> None: if lock_at_dump: - wrap_with_dump_lock(func=self._dump, redis_params=self._get_redis_params(), exist_check=self.exists)(obj) + wrap_with_dump_lock(func=self._dump, task_lock_params=self._get_task_lock_params(), exist_check=self.exists)(obj) else: self._dump(obj) def remove(self) -> None: if self.exists(): - wrap_with_remove_lock(self._remove, redis_params=self._get_redis_params())() + wrap_with_remove_lock(self._remove, task_lock_params=self._get_task_lock_params())() def last_modification_time(self) -> datetime: return self._last_modification_time() @@ -44,14 +45,14 @@ def path(self) -> str: return self._path() def wrap_with_run_lock(self, func): - return wrap_with_run_lock(func=func, redis_params=self._get_redis_params()) + return wrap_with_run_lock(func=func, task_lock_params=self._get_task_lock_params()) @abstractmethod def _exists(self) -> bool: pass @abstractmethod - def _get_redis_params(self) -> RedisParams: + def _get_task_lock_params(self) -> TaskLockParams: pass @abstractmethod @@ -80,17 +81,17 @@ def __init__( self, target: luigi.target.FileSystemTarget, processor: FileProcessor, - redis_params: RedisParams, + task_lock_params: TaskLockParams, ) -> None: self._target = target self._processor = processor - self._redis_params = redis_params + self._task_lock_params = task_lock_params def _exists(self) -> bool: return self._target.exists() - def _get_redis_params(self) -> RedisParams: - return self._redis_params + def _get_task_lock_params(self) -> TaskLockParams: + return self._task_lock_params def _load(self) -> Any: with self._target.open('r') as f: @@ -117,19 +118,19 @@ def __init__( temporary_directory: str, load_function, save_function, - redis_params: RedisParams, + task_lock_params: TaskLockParams, ) -> None: self._zip_client = make_zip_client(file_path, temporary_directory) self._temporary_directory = temporary_directory self._save_function = save_function self._load_function = load_function - self._redis_params = redis_params + self._task_lock_params = task_lock_params def _exists(self) -> bool: return self._zip_client.exists() - def _get_redis_params(self) -> RedisParams: - return self._redis_params + def _get_task_lock_params(self) -> TaskLockParams: + return self._task_lock_params def _load(self) -> Any: self._zip_client.unpack_archive() @@ -217,22 +218,31 @@ def make_target( file_path: str, unique_id: Optional[str] = None, processor: Optional[FileProcessor] = None, - redis_params: Optional[RedisParams] = None, + task_lock_params: Optional[TaskLockParams] = None, store_index_in_feather: bool = True, ) -> TargetOnKart: - _redis_params = redis_params if redis_params is not None else make_redis_params(file_path=file_path, unique_id=unique_id) + _task_lock_params = task_lock_params if task_lock_params is not None else make_task_lock_params(file_path=file_path, unique_id=unique_id) file_path = _make_file_path(file_path, unique_id) processor = processor or make_file_processor(file_path, store_index_in_feather=store_index_in_feather) file_system_target = _make_file_system_target(file_path, processor=processor, store_index_in_feather=store_index_in_feather) - return SingleFileTarget(target=file_system_target, processor=processor, redis_params=_redis_params) + return SingleFileTarget(target=file_system_target, processor=processor, task_lock_params=_task_lock_params) def make_model_target( - file_path: str, temporary_directory: str, save_function, load_function, unique_id: Optional[str] = None, redis_params: Optional[RedisParams] = None + file_path: str, + temporary_directory: str, + save_function, + load_function, + unique_id: Optional[str] = None, + task_lock_params: Optional[TaskLockParams] = None, ) -> TargetOnKart: - _redis_params = redis_params if redis_params is not None else make_redis_params(file_path=file_path, unique_id=unique_id) + _task_lock_params = task_lock_params if task_lock_params is not None else make_task_lock_params(file_path=file_path, unique_id=unique_id) file_path = _make_file_path(file_path, unique_id) temporary_directory = os.path.join(temporary_directory, hashlib.md5(file_path.encode()).hexdigest()) return ModelTarget( - file_path=file_path, temporary_directory=temporary_directory, save_function=save_function, load_function=load_function, redis_params=_redis_params + file_path=file_path, + temporary_directory=temporary_directory, + save_function=save_function, + load_function=load_function, + task_lock_params=_task_lock_params, ) diff --git a/gokart/task.py b/gokart/task.py index d592454f..b818cffc 100644 --- a/gokart/task.py +++ b/gokart/task.py @@ -11,10 +11,10 @@ from luigi.parameter import ParameterVisibility import gokart +from gokart.conflict_prevention_lock.task_lock import make_task_lock_params from gokart.file_processor import FileProcessor from gokart.pandas_type_config import PandasTypeConfigMap from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, TaskInstanceParameter -from gokart.redis_lock import make_redis_params from gokart.target import TargetOnKart from gokart.task_complete_check import task_complete_check_wrapper @@ -173,7 +173,7 @@ def make_target(self, relative_file_path: Optional[str] = None, use_unique_id: b file_path = os.path.join(self.workspace_directory, formatted_relative_file_path) unique_id = self.make_unique_id() if use_unique_id else None - redis_params = make_redis_params( + task_lock_params = make_task_lock_params( file_path=file_path, unique_id=unique_id, redis_host=self.redis_host, @@ -183,7 +183,7 @@ def make_target(self, relative_file_path: Optional[str] = None, use_unique_id: b ) return gokart.target.make_target( - file_path=file_path, unique_id=unique_id, processor=processor, redis_params=redis_params, store_index_in_feather=self.store_index_in_feather + file_path=file_path, unique_id=unique_id, processor=processor, task_lock_params=task_lock_params, store_index_in_feather=self.store_index_in_feather ) def make_large_data_frame_target(self, relative_file_path: Optional[str] = None, use_unique_id: bool = True, max_byte=int(2**26)) -> TargetOnKart: @@ -192,7 +192,7 @@ def make_large_data_frame_target(self, relative_file_path: Optional[str] = None, ) file_path = os.path.join(self.workspace_directory, formatted_relative_file_path) unique_id = self.make_unique_id() if use_unique_id else None - redis_params = make_redis_params( + task_lock_params = make_task_lock_params( file_path=file_path, unique_id=unique_id, redis_host=self.redis_host, @@ -207,7 +207,7 @@ def make_large_data_frame_target(self, relative_file_path: Optional[str] = None, unique_id=unique_id, save_function=gokart.target.LargeDataFrameProcessor(max_byte=max_byte).save, load_function=gokart.target.LargeDataFrameProcessor.load, - redis_params=redis_params, + task_lock_params=task_lock_params, ) def make_model_target( @@ -224,7 +224,7 @@ def make_model_target( file_path = os.path.join(self.workspace_directory, relative_file_path) assert relative_file_path[-3:] == 'zip', f'extension must be zip, but {relative_file_path} is passed.' unique_id = self.make_unique_id() if use_unique_id else None - redis_params = make_redis_params( + task_lock_params = make_task_lock_params( file_path=file_path, unique_id=unique_id, redis_host=self.redis_host, @@ -239,7 +239,7 @@ def make_model_target( unique_id=unique_id, save_function=save_function, load_function=load_function, - redis_params=redis_params, + task_lock_params=task_lock_params, ) def load(self, target: Union[None, str, TargetOnKart] = None) -> Any: diff --git a/test/conflict_prevention_lock/__init__.py b/test/conflict_prevention_lock/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/conflict_prevention_lock/test_task_lock.py b/test/conflict_prevention_lock/test_task_lock.py new file mode 100644 index 00000000..744f9419 --- /dev/null +++ b/test/conflict_prevention_lock/test_task_lock.py @@ -0,0 +1,72 @@ +import random +import unittest +from unittest.mock import patch + +from gokart.conflict_prevention_lock.task_lock import RedisClient, TaskLockParams, make_task_lock_key, make_task_lock_params + + +class TestRedisClient(unittest.TestCase): + @staticmethod + def _get_randint(host, port): + return random.randint(0, 100000) + + def test_redis_client_is_singleton(self): + with patch('redis.Redis') as mock: + mock.side_effect = self._get_randint + + redis_client_0_0 = RedisClient(host='host_0', port='123') + redis_client_1 = RedisClient(host='host_1', port='123') + redis_client_0_1 = RedisClient(host='host_0', port='123') + + self.assertNotEqual(redis_client_0_0, redis_client_1) + self.assertEqual(redis_client_0_0, redis_client_0_1) + + self.assertEqual(redis_client_0_0.get_redis_client(), redis_client_0_1.get_redis_client()) + + +class TestMakeRedisKey(unittest.TestCase): + def test_make_redis_key(self): + result = make_task_lock_key(file_path='gs://test_ll/dir/fname.pkl', unique_id='12345') + self.assertEqual(result, 'fname_12345') + + +class TestMakeRedisParams(unittest.TestCase): + def test_make_task_lock_params_with_valid_host(self): + result = make_task_lock_params( + file_path='gs://aaa.pkl', unique_id='123', redis_host='0.0.0.0', redis_port='12345', redis_timeout=180, raise_task_lock_exception_on_collision=False + ) + expected = TaskLockParams( + redis_host='0.0.0.0', + redis_port='12345', + redis_key='aaa_123', + should_task_lock=True, + redis_timeout=180, + raise_task_lock_exception_on_collision=False, + lock_extend_seconds=10, + ) + self.assertEqual(result, expected) + + def test_make_task_lock_params_with_no_host(self): + result = make_task_lock_params( + file_path='gs://aaa.pkl', unique_id='123', redis_host=None, redis_port='12345', redis_timeout=180, raise_task_lock_exception_on_collision=False + ) + expected = TaskLockParams( + redis_host=None, + redis_port='12345', + redis_key='aaa_123', + should_task_lock=False, + redis_timeout=180, + raise_task_lock_exception_on_collision=False, + lock_extend_seconds=10, + ) + self.assertEqual(result, expected) + + def test_assert_when_redis_timeout_is_too_short(self): + with self.assertRaises(AssertionError): + make_task_lock_params( + file_path='test_dir/test_file.pkl', + unique_id='123abc', + redis_host='0.0.0.0', + redis_port=12345, + redis_timeout=2, + ) diff --git a/test/test_redis_lock.py b/test/conflict_prevention_lock/test_task_lock_wrappers.py similarity index 60% rename from test/test_redis_lock.py rename to test/conflict_prevention_lock/test_task_lock_wrappers.py index d4835eee..eec87bb4 100644 --- a/test/test_redis_lock.py +++ b/test/conflict_prevention_lock/test_task_lock_wrappers.py @@ -1,39 +1,11 @@ -import random import time import unittest from unittest.mock import MagicMock, patch import fakeredis -from gokart.redis_lock import ( - RedisClient, - RedisParams, - make_redis_key, - make_redis_params, - wrap_with_dump_lock, - wrap_with_load_lock, - wrap_with_remove_lock, - wrap_with_run_lock, -) - - -class TestRedisClient(unittest.TestCase): - @staticmethod - def _get_randint(host, port): - return random.randint(0, 100000) - - def test_redis_client_is_singleton(self): - with patch('redis.Redis') as mock: - mock.side_effect = self._get_randint - - redis_client_0_0 = RedisClient(host='host_0', port='123') - redis_client_1 = RedisClient(host='host_1', port='123') - redis_client_0_1 = RedisClient(host='host_0', port='123') - - self.assertNotEqual(redis_client_0_0, redis_client_1) - self.assertEqual(redis_client_0_0, redis_client_0_1) - - self.assertEqual(redis_client_0_0.get_redis_client(), redis_client_0_1.get_redis_client()) +from gokart.conflict_prevention_lock.task_lock import make_task_lock_params +from gokart.conflict_prevention_lock.task_lock_wrappers import wrap_with_dump_lock, wrap_with_load_lock, wrap_with_remove_lock, wrap_with_run_lock def _sample_func_with_error(a: int, b: str): @@ -47,14 +19,14 @@ def _sample_long_func(a: int, b: str): class TestWrapWithRunLock(unittest.TestCase): def test_no_redis(self): - redis_params = make_redis_params( + task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host=None, redis_port=None, ) mock_func = MagicMock() - resulted = wrap_with_run_lock(func=mock_func, redis_params=redis_params)(123, b='abc') + resulted = wrap_with_run_lock(func=mock_func, task_lock_params=task_lock_params)(123, b='abc') mock_func.assert_called_once() called_args, called_kwargs = mock_func.call_args @@ -63,17 +35,17 @@ def test_no_redis(self): self.assertEqual(resulted, mock_func()) def test_use_redis(self): - redis_params = make_redis_params( + task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', redis_port=12345, ) - with patch('gokart.redis_lock.redis.Redis') as redis_mock: + with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock: mock_func = MagicMock() redis_mock.side_effect = fakeredis.FakeRedis - resulted = wrap_with_run_lock(func=mock_func, redis_params=redis_params)(123, b='abc') + resulted = wrap_with_run_lock(func=mock_func, task_lock_params=task_lock_params)(123, b='abc') mock_func.assert_called_once() called_args, called_kwargs = mock_func.call_args @@ -82,7 +54,7 @@ def test_use_redis(self): self.assertEqual(resulted, mock_func()) def test_check_lock_extended(self): - redis_params = make_redis_params( + task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', @@ -91,14 +63,14 @@ def test_check_lock_extended(self): lock_extend_seconds=1, ) - with patch('gokart.redis_lock.redis.Redis') as redis_mock: + with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock: redis_mock.side_effect = fakeredis.FakeRedis - resulted = wrap_with_run_lock(func=_sample_long_func, redis_params=redis_params)(123, b='abc') + resulted = wrap_with_run_lock(func=_sample_long_func, task_lock_params=task_lock_params)(123, b='abc') expected = dict(a=123, b='abc') self.assertEqual(resulted, expected) def test_lock_is_removed_after_func_is_finished(self): - redis_params = make_redis_params( + task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', @@ -107,10 +79,10 @@ def test_lock_is_removed_after_func_is_finished(self): server = fakeredis.FakeServer() - with patch('gokart.redis_lock.redis.Redis') as redis_mock: - redis_mock.return_value = fakeredis.FakeRedis(server=server, host=redis_params.redis_host, port=redis_params.redis_port) + with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock: + redis_mock.return_value = fakeredis.FakeRedis(server=server, host=task_lock_params.redis_host, port=task_lock_params.redis_port) mock_func = MagicMock() - resulted = wrap_with_run_lock(func=mock_func, redis_params=redis_params)(123, b='abc') + resulted = wrap_with_run_lock(func=mock_func, task_lock_params=task_lock_params)(123, b='abc') mock_func.assert_called_once() called_args, called_kwargs = mock_func.call_args @@ -120,10 +92,10 @@ def test_lock_is_removed_after_func_is_finished(self): fake_redis = fakeredis.FakeStrictRedis(server=server) with self.assertRaises(KeyError): - fake_redis[redis_params.redis_key] + fake_redis[task_lock_params.redis_key] def test_lock_is_removed_after_func_is_finished_with_error(self): - redis_params = make_redis_params( + task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', @@ -132,26 +104,26 @@ def test_lock_is_removed_after_func_is_finished_with_error(self): server = fakeredis.FakeServer() - with patch('gokart.redis_lock.redis.Redis') as redis_mock: - redis_mock.return_value = fakeredis.FakeRedis(server=server, host=redis_params.redis_host, port=redis_params.redis_port) + with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock: + redis_mock.return_value = fakeredis.FakeRedis(server=server, host=task_lock_params.redis_host, port=task_lock_params.redis_port) try: - wrap_with_run_lock(func=_sample_func_with_error, redis_params=redis_params)(a=123, b='abc') + wrap_with_run_lock(func=_sample_func_with_error, task_lock_params=task_lock_params)(a=123, b='abc') except Exception: fake_redis = fakeredis.FakeStrictRedis(server=server) with self.assertRaises(KeyError): - fake_redis[redis_params.redis_key] + fake_redis[task_lock_params.redis_key] class TestWrapWithDumpLock(unittest.TestCase): def test_no_redis(self): - redis_params = make_redis_params( + task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host=None, redis_port=None, ) mock_func = MagicMock() - wrap_with_dump_lock(func=mock_func, redis_params=redis_params, exist_check=lambda: False)(123, b='abc') + wrap_with_dump_lock(func=mock_func, task_lock_params=task_lock_params, exist_check=lambda: False)(123, b='abc') mock_func.assert_called_once() called_args, called_kwargs = mock_func.call_args @@ -159,17 +131,17 @@ def test_no_redis(self): self.assertDictEqual(called_kwargs, dict(b='abc')) def test_use_redis(self): - redis_params = make_redis_params( + task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', redis_port=12345, ) - with patch('gokart.redis_lock.redis.Redis') as redis_mock: + with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock: redis_mock.side_effect = fakeredis.FakeRedis mock_func = MagicMock() - wrap_with_dump_lock(func=mock_func, redis_params=redis_params, exist_check=lambda: False)(123, b='abc') + wrap_with_dump_lock(func=mock_func, task_lock_params=task_lock_params, exist_check=lambda: False)(123, b='abc') mock_func.assert_called_once() called_args, called_kwargs = mock_func.call_args @@ -177,22 +149,22 @@ def test_use_redis(self): self.assertDictEqual(called_kwargs, dict(b='abc')) def test_if_func_is_skipped_when_cache_already_exists(self): - redis_params = make_redis_params( + task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', redis_port=12345, ) - with patch('gokart.redis_lock.redis.Redis') as redis_mock: + with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock: redis_mock.side_effect = fakeredis.FakeRedis mock_func = MagicMock() - wrap_with_dump_lock(func=mock_func, redis_params=redis_params, exist_check=lambda: True)(123, b='abc') + wrap_with_dump_lock(func=mock_func, task_lock_params=task_lock_params, exist_check=lambda: True)(123, b='abc') mock_func.assert_not_called() def test_check_lock_extended(self): - redis_params = make_redis_params( + task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', @@ -201,12 +173,12 @@ def test_check_lock_extended(self): lock_extend_seconds=1, ) - with patch('gokart.redis_lock.redis.Redis') as redis_mock: + with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock: redis_mock.side_effect = fakeredis.FakeRedis - wrap_with_dump_lock(func=_sample_long_func, redis_params=redis_params, exist_check=lambda: False)(123, b='abc') + wrap_with_dump_lock(func=_sample_long_func, task_lock_params=task_lock_params, exist_check=lambda: False)(123, b='abc') def test_lock_is_removed_after_func_is_finished(self): - redis_params = make_redis_params( + task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', @@ -215,10 +187,10 @@ def test_lock_is_removed_after_func_is_finished(self): server = fakeredis.FakeServer() - with patch('gokart.redis_lock.redis.Redis') as redis_mock: - redis_mock.return_value = fakeredis.FakeRedis(server=server, host=redis_params.redis_host, port=redis_params.redis_port) + with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock: + redis_mock.return_value = fakeredis.FakeRedis(server=server, host=task_lock_params.redis_host, port=task_lock_params.redis_port) mock_func = MagicMock() - wrap_with_dump_lock(func=mock_func, redis_params=redis_params, exist_check=lambda: False)(123, b='abc') + wrap_with_dump_lock(func=mock_func, task_lock_params=task_lock_params, exist_check=lambda: False)(123, b='abc') mock_func.assert_called_once() called_args, called_kwargs = mock_func.call_args @@ -227,10 +199,10 @@ def test_lock_is_removed_after_func_is_finished(self): fake_redis = fakeredis.FakeStrictRedis(server=server) with self.assertRaises(KeyError): - fake_redis[redis_params.redis_key] + fake_redis[task_lock_params.redis_key] def test_lock_is_removed_after_func_is_finished_with_error(self): - redis_params = make_redis_params( + task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', @@ -239,26 +211,26 @@ def test_lock_is_removed_after_func_is_finished_with_error(self): server = fakeredis.FakeServer() - with patch('gokart.redis_lock.redis.Redis') as redis_mock: - redis_mock.return_value = fakeredis.FakeRedis(server=server, host=redis_params.redis_host, port=redis_params.redis_port) + with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock: + redis_mock.return_value = fakeredis.FakeRedis(server=server, host=task_lock_params.redis_host, port=task_lock_params.redis_port) try: - wrap_with_dump_lock(func=_sample_func_with_error, redis_params=redis_params, exist_check=lambda: False)(123, b='abc') + wrap_with_dump_lock(func=_sample_func_with_error, task_lock_params=task_lock_params, exist_check=lambda: False)(123, b='abc') except Exception: fake_redis = fakeredis.FakeStrictRedis(server=server) with self.assertRaises(KeyError): - fake_redis[redis_params.redis_key] + fake_redis[task_lock_params.redis_key] class TestWrapWithLoadLock(unittest.TestCase): def test_no_redis(self): - redis_params = make_redis_params( + task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host=None, redis_port=None, ) mock_func = MagicMock() - resulted = wrap_with_load_lock(func=mock_func, redis_params=redis_params)(123, b='abc') + resulted = wrap_with_load_lock(func=mock_func, task_lock_params=task_lock_params)(123, b='abc') mock_func.assert_called_once() called_args, called_kwargs = mock_func.call_args @@ -268,17 +240,17 @@ def test_no_redis(self): self.assertEqual(resulted, mock_func()) def test_use_redis(self): - redis_params = make_redis_params( + task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', redis_port=12345, ) - with patch('gokart.redis_lock.redis.Redis') as redis_mock: + with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock: redis_mock.side_effect = fakeredis.FakeRedis mock_func = MagicMock() - resulted = wrap_with_load_lock(func=mock_func, redis_params=redis_params)(123, b='abc') + resulted = wrap_with_load_lock(func=mock_func, task_lock_params=task_lock_params)(123, b='abc') mock_func.assert_called_once() called_args, called_kwargs = mock_func.call_args @@ -288,7 +260,7 @@ def test_use_redis(self): self.assertEqual(resulted, mock_func()) def test_check_lock_extended(self): - redis_params = make_redis_params( + task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', @@ -297,14 +269,14 @@ def test_check_lock_extended(self): lock_extend_seconds=1, ) - with patch('gokart.redis_lock.redis.Redis') as redis_mock: + with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock: redis_mock.side_effect = fakeredis.FakeRedis - resulted = wrap_with_load_lock(func=_sample_long_func, redis_params=redis_params)(123, b='abc') + resulted = wrap_with_load_lock(func=_sample_long_func, task_lock_params=task_lock_params)(123, b='abc') expected = dict(a=123, b='abc') self.assertEqual(resulted, expected) def test_lock_is_removed_after_func_is_finished(self): - redis_params = make_redis_params( + task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', @@ -313,10 +285,10 @@ def test_lock_is_removed_after_func_is_finished(self): server = fakeredis.FakeServer() - with patch('gokart.redis_lock.redis.Redis') as redis_mock: - redis_mock.return_value = fakeredis.FakeRedis(server=server, host=redis_params.redis_host, port=redis_params.redis_port) + with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock: + redis_mock.return_value = fakeredis.FakeRedis(server=server, host=task_lock_params.redis_host, port=task_lock_params.redis_port) mock_func = MagicMock() - resulted = wrap_with_load_lock(func=mock_func, redis_params=redis_params)(123, b='abc') + resulted = wrap_with_load_lock(func=mock_func, task_lock_params=task_lock_params)(123, b='abc') mock_func.assert_called_once() called_args, called_kwargs = mock_func.call_args @@ -326,10 +298,10 @@ def test_lock_is_removed_after_func_is_finished(self): fake_redis = fakeredis.FakeStrictRedis(server=server) with self.assertRaises(KeyError): - fake_redis[redis_params.redis_key] + fake_redis[task_lock_params.redis_key] def test_lock_is_removed_after_func_is_finished_with_error(self): - redis_params = make_redis_params( + task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', @@ -338,26 +310,26 @@ def test_lock_is_removed_after_func_is_finished_with_error(self): server = fakeredis.FakeServer() - with patch('gokart.redis_lock.redis.Redis') as redis_mock: - redis_mock.return_value = fakeredis.FakeRedis(server=server, host=redis_params.redis_host, port=redis_params.redis_port) + with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock: + redis_mock.return_value = fakeredis.FakeRedis(server=server, host=task_lock_params.redis_host, port=task_lock_params.redis_port) try: - wrap_with_load_lock(func=_sample_func_with_error, redis_params=redis_params)(123, b='abc') + wrap_with_load_lock(func=_sample_func_with_error, task_lock_params=task_lock_params)(123, b='abc') except Exception: fake_redis = fakeredis.FakeStrictRedis(server=server) with self.assertRaises(KeyError): - fake_redis[redis_params.redis_key] + fake_redis[task_lock_params.redis_key] class TestWrapWithRemoveLock(unittest.TestCase): def test_no_redis(self): - redis_params = make_redis_params( + task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host=None, redis_port=None, ) mock_func = MagicMock() - resulted = wrap_with_remove_lock(func=mock_func, redis_params=redis_params)(123, b='abc') + resulted = wrap_with_remove_lock(func=mock_func, task_lock_params=task_lock_params)(123, b='abc') mock_func.assert_called_once() called_args, called_kwargs = mock_func.call_args @@ -366,17 +338,17 @@ def test_no_redis(self): self.assertEqual(resulted, mock_func()) def test_use_redis(self): - redis_params = make_redis_params( + task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', redis_port=12345, ) - with patch('gokart.redis_lock.redis.Redis') as redis_mock: + with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock: redis_mock.side_effect = fakeredis.FakeRedis mock_func = MagicMock() - resulted = wrap_with_remove_lock(func=mock_func, redis_params=redis_params)(123, b='abc') + resulted = wrap_with_remove_lock(func=mock_func, task_lock_params=task_lock_params)(123, b='abc') mock_func.assert_called_once() called_args, called_kwargs = mock_func.call_args @@ -385,7 +357,7 @@ def test_use_redis(self): self.assertEqual(resulted, mock_func()) def test_check_lock_extended(self): - redis_params = make_redis_params( + task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', @@ -394,14 +366,14 @@ def test_check_lock_extended(self): lock_extend_seconds=1, ) - with patch('gokart.redis_lock.redis.Redis') as redis_mock: + with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock: redis_mock.side_effect = fakeredis.FakeRedis - resulted = wrap_with_remove_lock(func=_sample_long_func, redis_params=redis_params)(123, b='abc') + resulted = wrap_with_remove_lock(func=_sample_long_func, task_lock_params=task_lock_params)(123, b='abc') expected = dict(a=123, b='abc') self.assertEqual(resulted, expected) def test_lock_is_removed_after_func_is_finished(self): - redis_params = make_redis_params( + task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', @@ -410,10 +382,10 @@ def test_lock_is_removed_after_func_is_finished(self): server = fakeredis.FakeServer() - with patch('gokart.redis_lock.redis.Redis') as redis_mock: - redis_mock.return_value = fakeredis.FakeRedis(server=server, host=redis_params.redis_host, port=redis_params.redis_port) + with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock: + redis_mock.return_value = fakeredis.FakeRedis(server=server, host=task_lock_params.redis_host, port=task_lock_params.redis_port) mock_func = MagicMock() - resulted = wrap_with_remove_lock(func=mock_func, redis_params=redis_params)(123, b='abc') + resulted = wrap_with_remove_lock(func=mock_func, task_lock_params=task_lock_params)(123, b='abc') mock_func.assert_called_once() called_args, called_kwargs = mock_func.call_args self.assertTupleEqual(called_args, (123,)) @@ -422,10 +394,10 @@ def test_lock_is_removed_after_func_is_finished(self): fake_redis = fakeredis.FakeStrictRedis(server=server) with self.assertRaises(KeyError): - fake_redis[redis_params.redis_key] + fake_redis[task_lock_params.redis_key] def test_lock_is_removed_after_func_is_finished_with_error(self): - redis_params = make_redis_params( + task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', @@ -434,59 +406,11 @@ def test_lock_is_removed_after_func_is_finished_with_error(self): server = fakeredis.FakeServer() - with patch('gokart.redis_lock.redis.Redis') as redis_mock: - redis_mock.return_value = fakeredis.FakeRedis(server=server, host=redis_params.redis_host, port=redis_params.redis_port) + with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock: + redis_mock.return_value = fakeredis.FakeRedis(server=server, host=task_lock_params.redis_host, port=task_lock_params.redis_port) try: - wrap_with_remove_lock(func=_sample_func_with_error, redis_params=redis_params)(123, b='abc') + wrap_with_remove_lock(func=_sample_func_with_error, task_lock_params=task_lock_params)(123, b='abc') except Exception: fake_redis = fakeredis.FakeStrictRedis(server=server) with self.assertRaises(KeyError): - fake_redis[redis_params.redis_key] - - -class TestMakeRedisKey(unittest.TestCase): - def test_make_redis_key(self): - result = make_redis_key(file_path='gs://test_ll/dir/fname.pkl', unique_id='12345') - self.assertEqual(result, 'fname_12345') - - -class TestMakeRedisParams(unittest.TestCase): - def test_make_redis_params_with_valid_host(self): - result = make_redis_params( - file_path='gs://aaa.pkl', unique_id='123', redis_host='0.0.0.0', redis_port='12345', redis_timeout=180, raise_task_lock_exception_on_collision=False - ) - expected = RedisParams( - redis_host='0.0.0.0', - redis_port='12345', - redis_key='aaa_123', - should_redis_lock=True, - redis_timeout=180, - raise_task_lock_exception_on_collision=False, - lock_extend_seconds=10, - ) - self.assertEqual(result, expected) - - def test_make_redis_params_with_no_host(self): - result = make_redis_params( - file_path='gs://aaa.pkl', unique_id='123', redis_host=None, redis_port='12345', redis_timeout=180, raise_task_lock_exception_on_collision=False - ) - expected = RedisParams( - redis_host=None, - redis_port='12345', - redis_key='aaa_123', - should_redis_lock=False, - redis_timeout=180, - raise_task_lock_exception_on_collision=False, - lock_extend_seconds=10, - ) - self.assertEqual(result, expected) - - def test_assert_when_redis_timeout_is_too_short(self): - with self.assertRaises(AssertionError): - make_redis_params( - file_path='test_dir/test_file.pkl', - unique_id='123abc', - redis_host='0.0.0.0', - redis_port=12345, - redis_timeout=2, - ) + fake_redis[task_lock_params.redis_key] diff --git a/test/test_task_on_kart.py b/test/test_task_on_kart.py index 9f172e48..bd9f5c2b 100644 --- a/test/test_task_on_kart.py +++ b/test/test_task_on_kart.py @@ -11,9 +11,9 @@ from luigi.util import inherits import gokart +from gokart.conflict_prevention_lock.run_with_lock import RunWithLock from gokart.file_processor import XmlFileProcessor from gokart.parameter import ListTaskInstanceParameter, TaskInstanceParameter -from gokart.run_with_lock import RunWithLock from gokart.target import ModelTarget, SingleFileTarget, TargetOnKart