Skip to content

Commit

Permalink
fix bug hfcrossscorer
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuxuan ZONG committed Feb 23, 2023
1 parent 2e76935 commit 5c4c04f
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 15 deletions.
1 change: 1 addition & 0 deletions src/xpmir/letor/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,7 @@ def execute(self):
for batch in tqdm(self.iter_batches()):

# scores in shape: [batch_size, 2]
self.teacher_model.eval()
scores = self.teacher_model(batch)
scores = scores.reshape(2, -1).T

Expand Down
77 changes: 62 additions & 15 deletions src/xpmir/neural/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@
from xpmir.letor.records import BaseRecords
from xpmir.neural import TorchLearnableScorer
from experimaestro import Param
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
from xpmir.letor.records import TokenizedTexts
from typing import List, Tuple
from xpmir.distributed import DistributableModel
import torch


class HFCrossScorer(TorchLearnableScorer):
class HFCrossScorer(TorchLearnableScorer, DistributableModel):
"""Load a cross scorer model from the huggingface"""

hf_id: Param[str]
Expand All @@ -14,19 +19,61 @@ class HFCrossScorer(TorchLearnableScorer):
"""the max length for the transformer model"""

def __post_init__(self):
try:
from sentence_transformers import CrossEncoder
except Exception:
self.logger.error(
"Sentence transformer is not installed:"
"pip install -U sentence_transformers"
)
raise
self.model = CrossEncoder(self.hf_id, max_length=self.max_length)
# FIXME: consider how to treat with the device
self.device = torch.nn.Parameter(torch.Tensor()).device

self.config = AutoConfig.from_pretrained(self.hf_id)
self.model = AutoModelForSequenceClassification.from_pretrained(
self.hf_id, config=self.config
)
self.tokenizer = AutoTokenizer.from_pretrained(self.hf_id)

def _initialize(self, random):
pass

def batch_tokenize(
self,
input_records: BaseRecords,
maxlen=None,
mask=False,
) -> TokenizedTexts:
"""Transform the text to tokens by using the tokenizer"""
if maxlen is None:
maxlen = self.tokenizer.model_max_length
else:
maxlen = min(maxlen, self.tokenizer.model_max_length)

texts: List[Tuple[str, str]] = [
(q.text, d.text)
for q, d in zip(input_records.queries, input_records.documents)
]

r = self.tokenizer(
texts,
max_length=maxlen,
truncation=True,
padding=True,
return_tensors="pt",
return_length=True,
return_attention_mask=mask,
)
return TokenizedTexts(
None,
r["input_ids"].to(self.device),
r["length"],
r.get("attention_mask", None),
r.get("token_type_ids", None), # if r["token_type_ids"] else None
)

def forward(self, inputs: BaseRecords, info: TrainerContext = None):
predict_score_list = self.model.predict(
[(q.text, d.text) for q, d in zip(inputs.queries, inputs.documents)],
convert_to_tensor=True,
) # Tensor[float] of length records size
return predict_score_list

tokenized = self.batch_tokenize(inputs, maxlen=self.max_length, mask=True)
# strange that some existing models on the huggingface don't use the token_type
with torch.set_grad_enabled(torch.is_grad_enabled()):
result = self.model(
tokenized.ids, attention_mask=tokenized.mask.to(self.device)
).logits # Tensor[float] of length records size
return result

def distribute_models(self, update):
self.model = update(self.model)

0 comments on commit 5c4c04f

Please sign in to comment.