diff --git a/gokart/conflict_prevention_lock/task_lock_wrappers.py b/gokart/conflict_prevention_lock/task_lock_wrappers.py index cb7c5d1e..1aee2552 100644 --- a/gokart/conflict_prevention_lock/task_lock_wrappers.py +++ b/gokart/conflict_prevention_lock/task_lock_wrappers.py @@ -1,6 +1,6 @@ import functools from logging import getLogger -from typing import Any, Callable +from typing import Callable from gokart.conflict_prevention_lock.task_lock import TaskLockParams, set_lock_scheduler, set_task_lock @@ -83,19 +83,18 @@ def wrapper(*args, **kwargs): return wrapper -def wrap_run_with_lock(run_func: Callable[[], Any], task_lock_params: TaskLockParams): +def wrap_run_with_lock(run_func: Callable[[], None], task_lock_params: TaskLockParams) -> Callable[[], None]: @functools.wraps(run_func) - def wrapped(): + def wrapped() -> None: task_lock = set_task_lock(task_lock_params=task_lock_params) scheduler = set_lock_scheduler(task_lock=task_lock, task_lock_params=task_lock_params) try: logger.debug(f'Task RUN lock of {task_lock_params.redis_key} locked.') - result = run_func() + run_func() task_lock.release() logger.debug(f'Task RUN lock of {task_lock_params.redis_key} released.') scheduler.shutdown() - return result except BaseException as e: logger.debug(f'Task RUN lock of {task_lock_params.redis_key} released with BaseException.') task_lock.release() diff --git a/gokart/task.py b/gokart/task.py index 2f351d2a..d4a3e015 100644 --- a/gokart/task.py +++ b/gokart/task.py @@ -3,9 +3,10 @@ import os import random import types +from functools import partial from importlib import import_module from logging import getLogger -from typing import Any, Callable, Dict, Generator, Generic, Iterable, List, Optional, Set, TypeVar, Union, overload +from typing import Any, Callable, Dict, Generator, Generic, Iterable, List, Optional, Protocol, Set, TypeVar, Union, overload import luigi import pandas as pd @@ -29,6 +30,10 @@ K = TypeVar('K') +class RunWrapperFunc(Protocol): + def __call__(self, run_func: Callable[[], None]) -> Callable[[], None]: ... + + class TaskOnKart(luigi.Task, Generic[T]): """ This is a wrapper class of luigi.Task. @@ -111,16 +116,23 @@ def __init__(self, *args, **kwargs): super(TaskOnKart, self).__init__(*args, **kwargs) self._rerun_state = self.rerun self._lock_at_dump = True + # store callbacks to wrap `run` + self._wrapper_funcs: List[RunWrapperFunc] = [] if self.complete_check_at_run: - self.run = task_complete_check_wrapper(run_func=self.run, complete_check_func=self.complete) # type: ignore + self._wrapper_funcs.append(partial(task_complete_check_wrapper, complete_check_func=self.complete)) if self.should_lock_run: self._lock_at_dump = False assert self.redis_host is not None, 'redis_host must be set when should_lock_run is True.' assert self.redis_port is not None, 'redis_port must be set when should_lock_run is True.' task_lock_params = make_task_lock_params_for_run(task_self=self) - self.run = wrap_run_with_lock(run_func=self.run, task_lock_params=task_lock_params) # type: ignore + self._wrapper_funcs.append(partial(wrap_run_with_lock, task_lock_params=task_lock_params)) + + wrapped_run = self.run + for func in self._wrapper_funcs[::-1]: + wrapped_run = func(wrapped_run) + self.run = wrapped_run # type: ignore def input(self) -> FlattenableItems[TargetOnKart]: return super().input() @@ -566,3 +578,29 @@ def _make_representation(self, param_obj: luigi.Parameter, param_value): if isinstance(param_obj, ListTaskInstanceParameter): return f"[{', '.join(f'{v.get_task_family()}({v.make_unique_id()})' for v in param_value)}]" return param_obj.serialize(param_value) + + def __getstate__(self): + """NOTE: overwrite __getstate__ to avoid pickling error + + `run` method is wrapped by some functions with instance parameters. + This shows `it's not the same object as` error when pickling. + """ + state = self.__dict__.copy() + while hasattr(state['run'], '__wrapped__'): + state['run'] = state['run'].__wrapped__ + return state + + def __setstate__(self, state): + """NOTE: overwrite __setstate__ to avoid pickling error + + `run` method is wrapped by some functions with instance parameters. + This shows `it's not the same object as` error when pickling. + """ + run = state.pop('run') + if '_wrapper_funcs' not in state: + return self.__dict__.update(state) + + for func in state['_wrapper_funcs'][::-1]: + run = func(run) + state['run'] = run + self.__dict__.update(state) diff --git a/gokart/task_complete_check.py b/gokart/task_complete_check.py index 53c9f92d..f6902114 100644 --- a/gokart/task_complete_check.py +++ b/gokart/task_complete_check.py @@ -5,9 +5,9 @@ logger = getLogger(__name__) -def task_complete_check_wrapper(run_func: Callable, complete_check_func: Callable): +def task_complete_check_wrapper(run_func: Callable[[], None], complete_check_func: Callable[[], bool]) -> Callable[[], None]: @functools.wraps(run_func) - def wrapper(*args, **kwargs): + def wrapper(*args, **kwargs) -> None: if complete_check_func(): logger.warning(f'{run_func.__name__} is skipped because the task is already completed.') return diff --git a/test/test_task_on_kart.py b/test/test_task_on_kart.py index e3946b49..86a8a682 100644 --- a/test/test_task_on_kart.py +++ b/test/test_task_on_kart.py @@ -1,5 +1,6 @@ import os import pathlib +import pickle import unittest from datetime import datetime from typing import Any, Dict, List, cast @@ -657,5 +658,18 @@ def test_run_when_complete_check_at_run_is_true_and_task_is_completed(self): task.dump.assert_not_called() +class TestPickleTaskOnKart: + def test_pickle_and_unpickle(self): + task = _DummyTask(redis_host='0.0.0.0', redis_port=12345, redis_timeout=180, should_lock_run=True) + pickled = pickle.dumps(task) + unpickled = pickle.loads(pickled) + assert task.to_str_params() == unpickled.to_str_params() + + task = _DummyTask(should_lock_run=False) + pickled = pickle.dumps(task) + unpickled = pickle.loads(pickled) + assert task.to_str_params() == unpickled.to_str_params() + + if __name__ == '__main__': unittest.main()