diff --git a/README.md b/README.md index c1f6a22..ddb04d5 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,8 @@ from palm_rlhf_pytorch import PaLM palm = PaLM( num_tokens = 20000, dim = 512, - depth = 12 + depth = 12, + flash_attn = True ).cuda() seq = torch.randint(0, 20000, (1, 2048)).cuda() @@ -209,9 +210,9 @@ answer = trainer.generate(2048, prompt = prompts[0], num_samples = 10) # (<= 204 ```bibtex @inproceedings{dao2022flashattention, - title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness}, - author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, - booktitle={Advances in Neural Information Processing Systems}, - year={2022} + title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness}, + author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, + booktitle = {Advances in Neural Information Processing Systems}, + year = {2022} } ``` diff --git a/palm_rlhf_pytorch/attention.py b/palm_rlhf_pytorch/attention.py new file mode 100644 index 0000000..f0f7507 --- /dev/null +++ b/palm_rlhf_pytorch/attention.py @@ -0,0 +1,143 @@ +import torch +from torch import nn, einsum +import torch.nn.functional as F + +from collections import namedtuple +from functools import wraps +from packaging import version + +from einops import rearrange + +# constants + +Config = namedtuple('FlashConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) + +# helpers + +def exists(val): + return val is not None + +def once(fn): + called = False + @wraps(fn) + def inner(x): + nonlocal called + if called: + return + called = True + return fn(x) + return inner + +print_once = once(print) + +# main class + +class Attention(nn.Module): + def __init__( + self, + dropout = 0., + causal = False, + use_flash_attn = False + ): + super().__init__() + self.dropout = dropout + self.causal = causal + self.attn_dropout = nn.Dropout(dropout) + + assert version.parse(torch.__version__) >= version.parse('2.0.0'), 'in order to use flash attention, you must be using pytorch 2.0 or above' + self.use_flash_attn = use_flash_attn + + self.register_buffer("mask", None, persistent=False) + + def get_mask(self, n, device): + if exists(self.mask) and self.mask.shape[-1] >= n: + return self.mask[:n, :n] + + mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) + self.register_buffer("mask", mask, persistent=False) + return mask + + def flash_attn(self, q, k, v, mask = None): + _, heads, q_len, _, k_len = *q.shape, k.shape[-2] + + # Recommended for multi-query single-key-value attention by Tri Dao + # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64]) + + k = rearrange(k, 'b ... -> b 1 ...').expand_as(q) + v = rearrange(v, 'b ... -> b 1 ...').expand_as(q) + + # Check if mask exists and expand to compatible shape + # The mask is B L, so it would have to be expanded to B H N L + + if exists(mask): + mask = rearrange(mask, 'b j -> b 1 1 j') + mask = mask.expand(-1, heads, q_len, -1) + + # Check if there is a compatible device for flash attention + + device_str = 'cuda' if torch.cuda.is_available() and q.is_cuda else 'cpu' + device = torch.device(device_str) + + try: + if device_str == 'cuda': + device_properties = torch.cuda.get_device_properties(device) + + if device_properties.major == 8 and device_properties.minor == 0: + print_once('A100 GPU detected, using flash attention') + config = Config(True, False, False) + else: + print_once('Non-A100 GPU detected, using math or mem efficient attention') + config = Config(False, True, True) + else: + print_once('CPU detected, using default context manager settings') + config = Config(True, True, True) + + except RuntimeError as error: + print(f'An error occurred: {error}.') + + # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale + + with torch.backends.cuda.sdp_kernel(**config._asdict()): + out = F.scaled_dot_product_attention( + q, k, v, + attn_mask = mask, + dropout_p = self.dropout if self.training else 0., + is_causal = self.causal + ) + + return out + + def forward(self, q, k, v, mask = None): + n, device = q.shape[-2], q.device + + scale = q.shape[-1] ** -0.5 + + if self.use_flash_attn: + return self.flash_attn(q, k, v, mask = mask) + + # similarity + + sim = einsum("b h i d, b j d -> b h i j", q, k) * scale + + # key padding mask + + if exists(mask): + mask = rearrange(mask, 'b j -> b 1 1 j') + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + + # causal mask + + if self.causal: + causal_mask = self.get_mask(n, device) + sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) + + # attention + + attn = sim.softmax(dim=-1) + attn = self.attn_dropout(attn) + + # aggregate values + + out = einsum("b h i j, b j d -> b h i d", attn, v) + + return out diff --git a/palm_rlhf_pytorch/palm.py b/palm_rlhf_pytorch/palm.py index 64eb90d..cbc1692 100644 --- a/palm_rlhf_pytorch/palm.py +++ b/palm_rlhf_pytorch/palm.py @@ -2,8 +2,8 @@ import copy from pathlib import Path from collections import namedtuple +from functools import wraps from itertools import zip_longest -from packaging import version from tqdm import tqdm from beartype import beartype @@ -16,6 +16,7 @@ from einops import rearrange, repeat, reduce, pack, unpack from einops.layers.torch import Rearrange, Reduce +from palm_rlhf_pytorch.attention import Attention from palm_rlhf_pytorch.utils import top_p, top_k, masked_mean, gumbel_sample, eval_decorator from palm_rlhf_pytorch.lora import LoRA @@ -142,6 +143,12 @@ def __init__( self.q_scale = nn.Parameter(torch.ones(dim_head)) self.k_scale = nn.Parameter(torch.ones(dim_head)) + self.attend = Attention( + causal = causal, + dropout = attn_dropout, + use_flash_attn = flash_attn + ) + self.heads = heads self.scale = (dim_head ** -0.5) if not qk_rmsnorm else qk_scale self.causal = causal @@ -165,18 +172,9 @@ def __init__( # for caching causal mask and rotary embeddings - self.register_buffer("mask", None, persistent=False) self.register_buffer("pos_emb", None, persistent=False) self.register_buffer("pos_emb_scale", None, persistent=False) - def get_mask(self, n, device): - if exists(self.mask) and self.mask.shape[-1] >= n: - return self.mask[:n, :n] - - mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) - self.register_buffer("mask", mask, persistent=False) - return mask - def get_rotary_embedding(self, n, device): if exists(self.pos_emb) and self.pos_emb.shape[-2] >= n: return self.pos_emb[:n], self.pos_emb_scale[:n] @@ -241,101 +239,9 @@ def forward( q = apply_rotary_pos_emb(positions, q, scale) k = apply_rotary_pos_emb(positions, k, scale ** -1) - # flash attention triton - - if self.flash_attn: - - # Check to see if the correct version of PyTorch is supported - - try: - assert version.parse(torch.__version__) >= version.parse('2.0.0') - except: - raise Exception("flash attention requires pytorch 2.0") - - # Recommended for multi-query single-key-value attention by Tri Dao - # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64]) - - k = k.unsqueeze(1).expand_as(q) - v = v.unsqueeze(1).expand_as(q) - - # Check if mask exists and expand to compatible shape - # The mask is B L, so it would have to be expanded to B N L - - if exists(mask): - mask = mask.unsqueeze(1).unsqueeze(-1).expand(-1, h, q.shape[-2], -1) - - # Check if there is a compatible device for flash attention - - try: - if torch.cuda.is_available(): - if x.device.type == 'cuda': - flash_device = torch.device('cuda') - else: - flash_device = torch.device('cpu') - else: - flash_device = torch.device('cpu') - - if flash_device.type == 'cuda': - device_properties = torch.cuda.get_device_properties(device) - if device_properties.major == 8 and device_properties.minor == 0: - print('A100 GPU detected, using flash attention') - enable_flash = True - enable_math = False - enable_mem_efficient = False - else: - print('Non-A100 GPU detected, using math or mem efficient attention') - enable_flash = False - enable_math = True - enable_mem_efficient = True - elif flash_device.type == 'cpu': - # Default context manager settings with CPU - print('CPU detected, using default context manager settings') - enable_flash = True - enable_math = True - enable_mem_efficient = True - except RuntimeError as error: - print(f'An error occurred: {error}.') - - # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale - - with torch.backends.cuda.sdp_kernel( - enable_flash=enable_flash, - enable_math=enable_math, - enable_mem_efficient=enable_mem_efficient - ): - out = F.scaled_dot_product_attention( - q, k, v, - attn_mask = mask, - dropout_p = self.flash_attn_dropout, - is_causal = self.causal, - scale = self.scale - ) - - else: - # similarity - - sim = einsum("b h i d, b j d -> b h i j", q, k) * self.scale - - # key padding mask - - if exists(mask): - mask = rearrange(mask, 'b j -> b 1 1 j') - sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) - - # causal mask - - if self.causal: - causal_mask = self.get_mask(n, device) - sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) - - # attention - - attn = sim.softmax(dim=-1) - attn = self.attn_dropout(attn) - - # aggregate values + # attention function, either regular or flash - out = einsum("b h i j, b j d -> b h i d", attn, v) + out = self.attend(q, k, v, mask = mask) # merge heads diff --git a/setup.py b/setup.py index c1a2599..24f3af8 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'PaLM-rlhf-pytorch', packages = find_packages(exclude=[]), - version = '0.0.68', + version = '0.1.0', license='MIT', description = 'PaLM + Reinforcement Learning with Human Feedback - Pytorch', author = 'Phil Wang', diff --git a/train.py b/train.py index 27fd86a..440e0b7 100644 --- a/train.py +++ b/train.py @@ -47,7 +47,8 @@ def decode_tokens(tokens): model = PaLM( num_tokens=256, dim=512, - depth=8 + depth=8, + flash_attn=True ).to(device) # prepare enwik8 data