-
Notifications
You must be signed in to change notification settings - Fork 60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: infer type of parameter when user doesn't specify type #403
Merged
+161
−53
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,27 +10,23 @@ | |
from typing import Callable, Final, Iterator, Literal, Optional | ||
|
||
import luigi | ||
from mypy.expandtype import expand_type, expand_type_by_instance | ||
from mypy.expandtype import expand_type | ||
from mypy.nodes import ( | ||
ARG_NAMED_OPT, | ||
ARG_POS, | ||
Argument, | ||
AssignmentStmt, | ||
Block, | ||
CallExpr, | ||
ClassDef, | ||
Context, | ||
EllipsisExpr, | ||
Expression, | ||
FuncDef, | ||
IfStmt, | ||
JsonDict, | ||
MemberExpr, | ||
NameExpr, | ||
PlaceholderNode, | ||
RefExpr, | ||
Statement, | ||
SymbolTableNode, | ||
TempNode, | ||
TypeInfo, | ||
Var, | ||
|
@@ -45,12 +41,11 @@ | |
from mypy.typeops import map_type_from_supertype | ||
from mypy.types import ( | ||
AnyType, | ||
CallableType, | ||
Instance, | ||
NoneType, | ||
Type, | ||
TypeOfAny, | ||
get_proper_type, | ||
UnionType, | ||
) | ||
from mypy.typevars import fill_typevars | ||
|
||
|
@@ -327,7 +322,11 @@ def collect_attributes(self) -> Optional[list[TaskOnKartAttribute]]: | |
|
||
current_attr_names.add(lhs.name) | ||
with state.strict_optional_set(self._api.options.strict_optional): | ||
init_type = self._infer_task_on_kart_attr_init_type(sym, stmt) | ||
init_type = sym.type | ||
|
||
# infer Parameter type | ||
if init_type is None: | ||
init_type = self._infer_type_from_parameters(stmt.rvalue) | ||
|
||
found_attrs[lhs.name] = TaskOnKartAttribute( | ||
name=lhs.name, | ||
|
@@ -361,65 +360,115 @@ def _collect_parameter_args(self, expr: Expression) -> tuple[bool, dict[str, Exp | |
return True, args | ||
return False, {} | ||
|
||
def _infer_task_on_kart_attr_init_type(self, sym: SymbolTableNode, context: Context) -> Type | None: | ||
"""Infer __init__ argument type for an attribute. | ||
def _infer_type_from_parameters(self, parameter: Expression) -> Optional[Type]: | ||
""" | ||
Generate default type from Parameter. | ||
For example, when parameter is `luigi.parameter.Parameter`, this method should return `str` type. | ||
""" | ||
parameter_name = _extract_parameter_name(parameter) | ||
if parameter_name is None: | ||
return None | ||
|
||
underlying_type: Optional[Type] = None | ||
if parameter_name in ['luigi.parameter.Parameter', 'luigi.parameter.OptionalParameter']: | ||
underlying_type = self._api.named_type('builtins.str', []) | ||
elif parameter_name in ['luigi.parameter.IntParameter', 'luigi.parameter.OptionalIntParameter']: | ||
underlying_type = self._api.named_type('builtins.int', []) | ||
elif parameter_name in ['luigi.parameter.FloatParameter', 'luigi.parameter.OptionalFloatParameter']: | ||
underlying_type = self._api.named_type('builtins.float', []) | ||
elif parameter_name in ['luigi.parameter.BoolParameter', 'luigi.parameter.OptionalBoolParameter']: | ||
underlying_type = self._api.named_type('builtins.bool', []) | ||
elif parameter_name in ['luigi.parameter.DateParameter', 'luigi.parameter.MonthParameter', 'luigi.parameter.YearParameter']: | ||
underlying_type = self._api.named_type('datetime.date', []) | ||
elif parameter_name in ['luigi.parameter.DateHourParameter', 'luigi.parameter.DateMinuteParameter', 'luigi.parameter.DateSecondParameter']: | ||
underlying_type = self._api.named_type('datetime.datetime', []) | ||
elif parameter_name in ['luigi.parameter.TimeDeltaParameter']: | ||
underlying_type = self._api.named_type('datetime.timedelta', []) | ||
elif parameter_name in ['luigi.parameter.DictParameter', 'luigi.parameter.OptionalDictParameter']: | ||
underlying_type = self._api.named_type('builtins.dict', [AnyType(TypeOfAny.unannotated), AnyType(TypeOfAny.unannotated)]) | ||
elif parameter_name in ['luigi.parameter.ListParameter', 'luigi.parameter.OptionalListParameter']: | ||
underlying_type = self._api.named_type('builtins.tuple', [AnyType(TypeOfAny.unannotated)]) | ||
elif parameter_name in ['luigi.parameter.TupleParameter', 'luigi.parameter.OptionalTupleParameter']: | ||
underlying_type = self._api.named_type('builtins.tuple', [AnyType(TypeOfAny.unannotated)]) | ||
elif parameter_name in ['luigi.parameter.PathParameter', 'luigi.parameter.OptionalPathParameter']: | ||
underlying_type = self._api.named_type('pathlib.Path', []) | ||
elif parameter_name in ['gokart.parameter.TaskInstanceParameter']: | ||
underlying_type = self._api.named_type('gokart.task.TaskOnKart', [AnyType(TypeOfAny.unannotated)]) | ||
elif parameter_name in ['gokart.parameter.ListTaskInstanceParameter']: | ||
underlying_type = self._api.named_type('builtins.list', [self._api.named_type('gokart.task.TaskOnKart', [AnyType(TypeOfAny.unannotated)])]) | ||
elif parameter_name in ['gokart.parameter.ExplicitBoolParameter']: | ||
underlying_type = self._api.named_type('builtins.bool', []) | ||
elif parameter_name in ['luigi.parameter.NumericalParameter']: | ||
underlying_type = self._get_type_from_args(parameter, 'var_type') | ||
elif parameter_name in ['luigi.parameter.ChoiceParameter']: | ||
underlying_type = self._get_type_from_args(parameter, 'var_type') | ||
elif parameter_name in ['luigi.parameter.ChoiceListPareameter']: | ||
base_type = self._get_type_from_args(parameter, 'var_type') | ||
if base_type is not None: | ||
underlying_type = self._api.named_type('builtins.tuple', [base_type]) | ||
elif parameter_name in ['luigi.parameter.EnumParameter']: | ||
underlying_type = self._get_type_from_args(parameter, 'enum') | ||
elif parameter_name in ['luigi.parameter.EnumListParameter']: | ||
base_type = self._get_type_from_args(parameter, 'enum') | ||
if base_type is not None: | ||
underlying_type = self._api.named_type('builtins.tuple', [base_type]) | ||
|
||
Comment on lines
+373
to
+415
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. memo: After Python 3.9 EOL, we can use |
||
if underlying_type is None: | ||
return None | ||
|
||
# When parameter has Optional, it can be none value. | ||
if 'Optional' in parameter_name: | ||
return UnionType([underlying_type, NoneType()]) | ||
|
||
return underlying_type | ||
|
||
def _get_type_from_args(self, parameter: Expression, arg_key: str) -> Optional[Type]: | ||
""" | ||
get type from parameter arguments. | ||
|
||
In particular, possibly use the signature of __set__. | ||
e.x) | ||
When parameter is `luigi.ChoiceParameter(var_type=int)`, this method should return `int` type. | ||
""" | ||
default = sym.type | ||
if sym.implicit: | ||
return default | ||
t = get_proper_type(sym.type) | ||
|
||
# Perform a simple-minded inference from the signature of __set__, if present. | ||
# We can't use mypy.checkmember here, since this plugin runs before type checking. | ||
# We only support some basic scanerios here, which is hopefully sufficient for | ||
# the vast majority of use cases. | ||
if not isinstance(t, Instance): | ||
return default | ||
setter = t.type.get('__set__') | ||
|
||
if not setter: | ||
return default | ||
|
||
if isinstance(setter.node, FuncDef): | ||
super_info = t.type.get_containing_type_info('__set__') | ||
assert super_info | ||
if setter.type: | ||
setter_type = get_proper_type(map_type_from_supertype(setter.type, t.type, super_info)) | ||
else: | ||
return AnyType(TypeOfAny.unannotated) | ||
if isinstance(setter_type, CallableType) and setter_type.arg_kinds == [ | ||
ARG_POS, | ||
ARG_POS, | ||
ARG_POS, | ||
]: | ||
return expand_type_by_instance(setter_type.arg_types[2], t) | ||
else: | ||
self._api.fail(f'Unsupported signature for "__set__" in "{t.type.name}"', context) | ||
else: | ||
self._api.fail(f'Unsupported "__set__" in "{t.type.name}"', context) | ||
|
||
return default | ||
ok, args = self._collect_parameter_args(parameter) | ||
if not ok: | ||
return None | ||
|
||
if arg_key not in args: | ||
return None | ||
|
||
arg = args[arg_key] | ||
if not isinstance(arg, NameExpr): | ||
return None | ||
if not isinstance(arg.node, TypeInfo): | ||
return None | ||
return Instance(arg.node, []) | ||
|
||
|
||
def is_parameter_call(expr: Expression) -> bool: | ||
"""Checks if the expression is a call to luigi.Parameter()""" | ||
if not isinstance(expr, CallExpr): | ||
parameter_name = _extract_parameter_name(expr) | ||
if parameter_name is None: | ||
return False | ||
return PARAMETER_FULLNAME_MATCHER.match(parameter_name) is not None | ||
|
||
|
||
def _extract_parameter_name(expr: Expression) -> Optional[str]: | ||
"""Extract name if the expression is a call to luigi.Parameter()""" | ||
if not isinstance(expr, CallExpr): | ||
return None | ||
|
||
callee = expr.callee | ||
if isinstance(callee, MemberExpr): | ||
type_info = callee.node | ||
if type_info is None and isinstance(callee.expr, NameExpr): | ||
return PARAMETER_FULLNAME_MATCHER.match(f'{callee.expr.name}.{callee.name}') is not None | ||
return f'{callee.expr.name}.{callee.name}' | ||
elif isinstance(callee, NameExpr): | ||
type_info = callee.node | ||
else: | ||
return False | ||
return None | ||
|
||
if isinstance(type_info, TypeInfo): | ||
return PARAMETER_FULLNAME_MATCHER.match(type_info.fullname) is not None | ||
return type_info.fullname | ||
|
||
# Currently, luigi doesn't provide py.typed. it will be released next to 3.5.1. | ||
# https://github.com/spotify/luigi/pull/3297 | ||
|
@@ -429,8 +478,9 @@ def is_parameter_call(expr: Expression) -> bool: | |
# class MyTask(gokart.TaskOnKart): | ||
# param = Parameter() | ||
if isinstance(type_info, Var) and luigi.__version__ <= '3.5.1': | ||
return PARAMETER_TMP_MATCHER.match(type_info.name) is not None | ||
return False | ||
return type_info.name | ||
|
||
return None | ||
|
||
|
||
def plugin(version: str) -> type[Plugin]: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(just a memo for this change)
This function aims to infer type for pydantic field, which is not needed on gokart.