Skip to content

Commit

Permalink
Multi-objective post-composition
Browse files Browse the repository at this point in the history
  • Loading branch information
ruicoelhopedro committed Apr 10, 2024
1 parent 6b544ff commit 5dd8e8e
Show file tree
Hide file tree
Showing 9 changed files with 560 additions and 173 deletions.
2 changes: 1 addition & 1 deletion piglot/bin/piglot.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def main(config_path: str = None):
stop_criteria=stop,
)
# Re-run the best case
if 'skip_last_run' not in config:
if 'skip_last_run' not in config and best_params is not None:
objective(parameters.normalise(best_params))


Expand Down
111 changes: 111 additions & 0 deletions piglot/bin/piglot_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import argparse
from tempfile import TemporaryDirectory
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.integrate import trapezoid
from scipy.spatial import ConvexHull
from PIL import Image
import torch
from piglot.parameter import read_parameters
Expand Down Expand Up @@ -321,6 +323,85 @@ def plot_gp(args):
plt.close()


def plot_pareto(args):
"""Driver for plotting the Pareto front for a given multi-objective optimisation problem.
Parameters
----------
args : dict
Passed arguments.
"""
config = parse_config_file(args.config)
objective = read_objective(config["objective"], read_parameters(config), config["output"])
if objective.num_objectives != 2:
raise ValueError("Can only plot the Pareto front for a two-objective optimisation problem.")
data = objective.get_history()
fig, ax = plt.subplots()
# Read all the points and variances
total_points = np.array([entry['values'] for entry in data.values()]).T
variances = np.array([entry['variances'] for entry in data.values() if 'variances' in entry]).T
has_variance = variances.size > 0
# Separate the dominated points
dominated = []
nondominated = []
pareto = pd.read_table(os.path.join(config["output"], 'pareto_front')).to_numpy()
for i, point in enumerate(total_points):
if np.isclose(point, pareto).all(axis=1).any():
nondominated.append((point, variances[i, :] if has_variance else None))
else:
dominated.append((point, variances[i, :] if has_variance else None))
# Build the Pareto hull
hull = ConvexHull(pareto)
for simplex in hull.simplices:
# Hacky: filter the line between the two endpoints (assuming the list is sorted by x)
if abs(simplex[0] - simplex[1]) < pareto.shape[0] - 1:
ax.plot(pareto[simplex, 0], pareto[simplex, 1], 'r', ls='--')
# Plot the points
if has_variance:
ax.errorbar(
[point[0][0] for point in nondominated],
[point[0][1] for point in nondominated],
xerr=np.sqrt([point[1][0] for point in nondominated]),
yerr=np.sqrt([point[1][0] for point in nondominated]),
c='r',
fmt='o',
label='Pareto front',
)
if args.all:
ax.errorbar(
[point[0][0] for point in dominated],
[point[0][1] for point in dominated],
xerr=np.sqrt([point[1][0] for point in dominated]),
yerr=np.sqrt([point[1][0] for point in dominated]),
c='k',
fmt='o',
label='Dominated points',
)
else:
ax.scatter(
[point[0][0] for point in nondominated],
[point[0][1] for point in nondominated],
c='r',
label='Pareto front',
)
if args.all:
ax.scatter(
[point[0][0] for point in dominated],
[point[0][1] for point in dominated],
c='k',
label='Dominated points',
)
if args.log:
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xlabel('Objective 1')
ax.set_ylabel('Objective 2')
ax.legend()
ax.grid()
fig.tight_layout()
plt.show()


def main(passed_args: List[str] = None):
"""Entry point for this script."""
# Global argument parser settings
Expand Down Expand Up @@ -544,6 +625,36 @@ def main(passed_args: List[str] = None):
)
sp_gp.set_defaults(func=plot_gp)

# Pareto front plotting
sp_pareto = subparsers.add_parser(
'pareto',
help='plot the Pareto front for a given multi-objective optimisation problem',
description=("Plot the Pareto front for a given multi-objective optimisation problem. This "
"must be executed in the same path as the running piglot instance."),
)
sp_pareto.add_argument(
'config',
type=str,
help="Path for the used or generated configuration file.",
)
sp_pareto.add_argument(
'--save_fig',
default=None,
type=str,
help=("Path to save the generated figure. If used, graphical output is skipped."),
)
sp_pareto.add_argument(
'--log',
action='store_true',
help="Plot in a log scale."
)
sp_pareto.add_argument(
'--all',
action='store_true',
help="Plot the both the Pareto front and the dominated points."
)
sp_pareto.set_defaults(func=plot_pareto)

args = parser.parse_args() if passed_args is None else parser.parse_args(passed_args)
args.func(args)

Expand Down
Loading

0 comments on commit 5dd8e8e

Please sign in to comment.