Skip to content

Commit

Permalink
Add external KV
Browse files Browse the repository at this point in the history
  • Loading branch information
gordicaleksa committed Aug 10, 2024
1 parent 6e6a528 commit aeb9817
Showing 1 changed file with 29 additions and 20 deletions.
49 changes: 29 additions & 20 deletions train_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,9 @@
import inspect
from contextlib import nullcontext
from dataclasses import dataclass
import json
from pathlib import Path
from typing import (
AbstractSet,
Callable,
Collection,
Dict,
Iterator,
Expand Down Expand Up @@ -58,6 +56,25 @@
# using a global to toggle flash-attention
FLASH = 0

class KVCache:
def __init__(self, num_layers, max_batch_size, block_size, n_kv_head, hd, device, dtype):
self.kv = [KV(max_batch_size, block_size, n_kv_head, hd, device, dtype) for _ in range(num_layers)]

def clear(self):
self.kv = []

class KV:
# Static KV cache, we preallocate the memory for the key and value tensors
def __init__(self, max_batch_size, block_size, n_kv_head, hd, device, dtype):
self.k = torch.zeros((max_batch_size, block_size, n_kv_head, hd), dtype=dtype).to(device)
self.v = torch.zeros((max_batch_size, block_size, n_kv_head, hd), dtype=dtype).to(device)

def add(self, k, v, B, T, start_pos):
assert B == k.shape[0] and T == k.shape[1] and self.k.shape[2] == k.shape[2] and self.k.shape[3] == k.shape[3]
self.k[:B, start_pos : start_pos + T] = k
self.v[:B, start_pos : start_pos + T] = v
return self.k[:B, : start_pos + T], self.v[:B, : start_pos + T]

# Used in Grouped Query Attention (GQA), broadcasts the key and value tensors
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
Expand Down Expand Up @@ -156,17 +173,11 @@ def __init__(self, config):
self.n_kv_head = config.n_kv_head
self.n_rep = self.n_head // self.n_kv_head
self.hd = config.n_embd // config.n_head
self.use_kv = config.use_kv

self.c_attn = nn.Linear(config.n_embd, (config.n_head + 2 * config.n_kv_head) * self.hd, bias=False) # key, query, value projections
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) # output projection

# static KV cache - we could alternatively allocate it outside of the model and just pass it in when needed
if self.use_kv:
self.cache_k = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd))
self.cache_v = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd))

def forward(self, x, freqs_cis=None, start_pos=None, mask=None):
def forward(self, x, freqs_cis=None, kv_cache=None, start_pos=None, mask=None):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
qkv = self.c_attn(x)
Expand All @@ -175,11 +186,8 @@ def forward(self, x, freqs_cis=None, start_pos=None, mask=None):

q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis) # rotate QK (rope) <-- 1. difference compared to GPT-2

if self.use_kv and not self.training and start_pos >= 0: # use kv-caching during inference
self.cache_k[:B, start_pos : start_pos + T] = k
self.cache_v[:B, start_pos : start_pos + T] = v
k = self.cache_k[:B, : start_pos + T]
v = self.cache_v[:B, : start_pos + T]
if kv_cache and not self.training: # use kv-caching during inference
k, v = kv_cache.add(k, v, B, T, start_pos)

k = repeat_kv(k, self.n_rep) # GQA <-- 2. difference compared to GPT-2
v = repeat_kv(v, self.n_rep)
Expand Down Expand Up @@ -233,8 +241,8 @@ def __init__(self, config):
self.ln_2 = RMSNorm(config.n_embd, config.norm_eps)
self.mlp = MLP(config)

def forward(self, x, freqs_cis=None, start_pos=None, mask=None):
x = x + self.attn(self.ln_1(x), freqs_cis, start_pos, mask)
def forward(self, x, freqs_cis=None, kv_cache=None, start_pos=None, mask=None):
x = x + self.attn(self.ln_1(x), freqs_cis, kv_cache, start_pos, mask)
x = x + self.mlp(self.ln_2(x))
return x

Expand All @@ -256,7 +264,6 @@ class LlamaConfig:
rope_theta: float = 500000.0
use_scaled_rope: bool = True
max_gen_batch_size: int = 4
use_kv: bool = True

def __init__(self, **kwargs):
for k, v in kwargs.items():
Expand Down Expand Up @@ -290,7 +297,7 @@ def __init__(self, config):
config.use_scaled_rope,
)

def forward(self, idx, targets=None, return_logits=True, start_pos=0):
def forward(self, idx, targets=None, return_logits=True, kv_cache=None, start_pos=0):
_, t = idx.size()
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"

Expand All @@ -301,7 +308,7 @@ def forward(self, idx, targets=None, return_logits=True, start_pos=0):
mask = torch.triu(torch.ones((t, t), device=next(self.parameters()).device, dtype=torch.bool), diagonal=1)

for i, block in enumerate(self.transformer.h):
x = block(x, freqs_cis, start_pos, mask)
x = block(x, freqs_cis, kv_cache.kv[i] if kv_cache is not None else None, start_pos, mask)
x = self.transformer.ln_f(x)

if targets is not None:
Expand Down Expand Up @@ -530,8 +537,10 @@ def generate(

stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens)).to(device)

dtype = self.transformer.wte.weight.dtype
kv_cache = KVCache(self.config.n_layer, bsz, self.config.block_size, self.config.n_kv_head, self.config.n_embd // self.config.n_head, device, dtype)
for cur_pos in range(min_prompt_len, total_len):
logits, _ = self.forward(tokens[:, prev_pos:cur_pos], start_pos=prev_pos)
logits, _ = self.forward(tokens[:, prev_pos:cur_pos], kv_cache=kv_cache, start_pos=prev_pos)
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
Expand Down

0 comments on commit aeb9817

Please sign in to comment.