Skip to content

Commit

Permalink
Add support for response transformers for design and fitting objectives
Browse files Browse the repository at this point in the history
  • Loading branch information
ruicoelhopedro committed May 27, 2024
1 parent 735ed21 commit 5762e47
Show file tree
Hide file tree
Showing 3 changed files with 251 additions and 6 deletions.
25 changes: 24 additions & 1 deletion piglot/objectives/design.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from piglot.objective import Composition, DynamicPlotter, GenericObjective, ObjectiveResult
from piglot.utils.assorted import stats_interp_to_common_grid, read_custom_module
from piglot.utils.reductions import Reduction, NegateReduction, read_reduction
from piglot.utils.response_transformer import ResponseTransformer, read_response_transformer
from piglot.utils.composition.responses import ResponseComposition, EndpointFlattenUtility


Expand Down Expand Up @@ -127,6 +128,7 @@ def __init__(
stochastic: bool = False,
composite: bool = False,
multi_objective: bool = False,
transformers: Dict[str, ResponseTransformer] = None,
) -> None:
super().__init__(
parameters,
Expand All @@ -139,6 +141,7 @@ def __init__(
)
self.solver = solver
self.targets = targets
self.transformers = transformers if transformers is not None else {}

def prepare(self) -> None:
"""Prepare the objective for optimisation."""
Expand Down Expand Up @@ -221,6 +224,10 @@ def _objective(self, values: np.ndarray, concurrent: bool = False) -> ObjectiveR
Objective result.
"""
raw_responses = self.solver.solve(values, concurrent)
# Transform responses
for name, transformer in self.transformers.items():
if name in raw_responses:
raw_responses[name] = transformer.transform(raw_responses[name])
# Interpolate responses to common grid and map to targets
responses_interp = {
target: [
Expand Down Expand Up @@ -268,12 +275,22 @@ def plot_case(self, case_hash: str, options: Dict[str, Any] = None) -> List[Figu
figures = []
# Load all responses
responses = self.solver.get_output_response(case_hash)
# Transform responses if necessary
for name, transformer in self.transformers.items():
if name in responses:
responses[name] = transformer.transform(responses[name])
# Plot each target
for target in self.targets:
# Build figure with individual responses for this target
fig, axis = plt.subplots()
for pred in target.prediction:
axis.plot(responses[pred].get_time(), responses[pred].get_data(), label=f'{pred}')
marker = 'o' if len(responses[pred].get_time()) < 2 else None
axis.plot(
responses[pred].get_time(),
responses[pred].get_data(),
label=f'{pred}',
marker=marker,
)
# Under stochasticity, plot response stats and mean confidence interval
if self.stochastic:
stats = stats_interp_to_common_grid([
Expand Down Expand Up @@ -355,6 +372,11 @@ def read(
# Sanitise target: check if it is associated to at least one case
if len(targets[target]) == 0:
raise ValueError(f"Design target '{target_name}' is not associated to any case.")
# Read transformers
transformers: Dict[str, ResponseTransformer] = {}
if 'transformers' in config:
for name, transformer_config in config.pop('transformers').items():
transformers[name] = read_response_transformer(transformer_config)
# Read custom class (if any)
target_class: type[DesignObjective] = DesignObjective
if 'custom_class' in config:
Expand All @@ -367,4 +389,5 @@ def read(
stochastic=bool(config.get('stochastic', False)),
composite=bool(config.get('composite', False)),
multi_objective=bool(config.get('multi_objective', False)),
transformers=transformers,
)
30 changes: 25 additions & 5 deletions piglot/objectives/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from piglot.utils.assorted import stats_interp_to_common_grid
from piglot.utils.reductions import Reduction, read_reduction
from piglot.utils.responses import Transformer, reduce_response, interpolate_response
from piglot.utils.response_transformer import ResponseTransformer, read_response_transformer
from piglot.utils.composition.responses import ResponseComposition, FixedFlatteningUtility
from piglot.objective import Composition, DynamicPlotter, GenericObjective, ObjectiveResult

Expand Down Expand Up @@ -47,9 +48,12 @@ def __init__(
self.weight = weight
self.reduction = reduction
# Load the data right away
data = np.genfromtxt(filename, skip_header=skip_header)[:, [x_col - 1, y_col - 1]]
self.x_data = data[:, 0]
self.y_data = data[:, 1]
data = np.genfromtxt(filename, skip_header=skip_header)
# Sanitise to ensure it is a 2D array
if len(data.shape) == 1:
data = data.reshape(1, -1)
self.x_data = data[:, x_col - 1]
self.y_data = data[:, y_col - 1]
# Apply the transformer
if self.transformer is not None:
self.x_data, self.y_data = self.transformer(self.x_data, self.y_data)
Expand Down Expand Up @@ -259,10 +263,12 @@ def __init__(
solver: Solver,
references: Dict[Reference, List[str]],
reduction: Reduction,
transformers: Dict[str, ResponseTransformer] = None,
) -> None:
self.solver = solver
self.references = references
self.reduction = reduction
self.transformers = transformers if transformers is not None else {}
# Assign the reduction to non-defined references
for reference in self.references:
if reference.reduction is None:
Expand Down Expand Up @@ -319,6 +325,10 @@ def solve(self, values: np.ndarray, concurrent: bool) -> Dict[Reference, List[Ou
Output results.
"""
result = self.solver.solve(values, concurrent)
# Transform responses
for name, transformer in self.transformers.items():
if name in result:
result[name] = transformer.transform(result[name])
# Populate output results
output = {}
for reference, cases in self.references.items():
Expand Down Expand Up @@ -348,6 +358,10 @@ def plot_case(self, case_hash: str, options: Dict[str, Any] = None) -> List[Figu
figures = []
# Load all responses
responses = self.solver.get_output_response(case_hash)
# Transform responses if necessary
for name, transformer in self.transformers.items():
if name in responses:
responses[name] = transformer.transform(responses[name])
# Plot each reference
for reference, names in self.references.items():
# Build figure, index axes and plot response
Expand Down Expand Up @@ -390,8 +404,9 @@ def plot_case(self, case_hash: str, options: Dict[str, Any] = None) -> List[Figu
else:
# Plot the individual responses
for name in names:
marker = 'o' if len(responses[name].get_time()) < 2 else None
axis.plot(responses[name].get_time(), responses[name].get_data(),
label=f'{name}')
label=f'{name}', marker=marker)
if reference_limits:
axis.set_xlim(xlim)
axis.set_ylim(ylim)
Expand Down Expand Up @@ -449,8 +464,13 @@ def read(config: Dict[str, Any], parameters: ParameterSet, output_dir: str) -> F
raise ValueError(f"Reference '{reference_name}' is not associated to any case.")
# Read the optional reduction
reduction = read_reduction(config.get('reduction', 'mse'))
# Read the optional transformers
transformers: Dict[str, ResponseTransformer] = {}
if 'transformers' in config:
for name, transformer_config in config['transformers'].items():
transformers[name] = read_response_transformer(transformer_config)
# Return the solver
return FittingSolver(solver, references, reduction)
return FittingSolver(solver, references, reduction, transformers=transformers)


class FittingObjective(GenericObjective):
Expand Down
202 changes: 202 additions & 0 deletions piglot/utils/response_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
"""Module for defining transformations for responses."""
from typing import Union, Dict, Any, Type, List
from abc import ABC, abstractmethod
import numpy as np
from piglot.utils.assorted import read_custom_module
from piglot.solver.solver import OutputResult


class ResponseTransformer(ABC):
"""Abstract class for defining transformation functions."""

@abstractmethod
def transform(self, response: OutputResult) -> OutputResult:
"""Transform the input data.
Parameters
----------
response : OutputResult
Time and data points of the response.
Returns
-------
OutputResult
Transformed time and data points of the response.
"""


class ChainResponse(ResponseTransformer):
"""Chain of response transformers."""

def __init__(self, transformers: List[Any]):
self.transformers = [read_response_transformer(t) for t in transformers]

def transform(self, response: OutputResult) -> OutputResult:
"""Transform the input data.
Parameters
----------
response : OutputResult
Time and data points of the response.
Returns
-------
Output
Transformed time and data points of the response.
"""
for transformer in self.transformers:
response = transformer.transform(response)
return response


class MinimumResponse(ResponseTransformer):
"""Minimum of a response transformer."""

def transform(self, response: OutputResult) -> OutputResult:
"""Transform the input data.
Parameters
----------
response : OutputResult
Time and data points of the response.
Returns
-------
OutputResult
Transformed time and data points of the response.
"""
return OutputResult(np.array([0.0]), np.array([np.min(response.data)]))


class MaximumResponse(ResponseTransformer):
"""Maximum of a response transformer."""

def transform(self, response: OutputResult) -> OutputResult:
"""Transform the input data.
Parameters
----------
response : OutputResult
Time and data points of the response.
Returns
-------
OutputResult
Transformed time and data points of the response.
"""
return OutputResult(np.array([0.0]), np.array([np.max(response.data)]))


class NegateResponse(ResponseTransformer):
"""Negate a response transformer."""

def transform(self, response: OutputResult) -> OutputResult:
"""Transform the input data.
Parameters
----------
response : OutputResult
Time and data points of the response.
Returns
-------
OutputResult
Transformed time and data points of the response.
"""
return OutputResult(response.time, -response.data)


class SquareResponse(ResponseTransformer):
"""Square a response transformer."""

def transform(self, response: OutputResult) -> OutputResult:
"""Transform the input data.
Parameters
----------
response : OutputResult
Time and data points of the response.
Returns
-------
Output
Transformed time and data points of the response.
"""
return OutputResult(response.time, np.square(response.data))


class AffineTransformResponse(ResponseTransformer):
"""Affine transformation of a response transformer."""

def __init__(
self,
scale_x: float = 1.0,
offset_x: float = 0.0,
scale_y: float = 1.0,
offset_y: float = 0.0,
):
self.scale_x = scale_x
self.offset_x = offset_x
self.scale_y = scale_y
self.offset_y = offset_y

def transform(self, response: OutputResult) -> OutputResult:
"""Transform the input data.
Parameters
----------
response : OutputResult
Time and data points of the response.
Returns
-------
OutputResult
Transformed time and data points of the response.
"""
return OutputResult(
self.scale_x * response.time + self.offset_x,
self.scale_y * response.data + self.offset_y,
)


AVAILABLE_RESPONSE_TRANSFORMERS: Dict[str, Type[ResponseTransformer]] = {
'min': MinimumResponse,
'max': MaximumResponse,
'negate': NegateResponse,
'square': SquareResponse,
'chain': ChainResponse,
'affine': AffineTransformResponse,
}


def read_response_transformer(config: Union[str, Dict[str, Any]]) -> ResponseTransformer:
"""Read a response transformer from a configuration.
Parameters
----------
config : Union[str, Dict[str, Any]]
Configuration of the response transformer.
Returns
-------
ResponseTransformer
Response transformer.
"""
# Parse the transformer in the simple format
if isinstance(config, str):
name = config
if name == 'script':
raise ValueError('Need to pass the file path for the "script" transformer.')
if name not in AVAILABLE_RESPONSE_TRANSFORMERS:
raise ValueError(f'Response transformer "{name}" is not available.')
return AVAILABLE_RESPONSE_TRANSFORMERS[name]()
# Detailed format
if 'name' not in config:
raise ValueError('Need to pass the name of the response transformer.')
name = config.pop('name')
# Read script transformer
if name == 'script':
return read_custom_module(config, ResponseTransformer)()
if name not in AVAILABLE_RESPONSE_TRANSFORMERS:
raise ValueError(f'Response transformer "{name}" is not available.')
return AVAILABLE_RESPONSE_TRANSFORMERS[name](**config)

0 comments on commit 5762e47

Please sign in to comment.