From 146193e4cc8bff2bec05d1255728da2fa1177759 Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Sat, 31 Aug 2024 18:31:54 +0200
Subject: [PATCH] Add type annotation for benchmark parser module
---
amlb/benchmarks/openml.py | 6 +++---
amlb/benchmarks/parser.py | 17 +++++++++++++----
amlb/utils/core.py | 2 +-
3 files changed, 17 insertions(+), 8 deletions(-)
diff --git a/amlb/benchmarks/openml.py b/amlb/benchmarks/openml.py
index 93c4cf42a..fa9befef1 100644
--- a/amlb/benchmarks/openml.py
+++ b/amlb/benchmarks/openml.py
@@ -26,8 +26,6 @@ def is_openml_benchmark(benchmark: str) -> bool:
def load_oml_benchmark(benchmark: str) -> tuple[str, str | None, list[Namespace]]:
""" Loads benchmark defined by openml suite or task, from openml/s/X or openml/t/Y. """
domain, oml_type, oml_id = benchmark.split('/')
- path = None # benchmark file does not exist on disk
- name = benchmark # name is later passed as cli input again for containers, it needs to remain parsable
if domain == "test.openml":
log.debug("Setting openml server to the test server.")
@@ -62,4 +60,6 @@ def load_oml_benchmark(benchmark: str) -> tuple[str, str | None, list[Namespace]
id="{}.org/t/{}".format(domain, tid)))
else:
raise ValueError(f"The oml_type is {oml_type} but must be 's' or 't'")
- return name, path, tasks
+ # The first argument needs to remain parsable further in the pipeline as is
+ # The second argument is path, the benchmark does not exist on disk
+ return benchmark, None, tasks
diff --git a/amlb/benchmarks/parser.py b/amlb/benchmarks/parser.py
index c4aa3e246..f78434b58 100644
--- a/amlb/benchmarks/parser.py
+++ b/amlb/benchmarks/parser.py
@@ -1,11 +1,21 @@
-from typing import List
+from __future__ import annotations
+
+from typing import List, Tuple
from .openml import is_openml_benchmark, load_oml_benchmark
from .file import load_file_benchmark
-from amlb.utils import str_sanitize
+from amlb.utils import str_sanitize, Namespace
-def benchmark_load(name, benchmark_definition_dirs: List[str]):
+def benchmark_load(
+ name: str,
+ benchmark_definition_dirs: List[str]
+ ) -> Tuple[
+ Namespace | None,
+ List[Namespace],
+ str | None,
+ str
+ ]:
""" Loads the benchmark definition for the 'benchmark' cli input string.
:param name: the value for 'benchmark'
@@ -17,7 +27,6 @@ def benchmark_load(name, benchmark_definition_dirs: List[str]):
# which is why it is tried last.
if is_openml_benchmark(name):
benchmark_name, benchmark_path, tasks = load_oml_benchmark(name)
- # elif is_kaggle_benchmark(name):
else:
benchmark_name, benchmark_path, tasks = load_file_benchmark(name, benchmark_definition_dirs)
diff --git a/amlb/utils/core.py b/amlb/utils/core.py
index 130e910fc..a7b34892e 100644
--- a/amlb/utils/core.py
+++ b/amlb/utils/core.py
@@ -345,7 +345,7 @@ def str_iter(col, sep=", "):
return sep.join(map(str, col))
-def str_sanitize(s):
+def str_sanitize(s: str) ->str:
return re.sub(r"[^\w-]", "_", s)