Skip to content

MLoRAx is a minimalist library for low-rank adaptation designd to effortlessly enable parameter-efficient training for Transformer-based models based on JAX.

License

Notifications You must be signed in to change notification settings

yongchanghao/MLoRAx

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Installation

Choice 1: Install with pip

You can install mlorax with pip. This is the recommended way to use mlorax if you want to receive future updates.

pip install mlorax

Choice 2: Just copy the code

You can also directly copy the code from mlorax.py and paste it into your project. This is the easiest way to use mlorax if you do not care about future updates.

Choice 3: Install from source

You can also install mlorax from source. You only need to do this if you want to contribute to the project.

git clone https://github.com/yongchanghao/MLoRAx.git
cd MLoRAx
pip install -e .

Usage

It is extremely easy to use mlorax to convert any Flax model to a LoRA model. The following code snippet shows how to convert a T5 model to a LoRA model based on HuggingFace's FlaxT5ForConditionalGeneration class.

+ import mlorax
  model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
- params = model.params
- apply_fn = model.__call__
+ lora_spec = mlorax.LoRASpec(rank=16, rules=['Attention.q', 'Attention.v'])
+ params, apply_fn, merge_fn = mlorax.lora_init(lora_spec, model)
  state = TrainState(apply_fn=apply_fn, params=params, tx=tx, **kwargs)

That's it! You can now train the model as usual.

Principles

Always use the returned apply_fn for model forwarding if possible. Otherwise use params=merge_fn(params) to pass the merged parameters to the function. For example, if you want to use model.generate for text generation, you can do the following:

- outputs = model.generate(**batch, params=params)
+ outputs = model.generate(**batch, params=merge_fn(params))

Example and Results

Please refer to the examples folder for details.

Citation

If you find MLoRAx useful, please cite the following paper:

@software{hao2023lmrax,
  author = {Yongchang Hao},
  title = {{T}he {LMR}ax {E}cosystem: a minimalist library for training {T}ransformer models with {JAX}},
  year = {2023},
  url = {https://github.com/yongchanghao/LMRax},
  version = {0.9.5}
}

About

MLoRAx is a minimalist library for low-rank adaptation designd to effortlessly enable parameter-efficient training for Transformer-based models based on JAX.

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Languages