Skip to content

Commit

Permalink
Separate out framework def as dataclass
Browse files Browse the repository at this point in the history
  • Loading branch information
PGijsbers committed Nov 28, 2024
1 parent af1f7fa commit e8ab966
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 5 deletions.
5 changes: 2 additions & 3 deletions amlb/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,8 @@ def __init__(
Benchmark.data_loader = DataLoader(rconfig())

self._job_history = self._load_job_history(job_history=job_history)
self.framework_def, self.framework_name = load_framework_definition(
framework_name, rget()
)
framework = load_framework_definition(framework_name, rget())
self.framework_def, self.framework_name = framework, framework.name
log.debug("Using framework definition: %s.", self.framework_def)

self.constraint_def, self.constraint_name = rget().constraint_definition(
Expand Down
36 changes: 34 additions & 2 deletions amlb/frameworks/definitions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import copy
import itertools
import logging
import os
from dataclasses import dataclass
from typing import List, Optional, Union, TYPE_CHECKING

from amlb.utils import Namespace, config_load, str_sanitize
Expand Down Expand Up @@ -238,8 +241,37 @@ def _remove_frameworks_with_unknown_parent(frameworks: Namespace):
del frameworks[framework]


def load_framework_definition(framework_name: str, configuration: "Resources"):
@dataclass
class Image:
author: str
image_name: str
tag: str


@dataclass
class Framework:
name: str
description: str
project: str
abstract: bool
module: str
version: str
params: dict
# Image
image: Image
# Setup
setup_env: dict
setup_args: list[str]
_setup_cmd: str | None
setup_cmd: str | None
setup_script: str | None


def load_framework_definition(
framework_name: str, configuration: "Resources"
) -> Framework:
tag = None
if ":" in framework_name:
framework_name, tag = framework_name.split(":", 1)
return configuration.framework_definition(framework_name, tag)
definition_ns, name = configuration.framework_definition(framework_name, tag)
return Framework(**Namespace.dict(definition_ns))

0 comments on commit e8ab966

Please sign in to comment.