Skip to content

Commit

Permalink
Composite Indexer
Browse files Browse the repository at this point in the history
  • Loading branch information
artitw committed Oct 15, 2024
1 parent bb4e810 commit 1d1c5c0
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 9 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="text2text",
version="1.7.8",
version="1.7.9",
author="artitw",
author_email="artitw@gmail.com",
description="Text2Text: Crosslingual NLP/G toolkit",
Expand Down
1 change: 1 addition & 0 deletions text2text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .tfidfer import Tfidfer
from .bm25er import Bm25er
from .indexer import Indexer
from .composite_indexer import CompositeIndexer
from .rag_assistant import RagAssistant
from .variator import Variator
from .identifier import Identifier
Expand Down
60 changes: 60 additions & 0 deletions text2text/composite_indexer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import text2text as t2t

import pandas as pd

from collections import Counter

CORPUS_COLUMNS = ["document", "embedding"]

def aggregate_and_sort(x):
# Flatten the list of lists
flat_list = [item for sublist in x for item in sublist]
# Count occurrences
counts = Counter(flat_list)
# Sort items by count (descending) and then by item (ascending)
sorted_items = sorted(counts.items(), key=lambda item: (-item[1], item[0]))
# Return only the items in sorted order
return [item[0] for item in sorted_items]

class CompositeIndexer(object):
def __init__(self):
index_sem = t2t.Indexer(encoders=[t2t.Vectorizer()]).transform([])
index_syn = t2t.Indexer(encoders=[t2t.Tfidfer()]).transform([])
self.indexes = [index_sem, index_syn]
self.corpus = pd.DataFrame(columns=CORPUS_COLUMNS)

def size(self, **kwargs):
return len(self.corpus.index)

def update_corpus(self):
new_rows = pd.DataFrame(columns=CORPUS_COLUMNS)
for index in self.indexes:
new_rows = pd.concat([new_rows, index.corpus], ignore_index=False)
self.corpus = new_rows.groupby(new_rows.index).agg({"document": "max", "embedding": list}).reset_index()

def add(self, texts, **kwargs):
embeddings = kwargs.get("embeddings", [None]*len(self.indexes))
for i, index in enumerate(self.indexes):
index.add(texts, embeddings=embeddings[i])
self.update_corpus()
return self

def remove(self, ids, faiss_index=None, **kwargs):
for index in self.indexes:
index.remove(ids)
self.update_corpus()

def retrieve(self, input_lines, **kwargs):
df = pd.DataFrame({"document": []})
for index in self.indexes:
res = index.retrieve(input_lines, **kwargs)
df2 = pd.DataFrame({"document": res})
df = pd.concat([df, df2], axis=0)
df = df.groupby(df.index).agg(aggregate_and_sort)
df.reset_index(drop=True, inplace=True)
return df["document"].tolist()

def transform(self, input_lines, src_lang='en', **kwargs):
if not input_lines:
return self
return self.add(input_lines, src_lang=src_lang, **kwargs)
1 change: 1 addition & 0 deletions text2text/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class Handler(object):
"translate": t2t.Translator,
"variate": t2t.Variator,
"vectorize": t2t.Vectorizer,
"composite_index": t2t.CompositeIndexer,
}

transformer_instances = {}
Expand Down
9 changes: 5 additions & 4 deletions text2text/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ class Indexer(t2t.Transformer):

def __init__(self, **kwargs):
self.input_lines = []
self.encoders = kwargs.get("encoders", [t2t.Vectorizer()])

self.encoders = kwargs.get("encoders", [t2t.Tfidfer()])
columns = ["document", "embedding"]
self.corpus = pd.DataFrame(columns=columns)

def get_formatted_matrix(self, input_lines, src_lang='en', **kwargs):
res = np.array([[]]*len(input_lines))
for encoder in self.encoders:
Expand Down Expand Up @@ -65,11 +67,10 @@ def retrieve(self, input_lines, k=3, **kwargs):
return [self.corpus["document"].loc[[i for i in line_ids if i >= 0]].tolist() for line_ids in pred_ids]

def transform(self, input_lines, src_lang='en', **kwargs):
super().transform(input_lines, src_lang='en', **kwargs)
super().transform(input_lines, src_lang=src_lang, **kwargs)
self.src_lang = src_lang
d = self.get_formatted_matrix(["DUMMY"], src_lang=src_lang, **kwargs).shape[-1]
self.index = faiss.IndexIDMap2(faiss.IndexFlatL2(d))
self.corpus = pd.DataFrame({"document": [], "embedding": []})
if not input_lines:
return self
return self.add(input_lines, src_lang=src_lang, **kwargs)
8 changes: 5 additions & 3 deletions text2text/rag_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ def __init__(self, **kwargs):
texts = kwargs.get("texts", [])
urls = kwargs.get("urls", [])
sqlite_path = kwargs.get("sqlite_path", None)
encoders = kwargs.get("encoders", [t2t.Vectorizer()])
self.index = t2t.Indexer(encoders=encoders).transform([])
self.index = t2t.CompositeIndexer()

if urls:
for u in tqdm(urls, desc='Scrape HTML'):
Expand Down Expand Up @@ -95,11 +94,14 @@ def __init__(self, **kwargs):
fields = ", ".join(db_fields)
query = f"SELECT {fields} FROM {RAG_TABLE_NAME}"
db_records = pd.read_sql_query(query, conn)
db_records.dropna(subset=["document", "embedding"], inplace=True)
conn.close()
embeddings = db_records["embedding"].apply(lambda x: pickle.loads(x))
embeddings = pd.DataFrame(embeddings.to_list())
embeddings = [np.vstack(embeddings[col]) for col in embeddings.columns]
self.index.add(
db_records["document"].tolist(),
embeddings=np.vstack(embeddings)
embeddings=embeddings
)
self.records = pd.concat([self.records, db_records], ignore_index=True)

Expand Down
7 changes: 6 additions & 1 deletion text2text/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import text2text as t2t

import warnings

from transformers import AutoTokenizer

class Tokenizer(t2t.Transformer):

def __init__(self, **kwargs):
pretrained_translator = self.__class__.PRETRAINED_TRANSLATOR
self.__class__.tokenizer = AutoTokenizer.from_pretrained(pretrained_translator)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
self.__class__.tokenizer = AutoTokenizer.from_pretrained(pretrained_translator)

def transform(self, input_lines, src_lang='en', output='tokens', **kwargs):
input_lines = t2t.Transformer.transform(self, input_lines, src_lang=src_lang, **kwargs)
Expand Down
12 changes: 12 additions & 0 deletions text2text/vectorizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
import text2text as t2t
from tqdm.auto import tqdm

class Vectorizer(t2t.Assistant):

def transform(self, input_lines, **kwargs):
sentencize = kwargs.get("sentencize", True)
if sentencize and input_lines != ["DUMMY"]:
sentences = []
for text in tqdm(input_lines, desc='Summarize'):
if len(text) > 100:
prompt = f'Summarize the following text to a single sentence:\n\n{text}'
result = self.chat_completion([{"role": "user", "content": prompt}])
sentences.append(result["message"]["content"])
else:
sentences.append(text)
return self.embed(sentences)
return self.embed(input_lines)

0 comments on commit 1d1c5c0

Please sign in to comment.