Skip to content

Commit

Permalink
Implemented automatic hiding of strategies used as an executed baseline
Browse files Browse the repository at this point in the history
  • Loading branch information
fjwillemsen committed Jun 25, 2024
1 parent 9c3c1bc commit 3221df1
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion src/autotuning_methodology/visualize_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ def __init__(
compare_baselines: bool = plot_settings.get("compare_baselines", False)
compare_split_times: bool = plot_settings.get("compare_split_times", False)
confidence_level: float = plot_settings.get("confidence_level", 0.95)
self.plot_skip_strategies: list[str] = list()
if use_strategy_as_baseline is not None:
self.plot_skip_strategies.append(use_strategy_as_baseline)

# visualize
aggregation_data: list[tuple[Baseline, list[Curve], SearchspaceStatistics, np.ndarray]] = list()
Expand Down Expand Up @@ -382,6 +385,8 @@ def plot_baselines_comparison(

# plot normal strategies
for strategy_curve in strategies_curves:
if strategy_curve.name in self.plot_skip_strategies:
continue
(
_,
x_axis_range_real,
Expand Down Expand Up @@ -448,6 +453,9 @@ def plot_split_times_comparison(
# )
# )
lines: list[CurveBasis] = strategies_curves + baselines
for line in lines:
if isinstance(line, Curve) and line.name in self.plot_skip_strategies:
lines.remove(line)

# setup the subplots
num_rows = len(lines)
Expand Down Expand Up @@ -538,11 +546,13 @@ def plot_split_times_bar_comparison(
data_dict = dict.fromkeys(objective_time_keys)
data_table = list(
list(list() for _ in range(len(objective_time_keys) - len(print_skip)))
for _ in range(len(strategies_curves) + 1)
for _ in range((len(strategies_curves) - len(self.plot_skip_strategies)) + 1)
)
for objective_time_key in objective_time_keys:
data_dict[objective_time_key] = np.full((len(strategies_curves)), np.NaN)
for strategy_index, strategy_curve in enumerate(strategies_curves):
if strategy_curve.name in self.plot_skip_strategies:
continue
print_skip_counter = 0
strategy_labels.append(strategy_curve.display_name)
strategy_split_times = strategy_curve.get_split_times(fevals_or_time_range, x_type, searchspace_stats)
Expand Down Expand Up @@ -692,6 +702,8 @@ def normalize_multiple(curves: list) -> tuple:
color = self.colors[strategy_index]
label = f"{strategy['display_name']}"
strategy_curve = strategies_curves[strategy_index]
if strategy_curve.name in self.plot_skip_strategies:
continue

# get the plot data
if y_type == "scatter":
Expand Down Expand Up @@ -919,6 +931,8 @@ def plot_strategies_aggregated(
print("\n-------")
print("Quantification of aggregate performance across all search spaces:")
for strategy_index, strategy_performance in enumerate(strategies_performance):
if self.strategies[strategy_index]["name"] in self.plot_skip_strategies:
continue
displayname = self.strategies[strategy_index]["display_name"]
color = self.colors[strategy_index]
real_stopping_point_fraction = strategies_real_stopping_point_fraction[strategy_index]
Expand Down

0 comments on commit 3221df1

Please sign in to comment.