Skip to content

Commit

Permalink
create script to generate embeddings for entities
Browse files Browse the repository at this point in the history
  • Loading branch information
vemonet committed Nov 5, 2024
1 parent c5d421f commit 4975997
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 5 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ __pycache__/

notebooks/*.csv
notebooks/*.txt
entities_embeddings.csv
uv.lock
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ RUN pip install --upgrade pip
COPY . /app/
# COPY ./scripts/prestart.sh /app/

RUN pip install -e ".[chat]"
RUN pip install -e ".[chat,cpu]"

ENV PYTHONPATH=/app
ENV MODULE_NAME=src.sparql_llm.api
Expand Down
10 changes: 9 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ chat = [
"openai",
"azure-ai-inference",
"qdrant_client",
"fastembed",
# "fastembed",
]
test = [
"pytest",
Expand All @@ -68,6 +68,10 @@ test = [
# "ipywidgets",
# "tqdm",
]
cpu = [
# "onnxruntime",
"fastembed",
]
gpu = [
"onnxruntime-gpu",
"fastembed-gpu",
Expand All @@ -94,6 +98,10 @@ features = [
]
post-install-commands = []


# uv venv
# uv pip install ".[chat,gpu]"
# uv run python src/sparql_llm/embed_entities.py
[tool.hatch.envs.default.scripts]
fmt = [
"ruff format",
Expand Down
3 changes: 2 additions & 1 deletion src/sparql_llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ async def stream_openai(response: Stream[ChatCompletionChunk], docs, full_prompt
yield f"data: {json.dumps(resp_chunk)}\n\n"



@app.post("/chat")
async def chat(request: ChatCompletionRequest):
if settings.expasy_api_key and request.api_key != settings.expasy_api_key:
Expand Down Expand Up @@ -159,7 +160,7 @@ async def chat(request: ChatCompletionRequest):

# Get the most relevant documents other than SPARQL query examples from the search engine (ShEx shapes, general infos)
docs_hits = vectordb.search(
collection_name=settings.docs_collection_name,
collection_name="entities",
query_vector=query_embeddings,
query_filter=Filter(
should=[
Expand Down
2 changes: 2 additions & 0 deletions src/sparql_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ class Settings(BaseSettings):
glhf_api_key: str = ""
expasy_api_key: str = ""
logs_api_key: str = ""
azure_inference_credential: str = ""
azure_inference_endpoint: str = ""
# llm_model: str = "gpt-4o"
# cheap_llm_model: str = "gpt-4o-mini"

Expand Down
29 changes: 27 additions & 2 deletions src/sparql_llm/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from sparql_llm.config import settings
from sparql_llm.sparql_examples_loader import SparqlExamplesLoader
from sparql_llm.sparql_void_shapes_loader import SparqlVoidShapesLoader
from sparql_llm.utils import get_prefixes_for_endpoints

from sparql_llm.utils import get_prefixes_for_endpoints, query_sparql

def get_embedding_model() -> TextEmbedding:
return TextEmbedding(settings.embedding_model)
Expand Down Expand Up @@ -205,6 +204,32 @@ def init_vectordb(vectordb_host: str = settings.vectordb_host) -> None:
)
print(f"Done generating and indexing {len(docs)} documents into the vectordb in {time.time() - start_time} seconds")

docs = []
# TODO: Add entities list to the vectordb
for entity in entities_list.values():
res = query_sparql(entity["query"], entity["endpoint"])
for entity_res in res["results"]["bindings"]:
docs.append(
Document(
page_content=entity_res["label"],
metadata={
"label": entity_res["label"],
"uri": entity_res["uri"],
"endpoint_url": entity["endpoint"],
"entity_type": entity["uri"],
},
)
)
print(f"Generating embeddings for {len(docs)} entities")
vectordb.upsert(
collection_name="entities",
points=models.Batch(
ids=list(range(1, len(docs) + 1)),
vectors=embeddings,
payloads=[doc.metadata for doc in docs],
),
# wait=False, # Waiting for indexing to finish or not
)

if __name__ == "__main__":
init_vectordb()
Expand Down
86 changes: 86 additions & 0 deletions src/sparql_llm/embed_entities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from langchain_core.documents import Document
from qdrant_client import models

from sparql_llm.embed import get_embedding_model, get_vectordb
from sparql_llm.utils import query_sparql
import csv


embedding_model = get_embedding_model()

entities_list = {
"genex:AnatomicalEntity": {
"label": "Anatomical entity",
"uri": "http://purl.org/genex#AnatomicalEntity",
"description": "An anatomical entity can be an organism part (e.g. brain, blood, liver and so on) or a material anatomical entity such as a cell.",
"endpoint": "https://www.bgee.org/sparql/",
"query": """PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
PREFIX genex: <http://purl.org/genex#>
SELECT DISTINCT ?uri ?label
WHERE {
?uri a genex:AnatomicalEntity ;
rdfs:label ?label .
}"""
},
"bgee_species": {
"label": "Anatomical entity",
"uri": "http://purl.uniprot.org/core/Species",
"description": "An anatomical entity can be an organism part (e.g. brain, blood, liver and so on) or a material anatomical entity such as a cell.",
"endpoint": "https://www.bgee.org/sparql/",
"query": """PREFIX up: <http://purl.uniprot.org/core/>
SELECT ?uri ?label
WHERE {
?uri a up:Taxon ;
up:rank up:Species ;
up:scientificName ?label .
}"""
},
}

docs: list[Document] = []
for entity in entities_list.values():
res = query_sparql(entity["query"], entity["endpoint"])
for entity_res in res["results"]["bindings"]:
docs.append(
Document(
page_content=entity_res["label"]["value"],
metadata={
"label": entity_res["label"]["value"],
"uri": entity_res["uri"]["value"],
"endpoint_url": entity["endpoint"],
"entity_type": entity["uri"],
},
)
)
print(f"Generating embeddings for {len(docs)} entities")

# To test with a smaller number of entities
docs = docs[:10]

embeddings = embedding_model.embed([q.page_content for q in docs])

with open('entities_embeddings.csv', mode='w', newline='') as file:
writer = csv.writer(file)
header = ["label", "uri", "endpoint_url", "entity_type", "embedding"]
writer.writerow(header)

for doc, embedding in zip(docs, embeddings):
row = [
doc.metadata["label"],
doc.metadata["uri"],
doc.metadata["endpoint_url"],
doc.metadata["entity_type"],
embedding.tolist(),
]
writer.writerow(row)

# vectordb = get_vectordb()
# vectordb.upsert(
# collection_name="entities",
# points=models.Batch(
# ids=list(range(1, len(docs) + 1)),
# vectors=embeddings,
# payloads=[doc.metadata for doc in docs],
# ),
# # wait=False, # Waiting for indexing to finish or not
# )

0 comments on commit 4975997

Please sign in to comment.