Skip to content

Commit

Permalink
Fix MLM loss ignore idx (#552)
Browse files Browse the repository at this point in the history
  • Loading branch information
farhadrgh authored Dec 23, 2024
1 parent 30527b1 commit c28772b
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 4 deletions.
5 changes: 5 additions & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,8 @@ awscli==1.33.33
nbval==0.11.0
# For NvFaidx equivalence tests
pyfaidx==0.8.1.3

# Temporary pin for pytorch-lightning until megatron callbacks in ProgressPrinter can get fixed.
# See https://nvidia.slack.com/archives/C02A7LYGHK8/p1734727482697309
pytorch-lightning<2.5.0
lightning<2.5.0
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from bionemo.esm2.api import ESM2GenericConfig, ESM2Model
from bionemo.esm2.data import tokenizer
from bionemo.llm.data.collate import MLM_LOSS_IGNORE_INDEX
from bionemo.llm.data.label2id_tokenizer import Label2IDTokenizer
from bionemo.llm.data.types import BertSample
from bionemo.llm.model.biobert.model import BioBertOutput
Expand Down Expand Up @@ -230,6 +231,7 @@ def __init__(
self.tokenizer = tokenizer
label_tokenizer = Label2IDTokenizer()
self.label_tokenizer = label_tokenizer.build_vocab("CHE")
self.label_cls_eos_id = MLM_LOSS_IGNORE_INDEX

def __len__(self) -> int:
"""Length of dataset."""
Expand Down Expand Up @@ -257,13 +259,13 @@ def _tokenize_labels(self, labels_sequence: str) -> Tensor:

# # for multi-label classification with BCEWithLogitsLoss
# tokenized_labels = torch.nn.functional.one_hot(label_ids, num_classes=self.label_tokenizer.vocab_size)
# cls_eos = torch.full((1, self.label_tokenizer.vocab_size), -1, dtype=tokenized_labels.dtype)
# cls_eos = torch.full((1, self.label_tokenizer.vocab_size), self.label_cls_eos_id, dtype=tokenized_labels.dtype)

# for multi-class (mutually exclusive) classification with CrossEntropyLoss
tokenized_labels = label_ids
cls_eos = torch.tensor([-1], dtype=tokenized_labels.dtype)
cls_eos = torch.tensor([self.label_cls_eos_id], dtype=tokenized_labels.dtype)

# add cls / eos labels with padding value -1 to have the same shape as tokenized_sequence
# add cls / eos label ids with padding value -100 to have the same shape as tokenized_sequence
labels = torch.cat((cls_eos, tokenized_labels, cls_eos))
return labels

Expand Down
5 changes: 4 additions & 1 deletion sub-packages/bionemo-llm/src/bionemo/llm/data/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
_warned_once: bool = False


MLM_LOSS_IGNORE_INDEX = -100 # This should match the masked value used in the MLM loss mask.


def padding_collate_fn(
batch: Sequence[_T],
padding_values: dict[str, int],
Expand Down Expand Up @@ -105,7 +108,7 @@ def bert_padding_collate_fn(
"text": padding_value,
"types": 0,
"attention_mask": False,
"labels": -100, # This should match the masked value used in the MLM loss mask.
"labels": MLM_LOSS_IGNORE_INDEX, # This should match the masked value used in the MLM loss mask.
"loss_mask": False,
"is_random": 0,
}
Expand Down

0 comments on commit c28772b

Please sign in to comment.