From 17f70645f5f81a5b02e31245b48299d99be931dd Mon Sep 17 00:00:00 2001 From: Ryo Kitagawa Date: Sun, 27 Oct 2024 13:53:37 +0900 Subject: [PATCH] feat: infer type of parameter when user doesn't specify type (#403) * feat: infer type of parameter when user doesn't specify type * chore: delete _infer_task_on_kart_attr_init_type --- gokart/mypy.py | 154 ++++++++++++++++++--------- test/test_mypy.py | 58 ++++++++++ test/test_task_instance_parameter.py | 2 +- 3 files changed, 161 insertions(+), 53 deletions(-) diff --git a/gokart/mypy.py b/gokart/mypy.py index 46f67a1c..b275d11e 100644 --- a/gokart/mypy.py +++ b/gokart/mypy.py @@ -10,19 +10,16 @@ from typing import Callable, Final, Iterator, Literal, Optional import luigi -from mypy.expandtype import expand_type, expand_type_by_instance +from mypy.expandtype import expand_type from mypy.nodes import ( ARG_NAMED_OPT, - ARG_POS, Argument, AssignmentStmt, Block, CallExpr, ClassDef, - Context, EllipsisExpr, Expression, - FuncDef, IfStmt, JsonDict, MemberExpr, @@ -30,7 +27,6 @@ PlaceholderNode, RefExpr, Statement, - SymbolTableNode, TempNode, TypeInfo, Var, @@ -45,12 +41,11 @@ from mypy.typeops import map_type_from_supertype from mypy.types import ( AnyType, - CallableType, Instance, NoneType, Type, TypeOfAny, - get_proper_type, + UnionType, ) from mypy.typevars import fill_typevars @@ -327,7 +322,11 @@ def collect_attributes(self) -> Optional[list[TaskOnKartAttribute]]: current_attr_names.add(lhs.name) with state.strict_optional_set(self._api.options.strict_optional): - init_type = self._infer_task_on_kart_attr_init_type(sym, stmt) + init_type = sym.type + + # infer Parameter type + if init_type is None: + init_type = self._infer_type_from_parameters(stmt.rvalue) found_attrs[lhs.name] = TaskOnKartAttribute( name=lhs.name, @@ -361,65 +360,115 @@ def _collect_parameter_args(self, expr: Expression) -> tuple[bool, dict[str, Exp return True, args return False, {} - def _infer_task_on_kart_attr_init_type(self, sym: SymbolTableNode, context: Context) -> Type | None: - """Infer __init__ argument type for an attribute. + def _infer_type_from_parameters(self, parameter: Expression) -> Optional[Type]: + """ + Generate default type from Parameter. + For example, when parameter is `luigi.parameter.Parameter`, this method should return `str` type. + """ + parameter_name = _extract_parameter_name(parameter) + if parameter_name is None: + return None + + underlying_type: Optional[Type] = None + if parameter_name in ['luigi.parameter.Parameter', 'luigi.parameter.OptionalParameter']: + underlying_type = self._api.named_type('builtins.str', []) + elif parameter_name in ['luigi.parameter.IntParameter', 'luigi.parameter.OptionalIntParameter']: + underlying_type = self._api.named_type('builtins.int', []) + elif parameter_name in ['luigi.parameter.FloatParameter', 'luigi.parameter.OptionalFloatParameter']: + underlying_type = self._api.named_type('builtins.float', []) + elif parameter_name in ['luigi.parameter.BoolParameter', 'luigi.parameter.OptionalBoolParameter']: + underlying_type = self._api.named_type('builtins.bool', []) + elif parameter_name in ['luigi.parameter.DateParameter', 'luigi.parameter.MonthParameter', 'luigi.parameter.YearParameter']: + underlying_type = self._api.named_type('datetime.date', []) + elif parameter_name in ['luigi.parameter.DateHourParameter', 'luigi.parameter.DateMinuteParameter', 'luigi.parameter.DateSecondParameter']: + underlying_type = self._api.named_type('datetime.datetime', []) + elif parameter_name in ['luigi.parameter.TimeDeltaParameter']: + underlying_type = self._api.named_type('datetime.timedelta', []) + elif parameter_name in ['luigi.parameter.DictParameter', 'luigi.parameter.OptionalDictParameter']: + underlying_type = self._api.named_type('builtins.dict', [AnyType(TypeOfAny.unannotated), AnyType(TypeOfAny.unannotated)]) + elif parameter_name in ['luigi.parameter.ListParameter', 'luigi.parameter.OptionalListParameter']: + underlying_type = self._api.named_type('builtins.tuple', [AnyType(TypeOfAny.unannotated)]) + elif parameter_name in ['luigi.parameter.TupleParameter', 'luigi.parameter.OptionalTupleParameter']: + underlying_type = self._api.named_type('builtins.tuple', [AnyType(TypeOfAny.unannotated)]) + elif parameter_name in ['luigi.parameter.PathParameter', 'luigi.parameter.OptionalPathParameter']: + underlying_type = self._api.named_type('pathlib.Path', []) + elif parameter_name in ['gokart.parameter.TaskInstanceParameter']: + underlying_type = self._api.named_type('gokart.task.TaskOnKart', [AnyType(TypeOfAny.unannotated)]) + elif parameter_name in ['gokart.parameter.ListTaskInstanceParameter']: + underlying_type = self._api.named_type('builtins.list', [self._api.named_type('gokart.task.TaskOnKart', [AnyType(TypeOfAny.unannotated)])]) + elif parameter_name in ['gokart.parameter.ExplicitBoolParameter']: + underlying_type = self._api.named_type('builtins.bool', []) + elif parameter_name in ['luigi.parameter.NumericalParameter']: + underlying_type = self._get_type_from_args(parameter, 'var_type') + elif parameter_name in ['luigi.parameter.ChoiceParameter']: + underlying_type = self._get_type_from_args(parameter, 'var_type') + elif parameter_name in ['luigi.parameter.ChoiceListPareameter']: + base_type = self._get_type_from_args(parameter, 'var_type') + if base_type is not None: + underlying_type = self._api.named_type('builtins.tuple', [base_type]) + elif parameter_name in ['luigi.parameter.EnumParameter']: + underlying_type = self._get_type_from_args(parameter, 'enum') + elif parameter_name in ['luigi.parameter.EnumListParameter']: + base_type = self._get_type_from_args(parameter, 'enum') + if base_type is not None: + underlying_type = self._api.named_type('builtins.tuple', [base_type]) + + if underlying_type is None: + return None + + # When parameter has Optional, it can be none value. + if 'Optional' in parameter_name: + return UnionType([underlying_type, NoneType()]) + + return underlying_type + + def _get_type_from_args(self, parameter: Expression, arg_key: str) -> Optional[Type]: + """ + get type from parameter arguments. - In particular, possibly use the signature of __set__. + e.x) + When parameter is `luigi.ChoiceParameter(var_type=int)`, this method should return `int` type. """ - default = sym.type - if sym.implicit: - return default - t = get_proper_type(sym.type) - - # Perform a simple-minded inference from the signature of __set__, if present. - # We can't use mypy.checkmember here, since this plugin runs before type checking. - # We only support some basic scanerios here, which is hopefully sufficient for - # the vast majority of use cases. - if not isinstance(t, Instance): - return default - setter = t.type.get('__set__') - - if not setter: - return default - - if isinstance(setter.node, FuncDef): - super_info = t.type.get_containing_type_info('__set__') - assert super_info - if setter.type: - setter_type = get_proper_type(map_type_from_supertype(setter.type, t.type, super_info)) - else: - return AnyType(TypeOfAny.unannotated) - if isinstance(setter_type, CallableType) and setter_type.arg_kinds == [ - ARG_POS, - ARG_POS, - ARG_POS, - ]: - return expand_type_by_instance(setter_type.arg_types[2], t) - else: - self._api.fail(f'Unsupported signature for "__set__" in "{t.type.name}"', context) - else: - self._api.fail(f'Unsupported "__set__" in "{t.type.name}"', context) - - return default + ok, args = self._collect_parameter_args(parameter) + if not ok: + return None + + if arg_key not in args: + return None + + arg = args[arg_key] + if not isinstance(arg, NameExpr): + return None + if not isinstance(arg.node, TypeInfo): + return None + return Instance(arg.node, []) def is_parameter_call(expr: Expression) -> bool: """Checks if the expression is a call to luigi.Parameter()""" - if not isinstance(expr, CallExpr): + parameter_name = _extract_parameter_name(expr) + if parameter_name is None: return False + return PARAMETER_FULLNAME_MATCHER.match(parameter_name) is not None + + +def _extract_parameter_name(expr: Expression) -> Optional[str]: + """Extract name if the expression is a call to luigi.Parameter()""" + if not isinstance(expr, CallExpr): + return None callee = expr.callee if isinstance(callee, MemberExpr): type_info = callee.node if type_info is None and isinstance(callee.expr, NameExpr): - return PARAMETER_FULLNAME_MATCHER.match(f'{callee.expr.name}.{callee.name}') is not None + return f'{callee.expr.name}.{callee.name}' elif isinstance(callee, NameExpr): type_info = callee.node else: - return False + return None if isinstance(type_info, TypeInfo): - return PARAMETER_FULLNAME_MATCHER.match(type_info.fullname) is not None + return type_info.fullname # Currently, luigi doesn't provide py.typed. it will be released next to 3.5.1. # https://github.com/spotify/luigi/pull/3297 @@ -429,8 +478,9 @@ def is_parameter_call(expr: Expression) -> bool: # class MyTask(gokart.TaskOnKart): # param = Parameter() if isinstance(type_info, Var) and luigi.__version__ <= '3.5.1': - return PARAMETER_TMP_MATCHER.match(type_info.name) is not None - return False + return type_info.name + + return None def plugin(version: str) -> type[Plugin]: diff --git a/test/test_mypy.py b/test/test_mypy.py index 545fcdde..74b83a84 100644 --- a/test/test_mypy.py +++ b/test/test_mypy.py @@ -64,3 +64,61 @@ class MyTask(gokart.TaskOnKart): self.assertIn('error: Argument "baz" to "MyTask" has incompatible type "str"; expected "bool" [arg-type]', result[0]) self.assertIn('error: Argument "complete_check_at_run" to "MyTask" has incompatible type "str"; expected "bool" [arg-type]', result[0]) self.assertIn('Found 3 errors in 1 file (checked 1 source file)', result[0]) + + def test_parameter_has_default_type_invalid_pattern(self): + """ + If user doesn't set the type of the parameter, mypy infer the default type from Parameter types. + """ + test_code = """ +import enum +import luigi +import gokart + + +class MyEnum(enum.Enum): + FOO = enum.auto() + +class MyTask(gokart.TaskOnKart): + # NOTE: mypy shows attr-defined error for the following lines, so we need to ignore it. + foo = luigi.IntParameter() + bar = luigi.DateParameter() + baz = gokart.TaskInstanceParameter() + qux = luigi.NumericalParameter(var_type=int) + quux = luigi.ChoiceParameter(choices=[1, 2, 3], var_type=int) + corge = luigi.EnumParameter(enum=MyEnum) + +MyTask(foo="1", bar=1, baz=1, qux='1', quux='1', corge=1) +""" + with tempfile.NamedTemporaryFile(suffix='.py') as test_file: + test_file.write(test_code.encode('utf-8')) + test_file.flush() + result = api.run(['--show-traceback', '--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name]) + self.assertIn('error: Argument "foo" to "MyTask" has incompatible type "str"; expected "int" [arg-type]', result[0]) + self.assertIn('error: Argument "bar" to "MyTask" has incompatible type "int"; expected "date" [arg-type]', result[0]) + self.assertIn('error: Argument "baz" to "MyTask" has incompatible type "int"; expected "TaskOnKart[Any]" [arg-type]', result[0]) + self.assertIn('error: Argument "qux" to "MyTask" has incompatible type "str"; expected "int" [arg-type]', result[0]) + self.assertIn('error: Argument "quux" to "MyTask" has incompatible type "str"; expected "int" [arg-type]', result[0]) + self.assertIn('error: Argument "corge" to "MyTask" has incompatible type "int"; expected "MyEnum" [arg-type]', result[0]) + + def test_parameter_has_default_type_no_issue_pattern(self): + """ + If user doesn't set the type of the parameter, mypy infer the default type from Parameter types. + """ + test_code = """ +from datetime import date +import luigi +import gokart + +class MyTask(gokart.TaskOnKart): + # NOTE: mypy shows attr-defined error for the following lines, so we need to ignore it. + foo = luigi.IntParameter() + bar = luigi.DateParameter() + baz = gokart.TaskInstanceParameter() + +MyTask(foo=1, bar=date.today(), baz=gokart.TaskOnKart()) +""" + with tempfile.NamedTemporaryFile(suffix='.py') as test_file: + test_file.write(test_code.encode('utf-8')) + test_file.flush() + result = api.run(['--show-traceback', '--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name]) + self.assertIn('Success: no issues found', result[0]) diff --git a/test/test_task_instance_parameter.py b/test/test_task_instance_parameter.py index fe2f7b8a..5230d638 100644 --- a/test/test_task_instance_parameter.py +++ b/test/test_task_instance_parameter.py @@ -90,7 +90,7 @@ class _DummyPipelineD(TaskOnKart): subtask = gokart.ListTaskInstanceParameter(expected_elements_type=_DummySubTask) with self.assertRaises(TypeError): - _DummyPipelineD(subtask=[_DummyInvalidSubClassTask(), _DummyCorrectSubClassTask]) + _DummyPipelineD(subtask=[_DummyInvalidSubClassTask(), _DummyCorrectSubClassTask()]) if __name__ == '__main__':