From c0de16d64f6a8b19d98e67eeaeaef49b26237625 Mon Sep 17 00:00:00 2001 From: Ryo Kitagawa Date: Fri, 25 Oct 2024 22:43:44 +0900 Subject: [PATCH] feat: add default type for parameter --- gokart/mypy.py | 109 +++++++++++++++++++++++++-- test/test_mypy.py | 58 ++++++++++++++ test/test_task_instance_parameter.py | 2 +- 3 files changed, 162 insertions(+), 7 deletions(-) diff --git a/gokart/mypy.py b/gokart/mypy.py index 46f67a1c..e693c9a8 100644 --- a/gokart/mypy.py +++ b/gokart/mypy.py @@ -50,6 +50,7 @@ NoneType, Type, TypeOfAny, + UnionType, get_proper_type, ) from mypy.typevars import fill_typevars @@ -329,6 +330,10 @@ def collect_attributes(self) -> Optional[list[TaskOnKartAttribute]]: with state.strict_optional_set(self._api.options.strict_optional): init_type = self._infer_task_on_kart_attr_init_type(sym, stmt) + # infer Parameter type + if init_type is None: + init_type = self._generate_type_from_parameters(stmt.rvalue) + found_attrs[lhs.name] = TaskOnKartAttribute( name=lhs.name, has_default=has_default, @@ -402,24 +407,115 @@ def _infer_task_on_kart_attr_init_type(self, sym: SymbolTableNode, context: Cont return default + def _generate_type_from_parameters(self, parameter: Expression) -> Optional[Type]: + """ + Generate default type from Parameter. + For example, when parameter is `luigi.parameter.Parameter`, this method should return `str` type. + """ + parameter_name = _extract_parameter_name(parameter) + if parameter_name is None: + return None + + result_type = None + if parameter_name in ['luigi.parameter.Parameter', 'luigi.parameter.OptionalParameter']: + result_type = self._api.named_type('builtins.str', []) + elif parameter_name in ['luigi.parameter.IntParameter', 'luigi.parameter.OptionalIntParameter']: + result_type = self._api.named_type('builtins.int', []) + elif parameter_name in ['luigi.parameter.FloatParameter', 'luigi.parameter.OptionalFloatParameter']: + result_type = self._api.named_type('builtins.float', []) + elif parameter_name in ['luigi.parameter.BoolParameter', 'luigi.parameter.OptionalBoolParameter']: + result_type = self._api.named_type('builtins.bool', []) + elif parameter_name in ['luigi.parameter.DateParameter', 'luigi.parameter.MonthParameter', 'luigi.parameter.YearParameter']: + result_type = self._api.named_type('datetime.date', []) + elif parameter_name in ['luigi.parameter.DateHourParameter', 'luigi.parameter.DateMinuteParameter', 'luigi.parameter.DateSecondParameter']: + result_type = self._api.named_type('datetime.datetime', []) + elif parameter_name in ['luigi.parameter.TimeDeltaParameter']: + result_type = self._api.named_type('datetime.timedelta', []) + elif parameter_name in ['luigi.parameter.DictParameter', 'luigi.parameter.OptionalDictParameter']: + result_type = self._api.named_type('builtins.dict', [AnyType(TypeOfAny.unannotated), AnyType(TypeOfAny.unannotated)]) + elif parameter_name in ['luigi.parameter.ListParameter', 'luigi.parameter.OptionalListParameter']: + result_type = self._api.named_type('builtins.tuple', [AnyType(TypeOfAny.unannotated)]) + elif parameter_name in ['luigi.parameter.TupleParameter', 'luigi.parameter.OptionalTupleParameter']: + result_type = self._api.named_type('builtins.tuple', [AnyType(TypeOfAny.unannotated)]) + elif parameter_name in ['luigi.parameter.PathParameter', 'luigi.parameter.OptionalPathParameter']: + result_type = self._api.named_type('pathlib.Path', []) + elif parameter_name in ['gokart.parameter.TaskInstanceParameter']: + result_type = self._api.named_type('gokart.task.TaskOnKart', [AnyType(TypeOfAny.unannotated)]) + elif parameter_name in ['gokart.parameter.ListTaskInstanceParameter']: + result_type = self._api.named_type('builtins.list', [self._api.named_type('gokart.task.TaskOnKart', [AnyType(TypeOfAny.unannotated)])]) + elif parameter_name in ['gokart.parameter.ExplicitBoolParameter']: + result_type = self._api.named_type('builtins.bool', []) + elif parameter_name in ['luigi.parameter.NumericalParameter']: + result_type = self._get_type_from_args(parameter, 'var_type') + elif parameter_name in ['luigi.parameter.ChoiceParameter']: + result_type = self._get_type_from_args(parameter, 'var_type') + elif parameter_name in ['luigi.parameter.ChoiceListPareameter']: + base_type = self._get_type_from_args(parameter, 'var_type') + if base_type is not None: + result_type = self._api.named_type('builtins.tuple', [base_type]) + elif parameter_name in ['luigi.parameter.EnumParameter']: + result_type = self._get_type_from_args(parameter, 'enum') + elif parameter_name in ['luigi.parameter.EnumListParameter']: + base_type = self._get_type_from_args(parameter, 'enum') + if base_type is not None: + result_type = self._api.named_type('builtins.tuple', [base_type]) + + if result_type is None: + return None + + # When parameter has Optional, it can be none value. + if 'Optional' in parameter_name: + result_type = UnionType([result_type, NoneType()]) + + return result_type + + def _get_type_from_args(self, parameter: Expression, arg_key: str) -> Optional[Type]: + """ + get type from parameter arguments. + + e.x) + When parameter is `luigi.ChoiceParameter(var_type=int)`, this method should return `int` type. + """ + ok, args = self._collect_parameter_args(parameter) + if not ok: + return None + + if arg_key not in args: + return None + + arg = args[arg_key] + if not isinstance(arg, NameExpr): + return None + if not isinstance(arg.node, TypeInfo): + return None + return Instance(arg.node, []) + def is_parameter_call(expr: Expression) -> bool: """Checks if the expression is a call to luigi.Parameter()""" - if not isinstance(expr, CallExpr): + parameter_name = _extract_parameter_name(expr) + if parameter_name is None: return False + return PARAMETER_FULLNAME_MATCHER.match(parameter_name) is not None + + +def _extract_parameter_name(expr: Expression) -> Optional[str]: + """Extract name if the expression is a call to luigi.Parameter()""" + if not isinstance(expr, CallExpr): + return None callee = expr.callee if isinstance(callee, MemberExpr): type_info = callee.node if type_info is None and isinstance(callee.expr, NameExpr): - return PARAMETER_FULLNAME_MATCHER.match(f'{callee.expr.name}.{callee.name}') is not None + return f'{callee.expr.name}.{callee.name}' elif isinstance(callee, NameExpr): type_info = callee.node else: - return False + return None if isinstance(type_info, TypeInfo): - return PARAMETER_FULLNAME_MATCHER.match(type_info.fullname) is not None + return type_info.fullname # Currently, luigi doesn't provide py.typed. it will be released next to 3.5.1. # https://github.com/spotify/luigi/pull/3297 @@ -429,8 +525,9 @@ def is_parameter_call(expr: Expression) -> bool: # class MyTask(gokart.TaskOnKart): # param = Parameter() if isinstance(type_info, Var) and luigi.__version__ <= '3.5.1': - return PARAMETER_TMP_MATCHER.match(type_info.name) is not None - return False + return type_info.name + + return None def plugin(version: str) -> type[Plugin]: diff --git a/test/test_mypy.py b/test/test_mypy.py index 545fcdde..74b83a84 100644 --- a/test/test_mypy.py +++ b/test/test_mypy.py @@ -64,3 +64,61 @@ class MyTask(gokart.TaskOnKart): self.assertIn('error: Argument "baz" to "MyTask" has incompatible type "str"; expected "bool" [arg-type]', result[0]) self.assertIn('error: Argument "complete_check_at_run" to "MyTask" has incompatible type "str"; expected "bool" [arg-type]', result[0]) self.assertIn('Found 3 errors in 1 file (checked 1 source file)', result[0]) + + def test_parameter_has_default_type_invalid_pattern(self): + """ + If user doesn't set the type of the parameter, mypy infer the default type from Parameter types. + """ + test_code = """ +import enum +import luigi +import gokart + + +class MyEnum(enum.Enum): + FOO = enum.auto() + +class MyTask(gokart.TaskOnKart): + # NOTE: mypy shows attr-defined error for the following lines, so we need to ignore it. + foo = luigi.IntParameter() + bar = luigi.DateParameter() + baz = gokart.TaskInstanceParameter() + qux = luigi.NumericalParameter(var_type=int) + quux = luigi.ChoiceParameter(choices=[1, 2, 3], var_type=int) + corge = luigi.EnumParameter(enum=MyEnum) + +MyTask(foo="1", bar=1, baz=1, qux='1', quux='1', corge=1) +""" + with tempfile.NamedTemporaryFile(suffix='.py') as test_file: + test_file.write(test_code.encode('utf-8')) + test_file.flush() + result = api.run(['--show-traceback', '--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name]) + self.assertIn('error: Argument "foo" to "MyTask" has incompatible type "str"; expected "int" [arg-type]', result[0]) + self.assertIn('error: Argument "bar" to "MyTask" has incompatible type "int"; expected "date" [arg-type]', result[0]) + self.assertIn('error: Argument "baz" to "MyTask" has incompatible type "int"; expected "TaskOnKart[Any]" [arg-type]', result[0]) + self.assertIn('error: Argument "qux" to "MyTask" has incompatible type "str"; expected "int" [arg-type]', result[0]) + self.assertIn('error: Argument "quux" to "MyTask" has incompatible type "str"; expected "int" [arg-type]', result[0]) + self.assertIn('error: Argument "corge" to "MyTask" has incompatible type "int"; expected "MyEnum" [arg-type]', result[0]) + + def test_parameter_has_default_type_no_issue_pattern(self): + """ + If user doesn't set the type of the parameter, mypy infer the default type from Parameter types. + """ + test_code = """ +from datetime import date +import luigi +import gokart + +class MyTask(gokart.TaskOnKart): + # NOTE: mypy shows attr-defined error for the following lines, so we need to ignore it. + foo = luigi.IntParameter() + bar = luigi.DateParameter() + baz = gokart.TaskInstanceParameter() + +MyTask(foo=1, bar=date.today(), baz=gokart.TaskOnKart()) +""" + with tempfile.NamedTemporaryFile(suffix='.py') as test_file: + test_file.write(test_code.encode('utf-8')) + test_file.flush() + result = api.run(['--show-traceback', '--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name]) + self.assertIn('Success: no issues found', result[0]) diff --git a/test/test_task_instance_parameter.py b/test/test_task_instance_parameter.py index fe2f7b8a..5230d638 100644 --- a/test/test_task_instance_parameter.py +++ b/test/test_task_instance_parameter.py @@ -90,7 +90,7 @@ class _DummyPipelineD(TaskOnKart): subtask = gokart.ListTaskInstanceParameter(expected_elements_type=_DummySubTask) with self.assertRaises(TypeError): - _DummyPipelineD(subtask=[_DummyInvalidSubClassTask(), _DummyCorrectSubClassTask]) + _DummyPipelineD(subtask=[_DummyInvalidSubClassTask(), _DummyCorrectSubClassTask()]) if __name__ == '__main__':