From 1d1c5c0a423b7cdb9659f3fdf445197352c3482f Mon Sep 17 00:00:00 2001 From: artitw Date: Tue, 15 Oct 2024 06:20:18 +0000 Subject: [PATCH] Composite Indexer --- setup.py | 2 +- text2text/__init__.py | 1 + text2text/composite_indexer.py | 60 ++++++++++++++++++++++++++++++++++ text2text/handler.py | 1 + text2text/indexer.py | 9 ++--- text2text/rag_assistant.py | 8 +++-- text2text/tokenizer.py | 7 +++- text2text/vectorizer.py | 12 +++++++ 8 files changed, 91 insertions(+), 9 deletions(-) create mode 100644 text2text/composite_indexer.py diff --git a/setup.py b/setup.py index 8e67038..7b9fc01 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="text2text", - version="1.7.8", + version="1.7.9", author="artitw", author_email="artitw@gmail.com", description="Text2Text: Crosslingual NLP/G toolkit", diff --git a/text2text/__init__.py b/text2text/__init__.py index a559f44..881832d 100644 --- a/text2text/__init__.py +++ b/text2text/__init__.py @@ -9,6 +9,7 @@ from .tfidfer import Tfidfer from .bm25er import Bm25er from .indexer import Indexer +from .composite_indexer import CompositeIndexer from .rag_assistant import RagAssistant from .variator import Variator from .identifier import Identifier diff --git a/text2text/composite_indexer.py b/text2text/composite_indexer.py new file mode 100644 index 0000000..571c605 --- /dev/null +++ b/text2text/composite_indexer.py @@ -0,0 +1,60 @@ +import text2text as t2t + +import pandas as pd + +from collections import Counter + +CORPUS_COLUMNS = ["document", "embedding"] + +def aggregate_and_sort(x): + # Flatten the list of lists + flat_list = [item for sublist in x for item in sublist] + # Count occurrences + counts = Counter(flat_list) + # Sort items by count (descending) and then by item (ascending) + sorted_items = sorted(counts.items(), key=lambda item: (-item[1], item[0])) + # Return only the items in sorted order + return [item[0] for item in sorted_items] + +class CompositeIndexer(object): + def __init__(self): + index_sem = t2t.Indexer(encoders=[t2t.Vectorizer()]).transform([]) + index_syn = t2t.Indexer(encoders=[t2t.Tfidfer()]).transform([]) + self.indexes = [index_sem, index_syn] + self.corpus = pd.DataFrame(columns=CORPUS_COLUMNS) + + def size(self, **kwargs): + return len(self.corpus.index) + + def update_corpus(self): + new_rows = pd.DataFrame(columns=CORPUS_COLUMNS) + for index in self.indexes: + new_rows = pd.concat([new_rows, index.corpus], ignore_index=False) + self.corpus = new_rows.groupby(new_rows.index).agg({"document": "max", "embedding": list}).reset_index() + + def add(self, texts, **kwargs): + embeddings = kwargs.get("embeddings", [None]*len(self.indexes)) + for i, index in enumerate(self.indexes): + index.add(texts, embeddings=embeddings[i]) + self.update_corpus() + return self + + def remove(self, ids, faiss_index=None, **kwargs): + for index in self.indexes: + index.remove(ids) + self.update_corpus() + + def retrieve(self, input_lines, **kwargs): + df = pd.DataFrame({"document": []}) + for index in self.indexes: + res = index.retrieve(input_lines, **kwargs) + df2 = pd.DataFrame({"document": res}) + df = pd.concat([df, df2], axis=0) + df = df.groupby(df.index).agg(aggregate_and_sort) + df.reset_index(drop=True, inplace=True) + return df["document"].tolist() + + def transform(self, input_lines, src_lang='en', **kwargs): + if not input_lines: + return self + return self.add(input_lines, src_lang=src_lang, **kwargs) \ No newline at end of file diff --git a/text2text/handler.py b/text2text/handler.py index 4283ff6..f55e707 100644 --- a/text2text/handler.py +++ b/text2text/handler.py @@ -19,6 +19,7 @@ class Handler(object): "translate": t2t.Translator, "variate": t2t.Variator, "vectorize": t2t.Vectorizer, + "composite_index": t2t.CompositeIndexer, } transformer_instances = {} diff --git a/text2text/indexer.py b/text2text/indexer.py index 769696a..310aa5e 100644 --- a/text2text/indexer.py +++ b/text2text/indexer.py @@ -10,8 +10,10 @@ class Indexer(t2t.Transformer): def __init__(self, **kwargs): self.input_lines = [] - self.encoders = kwargs.get("encoders", [t2t.Vectorizer()]) - + self.encoders = kwargs.get("encoders", [t2t.Tfidfer()]) + columns = ["document", "embedding"] + self.corpus = pd.DataFrame(columns=columns) + def get_formatted_matrix(self, input_lines, src_lang='en', **kwargs): res = np.array([[]]*len(input_lines)) for encoder in self.encoders: @@ -65,11 +67,10 @@ def retrieve(self, input_lines, k=3, **kwargs): return [self.corpus["document"].loc[[i for i in line_ids if i >= 0]].tolist() for line_ids in pred_ids] def transform(self, input_lines, src_lang='en', **kwargs): - super().transform(input_lines, src_lang='en', **kwargs) + super().transform(input_lines, src_lang=src_lang, **kwargs) self.src_lang = src_lang d = self.get_formatted_matrix(["DUMMY"], src_lang=src_lang, **kwargs).shape[-1] self.index = faiss.IndexIDMap2(faiss.IndexFlatL2(d)) - self.corpus = pd.DataFrame({"document": [], "embedding": []}) if not input_lines: return self return self.add(input_lines, src_lang=src_lang, **kwargs) diff --git a/text2text/rag_assistant.py b/text2text/rag_assistant.py index e8f5ac2..bbd5301 100644 --- a/text2text/rag_assistant.py +++ b/text2text/rag_assistant.py @@ -55,8 +55,7 @@ def __init__(self, **kwargs): texts = kwargs.get("texts", []) urls = kwargs.get("urls", []) sqlite_path = kwargs.get("sqlite_path", None) - encoders = kwargs.get("encoders", [t2t.Vectorizer()]) - self.index = t2t.Indexer(encoders=encoders).transform([]) + self.index = t2t.CompositeIndexer() if urls: for u in tqdm(urls, desc='Scrape HTML'): @@ -95,11 +94,14 @@ def __init__(self, **kwargs): fields = ", ".join(db_fields) query = f"SELECT {fields} FROM {RAG_TABLE_NAME}" db_records = pd.read_sql_query(query, conn) + db_records.dropna(subset=["document", "embedding"], inplace=True) conn.close() embeddings = db_records["embedding"].apply(lambda x: pickle.loads(x)) + embeddings = pd.DataFrame(embeddings.to_list()) + embeddings = [np.vstack(embeddings[col]) for col in embeddings.columns] self.index.add( db_records["document"].tolist(), - embeddings=np.vstack(embeddings) + embeddings=embeddings ) self.records = pd.concat([self.records, db_records], ignore_index=True) diff --git a/text2text/tokenizer.py b/text2text/tokenizer.py index c1cbf7e..d245c21 100644 --- a/text2text/tokenizer.py +++ b/text2text/tokenizer.py @@ -1,11 +1,16 @@ import text2text as t2t + +import warnings + from transformers import AutoTokenizer class Tokenizer(t2t.Transformer): def __init__(self, **kwargs): pretrained_translator = self.__class__.PRETRAINED_TRANSLATOR - self.__class__.tokenizer = AutoTokenizer.from_pretrained(pretrained_translator) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + self.__class__.tokenizer = AutoTokenizer.from_pretrained(pretrained_translator) def transform(self, input_lines, src_lang='en', output='tokens', **kwargs): input_lines = t2t.Transformer.transform(self, input_lines, src_lang=src_lang, **kwargs) diff --git a/text2text/vectorizer.py b/text2text/vectorizer.py index c8a907b..738763e 100644 --- a/text2text/vectorizer.py +++ b/text2text/vectorizer.py @@ -1,6 +1,18 @@ import text2text as t2t +from tqdm.auto import tqdm class Vectorizer(t2t.Assistant): def transform(self, input_lines, **kwargs): + sentencize = kwargs.get("sentencize", True) + if sentencize and input_lines != ["DUMMY"]: + sentences = [] + for text in tqdm(input_lines, desc='Summarize'): + if len(text) > 100: + prompt = f'Summarize the following text to a single sentence:\n\n{text}' + result = self.chat_completion([{"role": "user", "content": prompt}]) + sentences.append(result["message"]["content"]) + else: + sentences.append(text) + return self.embed(sentences) return self.embed(input_lines) \ No newline at end of file