Skip to content

Commit

Permalink
Add curve fitting solver
Browse files Browse the repository at this point in the history
  • Loading branch information
ruicoelhopedro committed Jan 31, 2024
1 parent 551516e commit be151b2
Show file tree
Hide file tree
Showing 29 changed files with 487 additions and 554 deletions.
4 changes: 2 additions & 2 deletions piglot/solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from piglot.parameter import ParameterSet
from piglot.solver.solver import Solver
from piglot.solver.links.solver import LinksSolver
from piglot.solver.dummy.solver import DummySolver
from piglot.solver.curve.solver import CurveSolver


AVAILABLE_SOLVERS: Dict[str, Type[Solver]] = {
'links': LinksSolver,
'dummy': DummySolver,
'curve': CurveSolver,
}


Expand Down
File renamed without changes.
157 changes: 157 additions & 0 deletions piglot/solver/curve/fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
"""Module for output fields from Curve solver."""
from __future__ import annotations
from typing import Dict, Any, Tuple
import os
import copy
import numpy as np
from piglot.parameter import ParameterSet
from piglot.solver.solver import InputData, OutputField, OutputResult
from piglot.utils.solver_utils import write_parameters, get_case_name


class CurveInputData(InputData):
"""Container for dummy input data."""

def __init__(
self,
case_name: str,
expression: str,
parametric: str,
bounds: Tuple[float, float],
points: int,
) -> None:
super().__init__()
self.case_name = case_name
self.expression = expression
self.parametric = parametric
self.bounds = bounds
self.points = points
self.input_file: str = None

def prepare(
self,
values: np.ndarray,
parameters: ParameterSet,
tmp_dir: str = None,
) -> CurveInputData:
"""Prepare the input data for the simulation with a given set of parameters.
Parameters
----------
values : np.ndarray
Parameters to run for.
parameters : ParameterSet
Parameter set for this problem.
tmp_dir : str, optional
Temporary directory to run the analyses, by default None
Returns
-------
CurveInputData
Input data prepared for the simulation.
"""
result = copy.copy(self)
# Write the input file (with the name placeholder)
tmp_file = os.path.join(tmp_dir, f'{self.case_name}.tmp')
with open(tmp_file, 'w', encoding='utf8') as file:
file.write(f'{self.expression}')
# Write the parameters to the input file
result.input_file = os.path.join(tmp_dir, f'{self.case_name}.dat')
write_parameters(parameters.to_dict(values), tmp_file, result.input_file)
return result

def check(self, parameters: ParameterSet) -> None:
"""Check if the input data is valid according to the given parameters.
Parameters
----------
parameters : ParameterSet
Parameter set for this problem.
"""
# Generate a dummy set of parameters (to ensure proper handling of output parameters)
values = np.array([parameter.inital_value for parameter in parameters])
param_dict = parameters.to_dict(values, input_normalised=False)
for parameter in param_dict:
if parameter not in self.expression:
raise ValueError(f"Parameter '{parameter}' not found in expression.")

def name(self) -> str:
"""Return the name of the input data.
Returns
-------
str
Name of the input data.
"""
return self.case_name

def get_current(self, target_dir: str) -> CurveInputData:
"""Get the current input data.
Parameters
----------
target_dir : str
Target directory to copy the input file.
Returns
-------
CurveInputData
Current input data.
"""
result = CurveInputData(os.path.join(target_dir, self.case_name), self.expression,
self.parametric, self.bounds, self.points)
result.input_file = os.path.join(target_dir, self.case_name + '.dat')
return result


class Curve(OutputField):
"""Curve output reader."""

def check(self, input_data: CurveInputData) -> None:
"""Sanity checks on the input file.
Parameters
----------
input_data : CurveInputData
Input data for this case.
"""

def get(self, input_data: CurveInputData) -> OutputResult:
"""Reads reactions from a Curve analysis.
Parameters
----------
input_data : CurveInputData
Input data for this case.
Returns
-------
array
2D array with parametric value and corresponding expression value.
"""
input_file = input_data.input_file
casename = get_case_name(input_file)
output_dir = os.path.dirname(input_file)
output_filename = os.path.join(output_dir, f'{casename}.out')
# Ensure the file exists
if not os.path.exists(output_filename):
return OutputResult(np.empty(0), np.empty(0))
data = np.genfromtxt(output_filename)
return OutputResult(data[:, 0], data[:, 1])

@staticmethod
def read(config: Dict[str, Any]) -> Curve:
"""Read the output field from the configuration dictionary.
Parameters
----------
config : Dict[str, Any]
Configuration dictionary.
Returns
-------
Reaction
Output field to use for this problem.
"""
return Curve()
184 changes: 184 additions & 0 deletions piglot/solver/curve/solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
"""Module for Curve solver."""
from typing import Dict, Any
import os
import time
import shutil
from multiprocessing.pool import ThreadPool as Pool
import numpy as np
import sympy
from piglot.parameter import ParameterSet
from piglot.solver.solver import Solver, Case, CaseResult, OutputField, OutputResult
from piglot.solver.curve.fields import CurveInputData, Curve


class CurveSolver(Solver):
"""Curve solver."""

def __init__(
self,
cases: Dict[Case, Dict[str, OutputField]],
parameters: ParameterSet,
output_dir: str,
parallel: int,
tmp_dir: str,
) -> None:
"""Constructor for the Curve solver class.
Parameters
----------
cases : Dict[Case, Dict[str, OutputField]]
Cases to be run and respective output fields.
parameters : ParameterSet
Parameter set for this problem.
output_dir : str
Path to the output directory.
parallel : int
Number of parallel processes to use.
tmp_dir : str
Path to the temporary directory.
"""
super().__init__(cases, parameters, output_dir)
self.parallel = parallel
self.tmp_dir = tmp_dir

def _run_case(self, values: np.ndarray, case: Case, tmp_dir: str) -> CaseResult:
"""Run a single case wth Curve.
Parameters
----------
values: np.ndarray
Current parameter values
case : Case
Case to run.
tmp_dir: str
Temporary directory to run the simulation
Returns
-------
CaseResult
Results for this case
"""
# Copy input file replacing parameters by passed value
input_data: CurveInputData = case.input_data.prepare(values, self.parameters, tmp_dir)
# Run dummy solver
begin_time = time.time()
# Read the expression from the input file
with open(input_data.input_file, 'r', encoding='utf8') as file:
expression_str = file.read()
symbs = sympy.symbols(input_data.parametric)
expression = sympy.lambdify(symbs, expression_str)
# Evaluate the expression on the grid
grid = np.linspace(input_data.bounds[0], input_data.bounds[1], input_data.points)
curve = [expression(**{input_data.parametric: x}) for x in grid]
# Write out the curve
output_file = os.path.splitext(input_data.input_file)[0] + '.out'
with open(output_file, 'w', encoding='utf8') as file:
for x, y in zip(grid, curve):
file.write(f'{x} {y}\n')
# Read results from output directories
responses = {name: field.get(input_data) for name, field in case.fields.items()}
end_time = time.time()
return CaseResult(
begin_time,
end_time - begin_time,
values,
True,
self.parameters.hash(values),
responses,
)

def _solve(
self,
values: np.ndarray,
concurrent: bool,
) -> Dict[Case, CaseResult]:
"""Internal solver for the prescribed problems.
Parameters
----------
values : array
Current parameters to evaluate.
concurrent : bool
Whether this run may be concurrent to another one (so use unique file names).
Returns
-------
Dict[Case, CaseResult]
Results for each case.
"""
# Ensure tmp directory is clean
tmp_dir = f'{self.tmp_dir}_{self.parameters.hash(values)}' if concurrent else self.tmp_dir
if os.path.isdir(tmp_dir):
shutil.rmtree(tmp_dir)
os.mkdir(tmp_dir)

def run_case(case: Case) -> CaseResult:
return self._run_case(values, case, tmp_dir)
# Run cases (in parallel if specified)
if self.parallel > 1:
with Pool(self.parallel) as pool:
results = pool.map(run_case, self.cases)
else:
results = map(run_case, self.cases)
# Ensure we actually resolve the map
results = list(results)
# Cleanup temporary directories
if concurrent:
shutil.rmtree(tmp_dir)
# Build output dict
return dict(zip(self.cases, results))

def get_current_response(self) -> Dict[str, OutputResult]:
"""Get the responses from a given output field for all cases.
Returns
-------
Dict[str, OutputResult]
Output responses.
"""
fields = self.get_output_fields()
return {
name: field.get(case.input_data.get_current(self.tmp_dir))
for name, (case, field) in fields.items()
}

@staticmethod
def read(config: Dict[str, Any], parameters: ParameterSet, output_dir: str) -> Solver:
"""Read the solver from the configuration dictionary.
Parameters
----------
config : Dict[str, Any]
Configuration dictionary.
parameters : ParameterSet
Parameter set for this problem.
output_dir : str
Path to the output directory.
Returns
-------
Solver
Solver to use for this problem.
"""
# Read the parallelism and temporary directory (if present)
parallel = int(config.get('parallel', 1))
tmp_dir = os.path.join(output_dir, config.get('tmp_dir', 'tmp'))
# Read the cases
if 'cases' not in config:
raise ValueError("Missing 'cases' in solver configuration.")
cases = []
for case_name, case_config in config['cases'].items():
if 'expression' not in case_config:
raise ValueError("Missing 'expression' in solver configuration.")
if 'parametric' not in case_config:
raise ValueError("Missing 'parametric' in solver configuration.")
if 'bounds' not in case_config:
raise ValueError("Missing 'bounds' in solver configuration.")
expression = case_config['expression']
parametric = case_config['parametric']
bounds = case_config['bounds']
points = int(case_config['points']) if 'points' in case_config else 100
input_data = CurveInputData(case_name, expression, parametric, bounds, points)
cases.append(Case(input_data, {case_name: Curve()}))
# Return the solver
return CurveSolver(cases, parameters, output_dir, parallel=parallel, tmp_dir=tmp_dir)
Loading

0 comments on commit be151b2

Please sign in to comment.