Skip to content

Commit

Permalink
flash attention sdp context config only needs to be done once
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Mar 17, 2023
1 parent 734afb3 commit 6c8028a
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 24 deletions.
45 changes: 22 additions & 23 deletions palm_rlhf_pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

# constants

Config = namedtuple('FlashConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
Config = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# helpers

Expand Down Expand Up @@ -41,13 +41,30 @@ def __init__(
):
super().__init__()
self.dropout = dropout
self.causal = causal
self.attn_dropout = nn.Dropout(dropout)

self.causal = causal
self.register_buffer("mask", None, persistent=False)

self.use_flash_attn = use_flash_attn
assert not (use_flash_attn and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

self.register_buffer("mask", None, persistent=False)
# determine efficient attention configs for cuda and cpu

self.cpu_config = Config(True, True, True)
self.cuda_config = None

if not torch.cuda.is_available() or not use_flash_attn:
return

device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config = Config(True, False, False)
else:
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = Config(False, True, True)

def get_mask(self, n, device):
if exists(self.mask) and self.mask.shape[-1] >= n:
Expand All @@ -58,7 +75,7 @@ def get_mask(self, n, device):
return mask

def flash_attn(self, q, k, v, mask = None):
_, heads, q_len, _, k_len = *q.shape, k.shape[-2]
_, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda

# Recommended for multi-query single-key-value attention by Tri Dao
# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
Expand All @@ -75,25 +92,7 @@ def flash_attn(self, q, k, v, mask = None):

# Check if there is a compatible device for flash attention

device_str = 'cuda' if torch.cuda.is_available() and q.is_cuda else 'cpu'
device = torch.device(device_str)

try:
if device_str == 'cuda':
device_properties = torch.cuda.get_device_properties(device)

if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention')
config = Config(True, False, False)
else:
print_once('Non-A100 GPU detected, using math or mem efficient attention')
config = Config(False, True, True)
else:
print_once('CPU detected, using default context manager settings')
config = Config(True, True, True)

except RuntimeError as error:
print(f'An error occurred: {error}.')
config = self.cuda_config if is_cuda else self.cpu_config

# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'PaLM-rlhf-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.1',
version = '0.1.2',
license='MIT',
description = 'PaLM + Reinforcement Learning with Human Feedback - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 6c8028a

Please sign in to comment.