From 939cfa74256c9386935a925d267f025e4399fc6a Mon Sep 17 00:00:00 2001 From: ujiuji1259 Date: Sat, 4 Feb 2023 20:05:27 +0900 Subject: [PATCH 01/13] add subclass bound --- gokart/parameter.py | 29 +++++++++++++ test/test_task_instance_parameter.py | 61 ++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+) diff --git a/gokart/parameter.py b/gokart/parameter.py index 50e7009c..f07254d9 100644 --- a/gokart/parameter.py +++ b/gokart/parameter.py @@ -1,4 +1,5 @@ import bz2 +import gokart import json from logging import getLogger @@ -10,6 +11,14 @@ class TaskInstanceParameter(luigi.Parameter): + def __init__(self, *args, **kwargs): + bound = kwargs.pop('bound', gokart.TaskOnKart) + if isinstance(bound, type): + self._bound = bound + else: + raise ValueError(f'bound must be a type, not {type(bound)}') + super().__init__(*args, **kwargs) + @staticmethod def _recursive(param_dict): params = param_dict['params'] @@ -36,6 +45,12 @@ def serialize(self, x): values = dict(type=x.get_task_family(), params=params) return luigi.DictParameter().serialize(values) + def normalize(self, v): + if not isinstance(v, self._bound): + # if self._bound in v.mro(): + raise ValueError(f'{v} is not an instance of {self._bound}') + return v + class _TaskInstanceEncoder(json.JSONEncoder): @@ -48,12 +63,26 @@ def default(self, obj): class ListTaskInstanceParameter(luigi.Parameter): + def __init__(self, *args, **kwargs): + bound = kwargs.pop('bound', gokart.TaskOnKart) + if isinstance(bound, type): + self._bound = bound + else: + raise ValueError(f'bound must be a type, not {type(bound)}') + super().__init__(*args, **kwargs) + def parse(self, s): return [TaskInstanceParameter().parse(x) for x in list(json.loads(s))] def serialize(self, x): return json.dumps(x, cls=_TaskInstanceEncoder) + def normalize(self, values): + for v in values: + if not isinstance(v, self._bound): + raise ValueError(f'{v} is not an instance of {self._bound}') + return values + class ExplicitBoolParameter(luigi.BoolParameter): diff --git a/test/test_task_instance_parameter.py b/test/test_task_instance_parameter.py index a7be296f..6ee68680 100644 --- a/test/test_task_instance_parameter.py +++ b/test/test_task_instance_parameter.py @@ -40,6 +40,67 @@ def test_serialize_and_parse_list_params(self): parsed = gokart.TaskInstanceParameter().parse(s) self.assertEqual(parsed.task_id, original.task_id) + def test_invalid_bound(self): + self.assertRaises(ValueError, lambda: gokart.TaskInstanceParameter(bound=1)) # not type instance + + def test_params_with_correct_subclass_bound(self): + class _DummyCorrectSubTask(_DummySubTask): + task_namespace = __name__ + pass + + class TaskA(TaskOnKart): + subtask = gokart.TaskInstanceParameter(bound=_DummySubTask) + + task = TaskA(subtask=_DummyCorrectSubTask()) + self.assertEqual(task.requires()['subtask'], _DummyCorrectSubTask()) + + def test_params_with_invalid_subclass_bound(self): + class _DummyInvalidSubTask(TaskOnKart): + task_namespace = __name__ + pass + + class TaskA(TaskOnKart): + subtask = gokart.TaskInstanceParameter(bound=_DummySubTask) + + with self.assertRaises(ValueError): + TaskA(subtask=_DummyInvalidSubTask()) + + +class ListTaskInstanceParameterTest(unittest.TestCase): + + def setUp(self): + _DummyTask.clear_instance_cache() + + def test_invalid_bound(self): + self.assertRaises(ValueError, lambda: gokart.ListTaskInstanceParameter(bound=1)) # not type instance + + def test_list_params_with_correct_subclass_bound(self): + class _DummyCorrectSubTask(_DummySubTask): + task_namespace = __name__ + pass + + class TaskA(TaskOnKart): + subtask = gokart.ListTaskInstanceParameter(bound=_DummySubTask) + + task = TaskA(subtask=[_DummyCorrectSubTask()]) + self.assertEqual(task.requires()['subtask'], (_DummyCorrectSubTask(),)) + + def test_list_params_with_invalid_subclass_bound(self): + class _DummyCorrectSubTask(_DummySubTask): + task_namespace = __name__ + pass + + class _DummyInvalidSubTask(TaskOnKart): + task_namespace = __name__ + pass + + class TaskA(TaskOnKart): + subtask = gokart.ListTaskInstanceParameter(bound=_DummySubTask) + + with self.assertRaises(ValueError): + TaskA(subtask=[_DummyInvalidSubTask(), _DummyCorrectSubTask]) + + if __name__ == '__main__': unittest.main() From 8beffb5fe42c64bd61e3d352241e5a298636dd38 Mon Sep 17 00:00:00 2001 From: ujiuji1259 Date: Sat, 4 Feb 2023 20:30:37 +0900 Subject: [PATCH 02/13] fix lint --- gokart/parameter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gokart/parameter.py b/gokart/parameter.py index f07254d9..0c30e578 100644 --- a/gokart/parameter.py +++ b/gokart/parameter.py @@ -1,11 +1,12 @@ import bz2 -import gokart import json from logging import getLogger import luigi from luigi import task_register +import gokart + logger = getLogger(__name__) @@ -47,7 +48,6 @@ def serialize(self, x): def normalize(self, v): if not isinstance(v, self._bound): - # if self._bound in v.mro(): raise ValueError(f'{v} is not an instance of {self._bound}') return v From 9d48aa4fc9e49bb1d495c63481c9148b0a275c76 Mon Sep 17 00:00:00 2001 From: ujiuji1259 Date: Sat, 4 Feb 2023 20:50:32 +0900 Subject: [PATCH 03/13] fix test --- test/test_task_instance_parameter.py | 47 +++++++++++++--------------- 1 file changed, 21 insertions(+), 26 deletions(-) diff --git a/test/test_task_instance_parameter.py b/test/test_task_instance_parameter.py index 6ee68680..1d44ae5e 100644 --- a/test/test_task_instance_parameter.py +++ b/test/test_task_instance_parameter.py @@ -11,6 +11,16 @@ class _DummySubTask(TaskOnKart): pass +class _DummyCorrectSubClassTask(_DummySubTask): + task_namespace = __name__ + pass + + +class _DummyInvalidSubClassTask(TaskOnKart): + task_namespace = __name__ + pass + + class _DummyTask(TaskOnKart): task_namespace = __name__ param = luigi.IntParameter() @@ -44,26 +54,21 @@ def test_invalid_bound(self): self.assertRaises(ValueError, lambda: gokart.TaskInstanceParameter(bound=1)) # not type instance def test_params_with_correct_subclass_bound(self): - class _DummyCorrectSubTask(_DummySubTask): + class _DummyPipelineA(TaskOnKart): task_namespace = __name__ - pass - - class TaskA(TaskOnKart): subtask = gokart.TaskInstanceParameter(bound=_DummySubTask) - task = TaskA(subtask=_DummyCorrectSubTask()) - self.assertEqual(task.requires()['subtask'], _DummyCorrectSubTask()) + task = _DummyPipelineA(subtask=_DummyCorrectSubClassTask()) + self.assertEqual(task.requires()['subtask'], _DummyCorrectSubClassTask()) def test_params_with_invalid_subclass_bound(self): - class _DummyInvalidSubTask(TaskOnKart): - task_namespace = __name__ - pass - class TaskA(TaskOnKart): + class _DummyPipelineB(TaskOnKart): + task_namespace = __name__ subtask = gokart.TaskInstanceParameter(bound=_DummySubTask) with self.assertRaises(ValueError): - TaskA(subtask=_DummyInvalidSubTask()) + _DummyPipelineB(subtask=_DummyInvalidSubClassTask()) class ListTaskInstanceParameterTest(unittest.TestCase): @@ -75,30 +80,20 @@ def test_invalid_bound(self): self.assertRaises(ValueError, lambda: gokart.ListTaskInstanceParameter(bound=1)) # not type instance def test_list_params_with_correct_subclass_bound(self): - class _DummyCorrectSubTask(_DummySubTask): + class _DummyPipelineC(TaskOnKart): task_namespace = __name__ - pass - - class TaskA(TaskOnKart): subtask = gokart.ListTaskInstanceParameter(bound=_DummySubTask) - task = TaskA(subtask=[_DummyCorrectSubTask()]) - self.assertEqual(task.requires()['subtask'], (_DummyCorrectSubTask(),)) + task = _DummyPipelineC(subtask=[_DummyCorrectSubClassTask()]) + self.assertEqual(task.requires()['subtask'], (_DummyCorrectSubClassTask(),)) def test_list_params_with_invalid_subclass_bound(self): - class _DummyCorrectSubTask(_DummySubTask): + class _DummyPipelineD(TaskOnKart): task_namespace = __name__ - pass - - class _DummyInvalidSubTask(TaskOnKart): - task_namespace = __name__ - pass - - class TaskA(TaskOnKart): subtask = gokart.ListTaskInstanceParameter(bound=_DummySubTask) with self.assertRaises(ValueError): - TaskA(subtask=[_DummyInvalidSubTask(), _DummyCorrectSubTask]) + _DummyPipelineD(subtask=[_DummyInvalidSubClassTask(), _DummyCorrectSubClassTask]) From 8b3f97c9d6c5b68171db15c7ed07673f3e6c3926 Mon Sep 17 00:00:00 2001 From: ujiuji1259 Date: Sat, 4 Feb 2023 20:53:15 +0900 Subject: [PATCH 04/13] fix test --- test/test_task_instance_parameter.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_task_instance_parameter.py b/test/test_task_instance_parameter.py index 1d44ae5e..eac93b38 100644 --- a/test/test_task_instance_parameter.py +++ b/test/test_task_instance_parameter.py @@ -54,6 +54,7 @@ def test_invalid_bound(self): self.assertRaises(ValueError, lambda: gokart.TaskInstanceParameter(bound=1)) # not type instance def test_params_with_correct_subclass_bound(self): + class _DummyPipelineA(TaskOnKart): task_namespace = __name__ subtask = gokart.TaskInstanceParameter(bound=_DummySubTask) @@ -80,6 +81,7 @@ def test_invalid_bound(self): self.assertRaises(ValueError, lambda: gokart.ListTaskInstanceParameter(bound=1)) # not type instance def test_list_params_with_correct_subclass_bound(self): + class _DummyPipelineC(TaskOnKart): task_namespace = __name__ subtask = gokart.ListTaskInstanceParameter(bound=_DummySubTask) @@ -88,6 +90,7 @@ class _DummyPipelineC(TaskOnKart): self.assertEqual(task.requires()['subtask'], (_DummyCorrectSubClassTask(),)) def test_list_params_with_invalid_subclass_bound(self): + class _DummyPipelineD(TaskOnKart): task_namespace = __name__ subtask = gokart.ListTaskInstanceParameter(bound=_DummySubTask) @@ -96,6 +99,5 @@ class _DummyPipelineD(TaskOnKart): _DummyPipelineD(subtask=[_DummyInvalidSubClassTask(), _DummyCorrectSubClassTask]) - if __name__ == '__main__': unittest.main() From 03af64815e56697eef9c0225b00cc6da42b1c363 Mon Sep 17 00:00:00 2001 From: ujiuji1259 Date: Sat, 4 Feb 2023 20:54:30 +0900 Subject: [PATCH 05/13] fix lint --- test/test_task_instance_parameter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_task_instance_parameter.py b/test/test_task_instance_parameter.py index eac93b38..fd5b660c 100644 --- a/test/test_task_instance_parameter.py +++ b/test/test_task_instance_parameter.py @@ -87,7 +87,7 @@ class _DummyPipelineC(TaskOnKart): subtask = gokart.ListTaskInstanceParameter(bound=_DummySubTask) task = _DummyPipelineC(subtask=[_DummyCorrectSubClassTask()]) - self.assertEqual(task.requires()['subtask'], (_DummyCorrectSubClassTask(),)) + self.assertEqual(task.requires()['subtask'], (_DummyCorrectSubClassTask(), )) def test_list_params_with_invalid_subclass_bound(self): From 6ebd6f8291fc6b8703308d8081194e76ab180e2a Mon Sep 17 00:00:00 2001 From: ujiuji1259 Date: Tue, 21 Feb 2023 21:06:59 +0900 Subject: [PATCH 06/13] change args --- gokart/parameter.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/gokart/parameter.py b/gokart/parameter.py index 0c30e578..ac6f6dca 100644 --- a/gokart/parameter.py +++ b/gokart/parameter.py @@ -12,8 +12,7 @@ class TaskInstanceParameter(luigi.Parameter): - def __init__(self, *args, **kwargs): - bound = kwargs.pop('bound', gokart.TaskOnKart) + def __init__(self, bound=gokart.TaskOnKart, *args, **kwargs): if isinstance(bound, type): self._bound = bound else: @@ -63,8 +62,7 @@ def default(self, obj): class ListTaskInstanceParameter(luigi.Parameter): - def __init__(self, *args, **kwargs): - bound = kwargs.pop('bound', gokart.TaskOnKart) + def __init__(self, bound=gokart.TaskOnKart, *args, **kwargs): if isinstance(bound, type): self._bound = bound else: From ebee6c40fbd30f9dd6e0f753c237b15a2f588c9b Mon Sep 17 00:00:00 2001 From: ujiuji1259 Date: Tue, 21 Feb 2023 22:25:18 +0900 Subject: [PATCH 07/13] fix test name --- gokart/parameter.py | 38 +++++++++++++++------------- test/test_task_instance_parameter.py | 28 ++++++++++---------- 2 files changed, 35 insertions(+), 31 deletions(-) diff --git a/gokart/parameter.py b/gokart/parameter.py index ac6f6dca..d0518c2f 100644 --- a/gokart/parameter.py +++ b/gokart/parameter.py @@ -12,11 +12,12 @@ class TaskInstanceParameter(luigi.Parameter): - def __init__(self, bound=gokart.TaskOnKart, *args, **kwargs): - if isinstance(bound, type): - self._bound = bound + def __init__(self, *args, **kwargs): + expected_type = kwargs.pop('expected_type', gokart.TaskOnKart) + if isinstance(expected_type, type): + self.expected_type = expected_type else: - raise ValueError(f'bound must be a type, not {type(bound)}') + raise TypeError(f'expected_type must be a type, not {type(expected_type)}') super().__init__(*args, **kwargs) @staticmethod @@ -45,10 +46,11 @@ def serialize(self, x): values = dict(type=x.get_task_family(), params=params) return luigi.DictParameter().serialize(values) - def normalize(self, v): - if not isinstance(v, self._bound): - raise ValueError(f'{v} is not an instance of {self._bound}') - return v + def _warn_on_wrong_param_type(self, param_name, param_value): + if self.__class__ != TaskInstanceParameter: + return + if not isinstance(param_value, self.expected_type): + raise TypeError(f'{param_value} is not an instance of {self.expected_type}') class _TaskInstanceEncoder(json.JSONEncoder): @@ -62,11 +64,12 @@ def default(self, obj): class ListTaskInstanceParameter(luigi.Parameter): - def __init__(self, bound=gokart.TaskOnKart, *args, **kwargs): - if isinstance(bound, type): - self._bound = bound + def __init__(self, *args, **kwargs): + expected_type = kwargs.pop('expected_type', gokart.TaskOnKart) + if isinstance(expected_type, type): + self.expected_type = expected_type else: - raise ValueError(f'bound must be a type, not {type(bound)}') + raise TypeError(f'expected_type must be a type, not {type(expected_type)}') super().__init__(*args, **kwargs) def parse(self, s): @@ -75,11 +78,12 @@ def parse(self, s): def serialize(self, x): return json.dumps(x, cls=_TaskInstanceEncoder) - def normalize(self, values): - for v in values: - if not isinstance(v, self._bound): - raise ValueError(f'{v} is not an instance of {self._bound}') - return values + def _warn_on_wrong_param_type(self, param_name, param_value): + if self.__class__ != ListTaskInstanceParameter: + return + for v in param_value: + if not isinstance(v, self.expected_type): + raise TypeError(f'{v} is not an instance of {self.expected_type}') class ExplicitBoolParameter(luigi.BoolParameter): diff --git a/test/test_task_instance_parameter.py b/test/test_task_instance_parameter.py index fd5b660c..f1b1f5b0 100644 --- a/test/test_task_instance_parameter.py +++ b/test/test_task_instance_parameter.py @@ -50,25 +50,25 @@ def test_serialize_and_parse_list_params(self): parsed = gokart.TaskInstanceParameter().parse(s) self.assertEqual(parsed.task_id, original.task_id) - def test_invalid_bound(self): - self.assertRaises(ValueError, lambda: gokart.TaskInstanceParameter(bound=1)) # not type instance + def test_invalid_class(self): + self.assertRaises(TypeError, lambda: gokart.TaskInstanceParameter(expected_type=1)) # not type instance - def test_params_with_correct_subclass_bound(self): + def test_params_with_correct_param_type(self): class _DummyPipelineA(TaskOnKart): task_namespace = __name__ - subtask = gokart.TaskInstanceParameter(bound=_DummySubTask) + subtask = gokart.TaskInstanceParameter(expected_type=_DummySubTask) task = _DummyPipelineA(subtask=_DummyCorrectSubClassTask()) self.assertEqual(task.requires()['subtask'], _DummyCorrectSubClassTask()) - def test_params_with_invalid_subclass_bound(self): + def test_params_with_invalid_param_type(self): class _DummyPipelineB(TaskOnKart): task_namespace = __name__ - subtask = gokart.TaskInstanceParameter(bound=_DummySubTask) + subtask = gokart.TaskInstanceParameter(expected_type=_DummySubTask) - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): _DummyPipelineB(subtask=_DummyInvalidSubClassTask()) @@ -77,25 +77,25 @@ class ListTaskInstanceParameterTest(unittest.TestCase): def setUp(self): _DummyTask.clear_instance_cache() - def test_invalid_bound(self): - self.assertRaises(ValueError, lambda: gokart.ListTaskInstanceParameter(bound=1)) # not type instance + def test_invalid_class(self): + self.assertRaises(TypeError, lambda: gokart.ListTaskInstanceParameter(expected_type=1)) # not type instance - def test_list_params_with_correct_subclass_bound(self): + def test_list_params_with_correct_param_types(self): class _DummyPipelineC(TaskOnKart): task_namespace = __name__ - subtask = gokart.ListTaskInstanceParameter(bound=_DummySubTask) + subtask = gokart.ListTaskInstanceParameter(expected_type=_DummySubTask) task = _DummyPipelineC(subtask=[_DummyCorrectSubClassTask()]) self.assertEqual(task.requires()['subtask'], (_DummyCorrectSubClassTask(), )) - def test_list_params_with_invalid_subclass_bound(self): + def test_list_params_with_invalid_param_types(self): class _DummyPipelineD(TaskOnKart): task_namespace = __name__ - subtask = gokart.ListTaskInstanceParameter(bound=_DummySubTask) + subtask = gokart.ListTaskInstanceParameter(expected_type=_DummySubTask) - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): _DummyPipelineD(subtask=[_DummyInvalidSubClassTask(), _DummyCorrectSubClassTask]) From ab66041b428638e60521a2f7000fb12664020ce8 Mon Sep 17 00:00:00 2001 From: ujiuji1259 Date: Wed, 22 Feb 2023 23:11:59 +0900 Subject: [PATCH 08/13] fix --- gokart/parameter.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/gokart/parameter.py b/gokart/parameter.py index d0518c2f..a087444a 100644 --- a/gokart/parameter.py +++ b/gokart/parameter.py @@ -12,11 +12,12 @@ class TaskInstanceParameter(luigi.Parameter): - def __init__(self, *args, **kwargs): - expected_type = kwargs.pop('expected_type', gokart.TaskOnKart) - if isinstance(expected_type, type): + def __init__(self, expected_type=None, *args, **kwargs): + if expected_type is None: + self.expected_type = gokart.TaskOnKart + elif isinstance(expected_type, type): self.expected_type = expected_type - else: + elif expected_type is not None: raise TypeError(f'expected_type must be a type, not {type(expected_type)}') super().__init__(*args, **kwargs) @@ -64,11 +65,12 @@ def default(self, obj): class ListTaskInstanceParameter(luigi.Parameter): - def __init__(self, *args, **kwargs): - expected_type = kwargs.pop('expected_type', gokart.TaskOnKart) - if isinstance(expected_type, type): + def __init__(self, expected_type=None, *args, **kwargs): + if expected_type is None: + self.expected_type = gokart.TaskOnKart + elif isinstance(expected_type, type): self.expected_type = expected_type - else: + elif expected_type is not None: raise TypeError(f'expected_type must be a type, not {type(expected_type)}') super().__init__(*args, **kwargs) From fd3bfc152e1e2db2083b8a18f66841c705ded9e4 Mon Sep 17 00:00:00 2001 From: ujiuji1259 Date: Thu, 2 Mar 2023 20:12:13 +0900 Subject: [PATCH 09/13] fix --- gokart/parameter.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/gokart/parameter.py b/gokart/parameter.py index a087444a..17859ee4 100644 --- a/gokart/parameter.py +++ b/gokart/parameter.py @@ -48,8 +48,6 @@ def serialize(self, x): return luigi.DictParameter().serialize(values) def _warn_on_wrong_param_type(self, param_name, param_value): - if self.__class__ != TaskInstanceParameter: - return if not isinstance(param_value, self.expected_type): raise TypeError(f'{param_value} is not an instance of {self.expected_type}') @@ -65,13 +63,13 @@ def default(self, obj): class ListTaskInstanceParameter(luigi.Parameter): - def __init__(self, expected_type=None, *args, **kwargs): - if expected_type is None: - self.expected_type = gokart.TaskOnKart - elif isinstance(expected_type, type): - self.expected_type = expected_type - elif expected_type is not None: - raise TypeError(f'expected_type must be a type, not {type(expected_type)}') + def __init__(self, expected_element_type=None, *args, **kwargs): + if expected_element_type is None: + self.expected_element_type = gokart.TaskOnKart + elif isinstance(expected_element_type, type): + self.expected_element_type = expected_element_type + elif expected_element_type is not None: + raise TypeError(f'expected_element_type must be a type, not {type(expected_element_type)}') super().__init__(*args, **kwargs) def parse(self, s): @@ -81,11 +79,9 @@ def serialize(self, x): return json.dumps(x, cls=_TaskInstanceEncoder) def _warn_on_wrong_param_type(self, param_name, param_value): - if self.__class__ != ListTaskInstanceParameter: - return for v in param_value: - if not isinstance(v, self.expected_type): - raise TypeError(f'{v} is not an instance of {self.expected_type}') + if not isinstance(v, self.expected_element_type): + raise TypeError(f'{v} is not an instance of {self.expected_element_type}') class ExplicitBoolParameter(luigi.BoolParameter): From 1372984795cd3c238dc32973218ae5c999020528 Mon Sep 17 00:00:00 2001 From: ujiuji1259 Date: Thu, 2 Mar 2023 20:13:14 +0900 Subject: [PATCH 10/13] fix --- test/test_task_instance_parameter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_task_instance_parameter.py b/test/test_task_instance_parameter.py index f1b1f5b0..8350bccc 100644 --- a/test/test_task_instance_parameter.py +++ b/test/test_task_instance_parameter.py @@ -84,7 +84,7 @@ def test_list_params_with_correct_param_types(self): class _DummyPipelineC(TaskOnKart): task_namespace = __name__ - subtask = gokart.ListTaskInstanceParameter(expected_type=_DummySubTask) + subtask = gokart.ListTaskInstanceParameter(expected_element_type=_DummySubTask) task = _DummyPipelineC(subtask=[_DummyCorrectSubClassTask()]) self.assertEqual(task.requires()['subtask'], (_DummyCorrectSubClassTask(), )) @@ -93,7 +93,7 @@ def test_list_params_with_invalid_param_types(self): class _DummyPipelineD(TaskOnKart): task_namespace = __name__ - subtask = gokart.ListTaskInstanceParameter(expected_type=_DummySubTask) + subtask = gokart.ListTaskInstanceParameter(expected_element_type=_DummySubTask) with self.assertRaises(TypeError): _DummyPipelineD(subtask=[_DummyInvalidSubClassTask(), _DummyCorrectSubClassTask]) From ad6e70496fe753ff426d7a6593e12158fde0ea06 Mon Sep 17 00:00:00 2001 From: ujiuji1259 Date: Thu, 2 Mar 2023 20:15:28 +0900 Subject: [PATCH 11/13] fix else --- gokart/parameter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gokart/parameter.py b/gokart/parameter.py index 17859ee4..4e590206 100644 --- a/gokart/parameter.py +++ b/gokart/parameter.py @@ -17,7 +17,7 @@ def __init__(self, expected_type=None, *args, **kwargs): self.expected_type = gokart.TaskOnKart elif isinstance(expected_type, type): self.expected_type = expected_type - elif expected_type is not None: + else: raise TypeError(f'expected_type must be a type, not {type(expected_type)}') super().__init__(*args, **kwargs) @@ -68,7 +68,7 @@ def __init__(self, expected_element_type=None, *args, **kwargs): self.expected_element_type = gokart.TaskOnKart elif isinstance(expected_element_type, type): self.expected_element_type = expected_element_type - elif expected_element_type is not None: + else: raise TypeError(f'expected_element_type must be a type, not {type(expected_element_type)}') super().__init__(*args, **kwargs) From 950b664a7d91971deddc66cd3431e167f8e78189 Mon Sep 17 00:00:00 2001 From: ujiuji1259 Date: Thu, 2 Mar 2023 20:18:15 +0900 Subject: [PATCH 12/13] fix name --- gokart/parameter.py | 14 +++++++------- test/test_task_instance_parameter.py | 6 +++--- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/gokart/parameter.py b/gokart/parameter.py index 4e590206..0115e0b1 100644 --- a/gokart/parameter.py +++ b/gokart/parameter.py @@ -63,13 +63,13 @@ def default(self, obj): class ListTaskInstanceParameter(luigi.Parameter): - def __init__(self, expected_element_type=None, *args, **kwargs): - if expected_element_type is None: + def __init__(self, expected_elements_type=None, *args, **kwargs): + if expected_elements_type is None: self.expected_element_type = gokart.TaskOnKart - elif isinstance(expected_element_type, type): - self.expected_element_type = expected_element_type + elif isinstance(expected_elements_type, type): + self.expected_element_type = expected_elements_type else: - raise TypeError(f'expected_element_type must be a type, not {type(expected_element_type)}') + raise TypeError(f'expected_elements_type must be a type, not {type(expected_elements_type)}') super().__init__(*args, **kwargs) def parse(self, s): @@ -80,8 +80,8 @@ def serialize(self, x): def _warn_on_wrong_param_type(self, param_name, param_value): for v in param_value: - if not isinstance(v, self.expected_element_type): - raise TypeError(f'{v} is not an instance of {self.expected_element_type}') + if not isinstance(v, self.expected_elements_type): + raise TypeError(f'{v} is not an instance of {self.expected_elements_type}') class ExplicitBoolParameter(luigi.BoolParameter): diff --git a/test/test_task_instance_parameter.py b/test/test_task_instance_parameter.py index 8350bccc..4935fa60 100644 --- a/test/test_task_instance_parameter.py +++ b/test/test_task_instance_parameter.py @@ -78,13 +78,13 @@ def setUp(self): _DummyTask.clear_instance_cache() def test_invalid_class(self): - self.assertRaises(TypeError, lambda: gokart.ListTaskInstanceParameter(expected_type=1)) # not type instance + self.assertRaises(TypeError, lambda: gokart.ListTaskInstanceParameter(expected_elements_type=1)) # not type instance def test_list_params_with_correct_param_types(self): class _DummyPipelineC(TaskOnKart): task_namespace = __name__ - subtask = gokart.ListTaskInstanceParameter(expected_element_type=_DummySubTask) + subtask = gokart.ListTaskInstanceParameter(expected_elements_type=_DummySubTask) task = _DummyPipelineC(subtask=[_DummyCorrectSubClassTask()]) self.assertEqual(task.requires()['subtask'], (_DummyCorrectSubClassTask(), )) @@ -93,7 +93,7 @@ def test_list_params_with_invalid_param_types(self): class _DummyPipelineD(TaskOnKart): task_namespace = __name__ - subtask = gokart.ListTaskInstanceParameter(expected_element_type=_DummySubTask) + subtask = gokart.ListTaskInstanceParameter(expected_elements_type=_DummySubTask) with self.assertRaises(TypeError): _DummyPipelineD(subtask=[_DummyInvalidSubClassTask(), _DummyCorrectSubClassTask]) From 5fd143b8155053f8de8b31c4f5a47d69461740e8 Mon Sep 17 00:00:00 2001 From: ujiuji1259 Date: Thu, 2 Mar 2023 20:18:45 +0900 Subject: [PATCH 13/13] fix name --- gokart/parameter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gokart/parameter.py b/gokart/parameter.py index 0115e0b1..337e835b 100644 --- a/gokart/parameter.py +++ b/gokart/parameter.py @@ -65,9 +65,9 @@ class ListTaskInstanceParameter(luigi.Parameter): def __init__(self, expected_elements_type=None, *args, **kwargs): if expected_elements_type is None: - self.expected_element_type = gokart.TaskOnKart + self.expected_elements_type = gokart.TaskOnKart elif isinstance(expected_elements_type, type): - self.expected_element_type = expected_elements_type + self.expected_elements_type = expected_elements_type else: raise TypeError(f'expected_elements_type must be a type, not {type(expected_elements_type)}') super().__init__(*args, **kwargs)