Skip to content

Commit

Permalink
improvise a bidirectional forgetting transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 31, 2024
1 parent 881be6b commit 85cffc6
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.42.4',
version = '1.42.5',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
25 changes: 21 additions & 4 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,12 +512,15 @@ def __init__(
self,
dim,
heads,
causal = True,
bias_init = 5.,
post_log_scale = 1.
post_log_scale = 1.,
):
super().__init__()

linear = nn.Linear(dim, heads)
self.causal = causal

linear = nn.Linear(dim, heads * (1 if causal else 2))

self.to_forget_gates = nn.Sequential(
linear,
Expand All @@ -529,9 +532,21 @@ def __init__(
self.post_log_scale = post_log_scale

def forward(self, x):
bidirectional = not self.causal

forget_gates = self.to_forget_gates(x) * self.post_log_scale

forget_gates = forget_gates.cumsum(dim = -1)

if bidirectional:
forget_gates, forget_gates_reversed = forget_gates.chunk(2, dim = 1)

forget_gates = einx.subtract('b h i, b h j -> b h i j', forget_gates, forget_gates)

if bidirectional:
forget_gates_reversed = einx.subtract('b h j, b h i -> b h i j', forget_gates_reversed, forget_gates_reversed)
forget_gates = forget_gates.tril() + forget_gates_reversed.triu()

return forget_gates

class PerRowDataDependentAlibi(Module):
Expand All @@ -541,10 +556,13 @@ def __init__(
self,
dim,
heads,
causal = True,
dim_head = 8,
post_log_scale = 1.
):
super().__init__()
assert causal, 'bidirectional not supported yet'

self.scale = dim_head ** -0.5

linear = nn.Linear(dim, heads * dim_head * 2, bias = False)
Expand Down Expand Up @@ -1138,10 +1156,9 @@ def __init__(
self.data_dependent_alibi = None

if data_dependent_alibi:
assert causal, 'data dependent alibi only works for autoregressive for now until further research'

dda_klass = DataDependentAlibi if not data_dependent_alibi_per_row else PerRowDataDependentAlibi
dda_kwargs = dict(dim = dim, heads = heads)
dda_kwargs = dict(dim = dim, heads = heads, causal = causal)

if data_dependent_alibi_per_row:
dda_kwargs.update(dim_head = data_dependent_alibi_per_row_dim_head)
Expand Down

0 comments on commit 85cffc6

Please sign in to comment.