Skip to content

Commit

Permalink
change the predict endpoint params to take one JSON request, can now …
Browse files Browse the repository at this point in the history
…take list of subjects and objects to compute predictions for
  • Loading branch information
vemonet committed Oct 16, 2023
1 parent 6bc79e8 commit 5c704da
Show file tree
Hide file tree
Showing 12 changed files with 382 additions and 381 deletions.
51 changes: 29 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,22 +58,19 @@ The `trapi_predict_kit` package provides a decorator `@trapi_predict` to annotat
The annotated predict functions are expected to take 2 input arguments: the input ID (string) and options for the prediction (dictionary). And it should return a dictionary with a list of predicted associated entities hits. Here is an example:

```python
from trapi_predict_kit import trapi_predict, PredictOptions, PredictOutput
from trapi_predict_kit import trapi_predict, PredictInput, PredictOutput

@trapi_predict(path='/predict',
@trapi_predict(
path='/predict',
name="Get predicted targets for a given entity",
description="Return the predicted targets for a given entity: drug (DrugBank ID) or disease (OMIM ID), with confidence scores.",
edges=[
{
'subject': 'biolink:Drug',
'predicate': 'biolink:treats',
'inverse': 'biolink:treated_by',
'object': 'biolink:Disease',
},
{
'subject': 'biolink:Disease',
'predicate': 'biolink:treated_by',
'object': 'biolink:Drug',
},
],
nodes={
"biolink:Disease": {
Expand All @@ -88,22 +85,19 @@ from trapi_predict_kit import trapi_predict, PredictOptions, PredictOutput
}
}
)
def get_predictions(
input_id: str, options: PredictOptions
) -> PredictOutput:
def get_predictions(request: PredictInput) -> PredictOutput:
predictions = []
# Add the code the load the model and get predictions here
predictions = {
"hits": [
{
"id": "DB00001",
"type": "biolink:Drug",
"score": 0.12345,
"label": "Leipirudin",
}
],
"count": 1,
}
return predictions
# Available props: request.subjects, request.objects, request.options
for subject in request.subjects:
predictions.append({
"subject": subject,
"object": "DB00001",
"score": 0.12345,
"object_label": "Leipirudin",
"object_type": "biolink:Drug",
})
return {"hits": predictions, "count": len(predictions)}
```

### Define the TRAPI object
Expand Down Expand Up @@ -293,3 +287,16 @@ The deployment of new releases is done automatically by a GitHub Action workflow
3. Create a new release on GitHub, which will automatically trigger the publish workflow, and publish the new release to PyPI.

You can also manually trigger the workflow from the Actions tab in your GitHub repository webpage.

Or use `hatch`:

```bash
hatch build
hatch publish -u "__token__"
```

And create the release with `gh`:

```bash
gh release create
```
38 changes: 16 additions & 22 deletions docs/getting-started/expose-model.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,19 @@ The annotated predict functions are expected to take 2 input arguments: the inpu
Here is an example:

```python
from trapi_predict_kit import trapi_predict, PredictOptions, PredictOutput
from trapi_predict_kit import trapi_predict, PredictInput, PredictOutput

@trapi_predict(path='/predict',
@trapi_predict(
path='/predict',
name="Get predicted targets for a given entity",
description="Return the predicted targets for a given entity: drug (DrugBank ID) or disease (OMIM ID), with confidence scores.",
edges=[
{
'subject': 'biolink:Drug',
'predicate': 'biolink:treats',
'inverse': 'biolink:treated_by',
'object': 'biolink:Disease',
},
{
'subject': 'biolink:Disease',
'predicate': 'biolink:treated_by',
'object': 'biolink:Drug',
},
],
nodes={
"biolink:Disease": {
Expand All @@ -41,22 +38,19 @@ from trapi_predict_kit import trapi_predict, PredictOptions, PredictOutput
}
}
)
def get_predictions(
input_id: str, options: PredictOptions
) -> PredictOutput:
def get_predictions(request: PredictInput) -> PredictOutput:
predictions = []
# Add the code the load the model and get predictions here
predictions = {
"hits": [
{
"id": "drugbank:DB00001",
"type": "biolink:Drug",
"score": 0.12345,
"label": "Leipirudin",
}
],
"count": 1,
}
return predictions
# Available props: request.subjects, request.objects, request.options
for subject in request.subjects:
predictions.append({
"subject": subject,
"object": "DB00001",
"score": 0.12345,
"object_label": "Leipirudin",
"object_type": "biolink:Drug",
})
return {"hits": predictions, "count": len(predictions)}
```

If you generated a project from the template you will find it in the `predict.py` script.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ dependencies = [
"SPARQLWrapper >=2.0.0,<3.0.0",
"reasoner-pydantic >=3.0.1",
"mlem",
"dvc",
# "fairworkflows @ git+https://github.com/vemonet/fairworkflows.git",
]

Expand Down
2 changes: 1 addition & 1 deletion src/trapi_predict_kit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .decorators import trapi_predict
from .save import LoadedModel, load, save
from .types import PredictHit, PredictOptions, PredictOutput, TrainingOutput
from .types import PredictHit, PredictInput, PredictOptions, PredictOutput, TrainingOutput
from .trapi import TRAPI
from .config import settings
from .utils import (
Expand Down
9 changes: 4 additions & 5 deletions src/trapi_predict_kit/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from reasoner_pydantic import MetaEdge, MetaNode

from trapi_predict_kit.types import PredictOptions
from trapi_predict_kit.types import PredictInput


def trapi_predict(
Expand All @@ -12,7 +12,7 @@ def trapi_predict(
nodes: Dict[str, MetaNode],
name: Optional[str] = None,
description: Optional[str] = "",
default_input: Optional[str] = "drugbank:DB00394",
default_input: Optional[str] = None,
default_model: Optional[str] = "openpredict_baseline",
) -> Callable:
"""A decorator to indicate a function is a function to generate prediction that can be integrated to TRAPI.
Expand All @@ -23,9 +23,8 @@ def trapi_predict(

def decorator(func: Callable) -> Any:
@functools.wraps(func)
def wrapper(input_id: str, options: Optional[PredictOptions] = None) -> Any:
options = PredictOptions.parse_obj(options) if options else PredictOptions()
return func(input_id, options)
def wrapper(request: PredictInput) -> Any:
return func(PredictInput.parse_obj(request))

wrapper._trapi_predict = {
"edges": edges,
Expand Down
54 changes: 29 additions & 25 deletions src/trapi_predict_kit/trapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from reasoner_pydantic import Query

from trapi_predict_kit.trapi_parser import resolve_trapi_query
from trapi_predict_kit.types import PredictOptions
from trapi_predict_kit.types import PredictInput

REQUIRED_TAGS = [
{"name": "reasoner"},
Expand Down Expand Up @@ -49,6 +49,9 @@ def __init__(
)
self.predict_endpoints = predict_endpoints
self.info = info
self.infores = self.info.get("x-translator", {}).get("infores")
if not self.infores and itrb_url_prefix:
self.infores = f"infores:{itrb_url_prefix}"
self.openapi_version = openapi_version

# On ITRB deployment and local dev we directly use the current server
Expand Down Expand Up @@ -187,7 +190,9 @@ def post_reasoner_predict(request_body: Query = Body(..., example=trapi_example)
}
# return ({"status": 501, "title": "Not Implemented", "detail": "Multi-edges queries not yet implemented", "type": "about:blank" }, 501)

reasonerapi_response = resolve_trapi_query(request_body.dict(exclude_none=True), self.predict_endpoints)
reasonerapi_response = resolve_trapi_query(
request_body.dict(exclude_none=True), self.predict_endpoints, self.infores
)

return JSONResponse(reasonerapi_response) or ("Not found", 404)

Expand All @@ -205,8 +210,25 @@ def get_meta_knowledge_graph() -> dict:
"""
metakg = {"edges": [], "nodes": {}}
for predict_func in self.predict_endpoints:
if predict_func._trapi_predict["edges"] not in metakg["edges"]:
metakg["edges"] += predict_func._trapi_predict["edges"]
for func_edge in predict_func._trapi_predict["edges"]:
meta_edge = [
{
"subject": func_edge.get("subject"),
"predicate": func_edge.get("predicate"),
"object": func_edge.get("object"),
}
]
if "inverse" in predict_func._trapi_predict and predict_func._trapi_predict["inverse"]:
meta_edge.append(
{
"subject": func_edge.get("object"),
"predicate": func_edge.get("inverse"),
"object": func_edge.get("subject"),
}
)

if meta_edge not in metakg["edges"]:
metakg["edges"] += meta_edge
# Merge nodes dict
metakg["nodes"] = {**metakg["nodes"], **predict_func._trapi_predict["nodes"]}
return JSONResponse(metakg)
Expand All @@ -231,26 +253,9 @@ def redirect_root_to_docs():

# Generate endpoints for the loaded models
def endpoint_factory(predict_func):
def prediction_endpoint(
input_id: str = predict_func._trapi_predict["default_input"],
model_id: str = predict_func._trapi_predict["default_model"],
min_score: Optional[float] = None,
max_score: Optional[float] = None,
n_results: Optional[int] = None,
):
def prediction_endpoint(request: PredictInput):
try:
return predict_func(
input_id,
PredictOptions.parse_obj(
{
"model_id": model_id,
"min_score": min_score,
"max_score": max_score,
"n_results": n_results,
# "types": ['biolink:Drug'],
}
),
)
return predict_func(PredictInput.parse_obj(request))
except Exception as e:
return (f"Error when getting the predictions: {e}", 500)

Expand All @@ -259,8 +264,7 @@ def prediction_endpoint(
for predict_func in self.predict_endpoints:
self.add_api_route(
path=predict_func._trapi_predict["path"],
methods=["GET"],
# endpoint=copy_func(prediction_endpoint, model['path'].replace('/', '')),
methods=["POST"],
endpoint=endpoint_factory(predict_func),
name=predict_func._trapi_predict["name"],
openapi_extra={"description": predict_func._trapi_predict["description"]},
Expand Down
Loading

0 comments on commit 5c704da

Please sign in to comment.