Skip to content

Commit

Permalink
test (fix): sdpa import
Browse files Browse the repository at this point in the history
  • Loading branch information
nickfraser committed Nov 28, 2024
1 parent 3798923 commit 6b1e51a
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/brevitas/nn/test_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from packaging import version
import pytest
import torch
from torch.nn.functional import scaled_dot_product_attention
import torch.nn.functional as F

from brevitas import torch_version
from brevitas.nn import QuantScaledDotProductAttention
Expand Down Expand Up @@ -50,7 +50,7 @@ def test_sdpa_fwd(self, dropout_p, is_causal, scale, enable_gqa, rand_attn_mask)
attn_mask = None
if dropout_p > 0.0:
torch.manual_seed(DROPOUT_SEED)
ref_out = scaled_dot_product_attention(q, k, v, attn_mask, **extra_kwargs)
ref_out = F.scaled_dot_product_attention(q, k, v, attn_mask, **extra_kwargs)
if dropout_p > 0.0:
torch.manual_seed(DROPOUT_SEED)
out = m(q, k, v, attn_mask, **extra_kwargs)
Expand Down

0 comments on commit 6b1e51a

Please sign in to comment.