Skip to content

Commit

Permalink
Merge pull request #332 from gyorilab/stratify_improvement
Browse files Browse the repository at this point in the history
Reimplement stratification logic for parameter consistency
  • Loading branch information
bgyori authored May 15, 2024
2 parents a8623e7 + a7f1875 commit 4003ed8
Show file tree
Hide file tree
Showing 6 changed files with 202 additions and 128 deletions.
12 changes: 6 additions & 6 deletions mira/dkg/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def get_entity(
curie: str = Path(
...,
description="A compact URI (CURIE) for an entity in the form of ``<prefix>:<local unique identifier>``",
example="ido:0000511",
examples=["ido:0000511"],
),
):
"""Get information about an entity (e.g., its name, description synonyms, alternative identifiers,
Expand All @@ -97,7 +97,7 @@ def get_entities(
...,
description="A comma-separated list of compact URIs (CURIEs) for an "
"entity in the form of ``<prefix>:<local unique identifier>,...``",
example="ido:0000511,ido:0000512",
examples=["ido:0000511,ido:0000512"],
),
):
"""
Expand Down Expand Up @@ -158,7 +158,7 @@ def get_transitive_closure(
relation_types: List[str] = Query(
...,
description="A list of relation types to get a transitive closure for",
example=DKG_REFINER_RELS,
examples=[DKG_REFINER_RELS],
),
):
"""Get a transitive closure of the requested type(s)"""
Expand Down Expand Up @@ -384,13 +384,13 @@ def is_ontological_child(
)
def search(
request: Request,
q: str = Query(..., example="infect", description="The search query"),
q: str = Query(..., examples=["infect"], description="The search query"),
limit: int = 25,
offset: int = 0,
prefixes: Optional[str] = Query(
default=None,
description="A comma-separated list of prefixes",
examples={
examples=[{
"no prefix filter": {
"summary": "Don't filter by prefix",
"value": None,
Expand All @@ -399,7 +399,7 @@ def search(
"summary": "Search for units, which have Wikidata prefixes",
"value": "wikidata",
},
},
}],
),
labels: Optional[str] = Query(
default=None,
Expand Down
2 changes: 1 addition & 1 deletion mira/dkg/grounding.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def ground_get(
description="The text to be grounded. Warning: grounding does not work well for "
"substring matches, i.e., if searching only for 'infected'. In these "
"cases, using the search API is more appropriate.",
example="Infected Population",
examples=["Infected Population"],
),
):
"""Ground text with Gilda."""
Expand Down
207 changes: 117 additions & 90 deletions mira/metamodel/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

def stratify(
template_model: TemplateModel,
*,
key: str,
strata: Collection[str],
strata_curie_to_name: Optional[Mapping[str, str]] = None,
Expand All @@ -42,6 +41,7 @@ def stratify(
params_to_preserve: Optional[Collection[str]] = None,
concepts_to_stratify: Optional[Collection[str]] = None,
concepts_to_preserve: Optional[Collection[str]] = None,
param_renaming_uses_strata_names: Optional[bool] = False,
) -> TemplateModel:
"""Multiplies a model into several strata.
Expand Down Expand Up @@ -95,23 +95,23 @@ def stratify(
params_to_stratify :
A list of parameters to stratify. If none given, will stratify all
parameters.
params_to_preserve:
params_to_preserve :
A list of parameters to preserve. If none given, will stratify all
parameters.
concepts_to_stratify :
A list of concepts to stratify. If none given, will stratify all
concepts.
concepts_to_preserve:
concepts_to_preserve :
A list of concepts to preserve. If none given, will stratify all
concepts.
param_renaming_uses_strata_names :
If true, the strata names will be used in the parameter renaming.
If false, the strata indices will be used. Default: False
Returns
-------
:
A stratified template model
"""
strata = sorted(strata)

if strata_name_lookup and strata_curie_to_name is None:
from mira.dkg.web_client import get_entities_web, MissingBaseUrlError
try:
Expand All @@ -137,8 +137,6 @@ def stratify(

# List of new templates
templates = []
# Counter to keep track of how many times a parameter has been stratified
params_count = Counter()

# Figure out excluded concepts
if concepts_to_stratify is None:
Expand All @@ -154,7 +152,10 @@ def stratify(
concept_names - set(concepts_to_stratify)
)

stratum_index_map = {stratum: i for i, stratum in enumerate(strata)}

keep_unstratified_parameters = set()
all_param_mappings = defaultdict(set)
for template in template_model.templates:
# If the template doesn't have any concepts that need to be stratified
# then we can just keep it as is and skip the rest of the loop
Expand All @@ -165,77 +166,100 @@ def stratify(
templates.append(deepcopy(template))
continue

# Generate a derived template for each strata
for stratum in strata:
new_template = template.with_context(
do_rename=modify_names, exclude_concepts=exclude_concepts,
curie_to_name_map=strata_curie_to_name,
**{key: stratum},
)
rewrite_rate_law(template_model=template_model,
old_template=template,
new_template=new_template,
params_count=params_count,
params_to_stratify=params_to_stratify,
params_to_preserve=params_to_preserve)
# parameters = list(template_model.get_parameters_from_rate_law(template.rate_law))
# if len(parameters) == 1:
# new_template.set_mass_action_rate_law(parameters[0])
templates.append(new_template)

# assume all controllers have to get stratified together
# and mixing of strata doesn't occur during control
controllers = template.get_controllers()
if cartesian_control and controllers:
remaining_strata = [s for s in strata if s != stratum]

# use itt.product to generate all combinations of remaining
# strata for remaining controllers. for example, if there
# Check if we will have any controllers in the template
ncontrollers = num_controllers(template)
# If we have controllers, and we want cartesian control then
# we will stratify controllers separately
stratify_controllers = (ncontrollers > 0) and cartesian_control

# Generate a derived template for each stratum
for stratum, stratum_idx in stratum_index_map.items():
template_strata = []
new_template = deepcopy(template)
# We have to make sure that we only add the stratum to the
# list of template strata if we stratified any of the non-controllers
# in this first for loop
any_noncontrollers_stratified = False
# We apply this stratum to each concept except for controllers
# in case we will separately stratify those
for concept in new_template.get_concepts_flat(
exclude_controllers=stratify_controllers,
refresh=True):
if concept.name in exclude_concepts:
continue
concept.with_context(
do_rename=modify_names,
curie_to_name_map=strata_curie_to_name,
inplace=True,
**{key: stratum})
any_noncontrollers_stratified = True

# If we don't stratify controllers then we are done and can just
# make the new rate law, then append this new template
if not stratify_controllers:
# We only need to do this if we stratified any of the non-controllers
if any_noncontrollers_stratified:
template_strata = [stratum if
param_renaming_uses_strata_names else stratum_idx]
param_mappings = rewrite_rate_law(template_model=template_model,
old_template=template,
new_template=new_template,
template_strata=template_strata,
params_to_stratify=params_to_stratify,
params_to_preserve=params_to_preserve)
for old_param, new_param in param_mappings.items():
all_param_mappings[old_param].add(new_param)
templates.append(new_template)
# Otherwise we are stratifying controllers separately
else:
# Use itt.product to generate all combinations of
# strata for controllers. For example, if there
# are two controllers A and B and stratification is into
# old, middle, and young, then there will be the following 9:
# (A_old, B_old), (A_old, B_middle), (A_old, B_young),
# (A_middle, B_old), (A_middle, B_middle), (A_middle, B_young),
# (A_young, B_old), (A_young, B_middle), (A_young, B_young)
c_strata_tuples = itt.product(remaining_strata, repeat=len(controllers))
for c_strata_tuple in c_strata_tuples:
stratified_controllers = [
controller.with_context(do_rename=modify_names, **{key: c_stratum})
if controller.name not in exclude_concepts
else controller
for controller, c_stratum in zip(controllers, c_strata_tuple)
]
if isinstance(template, (GroupedControlledConversion, GroupedControlledProduction)):
stratified_template = new_template.with_controllers(stratified_controllers)
elif isinstance(template, (ControlledConversion, ControlledProduction,
ControlledDegradation, ControlledReplication)):
assert len(stratified_controllers) == 1
stratified_template = new_template.with_controller(stratified_controllers[0])
else:
raise NotImplementedError
# the old template is used here on purpose for easier bookkeeping
rewrite_rate_law(template_model=template_model,
old_template=template,
new_template=stratified_template,
params_count=params_count,
params_to_stratify=params_to_stratify,
params_to_preserve=params_to_preserve)
for c_strata_tuple in itt.product(strata, repeat=ncontrollers):
stratified_template = deepcopy(new_template)
stratified_controllers = stratified_template.get_controllers()
template_strata = [stratum if param_renaming_uses_strata_names
else stratum_idx]
# We now apply the stratum assigned to each controller in this particular
# tuple to the controller
for controller, c_stratum in zip(stratified_controllers, c_strata_tuple):
controller.with_context(do_rename=modify_names, inplace=True,
**{key: c_stratum})
template_strata.append(c_stratum if param_renaming_uses_strata_names
else stratum_index_map[c_stratum])

# Wew can now rewrite the rate law for this stratified template,
# then append the new template
param_mappings = rewrite_rate_law(template_model=template_model,
old_template=template,
new_template=stratified_template,
template_strata=template_strata,
params_to_stratify=params_to_stratify,
params_to_preserve=params_to_preserve)
for old_param, new_param in param_mappings.items():
all_param_mappings[old_param].add(new_param)
templates.append(stratified_template)

parameters = {}
for parameter_key, parameter in template_model.parameters.items():
if parameter_key not in params_count:
if parameter_key not in all_param_mappings:
parameters[parameter_key] = parameter
continue
# We need to keep the original param if it has been broken
# up but not in every instance. We then also
# generate the counted parameter variants
elif parameter_key in keep_unstratified_parameters:
parameters[parameter_key] = parameter
# note that `params_count[key]` will be 1 higher than the number of uses
for i in range(params_count[parameter_key]):
# We otherwise generate variants of the parameter based
# on the previously complied parameter mappings
for stratified_param in all_param_mappings[parameter_key]:
d = deepcopy(parameter)
d.name = f"{parameter_key}_{i}"
parameters[d.name] = d
d.name = stratified_param
parameters[stratified_param] = d

# Create new initial values for each of the strata
# of the original compartments, copied from the initial
Expand Down Expand Up @@ -320,7 +344,7 @@ def rewrite_rate_law(
template_model: TemplateModel,
old_template: Template,
new_template: Template,
params_count: Counter,
template_strata: List[int],
params_to_stratify: Optional[Collection[str]] = None,
params_to_preserve: Optional[Collection[str]] = None,
):
Expand All @@ -337,9 +361,9 @@ def rewrite_rate_law(
new_template :
The new template. One of the templates created by stratification of
``old_template``.
params_count :
A counter that keeps track of how many times a parameter has been
stratified.
template_strata :
A list of strata indices that have been applied to the template,
used for parameter naming.
params_to_stratify :
A list of parameters to stratify. If none given, will stratify all
parameters.
Expand All @@ -351,7 +375,7 @@ def rewrite_rate_law(
# to the stratified controllers in for the originals
rate_law = old_template.rate_law
if not rate_law:
return
return {}

# If the template has controllers/subjects that affect the rate law
# and there is an overlap between these, then simple substitution
Expand All @@ -362,28 +386,7 @@ def rewrite_rate_law(
old_template.get_controllers()}:
has_subject_controller_overlap = True

# Step 1. Identify the mass action symbol and rename it with a
parameters = list(template_model.get_parameters_from_rate_law(rate_law))
for parameter in parameters:
# If a parameter is explicitly listed as one to preserve, then
# don't stratify it
if params_to_preserve is not None and parameter in params_to_preserve:
continue
# If we have an explicit stratification list then if something isn't
# in the list then don't stratify it.
elif params_to_stratify is not None and parameter not in params_to_stratify:
continue
# Otherwise we go ahead with stratification, i.e., in cases
# where nothing was said about parameter stratification or the
# parameter was listed explicitly to be stratified
else:
rate_law = rate_law.subs(
parameter,
sympy.Symbol(f"{parameter}_{params_count[parameter]}")
)
params_count[parameter] += 1 # increment this each time to keep unique

# Step 2. Rename symbols based on the new concepts
# Step 1. Rename controllers
for old_controller, new_controller in zip(
old_template.get_controllers(), new_template.get_controllers(),
):
Expand All @@ -405,7 +408,7 @@ def rewrite_rate_law(
sympy.Symbol(new_controller.name),
)

# Step 3. Rename subject and object
# Step 2. Rename subject and object
old_cbr = old_template.get_concepts_by_role()
new_cbr = new_template.get_concepts_by_role()
if "subject" in old_cbr and "subject" in new_cbr:
Expand All @@ -419,7 +422,31 @@ def rewrite_rate_law(
sympy.Symbol(new_template.outcome.name),
)

# Step 3. Rename parameters by generating new parameters
# named according to the strata that were applied to the
# given template
parameters = list(template_model.get_parameters_from_rate_law(rate_law))
param_mappings = {}
for parameter in parameters:
# If a parameter is explicitly listed as one to preserve, then
# don't stratify it
if params_to_preserve is not None and parameter in params_to_preserve:
continue
# If we have an explicit stratification list then if something isn't
# in the list then don't stratify it.
elif params_to_stratify is not None and parameter not in params_to_stratify:
continue
# Otherwise we go ahead with stratification, i.e., in cases
# where nothing was said about parameter stratification or the
# parameter was listed explicitly to be stratified
else:
param_suffix = '_'.join([str(s) for s in template_strata])
new_param = f'{parameter}_{param_suffix}'
param_mappings[parameter] = new_param
rate_law = rate_law.subs(parameter, sympy.Symbol(new_param))

new_template.rate_law = rate_law
return param_mappings


def simplify_rate_laws(template_model: TemplateModel):
Expand Down
Loading

0 comments on commit 4003ed8

Please sign in to comment.