Skip to content

Commit

Permalink
improve entity linking
Browse files Browse the repository at this point in the history
  • Loading branch information
vemonet committed Nov 7, 2024
1 parent 5685ccb commit bdafaf2
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 63 deletions.
1 change: 1 addition & 0 deletions compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ services:
volumes:
- ./data/fastembed_cache:/tmp/fastembed_cache
- ./data/logs:/logs
- ./data/embeddings:/app/data/embeddings
- ./src:/app/src
# entrypoint: uvicorn src.sparql_llm.api:app --host 0.0.0.0 --port 80
env_file:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ chat = [
"openai",
"azure-ai-inference",
"qdrant_client",
"tqdm",
# "fastembed",
]
test = [
Expand Down
59 changes: 58 additions & 1 deletion src/sparql_llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,50 @@ async def stream_openai(response: Stream[ChatCompletionChunk], docs, full_prompt
yield f"data: {json.dumps(resp_chunk)}\n\n"


def extract_entities(sentence: str) -> list[dict[str, str]]:
score_threshold = 0.8
sentence_splitted = re.findall(r"\b\w+\b", sentence)
window_size = len(sentence_splitted)
entities_list = []
while window_size > 0 and window_size <= len(sentence_splitted):
window_start = 0
window_end = window_start + window_size
while window_end <= len(sentence_splitted):
term = sentence_splitted[window_start:window_end]
print("term", term)
term_embeddings = next(iter(embedding_model.embed([" ".join(term)])))
query_hits = vectordb.search(
collection_name=settings.entities_collection_name,
query_vector=term_embeddings,
limit=10,
)
matchs = []
for query_hit in query_hits:
if query_hit.score > score_threshold:
matchs.append(query_hit)
if len(matchs) > 0:
entities_list.append(
{
"matchs": matchs,
"term": term,
"start_index": window_start,
"end_index": window_end,
}
)
# term_search = reduce(lambda x, y: "{} {}".format(x, y), sentence_splitted[window_start:window_end])
# resultSearch = index.search(term_search)
# if resultSearch is not None and len(resultSearch) > 0:
# selected_hit = resultSearch[0]
# if selected_hit['score'] > MAX_SCORE_PARSER_TRIPLES:
# selected_hit = None
# if selected_hit is not None and selected_hit not in matchs:
# matchs.append(selected_hit)
window_start += 1
window_end = window_start + window_size
window_size -= 1
return entities_list


@app.post("/chat")
async def chat(request: ChatCompletionRequest):
if settings.expasy_api_key and request.api_key != settings.expasy_api_key:
Expand Down Expand Up @@ -159,7 +203,7 @@ async def chat(request: ChatCompletionRequest):

# Get the most relevant documents other than SPARQL query examples from the search engine (ShEx shapes, general infos)
docs_hits = vectordb.search(
collection_name="entities",
collection_name=settings.docs_collection_name,
query_vector=query_embeddings,
query_filter=Filter(
should=[
Expand Down Expand Up @@ -190,7 +234,20 @@ async def chat(request: ChatCompletionRequest):
else:
prompt_with_context += f"Information about: {docs_hit.payload['question']}\nRelated to SPARQL endpoint {docs_hit.payload['endpoint_url']}\n\n{docs_hit.payload['answer']}\n\n"

# Now extract entities from the user question
entities_list = extract_entities(question)
for entity in entities_list:
prompt_with_context += f'\n\nEntities found in the user question for "{" ".join(entity["term"])}":\n\n'
for match in entity["matchs"]:
prompt_with_context += f"- {match.payload['label']} with IRI <{match.payload['uri']}> in endpoint {match.payload['endpoint_url']}\n\n"

if len(entities_list) == 0:
prompt_with_context += "\nNo entities found in the user question that matches entities in the endpoints. "

prompt_with_context += "\nIf the user is asking for a named entity, and this entity cannot be found in the endpoint, warn them about the fact we could not find it in the endpoints.\n\n"

prompt_with_context += f"\n{INTRO_USER_QUESTION_PROMPT}\n{question}"
print(prompt_with_context)

# Use messages from the request to keep memory of previous messages sent by the client
# Replace the question asked by the user with the big prompt with all contextual infos
Expand Down
99 changes: 50 additions & 49 deletions src/sparql_llm/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,67 +131,68 @@ def load_ontology(endpoint: dict[str, str]) -> list[Document]:
def init_vectordb(vectordb_host: str = settings.vectordb_host) -> None:
"""Initialize the vectordb with example queries and ontology descriptions from the SPARQL endpoints"""
vectordb = get_vectordb(vectordb_host)
embedding_model = get_embedding_model()
docs: list[Document] = []

endpoints_urls = [endpoint["endpoint_url"] for endpoint in settings.endpoints]
prefix_map = get_prefixes_for_endpoints(endpoints_urls)
if not vectordb.collection_exists(settings.docs_collection_name):
vectordb.create_collection(
collection_name=settings.docs_collection_name,
vectors_config=VectorParams(size=settings.embedding_dimensions, distance=Distance.COSINE),
)
embedding_model = get_embedding_model()
docs: list[Document] = []

for endpoint in settings.endpoints:
print(f"\n 🔎 Getting metadata for {endpoint['label']} at {endpoint['endpoint_url']}")
queries_loader = SparqlExamplesLoader(endpoint["endpoint_url"], verbose=True)
docs += queries_loader.load()
endpoints_urls = [endpoint["endpoint_url"] for endpoint in settings.endpoints]
prefix_map = get_prefixes_for_endpoints(endpoints_urls)

void_loader = SparqlVoidShapesLoader(
endpoint["endpoint_url"],
prefix_map=prefix_map,
verbose=True,
)
docs += void_loader.load()
for endpoint in settings.endpoints:
print(f"\n 🔎 Getting metadata for {endpoint['label']} at {endpoint['endpoint_url']}")
queries_loader = SparqlExamplesLoader(endpoint["endpoint_url"], verbose=True)
docs += queries_loader.load()

docs += load_schemaorg_description(endpoint)
# NOTE: we dont use the ontology for now, schema from shex is better
# docs += load_ontology(endpoint)
void_loader = SparqlVoidShapesLoader(
endpoint["endpoint_url"],
prefix_map=prefix_map,
verbose=True,
)
docs += void_loader.load()

# NOTE: Manually add infos for UniProt since we cant retrieve it for now. Taken from https://www.uniprot.org/help/about
uniprot_description_question = "What is the SIB resource UniProt about?"
docs.append(
Document(
page_content=uniprot_description_question,
metadata={
"question": uniprot_description_question,
"answer": """The Universal Protein Resource (UniProt) is a comprehensive resource for protein sequence and annotation data. The UniProt databases are the UniProt Knowledgebase (UniProtKB), the UniProt Reference Clusters (UniRef), and the UniProt Archive (UniParc). The UniProt consortium and host institutions EMBL-EBI, SIB and PIR are committed to the long-term preservation of the UniProt databases.
docs += load_schemaorg_description(endpoint)
# NOTE: we dont use the ontology for now, schema from shex is better
# docs += load_ontology(endpoint)

# NOTE: Manually add infos for UniProt since we cant retrieve it for now. Taken from https://www.uniprot.org/help/about
uniprot_description_question = "What is the SIB resource UniProt about?"
docs.append(
Document(
page_content=uniprot_description_question,
metadata={
"question": uniprot_description_question,
"answer": """The Universal Protein Resource (UniProt) is a comprehensive resource for protein sequence and annotation data. The UniProt databases are the UniProt Knowledgebase (UniProtKB), the UniProt Reference Clusters (UniRef), and the UniProt Archive (UniParc). The UniProt consortium and host institutions EMBL-EBI, SIB and PIR are committed to the long-term preservation of the UniProt databases.
UniProt is a collaboration between the European Bioinformatics Institute (EMBL-EBI), the SIB Swiss Institute of Bioinformatics and the Protein Information Resource (PIR). Across the three institutes more than 100 people are involved through different tasks such as database curation, software development and support.
UniProt is a collaboration between the European Bioinformatics Institute (EMBL-EBI), the SIB Swiss Institute of Bioinformatics and the Protein Information Resource (PIR). Across the three institutes more than 100 people are involved through different tasks such as database curation, software development and support.
EMBL-EBI and SIB together used to produce Swiss-Prot and TrEMBL, while PIR produced the Protein Sequence Database (PIR-PSD). These two data sets coexisted with different protein sequence coverage and annotation priorities. TrEMBL (Translated EMBL Nucleotide Sequence Data Library) was originally created because sequence data was being generated at a pace that exceeded Swiss-Prot's ability to keep up. Meanwhile, PIR maintained the PIR-PSD and related databases, including iProClass, a database of protein sequences and curated families. In 2002 the three institutes decided to pool their resources and expertise and formed the UniProt consortium.
EMBL-EBI and SIB together used to produce Swiss-Prot and TrEMBL, while PIR produced the Protein Sequence Database (PIR-PSD). These two data sets coexisted with different protein sequence coverage and annotation priorities. TrEMBL (Translated EMBL Nucleotide Sequence Data Library) was originally created because sequence data was being generated at a pace that exceeded Swiss-Prot's ability to keep up. Meanwhile, PIR maintained the PIR-PSD and related databases, including iProClass, a database of protein sequences and curated families. In 2002 the three institutes decided to pool their resources and expertise and formed the UniProt consortium.
The UniProt consortium is headed by Alex Bateman, Alan Bridge and Cathy Wu, supported by key staff, and receives valuable input from an independent Scientific Advisory Board.
""",
"endpoint_url": "https://sparql.uniprot.org/sparql/",
"doc_type": "schemaorg_description",
},
The UniProt consortium is headed by Alex Bateman, Alan Bridge and Cathy Wu, supported by key staff, and receives valuable input from an independent Scientific Advisory Board.
""",
"endpoint_url": "https://sparql.uniprot.org/sparql/",
"doc_type": "schemaorg_description",
},
)
)
)

if not vectordb.collection_exists(settings.docs_collection_name):
vectordb.create_collection(
print(f"Generating embeddings for {len(docs)} documents")
embeddings = embedding_model.embed([q.page_content for q in docs])
start_time = time.time()
vectordb.upsert(
collection_name=settings.docs_collection_name,
vectors_config=VectorParams(size=settings.embedding_dimensions, distance=Distance.COSINE),
points=models.Batch(
ids=list(range(1, len(docs) + 1)),
vectors=embeddings,
payloads=[doc.metadata for doc in docs],
),
# wait=False, # Waiting for indexing to finish or not
)
print(f"Generating embeddings for {len(docs)} documents")
embeddings = embedding_model.embed([q.page_content for q in docs])
start_time = time.time()
vectordb.upsert(
collection_name=settings.docs_collection_name,
points=models.Batch(
ids=list(range(1, len(docs) + 1)),
vectors=embeddings,
payloads=[doc.metadata for doc in docs],
),
# wait=False, # Waiting for indexing to finish or not
)
print(f"Done generating and indexing {len(docs)} documents into the vectordb in {time.time() - start_time} seconds")
print(f"Done generating and indexing {len(docs)} documents into the vectordb in {time.time() - start_time} seconds")

if not vectordb.collection_exists(settings.entities_collection_name):
vectordb.create_collection(
Expand Down
33 changes: 20 additions & 13 deletions src/sparql_llm/embed_entities.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
import csv
import os
from ast import literal_eval
import time
from ast import literal_eval

from langchain_core.documents import Document
from qdrant_client import models

from sparql_llm.config import get_embedding_model, get_vectordb
from sparql_llm.config import get_embedding_model, get_vectordb, settings
from sparql_llm.utils import query_sparql
from tqdm import tqdm

entities_embeddings_dir = os.path.join("data", "embeddings")
entities_embeddings_filepath = os.path.join(entities_embeddings_dir, "entities_embeddings.csv")

def retrieve_index_data(entity: dict , docs: list[Document], pagination : (int, int) | None ):
if pagination:
query = entity["query"] + " LIMIT " + pagination[0] + " OFFSET " + pagination[1]
else:
query = entity["query"]
entities_res = query_sparql(query, entity["endpoint"])["results"]["bindings"]
def retrieve_index_data(entity: dict, docs: list[Document], pagination: (int, int) = None):
query = f"{entity['query']} LIMIT {pagination[0]} OFFSET {pagination[1]}" if pagination else entity["query"]
try:
entities_res = query_sparql(query, entity["endpoint"])["results"]["bindings"]
except Exception as _e:
return None
print(f"Found {len(entities_res)} entities for {entity['label']} in {entity['endpoint']}")
for entity_res in entities_res:
docs.append(
Expand All @@ -32,6 +33,8 @@ def retrieve_index_data(entity: dict , docs: list[Document], pagination : (int,
)
)
return entities_res


def generate_embeddings_for_entities():
start_time = time.time()
embedding_model = get_embedding_model()
Expand All @@ -42,6 +45,7 @@ def generate_embeddings_for_entities():
"label": "Anatomical entity",
"description": "An anatomical entity can be an organism part (e.g. brain, blood, liver and so on) or a material anatomical entity such as a cell.",
"endpoint": "https://www.bgee.org/sparql/",
"pagination": False,
"query": """PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
PREFIX genex: <http://purl.org/genex#>
SELECT DISTINCT ?uri ?label
Expand Down Expand Up @@ -175,15 +179,17 @@ def generate_embeddings_for_entities():
docs: list[Document] = []
for entity in entities_list.values():
if entity["pagination"]:
max_results = 100000
max_results = 200000
pagination = (max_results, 0)
while retrieve_index_data(entity, docs, pagination):
pagination = (pagination[0], pagination[1] + max_results)
else:
retrieve_index_data(entity, docs)

# entities_res = query_sparql(entity["query"], entity["endpoint"])["results"]["bindings"]
# print(f"Found {len(entities_res)} entities for {entity['label']} in {entity['endpoint']}")

print(f"Generating embeddings for {len(docs)} entities")
print(f"Done querying in {time.time() - start_time} seconds, generating embeddings for {len(docs)} entities")

# To test with a smaller number of entities
# docs = docs[:10]
Expand Down Expand Up @@ -228,9 +234,10 @@ def load_entities_embeddings_to_vectordb():
docs = []
embeddings = []

print("Reading entities embeddings from the .csv file")
with open(entities_embeddings_filepath) as file:
reader = csv.DictReader(file)
for _i, row in enumerate(reader):
for row in tqdm(reader, desc="Extracting embeddings from CSV file"):
docs.append(
Document(
page_content=row["label"],
Expand All @@ -243,9 +250,9 @@ def load_entities_embeddings_to_vectordb():
)
)
embeddings.append(literal_eval(row["embedding"]))

print(f"Found embeddings for {len(docs)} entities in {time.time() - start_time} seconds. Now adding them to the vectordb")
vectordb.upsert(
collection_name="entities",
collection_name=settings.entities_collection_name,
points=models.Batch(
ids=list(range(1, len(docs) + 1)),
vectors=embeddings,
Expand Down

0 comments on commit bdafaf2

Please sign in to comment.