-
Notifications
You must be signed in to change notification settings - Fork 0
/
recipe_utils.py
executable file
·121 lines (92 loc) · 3.78 KB
/
recipe_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Mar 16 11:33:59 2021
Recipe utilities
@author: edwardcui
"""
import os
import subprocess
import pathlib
from typing import Dict, List, Text, Optional
import yaml
import json
import jinja2
# Get project directory
PROJECT_DIR = str(pathlib.Path(__file__).parent.parent)
DEFAULT_METADATA = os.path.join(PROJECT_DIR, "metadata.yaml")
def get_metadata(metadata_file: str = DEFAULT_METADATA) -> Dict:
"""Return the metadata dictionary."""
metadata_file = DEFAULT_METADATA if metadata_file is None else metadata_file
with open(metadata_file, "r") as fid:
metadata = yaml.safe_load(fid)
# Parse any templated fields
metadata = parse_templated_fields(metadata)
return metadata
def get_config(metadata: Dict, field: Optional[Text] = "system_configurations",
filter_type: Optional[List[Text]] = None) -> Dict:
"""Return the pipeline config dictionary."""
if filter_type is not None:
config = {k: v_dict["value"] for k, v_dict in metadata[field].items() if v_dict["type"] in filter_type}
else:
config = {k: v_dict["value"] for k, v_dict in metadata[field].items()}
return config
def parse_templated_fields(metadata: Dict) -> Dict:
"""Parse any strings, array(string), or dict fields that are templated."""
parse_dict = {}
for field in metadata:
if "configurations" not in field:
parse_dict.update({field: metadata[field]})
else:
parse_dict.update(get_config(metadata, field))
def _recursive_render(s, cur_key):
if s is None:
return s
counter = 0
while "{{ " in s and " }}" in s:
s = jinja2.Template(s).render(**parse_dict)
counter += 1
if counter > 100:
raise(ValueError(f"Cannot parse templated field {cur_key}"))
return s
# looping over config sections:
for config_sec, configs in metadata.items():
if"configurations" not in config_sec:
continue
# looping over each field in the current config section
for cur_key, cur_val in configs.items():
if cur_val["type"] in ["string", "str"]:
cur_val["value"] = _recursive_render(cur_val["value"], cur_key)
elif cur_val["type"] == "array":
for index, s in enumerate(cur_val["value"]):
cur_val["value"][index] = _recursive_render(s, cur_key)
elif cur_val["type"] == "object": # a dict
# convert to json string
object_str = json.dumps(cur_val["value"]).replace("\\", "")
# parse it like a string
object_str = _recursive_render(object_str, cur_key)
# convert it back to dict
cur_val["value"] = json.loads(object_str)
metadata[config_sec][cur_key]["value"] = cur_val["value"]
return metadata
def run_shell_command(command, verbose=True):
"""Run a command and get the outputs."""
process = subprocess.Popen(command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
encoding="utf-8",
errors="replace")
cached_output = []
while True: # print progress real time
output = process.stdout.readline()
if output == '' and process.poll() is not None:
break
if output and verbose:
print(output.strip(), flush=True)
# save output
cached_output.append(output.strip())
cached_output = os.linesep.join(cached_output)
return process.returncode, cached_output
if __name__ == '__main__':
metadata = yaml.safe_load(open("tfx_template/metadata.yaml", "r"))
config = get_config(metadata)