Skip to content

Commit

Permalink
address #280
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 6, 2024
1 parent 0607310 commit 7e15c09
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 4 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.42.6',
version = '1.42.7',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
18 changes: 18 additions & 0 deletions tests/test_x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,3 +376,21 @@ def test_neo_mlp():

out = mlp(x)
assert out.shape == (3, 7)

def test_custom_alibi():
model = TransformerWrapper(
num_tokens = 20_000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 512,
depth = 2,
heads = 8,
alibi_pos_bias = True
)
)

x = torch.randint(0, 20000, (2, 4))

pos = torch.tensor([[0, 1, 2, 4], [1, 3, 5, 7]])

logits = model(x, pos = pos)
17 changes: 14 additions & 3 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1246,6 +1246,7 @@ def forward(
rel_pos = None,
attn_bias = None,
rotary_pos_emb = None,
pos = None, # for custom alibi positions
prev_attn = None,
mem = None,
mem_mask = None,
Expand Down Expand Up @@ -1392,7 +1393,14 @@ def forward(

if exists(rel_pos):
assert not exists(attn_bias)
attn_bias = rel_pos(i, j)

if exists(pos):
assert isinstance(rel_pos, AlibiPositionalBias), 'only alibi allowed for custom positions at the moment'
# allow for custom positions to be passed in
attn_bias = rel_pos.forward_custom_pos(pos)
else:
attn_bias = rel_pos(i, j)

attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0), value = 0.) # handle memory key / values

# prepare data dependent alibi from forgetting transformers paper, if needed
Expand Down Expand Up @@ -1843,6 +1851,7 @@ def forward(
cache_age = 1,
return_hiddens = False,
rotary_pos_emb = None,
pos = None,
attn_bias = None,
condition = None,
in_attn_cond = None, # https://arxiv.org/abs/2105.04090
Expand Down Expand Up @@ -1906,7 +1915,9 @@ def forward(
maybe_mem = mems[0] # todo - handle edge case where different layers get different memory lengths. don't think this will ever come up but who knows
mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0

pos = torch.arange(x.shape[1] + mem_len, device = x.device) - mem_len
if not exists(pos):
pos = torch.arange(x.shape[1] + mem_len, device = x.device) - mem_len

rotary_pos_emb = self.rotary_pos_emb(pos)

# assume cached key / values
Expand Down Expand Up @@ -2030,7 +2041,7 @@ def forward(
# forward depending on layer type

if layer_type == 'a':
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
elif layer_type == 'c':
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, return_intermediates = True)
elif layer_type == 'f':
Expand Down

0 comments on commit 7e15c09

Please sign in to comment.