-
Notifications
You must be signed in to change notification settings - Fork 1
/
transformer.py
105 lines (89 loc) · 4 KB
/
transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import sys
import os
import numpy as np
import textwrap
wrapper = textwrap.TextWrapper(width=70)
import trax
from trax import layers as tl
from trax.fastmath import numpy as jnp
from trax.supervised import training
from argparse import ArgumentParser
from utils import *
def DotProductAttention(query, key, value, mask):
"""Dot product self-attention.
Args:
query (jax.interpreters.xla.DeviceArray): array of query representations with shape (L_q by d)
key (jax.interpreters.xla.DeviceArray): array of key representations with shape (L_k by d)
value (jax.interpreters.xla.DeviceArray): array of value representations with shape (L_k by d) where L_v = L_k
mask (jax.interpreters.xla.DeviceArray): attention-mask, gates attention with shape (L_q by L_k)
Returns:
jax.interpreters.xla.DeviceArray: Self-attention array for q, k, v arrays. (L_q by L_k)
"""
assert query.shape[-1] == key.shape[-1] == value.shape[-1], "Embedding dimensions of q, k, v aren't all the same"
depth = query.shape[-1]
# Calculate scaled query key dot product according to formula above
dots = jnp.matmul(query, jnp.swapaxes(key, -1, -2)) / jnp.sqrt(depth)
if mask is not None: # The 'None' in this line does not need to be replaced
dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))
# Softmax formula implementation
logsumexp = trax.fastmath.logsumexp(dots, axis=-1, keepdims=True)
dots = jnp.exp(dots - logsumexp)
attention = jnp.matmul(dots, value)
return attention
def compute_attention_heads_closure(n_heads, d_head):
""" Function that simulates environment inside CausalAttention function.
Args:
d_head (int): dimensionality of heads.
n_heads (int): number of attention heads.
Returns:
function: compute_attention_heads function
"""
def compute_attention_heads(x):
""" Compute the attention heads.
Args:
x (jax.interpreters.xla.DeviceArray): tensor with shape (batch_size, seqlen, n_heads X d_head).
Returns:
jax.interpreters.xla.DeviceArray: reshaped tensor with shape (batch_size X n_heads, seqlen, d_head).
"""
batch_size = x.shape[0]
seqlen = x.shape[1]
x = jnp.reshape(x, (batch_size, seqlen, n_heads, d_head))
x = jnp.transpose(x, (0, 2, 1, 3))
x = jnp.reshape(x, (batch_size*n_heads, seqlen, d_head))
return x
return compute_attention_heads
def dot_product_self_attention(q, k, v):
""" Masked dot product self attention.
Args:
q (jax.interpreters.xla.DeviceArray): queries.
k (jax.interpreters.xla.DeviceArray): keys.
v (jax.interpreters.xla.DeviceArray): values.
Returns:
jax.interpreters.xla.DeviceArray: masked dot product self attention tensor.
"""
mask_size = q.shape[-2]
# Creates a matrix with ones below the diagonal and 0s above. It should have shape (1, mask_size, mask_size)
mask = jnp.tril(jnp.ones((1, mask_size, mask_size), dtype=jnp.bool_), k=0)
return DotProductAttention(q, k, v, mask)
def compute_attention_output_closure(n_heads, d_head):
""" Function that simulates environment inside CausalAttention function.
Args:
d_head (int): dimensionality of heads.
n_heads (int): number of attention heads.
Returns:
function: compute_attention_output function
"""
def compute_attention_output(x):
""" Compute the attention output.
Args:
x (jax.interpreters.xla.DeviceArray): tensor with shape (batch_size X n_heads, seqlen, d_head).
Returns:
jax.interpreters.xla.DeviceArray: reshaped tensor with shape (batch_size, seqlen, n_heads X d_head).
"""
# Length of the sequence
seqlen = x.shape[1]
batch_size = int(x.shape[0]/n_heads)
x = jnp.reshape(x, (batch_size, n_heads, seqlen, d_head))
x = jnp.transpose(x, (0, 2, 1, 3))
return jnp.reshape(x, (-1, seqlen, n_heads * d_head))
return compute_attention_output