-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
551516e
commit be151b2
Showing
29 changed files
with
487 additions
and
554 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
"""Module for output fields from Curve solver.""" | ||
from __future__ import annotations | ||
from typing import Dict, Any, Tuple | ||
import os | ||
import copy | ||
import numpy as np | ||
from piglot.parameter import ParameterSet | ||
from piglot.solver.solver import InputData, OutputField, OutputResult | ||
from piglot.utils.solver_utils import write_parameters, get_case_name | ||
|
||
|
||
class CurveInputData(InputData): | ||
"""Container for dummy input data.""" | ||
|
||
def __init__( | ||
self, | ||
case_name: str, | ||
expression: str, | ||
parametric: str, | ||
bounds: Tuple[float, float], | ||
points: int, | ||
) -> None: | ||
super().__init__() | ||
self.case_name = case_name | ||
self.expression = expression | ||
self.parametric = parametric | ||
self.bounds = bounds | ||
self.points = points | ||
self.input_file: str = None | ||
|
||
def prepare( | ||
self, | ||
values: np.ndarray, | ||
parameters: ParameterSet, | ||
tmp_dir: str = None, | ||
) -> CurveInputData: | ||
"""Prepare the input data for the simulation with a given set of parameters. | ||
Parameters | ||
---------- | ||
values : np.ndarray | ||
Parameters to run for. | ||
parameters : ParameterSet | ||
Parameter set for this problem. | ||
tmp_dir : str, optional | ||
Temporary directory to run the analyses, by default None | ||
Returns | ||
------- | ||
CurveInputData | ||
Input data prepared for the simulation. | ||
""" | ||
result = copy.copy(self) | ||
# Write the input file (with the name placeholder) | ||
tmp_file = os.path.join(tmp_dir, f'{self.case_name}.tmp') | ||
with open(tmp_file, 'w', encoding='utf8') as file: | ||
file.write(f'{self.expression}') | ||
# Write the parameters to the input file | ||
result.input_file = os.path.join(tmp_dir, f'{self.case_name}.dat') | ||
write_parameters(parameters.to_dict(values), tmp_file, result.input_file) | ||
return result | ||
|
||
def check(self, parameters: ParameterSet) -> None: | ||
"""Check if the input data is valid according to the given parameters. | ||
Parameters | ||
---------- | ||
parameters : ParameterSet | ||
Parameter set for this problem. | ||
""" | ||
# Generate a dummy set of parameters (to ensure proper handling of output parameters) | ||
values = np.array([parameter.inital_value for parameter in parameters]) | ||
param_dict = parameters.to_dict(values, input_normalised=False) | ||
for parameter in param_dict: | ||
if parameter not in self.expression: | ||
raise ValueError(f"Parameter '{parameter}' not found in expression.") | ||
|
||
def name(self) -> str: | ||
"""Return the name of the input data. | ||
Returns | ||
------- | ||
str | ||
Name of the input data. | ||
""" | ||
return self.case_name | ||
|
||
def get_current(self, target_dir: str) -> CurveInputData: | ||
"""Get the current input data. | ||
Parameters | ||
---------- | ||
target_dir : str | ||
Target directory to copy the input file. | ||
Returns | ||
------- | ||
CurveInputData | ||
Current input data. | ||
""" | ||
result = CurveInputData(os.path.join(target_dir, self.case_name), self.expression, | ||
self.parametric, self.bounds, self.points) | ||
result.input_file = os.path.join(target_dir, self.case_name + '.dat') | ||
return result | ||
|
||
|
||
class Curve(OutputField): | ||
"""Curve output reader.""" | ||
|
||
def check(self, input_data: CurveInputData) -> None: | ||
"""Sanity checks on the input file. | ||
Parameters | ||
---------- | ||
input_data : CurveInputData | ||
Input data for this case. | ||
""" | ||
|
||
def get(self, input_data: CurveInputData) -> OutputResult: | ||
"""Reads reactions from a Curve analysis. | ||
Parameters | ||
---------- | ||
input_data : CurveInputData | ||
Input data for this case. | ||
Returns | ||
------- | ||
array | ||
2D array with parametric value and corresponding expression value. | ||
""" | ||
input_file = input_data.input_file | ||
casename = get_case_name(input_file) | ||
output_dir = os.path.dirname(input_file) | ||
output_filename = os.path.join(output_dir, f'{casename}.out') | ||
# Ensure the file exists | ||
if not os.path.exists(output_filename): | ||
return OutputResult(np.empty(0), np.empty(0)) | ||
data = np.genfromtxt(output_filename) | ||
return OutputResult(data[:, 0], data[:, 1]) | ||
|
||
@staticmethod | ||
def read(config: Dict[str, Any]) -> Curve: | ||
"""Read the output field from the configuration dictionary. | ||
Parameters | ||
---------- | ||
config : Dict[str, Any] | ||
Configuration dictionary. | ||
Returns | ||
------- | ||
Reaction | ||
Output field to use for this problem. | ||
""" | ||
return Curve() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
"""Module for Curve solver.""" | ||
from typing import Dict, Any | ||
import os | ||
import time | ||
import shutil | ||
from multiprocessing.pool import ThreadPool as Pool | ||
import numpy as np | ||
import sympy | ||
from piglot.parameter import ParameterSet | ||
from piglot.solver.solver import Solver, Case, CaseResult, OutputField, OutputResult | ||
from piglot.solver.curve.fields import CurveInputData, Curve | ||
|
||
|
||
class CurveSolver(Solver): | ||
"""Curve solver.""" | ||
|
||
def __init__( | ||
self, | ||
cases: Dict[Case, Dict[str, OutputField]], | ||
parameters: ParameterSet, | ||
output_dir: str, | ||
parallel: int, | ||
tmp_dir: str, | ||
) -> None: | ||
"""Constructor for the Curve solver class. | ||
Parameters | ||
---------- | ||
cases : Dict[Case, Dict[str, OutputField]] | ||
Cases to be run and respective output fields. | ||
parameters : ParameterSet | ||
Parameter set for this problem. | ||
output_dir : str | ||
Path to the output directory. | ||
parallel : int | ||
Number of parallel processes to use. | ||
tmp_dir : str | ||
Path to the temporary directory. | ||
""" | ||
super().__init__(cases, parameters, output_dir) | ||
self.parallel = parallel | ||
self.tmp_dir = tmp_dir | ||
|
||
def _run_case(self, values: np.ndarray, case: Case, tmp_dir: str) -> CaseResult: | ||
"""Run a single case wth Curve. | ||
Parameters | ||
---------- | ||
values: np.ndarray | ||
Current parameter values | ||
case : Case | ||
Case to run. | ||
tmp_dir: str | ||
Temporary directory to run the simulation | ||
Returns | ||
------- | ||
CaseResult | ||
Results for this case | ||
""" | ||
# Copy input file replacing parameters by passed value | ||
input_data: CurveInputData = case.input_data.prepare(values, self.parameters, tmp_dir) | ||
# Run dummy solver | ||
begin_time = time.time() | ||
# Read the expression from the input file | ||
with open(input_data.input_file, 'r', encoding='utf8') as file: | ||
expression_str = file.read() | ||
symbs = sympy.symbols(input_data.parametric) | ||
expression = sympy.lambdify(symbs, expression_str) | ||
# Evaluate the expression on the grid | ||
grid = np.linspace(input_data.bounds[0], input_data.bounds[1], input_data.points) | ||
curve = [expression(**{input_data.parametric: x}) for x in grid] | ||
# Write out the curve | ||
output_file = os.path.splitext(input_data.input_file)[0] + '.out' | ||
with open(output_file, 'w', encoding='utf8') as file: | ||
for x, y in zip(grid, curve): | ||
file.write(f'{x} {y}\n') | ||
# Read results from output directories | ||
responses = {name: field.get(input_data) for name, field in case.fields.items()} | ||
end_time = time.time() | ||
return CaseResult( | ||
begin_time, | ||
end_time - begin_time, | ||
values, | ||
True, | ||
self.parameters.hash(values), | ||
responses, | ||
) | ||
|
||
def _solve( | ||
self, | ||
values: np.ndarray, | ||
concurrent: bool, | ||
) -> Dict[Case, CaseResult]: | ||
"""Internal solver for the prescribed problems. | ||
Parameters | ||
---------- | ||
values : array | ||
Current parameters to evaluate. | ||
concurrent : bool | ||
Whether this run may be concurrent to another one (so use unique file names). | ||
Returns | ||
------- | ||
Dict[Case, CaseResult] | ||
Results for each case. | ||
""" | ||
# Ensure tmp directory is clean | ||
tmp_dir = f'{self.tmp_dir}_{self.parameters.hash(values)}' if concurrent else self.tmp_dir | ||
if os.path.isdir(tmp_dir): | ||
shutil.rmtree(tmp_dir) | ||
os.mkdir(tmp_dir) | ||
|
||
def run_case(case: Case) -> CaseResult: | ||
return self._run_case(values, case, tmp_dir) | ||
# Run cases (in parallel if specified) | ||
if self.parallel > 1: | ||
with Pool(self.parallel) as pool: | ||
results = pool.map(run_case, self.cases) | ||
else: | ||
results = map(run_case, self.cases) | ||
# Ensure we actually resolve the map | ||
results = list(results) | ||
# Cleanup temporary directories | ||
if concurrent: | ||
shutil.rmtree(tmp_dir) | ||
# Build output dict | ||
return dict(zip(self.cases, results)) | ||
|
||
def get_current_response(self) -> Dict[str, OutputResult]: | ||
"""Get the responses from a given output field for all cases. | ||
Returns | ||
------- | ||
Dict[str, OutputResult] | ||
Output responses. | ||
""" | ||
fields = self.get_output_fields() | ||
return { | ||
name: field.get(case.input_data.get_current(self.tmp_dir)) | ||
for name, (case, field) in fields.items() | ||
} | ||
|
||
@staticmethod | ||
def read(config: Dict[str, Any], parameters: ParameterSet, output_dir: str) -> Solver: | ||
"""Read the solver from the configuration dictionary. | ||
Parameters | ||
---------- | ||
config : Dict[str, Any] | ||
Configuration dictionary. | ||
parameters : ParameterSet | ||
Parameter set for this problem. | ||
output_dir : str | ||
Path to the output directory. | ||
Returns | ||
------- | ||
Solver | ||
Solver to use for this problem. | ||
""" | ||
# Read the parallelism and temporary directory (if present) | ||
parallel = int(config.get('parallel', 1)) | ||
tmp_dir = os.path.join(output_dir, config.get('tmp_dir', 'tmp')) | ||
# Read the cases | ||
if 'cases' not in config: | ||
raise ValueError("Missing 'cases' in solver configuration.") | ||
cases = [] | ||
for case_name, case_config in config['cases'].items(): | ||
if 'expression' not in case_config: | ||
raise ValueError("Missing 'expression' in solver configuration.") | ||
if 'parametric' not in case_config: | ||
raise ValueError("Missing 'parametric' in solver configuration.") | ||
if 'bounds' not in case_config: | ||
raise ValueError("Missing 'bounds' in solver configuration.") | ||
expression = case_config['expression'] | ||
parametric = case_config['parametric'] | ||
bounds = case_config['bounds'] | ||
points = int(case_config['points']) if 'points' in case_config else 100 | ||
input_data = CurveInputData(case_name, expression, parametric, bounds, points) | ||
cases.append(Case(input_data, {case_name: Curve()})) | ||
# Return the solver | ||
return CurveSolver(cases, parameters, output_dir, parallel=parallel, tmp_dir=tmp_dir) |
Oops, something went wrong.