diff --git a/piglot/objectives/design.py b/piglot/objectives/design.py index 94f6dfb..90cdc2c 100644 --- a/piglot/objectives/design.py +++ b/piglot/objectives/design.py @@ -10,6 +10,7 @@ from piglot.objective import Composition, DynamicPlotter, GenericObjective, ObjectiveResult from piglot.utils.assorted import stats_interp_to_common_grid, read_custom_module from piglot.utils.reductions import Reduction, NegateReduction, read_reduction +from piglot.utils.response_transformer import ResponseTransformer, read_response_transformer from piglot.utils.composition.responses import ResponseComposition, EndpointFlattenUtility @@ -127,6 +128,7 @@ def __init__( stochastic: bool = False, composite: bool = False, multi_objective: bool = False, + transformers: Dict[str, ResponseTransformer] = None, ) -> None: super().__init__( parameters, @@ -139,6 +141,7 @@ def __init__( ) self.solver = solver self.targets = targets + self.transformers = transformers if transformers is not None else {} def prepare(self) -> None: """Prepare the objective for optimisation.""" @@ -221,6 +224,10 @@ def _objective(self, values: np.ndarray, concurrent: bool = False) -> ObjectiveR Objective result. """ raw_responses = self.solver.solve(values, concurrent) + # Transform responses + for name, transformer in self.transformers.items(): + if name in raw_responses: + raw_responses[name] = transformer.transform(raw_responses[name]) # Interpolate responses to common grid and map to targets responses_interp = { target: [ @@ -268,12 +275,22 @@ def plot_case(self, case_hash: str, options: Dict[str, Any] = None) -> List[Figu figures = [] # Load all responses responses = self.solver.get_output_response(case_hash) + # Transform responses if necessary + for name, transformer in self.transformers.items(): + if name in responses: + responses[name] = transformer.transform(responses[name]) # Plot each target for target in self.targets: # Build figure with individual responses for this target fig, axis = plt.subplots() for pred in target.prediction: - axis.plot(responses[pred].get_time(), responses[pred].get_data(), label=f'{pred}') + marker = 'o' if len(responses[pred].get_time()) < 2 else None + axis.plot( + responses[pred].get_time(), + responses[pred].get_data(), + label=f'{pred}', + marker=marker, + ) # Under stochasticity, plot response stats and mean confidence interval if self.stochastic: stats = stats_interp_to_common_grid([ @@ -355,6 +372,11 @@ def read( # Sanitise target: check if it is associated to at least one case if len(targets[target]) == 0: raise ValueError(f"Design target '{target_name}' is not associated to any case.") + # Read transformers + transformers: Dict[str, ResponseTransformer] = {} + if 'transformers' in config: + for name, transformer_config in config.pop('transformers').items(): + transformers[name] = read_response_transformer(transformer_config) # Read custom class (if any) target_class: type[DesignObjective] = DesignObjective if 'custom_class' in config: @@ -367,4 +389,5 @@ def read( stochastic=bool(config.get('stochastic', False)), composite=bool(config.get('composite', False)), multi_objective=bool(config.get('multi_objective', False)), + transformers=transformers, ) diff --git a/piglot/objectives/fitting.py b/piglot/objectives/fitting.py index 85c1fae..54c9f15 100644 --- a/piglot/objectives/fitting.py +++ b/piglot/objectives/fitting.py @@ -12,6 +12,7 @@ from piglot.utils.assorted import stats_interp_to_common_grid from piglot.utils.reductions import Reduction, read_reduction from piglot.utils.responses import Transformer, reduce_response, interpolate_response +from piglot.utils.response_transformer import ResponseTransformer, read_response_transformer from piglot.utils.composition.responses import ResponseComposition, FixedFlatteningUtility from piglot.objective import Composition, DynamicPlotter, GenericObjective, ObjectiveResult @@ -47,9 +48,12 @@ def __init__( self.weight = weight self.reduction = reduction # Load the data right away - data = np.genfromtxt(filename, skip_header=skip_header)[:, [x_col - 1, y_col - 1]] - self.x_data = data[:, 0] - self.y_data = data[:, 1] + data = np.genfromtxt(filename, skip_header=skip_header) + # Sanitise to ensure it is a 2D array + if len(data.shape) == 1: + data = data.reshape(1, -1) + self.x_data = data[:, x_col - 1] + self.y_data = data[:, y_col - 1] # Apply the transformer if self.transformer is not None: self.x_data, self.y_data = self.transformer(self.x_data, self.y_data) @@ -259,10 +263,12 @@ def __init__( solver: Solver, references: Dict[Reference, List[str]], reduction: Reduction, + transformers: Dict[str, ResponseTransformer] = None, ) -> None: self.solver = solver self.references = references self.reduction = reduction + self.transformers = transformers if transformers is not None else {} # Assign the reduction to non-defined references for reference in self.references: if reference.reduction is None: @@ -319,6 +325,10 @@ def solve(self, values: np.ndarray, concurrent: bool) -> Dict[Reference, List[Ou Output results. """ result = self.solver.solve(values, concurrent) + # Transform responses + for name, transformer in self.transformers.items(): + if name in result: + result[name] = transformer.transform(result[name]) # Populate output results output = {} for reference, cases in self.references.items(): @@ -348,6 +358,10 @@ def plot_case(self, case_hash: str, options: Dict[str, Any] = None) -> List[Figu figures = [] # Load all responses responses = self.solver.get_output_response(case_hash) + # Transform responses if necessary + for name, transformer in self.transformers.items(): + if name in responses: + responses[name] = transformer.transform(responses[name]) # Plot each reference for reference, names in self.references.items(): # Build figure, index axes and plot response @@ -390,8 +404,9 @@ def plot_case(self, case_hash: str, options: Dict[str, Any] = None) -> List[Figu else: # Plot the individual responses for name in names: + marker = 'o' if len(responses[name].get_time()) < 2 else None axis.plot(responses[name].get_time(), responses[name].get_data(), - label=f'{name}') + label=f'{name}', marker=marker) if reference_limits: axis.set_xlim(xlim) axis.set_ylim(ylim) @@ -449,8 +464,13 @@ def read(config: Dict[str, Any], parameters: ParameterSet, output_dir: str) -> F raise ValueError(f"Reference '{reference_name}' is not associated to any case.") # Read the optional reduction reduction = read_reduction(config.get('reduction', 'mse')) + # Read the optional transformers + transformers: Dict[str, ResponseTransformer] = {} + if 'transformers' in config: + for name, transformer_config in config['transformers'].items(): + transformers[name] = read_response_transformer(transformer_config) # Return the solver - return FittingSolver(solver, references, reduction) + return FittingSolver(solver, references, reduction, transformers=transformers) class FittingObjective(GenericObjective): diff --git a/piglot/utils/response_transformer.py b/piglot/utils/response_transformer.py new file mode 100644 index 0000000..7d756e8 --- /dev/null +++ b/piglot/utils/response_transformer.py @@ -0,0 +1,202 @@ +"""Module for defining transformations for responses.""" +from typing import Union, Dict, Any, Type, List +from abc import ABC, abstractmethod +import numpy as np +from piglot.utils.assorted import read_custom_module +from piglot.solver.solver import OutputResult + + +class ResponseTransformer(ABC): + """Abstract class for defining transformation functions.""" + + @abstractmethod + def transform(self, response: OutputResult) -> OutputResult: + """Transform the input data. + + Parameters + ---------- + response : OutputResult + Time and data points of the response. + + Returns + ------- + OutputResult + Transformed time and data points of the response. + """ + + +class ChainResponse(ResponseTransformer): + """Chain of response transformers.""" + + def __init__(self, transformers: List[Any]): + self.transformers = [read_response_transformer(t) for t in transformers] + + def transform(self, response: OutputResult) -> OutputResult: + """Transform the input data. + + Parameters + ---------- + response : OutputResult + Time and data points of the response. + + Returns + ------- + Output + Transformed time and data points of the response. + """ + for transformer in self.transformers: + response = transformer.transform(response) + return response + + +class MinimumResponse(ResponseTransformer): + """Minimum of a response transformer.""" + + def transform(self, response: OutputResult) -> OutputResult: + """Transform the input data. + + Parameters + ---------- + response : OutputResult + Time and data points of the response. + + Returns + ------- + OutputResult + Transformed time and data points of the response. + """ + return OutputResult(np.array([0.0]), np.array([np.min(response.data)])) + + +class MaximumResponse(ResponseTransformer): + """Maximum of a response transformer.""" + + def transform(self, response: OutputResult) -> OutputResult: + """Transform the input data. + + Parameters + ---------- + response : OutputResult + Time and data points of the response. + + Returns + ------- + OutputResult + Transformed time and data points of the response. + """ + return OutputResult(np.array([0.0]), np.array([np.max(response.data)])) + + +class NegateResponse(ResponseTransformer): + """Negate a response transformer.""" + + def transform(self, response: OutputResult) -> OutputResult: + """Transform the input data. + + Parameters + ---------- + response : OutputResult + Time and data points of the response. + + Returns + ------- + OutputResult + Transformed time and data points of the response. + """ + return OutputResult(response.time, -response.data) + + +class SquareResponse(ResponseTransformer): + """Square a response transformer.""" + + def transform(self, response: OutputResult) -> OutputResult: + """Transform the input data. + + Parameters + ---------- + response : OutputResult + Time and data points of the response. + + Returns + ------- + Output + Transformed time and data points of the response. + """ + return OutputResult(response.time, np.square(response.data)) + + +class AffineTransformResponse(ResponseTransformer): + """Affine transformation of a response transformer.""" + + def __init__( + self, + scale_x: float = 1.0, + offset_x: float = 0.0, + scale_y: float = 1.0, + offset_y: float = 0.0, + ): + self.scale_x = scale_x + self.offset_x = offset_x + self.scale_y = scale_y + self.offset_y = offset_y + + def transform(self, response: OutputResult) -> OutputResult: + """Transform the input data. + + Parameters + ---------- + response : OutputResult + Time and data points of the response. + + Returns + ------- + OutputResult + Transformed time and data points of the response. + """ + return OutputResult( + self.scale_x * response.time + self.offset_x, + self.scale_y * response.data + self.offset_y, + ) + + +AVAILABLE_RESPONSE_TRANSFORMERS: Dict[str, Type[ResponseTransformer]] = { + 'min': MinimumResponse, + 'max': MaximumResponse, + 'negate': NegateResponse, + 'square': SquareResponse, + 'chain': ChainResponse, + 'affine': AffineTransformResponse, +} + + +def read_response_transformer(config: Union[str, Dict[str, Any]]) -> ResponseTransformer: + """Read a response transformer from a configuration. + + Parameters + ---------- + config : Union[str, Dict[str, Any]] + Configuration of the response transformer. + + Returns + ------- + ResponseTransformer + Response transformer. + """ + # Parse the transformer in the simple format + if isinstance(config, str): + name = config + if name == 'script': + raise ValueError('Need to pass the file path for the "script" transformer.') + if name not in AVAILABLE_RESPONSE_TRANSFORMERS: + raise ValueError(f'Response transformer "{name}" is not available.') + return AVAILABLE_RESPONSE_TRANSFORMERS[name]() + # Detailed format + if 'name' not in config: + raise ValueError('Need to pass the name of the response transformer.') + name = config.pop('name') + # Read script transformer + if name == 'script': + return read_custom_module(config, ResponseTransformer)() + if name not in AVAILABLE_RESPONSE_TRANSFORMERS: + raise ValueError(f'Response transformer "{name}" is not available.') + return AVAILABLE_RESPONSE_TRANSFORMERS[name](**config)