Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/ Runtime validation of TaskInstanceParameter() and ListTaskInstanceParameter() by subclass bound #305

Merged
merged 13 commits into from
Mar 2, 2023
29 changes: 29 additions & 0 deletions gokart/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,21 @@
import luigi
from luigi import task_register

import gokart

logger = getLogger(__name__)


class TaskInstanceParameter(luigi.Parameter):

def __init__(self, *args, **kwargs):
ujiuji1259 marked this conversation as resolved.
Show resolved Hide resolved
bound = kwargs.pop('bound', gokart.TaskOnKart)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you mean subtype bound, how about just use 'type' keyword? There might be both upper and lower for 'bound' keyword.

if isinstance(bound, type):
self._bound = bound
else:
raise ValueError(f'bound must be a type, not {type(bound)}')
super().__init__(*args, **kwargs)

@staticmethod
def _recursive(param_dict):
params = param_dict['params']
Expand All @@ -36,6 +46,11 @@ def serialize(self, x):
values = dict(type=x.get_task_family(), params=params)
return luigi.DictParameter().serialize(values)

def normalize(self, v):
if not isinstance(v, self._bound):
raise ValueError(f'{v} is not an instance of {self._bound}')
ujiuji1259 marked this conversation as resolved.
Show resolved Hide resolved
return v


class _TaskInstanceEncoder(json.JSONEncoder):

Expand All @@ -48,12 +63,26 @@ def default(self, obj):

class ListTaskInstanceParameter(luigi.Parameter):

def __init__(self, *args, **kwargs):
bound = kwargs.pop('bound', gokart.TaskOnKart)
if isinstance(bound, type):
self._bound = bound
else:
raise ValueError(f'bound must be a type, not {type(bound)}')
super().__init__(*args, **kwargs)

def parse(self, s):
return [TaskInstanceParameter().parse(x) for x in list(json.loads(s))]

def serialize(self, x):
return json.dumps(x, cls=_TaskInstanceEncoder)

def normalize(self, values):
for v in values:
if not isinstance(v, self._bound):
raise ValueError(f'{v} is not an instance of {self._bound}')
return values


class ExplicitBoolParameter(luigi.BoolParameter):

Expand Down
58 changes: 58 additions & 0 deletions test/test_task_instance_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@ class _DummySubTask(TaskOnKart):
pass


class _DummyCorrectSubClassTask(_DummySubTask):
task_namespace = __name__
pass


class _DummyInvalidSubClassTask(TaskOnKart):
task_namespace = __name__
pass


class _DummyTask(TaskOnKart):
task_namespace = __name__
param = luigi.IntParameter()
Expand Down Expand Up @@ -40,6 +50,54 @@ def test_serialize_and_parse_list_params(self):
parsed = gokart.TaskInstanceParameter().parse(s)
self.assertEqual(parsed.task_id, original.task_id)

def test_invalid_bound(self):
self.assertRaises(ValueError, lambda: gokart.TaskInstanceParameter(bound=1)) # not type instance

def test_params_with_correct_subclass_bound(self):

class _DummyPipelineA(TaskOnKart):
task_namespace = __name__
subtask = gokart.TaskInstanceParameter(bound=_DummySubTask)

task = _DummyPipelineA(subtask=_DummyCorrectSubClassTask())
self.assertEqual(task.requires()['subtask'], _DummyCorrectSubClassTask())

def test_params_with_invalid_subclass_bound(self):

class _DummyPipelineB(TaskOnKart):
task_namespace = __name__
subtask = gokart.TaskInstanceParameter(bound=_DummySubTask)

with self.assertRaises(ValueError):
_DummyPipelineB(subtask=_DummyInvalidSubClassTask())


class ListTaskInstanceParameterTest(unittest.TestCase):

def setUp(self):
_DummyTask.clear_instance_cache()

def test_invalid_bound(self):
self.assertRaises(ValueError, lambda: gokart.ListTaskInstanceParameter(bound=1)) # not type instance

def test_list_params_with_correct_subclass_bound(self):

class _DummyPipelineC(TaskOnKart):
task_namespace = __name__
subtask = gokart.ListTaskInstanceParameter(bound=_DummySubTask)

task = _DummyPipelineC(subtask=[_DummyCorrectSubClassTask()])
self.assertEqual(task.requires()['subtask'], (_DummyCorrectSubClassTask(), ))

def test_list_params_with_invalid_subclass_bound(self):

class _DummyPipelineD(TaskOnKart):
task_namespace = __name__
subtask = gokart.ListTaskInstanceParameter(bound=_DummySubTask)

with self.assertRaises(ValueError):
_DummyPipelineD(subtask=[_DummyInvalidSubClassTask(), _DummyCorrectSubClassTask])


if __name__ == '__main__':
unittest.main()