Skip to content

Commit

Permalink
fix how we get labels
Browse files Browse the repository at this point in the history
  • Loading branch information
vemonet committed Oct 16, 2023
1 parent 5c704da commit ea010ae
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions src/trapi_predict_kit/trapi_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def resolve_trapi_query(reasoner_query, endpoints_list, infores: str = ""):
if subjs_to_predict and pred_to_predict and objs_to_predict:
subject_ids = subjs_to_predict.get("ids", [])
object_ids = objs_to_predict.get("ids", [])
labels_dict = get_entities_labels(subject_ids + object_ids)

try:
log.info(f"🔮⏳️ Getting predictions for: {subject_ids} | {object_ids}")
Expand All @@ -150,6 +149,12 @@ def resolve_trapi_query(reasoner_query, endpoints_list, infores: str = ""):
log.error(f"Error getting the predictions: {e}")
prediction_json = []

# Get the labels of all entities returned by the prediction function
all_ids = [pred["subject"] for pred in prediction_json] + [
pred["subject"] for pred in prediction_json
]
labels_dict = get_entities_labels(list(set(all_ids)))

for association in prediction_json:
# id/type of nodes are registered in a dict to avoid duplicate in knowledge_graph.nodes
# Build dict of node ID : label
Expand Down Expand Up @@ -186,8 +191,12 @@ def resolve_trapi_query(reasoner_query, endpoints_list, infores: str = ""):
),
}

if subject_id in labels_dict and labels_dict[subject_id]:
node_dict[subject_id]["label"] = labels_dict[subject_id]["id"]["label"]
if "subject_label" in association:
node_dict[subject_id]["label"] = association["subject_label"]
else:
if subject_id in labels_dict and labels_dict[subject_id]:
node_dict[subject_id]["label"] = labels_dict[subject_id]["id"]["label"]

if "object_label" in association:
node_dict[object_id]["label"] = association["object_label"]
else:
Expand Down

0 comments on commit ea010ae

Please sign in to comment.