From 146193e4cc8bff2bec05d1255728da2fa1177759 Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Sat, 31 Aug 2024 18:31:54 +0200 Subject: [PATCH] Add type annotation for benchmark parser module --- amlb/benchmarks/openml.py | 6 +++--- amlb/benchmarks/parser.py | 17 +++++++++++++---- amlb/utils/core.py | 2 +- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/amlb/benchmarks/openml.py b/amlb/benchmarks/openml.py index 93c4cf42a..fa9befef1 100644 --- a/amlb/benchmarks/openml.py +++ b/amlb/benchmarks/openml.py @@ -26,8 +26,6 @@ def is_openml_benchmark(benchmark: str) -> bool: def load_oml_benchmark(benchmark: str) -> tuple[str, str | None, list[Namespace]]: """ Loads benchmark defined by openml suite or task, from openml/s/X or openml/t/Y. """ domain, oml_type, oml_id = benchmark.split('/') - path = None # benchmark file does not exist on disk - name = benchmark # name is later passed as cli input again for containers, it needs to remain parsable if domain == "test.openml": log.debug("Setting openml server to the test server.") @@ -62,4 +60,6 @@ def load_oml_benchmark(benchmark: str) -> tuple[str, str | None, list[Namespace] id="{}.org/t/{}".format(domain, tid))) else: raise ValueError(f"The oml_type is {oml_type} but must be 's' or 't'") - return name, path, tasks + # The first argument needs to remain parsable further in the pipeline as is + # The second argument is path, the benchmark does not exist on disk + return benchmark, None, tasks diff --git a/amlb/benchmarks/parser.py b/amlb/benchmarks/parser.py index c4aa3e246..f78434b58 100644 --- a/amlb/benchmarks/parser.py +++ b/amlb/benchmarks/parser.py @@ -1,11 +1,21 @@ -from typing import List +from __future__ import annotations + +from typing import List, Tuple from .openml import is_openml_benchmark, load_oml_benchmark from .file import load_file_benchmark -from amlb.utils import str_sanitize +from amlb.utils import str_sanitize, Namespace -def benchmark_load(name, benchmark_definition_dirs: List[str]): +def benchmark_load( + name: str, + benchmark_definition_dirs: List[str] + ) -> Tuple[ + Namespace | None, + List[Namespace], + str | None, + str + ]: """ Loads the benchmark definition for the 'benchmark' cli input string. :param name: the value for 'benchmark' @@ -17,7 +27,6 @@ def benchmark_load(name, benchmark_definition_dirs: List[str]): # which is why it is tried last. if is_openml_benchmark(name): benchmark_name, benchmark_path, tasks = load_oml_benchmark(name) - # elif is_kaggle_benchmark(name): else: benchmark_name, benchmark_path, tasks = load_file_benchmark(name, benchmark_definition_dirs) diff --git a/amlb/utils/core.py b/amlb/utils/core.py index 130e910fc..a7b34892e 100644 --- a/amlb/utils/core.py +++ b/amlb/utils/core.py @@ -345,7 +345,7 @@ def str_iter(col, sep=", "): return sep.join(map(str, col)) -def str_sanitize(s): +def str_sanitize(s: str) ->str: return re.sub(r"[^\w-]", "_", s)