-
Notifications
You must be signed in to change notification settings - Fork 63
/
Copy pathindex.py
103 lines (92 loc) · 4.42 KB
/
index.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import json
import torch
from sentence_transformers import SentenceTransformer
from collections import defaultdict
from typing import List, Dict, Tuple, Union
import torch
from PIL import Image
import pickle
from openai import OpenAI
import os
import torch
import time
import yaml
class MemoryIndex:
def __init__(self,number_of_neighbours,use_openai=False):
self.documents = {}
self.document_vectors = {}
self.use_openai=use_openai
if use_openai:
api_key = os.getenv("OPENAI_API_KEY")
self.client = OpenAI(api_key=api_key)
self.model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
# self.model = SentenceTransformer('sentence-transformers/paraphrase-MiniLM-L6-v2')
with open('test_configs/llama2_test_config.yaml') as file:
config = yaml.load(file, Loader=yaml.FullLoader)
embedding_gpu_id=config['model']['minigpt4_gpu_id']
self.device = f"cuda:{embedding_gpu_id}" if torch.cuda.is_available() else "cpu"
self.number_of_neighbours=int(number_of_neighbours)
def load_documents_from_json(self, file_path,emdedding_path=""):
with open(file_path, 'r') as file:
data = json.load(file)
for doc_id, doc_data in data.items():
self.documents[doc_id] = doc_data
self.document_vectors[doc_id] = self._compute_sentence_embedding(doc_data)
# save self.documents and self.document_vectors to pkl file
m=[self.documents,self.document_vectors]
with open(emdedding_path, 'wb') as file:
pickle.dump(m, file)
return emdedding_path
def load_embeddings_from_pkl(self, pkl_file_path):
#read the pkl file
with open(pkl_file_path, 'rb') as file:
data = pickle.load(file)
self.documents=data[0]
self.document_vectors=data[1]
def load_data_from_pkl(self, pkl_file_path):
with open(pkl_file_path, 'rb') as file:
data = pickle.load(file)
for doc_id, doc_data in data.items():
self.documents[doc_id] = doc_data
self.document_vectors[doc_id] = doc_data
def _compute_sentence_embedding(self, text: str) -> torch.Tensor:
if self.use_openai:
done=False
while not done:
try:
embedding=self.client.embeddings.create(input = [text], model="text-embedding-3-small").data[0].embedding
# Convert the list to a PyTorch tensor
embedding = torch.tensor(embedding)
done=True
except Exception as e:
print("error",e)
print("text",text)
# sleep for 5 seconds and try again
time.sleep(5)
continue
else:
return self.model.encode(text, convert_to_tensor=True).to(self.device)
return embedding
def search_by_similarity(self, query: str) -> List[str]:
query_vector = self._compute_sentence_embedding(query)
scores = {doc_id: torch.nn.functional.cosine_similarity(query_vector, doc_vector, dim=0).item()
for doc_id, doc_vector in self.document_vectors.items()}
sorted_doc_ids = sorted(scores, key=scores.get, reverse=True)
sorted_documents=[self.documents[doc_id] for doc_id in sorted_doc_ids]
if self.number_of_neighbours == -1:
return list(self.documents.values()), list(self.documents.keys())
if self.number_of_neighbours > len(sorted_documents):
return sorted_documents, sorted_doc_ids
# if the retrieved document is the summary, return the summary and the next document to grauntee that always retieve clip name.
if self.number_of_neighbours==1 and sorted_doc_ids[0]=='summary':
return sorted_documents[0:2], sorted_doc_ids[:2]
print("Number of neighbours",self.number_of_neighbours)
return sorted_documents[:self.number_of_neighbours], sorted_doc_ids[:self.number_of_neighbours]
# # main function
# if __name__ == "__main__":
# memory_index = MemoryIndex(-1,use_openai=True)
# memory_index.load_documents_from_json('workspace/results/llama_vid/tt0035423.json')
# print(memory_index.documents.keys())
# docs,keys=memory_index.search_by_similarity('kerolos')