diff --git a/palm_rlhf_pytorch/attention.py b/palm_rlhf_pytorch/attention.py index 9716f0f..9fad546 100644 --- a/palm_rlhf_pytorch/attention.py +++ b/palm_rlhf_pytorch/attention.py @@ -10,7 +10,7 @@ # constants -Config = namedtuple('FlashConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) +Config = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) # helpers @@ -41,13 +41,30 @@ def __init__( ): super().__init__() self.dropout = dropout - self.causal = causal self.attn_dropout = nn.Dropout(dropout) + self.causal = causal + self.register_buffer("mask", None, persistent=False) + self.use_flash_attn = use_flash_attn assert not (use_flash_attn and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above' - self.register_buffer("mask", None, persistent=False) + # determine efficient attention configs for cuda and cpu + + self.cpu_config = Config(True, True, True) + self.cuda_config = None + + if not torch.cuda.is_available() or not use_flash_attn: + return + + device_properties = torch.cuda.get_device_properties(torch.device('cuda')) + + if device_properties.major == 8 and device_properties.minor == 0: + print_once('A100 GPU detected, using flash attention if input tensor is on cuda') + self.cuda_config = Config(True, False, False) + else: + print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda') + self.cuda_config = Config(False, True, True) def get_mask(self, n, device): if exists(self.mask) and self.mask.shape[-1] >= n: @@ -58,7 +75,7 @@ def get_mask(self, n, device): return mask def flash_attn(self, q, k, v, mask = None): - _, heads, q_len, _, k_len = *q.shape, k.shape[-2] + _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda # Recommended for multi-query single-key-value attention by Tri Dao # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64]) @@ -75,25 +92,7 @@ def flash_attn(self, q, k, v, mask = None): # Check if there is a compatible device for flash attention - device_str = 'cuda' if torch.cuda.is_available() and q.is_cuda else 'cpu' - device = torch.device(device_str) - - try: - if device_str == 'cuda': - device_properties = torch.cuda.get_device_properties(device) - - if device_properties.major == 8 and device_properties.minor == 0: - print_once('A100 GPU detected, using flash attention') - config = Config(True, False, False) - else: - print_once('Non-A100 GPU detected, using math or mem efficient attention') - config = Config(False, True, True) - else: - print_once('CPU detected, using default context manager settings') - config = Config(True, True, True) - - except RuntimeError as error: - print(f'An error occurred: {error}.') + config = self.cuda_config if is_cuda else self.cpu_config # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale diff --git a/setup.py b/setup.py index 311f00e..8d45f38 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'PaLM-rlhf-pytorch', packages = find_packages(exclude=[]), - version = '0.1.1', + version = '0.1.2', license='MIT', description = 'PaLM + Reinforcement Learning with Human Feedback - Pytorch', author = 'Phil Wang',