forked from snel-repo/ComputationThruDynamicsBenchmark
-
Notifications
You must be signed in to change notification settings - Fork 3
/
utils.py
56 lines (51 loc) · 1.8 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def make_data_tag(dm_cfg):
obs_dim = "" if "obs_dim" not in dm_cfg else dm_cfg.obs_dim
obs_noise = "" if "obs_noise" not in dm_cfg else dm_cfg.obs_noise
if "obs_noise_params" in dm_cfg:
obs_noise_params = ",".join(
[f"{k}={v}" for k, v in dm_cfg.obs_noise_params.items()]
)
else:
obs_noise_params = ""
data_tag = (
f"{dm_cfg.system}{obs_dim}_"
f"{dm_cfg.n_samples}S_"
f"{dm_cfg.n_timesteps}T_"
f"{dm_cfg.pts_per_period}P_"
f"{dm_cfg.seed}seed"
)
if obs_noise:
data_tag += f"_{obs_noise}{obs_noise_params}"
return data_tag
def flatten(dictionary, level=[]):
"""Flattens a dictionary by placing '.' between levels.
This function flattens a hierarchical dictionary by placing '.'
between keys at various levels to create a single key for each
value. It is used internally for converting the configuration
dictionary to more convenient formats. Implementation was
inspired by `this StackOverflow post
<https://stackoverflow.com/questions/6037503/python-unflatten-dict>`_.
Parameters
----------
dictionary : dict
The hierarchical dictionary to be flattened.
level : str, optional
The string to append to the beginning of this dictionary,
enabling recursive calls. By default, an empty string.
Returns
-------
dict
The flattened dictionary.
See Also
--------
lfads_tf2.utils.unflatten : Performs the opposite of this operation.
"""
tmp_dict = {}
for key, val in dictionary.items():
if type(val) == dict:
tmp_dict.update(flatten(val, level + [key]))
else:
tmp_dict[".".join(level + [key])] = val
return tmp_dict
def trial_function(trial):
return trial.experiment_tag