Skip to content

Commit

Permalink
Add complete_check_at_run parameter (#333)
Browse files Browse the repository at this point in the history
* add

* add

* add
  • Loading branch information
mski-iksm authored Nov 21, 2023
1 parent 06bc372 commit 919b4c7
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 0 deletions.
7 changes: 7 additions & 0 deletions gokart/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, TaskInstanceParameter
from gokart.redis_lock import make_redis_params
from gokart.target import TargetOnKart
from gokart.task_complete_check import task_complete_check_wrapper

logger = getLogger(__name__)

Expand Down Expand Up @@ -76,6 +77,9 @@ class TaskOnKart(luigi.Task):
description='Whether to dump supplementary files (task_log, random_seed, task_params, processing_time, module_versions) or not. \
Note that when set to False, task_info functions (e.g. gokart.tree.task_info.make_task_info_as_tree_str()) cannot be used.',
significant=False)
complete_check_at_run: bool = ExplicitBoolParameter(default=False,
description='Check if output file exists at run. If exists, run() will be skipped.',
significant=False)

def __init__(self, *args, **kwargs):
self._add_configuration(kwargs, 'TaskOnKart')
Expand All @@ -86,6 +90,9 @@ def __init__(self, *args, **kwargs):
self._rerun_state = self.rerun
self._lock_at_dump = True

if self.complete_check_at_run:
self.run = task_complete_check_wrapper(run_func=self.run, complete_check_func=self.complete)

def output(self):
return self.make_target()

Expand Down
15 changes: 15 additions & 0 deletions gokart/task_complete_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from logging import getLogger
from typing import Callable

logger = getLogger(__name__)


def task_complete_check_wrapper(run_func: Callable, complete_check_func: Callable):

def wrapper(*args, **kwargs):
if complete_check_func():
logger.warning(f'{run_func.__name__} is skipped because the task is already completed.')
return
return run_func(*args, **kwargs)

return wrapper
61 changes: 61 additions & 0 deletions test/test_task_on_kart.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,5 +561,66 @@ def test_serialize_and_deserialize_default_values(self):
self.assertDictEqual(task.to_str_params(), deserialized.to_str_params())


class _DummyTaskWithNonCompleted(gokart.TaskOnKart):

def dump(self, obj):
# overrive dump() to do nothing.
pass

def run(self):
self.dump('hello')

def complete(self):
return False


class _DummyTaskWithCompleted(gokart.TaskOnKart):

def dump(self, obj):
# overrive dump() to do nothing.
pass

def run(self):
self.dump('hello')

def complete(self):
return True


class TestCompleteCheckAtRun(unittest.TestCase):

def test_run_when_complete_check_at_run_is_false_and_task_is_not_completed(self):
task = _DummyTaskWithNonCompleted(complete_check_at_run=False)
task.dump = MagicMock()
task.run()

# since run() is called, dump() should be called.
task.dump.assert_called_once()

def test_run_when_complete_check_at_run_is_false_and_task_is_completed(self):
task = _DummyTaskWithCompleted(complete_check_at_run=False)
task.dump = MagicMock()
task.run()

# even task is completed, since run() is called, dump() should be called.
task.dump.assert_called_once()

def test_run_when_complete_check_at_run_is_true_and_task_is_not_completed(self):
task = _DummyTaskWithNonCompleted(complete_check_at_run=True)
task.dump = MagicMock()
task.run()

# since task is not completed, when run() is called, dump() should be called.
task.dump.assert_called_once()

def test_run_when_complete_check_at_run_is_true_and_task_is_completed(self):
task = _DummyTaskWithCompleted(complete_check_at_run=True)
task.dump = MagicMock()
task.run()

# since task is completed, even when run() is called, dump() should not be called.
task.dump.assert_not_called()


if __name__ == '__main__':
unittest.main()

0 comments on commit 919b4c7

Please sign in to comment.