Skip to content

Commit

Permalink
Merge pull request #271 from gyorilab/improve-entity-similarity-docs
Browse files Browse the repository at this point in the history
Improve entity similarity endpoint
  • Loading branch information
bgyori authored Jan 24, 2024
2 parents 253ee42 + 0642aa0 commit f42a5d3
Show file tree
Hide file tree
Showing 5 changed files with 613 additions and 40 deletions.
64 changes: 40 additions & 24 deletions mira/dkg/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,48 +453,63 @@ def common_parent(
return entity


class Distance(BaseModel):
"""Represents the distance between two entities."""
class NormalizedCosineSimilarity(BaseModel):
"""Represents the normalized cosine similarity between two entities.
The cosine similarity between two vectors is defined as the dot product
between the vectors divided by the L2 norm (i.e., magnitude) of each
vector. It ranges from [-1,1], where -1 represents two entities that are
very dissimilar, 0 represents entities that are not similar, and 1 represents
entities that are similar. This is calculated using :func:`scipy.spatial.distance.cosine`.
We normalize this onto a range of [0,1] such that 0 means very dissimilar, 0.5
means not similar, and 1 means similar. This is accomplished with the transform:
.. code:: python
normalized_cosine = (2 - scipy.spatial.distance.cosine(X, Y)) / 2
"""

source: str = Field(..., title="source CURIE")
target: str = Field(..., title="target CURIE")
distance: float = Field(..., title="cosine distance")
similarity: float = Field(
..., title="normalized cosine similarity", ge=0.0, le=1.0
)


@api_blueprint.post(
"/entity_similarity", response_model=List[Distance], tags=["entities"]
"/entity_similarity",
response_model=List[NormalizedCosineSimilarity],
tags=["entities"],
)
def entity_similarity(
request: Request,
sources: List[str] = Body(
...,
description="A list of CURIEs to use as sources",
description="A list of CURIEs corresponding to DKG terms to use as sources",
title="source CURIEs",
examples=[["ido:0000511", "ido:0000592", "ido:0000597", "ido:0000514"]],
),
targets: Optional[List[str]] = Body(
default=None,
title="target CURIEs",
description="If not given, source queries used for all-by-all comparison",
description="A list of CURIEs corrsponding to DKG terms to use as targets. "
"If not given, source CURIEs are used in all-by-all comparison",
examples=[["ido:0000566", "ido:0000567"]],
),
):
"""Get the pairwise similarities between elements referenced by CURIEs in the first list and second list."""
"""Test locally with:
import requests
def main():
curies = ["probonto:k0000000", "probonto:k0000007", "probonto:k0000008"]
res = requests.post(
"http://0.0.0.0:8771/api/entity_similarity",
json={"sources": curies, "targets": curies},
)
res.raise_for_status()
print(res.json())
"""Get normalized cosine similarities between source and target entities.
Similarity is calculated based on topological embeddings of terms in the DKG
produced by the Second-order LINE algorithm described in
`LINE: Large-scale Information Network Embedding <https://arxiv.org/pdf/1503.03578>`_.
This means that the relationships (i.e., edges) between edges are used to make nodes
that are connected to similar nodes more similar in vector space.
if __name__ == "__main__":
main()
.. warning::
The current embedding approach does **not** take into account entities'
lexical features (labels, descriptions, and synonyms).
"""
vectors = request.app.state.vectors
if not vectors:
Expand All @@ -505,16 +520,17 @@ def main():
targets = sources
rv = []
for source, target in itt.product(sources, targets):
if source == target:
continue
source_vector = vectors.get(source)
if source_vector is None:
continue
target_vector = vectors.get(target)
if target_vector is None:
continue
cosine_distance = distance.cosine(source_vector, target_vector)
cosine_similarity = (2 - cosine_distance) / 2
rv.append(
Distance(source=source, target=target, distance=cosine_distance)
NormalizedCosineSimilarity(
source=source, target=target, similarity=cosine_similarity
)
)
return rv
27 changes: 16 additions & 11 deletions mira/dkg/construct.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import pyobo
import pystow
from bioontologies import obograph
from bioontologies.obograph import Xref
from bioregistry import manager
from pydantic import BaseModel, Field
from pyobo.struct import part_of
Expand Down Expand Up @@ -231,7 +232,7 @@ def main(
config=config,
refresh=refresh,
do_upload=do_upload,
add_xref_edges=add_xref_edges,
add_xref_edges=True,
summaries=summaries
)

Expand Down Expand Up @@ -642,6 +643,10 @@ def _get_edge_name(curie_: str, strict: bool = False) -> str:
]
_results_pickle_path.write_bytes(pickle.dumps(parse_results))

if parse_results.graph_document is None:
click.secho(f"No graphs in {prefix}, skipping", fg="red")
continue

_graphs = parse_results.graph_document.graphs
click.secho(
f"{manager.get_name(prefix)} ({len(_graphs)} graphs)", fg="green", bold=True
Expand Down Expand Up @@ -759,29 +764,29 @@ def _get_edge_name(curie_: str, strict: bool = False) -> str:

if add_xref_edges:
for xref in node.xrefs:
try:
xref_curie = xref.curie
except ValueError:
if not isinstance(xref, Xref):
raise TypeError(f"Invalid type: {type(xref)}: {xref}")
if not xref.value:
continue
if xref_curie.split(":", 1)[0] in obograph.PROVENANCE_PREFIXES:
if xref.value.prefix in obograph.PROVENANCE_PREFIXES:
# Don't add provenance information as xrefs
continue
edges.append(
(
node.curie,
xref.curie,
xref.value.curie,
"xref",
"oboinowl:hasDbXref",
prefix,
graph_id,
version or "",
)
)
if xref_curie not in nodes:
if xref.value.curie not in nodes:
node_sources[node.replaced_by].add(prefix)
nodes[xref_curie] = NodeInfo(
curie=xref.curie,
prefix=xref.prefix,
nodes[xref.value.curie] = NodeInfo(
curie=xref.value.curie,
prefix=xref.value.prefix,
label="",
synonyms="",
deprecated="false",
Expand All @@ -798,7 +803,7 @@ def _get_edge_name(curie_: str, strict: bool = False) -> str:

for provenance in node.get_provenance():
if ":" in provenance.identifier:
tqdm.write(f"Malformed provenance for {node.curie}")
tqdm.write(f"Malformed provenance for {node.curie}: {provenance}")
provenance_curie = provenance.curie
node_sources[provenance_curie].add(prefix)
if provenance_curie not in nodes:
Expand Down
10 changes: 8 additions & 2 deletions mira/dkg/construct_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
def _construct_embeddings(upload: bool, use_case_paths: UseCasePaths) -> None:
with TemporaryDirectory() as directory:
path = os.path.join(directory, use_case_paths.EDGES_PATH.stem)
with gzip.open(use_case_paths.EDGES_PATH, "rb") as f_in, open(path, "wb") as f_out:
with gzip.open(use_case_paths.EDGES_PATH, "rb") as f_in, open(
path, "wb"
) as f_out:
shutil.copyfileobj(f_in, f_out)
graph = Graph.from_csv(
edge_path=path,
Expand All @@ -26,12 +28,16 @@ def _construct_embeddings(upload: bool, use_case_paths: UseCasePaths) -> None:
directed=True,
name="MIRA-DKG",
)
# TODO remove disconnected nodes
# graph = graph.remove_disconnected_nodes()
embedding = SecondOrderLINEEnsmallen(embedding_size=32).fit_transform(graph)
df = embedding.get_all_node_embedding()[0].sort_index()
df.index.name = "node"
df.to_csv(use_case_paths.EMBEDDINGS_PATH, sep="\t")
if upload:
upload_s3(use_case_paths.EMBEDDINGS_PATH, use_case=use_case_paths.use_case)
upload_s3(
use_case_paths.EMBEDDINGS_PATH, use_case=use_case_paths.use_case
)


@click.command()
Expand Down
11 changes: 8 additions & 3 deletions mira/dkg/construct_rdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,15 @@ def _construct_rdf(upload: bool, *, use_case_paths: UseCasePaths):
graph.add((_ref(s), p_ref, _ref(o)))

tqdm.write("serializing to turtle")
with gzip.open(use_case_paths.RDF_TTL_PATH, "wb") as file:
graph.serialize(file, format="turtle")
tqdm.write("done")
try:
with gzip.open(use_case_paths.RDF_TTL_PATH, "wb") as file:
graph.serialize(file, format="turtle")
except Exception as e:
click.secho("Failed to serialize RDF", fg="red")
click.echo(str(e))
return

tqdm.write("done")
if upload:
upload_s3(use_case_paths.RDF_TTL_PATH, use_case=use_case_paths.use_case)

Expand Down
Loading

0 comments on commit f42a5d3

Please sign in to comment.