From e436f631933ac7a5ffd659d5cfdd46bcf8569b45 Mon Sep 17 00:00:00 2001 From: Aviad Hashuel Date: Wed, 27 Sep 2023 20:29:02 +0300 Subject: [PATCH 01/10] update version --- splade4elastic/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/splade4elastic/__init__.py b/splade4elastic/__init__.py index adb8105..4c74b77 100644 --- a/splade4elastic/__init__.py +++ b/splade4elastic/__init__.py @@ -1,2 +1,2 @@ -__version__ = "0.0.22" +__version__ = "0.0.23" from .elastic_splade import SpladeRewriter, MLMBaseRewriter, LinearMLMRewriter From 5c1cc991a9a91e87a678be20880117ad9bdc4239 Mon Sep 17 00:00:00 2001 From: Aviad Hashuel Date: Wed, 27 Sep 2023 20:30:38 +0300 Subject: [PATCH 02/10] add 'ignore' multi words --- splade4elastic/elastic_splade.py | 20 ++++++++++++-------- test/basic_test.py | 10 ++++++++-- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/splade4elastic/elastic_splade.py b/splade4elastic/elastic_splade.py index 5429513..375ca53 100644 --- a/splade4elastic/elastic_splade.py +++ b/splade4elastic/elastic_splade.py @@ -7,17 +7,19 @@ class MLMBaseRewriter: """Elastic SPLADE model""" - def __init__(self, model_name: str, expansions_per_word:int = 10): + def __init__(self, model_name: str, expansions_per_word:int = 10, multi_word="split", exluded_words=[]): """Initialize the model Args: model_name (str): name of the model + multi_word (str, optional): How to handle multi-word tokens. Defaults to "split". Can be "filter" or "ignore" """ self.tokenizer = AutoTokenizer.from_pretrained(model_name, bos_token="") self.model = AutoModelForMaskedLM.from_pretrained(model_name) self.k=expansions_per_word self.exluded_words = exluded_words self.const_weight = 1 + self.multi_word = multi_word def __tokenize_to_words(self, sentence): return sentence.translate( @@ -61,12 +63,14 @@ def mask_expansion(self, txt): X = self.tokenizer.encode(txt, return_tensors="pt") word2token = self.__tokenize_and_preserve(txt) words = self.__tokenize_to_words(txt) - for wi, lst in word2token: if not self.do_expansion(words[wi]): # skip this word ret[wi].append((words[wi], self.const_weight)) continue + if self.multi_word == "ignore" and len(lst) > 1: # skip multi-word tokens + ret[wi].append((words[wi], self.const_weight)) + continue X_m = X.clone() for mask_token_index, token, _ in lst: ti = mask_token_index @@ -84,7 +88,7 @@ def mask_expansion(self, txt): ret[wi].extend(zip(max_tokens, max_scores)) ret = dict(ret) - if self.tokenizer.bos_token: + if self.tokenizer.bos_token == ret[0][0]: del ret[0] return list(ret.values()) @@ -101,8 +105,11 @@ def __force_weights(self, word_logits, txt): def __elastic_format(self, expanded_list, text): ret = [] text = text.lower().split() + # print(text) for words in expanded_list: + # print(words) words = self.logits2weights(words) + # print(words) # The original word should have a higher score words = self.__force_weights(words, text) words = [(self.__only_alpha(w[0]).lower(), w[1]) for w in words if w[0] != self.tokenizer.bos_token] @@ -111,11 +118,8 @@ def __elastic_format(self, expanded_list, text): words = [(unique, sum(w[1] for w in words if w[0] == unique)) for unique in unique_words] # sort by score words = sorted(words, key=lambda x: x[1], reverse=True) - # print(words) - or_statement = [] - for w in words: - or_statement.append(f"{w[0]}^{round(float(w[1]), 2)}") - or_statement = " OR ".join(or_statement) + or_statement_list = [f"{w[0]}^{round(float(w[1]), 2)}" for w in words] + or_statement = " OR ".join(or_statement_list) or_statement = f"({or_statement})" ret.append(or_statement) return " ".join(ret) diff --git a/test/basic_test.py b/test/basic_test.py index b95619d..8a55d7c 100644 --- a/test/basic_test.py +++ b/test/basic_test.py @@ -9,8 +9,14 @@ class BasicTest(unittest.TestCase): def setUp(self): pass - def test_base(self): - splade = MLMBaseRewriter("roberta-base", expansions_per_word=3) + def test_base_with_ignore(self): + splade = MLMBaseRewriter("roberta-base", expansions_per_word=3, multi_word="ignore") + text = "Coffee is good for you" + print(splade.query_expand(text)) + self.assertTrue(True) + + def test_base_with_split(self): + splade = MLMBaseRewriter("roberta-base", expansions_per_word=3, multi_word="split") text = "Coffee is good for you" print(splade.query_expand(text)) self.assertTrue(True) From 0d88b48d873cd29e8d0f658999849cadefc13cd0 Mon Sep 17 00:00:00 2001 From: Aviad Hashuel Date: Thu, 28 Sep 2023 00:48:35 +0300 Subject: [PATCH 03/10] split and combine all --- splade4elastic/elastic_splade.py | 56 ++++++++++++++++++++++++++++++-- 1 file changed, 53 insertions(+), 3 deletions(-) diff --git a/splade4elastic/elastic_splade.py b/splade4elastic/elastic_splade.py index 375ca53..bd149e0 100644 --- a/splade4elastic/elastic_splade.py +++ b/splade4elastic/elastic_splade.py @@ -60,6 +60,7 @@ def do_expansion(self, word): def mask_expansion(self, txt): ret = collections.defaultdict(list) + special_tokens = [self.tokenizer.bos_token, self.tokenizer.eos_token, self.tokenizer.mask_token] X = self.tokenizer.encode(txt, return_tensors="pt") word2token = self.__tokenize_and_preserve(txt) words = self.__tokenize_to_words(txt) @@ -78,6 +79,7 @@ def mask_expansion(self, txt): # ti += 1 X_m[0, ti] = self.tokenizer.mask_token_id logits = self.model(X_m).logits + all_combinations = [] for mask_token_index, token, _ in lst: mask_token_logits = logits[0, mask_token_index, :] max_ids = np.argsort(mask_token_logits.to("cpu").detach().numpy())[ @@ -85,13 +87,61 @@ def mask_expansion(self, txt): ][:self.k] max_tokens = self.tokenizer.convert_ids_to_tokens(max_ids) max_scores = np.sort(mask_token_logits.to("cpu").detach().numpy())[::-1][ :self.k] + tokens_tuple = zip(max_tokens, max_scores) + sub_words = [(t, s) for t, s in tokens_tuple if t not in special_tokens] + all_combinations.append(sub_words) - ret[wi].extend(zip(max_tokens, max_scores)) + # create all possible combinations of sub-words and normalize their scores + # for example: [('Why', 16.411238), ('C', 16.230715)] + # [('off', 19.477955), ('Employ', 14.500462), ('ut', 14.452579)] + # should be combined to: + # [('Why off', 16.411238*19.477955), ('Why Employ', 16.411238*14.500462), ('Why ut', 16.411238*14.452579), ('C off', 16.230715*19.477955), ('C Employ', 16.230715*14.500462), ('C ut', 16.230715*14.452579)] + all_combinations = self.combine_and_normalize(all_combinations) + ret[wi].extend(all_combinations) + # ret[wi].extend(zip(max_tokens, max_scores)) ret = dict(ret) - if self.tokenizer.bos_token == ret[0][0]: - del ret[0] return list(ret.values()) + + def combine_and_normalize(self, all_combinations): + result = [] + + # Filter out empty sub-lists + non_empty_combinations = [sub_list for sub_list in all_combinations if sub_list] + + # Check if there are any non-empty sub-lists + if not non_empty_combinations: + return result + + # Initialize with the first non-empty sub-list + initial_combination = non_empty_combinations[0] + + # Check if there's only one non-empty sub-list + if len(non_empty_combinations) == 1: + return initial_combination + + # Create a dictionary to store maximum scores for each word + max_scores = {word: max(score for _, score in sub_list) for sub_list in non_empty_combinations for word, _ in sub_list} + + # Iterate through all possible combinations of sub-words + for sub_word_combination in itertools.product(*non_empty_combinations): + combined_sub_words = ' '.join(word for word, _ in sub_word_combination) + combined_sub_words_no_space = ''.join(word for word, _ in sub_word_combination) + + # Calculate the product of scores for the sub-words in the combination + combined_score = 1.0 # Initialize with a score of 1.0 + for word, score in sub_word_combination: + combined_score *= score / max_scores[word] # Normalize the score + + # Append the combined sub-words and their normalized score + result.append((combined_sub_words, combined_score)) + result.append((combined_sub_words_no_space, combined_score)) + + # Sort the result by normalized scores (optional) + result.sort(key=lambda x: x[1], reverse=True) + + return result + def __only_alpha(self, txt): return "".join(c for c in txt if c in string.ascii_letters) From d3a15a342b6b7a49af1dcced2e0cf857a90997b8 Mon Sep 17 00:00:00 2001 From: Aviad Hashuel Date: Thu, 28 Sep 2023 10:05:12 +0300 Subject: [PATCH 04/10] add 'filter' to multi word --- splade4elastic/__init__.py | 2 +- splade4elastic/elastic_splade.py | 33 +++++++++++++++++++++----------- test/basic_test.py | 19 ++++++++++++++---- 3 files changed, 38 insertions(+), 16 deletions(-) diff --git a/splade4elastic/__init__.py b/splade4elastic/__init__.py index 4c74b77..153e628 100644 --- a/splade4elastic/__init__.py +++ b/splade4elastic/__init__.py @@ -1,2 +1,2 @@ -__version__ = "0.0.23" +__version__ = "0.0.24" from .elastic_splade import SpladeRewriter, MLMBaseRewriter, LinearMLMRewriter diff --git a/splade4elastic/elastic_splade.py b/splade4elastic/elastic_splade.py index bd149e0..662b87c 100644 --- a/splade4elastic/elastic_splade.py +++ b/splade4elastic/elastic_splade.py @@ -20,6 +20,17 @@ def __init__(self, model_name: str, expansions_per_word:int = 10, multi_word="sp self.exluded_words = exluded_words self.const_weight = 1 self.multi_word = multi_word + self.vocab = self.read_vocab() if multi_word == "filter" else None + + def read_vocab(self, vocab='/usr/share/dict/words'): + try: + with open(vocab, 'r') as f: + words = [l.strip() for l in f.readlines()] + except FileNotFoundError: + print(f"Could not find {vocab} file, using empty vocab") + return set() + words = [w.lower() for w in words if len(w)>1] + return frozenset(words) def __tokenize_to_words(self, sentence): return sentence.translate( @@ -54,9 +65,7 @@ def __tokenize_and_preserve(self, sentence, text_labels=None): ] def do_expansion(self, word): - if word in self.exluded_words: - return False - return True + return word not in self.exluded_words # expand all words except for the excluded ones def mask_expansion(self, txt): ret = collections.defaultdict(list) @@ -86,19 +95,23 @@ def mask_expansion(self, txt): ::-1 ][:self.k] max_tokens = self.tokenizer.convert_ids_to_tokens(max_ids) + # max_tokens = [t[1:] if t.startswith("Ġ") else t for t in max_tokens] # remove the leading space max_scores = np.sort(mask_token_logits.to("cpu").detach().numpy())[::-1][ :self.k] tokens_tuple = zip(max_tokens, max_scores) sub_words = [(t, s) for t, s in tokens_tuple if t not in special_tokens] all_combinations.append(sub_words) # create all possible combinations of sub-words and normalize their scores - # for example: [('Why', 16.411238), ('C', 16.230715)] - # [('off', 19.477955), ('Employ', 14.500462), ('ut', 14.452579)] - # should be combined to: - # [('Why off', 16.411238*19.477955), ('Why Employ', 16.411238*14.500462), ('Why ut', 16.411238*14.452579), ('C off', 16.230715*19.477955), ('C Employ', 16.230715*14.500462), ('C ut', 16.230715*14.452579)] all_combinations = self.combine_and_normalize(all_combinations) + all_combinations = [(w[1:], s) if w.startswith("Ġ") else (w, s) for w, s in all_combinations] + + if self.multi_word == "filter": + # filter out sub-words that are not in linux built-in dictionary + all_combinations = [(w, s) for w, s in all_combinations if w.lower() in self.vocab] + + ret[wi].extend(all_combinations) - # ret[wi].extend(zip(max_tokens, max_scores)) + ret = dict(ret) return list(ret.values()) @@ -125,8 +138,7 @@ def combine_and_normalize(self, all_combinations): # Iterate through all possible combinations of sub-words for sub_word_combination in itertools.product(*non_empty_combinations): - combined_sub_words = ' '.join(word for word, _ in sub_word_combination) - combined_sub_words_no_space = ''.join(word for word, _ in sub_word_combination) + combined_sub_words = ''.join(word.replace('Ġ', ' ') for word, _ in sub_word_combination) # Calculate the product of scores for the sub-words in the combination combined_score = 1.0 # Initialize with a score of 1.0 @@ -135,7 +147,6 @@ def combine_and_normalize(self, all_combinations): # Append the combined sub-words and their normalized score result.append((combined_sub_words, combined_score)) - result.append((combined_sub_words_no_space, combined_score)) # Sort the result by normalized scores (optional) result.sort(key=lambda x: x[1], reverse=True) diff --git a/test/basic_test.py b/test/basic_test.py index 8a55d7c..5d414fe 100644 --- a/test/basic_test.py +++ b/test/basic_test.py @@ -12,25 +12,36 @@ def setUp(self): def test_base_with_ignore(self): splade = MLMBaseRewriter("roberta-base", expansions_per_word=3, multi_word="ignore") text = "Coffee is good for you" - print(splade.query_expand(text)) + print("Testing MLMBaseRewriter with ignore multi-word option") + print(splade.query_expand(text), end="\n\n") self.assertTrue(True) def test_base_with_split(self): splade = MLMBaseRewriter("roberta-base", expansions_per_word=3, multi_word="split") text = "Coffee is good for you" - print(splade.query_expand(text)) + print("Testing MLMBaseRewriter with split multi-word option") + print(splade.query_expand(text), end="\n\n") + self.assertTrue(True) + + def test_base_with_filter(self): + splade = MLMBaseRewriter("roberta-base", expansions_per_word=3, multi_word="filter") + text = "Coffee is good for you" + print("Testing MLMBaseRewriter with filter multi-word option") + print(splade.query_expand(text), end="\n\n") self.assertTrue(True) def test_splade(self): splade = SpladeRewriter("roberta-base", expansions_per_word=3) text = "Coffee is good for you" - print(splade.query_expand(text)) + print("Testing SpladeRewriter") + print(splade.query_expand(text), end="\n\n") self.assertTrue(True) def test_linear(self): splade = LinearMLMRewriter("roberta-base", expansions_per_word=3) text = "Coffee is good for you" - print(splade.query_expand(text)) + print("Testing LinearMLMRewriter") + print(splade.query_expand(text), end="\n\n") self.assertTrue(True) From 5834b73f7f80a4a274c6fe58616e635aaef00ca4 Mon Sep 17 00:00:00 2001 From: Aviad Hashuel Date: Thu, 28 Sep 2023 15:00:51 +0300 Subject: [PATCH 05/10] fix BOS --- splade4elastic/elastic_splade.py | 33 +++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/splade4elastic/elastic_splade.py b/splade4elastic/elastic_splade.py index 662b87c..7e8100f 100644 --- a/splade4elastic/elastic_splade.py +++ b/splade4elastic/elastic_splade.py @@ -33,9 +33,26 @@ def read_vocab(self, vocab='/usr/share/dict/words'): return frozenset(words) def __tokenize_to_words(self, sentence): - return sentence.translate( - {ord(c): " " for c in string.punctuation} - ).split() + # Split the sentence into words + words = sentence.split() + + # Define a translation table to replace punctuation marks with spaces + translation_table = str.maketrans('', '', string.punctuation) + + # Initialize a list to store the cleaned words + cleaned_words = [] + + for word in words: + # Check if the word is not a special token + if word not in self.tokenizer.all_special_tokens: + # Remove punctuation marks from the word + cleaned_word = word.translate(translation_table) + cleaned_words.append(cleaned_word) + else: + # If the word is a special token, add it as is + cleaned_words.append(word) + + return cleaned_words def __tokenize_and_preserve(self, sentence, text_labels=None): if type(sentence) == str: @@ -69,7 +86,9 @@ def do_expansion(self, word): def mask_expansion(self, txt): ret = collections.defaultdict(list) - special_tokens = [self.tokenizer.bos_token, self.tokenizer.eos_token, self.tokenizer.mask_token] + special_tokens = self.tokenizer.all_special_tokens + if self.tokenizer.bos_token: + txt = self.tokenizer.bos_token + ' ' + txt X = self.tokenizer.encode(txt, return_tensors="pt") word2token = self.__tokenize_and_preserve(txt) words = self.__tokenize_to_words(txt) @@ -104,12 +123,12 @@ def mask_expansion(self, txt): # create all possible combinations of sub-words and normalize their scores all_combinations = self.combine_and_normalize(all_combinations) all_combinations = [(w[1:], s) if w.startswith("Ġ") else (w, s) for w, s in all_combinations] - + if self.multi_word == "filter": # filter out sub-words that are not in linux built-in dictionary all_combinations = [(w, s) for w, s in all_combinations if w.lower() in self.vocab] - + all_combinations = [(w, s) for w, s in all_combinations if len(w) > 0] # filter out empty sub-words ret[wi].extend(all_combinations) ret = dict(ret) @@ -138,7 +157,7 @@ def combine_and_normalize(self, all_combinations): # Iterate through all possible combinations of sub-words for sub_word_combination in itertools.product(*non_empty_combinations): - combined_sub_words = ''.join(word.replace('Ġ', ' ') for word, _ in sub_word_combination) + combined_sub_words = ''.join(word.replace('Ġ', ' ') for word, _ in sub_word_combination) # Combine the sub-words and use 'Ġ' to decide where to add spaces # Calculate the product of scores for the sub-words in the combination combined_score = 1.0 # Initialize with a score of 1.0 From 81c6bfedb2260fdf0b290469669af3a26de2718f Mon Sep 17 00:00:00 2001 From: Aviad Hashuel Date: Thu, 28 Sep 2023 15:01:03 +0300 Subject: [PATCH 06/10] fix notebook example --- notebooks/elastic_splade.ipynb | 61 ++++++++++++++++++++++++++-------- 1 file changed, 47 insertions(+), 14 deletions(-) diff --git a/notebooks/elastic_splade.ipynb b/notebooks/elastic_splade.ipynb index 2e32cc4..c9b94a3 100644 --- a/notebooks/elastic_splade.ipynb +++ b/notebooks/elastic_splade.ipynb @@ -254,7 +254,7 @@ "\n", "sys.path.append(\"../simple_splade\")\n", "\n", - "from elastic_splade import splade" + "from splade4elastic import SpladeRewriter" ] }, { @@ -269,7 +269,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -284,16 +284,16 @@ } ], "source": [ - "spalde_model = splade(model_name, model_name)" + "spalde_model = SpladeRewriter(model_name, expansions_per_word=3)" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ - "test_texts = [\n", + "texts = [\n", " \"My name is John\",\n", " \"The quick brown fox jumps over the lazy dog\",\n", " \"I like to eat apples\",\n", @@ -302,26 +302,59 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 14, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using eos_token, but it is not set yet.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "My name is John\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using eos_token, but it is not set yet.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(^0.79 OR the^0.21) (my^1.0 OR his^0.71 OR her^0.08) (name^1.0 OR father^0.0 OR husband^0.0) (is^1.0 OR was^0.04 OR means^0.0)\n", + "The quick brown fox jumps over the lazy dog\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using eos_token, but it is not set yet.\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "My name is John\n", - "(my^0.25 OR his^0.13 OR her^0.11 OR the^0.09 OR your^0.09 OR their^0.08 OR its^0.07 OR our^0.07 OR last^0.06 OR another^0.06) (name^0.3 OR father^0.09 OR husband^0.08 OR dad^0.08 OR brother^0.08 OR surname^0.08 OR nickname^0.07 OR title^0.07 OR boyfriend^0.07 OR son^0.07) (is^0.33 OR ^0.27 OR was^0.13 OR means^0.08 OR says^0.06 OR are^0.06 OR goes^0.06)\n", - "The quick brown fox jumps over the lazy dog\n", - "(the^0.29 OR a^0.15 OR one^0.09 OR some^0.08 OR little^0.07 OR this^0.07 OR his^0.06 OR another^0.06 OR no^0.06 OR my^0.06) (lazy^0.21 OR little^0.12 OR fat^0.09 OR young^0.09 OR big^0.09 OR great^0.08 OR hungry^0.08 OR small^0.08 OR large^0.08 OR old^0.08) (thinking^0.11 OR little^0.1 OR old^0.1 OR ^0.1 OR talking^0.1 OR y^0.1 OR ie^0.1 OR ing^0.1 OR en^0.09 OR ening^0.09) (dog^0.2 OR cat^0.1 OR ie^0.1 OR bear^0.09 OR man^0.09 OR one^0.09 OR boy^0.09 OR girl^0.08 OR guy^0.08 OR wolf^0.08) (took^0.12 OR takes^0.11 OR watched^0.11 OR watches^0.1 OR loomed^0.1 OR looked^0.1 OR ran^0.09 OR stood^0.09 OR watching^0.09 OR wins^0.09) (over^0.2 OR ^0.18 OR at^0.11 OR on^0.1 OR and^0.09 OR after^0.09 OR from^0.08 OR off^0.08 OR behind^0.08) (the^0.27 OR a^0.11 OR his^0.11 OR her^0.09 OR their^0.08 OR this^0.08 OR my^0.07 OR its^0.07 OR another^0.07 OR that^0.06) (little^0.12 OR big^0.12 OR small^0.11 OR other^0.1 OR wild^0.1 OR dead^0.1 OR hot^0.09 OR startled^0.09 OR barking^0.09 OR old^0.09)\n", + "(the^1.0 OR ^0.77) (the^1.0 OR a^0.69 OR one^0.0) (little^0.87 OR fat^0.07 OR young^0.06) (thinking^0.49 OR little^0.26 OR old^0.26) (dog^1.0 OR cat^0.33 OR ie^0.27) (took^0.48 OR takes^0.26 OR watched^0.26) (at^0.58 OR on^0.21 OR ^0.2) (the^1.0 OR a^0.07 OR his^0.06) (little^0.43 OR big^0.38 OR small^0.18)\n", "I like to eat apples\n", - "(i^0.27 OR they^0.11 OR we^0.1 OR you^0.1 OR people^0.09 OR and^0.07 OR just^0.07 OR some^0.07 OR he^0.06 OR women^0.06) (like^0.2 OR want^0.1 OR wanted^0.09 OR had^0.09 OR need^0.09 OR have^0.09 OR used^0.09 OR needed^0.09 OR forgot^0.08 OR refuse^0.08) (to^0.37 OR i^0.23 OR and^0.08 OR ta^0.06 OR you^0.06 OR or^0.06 OR they^0.05 OR too^0.05 OR not^0.05) (eat^0.21 OR like^0.18 OR grow^0.09 OR pick^0.08 OR love^0.08 OR have^0.08 OR see^0.07 OR drink^0.07 OR steal^0.07 OR buy^0.07)\n" + "(^0.79 OR the^0.21) (i^1.0 OR they^0.1 OR we^0.05) (want^0.6 OR wanted^0.21 OR had^0.19) (to^1.0 OR i^1.0 OR and^0.0) (eat^1.0 OR grow^0.13 OR pick^0.09)\n" ] } ], "source": [ - "for test_text in test_texts:\n", - " print(test_text)\n", - " print(spalde_model.splade_it(test_text))" + "for text in texts:\n", + " print(text)\n", + " print(spalde_model.query_expand(text))" ] } ], From 2239691ee231addfa25163a766e439d180e4937f Mon Sep 17 00:00:00 2001 From: Aviad Hashuel Date: Thu, 28 Sep 2023 15:01:17 +0300 Subject: [PATCH 07/10] add prints in tests --- test/basic_test.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/test/basic_test.py b/test/basic_test.py index 5d414fe..3d389fb 100644 --- a/test/basic_test.py +++ b/test/basic_test.py @@ -7,41 +7,37 @@ class BasicTest(unittest.TestCase): def setUp(self): - pass + # self.text = "Coffee is good for you" + self.text = "My name is John" def test_base_with_ignore(self): splade = MLMBaseRewriter("roberta-base", expansions_per_word=3, multi_word="ignore") - text = "Coffee is good for you" print("Testing MLMBaseRewriter with ignore multi-word option") - print(splade.query_expand(text), end="\n\n") + print(splade.query_expand(self.text), end="\n\n") self.assertTrue(True) def test_base_with_split(self): splade = MLMBaseRewriter("roberta-base", expansions_per_word=3, multi_word="split") - text = "Coffee is good for you" print("Testing MLMBaseRewriter with split multi-word option") - print(splade.query_expand(text), end="\n\n") + print(splade.query_expand(self.text), end="\n\n") self.assertTrue(True) def test_base_with_filter(self): splade = MLMBaseRewriter("roberta-base", expansions_per_word=3, multi_word="filter") - text = "Coffee is good for you" print("Testing MLMBaseRewriter with filter multi-word option") - print(splade.query_expand(text), end="\n\n") + print(splade.query_expand(self.text), end="\n\n") self.assertTrue(True) def test_splade(self): splade = SpladeRewriter("roberta-base", expansions_per_word=3) - text = "Coffee is good for you" print("Testing SpladeRewriter") - print(splade.query_expand(text), end="\n\n") + print(splade.query_expand(self.text), end="\n\n") self.assertTrue(True) def test_linear(self): splade = LinearMLMRewriter("roberta-base", expansions_per_word=3) - text = "Coffee is good for you" print("Testing LinearMLMRewriter") - print(splade.query_expand(text), end="\n\n") + print(splade.query_expand(self.text), end="\n\n") self.assertTrue(True) From 162b821deb1c6b1c8a3b425e28f6dff7dc5abf15 Mon Sep 17 00:00:00 2001 From: Aviad Hashuel Date: Thu, 28 Sep 2023 15:12:01 +0300 Subject: [PATCH 08/10] fix last word not returning --- splade4elastic/elastic_splade.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/splade4elastic/elastic_splade.py b/splade4elastic/elastic_splade.py index 7e8100f..92f05e0 100644 --- a/splade4elastic/elastic_splade.py +++ b/splade4elastic/elastic_splade.py @@ -109,7 +109,7 @@ def mask_expansion(self, txt): logits = self.model(X_m).logits all_combinations = [] for mask_token_index, token, _ in lst: - mask_token_logits = logits[0, mask_token_index, :] + mask_token_logits = logits[0, mask_token_index + 1, :] # need to add 1 because of the bos token we added max_ids = np.argsort(mask_token_logits.to("cpu").detach().numpy())[ ::-1 ][:self.k] @@ -227,7 +227,7 @@ def transform(self, X: List[str]): class LinearMLMRewriter(MLMBaseRewriter): def logits2weights(self, word_logits): - min_score = min(w[1] for w in word_logits) + min_score = min(w[1] for w in word_logits) ret = [(w[0], w[1] - min_score) for w in word_logits] norm_factor = sum(w[1] for w in ret) ret = [(w[0], w[1]/norm_factor) for w in ret] From c519b2a1e5f659ce1974830f19ada8e1a5918602 Mon Sep 17 00:00:00 2001 From: Aviad Hashuel Date: Thu, 28 Sep 2023 15:15:15 +0300 Subject: [PATCH 09/10] fix logits2weights if word_logits is empty list --- splade4elastic/elastic_splade.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/splade4elastic/elastic_splade.py b/splade4elastic/elastic_splade.py index 92f05e0..3b281e2 100644 --- a/splade4elastic/elastic_splade.py +++ b/splade4elastic/elastic_splade.py @@ -227,6 +227,8 @@ def transform(self, X: List[str]): class LinearMLMRewriter(MLMBaseRewriter): def logits2weights(self, word_logits): + if len(word_logits) == 0: + return word_logits # empty list min_score = min(w[1] for w in word_logits) ret = [(w[0], w[1] - min_score) for w in word_logits] norm_factor = sum(w[1] for w in ret) @@ -234,6 +236,8 @@ def logits2weights(self, word_logits): return ret class SpladeRewriter(MLMBaseRewriter): def logits2weights(self, word_logits): + if len(word_logits) == 0: + return word_logits # empty list ret = [(w[0], np.exp(w[1])) for w in word_logits] norm_factor = sum(w[1] for w in ret) ret = [(w[0], w[1]/norm_factor) for w in ret] From bf4ede9d41549ecc2857e14ee93649d27c7f4f5a Mon Sep 17 00:00:00 2001 From: Aviad Hashuel Date: Thu, 28 Sep 2023 17:02:06 +0300 Subject: [PATCH 10/10] fix multi words --- splade4elastic/elastic_splade.py | 9 +++++---- test/basic_test.py | 4 ++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/splade4elastic/elastic_splade.py b/splade4elastic/elastic_splade.py index 3b281e2..1181b7d 100644 --- a/splade4elastic/elastic_splade.py +++ b/splade4elastic/elastic_splade.py @@ -109,7 +109,7 @@ def mask_expansion(self, txt): logits = self.model(X_m).logits all_combinations = [] for mask_token_index, token, _ in lst: - mask_token_logits = logits[0, mask_token_index + 1, :] # need to add 1 because of the bos token we added + mask_token_logits = logits[0, mask_token_index, :] # need to add 1 because of the bos token we added max_ids = np.argsort(mask_token_logits.to("cpu").detach().numpy())[ ::-1 ][:self.k] @@ -126,7 +126,7 @@ def mask_expansion(self, txt): if self.multi_word == "filter": # filter out sub-words that are not in linux built-in dictionary - all_combinations = [(w, s) for w, s in all_combinations if w.lower() in self.vocab] + all_combinations = [(w, s) for w, s in all_combinations if w.lower() in self.vocab or len(w.split(" ")) > 1] all_combinations = [(w, s) for w, s in all_combinations if len(w) > 0] # filter out empty sub-words ret[wi].extend(all_combinations) @@ -172,7 +172,8 @@ def combine_and_normalize(self, all_combinations): return result - def __only_alpha(self, txt): + def __only_alpha(self, txt): + # consider to delete this function return "".join(c for c in txt if c in string.ascii_letters) def logits2weights(self, word_logits): @@ -192,7 +193,7 @@ def __elastic_format(self, expanded_list, text): # print(words) # The original word should have a higher score words = self.__force_weights(words, text) - words = [(self.__only_alpha(w[0]).lower(), w[1]) for w in words if w[0] != self.tokenizer.bos_token] + words = [(w[0].lower(), w[1]) for w in words if w[0] != self.tokenizer.bos_token] # unite equal words and sum their scores unique_words = {w[0] for w in words} words = [(unique, sum(w[1] for w in words if w[0] == unique)) for unique in unique_words] diff --git a/test/basic_test.py b/test/basic_test.py index 3d389fb..84d6f4c 100644 --- a/test/basic_test.py +++ b/test/basic_test.py @@ -7,8 +7,8 @@ class BasicTest(unittest.TestCase): def setUp(self): - # self.text = "Coffee is good for you" - self.text = "My name is John" + self.text = "Coffee is good for you" + # self.text = "My name is John" def test_base_with_ignore(self): splade = MLMBaseRewriter("roberta-base", expansions_per_word=3, multi_word="ignore")