Skip to content

Commit

Permalink
Refactoring: move conflict prevention lock codes (#351)
Browse files Browse the repository at this point in the history
* move code

* add

* add

* rename directory

* rename test directory

* add

* rename test

* add

* fix wrapper test

* fix ruff
  • Loading branch information
mski-iksm authored Feb 26, 2024
1 parent c912f45 commit f4adc46
Show file tree
Hide file tree
Showing 10 changed files with 383 additions and 371 deletions.
File renamed without changes.
102 changes: 102 additions & 0 deletions gokart/conflict_prevention_lock/task_lock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import functools
import os
from logging import getLogger
from typing import NamedTuple, Optional

import redis
from apscheduler.schedulers.background import BackgroundScheduler

logger = getLogger(__name__)


class TaskLockParams(NamedTuple):
redis_host: Optional[str]
redis_port: Optional[int]
redis_timeout: Optional[int]
redis_key: str
should_task_lock: bool
raise_task_lock_exception_on_collision: bool
lock_extend_seconds: int


class TaskLockException(Exception):
pass


class RedisClient:
_instances: dict = {}

def __new__(cls, *args, **kwargs):
key = (args, tuple(sorted(kwargs.items())))
if cls not in cls._instances:
cls._instances[cls] = {}
if key not in cls._instances[cls]:
cls._instances[cls][key] = super(RedisClient, cls).__new__(cls)
return cls._instances[cls][key]

def __init__(self, host: Optional[str], port: Optional[int]) -> None:
if not hasattr(self, '_redis_client'):
host = host or 'localhost'
port = port or 6379
self._redis_client = redis.Redis(host=host, port=port)

def get_redis_client(self):
return self._redis_client


def _extend_lock(task_lock: redis.lock.Lock, redis_timeout: int):
task_lock.extend(additional_time=redis_timeout, replace_ttl=True)


def set_task_lock(task_lock_params: TaskLockParams) -> redis.lock.Lock:
redis_client = RedisClient(host=task_lock_params.redis_host, port=task_lock_params.redis_port).get_redis_client()
blocking = not task_lock_params.raise_task_lock_exception_on_collision
task_lock = redis.lock.Lock(redis=redis_client, name=task_lock_params.redis_key, timeout=task_lock_params.redis_timeout, thread_local=False)
if not task_lock.acquire(blocking=blocking):
raise TaskLockException('Lock already taken by other task.')
return task_lock


def set_lock_scheduler(task_lock: redis.lock.Lock, task_lock_params: TaskLockParams) -> BackgroundScheduler:
scheduler = BackgroundScheduler()
extend_lock = functools.partial(_extend_lock, task_lock=task_lock, redis_timeout=task_lock_params.redis_timeout)
scheduler.add_job(
extend_lock,
'interval',
seconds=task_lock_params.lock_extend_seconds,
max_instances=999999999,
misfire_grace_time=task_lock_params.redis_timeout,
coalesce=False,
)
scheduler.start()
return scheduler


def make_task_lock_key(file_path: str, unique_id: Optional[str]):
basename_without_ext = os.path.splitext(os.path.basename(file_path))[0]
return f'{basename_without_ext}_{unique_id}'


def make_task_lock_params(
file_path: str,
unique_id: Optional[str],
redis_host: Optional[str] = None,
redis_port: Optional[int] = None,
redis_timeout: Optional[int] = None,
raise_task_lock_exception_on_collision: bool = False,
lock_extend_seconds: int = 10,
) -> TaskLockParams:
redis_key = make_task_lock_key(file_path, unique_id)
should_task_lock = redis_host is not None and redis_port is not None
if redis_timeout is not None:
assert redis_timeout > lock_extend_seconds, f'`redis_timeout` must be set greater than lock_extend_seconds:{lock_extend_seconds}, not {redis_timeout}.'
task_lock_params = TaskLockParams(
redis_host=redis_host,
redis_port=redis_port,
redis_key=redis_key,
should_task_lock=should_task_lock,
redis_timeout=redis_timeout,
raise_task_lock_exception_on_collision=raise_task_lock_exception_on_collision,
lock_extend_seconds=lock_extend_seconds,
)
return task_lock_params
94 changes: 94 additions & 0 deletions gokart/conflict_prevention_lock/task_lock_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from logging import getLogger
from typing import Callable

from gokart.conflict_prevention_lock.task_lock import TaskLockParams, set_lock_scheduler, set_task_lock

logger = getLogger(__name__)


def _wrap_with_lock(func, task_lock_params: TaskLockParams):
if not task_lock_params.should_task_lock:
return func

def wrapper(*args, **kwargs):
task_lock = set_task_lock(task_lock_params=task_lock_params)
scheduler = set_lock_scheduler(task_lock=task_lock, task_lock_params=task_lock_params)

try:
logger.debug(f'Task lock of {task_lock_params.redis_key} locked.')
result = func(*args, **kwargs)
task_lock.release()
logger.debug(f'Task lock of {task_lock_params.redis_key} released.')
scheduler.shutdown()
return result
except BaseException as e:
logger.debug(f'Task lock of {task_lock_params.redis_key} released with BaseException.')
task_lock.release()
scheduler.shutdown()
raise e

return wrapper


def wrap_with_run_lock(func: Callable, task_lock_params: TaskLockParams):
"""Redis lock wrapper function for RunWithLock.
When a fucntion is wrapped by RunWithLock, the wrapped function will be simply wrapped with redis lock.
https://github.com/m3dev/gokart/issues/265
"""
return _wrap_with_lock(func=func, task_lock_params=task_lock_params)


def wrap_with_dump_lock(func: Callable, task_lock_params: TaskLockParams, exist_check: Callable):
"""Redis lock wrapper function for TargetOnKart.dump().
When TargetOnKart.dump() is called, dump() will be wrapped with redis lock and cache existance check.
https://github.com/m3dev/gokart/issues/265
"""

if not task_lock_params.should_task_lock:
return func

def wrapper(*args, **kwargs):
task_lock = set_task_lock(task_lock_params=task_lock_params)
scheduler = set_lock_scheduler(task_lock=task_lock, task_lock_params=task_lock_params)

try:
logger.debug(f'Task lock of {task_lock_params.redis_key} locked.')
if not exist_check():
func(*args, **kwargs)
finally:
logger.debug(f'Task lock of {task_lock_params.redis_key} released.')
task_lock.release()
scheduler.shutdown()

return wrapper


def wrap_with_load_lock(func, task_lock_params: TaskLockParams):
"""Redis lock wrapper function for TargetOnKart.load().
When TargetOnKart.load() is called, redis lock will be locked and released before load().
https://github.com/m3dev/gokart/issues/265
"""

if not task_lock_params.should_task_lock:
return func

def wrapper(*args, **kwargs):
task_lock = set_task_lock(task_lock_params=task_lock_params)
scheduler = set_lock_scheduler(task_lock=task_lock, task_lock_params=task_lock_params)

logger.debug(f'Task lock of {task_lock_params.redis_key} locked.')
task_lock.release()
logger.debug(f'Task lock of {task_lock_params.redis_key} released.')
scheduler.shutdown()
result = func(*args, **kwargs)
return result

return wrapper


def wrap_with_remove_lock(func, task_lock_params: TaskLockParams):
"""Redis lock wrapper function for TargetOnKart.remove().
When TargetOnKart.remove() is called, remove() will be simply wrapped with redis lock.
https://github.com/m3dev/gokart/issues/265
"""
return _wrap_with_lock(func=func, task_lock_params=task_lock_params)
190 changes: 0 additions & 190 deletions gokart/redis_lock.py

This file was deleted.

Loading

0 comments on commit f4adc46

Please sign in to comment.