This library contains fused operations and custom kernels, to be expanded over time. Currently it contains the following:
- Fused operations and kernels extracted from unsloth.
- Low-Rank Adapter Fused Operations
- Fast RoPE Triton Kernels
- Fast RMS LayerNorm Triton Kernels
- Fast Cross Entropy Triton Kernels
Plugin | Description | Depends | Loading | Augmentation | Callbacks |
---|---|---|---|---|---|
fast_kernels | Enhanced version of fast_quantized_peft , also works for full-FT and non-quant peft |
Contains extracted code | ✅ |
Compatibility Matrix with Mixed Precision
torch_dtype | Mixed Precision | Full-FT-FOAK | PEFT-FOAK | QPEFT-FOAK |
---|---|---|---|---|
FLOAT16 | - | Compatible | Compatible | ✗ |
FLOAT16 | FP16 | ValueError: Attempting to unscale FP16 gradients. See here |
Compatible | Compatible |
BFLOAT16 | - | Compatible | Compatible | ✗ |
BFLOAT16 | BF16 | Compatible | Compatible | Less Performant |
NOTE: this chart is also a good reference for supported types, even for the non-FOAK case.
Notes on the extraction of code from unsloth:
- While unsloth is released under Apache 2.0, there are comments indicating some exceptions strewn throughout the code base, see an example here.
it would require a commercial license if used to run on more than 4 GPUs ...
- These exceptions appear to be located around the trainer improvements, see another example here.
- These exceptions appear around Feb 2024 Release; any code that appears in any file where such exceptions occur is not extracted.
- Instead in its place, we have adopted a different approach; we adopt the approach of model patching, as opposed unsloths' approach to rewrite the model. Our approach is novel and completely rewritten from scratch.
- We have also enabled dropout on the lora fused operations.
- All extracted code appears before the Feb 2024 Release.
- In the table below we record what was extracted, and the exact commit from which it was taken.
Path | Description | Extracted From | Modifications | Date |
---|---|---|---|---|
fused_ops/unsloth_lora | QLoRA fast dequant, activation kernels | unsloth/main @ 1ecc0185 |
28 Jan 2024 | |
fused_ops/unsloth_lora/bnb | BNB fast lora | unsloth/main @ 1ecc0185 |
fast_lora.py |
28 Jan 2024 |
fused_ops/unsloth_lora/gptq | GPTQ fast dequant (triton_v2) | jeromeku/main @ 2839d39 |
fast_lora.py triton/layers.py |
6 Feb 2024 |
kernels/unsloth | Fast RMS, RoPE, CrossEnt kernels | unsloth/main @ 1ecc0185 |
cross_entropy_loss.py rms_layernorm.py |
28 Jan 2024 |
Model | norm | pos emb | cross-ent | fused_lora |
---|---|---|---|---|
LlamaForCausalLM |
✅ | ✅ | ✅ | ✅ |
MistralForCausalLM |
✅ | ✅ | ✅ | ✅ |
MixtralForCausalLM |
✅ | ✅ | ✅ | ✅ |
GPTBigCodeForCausalLM |
❌ | ❌ | ✅ | ❌ |
GraniteForCausalLM |
✅ | ✅ | ✅ | ✅ |
It is realtively easy by following an existing template, in what follows we use GraniteForCausalLM as an example.
- implement a
get_mp_rules
for the new model, which returns a list ofModelPatcherRule
. - logic that needs to be changed is the various classes that the rules are triggered on. Import the various module classes likes so:
from transformers.models.granite.modeling_granite import ( GraniteAttention, GraniteMLP, GraniteRMSNorm, )
- replace the classes appropriately in various locations in
ModelPatcherRule
. In particular theModelPatcherTrigger
portions of it. Namerule_id
appropriately.ModelPatcherRule( rule_id="granite-rms", trigger=ModelPatcherTrigger(check=GraniteRMSNorm), forward=fast_rms_layernorm, )
- MixedPrecision
--fp16
or--bf16
should be used withfast_lora
. fast_lora
has issues with FSDP V1 with thepeft
style of FSDP wrapping.- This is because the adapter's forward functions are bypassed in the fused ops.
- For AutoGPTQ/QLoRA this is addressed by distributing the adapters using DDP so they will be unsharded in time for the fused ops.
fast_rope_embeddings
does not work with position_ids. Currentlyposition_ids
are ignored and could give wrong results.