-
Notifications
You must be signed in to change notification settings - Fork 1
/
ppp_wildcards.py
398 lines (348 loc) · 15.8 KB
/
ppp_wildcards.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
import fnmatch
import os
import json
from typing import Optional
import logging
import yaml
from ppp_logging import DEBUG_LEVEL
def deep_freeze(obj):
"""
Deep freeze an object.
Args:
obj (object): The object to freeze.
Returns:
object: The frozen object.
"""
if isinstance(obj, dict):
return tuple((k, deep_freeze(v)) for k, v in sorted(obj.items()))
elif isinstance(obj, list):
return tuple(deep_freeze(i) for i in obj)
elif isinstance(obj, set):
return tuple(deep_freeze(i) for i in sorted(obj))
else:
return obj
class PPPWildcard:
"""
A wildcard object.
Attributes:
key (str): The key of the wildcard.
file (str): The path to the file where the wildcard is defined.
unprocessed_choices (list[str]): The unprocessed choices of the wildcard.
choices (list[dict]): The processed choices of the wildcard.
options (dict): The options of the wildcard.
"""
def __init__(self, fullpath: str, key: str, choices: list[str]):
self.key: str = key
self.file: str = fullpath
self.unprocessed_choices: list[str] = choices
self.choices: list[dict] = None
self.options: dict = None
def __hash__(self) -> int:
t = (self.key, deep_freeze(self.unprocessed_choices))
return hash(t)
class PPPWildcards:
"""
A class to manage wildcards.
Attributes:
wildcards (dict[str, PPPWildcard]): The wildcards.
"""
DEFAULT_WILDCARDS_FOLDER = "wildcards"
def __init__(self, logger):
self.__logger: logging.Logger = logger
self.__debug_level = DEBUG_LEVEL.none
self.__wildcards_folders = []
self.__wildcard_files = {}
self.wildcards: dict[str, PPPWildcard] = {}
def __hash__(self) -> int:
return hash(deep_freeze(self.wildcards))
def refresh_wildcards(self, debug_level: DEBUG_LEVEL, wildcards_folders: Optional[list[str]]):
"""
Initialize the wildcards.
"""
self.__debug_level = debug_level
self.__wildcards_folders = wildcards_folders
if wildcards_folders is not None:
# if self.debug_level != DEBUG_LEVEL.none:
# self.logger.info("Initializing wildcards...")
# t1 = time.time()
for fullpath in list(self.__wildcard_files.keys()):
path = os.path.dirname(fullpath)
if not os.path.exists(fullpath) or not any(
os.path.commonpath([path, folder]) == folder for folder in self.__wildcards_folders
):
self.__remove_wildcards_from_file(fullpath)
for f in self.__wildcards_folders:
self.__get_wildcards_in_directory(f, f)
# t2 = time.time()
# if self.debug_level != DEBUG_LEVEL.none:
# self.logger.info(f"Wildcards init time: {t2 - t1:.3f} seconds")
else:
self.__wildcards_folders = []
self.wildcards = {}
self.__wildcard_files = {}
def get_wildcards(self, key: str) -> list[PPPWildcard]:
"""
Get all wildcards that match a key.
Args:
key (str): The key to match.
Returns:
list: A list of all wildcards that match the key.
"""
keys = sorted(fnmatch.filter(self.wildcards.keys(), key))
return [self.wildcards[k] for k in keys]
def __get_keys_in_dict(self, dictionary: dict, prefix="") -> list[str]:
"""
Get all keys in a dictionary.
Args:
dictionary (dict): The dictionary to check.
prefix (str): The prefix for the current key.
Returns:
list: A list of all keys in the dictionary, including nested keys.
"""
keys = []
for key in dictionary.keys():
if isinstance(dictionary[key], dict):
keys.extend(self.__get_keys_in_dict(dictionary[key], prefix + key + "/"))
else:
keys.append(prefix + str(key))
return keys
def __get_nested(self, dictionary: dict, keys: str) -> object:
"""
Get a nested value from a dictionary.
Args:
dictionary (dict): The dictionary to check.
keys (str): The keys to get the value from.
Returns:
object: The value of the nested keys in the dictionary.
"""
keys = keys.split("/")
current_dict = dictionary
for key in keys:
current_dict = current_dict.get(key)
if current_dict is None:
return None
return current_dict
def __remove_wildcards_from_file(self, full_path: str, debug=True):
"""
Clear all wildcards in a file.
Args:
full_path (str): The path to the file.
debug (bool): Whether to print debug messages or not.
"""
last_modified_cached = self.__wildcard_files.get(full_path, None)
if debug and last_modified_cached is not None and self.__debug_level != DEBUG_LEVEL.none:
self.__logger.debug(f"Removing wildcards from file: {full_path}")
if full_path in self.__wildcard_files.keys():
del self.__wildcard_files[full_path]
for key in list(self.wildcards.keys()):
if self.wildcards[key].file == full_path:
del self.wildcards[key]
def __get_wildcards_in_file(self, base, full_path: str):
"""
Get all wildcards in a file.
Args:
base (str): The base path for the wildcards.
full_path (str): The path to the file.
"""
last_modified = os.path.getmtime(full_path)
last_modified_cached = self.__wildcard_files.get(full_path, None)
if last_modified_cached is not None and last_modified == self.__wildcard_files[full_path]:
return
filename = os.path.basename(full_path)
_, extension = os.path.splitext(filename)
if extension not in (".txt", ".json", ".yaml", ".yml"):
return
self.__remove_wildcards_from_file(full_path, False)
if last_modified_cached is not None and self.__debug_level != DEBUG_LEVEL.none:
self.__logger.debug(f"Updating wildcards from file: {full_path}")
if extension == ".txt":
self.__get_wildcards_in_text_file(full_path, base)
elif extension in (".json", ".yaml", ".yml"):
self.__get_wildcards_in_structured_file(full_path, base)
self.__wildcard_files[full_path] = last_modified
def is_dict_choices_options(self, d: dict) -> bool:
"""
Check if a dictionary is a valid choices options dictionary.
Args:
d (dict): The dictionary to check.
Returns:
bool: Whether the dictionary is a valid choices options dictionary or not.
"""
return all(
k in ["sampler", "repeating", "count", "from", "to", "prefix", "suffix", "separator"] for k in d.keys()
)
def is_dict_choice_options(self, d: dict) -> bool:
"""
Check if a dictionary is a valid choice options dictionary.
Args:
d (dict): The dictionary to check.
Returns:
bool: Whether the dictionary is a valid choice options dictionary or not.
"""
return all(k in ["labels", "weight", "if", "content", "text"] for k in d.keys())
def __get_choices(self, obj: object, full_path: str, key_parts: list[str]) -> list:
"""
We process the choices in the object and return them as a list.
Args:
obj (object): the value of a wildcard
full_path (str): path to the file where the wildcard is defined
key_parts (list[str]): parts of the key for the wildcard
Returns:
list: list of choices
"""
choices = None
if obj is not None:
if isinstance(obj, (str, dict)):
choices = [obj]
elif isinstance(obj, (int, float, bool)):
choices = [str(obj)]
elif isinstance(obj, list) and len(obj) > 0:
choices = []
for i, c in enumerate(obj):
invalid_choice = False
if isinstance(c, str):
choice = c
elif isinstance(c, (int, float, bool)):
choice = str(c)
elif isinstance(c, list):
# we create an anonymous wildcard
choice = self.__create_anonymous_wildcard(full_path, key_parts, i, c)
elif isinstance(c, dict):
if self.is_dict_choices_options(c) or self.is_dict_choice_options(c):
# we assume it is a choice or wildcard parameters in object format
choice = c
choice_content = choice.get("content", choice.get("text", None))
if choice_content is not None and isinstance(choice_content, list):
# we create an anonymous wildcard
choice["content"] = self.__create_anonymous_wildcard(
full_path, key_parts, i, choice_content
)
if "text" in choice:
del choice["text"]
elif len(c) == 1:
# we assume it is an anonymous wildcard with options
firstkey = list(c.keys())[0]
choice = self.__create_anonymous_wildcard(full_path, key_parts, i, c[firstkey], firstkey)
else:
invalid_choice = True
else:
invalid_choice = True
if invalid_choice:
self.__logger.warning(
f"Invalid choice {i+1} in wildcard '{'/'.join(key_parts)}' in file '{full_path}'!"
)
else:
choices.append(choice)
return choices
def __create_anonymous_wildcard(self, full_path, key_parts, i, content, options=None):
"""
Create an anonymous wildcard.
Args:
full_path (str): The path to the file that contains it.
key_parts (list[str]): The parts of the key.
i (int): The index of the wildcard.
content (object): The content of the wildcard.
options (str): The options for the choice where the wildcard is defined.
Returns:
str: The resulting value for the choice.
"""
new_parts = key_parts + [f"#ANON_{i}"]
self.__add_wildcard(content, full_path, new_parts)
value = f"__{'/'.join(new_parts)}__"
if options is not None:
value = f"{options}::{value}"
return value
def __add_wildcard(self, content: object, full_path: str, external_key_parts: list[str]):
"""
Add a wildcard to the wildcards dictionary.
Args:
content (object): The content of the wildcard.
full_path (str): The path to the file that contains it.
external_key_parts (list[str]): The parts of the key.
"""
key_parts = external_key_parts.copy()
if isinstance(content, dict):
key_parts.pop()
keys = self.__get_keys_in_dict(content)
for key in keys:
tmp_key_parts = key_parts.copy()
tmp_key_parts.extend(key.split("/"))
fullkey = "/".join(tmp_key_parts)
if self.wildcards.get(fullkey, None) is not None:
self.__logger.warning(
f"Duplicate wildcard '{fullkey}' in file '{full_path}' and '{self.wildcards[fullkey].file}'!"
)
else:
obj = self.__get_nested(content, key)
choices = self.__get_choices(obj, full_path, tmp_key_parts)
if choices is None:
self.__logger.warning(f"Invalid wildcard '{fullkey}' in file '{full_path}'!")
else:
self.wildcards[fullkey] = PPPWildcard(full_path, fullkey, choices)
return
if isinstance(content, str):
content = [content]
elif isinstance(content, (int, float, bool)):
content = [str(content)]
if not isinstance(content, list):
self.__logger.warning(f"Invalid wildcard in file '{full_path}'!")
return
fullkey = "/".join(key_parts)
if self.wildcards.get(fullkey, None) is not None:
self.__logger.warning(
f"Duplicate wildcard '{fullkey}' in file '{full_path}' and '{self.wildcards[fullkey].file}'!"
)
else:
choices = self.__get_choices(content, full_path, key_parts)
if choices is None:
self.__logger.warning(f"Invalid wildcard '{fullkey}' in file '{full_path}'!")
else:
self.wildcards[fullkey] = PPPWildcard(full_path, fullkey, choices)
def __get_wildcards_in_structured_file(self, full_path, base):
"""
Get all wildcards in a structured file.
Args:
full_path (str): The path to the file.
base (str): The base path for the wildcards.
"""
external_key: str = os.path.relpath(os.path.splitext(full_path)[0], base)
external_key_parts = external_key.split(os.sep)
_, extension = os.path.splitext(full_path)
with open(full_path, "r", encoding="utf-8") as file:
if extension == ".json":
content = json.loads(file.read())
else:
content = yaml.safe_load(file)
self.__add_wildcard(content, full_path, external_key_parts)
def __get_wildcards_in_text_file(self, full_path, base):
"""
Get all wildcards in a text file.
Args:
full_path (str): The path to the file.
base (str): The base path for the wildcards.
"""
external_key: str = os.path.relpath(os.path.splitext(full_path)[0], base)
external_key_parts = external_key.split(os.sep)
with open(full_path, "r", encoding="utf-8") as file:
text_content = map(lambda x: x.strip("\n\r"), file.readlines())
text_content = list(filter(lambda x: x.strip() != "" and not x.strip().startswith("#"), text_content))
text_content = [x.split("#")[0].rstrip() if len(x.split("#")) > 1 else x for x in text_content]
self.__add_wildcard(text_content, full_path, external_key_parts)
def __get_wildcards_in_directory(self, base: str, directory: str):
"""
Get all wildcards in a directory.
Args:
base (str): The base path for the wildcards.
directory (str): The path to the directory.
"""
if not os.path.exists(directory):
self.__logger.warning(f"Wildcard directory '{directory}' does not exist!")
return
for filename in os.listdir(directory):
full_path = os.path.abspath(os.path.join(directory, filename))
if os.path.basename(full_path).startswith("."):
continue
if os.path.isdir(full_path):
self.__get_wildcards_in_directory(base, full_path)
elif os.path.isfile(full_path):
self.__get_wildcards_in_file(base, full_path)