Skip to content

Commit

Permalink
Add type annotation for benchmark parser module
Browse files Browse the repository at this point in the history
  • Loading branch information
PGijsbers committed Aug 31, 2024
1 parent 6d14f98 commit 146193e
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 8 deletions.
6 changes: 3 additions & 3 deletions amlb/benchmarks/openml.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ def is_openml_benchmark(benchmark: str) -> bool:
def load_oml_benchmark(benchmark: str) -> tuple[str, str | None, list[Namespace]]:
""" Loads benchmark defined by openml suite or task, from openml/s/X or openml/t/Y. """
domain, oml_type, oml_id = benchmark.split('/')
path = None # benchmark file does not exist on disk
name = benchmark # name is later passed as cli input again for containers, it needs to remain parsable

if domain == "test.openml":
log.debug("Setting openml server to the test server.")
Expand Down Expand Up @@ -62,4 +60,6 @@ def load_oml_benchmark(benchmark: str) -> tuple[str, str | None, list[Namespace]
id="{}.org/t/{}".format(domain, tid)))
else:
raise ValueError(f"The oml_type is {oml_type} but must be 's' or 't'")
return name, path, tasks
# The first argument needs to remain parsable further in the pipeline as is
# The second argument is path, the benchmark does not exist on disk
return benchmark, None, tasks
17 changes: 13 additions & 4 deletions amlb/benchmarks/parser.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
from typing import List
from __future__ import annotations

from typing import List, Tuple

from .openml import is_openml_benchmark, load_oml_benchmark
from .file import load_file_benchmark
from amlb.utils import str_sanitize
from amlb.utils import str_sanitize, Namespace


def benchmark_load(name, benchmark_definition_dirs: List[str]):
def benchmark_load(
name: str,
benchmark_definition_dirs: List[str]
) -> Tuple[
Namespace | None,
List[Namespace],
str | None,
str
]:
""" Loads the benchmark definition for the 'benchmark' cli input string.
:param name: the value for 'benchmark'
Expand All @@ -17,7 +27,6 @@ def benchmark_load(name, benchmark_definition_dirs: List[str]):
# which is why it is tried last.
if is_openml_benchmark(name):
benchmark_name, benchmark_path, tasks = load_oml_benchmark(name)
# elif is_kaggle_benchmark(name):
else:
benchmark_name, benchmark_path, tasks = load_file_benchmark(name, benchmark_definition_dirs)

Expand Down
2 changes: 1 addition & 1 deletion amlb/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def str_iter(col, sep=", "):
return sep.join(map(str, col))


def str_sanitize(s):
def str_sanitize(s: str) ->str:
return re.sub(r"[^\w-]", "_", s)


Expand Down

0 comments on commit 146193e

Please sign in to comment.