Skip to content

Commit

Permalink
Merge pull request #316 from gyorilab/odes
Browse files Browse the repository at this point in the history
New features in rate law settings and ODE simulation + photosynthesis model
  • Loading branch information
bgyori authored Mar 28, 2024
2 parents 37d55c4 + 2fbf1d9 commit 0da409b
Show file tree
Hide file tree
Showing 8 changed files with 833 additions and 37 deletions.
10 changes: 9 additions & 1 deletion mira/metamodel/template_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import datetime
import sys
from typing import List, Dict, Set, Optional, Mapping, Tuple, Any
from typing import List, Dict, Set, Optional, Mapping, Tuple, Any, Union

import networkx as nx
import sympy
Expand Down Expand Up @@ -603,6 +603,14 @@ def generate_model_graph(self, use_display_name: bool = False) -> nx.DiGraph:

return graph

def set_rate_law(self, template_name: str,
rate_law: Union[str, sympy.Expr, SympyExprStr],
local_dict=None):
"""Set the rate law of a template with a given name."""
for template in self.templates:
if template.name == template_name:
template.set_rate_law(rate_law, local_dict=local_dict)

def draw_graph(
self,
path: str,
Expand Down
24 changes: 24 additions & 0 deletions mira/metamodel/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,30 @@ def with_mass_action_rate_law(self, parameter, independent=False) -> "Template":
template.set_mass_action_rate_law(parameter, independent=independent)
return template

def set_rate_law(self, rate_law: Union[str, sympy.Expr, SympyExprStr],
local_dict=None):
"""Set the rate law of this template to the given rate law."""
if isinstance(rate_law, SympyExprStr):
self.rate_law = rate_law
elif isinstance(rate_law, sympy.Expr):
self.rate_law = SympyExprStr(rate_law)
elif isinstance(rate_law, str):
try:
rate = SympyExprStr(safe_parse_expr(rate_law,
local_dict=local_dict))
except Exception as e:
logger.warning(f"Could not parse rate law into "
f"symbolic expression: {rate_law}. "
f"Not setting rate law.")
return
self.rate_law = rate

def with_rate_law(self, rate_law: Union[str, sympy.Expr, SympyExprStr],
local_dict=None) -> "Template":
template = self.copy(deep=True)
template.set_rate_law(rate_law, local_dict=local_dict)
return template

def get_parameter_names(self) -> Set[str]:
"""Get the set of parameter names.
Expand Down
65 changes: 57 additions & 8 deletions mira/modeling/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def __init__(self, model: Model, initialized: bool):
self.y = sympy.MatrixSymbol('y', len(model.variables), 1)
self.vmap = {variable.key: idx for idx, variable
in enumerate(model.variables.values())}
self.observable_map = {obs_key: idx for idx, obs_key
in enumerate(model.observables)}
real_params = {k: v for k, v in model.parameters.items()
if not v.placeholder}
self.p = sympy.MatrixSymbol('p', len(real_params), 1)
Expand All @@ -27,9 +29,9 @@ def __init__(self, model: Model, initialized: bool):
parameter_map = {parameter.concept.name: parameter.key
for parameter in real_params.values()}

'''
Following code block is agnostic towards the case if the ODE model was created with parameter and agent
values initialized when creating parameters or when calling the simulate_ode method.'''
"""Following code block is agnostic towards the case if the ODE model
was created with parameter and initial values initialized when
creating parameters or when calling the simulate_ode method."""
if initialized:
self.parameter_values = []
self.variable_values = []
Expand Down Expand Up @@ -73,11 +75,42 @@ def __init__(self, model: Model, initialized: bool):
self.kinetics = sympy.Matrix(self.kinetics)
self.kinetics_lmbd = sympy.lambdify([self.y], self.kinetics)

observables = []
for obs_name, model_obs in model.observables.items():
expr = deepcopy(model_obs.observable.expression).args[0]
for symbol in expr.free_symbols:
sym_str = str(symbol)
if sym_str in concept_map:
expr = expr.subs(symbol,
self.y[self.vmap[concept_map[sym_str]]])
elif sym_str in self.pmap:
expr = expr.subs(symbol,
self.p[self.pmap[parameter_map[sym_str]]])
elif model.template_model.time and \
sym_str == model.template_model.time.name:
expr = expr.subs(symbol, 't')
else:
assert False, sym_str
observables.append(expr)
self.observables = sympy.Matrix(observables)
self.observables_lmbd = sympy.lambdify([self.y], self.observables)

def get_interpretable_kinetics(self):
# Return kinetics but with y and p substituted
# based on vmap and pmap
subs = {self.y[v]: sympy.Symbol(k) for k, v in self.vmap.items()}
subs.update({self.p[p]: sympy.Symbol(k) for k, p in self.pmap.items()})
return sympy.Matrix([
k.subs(subs) for k in self.kinetics
])

def set_parameters(self, params):
"""Set the parameters of the model."""
for p, v in params.items():
self.kinetics = self.kinetics.subs(self.p[self.pmap[p]], v)
self.observables = self.observables.subs(self.p[self.pmap[p]], v)
self.kinetics_lmbd = sympy.lambdify([self.y], self.kinetics)
self.observables_lmbd = sympy.lambdify([self.y], self.observables)

def get_rhs(self):
"""Return the right-hand side of the ODE system."""
Expand All @@ -96,7 +129,7 @@ def rhs(t, y):
# ode_model: OdeModel, times, initials=None,
# parameters=None
def simulate_ode_model(ode_model: OdeModel, times, initials=None,
parameters=None):
parameters=None, with_observables=False):
"""Simulate an ODE model given initial conditions, parameters and a
time span.
Expand All @@ -112,6 +145,9 @@ def simulate_ode_model(ode_model: OdeModel, times, initials=None,
times:
A one-dimensional array of time values, typically from
a linear space like ``numpy.linspace(0, 25, 100)``
with_observables:
A boolean indicating whether to return the observables
as well as the variables.
Returns
-------
Expand All @@ -130,14 +166,27 @@ def simulate_ode_model(ode_model: OdeModel, times, initials=None,

initials = ode_model.variable_values
for index, expression in enumerate(initials):
initials[index] = float(expression.subs(parameters).args[0])
# Only substitute if this is an expression. Once the model has been
# simulated, this is actually a float.
if isinstance(expression, sympy.Expr):
initials[index] = float(expression.subs(parameters).args[0])

ode_model.set_parameters(parameters)
solver = scipy.integrate.ode(f=rhs)

solver.set_initial_value(initials)
res = numpy.zeros((len(times), ode_model.y.shape[0]))
res[0, :] = initials
num_vars = ode_model.y.shape[0]
num_obs = len(ode_model.observable_map)
num_cols = num_vars + (num_obs if with_observables else 0)
res = numpy.zeros((len(times), num_cols))
res[0, :num_vars] = initials
for idx, time in enumerate(times[1:]):
res[idx + 1, :] = solver.integrate(time)
res[idx + 1, :num_vars] = solver.integrate(time)

if with_observables:
for tidx, t in enumerate(times):
obs_res = \
ode_model.observables_lmbd(res[tidx, :num_vars][:, None])
for idx, val in enumerate(obs_res):
res[tidx, num_vars + idx] = obs_res[idx]
return res
Loading

0 comments on commit 0da409b

Please sign in to comment.