Skip to content

Commit

Permalink
Add observables replacement and add more tests for model composition
Browse files Browse the repository at this point in the history
  • Loading branch information
nanglo123 committed Jun 12, 2024
1 parent 526b394 commit e08fdd7
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 54 deletions.
157 changes: 117 additions & 40 deletions mira/metamodel/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
"compose_two_models"
]

from copy import deepcopy
import sympy

from .comparison import TemplateModelComparison, get_dkg_refinement_closure
from .template_model import Author, Annotations, TemplateModel

Expand Down Expand Up @@ -82,6 +85,20 @@ def compose_two_models(tm0, tm1):
refinement_func=rf_func)
compare_graph = compare.model_comparison
comparison_result = compare_graph.get_similarity_scores()
tm_keys = [tm_key for tm_key in compare_graph.template_models]
outer_tm_id = tm_keys[0]
inner_tm_id = tm_keys[1]

# Create a copy of observables in case we need to modify observable
# expressions
new_observables = {key: deepcopy(value) for d in (tm1.observables,
tm0.observables)
for key, value in d.items()}

# Create a copy of templates in case we need to modify template rate laws
# when doing time substitutions
tm0.templates = deepcopy(tm0.templates)
tm1.templates = deepcopy(tm1.templates)

new_annotations = annotation_composition(tm0.annotations,
tm1.annotations)
Expand All @@ -96,7 +113,6 @@ def compose_two_models(tm0, tm1):
new_templates = tm0.templates + tm1.templates
new_parameters = {**tm1.parameters, **tm0.parameters}
new_initials = {**tm1.initials, **tm0.initials}
new_observables = {**tm1.observables, **tm0.observables}

composed_tm = TemplateModel(templates=new_templates,
parameters=new_parameters,
Expand All @@ -113,56 +129,105 @@ def compose_two_models(tm0, tm1):
new_templates = []
new_parameters = {}
new_initials = {}
new_observables = {}
concept_map = {}

# We wouldn't have an edge from a template to a concept node,
# so we only need to check if the source edge tuple contains a template
# or a concept id
inter_model_edge_dict = {
inter_template_edges = {
inter_model_edge[0:2]: inter_model_edge[2]
for inter_model_edge in compare_graph.inter_model_edges if
(inter_model_edge[0][1] not in compare_graph.concept_nodes[0] and
inter_model_edge[0][1] not in compare_graph.concept_nodes[1])
}

# same type of checking for inter-concept edges
inter_concept_edges = {
inter_model_edge[0:2]: inter_model_edge[2]
for inter_model_edge in compare_graph.inter_model_edges if
(inter_model_edge[0][1] not in compare_graph.template_nodes[0] and
inter_model_edge[0][1] not in compare_graph.template_nodes[1])
}

# doesn't handle case where if we add the template that has the less
# specific concept for some reason. If we add the template that
# contains S instead of the template that has S_old, this code will
# add S_old and not S even though S_old isn't a present concept in
# the added template
# will a case like this ever happen?

# for source_target_concept_edge, relation in inter_concept_edges.items():
# replaced_tm_id, replaced_concept_id = source_target_concept_edge[1]
# new_tm_id, new_concept_id = source_target_concept_edge[0]
#
# replaced_concept = compare_graph.concept_nodes[replaced_tm_id][
# replaced_concept_id]
# new_concept = compare_graph.concept_nodes[new_tm_id][new_concept_id]
# concept_map.setdefault(replaced_concept.name, set())
# concept_map[replaced_concept.name].add(new_concept.name)

# process templates that are present in a relation first
# we only process the source template because either it's a template
# equality relation, so we prioritize the first tm passed in,
# or it's a refinement relationship in which we want to add the more
# specific template which is the source template
for source_target_edge, relation in inter_model_edge_dict.items():
tm_id, template_id = source_target_edge[0]
tm, added_template = compare_graph.template_models[tm_id], \
compare_graph.template_nodes[tm_id][template_id]
process_template(new_templates, added_template, tm,
new_parameters, new_initials, new_observables)

tm_keys = [tm_key for tm_key in compare_graph.template_models]
outer_tm_id = tm_keys[0]
inner_tm_id = tm_keys[1]
for source_target_template_edge, relation in inter_template_edges.items():
new_tm_id, new_template_id = source_target_template_edge[0]

new_tm, new_template = compare_graph.template_models[
new_tm_id], \
compare_graph.template_nodes[new_tm_id][new_template_id]

old_tm_id, old_template_id = source_target_template_edge[
1]
replaced_tm, replaced_template = compare_graph.template_models[
old_tm_id], compare_graph.template_nodes[old_tm_id][
old_template_id]

replaced_concepts = replaced_template.get_concepts()
new_concepts = new_template.get_concepts()
process_template(new_templates, new_template, new_tm,
new_parameters, new_initials)

# For each concept from an old template that has been replaced from
# find all the new concepts related to the replaced concept
for replaced_concept in replaced_concepts:
concept_map.setdefault(replaced_concept.name, set())
old_concept_id = next((key for key, val in
compare_graph.concept_nodes[
old_tm_id].items() if val
== replaced_concept))
for new_concept in new_concepts:
new_concept_id = next((key for key, val in
compare_graph.concept_nodes[
new_tm_id].items() if val
== new_concept))
lookup = ((new_tm_id, new_concept_id),
(old_tm_id, old_concept_id))
if lookup in inter_concept_edges:
concept_map[replaced_concept.name].add(
new_concept.name)

update_observable_expressions(new_observables, concept_map)

for outer_template_id, outer_template in enumerate(tm0.templates):
for inner_template_id, inner_template in enumerate(tm1.templates):

# only process templates that haven't been pre-processed
# by checking to see if they aren't present in the
# inter_edge_dict mapping

# process inner template first such that outer_template from
# tm0 take priority
if not check_template_in_inter_edge_dict(inter_model_edge_dict,
inner_tm_id,
inner_template_id):
if not check_template_in_inter_edge_dict(
inter_template_edges, inner_tm_id, inner_template_id):
process_template(new_templates, inner_template, tm1,
new_parameters, new_initials,
new_observables)

if not check_template_in_inter_edge_dict(inter_model_edge_dict,
outer_tm_id,
outer_template_id):
new_parameters, new_initials)
if not check_template_in_inter_edge_dict(
inter_template_edges,
outer_tm_id,
outer_template_id):
process_template(new_templates, outer_template, tm0,
new_parameters, new_initials,
new_observables)
new_parameters, new_initials)

composed_tm = TemplateModel(templates=new_templates,
parameters=new_parameters,
Expand Down Expand Up @@ -204,8 +269,8 @@ def check_template_in_inter_edge_dict(inter_edge_dict, tm_id, template_id):
return False


def process_template(templates, added_template, tm, parameters, initials,
observables):
def process_template(templates, added_template, added_tm, parameters,
initials):
"""Helper method that updates the dictionaries that contain the attributes
to be used for the new composed template model
Expand All @@ -216,7 +281,7 @@ def process_template(templates, added_template, tm, parameters, initials,
added_template :
The template that was added to the list of templates for the composed
template model
tm :
added_tm :
The input template model to the model_compose method that contains the
template to be added
parameters :
Expand All @@ -225,23 +290,34 @@ def process_template(templates, added_template, tm, parameters, initials,
initials :
The dictionary of initials to update that will be used for the
composed template model
observables :
The dictionary observables to update that will be used for the
composed template model
"""
if added_template not in templates:
templates.append(added_template)
parameters.update({param_name: tm.parameters[param_name] for param_name
parameters.update({param_name: added_tm.parameters[param_name] for
param_name
in added_template.get_parameter_names()})
initials.update({initial_name: tm.initials[initial_name] for
initials.update({initial_name: added_tm.initials[initial_name] for
initial_name in added_template.get_concept_names()
if initial_name in tm.initials})
if initial_name in added_tm.initials})


def update_observable_expressions(observables, concept_map):
"""Helper method that updates observables expressions based on which
concepts were replaced
def update_observables():
# TODO: Clarify on how to update observables for template models
# that are partially similar
pass
Parameters
----------
observables :
The dictionary of observables to update
concept_map :
The mapping of old concepts to the list of new concepts
"""
for observable in observables.values():
for old_concept_name, new_concept_list in concept_map.items():
new_expression = sum([sympy.Symbol(new_concept_name) for
new_concept_name in new_concept_list])
observable.expression = observable.expression.subs(sympy.Symbol(
old_concept_name), new_expression)


def substitute_time(tm, time_0, time_1):
Expand All @@ -262,8 +338,9 @@ def substitute_time(tm, time_0, time_1):
The time that will be substituted
"""
for template in tm.templates:
template.rate_law = template.rate_law.subs(time_1.units.expression,
time_0.units.expression)
if template.rate_law:
template.rate_law = template.rate_law.subs(time_1.units.expression,
time_0.units.expression)
for observable in tm.observables.values():
observable.expression = observable.expression.subs(
time_1.units.expression, time_0.units.expression)
Expand Down
51 changes: 37 additions & 14 deletions tests/test_model_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
"https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations"
"/main/petrinet/examples/sir.json")

halfar_petrinet_tm = model_from_url(
"https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations"
"/main/petrinet/examples/halfar.json"
)

infection = ControlledConversion(
subject=susceptible,
outcome=infected,
Expand Down Expand Up @@ -70,22 +75,40 @@ def test_compose_list():
assert len(composed_model.templates) == 4


def test_model_priority():
"""Test to see that we prioritize the first template model passed in"""
composed_tm = compose_two_models(sir, sir_petrinet_tm)
def test_disjoint_models():
composed_tm = compose_two_models(sir_petrinet_tm, halfar_petrinet_tm)
assert len(composed_tm.initials) == len(sir_petrinet_tm.initials) + len(
halfar_petrinet_tm.initials)
assert len(composed_tm.templates) == len(sir_petrinet_tm.templates) + len(
halfar_petrinet_tm.templates)
# shared parameter "gamma"
assert len(composed_tm.parameters) == len(sir_petrinet_tm.parameters) + len(
halfar_petrinet_tm.parameters) - 1
assert (len(composed_tm.observables) == len(sir_petrinet_tm.observables) +
len(halfar_petrinet_tm.observables))

assert len(composed_tm.templates) == 2
assert len(composed_tm.initials) == 0
assert len(composed_tm.parameters) == 0
assert composed_tm.templates[0].rate_law is None
assert composed_tm.templates[1].rate_law is None

composed_tm = compose_two_models(sir_petrinet_tm, sir)
assert len(composed_tm.templates) == 2
assert len(composed_tm.initials) == 3
assert len(composed_tm.parameters) == 2
assert composed_tm.templates[0].rate_law
assert composed_tm.templates[1].rate_law
def test_model_priority():
"""Test to see that we prioritize the first template model passed in and
checks for observable expression concept replacement"""
composed_tm_0 = compose_two_models(sir, sir_petrinet_tm)
assert len(composed_tm_0.templates) == 2
assert len(composed_tm_0.initials) == 0
assert len(composed_tm_0.parameters) == 0
assert len(composed_tm_0.observables) == 1
assert (str(composed_tm_0.observables["noninf"].expression) ==
"immune_population + susceptible_population")
assert composed_tm_0.templates[0].rate_law is None
assert composed_tm_0.templates[1].rate_law is None

composed_tm_1 = compose_two_models(sir_petrinet_tm, sir)
assert len(composed_tm_1.templates) == 2
assert len(composed_tm_1.initials) == 3
assert len(composed_tm_1.parameters) == 2
assert len(composed_tm_1.observables) == 1
assert str(composed_tm_1.observables["noninf"].expression) == "R + S"
assert composed_tm_1.templates[0].rate_law
assert composed_tm_1.templates[1].rate_law


def test_template_inclusion():
Expand Down

0 comments on commit e08fdd7

Please sign in to comment.