diff --git a/amlb/benchmark.py b/amlb/benchmark.py index 3cc90e0e4..b611ca66d 100644 --- a/amlb/benchmark.py +++ b/amlb/benchmark.py @@ -120,9 +120,8 @@ def __init__( Benchmark.data_loader = DataLoader(rconfig()) self._job_history = self._load_job_history(job_history=job_history) - self.framework_def, self.framework_name = load_framework_definition( - framework_name, rget() - ) + framework = load_framework_definition(framework_name, rget()) + self.framework_def, self.framework_name = framework, framework.name log.debug("Using framework definition: %s.", self.framework_def) self.constraint_def, self.constraint_name = rget().constraint_definition( diff --git a/amlb/frameworks/definitions.py b/amlb/frameworks/definitions.py index 977e3b004..3484f2104 100644 --- a/amlb/frameworks/definitions.py +++ b/amlb/frameworks/definitions.py @@ -1,7 +1,10 @@ +from __future__ import annotations + import copy import itertools import logging import os +from dataclasses import dataclass from typing import List, Optional, Union, TYPE_CHECKING from amlb.utils import Namespace, config_load, str_sanitize @@ -238,8 +241,37 @@ def _remove_frameworks_with_unknown_parent(frameworks: Namespace): del frameworks[framework] -def load_framework_definition(framework_name: str, configuration: "Resources"): +@dataclass +class Image: + author: str + image_name: str + tag: str + + +@dataclass +class Framework: + name: str + description: str + project: str + abstract: bool + module: str + version: str + params: dict + # Image + image: Image + # Setup + setup_env: dict + setup_args: list[str] + _setup_cmd: str | None + setup_cmd: str | None + setup_script: str | None + + +def load_framework_definition( + framework_name: str, configuration: "Resources" +) -> Framework: tag = None if ":" in framework_name: framework_name, tag = framework_name.split(":", 1) - return configuration.framework_definition(framework_name, tag) + definition_ns, name = configuration.framework_definition(framework_name, tag) + return Framework(**Namespace.dict(definition_ns))