Skip to content

Commit

Permalink
Merge pull request #15 from edong6768/changes/0.1.13
Browse files Browse the repository at this point in the history
Changes/0.1.13
  • Loading branch information
edong6768 committed May 12, 2024
2 parents 3f31657 + 199670d commit b07c560
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 25 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ package-dir = {""="src"}

[project]
name = "malet"
version = "0.1.12"
version = "0.1.13"
description = "Malet: a tool for machine learning experiment"
readme = "README.md"
requires-python = ">=3.8"
Expand Down
34 changes: 20 additions & 14 deletions src/malet/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def __init__(self, exp_config_path):
self.name = os.path.split(exp_config_path)[0].split('/')[-1]
self.grid = self.raw_config.get('grid')
self.static_configs = {k:self.raw_config[k] for k in set(self.raw_config)-{'grid_fields', 'grid'}}

assert not (f:={k for k in self.static_configs.keys() if k in self.grid_fields}), f'Overlapping fields {f} in Static configs and grid fields.'

self.grid_iter = self.__get_iter()

Expand Down Expand Up @@ -107,19 +109,24 @@ def __extract_grid_order(cfg_str):
return dupless_names

@staticmethod
def __ravel_group(study):
# Return list of study if there is no 'group'
if 'group' not in study: return [study]
def __ravel_group(grid):
# Return list of grid if there is no 'group'
if 'group' not in grid: return [grid]
group = grid['group']

# Ravel grouped fields into list of experiment plans.
grid_g = lambda g: ([*zip(g.keys(), value)] for value in zip(*g.values()))
if type(study['group'])==dict:
r_groups = grid_g(study['group'])
elif type(study['group'])==list:
r_groups = (chain(*gs) for gs in product(*map(grid_g, study['group'])))
study.pop('group')

raveled_study = ({**study, **dict(g)} for g in r_groups)
def grid_g(g):
g_len = [*map(len, g.items())]
assert all([l==g_len[0] for l in g_len]), f'Grouped fields should have same length, got fields with length {dict(zip(g.keys(), g_len))}'
return ([*zip(g.keys(), value)] for value in zip(*g.values()))

if isinstance(group, dict):
r_groups = grid_g(group)
elif isinstance(group, list):
r_groups = (chain(*gs) for gs in product(*map(grid_g, group)))
grid.pop('group')

raveled_study = ({**grid, **dict(g)} for g in r_groups)
return raveled_study

def __get_iter(self):
Expand Down Expand Up @@ -562,11 +569,10 @@ def __get_and_split_configs(cfg_file, exp_bs, exp_bi):

configiter = ConfigIter(cfg_file)

assert exp_bs.isdigit() or (exp_bs in configiter.grid_fields), f'Enter valid splits (int | Literal{configiter.grid_fields}).'
assert isinstance(exp_bs, int) or (exp_bs in configiter.grid_fields), f'Enter valid splits (int | Literal{configiter.grid_fields}).'

# if total exp split is given as integer : uniformly split
if exp_bs.isdigit():
exp_bs, exp_bi = map(int, [exp_bs, exp_bi])
if isinstance(exp_bs, int):
assert exp_bs > 0, 'Total number of experiment splits should be larger than 0'
assert exp_bs > exp_bi, 'Experiment split index should be smaller than the total number of experiment splits'
if exp_bs>1:
Expand Down
6 changes: 4 additions & 2 deletions src/malet/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def draw_metric(tsv_file, plot_config, save_name='', preprcs_df=lambda *x: x):
#--- initial filter for df according to FLAGS.filter (except epoch and metric)
if pflt:
filt_dict = [*map(lambda flt: re.split('(?<!,) ', flt.strip()), pflt.split('/'))] # split ' ' except ', '
log.df = select_df(log.df, {fk:[*map(str2value, fvs)] for fk, *fvs in filt_dict if fk not in {'step', 'metric'}})
log.df = select_df(log.df, {fk:[*map(str2value, fvs)] for fk, *fvs in filt_dict if fk[-1]!='!' and fk not in {'step', 'metric'}})
log.df = select_df(log.df, {fk[:-1]:[*map(str2value, fvs)] for fk, *fvs in filt_dict if fk[-1]=='!' in fk and fk not in {'step', 'metric'}}, equal=False)

#--- melt and explode metric in log.df
if 'metric' not in pmlf and 'metric' not in x_fields:
Expand All @@ -99,7 +100,8 @@ def draw_metric(tsv_file, plot_config, save_name='', preprcs_df=lambda *x: x):
#---filter df according to FLAGS.filter step and metrics
if pflt:
e_rng = lambda fvs: [*range(*map(int, fvs[0].split(':')))] if (len(fvs)==1 and ':' in fvs[0]) else fvs # CNG 'a:b' step filter later
df = select_df(df, {fk:[*map(str2value, e_rng(fvs))] for fk, *fvs in filt_dict if fk in {'step', 'metric'}})
df = select_df(df, {fk:[*map(str2value, e_rng(fvs))] for fk, *fvs in filt_dict if fk[-1]!='!' and fk in {'step', 'metric'}})
df = select_df(df, {fk[:-1]:[*map(str2value, e_rng(fvs))] for fk, *fvs in filt_dict if fk[-1]=='!' and fk in {'step', 'metric'}}, equal=False)


#---set mlines according to FLAGS.multi_line_fields
Expand Down
5 changes: 3 additions & 2 deletions src/malet/plot_utils/data_processor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

import pandas as pd

def select_df(df, filt_dict, *exclude_fields, drop=False):
def select_df(df, filt_dict, *exclude_fields, equal=True, drop=False):
"""Select df rows with matching from given filt_dict except ``exclude_fields``"""
assert not df.empty, 'Given dataframe is empty.'
assert not (k:=set(filt_dict.keys()) - set(df.index.names)), f'filt_dict keys {k} is not in df.'
Expand All @@ -13,7 +13,8 @@ def select_df(df, filt_dict, *exclude_fields, drop=False):
for i, k in enumerate(filt_keys):
values = nest(filt_dict[k])
assert not (v:=set(values)-(vs:=set(df.index.get_level_values(k)))), f"Values {v} are not in field '{k}': {sorted(vs)}"
df = df.loc[df.index.get_level_values(k).isin(values)]
fltr = df.index.get_level_values(k).isin(values)
df = df.loc[fltr if equal else ~fltr]
assert not df.empty, f"Filter {k}:{values} return empty dataframe. Inspect {dict((k, filt_dict[k]) for k in filt_keys[:i+1])}"

if drop:
Expand Down
14 changes: 8 additions & 6 deletions src/malet/plot_utils/plot_drawer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ def ax_draw_curve(ax: Axes,

ax.plot(tick_values, metric_values, label=label, color=color, linewidth=linewidth,
marker=marker, markersize=markersize, markevery=markevery)

if len(x_values)%markevery!=0:
ax.plot(tick_values[-1], metric_values[-1], color=color, marker=marker, markersize=markersize)

if f'{y_field}_std' in df:
x_values, metric_std = map(np.array, zip(*dict(df[f'{y_field}_std']).items()))

Expand All @@ -68,7 +72,7 @@ def ax_draw_curve(ax: Axes,
[s[0]] + [i for i in s[1:] if i not in 'aeiou']) if len(s)>3 else s
abv_annot = [*map(abv, annotate_field)]
for i, (x,y,t) in enumerate(zip(x_values, metric_values, tick_values)):
if i%markevery: continue
if i%markevery and i!=len(x_values)-1: continue
txt = '\n'.join([f'{y:.5f}'+(f'$\pm${metric_std[i]:.5f}' if (f'{y_field}_std' in df and pd.notna(metric_std[i])) else ''), str(x)]
+[f'{i}={df.loc[x][j]}' for i, j in zip(abv_annot, annotate_field)])
ax.annotate(txt, (t,y), textcoords="offset points", xytext=(0,10), ha='center')
Expand Down Expand Up @@ -136,11 +140,9 @@ def ax_draw_best_stared_curve(ax: Axes,
best_idx = list(metric_values).index((max if best_at_max else min)(metric_values))
for i, (_,y,t) in enumerate(zip(x_values, metric_values, tick_values)):
if i%markevery: continue
if i==0:
ax.plot(tick_values[i], metric_values[i], label=label, color=color, linewidth=linewidth,
marker=marker, markersize=markersize, markevery=markevery)
elif i==best_idx:
ax.plot(tick_values[i], metric_values[i], color='green', marker='*', markersize=markersize+10)
if i==best_idx:
ax.plot(tick_values[i], metric_values[i], color='green',
marker='*', markersize=markersize+10)
else:
ax.plot(tick_values[i], metric_values[i], color=color,
marker=marker, markersize=markersize, markevery=markevery)
Expand Down

0 comments on commit b07c560

Please sign in to comment.