Skip to content

Commit

Permalink
RAG URL reference
Browse files Browse the repository at this point in the history
  • Loading branch information
artitw committed Oct 12, 2024
1 parent 9e7ea7c commit d567df1
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="text2text",
version="1.7.6",
version="1.7.7",
author="artitw",
author_email="artitw@gmail.com",
description="Text2Text: Crosslingual NLP/G toolkit",
Expand Down
2 changes: 1 addition & 1 deletion text2text/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __del__(self):
warnings.warn(str(e))

def load_model(self):
pbar = tqdm(total=6, desc='Model setup')
pbar = tqdm(total=6, desc='Model Setup')
if not ollama_version():
self.__del__()
pbar.update(1)
Expand Down
10 changes: 6 additions & 4 deletions text2text/rag_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ def __init__(self, **kwargs):
texts = kwargs.get("texts", [])
urls = kwargs.get("urls", [])
input_lines = []
for u in tqdm(urls, desc='Scrape URLs'):
for u in tqdm(urls, desc='Scrape HTML'):
if is_valid_url(u):
try:
texts.append(get_cleaned_html(u))
texts.append(get_cleaned_html(u) + f"\nURL: {u}")
except Exception as e:
warnings.warn(f"Skipping URL with errors: {u}")
else:
Expand All @@ -81,7 +81,7 @@ def __init__(self, **kwargs):
if schema:
column_names = schema.model_fields.keys()
self.records = pd.DataFrame(columns=column_names)
for t in tqdm(texts, desc='Schema extraction'):
for t in tqdm(texts, desc='Extract Schema'):
fields = ", ".join(column_names)
prompt = f'Extract {fields} from the following text:\n\n{t}'
res = t2t.Assistant.chat_completion(self, [{"role": "user", "content": prompt}], schema=schema)
Expand All @@ -93,6 +93,8 @@ def __init__(self, **kwargs):
input_lines = texts
self.records = pd.DataFrame({"text": texts})

print(input_lines)

self.index = t2t.Indexer().transform(input_lines, encoders=[t2t.Vectorizer()])
self.records = pd.concat([self.records, self.index.corpus], axis=1)
self.records["embedding"] = self.records["embedding"].apply(lambda x: pickle.dumps(x))
Expand All @@ -115,6 +117,6 @@ def chat_completion(self, messages=[{"role": "user", "content": "hello"}], strea
docs = self.index.retrieve([demand], k=k)[0]
else:
docs = self.index.retrieve([query], k=k)[0]
grounding_prompt = "Base your response on the following information:\n\n" + "\n- ".join(docs)
grounding_prompt = "Base your response on the following information:\n\n" + "\n\n".join(docs)
messages[-1] = {"role": "user", "content": query + "\n\n" + grounding_prompt}
return t2t.Assistant.chat_completion(self, messages=messages, stream=stream, schema=schema, **kwargs)

0 comments on commit d567df1

Please sign in to comment.