Skip to content

Commit

Permalink
Merge branch 'master' into support-python312
Browse files Browse the repository at this point in the history
  • Loading branch information
kitagry committed Jan 16, 2024
2 parents c65c97f + cc1e849 commit 1ac87d8
Show file tree
Hide file tree
Showing 12 changed files with 635 additions and 591 deletions.
66 changes: 0 additions & 66 deletions docs/using_task_cache_collision_lock.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,69 +49,3 @@ How to use
3. Done

With the above configuration, all tasks that inherits gokart.TaskOnKart will ask the redis server if any other node is not trying to access the same cache file at the same time whenever they access the file with dump or load.


Advanced: Using efficient task cache collision lock
-----------------------------------------

The cache lock introduced above will prevent cache collision.
However, above setting check collisions only when the task access the cache file (i.e. ``task.dump()``, ``task.load()`` and ``task.remove()``).
This will allow applications to run ``run()`` of same task at the same time, which is not time efficient.

Settings in this section will prevent running ``run()`` at the same time for efficiency.

If you try to run() the same task on multiple worker nodes at the same time, run() will fail on the second and subsequent node's tasks.
gokart will execute other unaffected tasks in the meantime. Since we have also set up the retry process, we will come back to the failed task later.
When it comes back, the first worker node has already completed run() and a cache has been created, so there is no need to run() on the second and subsequent nodes.
In this way, efficient distributed processing is made possible.


This setting must be done to each gokart task which you want to lock the ``run()```.

1. Set normal cache collision lock

Follow the steps in ``How to use`` to set up cache collision lock.


2. Decorate ``run()`` with ``@RunWithLock``

Decorate ``run()`` of your gokart tasks you want to lock with ``@RunWithLock``.

.. code:: python
from gokart.run_with_lock import RunWithLock
class SomeTask(gokart.TaskOnKart):
@RunWithLock
def run(self):
...
3. Set ``redis_fail_on_collision`` parameter to true.

This parameter will affect the behavior when the task's lock is taken by other applications or nodes.
Setting ``redis_fail_on_collision=True`` will make the task to be failed if the task's lock is taken by others.

The parameter can be set by config file.

.. code::
[TaskOnKart]
redis_host=localhost
redis_port=6379
redis_fail_on_collision=true
4. Set retry parameters

Set following parameters to retry when task failed.
* ``retry_count``: the max number of retries
* ``retry_delay``: this value is set in seconds

.. code::
[scheduler]
retry_count=10000
retry_delay=10
[worker]
keep_alive=true
8 changes: 6 additions & 2 deletions gokart/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import luigi

import gokart
from gokart.task import TaskOnKart


Expand Down Expand Up @@ -38,8 +39,11 @@ def _get_output(task: TaskOnKart) -> Any:


def _reset_register(keep={'gokart', 'luigi'}):
luigi.task_register.Register._reg = [x for x in luigi.task_register.Register._reg
if x.__module__.split('.')[0] in keep] # avoid TaskClassAmbigiousException
"""reset luigi.task_register.Register._reg everytime gokart.build called to avoid TaskClassAmbigiousException"""
luigi.task_register.Register._reg = [
x for x in luigi.task_register.Register._reg if ((x.__module__.split('.')[0] in keep) # keep luigi and gokart
or (issubclass(x, gokart.PandasTypeConfig))) # PandasTypeConfig should be kept
]


def build(task: TaskOnKart, return_value: bool = True, reset_register: bool = True, log_level: int = logging.ERROR, **env_params) -> Optional[Any]:
Expand Down
10 changes: 5 additions & 5 deletions gokart/redis_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class RedisParams(NamedTuple):
redis_timeout: Optional[int]
redis_key: str
should_redis_lock: bool
redis_fail_on_collision: bool
raise_task_lock_exception_on_collision: bool
lock_extend_seconds: int


Expand Down Expand Up @@ -50,7 +50,7 @@ def _extend_lock(redis_lock: redis.lock.Lock, redis_timeout: int):

def _set_redis_lock(redis_params: RedisParams) -> redis.lock.Lock:
redis_client = RedisClient(host=redis_params.redis_host, port=redis_params.redis_port).get_redis_client()
blocking = not redis_params.redis_fail_on_collision
blocking = not redis_params.raise_task_lock_exception_on_collision
redis_lock = redis.lock.Lock(redis=redis_client, name=redis_params.redis_key, timeout=redis_params.redis_timeout, thread_local=False)
if not redis_lock.acquire(blocking=blocking):
raise TaskLockException('Lock already taken by other task.')
Expand Down Expand Up @@ -94,7 +94,7 @@ def wrapper(*args, **kwargs):
return wrapper


def wrap_with_run_lock(func, redis_params: RedisParams):
def wrap_with_run_lock(func: Callable, redis_params: RedisParams):
"""Redis lock wrapper function for RunWithLock.
When a fucntion is wrapped by RunWithLock, the wrapped function will be simply wrapped with redis lock.
https://github.com/m3dev/gokart/issues/265
Expand Down Expand Up @@ -168,7 +168,7 @@ def make_redis_params(file_path: str,
redis_host: Optional[str] = None,
redis_port: Optional[int] = None,
redis_timeout: Optional[int] = None,
redis_fail_on_collision: bool = False,
raise_task_lock_exception_on_collision: bool = False,
lock_extend_seconds: int = 10):
redis_key = make_redis_key(file_path, unique_id)
should_redis_lock = redis_host is not None and redis_port is not None
Expand All @@ -179,6 +179,6 @@ def make_redis_params(file_path: str,
redis_key=redis_key,
should_redis_lock=should_redis_lock,
redis_timeout=redis_timeout,
redis_fail_on_collision=redis_fail_on_collision,
raise_task_lock_exception_on_collision=raise_task_lock_exception_on_collision,
lock_extend_seconds=lock_extend_seconds)
return redis_params
2 changes: 1 addition & 1 deletion gokart/run_with_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@ def _run_with_lock(cls, func, output_list: list):
return func()

output = output_list.pop()
wrapped_func = output.wrap_with_lock(func)
wrapped_func = output.wrap_with_run_lock(func)
return cls._run_with_lock(func=wrapped_func, output_list=output_list)
2 changes: 1 addition & 1 deletion gokart/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def last_modification_time(self) -> datetime:
def path(self) -> str:
return self._path()

def wrap_with_lock(self, func):
def wrap_with_run_lock(self, func):
return wrap_with_run_lock(func=func, redis_params=self._get_redis_params())

@abstractmethod
Expand Down
21 changes: 14 additions & 7 deletions gokart/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, TaskInstanceParameter
from gokart.redis_lock import make_redis_params
from gokart.target import TargetOnKart
from gokart.task_complete_check import task_complete_check_wrapper

logger = getLogger(__name__)

Expand Down Expand Up @@ -61,10 +62,7 @@ class TaskOnKart(luigi.Task):
redis_host = luigi.OptionalParameter(default=None, description='Task lock check is deactivated, when None.', significant=False)
redis_port = luigi.OptionalParameter(default=None, description='Task lock check is deactivated, when None.', significant=False)
redis_timeout = luigi.IntParameter(default=180, description='Redis lock will be released after `redis_timeout` seconds', significant=False)
redis_fail_on_collision: bool = luigi.BoolParameter(
default=False,
description='True for failing the task immediately when the cache is locked, instead of waiting for the lock to be released',
significant=False)

fail_on_empty_dump: bool = ExplicitBoolParameter(default=False, description='Fail when task dumps empty DF', significant=False)
store_index_in_feather: bool = ExplicitBoolParameter(default=True,
description='Wether to store index when using feather as a output object.',
Expand All @@ -76,6 +74,9 @@ class TaskOnKart(luigi.Task):
description='Whether to dump supplementary files (task_log, random_seed, task_params, processing_time, module_versions) or not. \
Note that when set to False, task_info functions (e.g. gokart.tree.task_info.make_task_info_as_tree_str()) cannot be used.',
significant=False)
complete_check_at_run: bool = ExplicitBoolParameter(default=False,
description='Check if output file exists at run. If exists, run() will be skipped.',
significant=False)

def __init__(self, *args, **kwargs):
self._add_configuration(kwargs, 'TaskOnKart')
Expand All @@ -86,6 +87,9 @@ def __init__(self, *args, **kwargs):
self._rerun_state = self.rerun
self._lock_at_dump = True

if self.complete_check_at_run:
self.run = task_complete_check_wrapper(run_func=self.run, complete_check_func=self.complete)

def output(self):
return self.make_target()

Expand Down Expand Up @@ -169,7 +173,8 @@ def make_target(self, relative_file_path: Optional[str] = None, use_unique_id: b
redis_host=self.redis_host,
redis_port=self.redis_port,
redis_timeout=self.redis_timeout,
redis_fail_on_collision=self.redis_fail_on_collision)
raise_task_lock_exception_on_collision=False)

return gokart.target.make_target(file_path=file_path,
unique_id=unique_id,
processor=processor,
Expand All @@ -186,7 +191,8 @@ def make_large_data_frame_target(self, relative_file_path: Optional[str] = None,
redis_host=self.redis_host,
redis_port=self.redis_port,
redis_timeout=self.redis_timeout,
redis_fail_on_collision=self.redis_fail_on_collision)
raise_task_lock_exception_on_collision=False)

return gokart.target.make_model_target(file_path=file_path,
temporary_directory=self.local_temporary_directory,
unique_id=unique_id,
Expand Down Expand Up @@ -215,7 +221,8 @@ def make_model_target(self,
redis_host=self.redis_host,
redis_port=self.redis_port,
redis_timeout=self.redis_timeout,
redis_fail_on_collision=self.redis_fail_on_collision)
raise_task_lock_exception_on_collision=False)

return gokart.target.make_model_target(file_path=file_path,
temporary_directory=self.local_temporary_directory,
unique_id=unique_id,
Expand Down
17 changes: 17 additions & 0 deletions gokart/task_complete_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import functools
from logging import getLogger
from typing import Callable

logger = getLogger(__name__)


def task_complete_check_wrapper(run_func: Callable, complete_check_func: Callable):

@functools.wraps(run_func)
def wrapper(*args, **kwargs):
if complete_check_func():
logger.warning(f'{run_func.__name__} is skipped because the task is already completed.')
return
return run_func(*args, **kwargs)

return wrapper
Loading

0 comments on commit 1ac87d8

Please sign in to comment.