Skip to content

Commit

Permalink
feat: infer type of parameter when user doesn't specify type (#403)
Browse files Browse the repository at this point in the history
* feat: infer type of parameter when user doesn't specify type

* chore: delete _infer_task_on_kart_attr_init_type
  • Loading branch information
kitagry authored Oct 27, 2024
1 parent 599837b commit 17f7064
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 53 deletions.
154 changes: 102 additions & 52 deletions gokart/mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,23 @@
from typing import Callable, Final, Iterator, Literal, Optional

import luigi
from mypy.expandtype import expand_type, expand_type_by_instance
from mypy.expandtype import expand_type
from mypy.nodes import (
ARG_NAMED_OPT,
ARG_POS,
Argument,
AssignmentStmt,
Block,
CallExpr,
ClassDef,
Context,
EllipsisExpr,
Expression,
FuncDef,
IfStmt,
JsonDict,
MemberExpr,
NameExpr,
PlaceholderNode,
RefExpr,
Statement,
SymbolTableNode,
TempNode,
TypeInfo,
Var,
Expand All @@ -45,12 +41,11 @@
from mypy.typeops import map_type_from_supertype
from mypy.types import (
AnyType,
CallableType,
Instance,
NoneType,
Type,
TypeOfAny,
get_proper_type,
UnionType,
)
from mypy.typevars import fill_typevars

Expand Down Expand Up @@ -327,7 +322,11 @@ def collect_attributes(self) -> Optional[list[TaskOnKartAttribute]]:

current_attr_names.add(lhs.name)
with state.strict_optional_set(self._api.options.strict_optional):
init_type = self._infer_task_on_kart_attr_init_type(sym, stmt)
init_type = sym.type

# infer Parameter type
if init_type is None:
init_type = self._infer_type_from_parameters(stmt.rvalue)

found_attrs[lhs.name] = TaskOnKartAttribute(
name=lhs.name,
Expand Down Expand Up @@ -361,65 +360,115 @@ def _collect_parameter_args(self, expr: Expression) -> tuple[bool, dict[str, Exp
return True, args
return False, {}

def _infer_task_on_kart_attr_init_type(self, sym: SymbolTableNode, context: Context) -> Type | None:
"""Infer __init__ argument type for an attribute.
def _infer_type_from_parameters(self, parameter: Expression) -> Optional[Type]:
"""
Generate default type from Parameter.
For example, when parameter is `luigi.parameter.Parameter`, this method should return `str` type.
"""
parameter_name = _extract_parameter_name(parameter)
if parameter_name is None:
return None

underlying_type: Optional[Type] = None
if parameter_name in ['luigi.parameter.Parameter', 'luigi.parameter.OptionalParameter']:
underlying_type = self._api.named_type('builtins.str', [])
elif parameter_name in ['luigi.parameter.IntParameter', 'luigi.parameter.OptionalIntParameter']:
underlying_type = self._api.named_type('builtins.int', [])
elif parameter_name in ['luigi.parameter.FloatParameter', 'luigi.parameter.OptionalFloatParameter']:
underlying_type = self._api.named_type('builtins.float', [])
elif parameter_name in ['luigi.parameter.BoolParameter', 'luigi.parameter.OptionalBoolParameter']:
underlying_type = self._api.named_type('builtins.bool', [])
elif parameter_name in ['luigi.parameter.DateParameter', 'luigi.parameter.MonthParameter', 'luigi.parameter.YearParameter']:
underlying_type = self._api.named_type('datetime.date', [])
elif parameter_name in ['luigi.parameter.DateHourParameter', 'luigi.parameter.DateMinuteParameter', 'luigi.parameter.DateSecondParameter']:
underlying_type = self._api.named_type('datetime.datetime', [])
elif parameter_name in ['luigi.parameter.TimeDeltaParameter']:
underlying_type = self._api.named_type('datetime.timedelta', [])
elif parameter_name in ['luigi.parameter.DictParameter', 'luigi.parameter.OptionalDictParameter']:
underlying_type = self._api.named_type('builtins.dict', [AnyType(TypeOfAny.unannotated), AnyType(TypeOfAny.unannotated)])
elif parameter_name in ['luigi.parameter.ListParameter', 'luigi.parameter.OptionalListParameter']:
underlying_type = self._api.named_type('builtins.tuple', [AnyType(TypeOfAny.unannotated)])
elif parameter_name in ['luigi.parameter.TupleParameter', 'luigi.parameter.OptionalTupleParameter']:
underlying_type = self._api.named_type('builtins.tuple', [AnyType(TypeOfAny.unannotated)])
elif parameter_name in ['luigi.parameter.PathParameter', 'luigi.parameter.OptionalPathParameter']:
underlying_type = self._api.named_type('pathlib.Path', [])
elif parameter_name in ['gokart.parameter.TaskInstanceParameter']:
underlying_type = self._api.named_type('gokart.task.TaskOnKart', [AnyType(TypeOfAny.unannotated)])
elif parameter_name in ['gokart.parameter.ListTaskInstanceParameter']:
underlying_type = self._api.named_type('builtins.list', [self._api.named_type('gokart.task.TaskOnKart', [AnyType(TypeOfAny.unannotated)])])
elif parameter_name in ['gokart.parameter.ExplicitBoolParameter']:
underlying_type = self._api.named_type('builtins.bool', [])
elif parameter_name in ['luigi.parameter.NumericalParameter']:
underlying_type = self._get_type_from_args(parameter, 'var_type')
elif parameter_name in ['luigi.parameter.ChoiceParameter']:
underlying_type = self._get_type_from_args(parameter, 'var_type')
elif parameter_name in ['luigi.parameter.ChoiceListPareameter']:
base_type = self._get_type_from_args(parameter, 'var_type')
if base_type is not None:
underlying_type = self._api.named_type('builtins.tuple', [base_type])
elif parameter_name in ['luigi.parameter.EnumParameter']:
underlying_type = self._get_type_from_args(parameter, 'enum')
elif parameter_name in ['luigi.parameter.EnumListParameter']:
base_type = self._get_type_from_args(parameter, 'enum')
if base_type is not None:
underlying_type = self._api.named_type('builtins.tuple', [base_type])

if underlying_type is None:
return None

# When parameter has Optional, it can be none value.
if 'Optional' in parameter_name:
return UnionType([underlying_type, NoneType()])

return underlying_type

def _get_type_from_args(self, parameter: Expression, arg_key: str) -> Optional[Type]:
"""
get type from parameter arguments.
In particular, possibly use the signature of __set__.
e.x)
When parameter is `luigi.ChoiceParameter(var_type=int)`, this method should return `int` type.
"""
default = sym.type
if sym.implicit:
return default
t = get_proper_type(sym.type)

# Perform a simple-minded inference from the signature of __set__, if present.
# We can't use mypy.checkmember here, since this plugin runs before type checking.
# We only support some basic scanerios here, which is hopefully sufficient for
# the vast majority of use cases.
if not isinstance(t, Instance):
return default
setter = t.type.get('__set__')

if not setter:
return default

if isinstance(setter.node, FuncDef):
super_info = t.type.get_containing_type_info('__set__')
assert super_info
if setter.type:
setter_type = get_proper_type(map_type_from_supertype(setter.type, t.type, super_info))
else:
return AnyType(TypeOfAny.unannotated)
if isinstance(setter_type, CallableType) and setter_type.arg_kinds == [
ARG_POS,
ARG_POS,
ARG_POS,
]:
return expand_type_by_instance(setter_type.arg_types[2], t)
else:
self._api.fail(f'Unsupported signature for "__set__" in "{t.type.name}"', context)
else:
self._api.fail(f'Unsupported "__set__" in "{t.type.name}"', context)

return default
ok, args = self._collect_parameter_args(parameter)
if not ok:
return None

if arg_key not in args:
return None

arg = args[arg_key]
if not isinstance(arg, NameExpr):
return None
if not isinstance(arg.node, TypeInfo):
return None
return Instance(arg.node, [])


def is_parameter_call(expr: Expression) -> bool:
"""Checks if the expression is a call to luigi.Parameter()"""
if not isinstance(expr, CallExpr):
parameter_name = _extract_parameter_name(expr)
if parameter_name is None:
return False
return PARAMETER_FULLNAME_MATCHER.match(parameter_name) is not None


def _extract_parameter_name(expr: Expression) -> Optional[str]:
"""Extract name if the expression is a call to luigi.Parameter()"""
if not isinstance(expr, CallExpr):
return None

callee = expr.callee
if isinstance(callee, MemberExpr):
type_info = callee.node
if type_info is None and isinstance(callee.expr, NameExpr):
return PARAMETER_FULLNAME_MATCHER.match(f'{callee.expr.name}.{callee.name}') is not None
return f'{callee.expr.name}.{callee.name}'
elif isinstance(callee, NameExpr):
type_info = callee.node
else:
return False
return None

if isinstance(type_info, TypeInfo):
return PARAMETER_FULLNAME_MATCHER.match(type_info.fullname) is not None
return type_info.fullname

# Currently, luigi doesn't provide py.typed. it will be released next to 3.5.1.
# https://github.com/spotify/luigi/pull/3297
Expand All @@ -429,8 +478,9 @@ def is_parameter_call(expr: Expression) -> bool:
# class MyTask(gokart.TaskOnKart):
# param = Parameter()
if isinstance(type_info, Var) and luigi.__version__ <= '3.5.1':
return PARAMETER_TMP_MATCHER.match(type_info.name) is not None
return False
return type_info.name

return None


def plugin(version: str) -> type[Plugin]:
Expand Down
58 changes: 58 additions & 0 deletions test/test_mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,61 @@ class MyTask(gokart.TaskOnKart):
self.assertIn('error: Argument "baz" to "MyTask" has incompatible type "str"; expected "bool" [arg-type]', result[0])
self.assertIn('error: Argument "complete_check_at_run" to "MyTask" has incompatible type "str"; expected "bool" [arg-type]', result[0])
self.assertIn('Found 3 errors in 1 file (checked 1 source file)', result[0])

def test_parameter_has_default_type_invalid_pattern(self):
"""
If user doesn't set the type of the parameter, mypy infer the default type from Parameter types.
"""
test_code = """
import enum
import luigi
import gokart
class MyEnum(enum.Enum):
FOO = enum.auto()
class MyTask(gokart.TaskOnKart):
# NOTE: mypy shows attr-defined error for the following lines, so we need to ignore it.
foo = luigi.IntParameter()
bar = luigi.DateParameter()
baz = gokart.TaskInstanceParameter()
qux = luigi.NumericalParameter(var_type=int)
quux = luigi.ChoiceParameter(choices=[1, 2, 3], var_type=int)
corge = luigi.EnumParameter(enum=MyEnum)
MyTask(foo="1", bar=1, baz=1, qux='1', quux='1', corge=1)
"""
with tempfile.NamedTemporaryFile(suffix='.py') as test_file:
test_file.write(test_code.encode('utf-8'))
test_file.flush()
result = api.run(['--show-traceback', '--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name])
self.assertIn('error: Argument "foo" to "MyTask" has incompatible type "str"; expected "int" [arg-type]', result[0])
self.assertIn('error: Argument "bar" to "MyTask" has incompatible type "int"; expected "date" [arg-type]', result[0])
self.assertIn('error: Argument "baz" to "MyTask" has incompatible type "int"; expected "TaskOnKart[Any]" [arg-type]', result[0])
self.assertIn('error: Argument "qux" to "MyTask" has incompatible type "str"; expected "int" [arg-type]', result[0])
self.assertIn('error: Argument "quux" to "MyTask" has incompatible type "str"; expected "int" [arg-type]', result[0])
self.assertIn('error: Argument "corge" to "MyTask" has incompatible type "int"; expected "MyEnum" [arg-type]', result[0])

def test_parameter_has_default_type_no_issue_pattern(self):
"""
If user doesn't set the type of the parameter, mypy infer the default type from Parameter types.
"""
test_code = """
from datetime import date
import luigi
import gokart
class MyTask(gokart.TaskOnKart):
# NOTE: mypy shows attr-defined error for the following lines, so we need to ignore it.
foo = luigi.IntParameter()
bar = luigi.DateParameter()
baz = gokart.TaskInstanceParameter()
MyTask(foo=1, bar=date.today(), baz=gokart.TaskOnKart())
"""
with tempfile.NamedTemporaryFile(suffix='.py') as test_file:
test_file.write(test_code.encode('utf-8'))
test_file.flush()
result = api.run(['--show-traceback', '--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name])
self.assertIn('Success: no issues found', result[0])
2 changes: 1 addition & 1 deletion test/test_task_instance_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class _DummyPipelineD(TaskOnKart):
subtask = gokart.ListTaskInstanceParameter(expected_elements_type=_DummySubTask)

with self.assertRaises(TypeError):
_DummyPipelineD(subtask=[_DummyInvalidSubClassTask(), _DummyCorrectSubClassTask])
_DummyPipelineD(subtask=[_DummyInvalidSubClassTask(), _DummyCorrectSubClassTask()])


if __name__ == '__main__':
Expand Down

0 comments on commit 17f7064

Please sign in to comment.