Skip to content

Commit

Permalink
Fix issues flagged by mypy strict (#3490)
Browse files Browse the repository at this point in the history
Signed-off-by: Merel Theisen <merel.theisen@quantumblack.com>
  • Loading branch information
merelcht authored Jan 12, 2024
1 parent ce27c2d commit 8c8c713
Show file tree
Hide file tree
Showing 38 changed files with 480 additions and 362 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ clean:

lint:
pre-commit run -a --hook-stage manual $(hook)
mypy kedro
mypy kedro --strict --allow-any-generics
test:
pytest --numprocesses 4 --dist loadfile

Expand Down
2 changes: 1 addition & 1 deletion features/steps/sh_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(self, cmd: list[str], **kwargs) -> None:
**kwargs: keyword arguments such as env and cwd
"""
super().__init__( # type: ignore
super().__init__(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, **kwargs
)

Expand Down
2 changes: 1 addition & 1 deletion kedro/config/abstract_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(
conf_source: str,
env: str | None = None,
runtime_params: dict[str, Any] | None = None,
**kwargs,
**kwargs: Any,
):
super().__init__()
self.conf_source = conf_source
Expand Down
40 changes: 21 additions & 19 deletions kedro/config/omegaconf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,13 @@ def __init__( # noqa: PLR0913
except MissingConfigException:
self._globals = {}

def __setitem__(self, key, value):
def __setitem__(self, key: str, value: Any) -> None:
if key == "globals":
# Update the cached value at self._globals since it is used by the globals resolver
self._globals = value
super().__setitem__(key, value)

def __getitem__(self, key) -> dict[str, Any]: # noqa: PLR0912
def __getitem__(self, key: str) -> dict[str, Any]: # noqa: PLR0912
"""Get configuration files by key, load and merge them, and
return them in the form of a config dictionary.
Expand All @@ -175,7 +175,7 @@ def __getitem__(self, key) -> dict[str, Any]: # noqa: PLR0912
self._register_runtime_params_resolver()

if key in self:
return super().__getitem__(key)
return super().__getitem__(key) # type: ignore[no-any-return]

if key not in self.config_patterns:
raise KeyError(
Expand All @@ -196,7 +196,7 @@ def __getitem__(self, key) -> dict[str, Any]: # noqa: PLR0912
else:
base_path = str(Path(self._fs.ls("", detail=False)[-1]) / self.base_env)
try:
base_config = self.load_and_merge_dir_config(
base_config = self.load_and_merge_dir_config( # type: ignore[no-untyped-call]
base_path, patterns, key, processed_files, read_environment_variables
)
except UnsupportedInterpolationType as exc:
Expand All @@ -216,7 +216,7 @@ def __getitem__(self, key) -> dict[str, Any]: # noqa: PLR0912
else:
env_path = str(Path(self._fs.ls("", detail=False)[-1]) / run_env)
try:
env_config = self.load_and_merge_dir_config(
env_config = self.load_and_merge_dir_config( # type: ignore[no-untyped-call]
env_path, patterns, key, processed_files, read_environment_variables
)
except UnsupportedInterpolationType as exc:
Expand Down Expand Up @@ -244,9 +244,9 @@ def __getitem__(self, key) -> dict[str, Any]: # noqa: PLR0912
f" the glob pattern(s): {[*self.config_patterns[key]]}"
)

return resulting_config
return resulting_config # type: ignore[no-any-return]

def __repr__(self): # pragma: no cover
def __repr__(self) -> str: # pragma: no cover
return (
f"OmegaConfigLoader(conf_source={self.conf_source}, env={self.env}, "
f"config_patterns={self.config_patterns})"
Expand Down Expand Up @@ -312,8 +312,8 @@ def load_and_merge_dir_config( # noqa: PLR0913
self._resolve_environment_variables(config)
config_per_file[config_filepath] = config
except (ParserError, ScannerError) as exc:
line = exc.problem_mark.line # type: ignore
cursor = exc.problem_mark.column # type: ignore
line = exc.problem_mark.line
cursor = exc.problem_mark.column
raise ParserError(
f"Invalid YAML or JSON file {Path(conf_path, config_filepath.name).as_posix()},"
f" unable to read line {line}, position {cursor}."
Expand Down Expand Up @@ -342,7 +342,7 @@ def load_and_merge_dir_config( # noqa: PLR0913
if not k.startswith("_")
}

def _is_valid_config_path(self, path):
def _is_valid_config_path(self, path: Path) -> bool:
"""Check if given path is a file path and file type is yaml or json."""
posix_path = path.as_posix()
return self._fs.isfile(str(posix_path)) and path.suffix in [
Expand All @@ -351,22 +351,22 @@ def _is_valid_config_path(self, path):
".json",
]

def _register_globals_resolver(self):
def _register_globals_resolver(self) -> None:
"""Register the globals resolver"""
OmegaConf.register_new_resolver(
"globals",
self._get_globals_value,
replace=True,
)

def _register_runtime_params_resolver(self):
def _register_runtime_params_resolver(self) -> None:
OmegaConf.register_new_resolver(
"runtime_params",
self._get_runtime_value,
replace=True,
)

def _get_globals_value(self, variable, default_value=_NO_VALUE):
def _get_globals_value(self, variable: str, default_value: Any = _NO_VALUE) -> Any:
"""Return the globals values to the resolver"""
if variable.startswith("_"):
raise InterpolationResolutionError(
Expand All @@ -383,7 +383,7 @@ def _get_globals_value(self, variable, default_value=_NO_VALUE):
f"Globals key '{variable}' not found and no default value provided."
)

def _get_runtime_value(self, variable, default_value=_NO_VALUE):
def _get_runtime_value(self, variable: str, default_value: Any = _NO_VALUE) -> Any:
"""Return the runtime params values to the resolver"""
runtime_oc = OmegaConf.create(self.runtime_params)
interpolated_value = OmegaConf.select(
Expand All @@ -397,7 +397,7 @@ def _get_runtime_value(self, variable, default_value=_NO_VALUE):
)

@staticmethod
def _register_new_resolvers(resolvers: dict[str, Callable]):
def _register_new_resolvers(resolvers: dict[str, Callable]) -> None:
"""Register custom resolvers"""
for name, resolver in resolvers.items():
if not OmegaConf.has_resolver(name):
Expand All @@ -406,7 +406,7 @@ def _register_new_resolvers(resolvers: dict[str, Callable]):
OmegaConf.register_new_resolver(name=name, resolver=resolver)

@staticmethod
def _check_duplicates(seen_files_to_keys: dict[Path, set[Any]]):
def _check_duplicates(seen_files_to_keys: dict[Path, set[Any]]) -> None:
duplicates = []

filepaths = list(seen_files_to_keys.keys())
Expand Down Expand Up @@ -449,7 +449,9 @@ def _resolve_environment_variables(config: DictConfig) -> None:
OmegaConf.resolve(config)

@staticmethod
def _destructive_merge(config, env_config, env_path):
def _destructive_merge(
config: dict[str, Any], env_config: dict[str, Any], env_path: str
) -> dict[str, Any]:
# Destructively merge the two env dirs. The chosen env will override base.
common_keys = config.keys() & env_config.keys()
if common_keys:
Expand All @@ -464,11 +466,11 @@ def _destructive_merge(config, env_config, env_path):
return config

@staticmethod
def _soft_merge(config, env_config):
def _soft_merge(config: dict[str, Any], env_config: dict[str, Any]) -> Any:
# Soft merge the two env dirs. The chosen env will override base if keys clash.
return OmegaConf.to_container(OmegaConf.merge(config, env_config))

def _is_hidden(self, path_str: str):
def _is_hidden(self, path_str: str) -> bool:
"""Check if path contains any hidden directory or is a hidden file"""
path = Path(path_str)
conf_path = Path(self.conf_source).resolve().as_posix()
Expand Down
29 changes: 18 additions & 11 deletions kedro/framework/cli/catalog.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
"""A collection of CLI commands for working with Kedro catalog."""
from __future__ import annotations

import copy
from collections import defaultdict
from itertools import chain
from pathlib import Path
from typing import Any

import click
import yaml
Expand All @@ -11,21 +15,22 @@
from kedro.framework.project import pipelines, settings
from kedro.framework.session import KedroSession
from kedro.framework.startup import ProjectMetadata
from kedro.io import AbstractDataset


def _create_session(package_name: str, **kwargs):
def _create_session(package_name: str, **kwargs: Any) -> KedroSession:
kwargs.setdefault("save_on_close", False)
return KedroSession.create(**kwargs)


# noqa: missing-function-docstring
@click.group(name="Kedro")
def catalog_cli(): # pragma: no cover
def catalog_cli() -> None: # pragma: no cover
pass


@catalog_cli.group()
def catalog():
def catalog() -> None:
"""Commands for working with catalog."""


Expand All @@ -42,7 +47,7 @@ def catalog():
callback=split_string,
)
@click.pass_obj
def list_datasets(metadata: ProjectMetadata, pipeline, env):
def list_datasets(metadata: ProjectMetadata, pipeline: str, env: str) -> None:
"""Show datasets per type."""
title = "Datasets in '{}' pipeline"
not_mentioned = "Datasets not mentioned in pipeline"
Expand Down Expand Up @@ -111,11 +116,13 @@ def list_datasets(metadata: ProjectMetadata, pipeline, env):
secho(yaml.dump(result))


def _map_type_to_datasets(datasets, datasets_meta):
def _map_type_to_datasets(
datasets: set[str], datasets_meta: dict[str, AbstractDataset]
) -> dict:
"""Build dictionary with a dataset type as a key and list of
datasets of the specific type as a value.
"""
mapping = defaultdict(list)
mapping = defaultdict(list) # type: ignore[var-annotated]
for dataset in datasets:
is_param = dataset.startswith("params:") or dataset == "parameters"
if not is_param:
Expand All @@ -136,7 +143,7 @@ def _map_type_to_datasets(datasets, datasets_meta):
help="Name of a pipeline.",
)
@click.pass_obj
def create_catalog(metadata: ProjectMetadata, pipeline_name, env):
def create_catalog(metadata: ProjectMetadata, pipeline_name: str, env: str) -> None:
"""Create Data Catalog YAML configuration with missing datasets.
Add ``MemoryDataset`` datasets to Data Catalog YAML configuration
Expand Down Expand Up @@ -185,7 +192,7 @@ def create_catalog(metadata: ProjectMetadata, pipeline_name, env):
click.echo("All datasets are already configured.")


def _add_missing_datasets_to_catalog(missing_ds, catalog_path):
def _add_missing_datasets_to_catalog(missing_ds: list[str], catalog_path: Path) -> None:
if catalog_path.is_file():
catalog_config = yaml.safe_load(catalog_path.read_text()) or {}
else:
Expand All @@ -204,7 +211,7 @@ def _add_missing_datasets_to_catalog(missing_ds, catalog_path):
@catalog.command("rank")
@env_option
@click.pass_obj
def rank_catalog_factories(metadata: ProjectMetadata, env):
def rank_catalog_factories(metadata: ProjectMetadata, env: str) -> None:
"""List all dataset factories in the catalog, ranked by priority by which they are matched."""
session = _create_session(metadata.package_name, env=env)
context = session.load_context()
Expand All @@ -219,7 +226,7 @@ def rank_catalog_factories(metadata: ProjectMetadata, env):
@catalog.command("resolve")
@env_option
@click.pass_obj
def resolve_patterns(metadata: ProjectMetadata, env):
def resolve_patterns(metadata: ProjectMetadata, env: str) -> None:
"""Resolve catalog factories against pipeline datasets. Note that this command is runner
agnostic and thus won't take into account any default dataset creation defined in the runner."""

Expand Down Expand Up @@ -268,5 +275,5 @@ def resolve_patterns(metadata: ProjectMetadata, env):
secho(yaml.dump(explicit_datasets))


def _trim_filepath(project_path: str, file_path: str):
def _trim_filepath(project_path: str, file_path: str) -> str:
return file_path.replace(project_path, "", 1)
24 changes: 13 additions & 11 deletions kedro/framework/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
This module implements commands available from the kedro CLI.
"""
from __future__ import annotations

import importlib
import sys
from collections import defaultdict
from pathlib import Path
from typing import Sequence
from typing import Any, Sequence

import click

Expand Down Expand Up @@ -42,7 +44,7 @@

@click.group(context_settings=CONTEXT_SETTINGS, name="Kedro")
@click.version_option(version, "--version", "-V", help="Show version and exit")
def cli(): # pragma: no cover
def cli() -> None: # pragma: no cover
"""Kedro is a CLI for creating and using Kedro projects. For more
information, type ``kedro info``.
Expand All @@ -51,7 +53,7 @@ def cli(): # pragma: no cover


@cli.command()
def info():
def info() -> None:
"""Get more information about kedro."""
click.secho(LOGO, fg="green")
click.echo(
Expand Down Expand Up @@ -104,12 +106,12 @@ def __init__(self, project_path: Path):

def main(
self,
args=None,
prog_name=None,
complete_var=None,
standalone_mode=True,
**extra,
):
args: Any | None = None,
prog_name: Any | None = None,
complete_var: Any | None = None,
standalone_mode: bool = True,
**extra: Any,
) -> Any:
if self._metadata:
extra.update(obj=self._metadata)

Expand Down Expand Up @@ -182,13 +184,13 @@ def project_groups(self) -> Sequence[click.MultiCommand]:
raise KedroCliError(
f"Cannot load commands from {self._metadata.package_name}.cli"
)
user_defined = project_cli.cli # type: ignore
user_defined = project_cli.cli
# return built-in commands, plugin commands and user defined commands
# (overriding happens as follows built-in < plugins < cli.py)
return [*built_in, *plugins, user_defined]


def main(): # pragma: no cover
def main() -> None: # pragma: no cover
"""Main entry point. Look for a ``cli.py``, and, if found, add its
commands to `kedro`'s before invoking the CLI.
"""
Expand Down
2 changes: 1 addition & 1 deletion kedro/framework/cli/hooks/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
_CLI_PLUGIN_HOOKS = "kedro.cli_hooks"


def get_cli_hook_manager():
def get_cli_hook_manager() -> PluginManager:
"""Create or return the global _hook_manager singleton instance."""
global _cli_hook_manager # noqa: PLW0603
if _cli_hook_manager is None:
Expand Down
4 changes: 2 additions & 2 deletions kedro/framework/cli/hooks/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def before_command_run(
self,
project_metadata: ProjectMetadata,
command_args: list[str],
):
) -> None:
"""Hooks to be invoked before a CLI command runs.
It receives the ``project_metadata`` as well as
all command line arguments that were used, including the command
Expand All @@ -32,7 +32,7 @@ def before_command_run(
@cli_hook_spec
def after_command_run(
self, project_metadata: ProjectMetadata, command_args: list[str], exit_code: int
):
) -> None:
"""Hooks to be invoked after a CLI command runs.
It receives the ``project_metadata`` as well as
all command line arguments that were used, including the command
Expand Down
Loading

0 comments on commit 8c8c713

Please sign in to comment.