diff --git a/tests/test_model_api.py b/tests/test_model_api.py index 67314b6c2..f4def331b 100644 --- a/tests/test_model_api.py +++ b/tests/test_model_api.py @@ -451,3 +451,41 @@ def test_n_way_comparison(self): sorted_json_str(local_response.dict()), sorted_json_str(resp_model.dict()), ) + + def test_n_way_comparison_askenet(self): + sir_parameterized_ctx = TemplateModel( + templates=[ + t.with_context(location="geonames:5128581") + for t in sir_parameterized.templates + ] + ) + askenet_list = [] + for sp in [sir_parameterized, sir_parameterized_ctx]: + askenet_list.append( + AskeNetPetriNetModel(Model(sp)).to_json() + ) + + response = self.client.post( + "/api/askenet_model_comparison", + json={"petrinet_models": askenet_list}, + ) + self.assertEqual(200, response.status_code) + + # See if the response json can be parsed with ModelComparisonResponse + resp_model = ModelComparisonResponse(**response.json()) + + # Check that the response is the same as the local version + # explicitly don't use TemplateModelComparison.from_template_models + local = TemplateModelComparison( + template_models=[sir_parameterized, sir_parameterized_ctx], + refinement_func=is_ontological_child_web + ) + model_comparson_graph_data = local.model_comparison + local_response = ModelComparisonResponse( + graph_comparison_data=model_comparson_graph_data, + similarity_scores=model_comparson_graph_data.get_similarity_scores(), + ) + self.assertEqual( + sorted_json_str(local_response.dict()), + sorted_json_str(resp_model.dict()), + )