Skip to content

Commit

Permalink
use bool masks (mlfoundations#665)
Browse files Browse the repository at this point in the history
  • Loading branch information
SeyedAlirezaFatemi authored and Interpause committed May 23, 2024
1 parent 72b06a2 commit d9a2bbe
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/open_clip/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def build_attention_mask(self):

def build_cls_mask(self, text, cast_dtype: torch.dtype):
cls_mask = (text != self.pad_id).unsqueeze(1)
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0)
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True)
additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)
additive_mask.fill_(0)
additive_mask.masked_fill_(~cls_mask, float("-inf"))
Expand Down
4 changes: 2 additions & 2 deletions tests/test_hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
def test_poolers():
bs, sl, d = 2, 10, 5
h = torch.arange(sl).repeat(bs).reshape(bs, sl)[..., None] * torch.linspace(0.2, 1., d)
mask = torch.ones(bs, sl, dtype=torch.long)
mask[:2, 6:] = 0
mask = torch.ones(bs, sl, dtype=torch.bool)
mask[:2, 6:] = False
x = BaseModelOutput(h)
for name, cls in _POOLERS.items():
pooler = cls()
Expand Down

0 comments on commit d9a2bbe

Please sign in to comment.