Skip to content

Commit

Permalink
fix multi words
Browse files Browse the repository at this point in the history
  • Loading branch information
aviad-has committed Sep 28, 2023
1 parent c519b2a commit bf4ede9
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
9 changes: 5 additions & 4 deletions splade4elastic/elastic_splade.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def mask_expansion(self, txt):
logits = self.model(X_m).logits
all_combinations = []
for mask_token_index, token, _ in lst:
mask_token_logits = logits[0, mask_token_index + 1, :] # need to add 1 because of the bos token we added
mask_token_logits = logits[0, mask_token_index, :] # need to add 1 because of the bos token we added
max_ids = np.argsort(mask_token_logits.to("cpu").detach().numpy())[
::-1
][:self.k]
Expand All @@ -126,7 +126,7 @@ def mask_expansion(self, txt):

if self.multi_word == "filter":
# filter out sub-words that are not in linux built-in dictionary
all_combinations = [(w, s) for w, s in all_combinations if w.lower() in self.vocab]
all_combinations = [(w, s) for w, s in all_combinations if w.lower() in self.vocab or len(w.split(" ")) > 1]

all_combinations = [(w, s) for w, s in all_combinations if len(w) > 0] # filter out empty sub-words
ret[wi].extend(all_combinations)
Expand Down Expand Up @@ -172,7 +172,8 @@ def combine_and_normalize(self, all_combinations):

return result

def __only_alpha(self, txt):
def __only_alpha(self, txt):
# consider to delete this function
return "".join(c for c in txt if c in string.ascii_letters)

def logits2weights(self, word_logits):
Expand All @@ -192,7 +193,7 @@ def __elastic_format(self, expanded_list, text):
# print(words)
# The original word should have a higher score
words = self.__force_weights(words, text)
words = [(self.__only_alpha(w[0]).lower(), w[1]) for w in words if w[0] != self.tokenizer.bos_token]
words = [(w[0].lower(), w[1]) for w in words if w[0] != self.tokenizer.bos_token]
# unite equal words and sum their scores
unique_words = {w[0] for w in words}
words = [(unique, sum(w[1] for w in words if w[0] == unique)) for unique in unique_words]
Expand Down
4 changes: 2 additions & 2 deletions test/basic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

class BasicTest(unittest.TestCase):
def setUp(self):
# self.text = "Coffee is good for you"
self.text = "My name is John"
self.text = "Coffee is good for you"
# self.text = "My name is John"

def test_base_with_ignore(self):
splade = MLMBaseRewriter("roberta-base", expansions_per_word=3, multi_word="ignore")
Expand Down

0 comments on commit bf4ede9

Please sign in to comment.