Skip to content

Commit

Permalink
demo
Browse files Browse the repository at this point in the history
  • Loading branch information
gordicaleksa committed Aug 2, 2024
1 parent c414d02 commit b499ff3
Showing 1 changed file with 20 additions and 14 deletions.
34 changes: 20 additions & 14 deletions train_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,8 @@ def __init__(self, config):
self.n_rep = self.n_head // self.n_kv_head
self.head_dim = config.n_embd // config.n_head

# TODO(gordicaleksa): this can be easily made the same as the above (c_attn, c_proj)
self.wq = nn.Linear(config.n_embd, config.n_head * self.head_dim, bias=False)
self.wk = nn.Linear(config.n_embd, config.n_kv_head * self.head_dim, bias=False)
self.wv = nn.Linear(config.n_embd, config.n_kv_head * self.head_dim, bias=False)
self.wo = nn.Linear(config.n_head * self.head_dim, config.n_embd, bias=False)
self.c_attn = nn.Linear(config.n_embd, (config.n_head + 2 * config.n_kv_head) * self.head_dim)
self.c_proj = nn.Linear(config.n_head * self.head_dim, config.n_embd, bias=False)

# not really a 'bias', more of a mask, but following the OpenAI/HF naming though
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
Expand Down Expand Up @@ -114,7 +111,8 @@ def forward(self, x, freqs_cis=None):
bsz, seqlen, _ = x.shape

# QKV
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq, xk, xv = torch.split(self.c_attn(x), [self.n_head * self.head_dim, self.n_kv_head * self.head_dim, self.n_kv_head * self.head_dim], dim=-1)

xq = xq.view(bsz, seqlen, self.n_head, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_kv_head, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_kv_head, self.head_dim)
Expand All @@ -136,7 +134,7 @@ def forward(self, x, freqs_cis=None):
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)
return self.c_proj(output)

class MLP(nn.Module):

Expand Down Expand Up @@ -308,16 +306,24 @@ def adapt_llama_state_dict_keys(checkpoint, config: LlamaConfig):
new_key = f'transformer.h.{i}.ln_{1 if name == "attention_norm" else 2}.{suffix}'
checkpoint[new_key] = checkpoint.pop(old_key)

# layers.x.attention.wq.weight -> transformer.h.x.attn.wq.weight
# layers.x.attention.wk.weight -> transformer.h.x.attn.wk.weight
# layers.x.attention.wv.weight -> transformer.h.x.attn.wv.weight
# layers.x.attention.wo.weight -> transformer.h.x.attn.wo.weight
# we merge the following 3:
# layers.x.attention.wq.weight
# layers.x.attention.wk.weight
# layers.x.attention.wv.weight
# into transformer.h.x.attn.c_attn.weight
# layers.x.attention.wo.weight -> transformer.h.x.attn.c_proj.weight
for i in range(config.n_layer):
for name in ['attention.wq', 'attention.wk', 'attention.wv', 'attention.wo']:
for name in ['attention.wq', 'attention.wk', 'attention.wv']:
for suffix in ['weight']:
old_key = f'layers.{i}.{name}.{suffix}'
new_key = f'transformer.h.{i}.attn.{name.split(".")[-1]}.{suffix}'
checkpoint[new_key] = checkpoint.pop(old_key)
new_key = f'transformer.h.{i}.attn.c_attn.weight'
if name == 'attention.wq':
checkpoint[new_key] = checkpoint.pop(old_key)
else:
checkpoint[new_key] = torch.cat((checkpoint[new_key], checkpoint.pop(old_key)), dim=0)
old_key = f'layers.{i}.attention.wo.weight'
new_key = f'transformer.h.{i}.attn.c_proj.weight'
checkpoint[new_key] = checkpoint.pop(old_key)

# layers.x.feed_forward.w1.weight -> transformer.h.x.mlp.w1.weight
# layers.x.feed_forward.w2.weight -> transformer.h.x.mlp.w2.weight
Expand Down

0 comments on commit b499ff3

Please sign in to comment.