Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: infer type of parameter when user doesn't specify type #403

Merged
merged 2 commits into from
Oct 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 102 additions & 52 deletions gokart/mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,23 @@
from typing import Callable, Final, Iterator, Literal, Optional

import luigi
from mypy.expandtype import expand_type, expand_type_by_instance
from mypy.expandtype import expand_type
from mypy.nodes import (
ARG_NAMED_OPT,
ARG_POS,
Argument,
AssignmentStmt,
Block,
CallExpr,
ClassDef,
Context,
EllipsisExpr,
Expression,
FuncDef,
IfStmt,
JsonDict,
MemberExpr,
NameExpr,
PlaceholderNode,
RefExpr,
Statement,
SymbolTableNode,
TempNode,
TypeInfo,
Var,
Expand All @@ -45,12 +41,11 @@
from mypy.typeops import map_type_from_supertype
from mypy.types import (
AnyType,
CallableType,
Instance,
NoneType,
Type,
TypeOfAny,
get_proper_type,
UnionType,
)
from mypy.typevars import fill_typevars

Expand Down Expand Up @@ -327,7 +322,11 @@ def collect_attributes(self) -> Optional[list[TaskOnKartAttribute]]:

current_attr_names.add(lhs.name)
with state.strict_optional_set(self._api.options.strict_optional):
init_type = self._infer_task_on_kart_attr_init_type(sym, stmt)
init_type = sym.type

# infer Parameter type
if init_type is None:
init_type = self._infer_type_from_parameters(stmt.rvalue)

found_attrs[lhs.name] = TaskOnKartAttribute(
name=lhs.name,
Expand Down Expand Up @@ -361,65 +360,115 @@ def _collect_parameter_args(self, expr: Expression) -> tuple[bool, dict[str, Exp
return True, args
return False, {}

def _infer_task_on_kart_attr_init_type(self, sym: SymbolTableNode, context: Context) -> Type | None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(just a memo for this change)
This function aims to infer type for pydantic field, which is not needed on gokart.

"""Infer __init__ argument type for an attribute.
def _infer_type_from_parameters(self, parameter: Expression) -> Optional[Type]:
"""
Generate default type from Parameter.
For example, when parameter is `luigi.parameter.Parameter`, this method should return `str` type.
"""
parameter_name = _extract_parameter_name(parameter)
if parameter_name is None:
return None

underlying_type: Optional[Type] = None
if parameter_name in ['luigi.parameter.Parameter', 'luigi.parameter.OptionalParameter']:
underlying_type = self._api.named_type('builtins.str', [])
elif parameter_name in ['luigi.parameter.IntParameter', 'luigi.parameter.OptionalIntParameter']:
underlying_type = self._api.named_type('builtins.int', [])
elif parameter_name in ['luigi.parameter.FloatParameter', 'luigi.parameter.OptionalFloatParameter']:
underlying_type = self._api.named_type('builtins.float', [])
elif parameter_name in ['luigi.parameter.BoolParameter', 'luigi.parameter.OptionalBoolParameter']:
underlying_type = self._api.named_type('builtins.bool', [])
elif parameter_name in ['luigi.parameter.DateParameter', 'luigi.parameter.MonthParameter', 'luigi.parameter.YearParameter']:
underlying_type = self._api.named_type('datetime.date', [])
elif parameter_name in ['luigi.parameter.DateHourParameter', 'luigi.parameter.DateMinuteParameter', 'luigi.parameter.DateSecondParameter']:
underlying_type = self._api.named_type('datetime.datetime', [])
elif parameter_name in ['luigi.parameter.TimeDeltaParameter']:
underlying_type = self._api.named_type('datetime.timedelta', [])
elif parameter_name in ['luigi.parameter.DictParameter', 'luigi.parameter.OptionalDictParameter']:
underlying_type = self._api.named_type('builtins.dict', [AnyType(TypeOfAny.unannotated), AnyType(TypeOfAny.unannotated)])
elif parameter_name in ['luigi.parameter.ListParameter', 'luigi.parameter.OptionalListParameter']:
underlying_type = self._api.named_type('builtins.tuple', [AnyType(TypeOfAny.unannotated)])
elif parameter_name in ['luigi.parameter.TupleParameter', 'luigi.parameter.OptionalTupleParameter']:
underlying_type = self._api.named_type('builtins.tuple', [AnyType(TypeOfAny.unannotated)])
elif parameter_name in ['luigi.parameter.PathParameter', 'luigi.parameter.OptionalPathParameter']:
underlying_type = self._api.named_type('pathlib.Path', [])
elif parameter_name in ['gokart.parameter.TaskInstanceParameter']:
underlying_type = self._api.named_type('gokart.task.TaskOnKart', [AnyType(TypeOfAny.unannotated)])
elif parameter_name in ['gokart.parameter.ListTaskInstanceParameter']:
underlying_type = self._api.named_type('builtins.list', [self._api.named_type('gokart.task.TaskOnKart', [AnyType(TypeOfAny.unannotated)])])
elif parameter_name in ['gokart.parameter.ExplicitBoolParameter']:
underlying_type = self._api.named_type('builtins.bool', [])
elif parameter_name in ['luigi.parameter.NumericalParameter']:
underlying_type = self._get_type_from_args(parameter, 'var_type')
elif parameter_name in ['luigi.parameter.ChoiceParameter']:
underlying_type = self._get_type_from_args(parameter, 'var_type')
elif parameter_name in ['luigi.parameter.ChoiceListPareameter']:
base_type = self._get_type_from_args(parameter, 'var_type')
if base_type is not None:
underlying_type = self._api.named_type('builtins.tuple', [base_type])
elif parameter_name in ['luigi.parameter.EnumParameter']:
underlying_type = self._get_type_from_args(parameter, 'enum')
elif parameter_name in ['luigi.parameter.EnumListParameter']:
base_type = self._get_type_from_args(parameter, 'enum')
if base_type is not None:
underlying_type = self._api.named_type('builtins.tuple', [base_type])

Comment on lines +373 to +415
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

memo: After Python 3.9 EOL, we can use match case syntax.

if underlying_type is None:
return None

# When parameter has Optional, it can be none value.
if 'Optional' in parameter_name:
return UnionType([underlying_type, NoneType()])

return underlying_type

def _get_type_from_args(self, parameter: Expression, arg_key: str) -> Optional[Type]:
"""
get type from parameter arguments.

In particular, possibly use the signature of __set__.
e.x)
When parameter is `luigi.ChoiceParameter(var_type=int)`, this method should return `int` type.
"""
default = sym.type
if sym.implicit:
return default
t = get_proper_type(sym.type)

# Perform a simple-minded inference from the signature of __set__, if present.
# We can't use mypy.checkmember here, since this plugin runs before type checking.
# We only support some basic scanerios here, which is hopefully sufficient for
# the vast majority of use cases.
if not isinstance(t, Instance):
return default
setter = t.type.get('__set__')

if not setter:
return default

if isinstance(setter.node, FuncDef):
super_info = t.type.get_containing_type_info('__set__')
assert super_info
if setter.type:
setter_type = get_proper_type(map_type_from_supertype(setter.type, t.type, super_info))
else:
return AnyType(TypeOfAny.unannotated)
if isinstance(setter_type, CallableType) and setter_type.arg_kinds == [
ARG_POS,
ARG_POS,
ARG_POS,
]:
return expand_type_by_instance(setter_type.arg_types[2], t)
else:
self._api.fail(f'Unsupported signature for "__set__" in "{t.type.name}"', context)
else:
self._api.fail(f'Unsupported "__set__" in "{t.type.name}"', context)

return default
ok, args = self._collect_parameter_args(parameter)
if not ok:
return None

if arg_key not in args:
return None

arg = args[arg_key]
if not isinstance(arg, NameExpr):
return None
if not isinstance(arg.node, TypeInfo):
return None
return Instance(arg.node, [])


def is_parameter_call(expr: Expression) -> bool:
"""Checks if the expression is a call to luigi.Parameter()"""
if not isinstance(expr, CallExpr):
parameter_name = _extract_parameter_name(expr)
if parameter_name is None:
return False
return PARAMETER_FULLNAME_MATCHER.match(parameter_name) is not None


def _extract_parameter_name(expr: Expression) -> Optional[str]:
"""Extract name if the expression is a call to luigi.Parameter()"""
if not isinstance(expr, CallExpr):
return None

callee = expr.callee
if isinstance(callee, MemberExpr):
type_info = callee.node
if type_info is None and isinstance(callee.expr, NameExpr):
return PARAMETER_FULLNAME_MATCHER.match(f'{callee.expr.name}.{callee.name}') is not None
return f'{callee.expr.name}.{callee.name}'
elif isinstance(callee, NameExpr):
type_info = callee.node
else:
return False
return None

if isinstance(type_info, TypeInfo):
return PARAMETER_FULLNAME_MATCHER.match(type_info.fullname) is not None
return type_info.fullname

# Currently, luigi doesn't provide py.typed. it will be released next to 3.5.1.
# https://github.com/spotify/luigi/pull/3297
Expand All @@ -429,8 +478,9 @@ def is_parameter_call(expr: Expression) -> bool:
# class MyTask(gokart.TaskOnKart):
# param = Parameter()
if isinstance(type_info, Var) and luigi.__version__ <= '3.5.1':
return PARAMETER_TMP_MATCHER.match(type_info.name) is not None
return False
return type_info.name

return None


def plugin(version: str) -> type[Plugin]:
Expand Down
58 changes: 58 additions & 0 deletions test/test_mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,61 @@ class MyTask(gokart.TaskOnKart):
self.assertIn('error: Argument "baz" to "MyTask" has incompatible type "str"; expected "bool" [arg-type]', result[0])
self.assertIn('error: Argument "complete_check_at_run" to "MyTask" has incompatible type "str"; expected "bool" [arg-type]', result[0])
self.assertIn('Found 3 errors in 1 file (checked 1 source file)', result[0])

def test_parameter_has_default_type_invalid_pattern(self):
"""
If user doesn't set the type of the parameter, mypy infer the default type from Parameter types.
"""
test_code = """
import enum
import luigi
import gokart
class MyEnum(enum.Enum):
FOO = enum.auto()
class MyTask(gokart.TaskOnKart):
# NOTE: mypy shows attr-defined error for the following lines, so we need to ignore it.
foo = luigi.IntParameter()
bar = luigi.DateParameter()
baz = gokart.TaskInstanceParameter()
qux = luigi.NumericalParameter(var_type=int)
quux = luigi.ChoiceParameter(choices=[1, 2, 3], var_type=int)
corge = luigi.EnumParameter(enum=MyEnum)
MyTask(foo="1", bar=1, baz=1, qux='1', quux='1', corge=1)
"""
with tempfile.NamedTemporaryFile(suffix='.py') as test_file:
test_file.write(test_code.encode('utf-8'))
test_file.flush()
result = api.run(['--show-traceback', '--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name])
self.assertIn('error: Argument "foo" to "MyTask" has incompatible type "str"; expected "int" [arg-type]', result[0])
self.assertIn('error: Argument "bar" to "MyTask" has incompatible type "int"; expected "date" [arg-type]', result[0])
self.assertIn('error: Argument "baz" to "MyTask" has incompatible type "int"; expected "TaskOnKart[Any]" [arg-type]', result[0])
self.assertIn('error: Argument "qux" to "MyTask" has incompatible type "str"; expected "int" [arg-type]', result[0])
self.assertIn('error: Argument "quux" to "MyTask" has incompatible type "str"; expected "int" [arg-type]', result[0])
self.assertIn('error: Argument "corge" to "MyTask" has incompatible type "int"; expected "MyEnum" [arg-type]', result[0])

def test_parameter_has_default_type_no_issue_pattern(self):
"""
If user doesn't set the type of the parameter, mypy infer the default type from Parameter types.
"""
test_code = """
from datetime import date
import luigi
import gokart
class MyTask(gokart.TaskOnKart):
# NOTE: mypy shows attr-defined error for the following lines, so we need to ignore it.
foo = luigi.IntParameter()
bar = luigi.DateParameter()
baz = gokart.TaskInstanceParameter()
MyTask(foo=1, bar=date.today(), baz=gokart.TaskOnKart())
"""
with tempfile.NamedTemporaryFile(suffix='.py') as test_file:
test_file.write(test_code.encode('utf-8'))
test_file.flush()
result = api.run(['--show-traceback', '--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name])
self.assertIn('Success: no issues found', result[0])
2 changes: 1 addition & 1 deletion test/test_task_instance_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class _DummyPipelineD(TaskOnKart):
subtask = gokart.ListTaskInstanceParameter(expected_elements_type=_DummySubTask)

with self.assertRaises(TypeError):
_DummyPipelineD(subtask=[_DummyInvalidSubClassTask(), _DummyCorrectSubClassTask])
_DummyPipelineD(subtask=[_DummyInvalidSubClassTask(), _DummyCorrectSubClassTask()])


if __name__ == '__main__':
Expand Down