Skip to content

Commit

Permalink
feat: infer type of parameter when user doesn't specify type
Browse files Browse the repository at this point in the history
  • Loading branch information
kitagry committed Oct 26, 2024
1 parent 8bd8fe0 commit 4589525
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 7 deletions.
109 changes: 103 additions & 6 deletions gokart/mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
NoneType,
Type,
TypeOfAny,
UnionType,
get_proper_type,
)
from mypy.typevars import fill_typevars
Expand Down Expand Up @@ -329,6 +330,10 @@ def collect_attributes(self) -> Optional[list[TaskOnKartAttribute]]:
with state.strict_optional_set(self._api.options.strict_optional):
init_type = self._infer_task_on_kart_attr_init_type(sym, stmt)

# infer Parameter type
if init_type is None:
init_type = self._infer_type_from_parameters(stmt.rvalue)

found_attrs[lhs.name] = TaskOnKartAttribute(
name=lhs.name,
has_default=has_default,
Expand Down Expand Up @@ -402,24 +407,115 @@ def _infer_task_on_kart_attr_init_type(self, sym: SymbolTableNode, context: Cont

return default

def _infer_type_from_parameters(self, parameter: Expression) -> Optional[Type]:
"""
Generate default type from Parameter.
For example, when parameter is `luigi.parameter.Parameter`, this method should return `str` type.
"""
parameter_name = _extract_parameter_name(parameter)
if parameter_name is None:
return None

underlying_type: Optional[Type] = None
if parameter_name in ['luigi.parameter.Parameter', 'luigi.parameter.OptionalParameter']:
underlying_type = self._api.named_type('builtins.str', [])
elif parameter_name in ['luigi.parameter.IntParameter', 'luigi.parameter.OptionalIntParameter']:
underlying_type = self._api.named_type('builtins.int', [])
elif parameter_name in ['luigi.parameter.FloatParameter', 'luigi.parameter.OptionalFloatParameter']:
underlying_type = self._api.named_type('builtins.float', [])
elif parameter_name in ['luigi.parameter.BoolParameter', 'luigi.parameter.OptionalBoolParameter']:
underlying_type = self._api.named_type('builtins.bool', [])
elif parameter_name in ['luigi.parameter.DateParameter', 'luigi.parameter.MonthParameter', 'luigi.parameter.YearParameter']:
underlying_type = self._api.named_type('datetime.date', [])
elif parameter_name in ['luigi.parameter.DateHourParameter', 'luigi.parameter.DateMinuteParameter', 'luigi.parameter.DateSecondParameter']:
underlying_type = self._api.named_type('datetime.datetime', [])
elif parameter_name in ['luigi.parameter.TimeDeltaParameter']:
underlying_type = self._api.named_type('datetime.timedelta', [])
elif parameter_name in ['luigi.parameter.DictParameter', 'luigi.parameter.OptionalDictParameter']:
underlying_type = self._api.named_type('builtins.dict', [AnyType(TypeOfAny.unannotated), AnyType(TypeOfAny.unannotated)])
elif parameter_name in ['luigi.parameter.ListParameter', 'luigi.parameter.OptionalListParameter']:
underlying_type = self._api.named_type('builtins.tuple', [AnyType(TypeOfAny.unannotated)])
elif parameter_name in ['luigi.parameter.TupleParameter', 'luigi.parameter.OptionalTupleParameter']:
underlying_type = self._api.named_type('builtins.tuple', [AnyType(TypeOfAny.unannotated)])
elif parameter_name in ['luigi.parameter.PathParameter', 'luigi.parameter.OptionalPathParameter']:
underlying_type = self._api.named_type('pathlib.Path', [])
elif parameter_name in ['gokart.parameter.TaskInstanceParameter']:
underlying_type = self._api.named_type('gokart.task.TaskOnKart', [AnyType(TypeOfAny.unannotated)])
elif parameter_name in ['gokart.parameter.ListTaskInstanceParameter']:
underlying_type = self._api.named_type('builtins.list', [self._api.named_type('gokart.task.TaskOnKart', [AnyType(TypeOfAny.unannotated)])])
elif parameter_name in ['gokart.parameter.ExplicitBoolParameter']:
underlying_type = self._api.named_type('builtins.bool', [])
elif parameter_name in ['luigi.parameter.NumericalParameter']:
underlying_type = self._get_type_from_args(parameter, 'var_type')
elif parameter_name in ['luigi.parameter.ChoiceParameter']:
underlying_type = self._get_type_from_args(parameter, 'var_type')
elif parameter_name in ['luigi.parameter.ChoiceListPareameter']:
base_type = self._get_type_from_args(parameter, 'var_type')
if base_type is not None:
underlying_type = self._api.named_type('builtins.tuple', [base_type])
elif parameter_name in ['luigi.parameter.EnumParameter']:
underlying_type = self._get_type_from_args(parameter, 'enum')
elif parameter_name in ['luigi.parameter.EnumListParameter']:
base_type = self._get_type_from_args(parameter, 'enum')
if base_type is not None:
underlying_type = self._api.named_type('builtins.tuple', [base_type])

if underlying_type is None:
return None

# When parameter has Optional, it can be none value.
if 'Optional' in parameter_name:
return UnionType([underlying_type, NoneType()])

return underlying_type

def _get_type_from_args(self, parameter: Expression, arg_key: str) -> Optional[Type]:
"""
get type from parameter arguments.
e.x)
When parameter is `luigi.ChoiceParameter(var_type=int)`, this method should return `int` type.
"""
ok, args = self._collect_parameter_args(parameter)
if not ok:
return None

if arg_key not in args:
return None

arg = args[arg_key]
if not isinstance(arg, NameExpr):
return None
if not isinstance(arg.node, TypeInfo):
return None
return Instance(arg.node, [])


def is_parameter_call(expr: Expression) -> bool:
"""Checks if the expression is a call to luigi.Parameter()"""
if not isinstance(expr, CallExpr):
parameter_name = _extract_parameter_name(expr)
if parameter_name is None:
return False
return PARAMETER_FULLNAME_MATCHER.match(parameter_name) is not None


def _extract_parameter_name(expr: Expression) -> Optional[str]:
"""Extract name if the expression is a call to luigi.Parameter()"""
if not isinstance(expr, CallExpr):
return None

callee = expr.callee
if isinstance(callee, MemberExpr):
type_info = callee.node
if type_info is None and isinstance(callee.expr, NameExpr):
return PARAMETER_FULLNAME_MATCHER.match(f'{callee.expr.name}.{callee.name}') is not None
return f'{callee.expr.name}.{callee.name}'
elif isinstance(callee, NameExpr):
type_info = callee.node
else:
return False
return None

if isinstance(type_info, TypeInfo):
return PARAMETER_FULLNAME_MATCHER.match(type_info.fullname) is not None
return type_info.fullname

# Currently, luigi doesn't provide py.typed. it will be released next to 3.5.1.
# https://github.com/spotify/luigi/pull/3297
Expand All @@ -429,8 +525,9 @@ def is_parameter_call(expr: Expression) -> bool:
# class MyTask(gokart.TaskOnKart):
# param = Parameter()
if isinstance(type_info, Var) and luigi.__version__ <= '3.5.1':
return PARAMETER_TMP_MATCHER.match(type_info.name) is not None
return False
return type_info.name

return None


def plugin(version: str) -> type[Plugin]:
Expand Down
58 changes: 58 additions & 0 deletions test/test_mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,61 @@ class MyTask(gokart.TaskOnKart):
self.assertIn('error: Argument "baz" to "MyTask" has incompatible type "str"; expected "bool" [arg-type]', result[0])
self.assertIn('error: Argument "complete_check_at_run" to "MyTask" has incompatible type "str"; expected "bool" [arg-type]', result[0])
self.assertIn('Found 3 errors in 1 file (checked 1 source file)', result[0])

def test_parameter_has_default_type_invalid_pattern(self):
"""
If user doesn't set the type of the parameter, mypy infer the default type from Parameter types.
"""
test_code = """
import enum
import luigi
import gokart
class MyEnum(enum.Enum):
FOO = enum.auto()
class MyTask(gokart.TaskOnKart):
# NOTE: mypy shows attr-defined error for the following lines, so we need to ignore it.
foo = luigi.IntParameter()
bar = luigi.DateParameter()
baz = gokart.TaskInstanceParameter()
qux = luigi.NumericalParameter(var_type=int)
quux = luigi.ChoiceParameter(choices=[1, 2, 3], var_type=int)
corge = luigi.EnumParameter(enum=MyEnum)
MyTask(foo="1", bar=1, baz=1, qux='1', quux='1', corge=1)
"""
with tempfile.NamedTemporaryFile(suffix='.py') as test_file:
test_file.write(test_code.encode('utf-8'))
test_file.flush()
result = api.run(['--show-traceback', '--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name])
self.assertIn('error: Argument "foo" to "MyTask" has incompatible type "str"; expected "int" [arg-type]', result[0])
self.assertIn('error: Argument "bar" to "MyTask" has incompatible type "int"; expected "date" [arg-type]', result[0])
self.assertIn('error: Argument "baz" to "MyTask" has incompatible type "int"; expected "TaskOnKart[Any]" [arg-type]', result[0])
self.assertIn('error: Argument "qux" to "MyTask" has incompatible type "str"; expected "int" [arg-type]', result[0])
self.assertIn('error: Argument "quux" to "MyTask" has incompatible type "str"; expected "int" [arg-type]', result[0])
self.assertIn('error: Argument "corge" to "MyTask" has incompatible type "int"; expected "MyEnum" [arg-type]', result[0])

def test_parameter_has_default_type_no_issue_pattern(self):
"""
If user doesn't set the type of the parameter, mypy infer the default type from Parameter types.
"""
test_code = """
from datetime import date
import luigi
import gokart
class MyTask(gokart.TaskOnKart):
# NOTE: mypy shows attr-defined error for the following lines, so we need to ignore it.
foo = luigi.IntParameter()
bar = luigi.DateParameter()
baz = gokart.TaskInstanceParameter()
MyTask(foo=1, bar=date.today(), baz=gokart.TaskOnKart())
"""
with tempfile.NamedTemporaryFile(suffix='.py') as test_file:
test_file.write(test_code.encode('utf-8'))
test_file.flush()
result = api.run(['--show-traceback', '--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name])
self.assertIn('Success: no issues found', result[0])
2 changes: 1 addition & 1 deletion test/test_task_instance_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class _DummyPipelineD(TaskOnKart):
subtask = gokart.ListTaskInstanceParameter(expected_elements_type=_DummySubTask)

with self.assertRaises(TypeError):
_DummyPipelineD(subtask=[_DummyInvalidSubClassTask(), _DummyCorrectSubClassTask])
_DummyPipelineD(subtask=[_DummyInvalidSubClassTask(), _DummyCorrectSubClassTask()])


if __name__ == '__main__':
Expand Down

0 comments on commit 4589525

Please sign in to comment.