diff --git a/mira/dkg/api.py b/mira/dkg/api.py index bf1b3de3b..22e1bb22d 100644 --- a/mira/dkg/api.py +++ b/mira/dkg/api.py @@ -73,7 +73,7 @@ def get_entity( curie: str = Path( ..., description="A compact URI (CURIE) for an entity in the form of ``:``", - example="ido:0000511", + examples=["ido:0000511"], ), ): """Get information about an entity (e.g., its name, description synonyms, alternative identifiers, @@ -97,7 +97,7 @@ def get_entities( ..., description="A comma-separated list of compact URIs (CURIEs) for an " "entity in the form of ``:,...``", - example="ido:0000511,ido:0000512", + examples=["ido:0000511,ido:0000512"], ), ): """ @@ -158,7 +158,7 @@ def get_transitive_closure( relation_types: List[str] = Query( ..., description="A list of relation types to get a transitive closure for", - example=DKG_REFINER_RELS, + examples=[DKG_REFINER_RELS], ), ): """Get a transitive closure of the requested type(s)""" @@ -384,13 +384,13 @@ def is_ontological_child( ) def search( request: Request, - q: str = Query(..., example="infect", description="The search query"), + q: str = Query(..., examples=["infect"], description="The search query"), limit: int = 25, offset: int = 0, prefixes: Optional[str] = Query( default=None, description="A comma-separated list of prefixes", - examples={ + examples=[{ "no prefix filter": { "summary": "Don't filter by prefix", "value": None, @@ -399,7 +399,7 @@ def search( "summary": "Search for units, which have Wikidata prefixes", "value": "wikidata", }, - }, + }], ), labels: Optional[str] = Query( default=None, diff --git a/mira/dkg/grounding.py b/mira/dkg/grounding.py index 277642c80..eecc07b5a 100644 --- a/mira/dkg/grounding.py +++ b/mira/dkg/grounding.py @@ -149,7 +149,7 @@ def ground_get( description="The text to be grounded. Warning: grounding does not work well for " "substring matches, i.e., if searching only for 'infected'. In these " "cases, using the search API is more appropriate.", - example="Infected Population", + examples=["Infected Population"], ), ): """Ground text with Gilda.""" diff --git a/mira/metamodel/ops.py b/mira/metamodel/ops.py index 513bfb85a..fed575f2d 100644 --- a/mira/metamodel/ops.py +++ b/mira/metamodel/ops.py @@ -28,7 +28,6 @@ def stratify( template_model: TemplateModel, - *, key: str, strata: Collection[str], strata_curie_to_name: Optional[Mapping[str, str]] = None, @@ -42,6 +41,7 @@ def stratify( params_to_preserve: Optional[Collection[str]] = None, concepts_to_stratify: Optional[Collection[str]] = None, concepts_to_preserve: Optional[Collection[str]] = None, + param_renaming_uses_strata_names: Optional[bool] = False, ) -> TemplateModel: """Multiplies a model into several strata. @@ -95,23 +95,23 @@ def stratify( params_to_stratify : A list of parameters to stratify. If none given, will stratify all parameters. - params_to_preserve: + params_to_preserve : A list of parameters to preserve. If none given, will stratify all parameters. concepts_to_stratify : A list of concepts to stratify. If none given, will stratify all concepts. - concepts_to_preserve: + concepts_to_preserve : A list of concepts to preserve. If none given, will stratify all concepts. - + param_renaming_uses_strata_names : + If true, the strata names will be used in the parameter renaming. + If false, the strata indices will be used. Default: False Returns ------- : A stratified template model """ - strata = sorted(strata) - if strata_name_lookup and strata_curie_to_name is None: from mira.dkg.web_client import get_entities_web, MissingBaseUrlError try: @@ -137,8 +137,6 @@ def stratify( # List of new templates templates = [] - # Counter to keep track of how many times a parameter has been stratified - params_count = Counter() # Figure out excluded concepts if concepts_to_stratify is None: @@ -154,7 +152,10 @@ def stratify( concept_names - set(concepts_to_stratify) ) + stratum_index_map = {stratum: i for i, stratum in enumerate(strata)} + keep_unstratified_parameters = set() + all_param_mappings = defaultdict(set) for template in template_model.templates: # If the template doesn't have any concepts that need to be stratified # then we can just keep it as is and skip the rest of the loop @@ -165,65 +166,87 @@ def stratify( templates.append(deepcopy(template)) continue - # Generate a derived template for each strata - for stratum in strata: - new_template = template.with_context( - do_rename=modify_names, exclude_concepts=exclude_concepts, - curie_to_name_map=strata_curie_to_name, - **{key: stratum}, - ) - rewrite_rate_law(template_model=template_model, - old_template=template, - new_template=new_template, - params_count=params_count, - params_to_stratify=params_to_stratify, - params_to_preserve=params_to_preserve) - # parameters = list(template_model.get_parameters_from_rate_law(template.rate_law)) - # if len(parameters) == 1: - # new_template.set_mass_action_rate_law(parameters[0]) - templates.append(new_template) - - # assume all controllers have to get stratified together - # and mixing of strata doesn't occur during control - controllers = template.get_controllers() - if cartesian_control and controllers: - remaining_strata = [s for s in strata if s != stratum] - - # use itt.product to generate all combinations of remaining - # strata for remaining controllers. for example, if there + # Check if we will have any controllers in the template + ncontrollers = num_controllers(template) + # If we have controllers, and we want cartesian control then + # we will stratify controllers separately + stratify_controllers = (ncontrollers > 0) and cartesian_control + + # Generate a derived template for each stratum + for stratum, stratum_idx in stratum_index_map.items(): + template_strata = [] + new_template = deepcopy(template) + # We have to make sure that we only add the stratum to the + # list of template strata if we stratified any of the non-controllers + # in this first for loop + any_noncontrollers_stratified = False + # We apply this stratum to each concept except for controllers + # in case we will separately stratify those + for concept in new_template.get_concepts_flat( + exclude_controllers=stratify_controllers, + refresh=True): + if concept.name in exclude_concepts: + continue + concept.with_context( + do_rename=modify_names, + curie_to_name_map=strata_curie_to_name, + inplace=True, + **{key: stratum}) + any_noncontrollers_stratified = True + + # If we don't stratify controllers then we are done and can just + # make the new rate law, then append this new template + if not stratify_controllers: + # We only need to do this if we stratified any of the non-controllers + if any_noncontrollers_stratified: + template_strata = [stratum if + param_renaming_uses_strata_names else stratum_idx] + param_mappings = rewrite_rate_law(template_model=template_model, + old_template=template, + new_template=new_template, + template_strata=template_strata, + params_to_stratify=params_to_stratify, + params_to_preserve=params_to_preserve) + for old_param, new_param in param_mappings.items(): + all_param_mappings[old_param].add(new_param) + templates.append(new_template) + # Otherwise we are stratifying controllers separately + else: + # Use itt.product to generate all combinations of + # strata for controllers. For example, if there # are two controllers A and B and stratification is into # old, middle, and young, then there will be the following 9: # (A_old, B_old), (A_old, B_middle), (A_old, B_young), # (A_middle, B_old), (A_middle, B_middle), (A_middle, B_young), # (A_young, B_old), (A_young, B_middle), (A_young, B_young) - c_strata_tuples = itt.product(remaining_strata, repeat=len(controllers)) - for c_strata_tuple in c_strata_tuples: - stratified_controllers = [ - controller.with_context(do_rename=modify_names, **{key: c_stratum}) - if controller.name not in exclude_concepts - else controller - for controller, c_stratum in zip(controllers, c_strata_tuple) - ] - if isinstance(template, (GroupedControlledConversion, GroupedControlledProduction)): - stratified_template = new_template.with_controllers(stratified_controllers) - elif isinstance(template, (ControlledConversion, ControlledProduction, - ControlledDegradation, ControlledReplication)): - assert len(stratified_controllers) == 1 - stratified_template = new_template.with_controller(stratified_controllers[0]) - else: - raise NotImplementedError - # the old template is used here on purpose for easier bookkeeping - rewrite_rate_law(template_model=template_model, - old_template=template, - new_template=stratified_template, - params_count=params_count, - params_to_stratify=params_to_stratify, - params_to_preserve=params_to_preserve) + for c_strata_tuple in itt.product(strata, repeat=ncontrollers): + stratified_template = deepcopy(new_template) + stratified_controllers = stratified_template.get_controllers() + template_strata = [stratum if param_renaming_uses_strata_names + else stratum_idx] + # We now apply the stratum assigned to each controller in this particular + # tuple to the controller + for controller, c_stratum in zip(stratified_controllers, c_strata_tuple): + controller.with_context(do_rename=modify_names, inplace=True, + **{key: c_stratum}) + template_strata.append(c_stratum if param_renaming_uses_strata_names + else stratum_index_map[c_stratum]) + + # Wew can now rewrite the rate law for this stratified template, + # then append the new template + param_mappings = rewrite_rate_law(template_model=template_model, + old_template=template, + new_template=stratified_template, + template_strata=template_strata, + params_to_stratify=params_to_stratify, + params_to_preserve=params_to_preserve) + for old_param, new_param in param_mappings.items(): + all_param_mappings[old_param].add(new_param) templates.append(stratified_template) parameters = {} for parameter_key, parameter in template_model.parameters.items(): - if parameter_key not in params_count: + if parameter_key not in all_param_mappings: parameters[parameter_key] = parameter continue # We need to keep the original param if it has been broken @@ -231,11 +254,12 @@ def stratify( # generate the counted parameter variants elif parameter_key in keep_unstratified_parameters: parameters[parameter_key] = parameter - # note that `params_count[key]` will be 1 higher than the number of uses - for i in range(params_count[parameter_key]): + # We otherwise generate variants of the parameter based + # on the previously complied parameter mappings + for stratified_param in all_param_mappings[parameter_key]: d = deepcopy(parameter) - d.name = f"{parameter_key}_{i}" - parameters[d.name] = d + d.name = stratified_param + parameters[stratified_param] = d # Create new initial values for each of the strata # of the original compartments, copied from the initial @@ -320,7 +344,7 @@ def rewrite_rate_law( template_model: TemplateModel, old_template: Template, new_template: Template, - params_count: Counter, + template_strata: List[int], params_to_stratify: Optional[Collection[str]] = None, params_to_preserve: Optional[Collection[str]] = None, ): @@ -337,9 +361,9 @@ def rewrite_rate_law( new_template : The new template. One of the templates created by stratification of ``old_template``. - params_count : - A counter that keeps track of how many times a parameter has been - stratified. + template_strata : + A list of strata indices that have been applied to the template, + used for parameter naming. params_to_stratify : A list of parameters to stratify. If none given, will stratify all parameters. @@ -351,7 +375,7 @@ def rewrite_rate_law( # to the stratified controllers in for the originals rate_law = old_template.rate_law if not rate_law: - return + return {} # If the template has controllers/subjects that affect the rate law # and there is an overlap between these, then simple substitution @@ -362,28 +386,7 @@ def rewrite_rate_law( old_template.get_controllers()}: has_subject_controller_overlap = True - # Step 1. Identify the mass action symbol and rename it with a - parameters = list(template_model.get_parameters_from_rate_law(rate_law)) - for parameter in parameters: - # If a parameter is explicitly listed as one to preserve, then - # don't stratify it - if params_to_preserve is not None and parameter in params_to_preserve: - continue - # If we have an explicit stratification list then if something isn't - # in the list then don't stratify it. - elif params_to_stratify is not None and parameter not in params_to_stratify: - continue - # Otherwise we go ahead with stratification, i.e., in cases - # where nothing was said about parameter stratification or the - # parameter was listed explicitly to be stratified - else: - rate_law = rate_law.subs( - parameter, - sympy.Symbol(f"{parameter}_{params_count[parameter]}") - ) - params_count[parameter] += 1 # increment this each time to keep unique - - # Step 2. Rename symbols based on the new concepts + # Step 1. Rename controllers for old_controller, new_controller in zip( old_template.get_controllers(), new_template.get_controllers(), ): @@ -405,7 +408,7 @@ def rewrite_rate_law( sympy.Symbol(new_controller.name), ) - # Step 3. Rename subject and object + # Step 2. Rename subject and object old_cbr = old_template.get_concepts_by_role() new_cbr = new_template.get_concepts_by_role() if "subject" in old_cbr and "subject" in new_cbr: @@ -419,7 +422,31 @@ def rewrite_rate_law( sympy.Symbol(new_template.outcome.name), ) + # Step 3. Rename parameters by generating new parameters + # named according to the strata that were applied to the + # given template + parameters = list(template_model.get_parameters_from_rate_law(rate_law)) + param_mappings = {} + for parameter in parameters: + # If a parameter is explicitly listed as one to preserve, then + # don't stratify it + if params_to_preserve is not None and parameter in params_to_preserve: + continue + # If we have an explicit stratification list then if something isn't + # in the list then don't stratify it. + elif params_to_stratify is not None and parameter not in params_to_stratify: + continue + # Otherwise we go ahead with stratification, i.e., in cases + # where nothing was said about parameter stratification or the + # parameter was listed explicitly to be stratified + else: + param_suffix = '_'.join([str(s) for s in template_strata]) + new_param = f'{parameter}_{param_suffix}' + param_mappings[parameter] = new_param + rate_law = rate_law.subs(parameter, sympy.Symbol(new_param)) + new_template.rate_law = rate_law + return param_mappings def simplify_rate_laws(template_model: TemplateModel): diff --git a/mira/metamodel/templates.py b/mira/metamodel/templates.py index 3d14c740a..c6a22ca95 100644 --- a/mira/metamodel/templates.py +++ b/mira/metamodel/templates.py @@ -132,7 +132,8 @@ class Config: SympyExprStr: lambda e: sympy.parse_expr(e) } - def with_context(self, do_rename=False, curie_to_name_map=None, **context) -> "Concept": + def with_context(self, do_rename=False, curie_to_name_map=None, + inplace=False, **context) -> "Concept": """Return this concept with extra context. Parameters @@ -146,6 +147,10 @@ def with_context(self, do_rename=False, curie_to_name_map=None, **context) -> "C the context values are e.g. curies or longer names that should be shortened, like {"New York City": "nyc"}. If not provided ( default behavior), the context values will be used as names. + inplace : bool + If True, modify the concept in place. Default: False. + **context : + The context to add to the concept. Returns ------- @@ -164,14 +169,20 @@ def with_context(self, do_rename=False, curie_to_name_map=None, **context) -> "C name = '_'.join(name_list) else: name = self.name - concept = Concept( - name=name, - display_name=self.display_name, - identifiers=self.identifiers, - context=dict(ChainMap(context, self.context)), - units=self.units, - ) - concept._base_name = self._base_name + full_context = dict(ChainMap(context, self.context)) + if inplace: + self.name = name + self.context = full_context + concept = self + else: + concept = Concept( + name=name, + display_name=self.display_name, + identifiers=self.identifiers, + context=full_context, + units=self.units, + ) + concept._base_name = self._base_name return concept def get_curie(self, config: Optional[Config] = None) -> Tuple[str, str]: @@ -565,7 +576,7 @@ def with_context( """ raise NotImplementedError("This method can only be called on subclasses") - def get_concepts(self) -> List[Concept]: + def get_concepts(self) -> List[Union[Concept, List[Concept]]]: """Return the concepts in this template. Returns @@ -579,6 +590,26 @@ def get_concepts(self) -> List[Concept]: ) return [getattr(self, k) for k in self.concept_keys] + def get_concepts_flat(self, exclude_controllers=False, + refresh=False) -> List[Concept]: + """Return the concepts in this template as a flat list. + + Attributes where a list of concepts is expected are flattened. + """ + concepts_flat = [] + for role, value in self.get_concepts_by_role().items(): + if role in {'controllers', 'controller'} and exclude_controllers: + continue + if isinstance(value, list): + if refresh: + setattr(self, role, [deepcopy(v) for v in value]) + concepts_flat.extend(getattr(self, role)) + else: + if refresh: + setattr(self, role, deepcopy(value)) + concepts_flat.append(getattr(self, role)) + return concepts_flat + def get_concepts_by_role(self) -> Dict[str, Concept]: """Return the concepts in this template as a dict keyed by role. @@ -620,7 +651,7 @@ def get_interactors(self) -> List[Concept]: interactors = controllers + ([subject] if subject else []) return interactors - def get_controllers(self): + def get_controllers(self) -> List[Concept]: """Return the controllers in this template. Returns diff --git a/tests/test_model_api.py b/tests/test_model_api.py index eae631461..65e8b45a6 100644 --- a/tests/test_model_api.py +++ b/tests/test_model_api.py @@ -235,8 +235,8 @@ def test_stratify(self): strat_templ_model = stratify( template_model=sir_templ_model, key=key, - strata=set(strata), - strata_curie_to_name=strata_name_map + strata=strata, + strata_curie_to_name=strata_name_map, ) strat_str = sorted_json_str(strat_templ_model.dict()) diff --git a/tests/test_ops.py b/tests/test_ops.py index d93bb3cc3..fe5bf8f50 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -5,8 +5,6 @@ from copy import deepcopy as _d import sympy -import requests -import itertools from mira.metamodel import * from mira.metamodel.ops import stratify, simplify_rate_law, counts_to_dimensionless @@ -51,8 +49,8 @@ def test_stratify_full(self): controller=infected.with_context(vaccination_status="unvaccinated", do_rename=True), rate_law=safe_parse_expr( - 'beta_0 * susceptible_population_unvaccinated * infected_population_unvaccinated', - local_dict={'beta_0': sympy.Symbol('beta_0')} + 'beta_0_0 * susceptible_population_unvaccinated * infected_population_unvaccinated', + local_dict={'beta_0_0': sympy.Symbol('beta_0_0')} ) ) expected_1 = ControlledConversion( @@ -63,8 +61,8 @@ def test_stratify_full(self): controller=infected.with_context(vaccination_status="vaccinated", do_rename=True), rate_law=safe_parse_expr( - 'beta_1 * susceptible_population_unvaccinated * infected_population_vaccinated', - local_dict={'beta_1': sympy.Symbol('beta_1')} + 'beta_0_1 * susceptible_population_unvaccinated * infected_population_vaccinated', + local_dict={'beta_0_1': sympy.Symbol('beta_0_1')} ) ) expected_2 = ControlledConversion( @@ -72,11 +70,11 @@ def test_stratify_full(self): do_rename=True), outcome=infected.with_context(vaccination_status="vaccinated", do_rename=True), - controller=infected.with_context(vaccination_status="vaccinated", + controller=infected.with_context(vaccination_status="unvaccinated", do_rename=True), rate_law=safe_parse_expr( - 'beta_2 * susceptible_population_vaccinated * infected_population_vaccinated', - local_dict={'beta_2': sympy.Symbol('beta_2')} + 'beta_1_0 * susceptible_population_vaccinated * infected_population_unvaccinated', + local_dict={'beta_1_0': sympy.Symbol('beta_1_0')} ) ) expected_3 = ControlledConversion( @@ -84,11 +82,11 @@ def test_stratify_full(self): do_rename=True), outcome=infected.with_context(vaccination_status="vaccinated", do_rename=True), - controller=infected.with_context(vaccination_status="unvaccinated", + controller=infected.with_context(vaccination_status="vaccinated", do_rename=True), rate_law=safe_parse_expr( - 'beta_3 * susceptible_population_vaccinated * infected_population_unvaccinated', - local_dict={'beta_3': sympy.Symbol('beta_3')} + 'beta_1_1 * susceptible_population_vaccinated * infected_population_vaccinated', + local_dict={'beta_1_1': sympy.Symbol('beta_1_1')} ) ) @@ -105,10 +103,10 @@ def test_stratify_full(self): tm_stratified = TemplateModel( templates=[expected_0, expected_1, expected_2, expected_3], parameters={ - "beta_0": Parameter(name="beta_0", value=0.1), - "beta_1": Parameter(name="beta_1", value=0.1), - "beta_2": Parameter(name="beta_2", value=0.1), - "beta_3": Parameter(name="beta_3", value=0.1), + "beta_0_0": Parameter(name="beta_0_0", value=0.1), + "beta_1_0": Parameter(name="beta_1_0", value=0.1), + "beta_1_1": Parameter(name="beta_1_1", value=0.1), + "beta_0_1": Parameter(name="beta_0_1", value=0.1), }, initials={ f"{susceptible.name}_vaccinated": Initial( @@ -132,14 +130,14 @@ def test_stratify_full(self): actual = stratify( tm, key="vaccination_status", - strata=["vaccinated", "unvaccinated"], + strata=["unvaccinated", "vaccinated"], cartesian_control=True, structure=[], modify_names=True, ) self.assertEqual(4, len(actual.templates)) self.assertEqual( - {"beta_0": 0.1, "beta_1": 0.1, "beta_2": 0.1, "beta_3": 0.1}, + {"beta_0_0": 0.1, "beta_0_1": 0.1, "beta_1_0": 0.1, "beta_1_1": 0.1}, {k: p.value for k, p in actual.parameters.items()} ) self.assertEqual( @@ -561,3 +559,21 @@ def test_stratify_excluded_species(): concepts_to_stratify=['susceptible_population']) assert len(tm.templates) == 5, templates + + +def test_stratify_parameter_consistency(): + templates = [ + NaturalDegradation(subject=Concept(name='A'), + rate_law=sympy.Symbol('alpha') * sympy.Symbol('A')), + NaturalDegradation(subject=Concept(name='A'), + rate_law=sympy.Symbol('alpha') * sympy.Symbol('A')), + NaturalDegradation(subject=Concept(name='B'), + rate_law=sympy.Symbol('alpha') * sympy.Symbol('B')), + ] + tm = TemplateModel(templates=templates, + parameters={'alpha': Parameter(name='alpha', value=0.1)}) + tm = stratify(tm, key='age', strata=['young', 'old'], structure=[]) + # This should be two (alpha_0 and alpha_1 instead of 6 which used to + # be the case when parameters would be incrementally numbered for each + # new template + assert len(tm.parameters) == 2