Skip to content

Commit

Permalink
Merge pull request #14 from VectorInstitute/add_embedding_service
Browse files Browse the repository at this point in the history
Add embedding service
  • Loading branch information
amrit110 authored Sep 23, 2024
2 parents bc96103 + dd550ce commit 3aee52b
Show file tree
Hide file tree
Showing 34 changed files with 4,401 additions and 324 deletions.
9 changes: 8 additions & 1 deletion .env.development
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@ BACKEND_PORT=8002
MEDCAT_MODELS_DIR=/mnt/data/medcat
NER_SERVICE_PORT=8003

EMBEDDING_SERVICE_PORT=8004
EMBEDDING_SERVICE_HOST=embedding-service-dev
BATCH_SIZE=1

MILVUS_HOST=milvus-standalone
MILVUS_PORT=19530

LLM_SERVICE_HOST=gpu043
LLM_SERVICE_PORT=8080

Expand All @@ -15,4 +22,4 @@ JWT_SECRET_KEY=a7c67720790687c81faddeeeb70d6cbf3820352b3567ba4c47593afe65d956b5
MONGO_USERNAME=root
MONGO_PASSWORD=password

MEDS_DATA_DIR=/mnt/data/odyssey/meds/merge_to_MEDS_cohort/train
MEDS_DATA_DIR=/mnt/data/odyssey/meds/hosp/merge_to_MEDS_cohort/train
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,4 @@ next-env.d.ts

# result files
validation_results.json
volumes
13 changes: 13 additions & 0 deletions backend/api/patients/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,19 @@
from pydantic import BaseModel, Field


class Query(BaseModel):
"""
Represents a query.
Attributes
----------
query : str
The query.
"""

query: str


class QAPair(BaseModel):
"""
Represents a question-answer pair.
Expand Down
153 changes: 153 additions & 0 deletions backend/api/patients/rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""RAG (Retrieval-Augmented Generation) API for patient data."""

import asyncio
from typing import Any, Dict, List

import httpx
from pymilvus import Collection, connections, utility


COLLECTION_NAME = "patient_notes"


class EmbeddingManager:
"""A class to manage embeddings."""

def __init__(self, embedding_service_url: str):
"""Initialize the EmbeddingManager.
Parameters
----------
embedding_service_url : str
The URL of the embedding service.
"""
self.embedding_service_url = embedding_service_url
self.client = httpx.AsyncClient(timeout=60.0)

async def get_embedding(self, text: str) -> List[float]:
"""Get the embedding for a given text.
Parameters
----------
text : str
The text to embed.
Returns
-------
List[float]
The embedding for the given text.
"""
response = await self.client.post(
self.embedding_service_url,
json={"texts": [text], "instruction": "Represent the query for retrieval:"},
)
response.raise_for_status()
return response.json()["embeddings"][0]

async def close(self):
"""Close the client."""
await self.client.aclose()


class MilvusManager:
"""A class to manage Milvus."""

def __init__(self, host: str, port: int):
"""Initialize the MilvusManager.
Parameters
----------
host : str
The host of the Milvus server.
port : int
The port of the Milvus server.
"""
self.host = host
self.port = port
self.collection_name = COLLECTION_NAME
self.collection = None

def connect(self):
"""Connect to the Milvus server.
Raises
------
ValueError
If the collection does not exist in Milvus.
"""
connections.connect(host=self.host, port=self.port)
if not utility.has_collection(self.collection_name):
raise ValueError(
f"Collection {self.collection_name} does not exist in Milvus"
)

def get_collection(self) -> Collection:
"""Get the collection from Milvus.
Returns
-------
Collection
The collection from Milvus.
"""
if self.collection is None:
self.collection = Collection(self.collection_name)
return self.collection

def load_collection(self):
"""Load the collection from Milvus.
Raises
------
ValueError
If the collection is not loaded.
"""
collection = self.get_collection()
collection.load()

async def ensure_collection_loaded(self):
"""Ensure the collection is loaded from Milvus.
Raises
------
ValueError
If the collection is not loaded.
"""
collection = self.get_collection()
# The load() method is synchronous and blocks until the collection is loaded
await asyncio.to_thread(collection.load)

async def search(
self, query_vector: List[float], top_k: int
) -> List[Dict[str, Any]]:
"""Search for the nearest neighbors in Milvus.
Parameters
----------
query_vector : List[float]
The query vector.
top_k : int
The number of nearest neighbors to return.
Returns
-------
List[Dict[str, Any]]
The nearest neighbors.
"""
await self.ensure_collection_loaded()
collection = self.get_collection()
search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
results = collection.search(
data=[query_vector],
anns_field="embedding",
param=search_params,
limit=top_k,
output_fields=["patient_id", "note_id"],
)
return [
{
"patient_id": hit.entity.get("patient_id"),
"note_id": hit.entity.get("note_id"),
"distance": hit.distance,
}
for hit in results[0]
]
86 changes: 84 additions & 2 deletions backend/api/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
from typing import Any, Dict, List

import httpx
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi import APIRouter, Body, Depends, HTTPException, Request, status
from motor.motor_asyncio import AsyncIOMotorDatabase
from sqlalchemy.ext.asyncio import AsyncSession

from api.patients.data import ClinicalNote, NERResponse, PatientData, QAPair
from api.patients.data import ClinicalNote, NERResponse, PatientData, QAPair, Query
from api.patients.db import get_database
from api.patients.ehr import fetch_patient_events, init_lazy_df
from api.patients.rag import EmbeddingManager, MilvusManager
from api.users.auth import (
ACCESS_TOKEN_EXPIRE_MINUTES,
authenticate_user,
Expand Down Expand Up @@ -48,11 +49,25 @@
MEDS_DATA_DIR = os.getenv(
"MEDS_DATA_DIR", "/mnt/data/odyssey/meds/merge_to_MEDS_cohort/train"
)
MILVUS_HOST = os.getenv("MILVUS_HOST", "localhost")
MILVUS_PORT = os.getenv("MILVUS_PORT", "19530")
EMBEDDING_SERVICE_HOST = os.getenv("EMBEDDING_SERVICE_HOST", "localhost")
EMBEDDING_SERVICE_PORT = os.getenv("EMBEDDING_SERVICE_PORT", "8004")
EMBEDDING_SERVICE_URL = (
f"http://{EMBEDDING_SERVICE_HOST}:{EMBEDDING_SERVICE_PORT}/embeddings"
)
COLLECTION_NAME = "patient_notes"
TOP_K = 5

# Initialize the lazy DataFrame
init_lazy_df(MEDS_DATA_DIR)


EMBEDDING_MANAGER = EmbeddingManager(EMBEDDING_SERVICE_URL)
MILVUS_MANAGER = MilvusManager(MILVUS_HOST, MILVUS_PORT)
MILVUS_MANAGER.connect()


@router.get("/database_summary", response_model=Dict[str, Any])
async def get_database_summary(
db: AsyncIOMotorDatabase[Any] = Depends(get_database), # noqa: B008
Expand Down Expand Up @@ -119,6 +134,73 @@ async def get_database_summary(
) from e


@router.post("/retrieve")
async def retrieve_relevant_patients(
query: Query = Body(...), # noqa: B008
db: AsyncIOMotorDatabase = Depends(get_database), # noqa: B008
current_user: User = Depends(get_current_active_user), # noqa: B008
) -> List[PatientData]:
"""
Retrieve relevant patients based on the query.
Parameters
----------
query : Query
The query to retrieve relevant patients.
db : AsyncIOMotorDatabase
The database connection.
current_user : User
The current authenticated user.
Returns
-------
List[PatientData]
The list of relevant patients.
"""
try:
await MILVUS_MANAGER.ensure_collection_loaded()
# Get query embedding
query_embedding = await EMBEDDING_MANAGER.get_embedding(query.query)

# Search Milvus for similar notes
search_results = await MILVUS_MANAGER.search(query_embedding, TOP_K)

# Retrieve patient data for the top results
patient_ids = list({result["patient_id"] for result in search_results})
patients = await db.patients.find({"patient_id": {"$in": patient_ids}}).to_list(
None
)

patient_dict = {patient["patient_id"]: patient for patient in patients}

patient_data_list = []
for result in search_results:
patient_id = result["patient_id"]
note_id = result["note_id"]

patient = patient_dict.get(patient_id)
if patient:
notes = [
ClinicalNote(**note)
for note in patient.get("notes", [])
if note["note_id"] == note_id
]
qa_pairs = [QAPair(**qa) for qa in patient.get("qa_pairs", [])]
events = fetch_patient_events(patient_id)

patient_data = PatientData(
patient_id=patient_id, notes=notes, qa_data=qa_pairs, events=events
)
patient_data_list.append(patient_data)

return patient_data_list

except Exception as e:
raise HTTPException(
status_code=500, detail=f"An error occurred: {str(e)}"
) from e


@router.get("/patient_data/{patient_id}", response_model=PatientData)
async def get_patient_data(
patient_id: int,
Expand Down
2 changes: 1 addition & 1 deletion backend/api/users/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
if not SECRET_KEY:
raise ValueError("JWT_SECRET_KEY environment variable is not set")
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
ACCESS_TOKEN_EXPIRE_MINUTES = 60

# OAuth2 scheme
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/auth/signin")
Expand Down
Loading

0 comments on commit 3aee52b

Please sign in to comment.