Skip to content

Commit

Permalink
Merge pull request #12 from WhereIsAI/bugfix/bi_attention_mask
Browse files Browse the repository at this point in the history
Bugfix/bi attention mask
  • Loading branch information
SeanLee97 authored May 31, 2024
2 parents d64f1bc + 8ed4837 commit 8688388
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 16 deletions.
14 changes: 9 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ Tool for converting LLMs from uni-directional to bi-directional for tasks like c
</a>


## Supported Models

- LLaMA
- Mistral
- Qwen2
- OpenELM

## Usage

1) `python -m pip install -U billm`
Expand Down Expand Up @@ -103,13 +110,10 @@ tokens = token_classifier(sentence)
print(tokens)
```

### Sentence Embeddings

## Supported Models
refer to AnglE: https://github.com/SeanLee97/AnglE

- LLaMA
- Mistral
- Qwen2
- OpenELM

## Citation

Expand Down
14 changes: 7 additions & 7 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion src/billm/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def forward(
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1]
)
bi_attention_mask = torch.zeros_like(causal_mask)
bi_attention_mask = torch.zeros_like(causal_mask) if causal_mask is not None else None

# embed positions
hidden_states = inputs_embeds
Expand Down
2 changes: 1 addition & 1 deletion src/billm/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def forward(
past_key_values_length,
sliding_window=self.config.sliding_window,
)
bi_attention_mask = torch.zeros_like(attention_mask)
bi_attention_mask = torch.zeros_like(attention_mask) if attention_mask is not None else None

hidden_states = inputs_embeds

Expand Down
2 changes: 1 addition & 1 deletion src/billm/modeling_openelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ def forward(
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
bi_attention_mask = torch.zeros_like(causal_mask)
bi_attention_mask = torch.zeros_like(causal_mask) if causal_mask is not None else None

# embed positions
hidden_states = inputs_embeds
Expand Down
2 changes: 1 addition & 1 deletion src/billm/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def forward(
past_key_values_length,
sliding_window=self.config.sliding_window,
)
bi_attention_mask = torch.zeros_like(attention_mask)
bi_attention_mask = torch.zeros_like(attention_mask) if attention_mask is not None else None

hidden_states = inputs_embeds

Expand Down

0 comments on commit 8688388

Please sign in to comment.