diff --git a/mart/generate_config.py b/mart/generate_config.py new file mode 100644 index 00000000..f729c96e --- /dev/null +++ b/mart/generate_config.py @@ -0,0 +1,44 @@ +# +# Copyright (C) 2022 Intel Corporation +# +# SPDX-License-Identifier: BSD-3-Clause +# + +from __future__ import annotations + +import fire +from omegaconf import OmegaConf + +from .utils.config import ( + DEFAULT_CONFIG_DIR, + DEFAULT_CONFIG_NAME, + DEFAULT_VERSION_BASE, + compose, +) + + +def main( + *overrides, + version_base: str = DEFAULT_VERSION_BASE, + config_dir: str = DEFAULT_CONFIG_DIR, + config_name: str = DEFAULT_CONFIG_NAME, + export_node: str | None = None, + resolve: bool = False, + sort_keys: bool = True, +): + cfg = compose( + *overrides, + version_base=version_base, + config_dir=config_dir, + config_name=config_name, + export_node=export_node, + ) + + cfg_yaml = OmegaConf.to_yaml(cfg, resolve=resolve, sort_keys=sort_keys) + + # OmegaConf.to_yaml() already ends with `\n`. + print(cfg_yaml, end="") + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/mart/utils/__init__.py b/mart/utils/__init__.py index 50e71b3d..b243a608 100644 --- a/mart/utils/__init__.py +++ b/mart/utils/__init__.py @@ -1,4 +1,5 @@ from .adapters import * +from .config import * from .export import * from .monkey_patch import * from .pylogger import * diff --git a/mart/utils/config.py b/mart/utils/config.py new file mode 100644 index 00000000..36dbb6e6 --- /dev/null +++ b/mart/utils/config.py @@ -0,0 +1,42 @@ +# +# Copyright (C) 2022 Intel Corporation +# +# SPDX-License-Identifier: BSD-3-Clause +# + +from __future__ import annotations + +import os + +from hydra import compose as hydra_compose +from hydra import initialize_config_dir + +DEFAULT_VERSION_BASE = "1.2" +DEFAULT_CONFIG_DIR = "." +DEFAULT_CONFIG_NAME = "lightning.yaml" + +__all__ = ["compose"] + + +def compose( + *overrides, + version_base: str = DEFAULT_VERSION_BASE, + config_dir: str = DEFAULT_CONFIG_DIR, + config_name: str = DEFAULT_CONFIG_NAME, + export_node: str | None = None, +): + # Add an absolute path {config_dir} to the search path of configs, preceding those in mart.configs. + if not os.path.isabs(config_dir): + config_dir = os.path.abspath(config_dir) + + # hydra.initialize_config_dir() requires an absolute path, + # while hydra.initialize() searches paths relatively to mart. + with initialize_config_dir(version_base=version_base, config_dir=config_dir): + cfg = hydra_compose(config_name=config_name, overrides=overrides) + + # Export a sub-tree. + if export_node is not None: + for key in export_node.split("."): + cfg = cfg[key] + + return cfg