Skip to content

Commit

Permalink
add ability to use flash attention if using pytorch 2.0, thanks to @c…
Browse files Browse the repository at this point in the history
…onceptofmind for the initial PR!
  • Loading branch information
lucidrains committed Mar 17, 2023
1 parent caf4755 commit 0a18450
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 111 deletions.
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ from palm_rlhf_pytorch import PaLM
palm = PaLM(
num_tokens = 20000,
dim = 512,
depth = 12
depth = 12,
flash_attn = True
).cuda()

seq = torch.randint(0, 20000, (1, 2048)).cuda()
Expand Down Expand Up @@ -209,9 +210,9 @@ answer = trainer.generate(2048, prompt = prompts[0], num_samples = 10) # (<= 204

```bibtex
@inproceedings{dao2022flashattention,
title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle={Advances in Neural Information Processing Systems},
year={2022}
title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle = {Advances in Neural Information Processing Systems},
year = {2022}
}
```
143 changes: 143 additions & 0 deletions palm_rlhf_pytorch/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import torch
from torch import nn, einsum
import torch.nn.functional as F

from collections import namedtuple
from functools import wraps
from packaging import version

from einops import rearrange

# constants

Config = namedtuple('FlashConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# helpers

def exists(val):
return val is not None

def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner

print_once = once(print)

# main class

class Attention(nn.Module):
def __init__(
self,
dropout = 0.,
causal = False,
use_flash_attn = False
):
super().__init__()
self.dropout = dropout
self.causal = causal
self.attn_dropout = nn.Dropout(dropout)

assert version.parse(torch.__version__) >= version.parse('2.0.0'), 'in order to use flash attention, you must be using pytorch 2.0 or above'
self.use_flash_attn = use_flash_attn

self.register_buffer("mask", None, persistent=False)

def get_mask(self, n, device):
if exists(self.mask) and self.mask.shape[-1] >= n:
return self.mask[:n, :n]

mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
self.register_buffer("mask", mask, persistent=False)
return mask

def flash_attn(self, q, k, v, mask = None):
_, heads, q_len, _, k_len = *q.shape, k.shape[-2]

# Recommended for multi-query single-key-value attention by Tri Dao
# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])

k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)

# Check if mask exists and expand to compatible shape
# The mask is B L, so it would have to be expanded to B H N L

if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
mask = mask.expand(-1, heads, q_len, -1)

# Check if there is a compatible device for flash attention

device_str = 'cuda' if torch.cuda.is_available() and q.is_cuda else 'cpu'
device = torch.device(device_str)

try:
if device_str == 'cuda':
device_properties = torch.cuda.get_device_properties(device)

if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention')
config = Config(True, False, False)
else:
print_once('Non-A100 GPU detected, using math or mem efficient attention')
config = Config(False, True, True)
else:
print_once('CPU detected, using default context manager settings')
config = Config(True, True, True)

except RuntimeError as error:
print(f'An error occurred: {error}.')

# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale

with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
dropout_p = self.dropout if self.training else 0.,
is_causal = self.causal
)

return out

def forward(self, q, k, v, mask = None):
n, device = q.shape[-2], q.device

scale = q.shape[-1] ** -0.5

if self.use_flash_attn:
return self.flash_attn(q, k, v, mask = mask)

# similarity

sim = einsum("b h i d, b j d -> b h i j", q, k) * scale

# key padding mask

if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

# causal mask

if self.causal:
causal_mask = self.get_mask(n, device)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

# attention

attn = sim.softmax(dim=-1)
attn = self.attn_dropout(attn)

# aggregate values

out = einsum("b h i j, b j d -> b h i d", attn, v)

return out
114 changes: 10 additions & 104 deletions palm_rlhf_pytorch/palm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import copy
from pathlib import Path
from collections import namedtuple
from functools import wraps
from itertools import zip_longest
from packaging import version

from tqdm import tqdm
from beartype import beartype
Expand All @@ -16,6 +16,7 @@
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange, Reduce

from palm_rlhf_pytorch.attention import Attention
from palm_rlhf_pytorch.utils import top_p, top_k, masked_mean, gumbel_sample, eval_decorator
from palm_rlhf_pytorch.lora import LoRA

Expand Down Expand Up @@ -142,6 +143,12 @@ def __init__(
self.q_scale = nn.Parameter(torch.ones(dim_head))
self.k_scale = nn.Parameter(torch.ones(dim_head))

self.attend = Attention(
causal = causal,
dropout = attn_dropout,
use_flash_attn = flash_attn
)

self.heads = heads
self.scale = (dim_head ** -0.5) if not qk_rmsnorm else qk_scale
self.causal = causal
Expand All @@ -165,18 +172,9 @@ def __init__(

# for caching causal mask and rotary embeddings

self.register_buffer("mask", None, persistent=False)
self.register_buffer("pos_emb", None, persistent=False)
self.register_buffer("pos_emb_scale", None, persistent=False)

def get_mask(self, n, device):
if exists(self.mask) and self.mask.shape[-1] >= n:
return self.mask[:n, :n]

mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
self.register_buffer("mask", mask, persistent=False)
return mask

def get_rotary_embedding(self, n, device):
if exists(self.pos_emb) and self.pos_emb.shape[-2] >= n:
return self.pos_emb[:n], self.pos_emb_scale[:n]
Expand Down Expand Up @@ -241,101 +239,9 @@ def forward(
q = apply_rotary_pos_emb(positions, q, scale)
k = apply_rotary_pos_emb(positions, k, scale ** -1)

# flash attention triton

if self.flash_attn:

# Check to see if the correct version of PyTorch is supported

try:
assert version.parse(torch.__version__) >= version.parse('2.0.0')
except:
raise Exception("flash attention requires pytorch 2.0")

# Recommended for multi-query single-key-value attention by Tri Dao
# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])

k = k.unsqueeze(1).expand_as(q)
v = v.unsqueeze(1).expand_as(q)

# Check if mask exists and expand to compatible shape
# The mask is B L, so it would have to be expanded to B N L

if exists(mask):
mask = mask.unsqueeze(1).unsqueeze(-1).expand(-1, h, q.shape[-2], -1)

# Check if there is a compatible device for flash attention

try:
if torch.cuda.is_available():
if x.device.type == 'cuda':
flash_device = torch.device('cuda')
else:
flash_device = torch.device('cpu')
else:
flash_device = torch.device('cpu')

if flash_device.type == 'cuda':
device_properties = torch.cuda.get_device_properties(device)
if device_properties.major == 8 and device_properties.minor == 0:
print('A100 GPU detected, using flash attention')
enable_flash = True
enable_math = False
enable_mem_efficient = False
else:
print('Non-A100 GPU detected, using math or mem efficient attention')
enable_flash = False
enable_math = True
enable_mem_efficient = True
elif flash_device.type == 'cpu':
# Default context manager settings with CPU
print('CPU detected, using default context manager settings')
enable_flash = True
enable_math = True
enable_mem_efficient = True
except RuntimeError as error:
print(f'An error occurred: {error}.')

# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale

with torch.backends.cuda.sdp_kernel(
enable_flash=enable_flash,
enable_math=enable_math,
enable_mem_efficient=enable_mem_efficient
):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
dropout_p = self.flash_attn_dropout,
is_causal = self.causal,
scale = self.scale
)

else:
# similarity

sim = einsum("b h i d, b j d -> b h i j", q, k) * self.scale

# key padding mask

if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

# causal mask

if self.causal:
causal_mask = self.get_mask(n, device)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

# attention

attn = sim.softmax(dim=-1)
attn = self.attn_dropout(attn)

# aggregate values
# attention function, either regular or flash

out = einsum("b h i j, b j d -> b h i d", attn, v)
out = self.attend(q, k, v, mask = mask)

# merge heads

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'PaLM-rlhf-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.68',
version = '0.1.0',
license='MIT',
description = 'PaLM + Reinforcement Learning with Human Feedback - Pytorch',
author = 'Phil Wang',
Expand Down
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def decode_tokens(tokens):
model = PaLM(
num_tokens=256,
dim=512,
depth=8
depth=8,
flash_attn=True
).to(device)

# prepare enwik8 data
Expand Down

0 comments on commit 0a18450

Please sign in to comment.