diff --git a/mira/dkg/api.py b/mira/dkg/api.py index 1278df02c..bf1b3de3b 100644 --- a/mira/dkg/api.py +++ b/mira/dkg/api.py @@ -453,48 +453,63 @@ def common_parent( return entity -class Distance(BaseModel): - """Represents the distance between two entities.""" +class NormalizedCosineSimilarity(BaseModel): + """Represents the normalized cosine similarity between two entities. + + The cosine similarity between two vectors is defined as the dot product + between the vectors divided by the L2 norm (i.e., magnitude) of each + vector. It ranges from [-1,1], where -1 represents two entities that are + very dissimilar, 0 represents entities that are not similar, and 1 represents + entities that are similar. This is calculated using :func:`scipy.spatial.distance.cosine`. + + We normalize this onto a range of [0,1] such that 0 means very dissimilar, 0.5 + means not similar, and 1 means similar. This is accomplished with the transform: + + .. code:: python + + normalized_cosine = (2 - scipy.spatial.distance.cosine(X, Y)) / 2 + """ source: str = Field(..., title="source CURIE") target: str = Field(..., title="target CURIE") - distance: float = Field(..., title="cosine distance") + similarity: float = Field( + ..., title="normalized cosine similarity", ge=0.0, le=1.0 + ) @api_blueprint.post( - "/entity_similarity", response_model=List[Distance], tags=["entities"] + "/entity_similarity", + response_model=List[NormalizedCosineSimilarity], + tags=["entities"], ) def entity_similarity( request: Request, sources: List[str] = Body( ..., - description="A list of CURIEs to use as sources", + description="A list of CURIEs corresponding to DKG terms to use as sources", title="source CURIEs", examples=[["ido:0000511", "ido:0000592", "ido:0000597", "ido:0000514"]], ), targets: Optional[List[str]] = Body( default=None, title="target CURIEs", - description="If not given, source queries used for all-by-all comparison", + description="A list of CURIEs corrsponding to DKG terms to use as targets. " + "If not given, source CURIEs are used in all-by-all comparison", examples=[["ido:0000566", "ido:0000567"]], ), ): - """Get the pairwise similarities between elements referenced by CURIEs in the first list and second list.""" - """Test locally with: - - import requests - - def main(): - curies = ["probonto:k0000000", "probonto:k0000007", "probonto:k0000008"] - res = requests.post( - "http://0.0.0.0:8771/api/entity_similarity", - json={"sources": curies, "targets": curies}, - ) - res.raise_for_status() - print(res.json()) + """Get normalized cosine similarities between source and target entities. + + Similarity is calculated based on topological embeddings of terms in the DKG + produced by the Second-order LINE algorithm described in + `LINE: Large-scale Information Network Embedding `_. + This means that the relationships (i.e., edges) between edges are used to make nodes + that are connected to similar nodes more similar in vector space. - if __name__ == "__main__": - main() + .. warning:: + + The current embedding approach does **not** take into account entities' + lexical features (labels, descriptions, and synonyms). """ vectors = request.app.state.vectors if not vectors: @@ -505,8 +520,6 @@ def main(): targets = sources rv = [] for source, target in itt.product(sources, targets): - if source == target: - continue source_vector = vectors.get(source) if source_vector is None: continue @@ -514,7 +527,10 @@ def main(): if target_vector is None: continue cosine_distance = distance.cosine(source_vector, target_vector) + cosine_similarity = (2 - cosine_distance) / 2 rv.append( - Distance(source=source, target=target, distance=cosine_distance) + NormalizedCosineSimilarity( + source=source, target=target, similarity=cosine_similarity + ) ) return rv diff --git a/mira/dkg/construct.py b/mira/dkg/construct.py index ed74a14d5..c1b0b0465 100644 --- a/mira/dkg/construct.py +++ b/mira/dkg/construct.py @@ -36,6 +36,7 @@ import pyobo import pystow from bioontologies import obograph +from bioontologies.obograph import Xref from bioregistry import manager from pydantic import BaseModel, Field from pyobo.struct import part_of @@ -231,7 +232,7 @@ def main( config=config, refresh=refresh, do_upload=do_upload, - add_xref_edges=add_xref_edges, + add_xref_edges=True, summaries=summaries ) @@ -642,6 +643,10 @@ def _get_edge_name(curie_: str, strict: bool = False) -> str: ] _results_pickle_path.write_bytes(pickle.dumps(parse_results)) + if parse_results.graph_document is None: + click.secho(f"No graphs in {prefix}, skipping", fg="red") + continue + _graphs = parse_results.graph_document.graphs click.secho( f"{manager.get_name(prefix)} ({len(_graphs)} graphs)", fg="green", bold=True @@ -759,17 +764,17 @@ def _get_edge_name(curie_: str, strict: bool = False) -> str: if add_xref_edges: for xref in node.xrefs: - try: - xref_curie = xref.curie - except ValueError: + if not isinstance(xref, Xref): + raise TypeError(f"Invalid type: {type(xref)}: {xref}") + if not xref.value: continue - if xref_curie.split(":", 1)[0] in obograph.PROVENANCE_PREFIXES: + if xref.value.prefix in obograph.PROVENANCE_PREFIXES: # Don't add provenance information as xrefs continue edges.append( ( node.curie, - xref.curie, + xref.value.curie, "xref", "oboinowl:hasDbXref", prefix, @@ -777,11 +782,11 @@ def _get_edge_name(curie_: str, strict: bool = False) -> str: version or "", ) ) - if xref_curie not in nodes: + if xref.value.curie not in nodes: node_sources[node.replaced_by].add(prefix) - nodes[xref_curie] = NodeInfo( - curie=xref.curie, - prefix=xref.prefix, + nodes[xref.value.curie] = NodeInfo( + curie=xref.value.curie, + prefix=xref.value.prefix, label="", synonyms="", deprecated="false", @@ -798,7 +803,7 @@ def _get_edge_name(curie_: str, strict: bool = False) -> str: for provenance in node.get_provenance(): if ":" in provenance.identifier: - tqdm.write(f"Malformed provenance for {node.curie}") + tqdm.write(f"Malformed provenance for {node.curie}: {provenance}") provenance_curie = provenance.curie node_sources[provenance_curie].add(prefix) if provenance_curie not in nodes: diff --git a/mira/dkg/construct_embeddings.py b/mira/dkg/construct_embeddings.py index 3e49689fb..0c966f430 100644 --- a/mira/dkg/construct_embeddings.py +++ b/mira/dkg/construct_embeddings.py @@ -15,7 +15,9 @@ def _construct_embeddings(upload: bool, use_case_paths: UseCasePaths) -> None: with TemporaryDirectory() as directory: path = os.path.join(directory, use_case_paths.EDGES_PATH.stem) - with gzip.open(use_case_paths.EDGES_PATH, "rb") as f_in, open(path, "wb") as f_out: + with gzip.open(use_case_paths.EDGES_PATH, "rb") as f_in, open( + path, "wb" + ) as f_out: shutil.copyfileobj(f_in, f_out) graph = Graph.from_csv( edge_path=path, @@ -26,12 +28,16 @@ def _construct_embeddings(upload: bool, use_case_paths: UseCasePaths) -> None: directed=True, name="MIRA-DKG", ) + # TODO remove disconnected nodes + # graph = graph.remove_disconnected_nodes() embedding = SecondOrderLINEEnsmallen(embedding_size=32).fit_transform(graph) df = embedding.get_all_node_embedding()[0].sort_index() df.index.name = "node" df.to_csv(use_case_paths.EMBEDDINGS_PATH, sep="\t") if upload: - upload_s3(use_case_paths.EMBEDDINGS_PATH, use_case=use_case_paths.use_case) + upload_s3( + use_case_paths.EMBEDDINGS_PATH, use_case=use_case_paths.use_case + ) @click.command() diff --git a/mira/dkg/construct_rdf.py b/mira/dkg/construct_rdf.py index 76aaa15f1..c33ffd854 100644 --- a/mira/dkg/construct_rdf.py +++ b/mira/dkg/construct_rdf.py @@ -133,10 +133,15 @@ def _construct_rdf(upload: bool, *, use_case_paths: UseCasePaths): graph.add((_ref(s), p_ref, _ref(o))) tqdm.write("serializing to turtle") - with gzip.open(use_case_paths.RDF_TTL_PATH, "wb") as file: - graph.serialize(file, format="turtle") - tqdm.write("done") + try: + with gzip.open(use_case_paths.RDF_TTL_PATH, "wb") as file: + graph.serialize(file, format="turtle") + except Exception as e: + click.secho("Failed to serialize RDF", fg="red") + click.echo(str(e)) + return + tqdm.write("done") if upload: upload_s3(use_case_paths.RDF_TTL_PATH, use_case=use_case_paths.use_case) diff --git a/notebooks/Entity Similarity Demo.ipynb b/notebooks/Entity Similarity Demo.ipynb new file mode 100644 index 000000000..6f80c402b --- /dev/null +++ b/notebooks/Entity Similarity Demo.ipynb @@ -0,0 +1,541 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "7d5e6ab6-38e6-4c11-b1e5-424ba9c1be74", + "metadata": {}, + "source": [ + "# Entity Similarity Demo\n", + "\n", + "Entity similarity is calculated based on topological embeddings of terms in the DKG\n", + "produced by the Second-order LINE algorithm described in\n", + "[LINE: Large-scale Information Network Embedding](https://arxiv.org/pdf/1503.03578).\n", + "This means that the relationships (i.e., edges) between edges are used to make nodes\n", + "that are connected to similar nodes more similar in dense vector space.\n", + "Note: the current embedding approach does **not** take into account entities' lexical features (labels, descriptions, and synonyms).\n", + "\n", + "The cosine similarity between two embedding vectors is defined as the dot product\n", + "between the vectors divided by the L2 norm (i.e., magnitude) of each\n", + "vector. It ranges from [-1,1], where -1 represents two entities that are\n", + "very dissimilar, 0 represents entities that are not similar, and 1 represents\n", + "entities that are similar. This is calculated using :func:`scipy.spatial.distance.cosine`.\n", + "\n", + "We normalize this onto a range of [0,1] such that 0 means very dissimilar, 0.5\n", + "means not similar, and 1 means similar. This is accomplished with the transform:\n", + "\n", + "> `normalized_cosine = (2 - scipy.spatial.distance.cosine(X, Y)) / 2`" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "ec636930-4048-4bdd-abc1-1d947b0f5176", + "metadata": {}, + "outputs": [], + "source": [ + "import requests\n", + "import pandas as pd" + ] + }, + { + "cell_type": "markdown", + "id": "5fe79eed-1cc3-4338-922b-5b8c8b70f722", + "metadata": {}, + "source": [ + "Documentation for the entity similarity endpoint can be found at http://34.230.33.149:8771/docs#/entities/entity_similarity_api_entity_similarity_post. It takes in compact URIs (CURIEs), which are the \"primary keys\" for terms in the DKG. It then performs an all-by-all comparison of sources and targets." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "6019d93b-de7a-40d2-a5ca-e66e12aeb448", + "metadata": {}, + "outputs": [], + "source": [ + "URL = \"http://127.0.0.1:8771/api/entity_similarity\"\n", + "\n", + "def get_similarities_df(sources, targets=None):\n", + " if targets is None:\n", + " targets = sources\n", + " res = requests.post(URL, json={\"sources\": sources, \"targets\": targets})\n", + " res.raise_for_status()\n", + " df = pd.DataFrame(res.json())\n", + "\n", + " curies = \",\".join(sorted(set(df.source).union(df.target)))\n", + " res = requests.get(f\"http://127.0.0.1:8771/api/entities/{curies}\").json()\n", + " names = {record['id']: record['name'] for record in res}\n", + " \n", + " assert \"similarity\" in df.columns\n", + " df[\"source_name\"]=df['source'].map(names)\n", + " df[\"target_name\"]=df['target'].map(names)\n", + " return df[[\"source\", \"source_name\", \"target\", \"target_name\", \"similarity\"]]" + ] + }, + { + "cell_type": "markdown", + "id": "15da50ec-15af-4ee0-bea6-6424faa22438", + "metadata": {}, + "source": [ + "Tom's example ([in this thread](https://askemgroup.slack.com/archives/C03THCGK2DU/p1704310487727779)) has us comparing `ido:0000514` (susceptible population) and `ido:0000511` (infected population). We see that these nodes are related, so their cross-comparison has a value over 0.5. The self comparison always will come out to 1.0." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "8e3c01c5-e48d-4663-a850-481f13a3a2b2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sourcesource_nametargettarget_namesimilarity
0ido:0000514susceptible populationido:0000514susceptible population1.000000
1ido:0000514susceptible populationido:0000592immune population0.537801
2ido:0000514susceptible populationvo:0004921human age0.565110
3ido:0000514susceptible populationido:0000511infected population0.555182
4ido:0000514susceptible populationido:0000512diseased population0.639745
5ido:0000514susceptible populationapollosv:00000233infected population0.516487
6ido:0000592immune populationido:0000514susceptible population0.537801
7ido:0000592immune populationido:0000592immune population1.000000
8ido:0000592immune populationvo:0004921human age0.374794
9ido:0000592immune populationido:0000511infected population0.484620
10ido:0000592immune populationido:0000512diseased population0.647143
11ido:0000592immune populationapollosv:00000233infected population0.420608
12vo:0004921human ageido:0000514susceptible population0.565110
13vo:0004921human ageido:0000592immune population0.374794
14vo:0004921human agevo:0004921human age1.000000
15vo:0004921human ageido:0000511infected population0.609507
16vo:0004921human ageido:0000512diseased population0.405982
17vo:0004921human ageapollosv:00000233infected population0.572247
18ido:0000511infected populationido:0000514susceptible population0.555182
19ido:0000511infected populationido:0000592immune population0.484620
20ido:0000511infected populationvo:0004921human age0.609507
21ido:0000511infected populationido:0000511infected population1.000000
22ido:0000511infected populationido:0000512diseased population0.589701
23ido:0000511infected populationapollosv:00000233infected population0.479230
24ido:0000512diseased populationido:0000514susceptible population0.639745
25ido:0000512diseased populationido:0000592immune population0.647143
26ido:0000512diseased populationvo:0004921human age0.405982
27ido:0000512diseased populationido:0000511infected population0.589701
28ido:0000512diseased populationido:0000512diseased population1.000000
29ido:0000512diseased populationapollosv:00000233infected population0.538134
30apollosv:00000233infected populationido:0000514susceptible population0.516487
31apollosv:00000233infected populationido:0000592immune population0.420608
32apollosv:00000233infected populationvo:0004921human age0.572247
33apollosv:00000233infected populationido:0000511infected population0.479230
34apollosv:00000233infected populationido:0000512diseased population0.538134
35apollosv:00000233infected populationapollosv:00000233infected population1.000000
\n", + "
" + ], + "text/plain": [ + " source source_name target \\\n", + "0 ido:0000514 susceptible population ido:0000514 \n", + "1 ido:0000514 susceptible population ido:0000592 \n", + "2 ido:0000514 susceptible population vo:0004921 \n", + "3 ido:0000514 susceptible population ido:0000511 \n", + "4 ido:0000514 susceptible population ido:0000512 \n", + "5 ido:0000514 susceptible population apollosv:00000233 \n", + "6 ido:0000592 immune population ido:0000514 \n", + "7 ido:0000592 immune population ido:0000592 \n", + "8 ido:0000592 immune population vo:0004921 \n", + "9 ido:0000592 immune population ido:0000511 \n", + "10 ido:0000592 immune population ido:0000512 \n", + "11 ido:0000592 immune population apollosv:00000233 \n", + "12 vo:0004921 human age ido:0000514 \n", + "13 vo:0004921 human age ido:0000592 \n", + "14 vo:0004921 human age vo:0004921 \n", + "15 vo:0004921 human age ido:0000511 \n", + "16 vo:0004921 human age ido:0000512 \n", + "17 vo:0004921 human age apollosv:00000233 \n", + "18 ido:0000511 infected population ido:0000514 \n", + "19 ido:0000511 infected population ido:0000592 \n", + "20 ido:0000511 infected population vo:0004921 \n", + "21 ido:0000511 infected population ido:0000511 \n", + "22 ido:0000511 infected population ido:0000512 \n", + "23 ido:0000511 infected population apollosv:00000233 \n", + "24 ido:0000512 diseased population ido:0000514 \n", + "25 ido:0000512 diseased population ido:0000592 \n", + "26 ido:0000512 diseased population vo:0004921 \n", + "27 ido:0000512 diseased population ido:0000511 \n", + "28 ido:0000512 diseased population ido:0000512 \n", + "29 ido:0000512 diseased population apollosv:00000233 \n", + "30 apollosv:00000233 infected population ido:0000514 \n", + "31 apollosv:00000233 infected population ido:0000592 \n", + "32 apollosv:00000233 infected population vo:0004921 \n", + "33 apollosv:00000233 infected population ido:0000511 \n", + "34 apollosv:00000233 infected population ido:0000512 \n", + "35 apollosv:00000233 infected population apollosv:00000233 \n", + "\n", + " target_name similarity \n", + "0 susceptible population 1.000000 \n", + "1 immune population 0.537801 \n", + "2 human age 0.565110 \n", + "3 infected population 0.555182 \n", + "4 diseased population 0.639745 \n", + "5 infected population 0.516487 \n", + "6 susceptible population 0.537801 \n", + "7 immune population 1.000000 \n", + "8 human age 0.374794 \n", + "9 infected population 0.484620 \n", + "10 diseased population 0.647143 \n", + "11 infected population 0.420608 \n", + "12 susceptible population 0.565110 \n", + "13 immune population 0.374794 \n", + "14 human age 1.000000 \n", + "15 infected population 0.609507 \n", + "16 diseased population 0.405982 \n", + "17 infected population 0.572247 \n", + "18 susceptible population 0.555182 \n", + "19 immune population 0.484620 \n", + "20 human age 0.609507 \n", + "21 infected population 1.000000 \n", + "22 diseased population 0.589701 \n", + "23 infected population 0.479230 \n", + "24 susceptible population 0.639745 \n", + "25 immune population 0.647143 \n", + "26 human age 0.405982 \n", + "27 infected population 0.589701 \n", + "28 diseased population 1.000000 \n", + "29 infected population 0.538134 \n", + "30 susceptible population 0.516487 \n", + "31 immune population 0.420608 \n", + "32 human age 0.572247 \n", + "33 infected population 0.479230 \n", + "34 diseased population 0.538134 \n", + "35 infected population 1.000000 " + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "get_similarities_df(\n", + " [\n", + " \"ido:0000514\", # susceptible population\n", + " \"ido:0000592\", # immune population\n", + " \"vo:0004921\", # = human age\n", + " \"ido:0000511\", # = infected population\n", + " \"ido:0000512\", # = diseased population\n", + " \"apollosv:00000233\", # = infected population\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "50df354b-1b3d-4172-b01d-2df30fdc1964", + "metadata": {}, + "source": [ + "Unfortunately, we see that the similarity between apollosv:00000233 (infected population) and ido:0000511 (infected population), which are two different terms from different ontologies describing the same concept, do not have a high similarity. This is probably due to the fact that the edge annotations on IDO terms are much more prevalent than APOLLO_SV terms, and therefore the topological similarity wasn't able to reflect that.\n", + "\n", + "Here's a few ideas on how to remedy this:\n", + "\n", + "1. Include the equivalence edges into the DKG embedding step (they are currently just properties of nodes)\n", + "2. Use SeMRA to automatically collapse nodes together either during the whole DKG build or during the embedding step\n", + "3. Include lexical information in the entity similarity in addition to toplogical similarity" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}