This repository implements several types of attention modules in PyTorch, including:
- Attention: The basic attention module
- Multi-head Attention: A multi-head attention module that performs attention on multiple different "heads"(each head is a set of Q, K, V) of the input sequence.
- Multi-Query Attention: A multi-query attention module that allows multiple queries and only one key, value to attend to the same input sequence.
- Grouped-Query Attention: A grouped query attention module that allows queries to be grouped together (each group include multiple queries and only one key) and attended to jointly.
- Linformer: which reduces the overall self-attention complexity from O(n2) to O(n) in both time and space.
multi-query attention
and grouped-query attention
modules is an alternative to multi-head attention
with much lower memory bandwidth requirements. They has been used in many models, the most famous of which are:
I implemented it in a simple way, with the purpose of understanding how attention works. It is an unoptimized version.
If you are looking for attention
with better performance, I suggest:
For more information, please see the following papers:
- Attention is All You Need (Vaswani et al., 2017)
- Fast Transformer Decoding: One Write-Head is All You Need (Noam Shazeer, 2019)
- GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints (Ainslie et al., 2023)
- Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)