From 5c704da1eb43e673ae5e7d71555cfeb2b33f8145 Mon Sep 17 00:00:00 2001 From: Vincent Emonet Date: Mon, 16 Oct 2023 12:56:31 +0200 Subject: [PATCH] change the predict endpoint params to take one JSON request, can now take list of subjects and objects to compute predictions for --- README.md | 51 +-- docs/getting-started/expose-model.md | 38 +- pyproject.toml | 1 + src/trapi_predict_kit/__init__.py | 2 +- src/trapi_predict_kit/decorators.py | 9 +- src/trapi_predict_kit/trapi.py | 54 +-- src/trapi_predict_kit/trapi_parser.py | 476 +++++++++++--------------- src/trapi_predict_kit/types.py | 24 +- src/trapi_predict_kit/utils.py | 48 ++- tests/conftest.py | 16 +- tests/test_trapi.py | 32 +- tests/test_utils.py | 12 +- 12 files changed, 382 insertions(+), 381 deletions(-) diff --git a/README.md b/README.md index 6575516..e603899 100644 --- a/README.md +++ b/README.md @@ -58,22 +58,19 @@ The `trapi_predict_kit` package provides a decorator `@trapi_predict` to annotat The annotated predict functions are expected to take 2 input arguments: the input ID (string) and options for the prediction (dictionary). And it should return a dictionary with a list of predicted associated entities hits. Here is an example: ```python -from trapi_predict_kit import trapi_predict, PredictOptions, PredictOutput +from trapi_predict_kit import trapi_predict, PredictInput, PredictOutput -@trapi_predict(path='/predict', +@trapi_predict( + path='/predict', name="Get predicted targets for a given entity", description="Return the predicted targets for a given entity: drug (DrugBank ID) or disease (OMIM ID), with confidence scores.", edges=[ { 'subject': 'biolink:Drug', 'predicate': 'biolink:treats', + 'inverse': 'biolink:treated_by', 'object': 'biolink:Disease', }, - { - 'subject': 'biolink:Disease', - 'predicate': 'biolink:treated_by', - 'object': 'biolink:Drug', - }, ], nodes={ "biolink:Disease": { @@ -88,22 +85,19 @@ from trapi_predict_kit import trapi_predict, PredictOptions, PredictOutput } } ) -def get_predictions( - input_id: str, options: PredictOptions - ) -> PredictOutput: +def get_predictions(request: PredictInput) -> PredictOutput: + predictions = [] # Add the code the load the model and get predictions here - predictions = { - "hits": [ - { - "id": "DB00001", - "type": "biolink:Drug", - "score": 0.12345, - "label": "Leipirudin", - } - ], - "count": 1, - } - return predictions + # Available props: request.subjects, request.objects, request.options + for subject in request.subjects: + predictions.append({ + "subject": subject, + "object": "DB00001", + "score": 0.12345, + "object_label": "Leipirudin", + "object_type": "biolink:Drug", + }) + return {"hits": predictions, "count": len(predictions)} ``` ### Define the TRAPI object @@ -293,3 +287,16 @@ The deployment of new releases is done automatically by a GitHub Action workflow 3. Create a new release on GitHub, which will automatically trigger the publish workflow, and publish the new release to PyPI. You can also manually trigger the workflow from the Actions tab in your GitHub repository webpage. + +Or use `hatch`: + +```bash +hatch build +hatch publish -u "__token__" +``` + +And create the release with `gh`: + +```bash +gh release create +``` diff --git a/docs/getting-started/expose-model.md b/docs/getting-started/expose-model.md index 2ff4e89..0543a7c 100644 --- a/docs/getting-started/expose-model.md +++ b/docs/getting-started/expose-model.md @@ -11,22 +11,19 @@ The annotated predict functions are expected to take 2 input arguments: the inpu Here is an example: ```python -from trapi_predict_kit import trapi_predict, PredictOptions, PredictOutput +from trapi_predict_kit import trapi_predict, PredictInput, PredictOutput -@trapi_predict(path='/predict', +@trapi_predict( + path='/predict', name="Get predicted targets for a given entity", description="Return the predicted targets for a given entity: drug (DrugBank ID) or disease (OMIM ID), with confidence scores.", edges=[ { 'subject': 'biolink:Drug', 'predicate': 'biolink:treats', + 'inverse': 'biolink:treated_by', 'object': 'biolink:Disease', }, - { - 'subject': 'biolink:Disease', - 'predicate': 'biolink:treated_by', - 'object': 'biolink:Drug', - }, ], nodes={ "biolink:Disease": { @@ -41,22 +38,19 @@ from trapi_predict_kit import trapi_predict, PredictOptions, PredictOutput } } ) -def get_predictions( - input_id: str, options: PredictOptions - ) -> PredictOutput: +def get_predictions(request: PredictInput) -> PredictOutput: + predictions = [] # Add the code the load the model and get predictions here - predictions = { - "hits": [ - { - "id": "drugbank:DB00001", - "type": "biolink:Drug", - "score": 0.12345, - "label": "Leipirudin", - } - ], - "count": 1, - } - return predictions + # Available props: request.subjects, request.objects, request.options + for subject in request.subjects: + predictions.append({ + "subject": subject, + "object": "DB00001", + "score": 0.12345, + "object_label": "Leipirudin", + "object_type": "biolink:Drug", + }) + return {"hits": predictions, "count": len(predictions)} ``` If you generated a project from the template you will find it in the `predict.py` script. diff --git a/pyproject.toml b/pyproject.toml index ee891de..e7a9f36 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "SPARQLWrapper >=2.0.0,<3.0.0", "reasoner-pydantic >=3.0.1", "mlem", + "dvc", # "fairworkflows @ git+https://github.com/vemonet/fairworkflows.git", ] diff --git a/src/trapi_predict_kit/__init__.py b/src/trapi_predict_kit/__init__.py index 0b46e7e..5e97e9e 100644 --- a/src/trapi_predict_kit/__init__.py +++ b/src/trapi_predict_kit/__init__.py @@ -1,6 +1,6 @@ from .decorators import trapi_predict from .save import LoadedModel, load, save -from .types import PredictHit, PredictOptions, PredictOutput, TrainingOutput +from .types import PredictHit, PredictInput, PredictOptions, PredictOutput, TrainingOutput from .trapi import TRAPI from .config import settings from .utils import ( diff --git a/src/trapi_predict_kit/decorators.py b/src/trapi_predict_kit/decorators.py index 1bc2d85..b8c9dae 100644 --- a/src/trapi_predict_kit/decorators.py +++ b/src/trapi_predict_kit/decorators.py @@ -3,7 +3,7 @@ from reasoner_pydantic import MetaEdge, MetaNode -from trapi_predict_kit.types import PredictOptions +from trapi_predict_kit.types import PredictInput def trapi_predict( @@ -12,7 +12,7 @@ def trapi_predict( nodes: Dict[str, MetaNode], name: Optional[str] = None, description: Optional[str] = "", - default_input: Optional[str] = "drugbank:DB00394", + default_input: Optional[str] = None, default_model: Optional[str] = "openpredict_baseline", ) -> Callable: """A decorator to indicate a function is a function to generate prediction that can be integrated to TRAPI. @@ -23,9 +23,8 @@ def trapi_predict( def decorator(func: Callable) -> Any: @functools.wraps(func) - def wrapper(input_id: str, options: Optional[PredictOptions] = None) -> Any: - options = PredictOptions.parse_obj(options) if options else PredictOptions() - return func(input_id, options) + def wrapper(request: PredictInput) -> Any: + return func(PredictInput.parse_obj(request)) wrapper._trapi_predict = { "edges": edges, diff --git a/src/trapi_predict_kit/trapi.py b/src/trapi_predict_kit/trapi.py index 5595e37..d07da6c 100644 --- a/src/trapi_predict_kit/trapi.py +++ b/src/trapi_predict_kit/trapi.py @@ -9,7 +9,7 @@ from reasoner_pydantic import Query from trapi_predict_kit.trapi_parser import resolve_trapi_query -from trapi_predict_kit.types import PredictOptions +from trapi_predict_kit.types import PredictInput REQUIRED_TAGS = [ {"name": "reasoner"}, @@ -49,6 +49,9 @@ def __init__( ) self.predict_endpoints = predict_endpoints self.info = info + self.infores = self.info.get("x-translator", {}).get("infores") + if not self.infores and itrb_url_prefix: + self.infores = f"infores:{itrb_url_prefix}" self.openapi_version = openapi_version # On ITRB deployment and local dev we directly use the current server @@ -187,7 +190,9 @@ def post_reasoner_predict(request_body: Query = Body(..., example=trapi_example) } # return ({"status": 501, "title": "Not Implemented", "detail": "Multi-edges queries not yet implemented", "type": "about:blank" }, 501) - reasonerapi_response = resolve_trapi_query(request_body.dict(exclude_none=True), self.predict_endpoints) + reasonerapi_response = resolve_trapi_query( + request_body.dict(exclude_none=True), self.predict_endpoints, self.infores + ) return JSONResponse(reasonerapi_response) or ("Not found", 404) @@ -205,8 +210,25 @@ def get_meta_knowledge_graph() -> dict: """ metakg = {"edges": [], "nodes": {}} for predict_func in self.predict_endpoints: - if predict_func._trapi_predict["edges"] not in metakg["edges"]: - metakg["edges"] += predict_func._trapi_predict["edges"] + for func_edge in predict_func._trapi_predict["edges"]: + meta_edge = [ + { + "subject": func_edge.get("subject"), + "predicate": func_edge.get("predicate"), + "object": func_edge.get("object"), + } + ] + if "inverse" in predict_func._trapi_predict and predict_func._trapi_predict["inverse"]: + meta_edge.append( + { + "subject": func_edge.get("object"), + "predicate": func_edge.get("inverse"), + "object": func_edge.get("subject"), + } + ) + + if meta_edge not in metakg["edges"]: + metakg["edges"] += meta_edge # Merge nodes dict metakg["nodes"] = {**metakg["nodes"], **predict_func._trapi_predict["nodes"]} return JSONResponse(metakg) @@ -231,26 +253,9 @@ def redirect_root_to_docs(): # Generate endpoints for the loaded models def endpoint_factory(predict_func): - def prediction_endpoint( - input_id: str = predict_func._trapi_predict["default_input"], - model_id: str = predict_func._trapi_predict["default_model"], - min_score: Optional[float] = None, - max_score: Optional[float] = None, - n_results: Optional[int] = None, - ): + def prediction_endpoint(request: PredictInput): try: - return predict_func( - input_id, - PredictOptions.parse_obj( - { - "model_id": model_id, - "min_score": min_score, - "max_score": max_score, - "n_results": n_results, - # "types": ['biolink:Drug'], - } - ), - ) + return predict_func(PredictInput.parse_obj(request)) except Exception as e: return (f"Error when getting the predictions: {e}", 500) @@ -259,8 +264,7 @@ def prediction_endpoint( for predict_func in self.predict_endpoints: self.add_api_route( path=predict_func._trapi_predict["path"], - methods=["GET"], - # endpoint=copy_func(prediction_endpoint, model['path'].replace('/', '')), + methods=["POST"], endpoint=endpoint_factory(predict_func), name=predict_func._trapi_predict["name"], openapi_extra={"description": predict_func._trapi_predict["description"]}, diff --git a/src/trapi_predict_kit/trapi_parser.py b/src/trapi_predict_kit/trapi_parser.py index 9c20297..c54af85 100644 --- a/src/trapi_predict_kit/trapi_parser.py +++ b/src/trapi_predict_kit/trapi_parser.py @@ -8,10 +8,6 @@ # TODO: add evidence path to TRAPI -def is_accepted_id(id_to_check): - return id_to_check.lower().startswith("omim") or id_to_check.lower().startswith("drugbank") - - def get_biolink_parents(concept): concept_snakecase = concept.replace("biolink:", "") concept_snakecase = re.sub(r"(? 0: - try: - resolve_curies = requests.get( - "https://nodenormalization-sri.renci.org/get_normalized_nodes", - params={"curie": ids_to_normalize}, - timeout=settings.TIMEOUT, - ) - # Get corresponding OMIM IDs for MONDO IDs if match - resp = resolve_curies.json() - for resolved_id, alt_ids in resp.items(): - for alt_id in alt_ids["equivalent_identifiers"]: - if is_accepted_id(str(alt_id["identifier"])): - main_id = str(alt_id["identifier"]) - # NOTE: fix issue when NodeNorm returns OMIM.PS: instead of OMIM: - if main_id.lower().startswith("omim"): - main_id = "OMIM:" + main_id.split(":", 1)[1] - resolved_ids_list.append(main_id) - resolved_ids_object[main_id] = resolved_id - except Exception: - log.warn("Error querying the NodeNormalization API, using the original IDs") - # log.info(f"Resolved: {resolve_ids_list} to {resolved_ids_object}") - return resolved_ids_list, resolved_ids_object - - -def resolve_id(id_to_resolve, resolved_ids_object): - if id_to_resolve in resolved_ids_object: - return resolved_ids_object[id_to_resolve] - return id_to_resolve - - -def resolve_trapi_query(reasoner_query, endpoints_list): +def resolve_trapi_query(reasoner_query, endpoints_list, infores: str = ""): """Main function for TRAPI Convert an array of predictions objects to ReasonerAPI format Run the get_predict to get the QueryGraph edges and nodes @@ -99,53 +54,25 @@ def resolve_trapi_query(reasoner_query, endpoints_list): model_id = str(query_options["model_id"]) query_plan = {} - resolved_ids_object = {} - - # if not similarity_embeddings or similarity_embeddings == {}: - # similarity_embeddings = None - # treatment_embeddings = None + # TODO: add a way to automatically resolve IDs passed to the prediction function? + # resolved_ids_object = {} # Parse the query_graph to build the query plan - for edge_id, qg_edge in query_graph["edges"].items(): - # Build dict with all infos of associations to predict + for edge_id, qg_edge in query_graph.get("edges", {}).items(): + qg_subject_node_id = qg_edge.get("subject") + qg_object_node_id = qg_edge.get("object") + subject_node = query_graph["nodes"].get(qg_subject_node_id) + object_node = query_graph["nodes"].get(qg_object_node_id) + # resolved_ids_object = resolve_ids_with_nodenormalization_api( + # subject_node.get("ids", []) + object_node.get("ids", []), resolved_ids_object + # ) query_plan[edge_id] = { - # 'predicates': qg_edge['predicates'], - # 'qedge_subjects': qg_edge['subject'], - "qg_source_id": qg_edge["subject"], - "qg_target_id": qg_edge["object"], + "subject": subject_node, + "predicates": qg_edge.get("predicates"), + "object": object_node, + "qg_subject_node_id": qg_subject_node_id, + "qg_object_node_id": qg_object_node_id, } - query_plan[edge_id]["predicates"] = qg_edge["predicates"] - - # If single value provided for predicate: make it an array - # if not isinstance(query_plan[edge_id]['predicate'], list): - # query_plan[edge_id]['predicate'] = [ query_plan[edge_id]['predicate'] ] - - # Get the nodes infos in the query plan object - for node_id, node in query_graph["nodes"].items(): - if node_id == qg_edge["subject"]: - query_plan[edge_id]["subject_qg_id"] = node_id - query_plan[edge_id]["subject_types"] = node.get("categories", ["biolink:NamedThing"]) - if "ids" in node: - query_plan[edge_id]["subject_kg_id"], resolved_ids_object = resolve_ids_with_nodenormalization_api( - node["ids"], resolved_ids_object - ) - query_plan[edge_id]["ids_to_predict"] = query_plan[edge_id]["subject_kg_id"] - query_plan[edge_id]["types_to_predict"] = query_plan[edge_id]["subject_types"] - query_plan[edge_id]["relation_to_predict"] = "subject" - query_plan[edge_id]["relation_predicted"] = "object" - - if node_id == qg_edge["object"]: - query_plan[edge_id]["object_qg_id"] = node_id - query_plan[edge_id]["object_types"] = node.get("categories", ["biolink:NamedThing"]) - if "ids" in node: - query_plan[edge_id]["object_kg_id"], resolved_ids_object = resolve_ids_with_nodenormalization_api( - node["ids"], resolved_ids_object - ) - if "ids_to_predict" not in query_plan[edge_id]: - query_plan[edge_id]["ids_to_predict"] = query_plan[edge_id]["object_kg_id"] - query_plan[edge_id]["types_to_predict"] = query_plan[edge_id]["object_types"] - query_plan[edge_id]["relation_to_predict"] = "object" - query_plan[edge_id]["relation_predicted"] = "subject" knowledge_graph = {"nodes": {}, "edges": {}} node_dict = {} @@ -154,184 +81,209 @@ def resolve_trapi_query(reasoner_query, endpoints_list): # Now iterates the query plan to execute each query for edge_qg_id in query_plan: - # TODO: exit if no ID provided? Or check already done before? - for predict_func in endpoints_list: - # TODO: run the functions in parallel with future.concurrent + # TODO: run the functions in parallel with future.concurrent? - for prediction_relation in predict_func._trapi_predict["edges"]: - predicate_parents = get_biolink_parents(prediction_relation["predicate"]) - subject_parents = get_biolink_parents(prediction_relation["subject"]) - object_parents = get_biolink_parents(prediction_relation["object"]) + for func_edge in predict_func._trapi_predict["edges"]: + predicate_parents = get_biolink_parents(func_edge["predicate"]) + subject_parents = get_biolink_parents(func_edge["subject"]) + object_parents = get_biolink_parents(func_edge["object"]) + subjs_to_predict = None + pred_to_predict = None + objs_to_predict = None + log.debug(f"QUERY PLAN: {query_plan[edge_qg_id]}") # TODO: add support for "qualifier_constraints" on query edges. cf. https://github.com/NCATSTranslator/testing/blob/main/ars-requests/not-none/1.2/mvp2cMetformin.json - - # Check if requested subject/predicate/object are served by the function if ( any(i in predicate_parents for i in query_plan[edge_qg_id]["predicates"]) - and any(i in subject_parents for i in query_plan[edge_qg_id]["subject_types"]) - and any(i in object_parents for i in query_plan[edge_qg_id]["object_types"]) + and any(i in subject_parents for i in query_plan[edge_qg_id]["subject"].get("categories", [])) + and any(i in object_parents for i in query_plan[edge_qg_id]["object"].get("categories", [])) ): - # TODO: pass all ids_to_predict instead of iterating - # And also pass the list of target IDs if provided: query_plan[edge_qg_id]["object_kg_id"] - # if "subject_kg_id" in query_plan[edge_id] - # and "object_kg_id" in query_plan[edge_id] - # New params are: input_ids, target_ids (target can be None, input is length 1 minimum) - for id_to_predict in query_plan[edge_id]["ids_to_predict"]: - labels_dict = get_entities_labels([id_to_predict]) - label_to_predict = None - if id_to_predict in labels_dict: - label_to_predict = labels_dict[id_to_predict]["id"]["label"] - try: - log.info(f"🔮⏳️ Getting predictions for: {id_to_predict}") - # Run function to get predictions - prediction_results = predict_func( - id_to_predict, - { + subjs_to_predict = query_plan[edge_id]["subject"] + pred_to_predict = func_edge["predicate"] + objs_to_predict = query_plan[edge_id]["object"] + + inverse = False + if "inverse" in func_edge: + inverse_parents = get_biolink_parents(func_edge["inverse"]) + if ( + any(i in inverse_parents for i in query_plan[edge_qg_id]["predicates"]) + and any(i in object_parents for i in query_plan[edge_qg_id]["subject"].get("categories", [])) + and any(i in subject_parents for i in query_plan[edge_qg_id]["object"].get("categories", [])) + ): + inverse = True + subjs_to_predict = query_plan[edge_id]["object"] + pred_to_predict = func_edge["inverse"] + objs_to_predict = query_plan[edge_id]["subject"] + # Also inverse the node binding IDs + # qg_subject_node_id, qg_object_node_id = qg_object_node_id, qg_subject_node_id + query_plan[edge_id]["qg_subject_node_id"], query_plan[edge_id]["qg_object_node_id"] = ( + query_plan[edge_id]["qg_object_node_id"], + query_plan[edge_id]["qg_subject_node_id"], + ) + + # Check if requested subject/predicate/object are served by the function + if subjs_to_predict and pred_to_predict and objs_to_predict: + subject_ids = subjs_to_predict.get("ids", []) + object_ids = objs_to_predict.get("ids", []) + labels_dict = get_entities_labels(subject_ids + object_ids) + + try: + log.info(f"🔮⏳️ Getting predictions for: {subject_ids} | {object_ids}") + # Run function to get predictions + prediction_results = predict_func( + { + "subjects": subject_ids, + "objects": object_ids, + "options": { "model_id": model_id, "min_score": min_score, "max_score": max_score, "n_results": n_results, - "types": query_plan[edge_id]["types_to_predict"], - # "types": query_plan[edge_qg_id]['from_type'], + # "subject_types": subjs_to_predict.get("categories", []), + # "object_types": objs_to_predict.get("categories", []), }, - ) - prediction_json = prediction_results["hits"] - except Exception as e: - log.error(f"Error getting the predictions: {e}") - prediction_json = [] - - for association in prediction_json: - # id/type of nodes are registered in a dict to avoid duplicate in knowledge_graph.nodes - # Build dict of node ID : label - source_node_id = resolve_id(id_to_predict, resolved_ids_object) - target_node_id = resolve_id(association["id"], resolved_ids_object) - - # TODO: XAI get path between source and target nodes (first create the function for this) - - # If the target ID is given, we filter here from the predictions - # if 'to_kg_id' in query_plan[edge_qg_id] and target_node_id not in query_plan[edge_qg_id]['to_kg_id']: - if ( - "subject_kg_id" in query_plan[edge_id] - and "object_kg_id" in query_plan[edge_id] - and target_node_id not in query_plan[edge_qg_id]["object_kg_id"] - ): - pass - + } + ) + prediction_json = prediction_results["hits"] + except Exception as e: + log.error(f"Error getting the predictions: {e}") + prediction_json = [] + + for association in prediction_json: + # id/type of nodes are registered in a dict to avoid duplicate in knowledge_graph.nodes + # Build dict of node ID : label + # log.info(resolved_ids_object) + # subject_id = resolve_id(association["subject"], resolved_ids_object) + # object_id = resolve_id(association["object"], resolved_ids_object) + subject_id = association["subject"] + object_id = association["object"] + + # TODO: XAI get path between source and target nodes (first create the function for this) + + # If the target ID is given, we filter here from the predictions + # if 'to_kg_id' in query_plan[edge_qg_id] and target_node_id not in query_plan[edge_qg_id]['to_kg_id']: + if ( + "subject_kg_id" in query_plan[edge_id] + and "object_kg_id" in query_plan[edge_id] + and object_id not in query_plan[edge_qg_id]["object_kg_id"] + ): + pass + + else: + edge_kg_id = "e" + str(kg_edge_count) + # Get the ID of the predicted entity in result association + # based on the type expected for the association "to" node + + node_dict[subject_id] = { + "type": association.get( + "subject_type", subjs_to_predict.get("categories", ["biolink:NamedThing"]) + ), + } + node_dict[object_id] = { + "type": association.get( + "object_type", objs_to_predict.get("categories", ["biolink:NamedThing"]) + ), + } + + if subject_id in labels_dict and labels_dict[subject_id]: + node_dict[subject_id]["label"] = labels_dict[subject_id]["id"]["label"] + if "object_label" in association: + node_dict[object_id]["label"] = association["object_label"] + else: + if object_id in labels_dict and labels_dict[object_id]: + node_dict[object_id]["label"] = labels_dict[object_id]["id"]["label"] + + # edge_association_type = 'biolink:ChemicalToDiseaseOrPhenotypicFeatureAssociation' + # relation = 'RO:0002434' # interacts with + # relation = 'OBOREL:0002606' + association_score = str(association["score"]) + + model_id_label = model_id + if not model_id_label: + model_id_label = "openpredict_baseline" + + edge_dict = {} + # Map the source/target of query_graph to source/target of association + # if association['source']['type'] == query_plan[edge_qg_id]['from_type']: + if inverse: + edge_dict["subject"] = object_id + edge_dict["object"] = subject_id else: - edge_kg_id = "e" + str(kg_edge_count) - # Get the ID of the predicted entity in result association - # based on the type expected for the association "to" node - # node_dict[id_to_predict] = query_plan[edge_qg_id]['from_type'] - # node_dict[association[query_plan[edge_qg_id]['to_type']]] = query_plan[edge_qg_id]['to_type'] - rel_to_predict = query_plan[edge_id]["relation_to_predict"] - rel_predicted = query_plan[edge_id]["relation_predicted"] - node_dict[source_node_id] = {"type": query_plan[edge_qg_id][f"{rel_to_predict}_types"]} - if label_to_predict: - node_dict[source_node_id]["label"] = label_to_predict - - node_dict[target_node_id] = {"type": association["type"]} - if "label" in association: - node_dict[target_node_id]["label"] = association["label"] - else: - # TODO: improve to avoid to call the resolver everytime - labels_dict = get_entities_labels([target_node_id]) - if target_node_id in labels_dict and labels_dict[target_node_id]: - node_dict[target_node_id]["label"] = labels_dict[target_node_id]["id"]["label"] - - # edge_association_type = 'biolink:ChemicalToDiseaseOrPhenotypicFeatureAssociation' - # relation = 'RO:0002434' # interacts with - # relation = 'OBOREL:0002606' - association_score = str(association["score"]) - - model_id_label = model_id - if not model_id_label: - model_id_label = "openpredict_baseline" - - # See attributes examples: https://github.com/NCATSTranslator/Evidence-Provenance-Confidence-Working-Group/blob/master/attribute_epc_examples/COHD_TRAPI1.1_Attribute_Example_2-3-21.yml - edge_dict = { - # TODO: not required anymore? 'association_type': edge_association_type, - # 'relation': relation, - # More details on attributes: https://github.com/NCATSTranslator/ReasonerAPI/blob/master/docs/reference.md#attribute- - "sources": [ - { - "resource_id": "infores:openpredict", - "resource_role": "primary_knowledge_source", - }, - {"resource_id": "infores:cohd", "resource_role": "supporting_data_source"}, - ], - "attributes": [ - { - "description": "model_id", - "attribute_type_id": "EDAM:data_1048", - "value": model_id_label, - }, - # { - # # TODO: use has_confidence_level? - # "description": "score", - # "attribute_type_id": "EDAM:data_1772", - # "value": association_score - # # https://www.ebi.ac.uk/ols/ontologies/edam/terms?iri=http%3A%2F%2Fedamontology.org%2Fdata_1772&viewMode=All&siblings=false - # }, - # https://github.com/NCATSTranslator/ReasonerAPI/blob/1.4/ImplementationGuidance/Specifications/knowledge_level_agent_type_specification.md - { - "attribute_type_id": "biolink:agent_type", - "value": "computational_model", - "attribute_source": "infores:openpredict", - }, - { - "attribute_type_id": "biolink:knowledge_level", - "value": "prediction", - "attribute_source": "infores:openpredict", - }, - ], - # "knowledge_types": knowledge_types - } - - # Map the source/target of query_graph to source/target of association - # if association['source']['type'] == query_plan[edge_qg_id]['from_type']: - edge_dict["subject"] = source_node_id - edge_dict["object"] = target_node_id - - # TODO: Define the predicate depending on the association source type returned by OpenPredict classifier - if len(query_plan[edge_qg_id]["predicates"]) > 0: - edge_dict["predicate"] = query_plan[edge_qg_id]["predicates"][0] - else: - edge_dict["predicate"] = prediction_relation["predicate"] - - # Add the association in the knowledge_graph as edge - # Use the type as key in the result association dict (for IDs) - knowledge_graph["edges"][edge_kg_id] = edge_dict - - # Add the bindings to the results object - result = { - "node_bindings": {}, - "analyses": [ - { - "resource_id": "infores:openpredict", - "score": association_score, - "scoring_method": "Model confidence between 0 and 1", - "edge_bindings": {edge_qg_id: [{"id": edge_kg_id}]}, - } - ], - # 'edge_bindings': {}, - } - # result['edge_bindings'][edge_qg_id] = [ - # { - # "id": edge_kg_id - # } - # ] - result["node_bindings"][query_plan[edge_qg_id][f"{rel_to_predict}_qg_id"]] = [ - {"id": source_node_id} - ] - result["node_bindings"][query_plan[edge_qg_id][f"{rel_predicted}_qg_id"]] = [ - {"id": target_node_id} - ] - query_results.append(result) - - kg_edge_count += 1 - if kg_edge_count == n_results: - break + edge_dict["subject"] = subject_id + edge_dict["object"] = object_id + + edge_dict["predicate"] = pred_to_predict + + # See attributes examples: https://github.com/NCATSTranslator/Evidence-Provenance-Confidence-Working-Group/blob/master/attribute_epc_examples/COHD_TRAPI1.1_Attribute_Example_2-3-21.yml + edge_dict = { + **edge_dict, + # TODO: not required anymore? 'association_type': edge_association_type, + # 'relation': relation, + # More details on attributes: https://github.com/NCATSTranslator/ReasonerAPI/blob/master/docs/reference.md#attribute- + "sources": [ + { + "resource_id": infores, + "resource_role": "primary_knowledge_source", + }, + {"resource_id": "infores:cohd", "resource_role": "supporting_data_source"}, + ], + "attributes": [ + { + "description": "model_id", + "attribute_type_id": "EDAM:data_1048", + "value": model_id_label, + }, + # { + # # TODO: use has_confidence_level? + # "description": "score", + # "attribute_type_id": "EDAM:data_1772", + # "value": association_score + # # https://www.ebi.ac.uk/ols/ontologies/edam/terms?iri=http%3A%2F%2Fedamontology.org%2Fdata_1772&viewMode=All&siblings=false + # }, + # https://github.com/NCATSTranslator/ReasonerAPI/blob/1.4/ImplementationGuidance/Specifications/knowledge_level_agent_type_specification.md + { + "attribute_type_id": "biolink:agent_type", + "value": "computational_model", + "attribute_source": infores, + }, + { + "attribute_type_id": "biolink:knowledge_level", + "value": "prediction", + "attribute_source": infores, + }, + ], + # "knowledge_types": knowledge_types + } + + # Add the association in the knowledge_graph as edge + # Use the type as key in the result association dict (for IDs) + knowledge_graph["edges"][edge_kg_id] = edge_dict + + # Add the bindings to the results object + result = { + "node_bindings": {}, + "analyses": [ + { + # TODO: pass infores_curie + "resource_id": infores, + "score": association_score, + "scoring_method": "Model confidence between 0 and 1", + "edge_bindings": {edge_qg_id: [{"id": edge_kg_id}]}, + } + ], + } + result["node_bindings"][query_plan[edge_id]["qg_subject_node_id"]] = [ + {"id": association["subject"]} + ] + result["node_bindings"][query_plan[edge_id]["qg_object_node_id"]] = [ + {"id": association["object"]} + ] + query_results.append(result) + + kg_edge_count += 1 + if kg_edge_count == n_results: + break # Generate kg nodes from the dict of nodes + result from query to resolve labels for node_id, properties in node_dict.items(): @@ -350,7 +302,7 @@ def resolve_trapi_query(reasoner_query, endpoints_list): return { "message": {"knowledge_graph": knowledge_graph, "query_graph": query_graph, "results": query_results}, "query_options": query_options, - "reasoner_id": "infores:openpredict", + "reasoner_id": infores, "schema_version": settings.TRAPI_VERSION, "biolink_version": settings.BIOLINK_VERSION, "status": "Success", @@ -363,17 +315,3 @@ def resolve_trapi_query(reasoner_query, endpoints_list): # }, # ] } - - -example_trapi = { - "message": { - "query_graph": { - "edges": {"e01": {"object": "n1", "predicates": ["biolink:treated_by", "biolink:treats"], "subject": "n0"}}, - "nodes": { - "n0": {"categories": ["biolink:Disease", "biolink:Drug"], "ids": ["OMIM:246300", "DRUGBANK:DB00394"]}, - "n1": {"categories": ["biolink:Drug", "biolink:Disease"]}, - }, - } - }, - "query_options": {"max_score": 1, "min_score": 0.5}, -} diff --git a/src/trapi_predict_kit/types.py b/src/trapi_predict_kit/types.py index 6b11762..6579e79 100644 --- a/src/trapi_predict_kit/types.py +++ b/src/trapi_predict_kit/types.py @@ -4,10 +4,17 @@ class PredictHit(BaseModel): - id: str - type: str + subject: str + object: str score: float + subject_type: Optional[str] + object_type: Optional[str] label: Optional[str] + subject_label: Optional[str] + object_label: Optional[str] + + class Config: + arbitrary_types_allowed = True class PredictOutput(BaseModel): @@ -18,11 +25,20 @@ class PredictOutput(BaseModel): class PredictOptions(BaseModel): - model_id: Optional[str] = "openpredict_baseline" + model_id: Optional[str] = None min_score: Optional[float] = None max_score: Optional[float] = None n_results: Optional[int] = None - types: Optional[List[str]] = None + # types: Optional[List[str]] = None + + class Config: + arbitrary_types_allowed = True + + +class PredictInput(BaseModel): + subjects: List[str] = [] + objects: List[str] = [] + options: PredictOptions = PredictOptions() class Config: arbitrary_types_allowed = True diff --git a/src/trapi_predict_kit/utils.py b/src/trapi_predict_kit/utils.py index 0048108..23f692d 100644 --- a/src/trapi_predict_kit/utils.py +++ b/src/trapi_predict_kit/utils.py @@ -21,6 +21,48 @@ log.addHandler(console_handler) +def is_accepted_id(id_to_check): + return id_to_check.lower().startswith("omim") or id_to_check.lower().startswith("drugbank") + + +def resolve_ids_with_nodenormalization_api(resolve_ids_list, resolved_ids_object): + ids_to_normalize = [] + for id_to_resolve in resolve_ids_list: + if is_accepted_id(id_to_resolve): + resolved_ids_object[id_to_resolve] = id_to_resolve + else: + ids_to_normalize.append(id_to_resolve) + + # Query Translator NodeNormalization API to convert IDs to OMIM/DrugBank IDs + if len(ids_to_normalize) > 0: + try: + resolve_curies = requests.get( + "https://nodenormalization-sri.renci.org/get_normalized_nodes", + params={"curie": ids_to_normalize}, + timeout=settings.TIMEOUT, + ) + # Get corresponding OMIM IDs for MONDO IDs if match + resp = resolve_curies.json() + for resolved_id, alt_ids in resp.items(): + for alt_id in alt_ids["equivalent_identifiers"]: + if is_accepted_id(str(alt_id["identifier"])): + main_id = str(alt_id["identifier"]) + # NOTE: fix issue when NodeNorm returns OMIM.PS: instead of OMIM: + if main_id.lower().startswith("omim"): + main_id = "OMIM:" + main_id.split(":", 1)[1] + resolved_ids_object[main_id] = resolved_id + except Exception: + log.warn("Error querying the NodeNormalization API, using the original IDs") + # log.info(f"Resolved: {resolve_ids_list} to {resolved_ids_object}") + return resolved_ids_object + + +def resolve_id(id_to_resolve, resolved_ids_object): + if id_to_resolve in resolved_ids_object: + return resolved_ids_object[id_to_resolve] + return id_to_resolve + + def resolve_entities(label: str) -> Any: """Use Translator SRI Name Resolution API to get the preferred Translator ID""" resp = requests.post( @@ -36,9 +78,9 @@ def normalize_id_to_translator(ids_list: List[str]) -> dict: for an ID https://nodenormalization-sri.renci.org/docs """ converted_ids_obj = {} - resolve_curies = requests.get( + resolve_curies = requests.post( "https://nodenormalization-sri.renci.org/get_normalized_nodes", - params={"curie": ids_list}, + json={"curies": ids_list}, timeout=settings.TIMEOUT, ) # Get corresponding OMIM IDs for MONDO IDs if match @@ -184,7 +226,7 @@ def get_run_metadata(scores: dict, model_features: dict, hyper_params: dict, run # Add scores as EvaluationMeasures g.add((evaluation_uri, RDF.type, MLS["ModelEvaluation"])) for key in scores: - key_uri = URIRef(run_prop_prefix + key) + key_uri = URIRef(f"{run_prop_prefix}{key}") g.add((evaluation_uri, MLS["specifiedBy"], key_uri)) g.add((key_uri, RDF.type, MLS["EvaluationMeasure"])) g.add((key_uri, RDFS.label, Literal(key))) diff --git a/tests/conftest.py b/tests/conftest.py index 4fd0cc5..32c6e9e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,7 @@ import logging import os -from typing import Optional -from trapi_predict_kit import TRAPI, PredictOptions, PredictOutput, trapi_predict +from trapi_predict_kit import TRAPI, PredictInput, PredictOutput, trapi_predict from trapi_predict_kit.config import settings # Setup logger @@ -21,24 +20,21 @@ { "subject": "biolink:Drug", "predicate": "biolink:treats", + "inverse": "biolink:treated_by", "object": "biolink:Disease", }, - { - "subject": "biolink:Disease", - "predicate": "biolink:treated_by", - "object": "biolink:Drug", - }, ], nodes={"biolink:Disease": {"id_prefixes": ["OMIM"]}, "biolink:Drug": {"id_prefixes": ["DRUGBANK"]}}, ) -def get_predictions(input_id: str, options: Optional[PredictOptions] = None) -> PredictOutput: +def get_predictions(request: PredictInput) -> PredictOutput: # Predictions results should be a list of entities # for which there is a predicted association with the input entity predictions = { "hits": [ { - "id": "drugbank:DB00001", - "type": "biolink:Drug", + "subject": "drugbank:DB00001", + "object": "OMIM:246300", + "subject_type": "biolink:Drug", "score": 0.12345, "label": "Leipirudin", } diff --git a/tests/test_trapi.py b/tests/test_trapi.py index 872cc3c..dd356af 100644 --- a/tests/test_trapi.py +++ b/tests/test_trapi.py @@ -32,8 +32,8 @@ def check_trapi_compliance(response): if validator: - # validator.check_compliance_of_trapi_response(response.json()["message"]) - validator.check_compliance_of_trapi_response(response.json()) + # validator.check_compliance_of_trapi_response(response["message"]) + validator.check_compliance_of_trapi_response(response) validator_resp = validator.get_messages() print("⚠️ REASONER VALIDATOR WARNINGS:") print(validator_resp["warnings"]) @@ -47,17 +47,23 @@ def check_trapi_compliance(response): def test_get_predict_drug(): """Test predict API GET operation for a drug""" - url = "/predict?input_id=DRUGBANK:DB00394&n_results=42" - response = client.get(url).json() - assert len(response["hits"]) == 1 - assert response["count"] == 1 - assert response["hits"][0]["id"] == "drugbank:DB00001" + response = client.post( + "/predict", + json={ + "subjects": ["DRUGBANK:DB00394"], + "options": { + "model_id": "openpredict_baseline", + }, + }, + ).json() + assert len(response["hits"]) >= 1 + assert response["count"] >= 1 + assert response["hits"][0]["subject"] == "drugbank:DB00001" def test_get_meta_kg(): """Get the metakg""" - url = "/meta_knowledge_graph" - response = client.get(url).json() + response = client.get("/meta_knowledge_graph").json() assert len(response["edges"]) >= 1 assert len(response["nodes"]) >= 1 @@ -82,11 +88,11 @@ def test_post_trapi(): } response = client.post( "/query", - data=json.dumps(trapi_query), + json=trapi_query, headers={"Content-Type": "application/json"}, - ) - edges = response.json()["message"]["knowledge_graph"]["edges"].items() - assert len(edges) == 1 + ).json() + edges = response["message"]["knowledge_graph"]["edges"].items() + assert len(edges) >= 1 check_trapi_compliance(response) diff --git a/tests/test_utils.py b/tests/test_utils.py index 89f8796..a418868 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,16 +5,15 @@ def test_resolve_entities(): """Test the function to resolve entities using the Name Resolution API""" - expect = 3 resp = resolve_entities("alzheimer") - assert len(resp) == expect - assert "MONDO:0004975" in resp + curie_list = [item["curie"] for item in resp] + assert "MONDO:0004975" in curie_list def test_normalize_id_to_translator(): to_convert = "OMIM:104300" normalized = normalize_id_to_translator([to_convert]) - assert normalized[to_convert] == "MONDO:0004975" + assert normalized[to_convert] == "MONDO:0007088" def test_get_entity_types(): @@ -28,7 +27,6 @@ def test_get_run_metadata(): def test_trapi_predict_decorator(): - expect = 2 - res = get_predictions("drugbank:DB00002", {}) - assert len(get_predictions._trapi_predict["edges"]) == expect + res = get_predictions({"subjects": ["drugbank:DB00002"]}) + assert get_predictions._trapi_predict["edges"][0]["subject"] == "biolink:Drug" assert len(res["hits"]) == 1