Skip to content

Commit

Permalink
Merge pull request #1 from argmaxml/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
urigoren authored Oct 2, 2023
2 parents fc272de + 1458a18 commit 1ee1ef8
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 45 deletions.
61 changes: 47 additions & 14 deletions notebooks/elastic_splade.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@
"\n",
"sys.path.append(\"../simple_splade\")\n",
"\n",
"from elastic_splade import splade"
"from splade4elastic import SpladeRewriter"
]
},
{
Expand All @@ -269,7 +269,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 10,
"metadata": {},
"outputs": [
{
Expand All @@ -284,16 +284,16 @@
}
],
"source": [
"spalde_model = splade(model_name, model_name)"
"spalde_model = SpladeRewriter(model_name, expansions_per_word=3)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"test_texts = [\n",
"texts = [\n",
" \"My name is John\",\n",
" \"The quick brown fox jumps over the lazy dog\",\n",
" \"I like to eat apples\",\n",
Expand All @@ -302,26 +302,59 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using eos_token, but it is not set yet.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"My name is John\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using eos_token, but it is not set yet.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(^0.79 OR the^0.21) (my^1.0 OR his^0.71 OR her^0.08) (name^1.0 OR father^0.0 OR husband^0.0) (is^1.0 OR was^0.04 OR means^0.0)\n",
"The quick brown fox jumps over the lazy dog\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using eos_token, but it is not set yet.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"My name is John\n",
"(my^0.25 OR his^0.13 OR her^0.11 OR the^0.09 OR your^0.09 OR their^0.08 OR its^0.07 OR our^0.07 OR last^0.06 OR another^0.06) (name^0.3 OR father^0.09 OR husband^0.08 OR dad^0.08 OR brother^0.08 OR surname^0.08 OR nickname^0.07 OR title^0.07 OR boyfriend^0.07 OR son^0.07) (is^0.33 OR ^0.27 OR was^0.13 OR means^0.08 OR says^0.06 OR are^0.06 OR goes^0.06)\n",
"The quick brown fox jumps over the lazy dog\n",
"(the^0.29 OR a^0.15 OR one^0.09 OR some^0.08 OR little^0.07 OR this^0.07 OR his^0.06 OR another^0.06 OR no^0.06 OR my^0.06) (lazy^0.21 OR little^0.12 OR fat^0.09 OR young^0.09 OR big^0.09 OR great^0.08 OR hungry^0.08 OR small^0.08 OR large^0.08 OR old^0.08) (thinking^0.11 OR little^0.1 OR old^0.1 OR ^0.1 OR talking^0.1 OR y^0.1 OR ie^0.1 OR ing^0.1 OR en^0.09 OR ening^0.09) (dog^0.2 OR cat^0.1 OR ie^0.1 OR bear^0.09 OR man^0.09 OR one^0.09 OR boy^0.09 OR girl^0.08 OR guy^0.08 OR wolf^0.08) (took^0.12 OR takes^0.11 OR watched^0.11 OR watches^0.1 OR loomed^0.1 OR looked^0.1 OR ran^0.09 OR stood^0.09 OR watching^0.09 OR wins^0.09) (over^0.2 OR ^0.18 OR at^0.11 OR on^0.1 OR and^0.09 OR after^0.09 OR from^0.08 OR off^0.08 OR behind^0.08) (the^0.27 OR a^0.11 OR his^0.11 OR her^0.09 OR their^0.08 OR this^0.08 OR my^0.07 OR its^0.07 OR another^0.07 OR that^0.06) (little^0.12 OR big^0.12 OR small^0.11 OR other^0.1 OR wild^0.1 OR dead^0.1 OR hot^0.09 OR startled^0.09 OR barking^0.09 OR old^0.09)\n",
"(the^1.0 OR ^0.77) (the^1.0 OR a^0.69 OR one^0.0) (little^0.87 OR fat^0.07 OR young^0.06) (thinking^0.49 OR little^0.26 OR old^0.26) (dog^1.0 OR cat^0.33 OR ie^0.27) (took^0.48 OR takes^0.26 OR watched^0.26) (at^0.58 OR on^0.21 OR ^0.2) (the^1.0 OR a^0.07 OR his^0.06) (little^0.43 OR big^0.38 OR small^0.18)\n",
"I like to eat apples\n",
"(i^0.27 OR they^0.11 OR we^0.1 OR you^0.1 OR people^0.09 OR and^0.07 OR just^0.07 OR some^0.07 OR he^0.06 OR women^0.06) (like^0.2 OR want^0.1 OR wanted^0.09 OR had^0.09 OR need^0.09 OR have^0.09 OR used^0.09 OR needed^0.09 OR forgot^0.08 OR refuse^0.08) (to^0.37 OR i^0.23 OR and^0.08 OR ta^0.06 OR you^0.06 OR or^0.06 OR they^0.05 OR too^0.05 OR not^0.05) (eat^0.21 OR like^0.18 OR grow^0.09 OR pick^0.08 OR love^0.08 OR have^0.08 OR see^0.07 OR drink^0.07 OR steal^0.07 OR buy^0.07)\n"
"(^0.79 OR the^0.21) (i^1.0 OR they^0.1 OR we^0.05) (want^0.6 OR wanted^0.21 OR had^0.19) (to^1.0 OR i^1.0 OR and^0.0) (eat^1.0 OR grow^0.13 OR pick^0.09)\n"
]
}
],
"source": [
"for test_text in test_texts:\n",
" print(test_text)\n",
" print(spalde_model.splade_it(test_text))"
"for text in texts:\n",
" print(text)\n",
" print(spalde_model.query_expand(text))"
]
}
],
Expand Down
2 changes: 1 addition & 1 deletion splade4elastic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version__ = "0.0.23"
__version__ = "0.0.24"
from .elastic_splade import SpladeRewriter, MLMBaseRewriter, LinearMLMRewriter
131 changes: 110 additions & 21 deletions splade4elastic/elastic_splade.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,52 @@
class MLMBaseRewriter:
"""Elastic SPLADE model"""

def __init__(self, model_name: str, expansions_per_word:int = 10, exluded_words=[]):
def __init__(self, model_name: str, expansions_per_word:int = 10, multi_word="split", exluded_words=[]):
"""Initialize the model
Args:
model_name (str): name of the model
multi_word (str, optional): How to handle multi-word tokens. Defaults to "split". Can be "filter" or "ignore"
"""
self.tokenizer = AutoTokenizer.from_pretrained(model_name, bos_token="<s>")
self.model = AutoModelForMaskedLM.from_pretrained(model_name)
self.k=expansions_per_word
self.exluded_words = exluded_words
self.const_weight = 1
self.multi_word = multi_word
self.vocab = self.read_vocab() if multi_word == "filter" else None

def read_vocab(self, vocab='/usr/share/dict/words'):
try:
with open(vocab, 'r') as f:
words = [l.strip() for l in f.readlines()]
except FileNotFoundError:
print(f"Could not find {vocab} file, using empty vocab")
return set()
words = [w.lower() for w in words if len(w)>1]
return frozenset(words)

def __tokenize_to_words(self, sentence):
return sentence.translate(
{ord(c): " " for c in string.punctuation}
).split()
# Split the sentence into words
words = sentence.split()

# Define a translation table to replace punctuation marks with spaces
translation_table = str.maketrans('', '', string.punctuation)

# Initialize a list to store the cleaned words
cleaned_words = []

for word in words:
# Check if the word is not a special token
if word not in self.tokenizer.all_special_tokens:
# Remove punctuation marks from the word
cleaned_word = word.translate(translation_table)
cleaned_words.append(cleaned_word)
else:
# If the word is a special token, add it as is
cleaned_words.append(word)

return cleaned_words

def __tokenize_and_preserve(self, sentence, text_labels=None):
if type(sentence) == str:
Expand Down Expand Up @@ -52,43 +82,98 @@ def __tokenize_and_preserve(self, sentence, text_labels=None):
]

def do_expansion(self, word):
if word in self.exluded_words:
return False
return True
return word not in self.exluded_words # expand all words except for the excluded ones

def mask_expansion(self, txt):
ret = collections.defaultdict(list)
special_tokens = self.tokenizer.all_special_tokens
if self.tokenizer.bos_token:
txt = self.tokenizer.bos_token + ' ' + txt
X = self.tokenizer.encode(txt, return_tensors="pt")
word2token = self.__tokenize_and_preserve(txt)
words = self.__tokenize_to_words(txt)

for wi, lst in word2token:
if not self.do_expansion(words[wi]):
# skip this word
ret[wi].append((words[wi], self.const_weight))
continue
if self.multi_word == "ignore" and len(lst) > 1: # skip multi-word tokens
ret[wi].append((words[wi], self.const_weight))
continue
X_m = X.clone()
for mask_token_index, token, _ in lst:
ti = mask_token_index
# if self.tokenizer.bos_token:
# ti += 1
X_m[0, ti] = self.tokenizer.mask_token_id
logits = self.model(X_m).logits
all_combinations = []
for mask_token_index, token, _ in lst:
mask_token_logits = logits[0, mask_token_index, :]
mask_token_logits = logits[0, mask_token_index, :] # need to add 1 because of the bos token we added
max_ids = np.argsort(mask_token_logits.to("cpu").detach().numpy())[
::-1
][:self.k]
max_tokens = self.tokenizer.convert_ids_to_tokens(max_ids)
# max_tokens = [t[1:] if t.startswith("Ġ") else t for t in max_tokens] # remove the leading space
max_scores = np.sort(mask_token_logits.to("cpu").detach().numpy())[::-1][ :self.k]

ret[wi].extend(zip(max_tokens, max_scores))
tokens_tuple = zip(max_tokens, max_scores)
sub_words = [(t, s) for t, s in tokens_tuple if t not in special_tokens]
all_combinations.append(sub_words)

# create all possible combinations of sub-words and normalize their scores
all_combinations = self.combine_and_normalize(all_combinations)
all_combinations = [(w[1:], s) if w.startswith("Ġ") else (w, s) for w, s in all_combinations]

if self.multi_word == "filter":
# filter out sub-words that are not in linux built-in dictionary
all_combinations = [(w, s) for w, s in all_combinations if w.lower() in self.vocab or len(w.split(" ")) > 1]

all_combinations = [(w, s) for w, s in all_combinations if len(w) > 0] # filter out empty sub-words
ret[wi].extend(all_combinations)

ret = dict(ret)
if self.tokenizer.bos_token:
del ret[0]
return list(ret.values())

def __only_alpha(self, txt):

def combine_and_normalize(self, all_combinations):
result = []

# Filter out empty sub-lists
non_empty_combinations = [sub_list for sub_list in all_combinations if sub_list]

# Check if there are any non-empty sub-lists
if not non_empty_combinations:
return result

# Initialize with the first non-empty sub-list
initial_combination = non_empty_combinations[0]

# Check if there's only one non-empty sub-list
if len(non_empty_combinations) == 1:
return initial_combination

# Create a dictionary to store maximum scores for each word
max_scores = {word: max(score for _, score in sub_list) for sub_list in non_empty_combinations for word, _ in sub_list}

# Iterate through all possible combinations of sub-words
for sub_word_combination in itertools.product(*non_empty_combinations):
combined_sub_words = ''.join(word.replace('Ġ', ' ') for word, _ in sub_word_combination) # Combine the sub-words and use 'Ġ' to decide where to add spaces

# Calculate the product of scores for the sub-words in the combination
combined_score = 1.0 # Initialize with a score of 1.0
for word, score in sub_word_combination:
combined_score *= score / max_scores[word] # Normalize the score

# Append the combined sub-words and their normalized score
result.append((combined_sub_words, combined_score))

# Sort the result by normalized scores (optional)
result.sort(key=lambda x: x[1], reverse=True)

return result

def __only_alpha(self, txt):
# consider to delete this function
return "".join(c for c in txt if c in string.ascii_letters)

def logits2weights(self, word_logits):
Expand All @@ -101,21 +186,21 @@ def __force_weights(self, word_logits, txt):
def __elastic_format(self, expanded_list, text):
ret = []
text = text.lower().split()
# print(text)
for words in expanded_list:
# print(words)
words = self.logits2weights(words)
# print(words)
# The original word should have a higher score
words = self.__force_weights(words, text)
words = [(self.__only_alpha(w[0]).lower(), w[1]) for w in words if w[0] != self.tokenizer.bos_token]
words = [(w[0].lower(), w[1]) for w in words if w[0] != self.tokenizer.bos_token]
# unite equal words and sum their scores
unique_words = {w[0] for w in words}
words = [(unique, sum(w[1] for w in words if w[0] == unique)) for unique in unique_words]
# sort by score
words = sorted(words, key=lambda x: x[1], reverse=True)
# print(words)
or_statement = []
for w in words:
or_statement.append(f"{w[0]}^{round(float(w[1]), 2)}")
or_statement = " OR ".join(or_statement)
or_statement_list = [f"{w[0]}^{round(float(w[1]), 2)}" for w in words]
or_statement = " OR ".join(or_statement_list)
or_statement = f"({or_statement})"
ret.append(or_statement)
return " ".join(ret)
Expand Down Expand Up @@ -143,13 +228,17 @@ def transform(self, X: List[str]):

class LinearMLMRewriter(MLMBaseRewriter):
def logits2weights(self, word_logits):
min_score = min(w[1] for w in word_logits)
if len(word_logits) == 0:
return word_logits # empty list
min_score = min(w[1] for w in word_logits)
ret = [(w[0], w[1] - min_score) for w in word_logits]
norm_factor = sum(w[1] for w in ret)
ret = [(w[0], w[1]/norm_factor) for w in ret]
return ret
class SpladeRewriter(MLMBaseRewriter):
def logits2weights(self, word_logits):
if len(word_logits) == 0:
return word_logits # empty list
ret = [(w[0], np.exp(w[1])) for w in word_logits]
norm_factor = sum(w[1] for w in ret)
ret = [(w[0], w[1]/norm_factor) for w in ret]
Expand Down
31 changes: 22 additions & 9 deletions test/basic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,37 @@

class BasicTest(unittest.TestCase):
def setUp(self):
pass
self.text = "Coffee is good for you"
# self.text = "My name is John"

def test_base(self):
splade = MLMBaseRewriter("roberta-base", expansions_per_word=3)
text = "Coffee is good for you"
print(splade.query_expand(text))
def test_base_with_ignore(self):
splade = MLMBaseRewriter("roberta-base", expansions_per_word=3, multi_word="ignore")
print("Testing MLMBaseRewriter with ignore multi-word option")
print(splade.query_expand(self.text), end="\n\n")
self.assertTrue(True)

def test_base_with_split(self):
splade = MLMBaseRewriter("roberta-base", expansions_per_word=3, multi_word="split")
print("Testing MLMBaseRewriter with split multi-word option")
print(splade.query_expand(self.text), end="\n\n")
self.assertTrue(True)

def test_base_with_filter(self):
splade = MLMBaseRewriter("roberta-base", expansions_per_word=3, multi_word="filter")
print("Testing MLMBaseRewriter with filter multi-word option")
print(splade.query_expand(self.text), end="\n\n")
self.assertTrue(True)

def test_splade(self):
splade = SpladeRewriter("roberta-base", expansions_per_word=3)
text = "Coffee is good for you"
print(splade.query_expand(text))
print("Testing SpladeRewriter")
print(splade.query_expand(self.text), end="\n\n")
self.assertTrue(True)

def test_linear(self):
splade = LinearMLMRewriter("roberta-base", expansions_per_word=3)
text = "Coffee is good for you"
print(splade.query_expand(text))
print("Testing LinearMLMRewriter")
print(splade.query_expand(self.text), end="\n\n")
self.assertTrue(True)


Expand Down

0 comments on commit 1ee1ef8

Please sign in to comment.