Skip to content

Latest commit

 

History

History
88 lines (71 loc) · 5.64 KB

File metadata and controls

88 lines (71 loc) · 5.64 KB

FMS Acceleration for Fused Operations and Kernels

This library contains fused operations and custom kernels, to be expanded over time. Currently it contains the following:

  1. Fused operations and kernels extracted from unsloth.
    • Low-Rank Adapter Fused Operations
    • Fast RoPE Triton Kernels
    • Fast RMS LayerNorm Triton Kernels
    • Fast Cross Entropy Triton Kernels

Plugins

Plugin Description Depends Loading Augmentation Callbacks
fast_kernels Enhanced version of fast_quantized_peft, also works for full-FT and non-quant peft Contains extracted code

Supported DataType Settings

Compatibility Matrix with Mixed Precision

torch_dtype Mixed Precision Full-FT-FOAK PEFT-FOAK QPEFT-FOAK
FLOAT16 - Compatible Compatible
FLOAT16 FP16 ValueError:
Attempting to
unscale FP16 gradients.
See here
Compatible Compatible
BFLOAT16 - Compatible Compatible
BFLOAT16 BF16 Compatible Compatible Less Performant

NOTE: this chart is also a good reference for supported types, even for the non-FOAK case.

Code Extracted from Unsloth

Notes on the extraction of code from unsloth:

  • While unsloth is released under Apache 2.0, there are comments indicating some exceptions strewn throughout the code base, see an example here.
    it would require a commercial license if used to run on more than 4 GPUs ...
    
  • These exceptions appear to be located around the trainer improvements, see another example here.
  • These exceptions appear around Feb 2024 Release; any code that appears in any file where such exceptions occur is not extracted.
  • Instead in its place, we have adopted a different approach; we adopt the approach of model patching, as opposed unsloths' approach to rewrite the model. Our approach is novel and completely rewritten from scratch.
  • We have also enabled dropout on the lora fused operations.
  • All extracted code appears before the Feb 2024 Release.
  • In the table below we record what was extracted, and the exact commit from which it was taken.
Path Description Extracted From Modifications Date
fused_ops/unsloth_lora QLoRA fast dequant, activation kernels unsloth/main @ 1ecc0185 28 Jan 2024
fused_ops/unsloth_lora/bnb BNB fast lora unsloth/main @ 1ecc0185 fast_lora.py 28 Jan 2024
fused_ops/unsloth_lora/gptq GPTQ fast dequant (triton_v2) jeromeku/main @ 2839d39 fast_lora.py
triton/layers.py
6 Feb 2024
kernels/unsloth Fast RMS, RoPE, CrossEnt kernels unsloth/main @ 1ecc0185 cross_entropy_loss.py
rms_layernorm.py
28 Jan 2024

Supported Models

Model norm pos emb cross-ent fused_lora
LlamaForCausalLM
MistralForCausalLM
MixtralForCausalLM
GPTBigCodeForCausalLM
GraniteForCausalLM

Adding Support For A New Model

It is realtively easy by following an existing template, in what follows we use GraniteForCausalLM as an example.

  • implement a get_mp_rules for the new model, which returns a list of ModelPatcherRule.
  • logic that needs to be changed is the various classes that the rules are triggered on. Import the various module classes likes so:
    from transformers.models.granite.modeling_granite import ( 
        GraniteAttention,
        GraniteMLP,
        GraniteRMSNorm,
    )
  • replace the classes appropriately in various locations in ModelPatcherRule. In particular the ModelPatcherTrigger portions of it. Name rule_id appropriately.
    ModelPatcherRule(
        rule_id="granite-rms",
        trigger=ModelPatcherTrigger(check=GraniteRMSNorm),
        forward=fast_rms_layernorm,
    )

Known Issues

  • MixedPrecision --fp16 or --bf16 should be used with fast_lora.
  • fast_lora has issues with FSDP V1 with the peft style of FSDP wrapping.
    • This is because the adapter's forward functions are bypassed in the fused ops.
    • For AutoGPTQ/QLoRA this is addressed by distributing the adapters using DDP so they will be unsharded in time for the fused ops.
  • fast_rope_embeddings does not work with position_ids. Currently position_ids are ignored and could give wrong results.