-
Notifications
You must be signed in to change notification settings - Fork 0
/
cerebrusrvd.py
201 lines (153 loc) · 8.05 KB
/
cerebrusrvd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
# Necessary imports
import streamlit as st
from cerebras.cloud.sdk import Cerebras
import ast
from sentence_transformers import SentenceTransformer
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
# Page config
st.set_page_config(page_title="Cerebras RVD", page_icon="↗️")
st.title('Cerebras RVD ↗️')
CEREBRAS_API_KEY = st.secrets["cerebras_api_key"] # replace with your own key
client = Cerebras(api_key=CEREBRAS_API_KEY)
# Cerebras LLM Call
@st.cache_data
def cerebras_rvd_api(user_query):
system_message = """\
You are an AI assistant designed to help improve semantic vector search by prioritizing key terms or suggesting changes to disambiguate queries. Your task is to analyze the input and provide a list of essential terms or concepts that best capture the core meaning, while reducing ambiguity. In cases of idiomatic expressions or colloquialisms, provide the intended meaning instead of the literal words.
Follow these guidelines:
1. Identify and prioritize the most important terms or concepts in the query.
2. Remove common words that don't add significant meaning.
3. If the input contains idioms or colloquialisms, replace them with their intended meaning.
4. For ambiguous terms, try to provide context or use more specific alternatives.
5. Do not use words that have multiple meanings
6. Output is a list of strings in square brackets, separated by commas. Directly give the list and talk nothing else.
Examples:
Input:
Italian food serving restaurant having outdoor seating
Output:
["restaurant","Italian food","outdoor seating"]
Input:
Just go out there and break a leg!
Output:
["wishing good luck"]
Input:
I need a new mouse for my computer
Output:
["computer mouse","input device"]
Input:
The fisherman went to the bank to deposit his catch
Output:
["fisherman","financial institution","deposit","fish catch"]
Input:
How to make a vegan apple pie without added sugar
Output:
["vegan apple pie recipe","sugar-free dessert","healthy baking"]
Input:
I need tips for growing tomatoes in a small urban garden
Output:
["tomato gardening tips","urban gardening","small space cultivation"]"""
response = client.chat.completions.create(
model="llama3.1-70b", # You can choose between llama3.1-8b and llama3.1-70b
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": f"Input: {user_query}"},
],
max_tokens=1000,
temperature=0,
)
return response.choices[0].message.content
# download embedding model
@st.cache_resource
def load_model():
model = SentenceTransformer('all-MiniLM-L6-v2')
return model
rvd_info = """\
`Recursive Vector Disambiguation (RVD)` is a technique to improve
semantic vector search process by using LLMs like Cerebras Llama-3.1-70B
to prioritize terms in the query and then matching for similar vectors
in the order of prioritized terms.
The terms may be from the query or generated by the `Cerebras Llama-3.1-70B` to
better encapsulate the intent of the query.
So, you first get similar matches for the highest priority term and then
rerank the list of vectors based on the remaining terms. This ensures **high
quality results**/top-matches.
-----
**For Ex**: When not using RVD,
the query `'go break a leg'` might match sentences related to
*literal leg injuries* instead of the correct interpretation of
*wishing good luck*.
Similarly, the query `Italian food serving restaurant having outdoor seating`
might wrongly match sentences that contain entities related to Italy or
restaurant while the user has specifically constrained the search
to the mentioned terms. A lot depends on what the training objective
of the embedding model was. RVD tries to improve this search experience.
You can try the demo below.
"""
# RVD info
with st.expander("Recursive Vector Disambiguation (RVD)"):
st.write(rvd_info)
# Template selection
template = st.selectbox("Select Sample Template",("Template 1","Template 2","Template 3","Custom"))
if template == "Template 1":
ta_value = "Pasta inside cafe\nBest pizza places in France with indoor seating only\nRomantic dinner spots with a view\nCafes with outdoor patios serving Italian cuisine\nHigh-end Italian indoor dining experiences\nCasual Italian eateries with al fresco options\nTop-rated chinese restaurants with garden seating\nAuthentic Italian cuisine in an intimate setting\nItalian Restaurants with rooftop dining areas\nBudget-friendly Italian food trucks\nMichelin-starred Japanese restaurants with terraces\nQuick Italian takeout options\nFrench bakery\nHistoric Italian restaurants with courtyard seating"
uq_value = "restaurants that serve Italian food and have outdoor seating"
elif template == "Template 2":
ta_value = "cats and dogs are falling from the sky\nheavy rain\nlot of animals in that vet house\ntorrential downpour expected\nwater gushing from gutter"
uq_value = "raining cats and dogs"
elif template == "Template 3":
ta_value = "Best coffee beans for brewing the perfect cup of Java\nSetting up Eclipse IDE for Java programming on Windows 10\nJava island's rich cultural heritage and tourist attractions\nTop 5 Java cafes in Seattle with great working environments\nDebugging techniques for identifying and resolving system errors"
uq_value = "Java development environment setup"
else:
ta_value = ""
uq_value = ""
corpus = st.text_area("Corpus Sentences (New line separated)",value=ta_value,height=300)
user_query = st.text_input("Query",value=uq_value)
if st.button("Vector Search!"):
if ta_value == "" or uq_value == "":
st.error("Please enter atleast 2 corpus sentences and a query")
else:
documents = [doc.strip() for doc in corpus.split("\n") if doc.strip()]
model = load_model()
llm_response = cerebras_rvd_api(user_query)
req_index = llm_response.find('[')
req_string_list = llm_response[req_index:]
priority_terms = ast.literal_eval(req_string_list)
# calculate embeddings for corpus
corpus_embeddings = model.encode(documents)
#st.write(corpus_embeddings)
def get_top_matches(query_embedding, embeddings, top_k=10):
similarities = cosine_similarity([query_embedding], embeddings)[0]
top_indices = similarities.argsort()[-min(top_k, len(similarities)):][::-1]
return top_indices, similarities[top_indices]
st.subheader("Cerebras LLM Generated Priority Terms")
st.write(priority_terms)
# Direct similarity search
col1, col2 = st.columns(2)
col1.subheader("Direct Similarity Search")
with st.spinner("Performing Direct Search..."):
query_embedding = model.encode([user_query])[0]
#st.write(query_embedding)
top_indices, similarities = get_top_matches(query_embedding, corpus_embeddings)
for idx, sim in zip(top_indices, similarities):
col1.write(f"**{sim:.3f}**: {documents[idx]}")
# RVD search
def rvd_search(priority_terms, documents, corpus_embeddings, model, top_k=10):
all_similarities = []
for term in priority_terms:
current_query = term
term_embedding = model.encode([current_query])[0]
# Calculate similarities for the current terms
similarities = cosine_similarity([term_embedding], corpus_embeddings)[0]
all_similarities.append(similarities)
# Calculate weighted scores
weights = np.linspace(1, 0.8, len(priority_terms)) # Linear weights from 1 to 0.8
weighted_similarities = np.sum(weights[:, np.newaxis] * np.array(all_similarities), axis=0)
# Get top_k indices based on weighted similarities
top_indices = weighted_similarities.argsort()[-top_k:][::-1]
return top_indices, weighted_similarities[top_indices]
col2.subheader("RVD Search")
with st.spinner("Performing RVD Search..."):
rvd_results, rvd_similarities = rvd_search(priority_terms, documents, corpus_embeddings, model)
for idx, sim in zip(rvd_results, rvd_similarities):
col2.write(f"**{sim:.3f}**: {documents[idx]}")