diff --git a/.gitignore b/.gitignore index 3e563d1d..83339037 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,13 @@ __pycache__ dist .venv +# Byte-compiled / optimized / DLL files +*.py[cod] +*$py.class + +# C extensions +*.so + # Log *.log *.log.* @@ -33,4 +40,157 @@ tests/state_of_the_union.txt # Build build -!dummy_file \ No newline at end of file +!dummy_file + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ \ No newline at end of file diff --git a/llm_rl/LICENSE b/llm_rl/LICENSE new file mode 100644 index 00000000..b09cd785 --- /dev/null +++ b/llm_rl/LICENSE @@ -0,0 +1,201 @@ +Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/llm_rl/README.md b/llm_rl/README.md new file mode 100644 index 00000000..54da92da --- /dev/null +++ b/llm_rl/README.md @@ -0,0 +1,501 @@ +# LLaMA Factory: Training and Evaluating Large Language Models with Minimal Effort + +[![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Factory?style=social)](https://github.com/hiyouga/LLaMA-Factory/stargazers) +[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE) +[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main) +[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/) +[![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/) +[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) +[![Discord](https://dcbadge.vercel.app/api/server/e73gccsSd?compact=true&style=flat)](https://discord.gg/e73gccsSd) + +👋 Join our [WeChat](assets/wechat.jpg). + +\[ English | [中文](README_zh.md) \] + +## LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory + +Launch **LLaMA Board** via `CUDA_VISIBLE_DEVICES=0 python src/train_web.py`. (multiple GPUs are not supported yet) + +Here is an example of altering the self-cognition of an instruction-tuned language model within 10 minutes on a single GPU. + +https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846-2d88920d5ba1 + +## Changelog + +[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neft_alpha` argument to activate NEFTune, e.g., `--neft_alpha 5`. + +[23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `--shift_attn` argument to enable shift short attention. + +[23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models. + +[23/09/10] We supported using **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)** for the LLaMA models. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs. + +[23/08/12] We supported **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings. + +[23/08/11] We supported **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [this example](#dpo-training) to train your models. + +[23/07/31] We supported **dataset streaming**. Try `--streaming` and `--max_steps 10000` arguments to load your dataset in streaming mode. + +[23/07/29] We released two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft)) for details. + +[23/07/18] We developed an **all-in-one Web UI** for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development. + +[23/07/09] We released **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested. + +[23/06/29] We provided a **reproducible example** of training a chat model using instruction-following datasets, see [Baichuan-7B-sft](https://huggingface.co/hiyouga/Baichuan-7B-sft) for details. + +[23/06/22] We aligned the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in **arbitrary ChatGPT-based applications**. + +[23/06/03] We supported quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). Try `--quantization_bit 4/8` argument to work with quantized models. + +## Supported Models + +| Model | Model size | Default module | Template | +| -------------------------------------------------------- | --------------------------- | ----------------- | --------- | +| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan | +| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 | +| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | +| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | +| [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 6B | query_key_value | chatglm3 | +| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | - | +| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern | +| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - | +| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 | +| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral | +| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - | +| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B/14B | c_attn | qwen | +| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse | + +> [!NOTE] +> **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules. +> +> For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "chat" models. + +Please refer to [template.py](src/llmtuner/extras/template.py) for a full list of models we supported. + +## Supported Training Approaches + +| Approach | Full-parameter | Partial-parameter | LoRA | QLoRA | +| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ | +| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| Reward Modeling | | | :white_check_mark: | :white_check_mark: | +| PPO Training | | | :white_check_mark: | :white_check_mark: | +| DPO Training | :white_check_mark: | | :white_check_mark: | :white_check_mark: | + +> [!NOTE] +> Use `--quantization_bit 4/8` argument to enable QLoRA. + +## Provided Datasets + +
Pre-training datasets + +- [Wiki Demo (en)](data/wiki_demo.txt) +- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb) +- [RedPajama V2 (en)](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-V2) +- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220) +- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered) +- [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile) +- [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B) +- [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack) +- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata) + +
+ +
Supervised fine-tuning datasets + +- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca) +- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca) +- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) +- [Self-cognition (zh)](data/self_cognition.json) +- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) +- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection) +- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) +- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN) +- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN) +- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN) +- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M) +- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M) +- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M) +- [UltraChat (en)](https://github.com/thunlp/UltraChat) +- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima) +- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus) +- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) +- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) +- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct) +- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) +- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) +- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) +- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen) +- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k) +- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4) +- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k) +- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct) +- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m) +- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k) + +
+ +
Preference datasets + +- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) +- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) +- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) + +
+ +Please refer to [data/README.md](data/README.md) for details. + +Some datasets require confirmation before using them, so we recommend logging in with your Hugging Face account using these commands. + +```bash +pip install --upgrade huggingface_hub +huggingface-cli login +``` + +## Requirement + +- Python 3.8+ and PyTorch 1.13.1+ +- 🤗Transformers, Datasets, Accelerate, PEFT and TRL +- sentencepiece, protobuf and tiktoken +- fire, jieba, rouge-chinese and nltk (used at evaluation and predict) +- gradio and matplotlib (used in web UI) +- uvicorn, fastapi and sse-starlette (used in API) + +And **powerful GPUs**! + +## Getting Started + +### Data Preparation (optional) + +Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use a single `.json` file or a [dataset loading script](https://huggingface.co/docs/datasets/dataset_script) with multiple files to create a custom dataset. + +> [!NOTE] +> Please update `data/dataset_info.json` to use your custom dataset. About the format of this file, please refer to `data/README.md`. + +### Dependence Installation (optional) + +```bash +git clone https://github.com/hiyouga/LLaMA-Factory.git +conda create -n llama_factory python=3.10 +conda activate llama_factory +cd LLaMA-Factory +pip install -r requirements.txt +``` + +If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you will be required to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.1. + +```bash +pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl +``` + +### Train on a single GPU + +> [!IMPORTANT] +> If you want to train models on multiple GPUs, please refer to [Distributed Training](#distributed-training). + +#### Pre-Training + +```bash +CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ + --stage pt \ + --model_name_or_path path_to_llama_model \ + --do_train \ + --dataset wiki_demo \ + --finetuning_type lora \ + --lora_target q_proj,v_proj \ + --output_dir path_to_pt_checkpoint \ + --overwrite_cache \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 4 \ + --lr_scheduler_type cosine \ + --logging_steps 10 \ + --save_steps 1000 \ + --learning_rate 5e-5 \ + --num_train_epochs 3.0 \ + --plot_loss \ + --fp16 +``` + +#### Supervised Fine-Tuning + +```bash +CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ + --stage sft \ + --model_name_or_path path_to_llama_model \ + --do_train \ + --dataset alpaca_gpt4_en \ + --template default \ + --finetuning_type lora \ + --lora_target q_proj,v_proj \ + --output_dir path_to_sft_checkpoint \ + --overwrite_cache \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 4 \ + --lr_scheduler_type cosine \ + --logging_steps 10 \ + --save_steps 1000 \ + --learning_rate 5e-5 \ + --num_train_epochs 3.0 \ + --plot_loss \ + --fp16 +``` + +#### Reward Modeling + +```bash +CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ + --stage rm \ + --model_name_or_path path_to_llama_model \ + --do_train \ + --dataset comparison_gpt4_en \ + --template default \ + --finetuning_type lora \ + --lora_target q_proj,v_proj \ + --resume_lora_training False \ + --checkpoint_dir path_to_sft_checkpoint \ + --output_dir path_to_rm_checkpoint \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 4 \ + --lr_scheduler_type cosine \ + --logging_steps 10 \ + --save_steps 1000 \ + --learning_rate 1e-6 \ + --num_train_epochs 1.0 \ + --plot_loss \ + --fp16 +``` + +#### PPO Training + +```bash +CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ + --stage ppo \ + --model_name_or_path path_to_llama_model \ + --do_train \ + --dataset alpaca_gpt4_en \ + --template default \ + --finetuning_type lora \ + --lora_target q_proj,v_proj \ + --resume_lora_training False \ + --checkpoint_dir path_to_sft_checkpoint \ + --reward_model path_to_rm_checkpoint \ + --output_dir path_to_ppo_checkpoint \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 4 \ + --lr_scheduler_type cosine \ + --logging_steps 10 \ + --save_steps 1000 \ + --learning_rate 1e-5 \ + --num_train_epochs 1.0 \ + --plot_loss \ + --fp16 +``` + +#### DPO Training + +```bash +CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ + --stage dpo \ + --model_name_or_path path_to_llama_model \ + --do_train \ + --dataset comparison_gpt4_en \ + --template default \ + --finetuning_type lora \ + --lora_target q_proj,v_proj \ + --resume_lora_training False \ + --checkpoint_dir path_to_sft_checkpoint \ + --output_dir path_to_dpo_checkpoint \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 4 \ + --lr_scheduler_type cosine \ + --logging_steps 10 \ + --save_steps 1000 \ + --learning_rate 1e-5 \ + --num_train_epochs 1.0 \ + --plot_loss \ + --fp16 +``` + +### Distributed Training + +#### Use Huggingface Accelerate + +```bash +accelerate config # configure the environment +accelerate launch src/train_bash.py # arguments (same as above) +``` + +
Example config for LoRA training + +```yaml +compute_environment: LOCAL_MACHINE +distributed_type: MULTI_GPU +downcast_bf16: 'no' +gpu_ids: all +machine_rank: 0 +main_training_function: main +mixed_precision: fp16 +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false +``` + +
+ +#### Use DeepSpeed + +```bash +deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \ + --deepspeed ds_config.json \ + ... # arguments (same as above) +``` + +
Example config for full-parameter training with DeepSpeed ZeRO-2 + +```json +{ + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "zero_allow_untested_optimizer": true, + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "initial_scale_power": 16, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "zero_optimization": { + "stage": 2, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, + "reduce_scatter": true, + "reduce_bucket_size": 5e8, + "overlap_comm": false, + "contiguous_gradients": true + } +} +``` + +
+ +### Export model + +```bash +python src/export_model.py \ + --model_name_or_path path_to_llama_model \ + --template default \ + --finetuning_type lora \ + --checkpoint_dir path_to_checkpoint \ + --export_dir path_to_export +``` + +### API Demo + +```bash +python src/api_demo.py \ + --model_name_or_path path_to_llama_model \ + --template default \ + --finetuning_type lora \ + --checkpoint_dir path_to_checkpoint +``` + +> [!NOTE] +> Visit `http://localhost:8000/docs` for API documentation. + +### CLI Demo + +```bash +python src/cli_demo.py \ + --model_name_or_path path_to_llama_model \ + --template default \ + --finetuning_type lora \ + --checkpoint_dir path_to_checkpoint +``` + +### Web Demo + +```bash +python src/web_demo.py \ + --model_name_or_path path_to_llama_model \ + --template default \ + --finetuning_type lora \ + --checkpoint_dir path_to_checkpoint +``` + +### Evaluation + +```bash +CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \ + --model_name_or_path path_to_llama_model \ + --finetuning_type lora \ + --checkpoint_dir path_to_checkpoint \ + --template vanilla \ + --task mmlu \ + --split test \ + --lang en \ + --n_shot 5 \ + --batch_size 4 +``` + +### Predict + +```bash +CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ + --stage sft \ + --model_name_or_path path_to_llama_model \ + --do_predict \ + --dataset alpaca_gpt4_en \ + --template default \ + --finetuning_type lora \ + --checkpoint_dir path_to_checkpoint \ + --output_dir path_to_predict_result \ + --per_device_eval_batch_size 8 \ + --max_samples 100 \ + --predict_with_generate +``` + +> [!NOTE] +> We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit predict. + +## Projects using LLaMA Factory + +- **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B. +- **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge. +- **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B. +- **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B. + +## License + +This repository is licensed under the [Apache-2.0 License](LICENSE). + +Please follow the model licenses to use the corresponding model weights: [Baichuan](https://huggingface.co/baichuan-inc/Baichuan-13B-Base/resolve/main/Community%20License%20for%20Baichuan-13B%20Model.pdf) / [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/resolve/main/Community%20License%20for%20Baichuan2%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [InternLM](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) + +## Citation + +If this work is helpful, please kindly cite as: + +```bibtex +@Misc{llama-factory, + title = {LLaMA Factory}, + author = {hiyouga}, + howpublished = {\url{https://github.com/hiyouga/LLaMA-Factory}}, + year = {2023} +} +``` + +## Acknowledgement + +This repo benefits from [PEFT](https://github.com/huggingface/peft), [QLoRA](https://github.com/artidoro/qlora) and [FastChat](https://github.com/lm-sys/FastChat). Thanks for their wonderful works. + +## Star History + +![Star History Chart](https://api.star-history.com/svg?repos=hiyouga/LLaMA-Factory&type=Date) diff --git a/llm_rl/README_zh.md b/llm_rl/README_zh.md new file mode 100644 index 00000000..c69e3983 --- /dev/null +++ b/llm_rl/README_zh.md @@ -0,0 +1,500 @@ +# LLaMA Factory: 轻松的大模型训练与评估 + +[![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Factory?style=social)](https://github.com/hiyouga/LLaMA-Factory/stargazers) +[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE) +[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main) +[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/) +[![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/) +[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) +[![Discord](https://dcbadge.vercel.app/api/server/e73gccsSd?compact=true&style=flat)](https://discord.gg/e73gccsSd) + +👋 加入我们的[微信群](assets/wechat.jpg)。 + +\[ [English](README.md) | 中文 \] + +## LLaMA Board: 通过一站式网页界面快速上手 LLaMA Factory + +使用 `CUDA_VISIBLE_DEVICES=0 python src/train_web.py` 启动 **LLaMA Board**。(该界面目前仅支持单卡训练) + +下面是使用单张 GPU 在 10 分钟内更改对话式大型语言模型自我认知的示例。 + +https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846-2d88920d5ba1 + +## 更新日志 + +[23/10/21] 我们支持了 **[NEFTune](https://arxiv.org/abs/2310.05914)** 训练技巧。请使用 `--neft_alpha` 参数启用 NEFTune,例如 `--neft_alpha 5`。 + +[23/09/27] 我们针对 LLaMA 模型支持了 [LongLoRA](https://github.com/dvlab-research/LongLoRA) 提出的 **$S^2$-Attn**。请使用 `--shift_attn` 参数以启用该功能。 + +[23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。 + +[23/09/10] 我们针对 LLaMA 模型支持了 **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU,请使用 `--flash_attn` 参数以启用 FlashAttention-2。 + +[23/08/12] 我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型。 + +[23/08/11] 我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。使用方法请参阅[此示例](#dpo-训练)。 + +[23/07/31] 我们支持了**数据流式加载**。请尝试使用 `--streaming` 和 `--max_steps 10000` 参数来流式加载数据集。 + +[23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft))。 + +[23/07/18] 我们开发了支持训练和测试的**浏览器一体化界面**。请使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。 + +[23/07/09] 我们开源了 **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹,一个简单易用的、能迅速编辑大模型事实记忆的工具包。如果您感兴趣请关注我们的 [FastEdit](https://github.com/hiyouga/FastEdit) 项目。 + +[23/06/29] 我们提供了一个**可复现的**指令模型微调示例,详细内容请查阅 [Baichuan-7B-sft](https://huggingface.co/hiyouga/Baichuan-7B-sft)。 + +[23/06/22] 我们对齐了[示例 API](src/api_demo.py) 与 [OpenAI API](https://platform.openai.com/docs/api-reference/chat) 的格式,您可以将微调模型接入**任意基于 ChatGPT 的应用**中。 + +[23/06/03] 我们实现了 4 比特的 LoRA 训练(也称 **[QLoRA](https://github.com/artidoro/qlora)**)。请使用 `--quantization_bit 4` 参数进行 4 比特量化微调。 + +## 模型 + +| 模型名 | 模型大小 | 默认模块 | Template | +| -------------------------------------------------------- | --------------------------- | ----------------- | --------- | +| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan | +| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 | +| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | +| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | +| [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 6B | query_key_value | chatglm3 | +| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | - | +| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern | +| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - | +| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 | +| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral | +| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - | +| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B/14B | c_attn | qwen | +| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse | + +> [!NOTE] +> **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块。 +> +> 对于所有“基座”(Base)模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Chat)模型请务必使用**对应的模板**。 + +项目所支持模型的完整列表请参阅 [template.py](src/llmtuner/extras/template.py)。 + +## 训练方法 + +| 方法 | 全参数训练 | 部分参数训练 | LoRA | QLoRA | +| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ | +| 预训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| 指令监督微调 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| 奖励模型训练 | | | :white_check_mark: | :white_check_mark: | +| PPO 训练 | | | :white_check_mark: | :white_check_mark: | +| DPO 训练 | :white_check_mark: | | :white_check_mark: | :white_check_mark: | + +> [!NOTE] +> 请使用 `--quantization_bit 4/8` 参数来启用 QLoRA 训练。 + +## 数据集 + +
预训练数据集 + +- [Wiki Demo (en)](data/wiki_demo.txt) +- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb) +- [RedPajama V2 (en)](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-V2) +- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220) +- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered) +- [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile) +- [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B) +- [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack) +- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata) + +
+ +
指令微调数据集 + +- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca) +- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca) +- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) +- [Self-cognition (zh)](data/self_cognition.json) +- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) +- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection) +- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) +- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN) +- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN) +- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN) +- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M) +- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M) +- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M) +- [UltraChat (en)](https://github.com/thunlp/UltraChat) +- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima) +- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus) +- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) +- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) +- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct) +- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) +- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) +- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) +- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen) +- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k) +- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4) +- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k) +- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct) +- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m) +- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k) + +
+ +
偏好数据集 + +- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) +- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) +- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) + +
+ +使用方法请参考 [data/README_zh.md](data/README_zh.md) 文件。 + +部分数据集的使用需要确认,我们推荐使用下述命令登录您的 Hugging Face 账户。 + +```bash +pip install --upgrade huggingface_hub +huggingface-cli login +``` + +## 软件依赖 + +- Python 3.8+ 和 PyTorch 1.13.1+ +- 🤗Transformers, Datasets, Accelerate, PEFT 和 TRL +- sentencepiece, protobuf 和 tiktoken +- fire, jieba, rouge-chinese 和 nltk (用于评估及预测) +- gradio 和 matplotlib (用于网页端交互) +- uvicorn, fastapi 和 sse-starlette (用于 API) + +以及 **强而有力的 GPU**! + +## 如何使用 + +### 数据准备(可跳过) + +关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。构建自定义数据集时,既可以使用单个 `.json` 文件,也可以使用一个[数据加载脚本](https://huggingface.co/docs/datasets/dataset_script)和多个文件。 + +> [!NOTE] +> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件,该文件的格式请参考 `data/README_zh.md`。 + +### 环境搭建(可跳过) + +```bash +git clone https://github.com/hiyouga/LLaMA-Factory.git +conda create -n llama_factory python=3.10 +conda activate llama_factory +cd LLaMA-Factory +pip install -r requirements.txt +``` + +如果要在 Windows 平台上开启量化 LoRA(QLoRA),需要安装预编译的 `bitsandbytes` 库, 支持 CUDA 11.1 到 12.1. + +```bash +pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl +``` + +### 单 GPU 训练 + +> [!IMPORTANT] +> 如果您使用多张 GPU 训练模型,请移步[多 GPU 分布式训练](#多-gpu-分布式训练)部分。 + +#### 预训练 + +```bash +CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ + --stage pt \ + --model_name_or_path path_to_llama_model \ + --do_train \ + --dataset wiki_demo \ + --finetuning_type lora \ + --lora_target q_proj,v_proj \ + --output_dir path_to_pt_checkpoint \ + --overwrite_cache \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 4 \ + --lr_scheduler_type cosine \ + --logging_steps 10 \ + --save_steps 1000 \ + --learning_rate 5e-5 \ + --num_train_epochs 3.0 \ + --plot_loss \ + --fp16 +``` + +#### 指令监督微调 + +```bash +CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ + --stage sft \ + --model_name_or_path path_to_llama_model \ + --do_train \ + --dataset alpaca_gpt4_zh \ + --template default \ + --finetuning_type lora \ + --lora_target q_proj,v_proj \ + --output_dir path_to_sft_checkpoint \ + --overwrite_cache \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 4 \ + --lr_scheduler_type cosine \ + --logging_steps 10 \ + --save_steps 1000 \ + --learning_rate 5e-5 \ + --num_train_epochs 3.0 \ + --plot_loss \ + --fp16 +``` + +#### 奖励模型训练 + +```bash +CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ + --stage rm \ + --model_name_or_path path_to_llama_model \ + --do_train \ + --dataset comparison_gpt4_zh \ + --template default \ + --finetuning_type lora \ + --lora_target q_proj,v_proj \ + --resume_lora_training False \ + --checkpoint_dir path_to_sft_checkpoint \ + --output_dir path_to_rm_checkpoint \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 4 \ + --lr_scheduler_type cosine \ + --logging_steps 10 \ + --save_steps 1000 \ + --learning_rate 1e-6 \ + --num_train_epochs 1.0 \ + --plot_loss \ + --fp16 +``` + +#### PPO 训练 + +```bash +CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ + --stage ppo \ + --model_name_or_path path_to_llama_model \ + --do_train \ + --dataset alpaca_gpt4_zh \ + --template default \ + --finetuning_type lora \ + --lora_target q_proj,v_proj \ + --resume_lora_training False \ + --checkpoint_dir path_to_sft_checkpoint \ + --reward_model path_to_rm_checkpoint \ + --output_dir path_to_ppo_checkpoint \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 4 \ + --lr_scheduler_type cosine \ + --logging_steps 10 \ + --save_steps 1000 \ + --learning_rate 1e-5 \ + --num_train_epochs 1.0 \ + --plot_loss +``` + +#### DPO 训练 + +```bash +CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ + --stage dpo \ + --model_name_or_path path_to_llama_model \ + --do_train \ + --dataset comparison_gpt4_zh \ + --template default \ + --finetuning_type lora \ + --lora_target q_proj,v_proj \ + --resume_lora_training False \ + --checkpoint_dir path_to_sft_checkpoint \ + --output_dir path_to_dpo_checkpoint \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 4 \ + --lr_scheduler_type cosine \ + --logging_steps 10 \ + --save_steps 1000 \ + --learning_rate 1e-5 \ + --num_train_epochs 1.0 \ + --plot_loss \ + --fp16 +``` + +### 多 GPU 分布式训练 + +#### 使用 Huggingface Accelerate + +```bash +accelerate config # 首先配置分布式环境 +accelerate launch src/train_bash.py # 参数同上 +``` + +
LoRA 训练的 Accelerate 配置示例 + +```yaml +compute_environment: LOCAL_MACHINE +distributed_type: MULTI_GPU +downcast_bf16: 'no' +gpu_ids: all +machine_rank: 0 +main_training_function: main +mixed_precision: fp16 +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false +``` + +
+ +#### 使用 DeepSpeed + +```bash +deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \ + --deepspeed ds_config.json \ + ... # 参数同上 +``` + +
使用 DeepSpeed ZeRO-2 进行全参数训练的 DeepSpeed 配置示例 + +```json +{ + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "zero_allow_untested_optimizer": true, + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "initial_scale_power": 16, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "zero_optimization": { + "stage": 2, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, + "reduce_scatter": true, + "reduce_bucket_size": 5e8, + "overlap_comm": false, + "contiguous_gradients": true + } +} +``` + +
+ +### 导出微调后的完整模型 + +```bash +python src/export_model.py \ + --model_name_or_path path_to_llama_model \ + --template default \ + --finetuning_type lora \ + --checkpoint_dir path_to_checkpoint \ + --export_dir path_to_export +``` + +### API 服务 + +```bash +python src/api_demo.py \ + --model_name_or_path path_to_llama_model \ + --template default \ + --finetuning_type lora \ + --checkpoint_dir path_to_checkpoint +``` + +> [!NOTE] +> 关于 API 文档请见 `http://localhost:8000/docs`。 + +### 命令行测试 + +```bash +python src/cli_demo.py \ + --model_name_or_path path_to_llama_model \ + --template default \ + --finetuning_type lora \ + --checkpoint_dir path_to_checkpoint +``` + +### 浏览器测试 + +```bash +python src/web_demo.py \ + --model_name_or_path path_to_llama_model \ + --template default \ + --finetuning_type lora \ + --checkpoint_dir path_to_checkpoint +``` + +### 模型评估 + +```bash +CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \ + --model_name_or_path path_to_llama_model \ + --finetuning_type lora \ + --checkpoint_dir path_to_checkpoint \ + --template vanilla \ + --task ceval \ + --split validation \ + --lang zh \ + --n_shot 5 \ + --batch_size 4 +``` + +### 模型预测 + +```bash +CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ + --stage sft \ + --model_name_or_path path_to_llama_model \ + --do_predict \ + --dataset alpaca_gpt4_zh \ + --template default \ + --finetuning_type lora \ + --checkpoint_dir path_to_checkpoint \ + --output_dir path_to_predict_result \ + --per_device_eval_batch_size 8 \ + --max_samples 100 \ + --predict_with_generate +``` + +> [!NOTE] +> 我们建议在量化模型的预测中使用 `--per_device_eval_batch_size=1` 和 `--max_target_length 128`。 + +## 使用了 LLaMA Factory 的项目 + +- **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper,基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。 +- **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM,基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。 +- **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。 +- **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT,基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。 + +## 协议 + +本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。 + +使用模型权重时,请遵循对应的模型协议:[Baichuan](https://huggingface.co/baichuan-inc/Baichuan-13B-Base/resolve/main/Community%20License%20for%20Baichuan-13B%20Model.pdf) / [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/resolve/main/Community%20License%20for%20Baichuan2%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [InternLM](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) + +## 引用 + +如果您觉得此项目有帮助,请考虑以下列格式引用 + +```bibtex +@Misc{llama-factory, + title = {LLaMA Factory}, + author = {hiyouga}, + howpublished = {\url{https://github.com/hiyouga/LLaMA-Factory}}, + year = {2023} +} +``` + +## 致谢 + +本项目受益于 [PEFT](https://github.com/huggingface/peft)、[QLoRA](https://github.com/artidoro/qlora) 和 [FastChat](https://github.com/lm-sys/FastChat),感谢以上诸位作者的付出。 + +## Star History + +![Star History Chart](https://api.star-history.com/svg?repos=hiyouga/LLaMA-Factory&type=Date) diff --git a/llm_rl/assets/wechat.jpg b/llm_rl/assets/wechat.jpg new file mode 100644 index 00000000..8df68f9c Binary files /dev/null and b/llm_rl/assets/wechat.jpg differ diff --git a/llm_rl/pyproject.toml b/llm_rl/pyproject.toml new file mode 100644 index 00000000..638dd9c5 --- /dev/null +++ b/llm_rl/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" diff --git a/llm_rl/requirements.txt b/llm_rl/requirements.txt new file mode 100644 index 00000000..840d2f2d --- /dev/null +++ b/llm_rl/requirements.txt @@ -0,0 +1,20 @@ +torch>=1.13.1 +transformers>=4.31.0,<4.35.0 +datasets>=2.12.0 +accelerate>=0.21.0 +peft>=0.4.0 +trl>=0.7.2 +gradio>=3.38.0,<4.0.0 +scipy +sentencepiece +protobuf +tiktoken +fire +jieba +rouge-chinese +nltk +uvicorn +pydantic +fastapi +sse-starlette +matplotlib diff --git a/llm_rl/reward_model.sh b/llm_rl/reward_model.sh new file mode 100644 index 00000000..3068fb43 --- /dev/null +++ b/llm_rl/reward_model.sh @@ -0,0 +1,21 @@ +python src/train_bash.py \ + --stage rm \ + --model_name_or_path meta-llama/Llama-2-13b \ + --do_train \ + --dataset comparison_gpt4_en \ + --template default \ + --finetuning_type lora \ + --lora_target q_proj,v_proj \ + --resume_lora_training False \ + --checkpoint_dir ./llama-2-13b-rm \ + --output_dir ./llama-2-13b-rm \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 4 \ + --lr_scheduler_type cosine \ + --logging_steps 10 \ + --save_steps 1000 \ + --learning_rate 1e-6 \ + --num_train_epochs 1.0 \ + --plot_loss \ + --fp16 \ + --hf_auth_token "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" \ No newline at end of file diff --git a/llm_rl/setup.py b/llm_rl/setup.py new file mode 100644 index 00000000..7638eaab --- /dev/null +++ b/llm_rl/setup.py @@ -0,0 +1,55 @@ +import os +import re +from setuptools import setup, find_packages + + +def get_version(): + with open(os.path.join("src", "llmtuner", "__init__.py"), "r", encoding="utf-8") as f: + file_content = f.read() + pattern = r"{0}\W*=\W*\"([^\"]+)\"".format("__version__") + version, = re.findall(pattern, file_content) + return version + + +def get_requires(): + with open("requirements.txt", "r", encoding="utf-8") as f: + file_content = f.read() + lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")] + return lines + + +def main(): + + setup( + name="llmtuner", + version=get_version(), + author="hiyouga", + author_email="hiyouga" "@" "buaa.edu.cn", + description="Easy-to-use LLM fine-tuning framework", + long_description=open("README.md", "r", encoding="utf-8").read(), + long_description_content_type="text/markdown", + keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"], + license="Apache 2.0 License", + url="https://github.com/hiyouga/LLaMA-Factory", + package_dir={"": "src"}, + packages=find_packages("src"), + python_requires=">=3.8.0", + install_requires=get_requires(), + classifiers=[ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ] + ) + + +if __name__ == "__main__": + main() diff --git a/llm_rl/src/api_demo.py b/llm_rl/src/api_demo.py new file mode 100644 index 00000000..720089fd --- /dev/null +++ b/llm_rl/src/api_demo.py @@ -0,0 +1,14 @@ +import uvicorn + +from llmtuner import ChatModel, create_app + + +def main(): + chat_model = ChatModel() + app = create_app(chat_model) + print("Visit http://localhost:8000/docs for API document.") + uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) + + +if __name__ == "__main__": + main() diff --git a/llm_rl/src/cli_demo.py b/llm_rl/src/cli_demo.py new file mode 100644 index 00000000..fe6a0bc4 --- /dev/null +++ b/llm_rl/src/cli_demo.py @@ -0,0 +1,39 @@ +import readline +from llmtuner import ChatModel + + +def main(): + chat_model = ChatModel() + history = [] + print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.") + + while True: + try: + query = input("\nUser: ") + except UnicodeDecodeError: + print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.") + continue + except Exception: + raise + + if query.strip() == "exit": + break + + if query.strip() == "clear": + history = [] + print("History has been removed.") + continue + + print("Assistant: ", end="", flush=True) + + response = "" + for new_text in chat_model.stream_chat(query, history): + print(new_text, end="", flush=True) + response += new_text + print() + + history = history + [(query, response)] + + +if __name__ == "__main__": + main() diff --git a/llm_rl/src/evaluate.py b/llm_rl/src/evaluate.py new file mode 100644 index 00000000..8af8c12c --- /dev/null +++ b/llm_rl/src/evaluate.py @@ -0,0 +1,190 @@ +# coding=utf-8 +# Evaluates the performance of pre-trained models. +# Usage: python evaluate.py --model_name_or_path path_to_model --checkpoint_dir path_to_ckpt --template vanilla +# --task ceval --split validation --lang zh --n_shot 5 --batch_size 4 --save_name result +# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py + +import os +import fire +import json +import torch +import numpy as np +import transformers +from collections import Counter +from datasets import load_dataset +from dataclasses import dataclass +from tqdm import tqdm, trange +from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple + +from llmtuner import ChatModel + +if TYPE_CHECKING: + from datasets import Dataset + + +choices = ["A", "B", "C", "D"] + + +@dataclass +class EvalTemplate: + + system: str + choice: str + answer: str + prefix: str + + def parse_example( + self, + example: Dict[str, str] + ) -> Tuple[str, str]: + candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in choices if ch in example] + return "".join([example["question"]] + candidates + [self.answer]), example["answer"] + + def format_example( + self, + target_data: Dict[str, str], + support_set: "Dataset", + subject_name: str, + use_history: bool + ) -> Tuple[str, str, List[Tuple[str, str]]]: + query, resp = self.parse_example(target_data) + history = [self.parse_example(support_set[k]) for k in range(len(support_set))] + + if len(history): + temp = history.pop(0) + history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1])) + else: + query = self.system.format(subject=subject_name) + query + + if not use_history: + query = "\n\n".join(["".join(item) for item in history] + [query]) + history = [] + return query.strip(), resp, history + + +eval_templates = { + "en": EvalTemplate( + system="The following are multiple choice questions (with answers) about {subject}.\n\n", + choice="\n{choice}. {content}", + answer="\nAnswer: ", + prefix=" " + ), + "zh": EvalTemplate( + system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n", + choice="\n{choice}. {content}", + answer="\n答案:", + prefix="\n" + ) +} + + +@torch.inference_mode() +def batch_inference( + chat_model: ChatModel, + batch_input: Dict[str, torch.Tensor], + prefix_char: str +) -> List[str]: + logits = chat_model.model(**batch_input).logits + lengths = torch.sum(batch_input["attention_mask"], dim=-1) + nextword_logits = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0) + probs = torch.nn.functional.softmax( + torch.stack( + [ + nextword_logits[:, chat_model.tokenizer.encode(prefix_char + choice, add_special_tokens=False)[-1]] + for choice in choices + ], + dim=-1 + ), + dim=-1 + ).detach() + return [chr(ord("A") + offset.item()) for offset in torch.argmax(probs, dim=-1)] + + +def evaluate( + model_name_or_path: str, + finetuning_type: Optional[str] = "lora", + checkpoint_dir: Optional[str] = None, + template: Optional[str] = "vanilla", + task: Optional[str] = "ceval", + dataset_dir: Optional[str] = "evaluation", + split: Optional[Literal["validation", "test"]] = "validation", + lang: Optional[Literal["zh", "en"]] = "zh", + n_shot: Optional[int] = 5, + n_avg: Optional[int] = 1, + batch_size: Optional[int] = 4, + save_name: Optional[str] = None, + seed: Optional[int] = 42 +): + with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f: + categorys: Dict[str, Dict[str, str]] = json.load(f) + + transformers.set_seed(seed) + chat_model = ChatModel(dict( + model_name_or_path=model_name_or_path, + finetuning_type=finetuning_type, + checkpoint_dir=checkpoint_dir, + template=template + )) + chat_model.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2 + eval_template = eval_templates[lang] + + category_corrects: Dict[str, np.ndarray] = { + subj: np.array([], dtype="bool") for subj in ["Average", "STEM", "Social Sciences", "Humanities", "Other"] + } + pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0) + results = {} + for subject in pbar: + dataset = load_dataset(os.path.join(dataset_dir, task), subject) + labels, answers, all_outputs = [], [], [] + for epoch in range(n_avg): + pbar.set_postfix_str("{} Trial: {}".format(categorys[subject]["name"], epoch)) + inputs, outputs = [], [] + for i in trange(len(dataset[split]), desc="Formatting batches", position=1, leave=False): + support_set = dataset["train"].shuffle().select(range(min(n_shot, len(dataset["train"])))) + query, resp, history = eval_template.format_example( + target_data=dataset[split][i], + support_set=support_set, + subject_name=categorys[subject]["name"], + use_history=chat_model.template.use_history + ) + input_ids, _ = chat_model.template.encode_oneturn( + tokenizer=chat_model.tokenizer, query=query, resp=resp, history=history + ) + inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)}) + if epoch == 0: + labels.append(resp) + + for i in trange(0, len(inputs), batch_size, desc="Predicting batches", position=1, leave=False): + batch_input = chat_model.tokenizer.pad( + inputs[i : i + batch_size], return_attention_mask=True, return_tensors="pt" + ).to(chat_model.model.device) + preds = batch_inference(chat_model, batch_input, eval_template.prefix) + outputs += preds + all_outputs.append(outputs) + + for i in range(len(all_outputs[0])): + count = Counter([all_outputs[epoch][i] for epoch in range(n_avg)]) + answers.append(count.most_common(1)[0][0]) + + corrects = (np.array(answers) == np.array(labels)) + category_name = categorys[subject]["category"] + category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0) + category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0) + results[subject] = {str(i): answers[i] for i in range(len(answers))} + + score_info = "\n".join([ + "{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct)) + for category_name, category_correct in category_corrects.items() if len(category_correct) + ]) + + print(score_info) + if save_name is not None: + with open(save_name + ".json", "w", encoding="utf-8", newline="\n") as f: + json.dump(results, f, indent=2) + + with open(save_name + ".log", "w", encoding="utf-8", newline="\n") as f: + f.write(score_info) + + +if __name__ == "__main__": + fire.Fire(evaluate) diff --git a/llm_rl/src/export_model.py b/llm_rl/src/export_model.py new file mode 100644 index 00000000..4baeb2c3 --- /dev/null +++ b/llm_rl/src/export_model.py @@ -0,0 +1,9 @@ +from llmtuner import export_model + + +def main(): + export_model() + + +if __name__ == "__main__": + main() diff --git a/llm_rl/src/llmtuner/__init__.py b/llm_rl/src/llmtuner/__init__.py new file mode 100644 index 00000000..37eb9535 --- /dev/null +++ b/llm_rl/src/llmtuner/__init__.py @@ -0,0 +1,9 @@ +# Level: api, webui > chat > tuner > dsets > extras, hparams + +from llmtuner.api import create_app +from llmtuner.chat import ChatModel +from llmtuner.tuner import export_model, run_exp +from llmtuner.webui import create_ui, create_web_demo + + +__version__ = "0.2.0" diff --git a/llm_rl/src/llmtuner/api/__init__.py b/llm_rl/src/llmtuner/api/__init__.py new file mode 100644 index 00000000..b3ce183a --- /dev/null +++ b/llm_rl/src/llmtuner/api/__init__.py @@ -0,0 +1 @@ +from llmtuner.api.app import create_app diff --git a/llm_rl/src/llmtuner/api/app.py b/llm_rl/src/llmtuner/api/app.py new file mode 100644 index 00000000..27fb19e0 --- /dev/null +++ b/llm_rl/src/llmtuner/api/app.py @@ -0,0 +1,146 @@ +import json +import uvicorn +from fastapi import FastAPI, HTTPException, status +from fastapi.middleware.cors import CORSMiddleware +from contextlib import asynccontextmanager +from sse_starlette import EventSourceResponse +from typing import List, Tuple +from pydantic import BaseModel + +from llmtuner.extras.misc import torch_gc +from llmtuner.chat import ChatModel +from llmtuner.api.protocol import ( + Role, + Finish, + ModelCard, + ModelList, + ChatMessage, + DeltaMessage, + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionStreamResponse, + ChatCompletionResponseChoice, + ChatCompletionResponseStreamChoice, + ChatCompletionResponseUsage +) + + +@asynccontextmanager +async def lifespan(app: FastAPI): # collects GPU memory + yield + torch_gc() + + +def to_json(data: BaseModel) -> str: + try: # pydantic v2 + return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False) + except: # pydantic v1 + return data.json(exclude_unset=True, ensure_ascii=False) + + +def create_app(chat_model: ChatModel) -> FastAPI: + app = FastAPI(lifespan=lifespan) + + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + @app.get("/v1/models", response_model=ModelList) + async def list_models(): + model_card = ModelCard(id="gpt-3.5-turbo") + return ModelList(data=[model_card]) + + @app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK) + async def create_chat_completion(request: ChatCompletionRequest): + if len(request.messages) < 1 or request.messages[-1].role != Role.USER: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") + + query = request.messages[-1].content + prev_messages = request.messages[:-1] + if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM: + system = prev_messages.pop(0).content + else: + system = None + + history = [] + if len(prev_messages) % 2 == 0: + for i in range(0, len(prev_messages), 2): + if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT: + history.append([prev_messages[i].content, prev_messages[i+1].content]) + else: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") + + if request.stream: + generate = predict(query, history, system, request) + return EventSourceResponse(generate, media_type="text/event-stream") + + response, (prompt_length, response_length) = chat_model.chat( + query, history, system, + do_sample=request.do_sample, + temperature=request.temperature, + top_p=request.top_p, + max_new_tokens=request.max_tokens, + num_return_sequences=request.n + ) + + usage = ChatCompletionResponseUsage( + prompt_tokens=prompt_length, + completion_tokens=response_length, + total_tokens=prompt_length+response_length + ) + + choices = [ChatCompletionResponseChoice( + index=i, + message=ChatMessage(role=Role.ASSISTANT, content=choice), + finish_reason=Finish.STOP + ) for i, choice in enumerate(response)] + + return ChatCompletionResponse(model=request.model, choices=choices, usage=usage) + + async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest): + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(role=Role.ASSISTANT), + finish_reason=None + ) + chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) + yield to_json(chunk) + + for new_text in chat_model.stream_chat( + query, history, system, + do_sample=request.do_sample, + temperature=request.temperature, + top_p=request.top_p, + max_new_tokens=request.max_tokens + ): + if len(new_text) == 0: + continue + + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(content=new_text), + finish_reason=None + ) + chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) + yield to_json(chunk) + + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(), + finish_reason=Finish.STOP + ) + chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) + yield to_json(chunk) + yield "[DONE]" + + return app + + +if __name__ == "__main__": + chat_model = ChatModel() + app = create_app(chat_model) + uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) diff --git a/llm_rl/src/llmtuner/api/protocol.py b/llm_rl/src/llmtuner/api/protocol.py new file mode 100644 index 00000000..6b99da40 --- /dev/null +++ b/llm_rl/src/llmtuner/api/protocol.py @@ -0,0 +1,83 @@ +import time +from enum import Enum +from pydantic import BaseModel, Field +from typing import List, Optional + + +class Role(str, Enum): + USER = "user" + ASSISTANT = "assistant" + SYSTEM = "system" + + +class Finish(str, Enum): + STOP = "stop" + LENGTH = "length" + + +class ModelCard(BaseModel): + id: str + object: Optional[str] = "model" + created: Optional[int] = Field(default_factory=lambda: int(time.time())) + owned_by: Optional[str] = "owner" + + +class ModelList(BaseModel): + object: Optional[str] = "list" + data: Optional[List[ModelCard]] = [] + + +class ChatMessage(BaseModel): + role: Role + content: str + + +class DeltaMessage(BaseModel): + role: Optional[Role] = None + content: Optional[str] = None + + +class ChatCompletionRequest(BaseModel): + model: str + messages: List[ChatMessage] + do_sample: Optional[bool] = True + temperature: Optional[float] = None + top_p: Optional[float] = None + n: Optional[int] = 1 + max_tokens: Optional[int] = None + stream: Optional[bool] = False + + +class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatMessage + finish_reason: Finish + + +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + delta: DeltaMessage + finish_reason: Optional[Finish] = None + + +class ChatCompletionResponseUsage(BaseModel): + prompt_tokens: int + completion_tokens: int + total_tokens: int + + +class ChatCompletionResponse(BaseModel): + id: Optional[str] = "chatcmpl-default" + object: Optional[str] = "chat.completion" + created: Optional[int] = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseChoice] + usage: ChatCompletionResponseUsage + + +class ChatCompletionStreamResponse(BaseModel): + id: Optional[str] = "chatcmpl-default" + object: Optional[str] = "chat.completion.chunk" + created: Optional[int] = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseStreamChoice] diff --git a/llm_rl/src/llmtuner/chat/__init__.py b/llm_rl/src/llmtuner/chat/__init__.py new file mode 100644 index 00000000..ba240d05 --- /dev/null +++ b/llm_rl/src/llmtuner/chat/__init__.py @@ -0,0 +1 @@ +from llmtuner.chat.stream_chat import ChatModel diff --git a/llm_rl/src/llmtuner/chat/stream_chat.py b/llm_rl/src/llmtuner/chat/stream_chat.py new file mode 100644 index 00000000..cc815d1b --- /dev/null +++ b/llm_rl/src/llmtuner/chat/stream_chat.py @@ -0,0 +1,109 @@ +import torch +from typing import Any, Dict, Generator, List, Optional, Tuple +from threading import Thread +from transformers import GenerationConfig, TextIteratorStreamer + +from llmtuner.extras.misc import dispatch_model, get_logits_processor +from llmtuner.extras.template import get_template_and_fix_tokenizer +from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer + + +class ChatModel: + + def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: + model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args) + self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args) + self.tokenizer.padding_side = "left" + self.model = dispatch_model(self.model) + self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer) + self.system_prompt = data_args.system_prompt + + def process_args( + self, + query: str, + history: Optional[List[Tuple[str, str]]] = None, + system: Optional[str] = None, + **input_kwargs + ) -> Tuple[Dict[str, Any], int]: + system = system or self.system_prompt + prompt, _ = self.template.encode_oneturn( + tokenizer=self.tokenizer, query=query, resp="", history=history, system=system + ) + prompt_length = len(prompt) + input_ids = torch.tensor([prompt], device=self.model.device) + + do_sample = input_kwargs.pop("do_sample", None) + temperature = input_kwargs.pop("temperature", None) + top_p = input_kwargs.pop("top_p", None) + top_k = input_kwargs.pop("top_k", None) + num_return_sequences = input_kwargs.pop("num_return_sequences", None) + repetition_penalty = input_kwargs.pop("repetition_penalty", None) + max_length = input_kwargs.pop("max_length", None) + max_new_tokens = input_kwargs.pop("max_new_tokens", None) + + generating_args = self.generating_args.to_dict() + generating_args.update(dict( + do_sample=do_sample if do_sample is not None else generating_args["do_sample"], + temperature=temperature or generating_args["temperature"], + top_p=top_p or generating_args["top_p"], + top_k=top_k or generating_args["top_k"], + num_return_sequences=num_return_sequences or 1, + repetition_penalty=repetition_penalty or generating_args["repetition_penalty"], + eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, + pad_token_id=self.tokenizer.pad_token_id + )) + + if isinstance(num_return_sequences, int) and num_return_sequences > 1: + generating_args["do_sample"] = True + + if max_length: + generating_args.pop("max_new_tokens", None) + generating_args["max_length"] = max_length + + if max_new_tokens: + generating_args.pop("max_length", None) + generating_args["max_new_tokens"] = max_new_tokens + + gen_kwargs = dict( + inputs=input_ids, + generation_config=GenerationConfig(**generating_args), + logits_processor=get_logits_processor() + ) + + return gen_kwargs, prompt_length + + @torch.inference_mode() + def chat( + self, + query: str, + history: Optional[List[Tuple[str, str]]] = None, + system: Optional[str] = None, + **input_kwargs + ) -> Tuple[List[str], Tuple[int, int]]: + gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs) + generate_output = self.model.generate(**gen_kwargs) + response_ids = generate_output[:, prompt_length:] + response = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) + response_length = 0 + for i in range(len(response_ids)): + eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero() + response_length += eos_index[0].item() if len(eos_index) else len(response_ids[i]) + + return response, (prompt_length, response_length) + + @torch.inference_mode() + def stream_chat( + self, + query: str, + history: Optional[List[Tuple[str, str]]] = None, + system: Optional[str] = None, + **input_kwargs + ) -> Generator[str, None, None]: + gen_kwargs, _ = self.process_args(query, history, system, **input_kwargs) + streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) + gen_kwargs["streamer"] = streamer + + thread = Thread(target=self.model.generate, kwargs=gen_kwargs) + thread.start() + + yield from streamer diff --git a/llm_rl/src/llmtuner/dsets/__init__.py b/llm_rl/src/llmtuner/dsets/__init__.py new file mode 100644 index 00000000..cccbd745 --- /dev/null +++ b/llm_rl/src/llmtuner/dsets/__init__.py @@ -0,0 +1,3 @@ +from llmtuner.dsets.loader import get_dataset +from llmtuner.dsets.preprocess import preprocess_dataset +from llmtuner.dsets.utils import split_dataset diff --git a/llm_rl/src/llmtuner/dsets/loader.py b/llm_rl/src/llmtuner/dsets/loader.py new file mode 100644 index 00000000..834ef733 --- /dev/null +++ b/llm_rl/src/llmtuner/dsets/loader.py @@ -0,0 +1,145 @@ +import os +from typing import TYPE_CHECKING, Any, Dict, List, Union + +from datasets import concatenate_datasets, interleave_datasets, load_dataset + +from llmtuner.dsets.utils import checksum, EXT2TYPE +from llmtuner.extras.logging import get_logger + +if TYPE_CHECKING: + from datasets import Dataset, IterableDataset + from llmtuner.hparams import ModelArguments, DataArguments + + +logger = get_logger(__name__) + + +def get_dataset( + model_args: "ModelArguments", + data_args: "DataArguments" +) -> Union["Dataset", "IterableDataset"]: + max_samples = data_args.max_samples + all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets + + for dataset_attr in data_args.dataset_list: + logger.info("Loading dataset {}...".format(dataset_attr)) + + if dataset_attr.load_from == "hf_hub": + data_path = dataset_attr.dataset_name + data_name = dataset_attr.subset + data_files = None + elif dataset_attr.load_from == "script": + data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) + data_name = dataset_attr.subset + data_files = None + elif dataset_attr.load_from == "file": + data_path, data_name = None, None + data_files: List[str] = [] + if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is directory + for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): + data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name)) + if data_path is None: + data_path = EXT2TYPE.get(file_name.split(".")[-1], None) + else: + assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical." + elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is file + data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)) + data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None) + else: + raise ValueError("File not found.") + + assert data_path, "File extension must be txt, csv, json or jsonl." + checksum(data_files, dataset_attr.dataset_sha1) + else: + raise NotImplementedError + + dataset = load_dataset( + path=data_path, + name=data_name, + data_files=data_files, + split=data_args.split, + cache_dir=model_args.cache_dir, + streaming=data_args.streaming, + use_auth_token=True if model_args.use_auth_token else None + ) + + if max_samples is not None: # truncate dataset + dataset = dataset.select(range(min(len(dataset), max_samples))) + + def convert_format(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + # convert dataset from sharegpt format to alpaca format + outputs = {"prompt": [], "query": [], "response": [], "history": []} + for msg_list in examples[dataset_attr.messages]: + msg_list = msg_list[:len(msg_list) // 2 * 2] # should be multiples of 2 + if len(msg_list) == 0: + continue + + msg_pairs = [] + user_role, assistant_role = None, None + for idx in range(0, len(msg_list), 2): + if user_role is None and assistant_role is None: + user_role = msg_list[idx][dataset_attr.role] + assistant_role = msg_list[idx + 1][dataset_attr.role] + else: + if ( + msg_list[idx][dataset_attr.role] != user_role + or msg_list[idx+1][dataset_attr.role] != assistant_role + ): + raise ValueError("Only accepts conversation in u/a/u/a/u/a order.") + msg_pairs.append((msg_list[idx][dataset_attr.content], msg_list[idx + 1][dataset_attr.content])) + + if len(msg_pairs) != 0: + outputs["prompt"].append(msg_pairs[-1][0]) + outputs["query"].append("") + outputs["response"].append(msg_pairs[-1][1]) + outputs["history"].append(msg_pairs[:-1]) + + return outputs + + if dataset_attr.formatting == "sharegpt": # convert format + column_names = list(next(iter(dataset)).keys()) + kwargs = {} + if not data_args.streaming: + kwargs = dict( + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=(not data_args.overwrite_cache), + desc="Converting format of dataset" + ) + + dataset = dataset.map( + convert_format, + batched=True, + remove_columns=column_names, + **kwargs + ) + else: + for column_name in ["prompt", "query", "response", "history"]: # align dataset + if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name: + dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name) + + if dataset_attr.system_prompt: # add system prompt + system_prompt = dataset_attr.system_prompt + if data_args.streaming: + dataset = dataset.map(lambda _: {"system": system_prompt}) + else: + dataset = dataset.add_column("system", [system_prompt] * len(dataset)) + + all_datasets.append(dataset) + + if len(data_args.dataset_list) == 1: + return all_datasets[0] + elif data_args.mix_strategy == "concat": + if data_args.streaming: + logger.warning("The samples between different datasets will not be mixed in streaming mode.") + return concatenate_datasets(all_datasets) + elif data_args.mix_strategy.startswith("interleave"): + if not data_args.streaming: + logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.") + return interleave_datasets( + datasets=all_datasets, + probabilities=data_args.interleave_probs, + seed=data_args.seed, + stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted" + ) + else: + raise ValueError("Unknown mixing strategy.") diff --git a/llm_rl/src/llmtuner/dsets/preprocess.py b/llm_rl/src/llmtuner/dsets/preprocess.py new file mode 100644 index 00000000..0484b78e --- /dev/null +++ b/llm_rl/src/llmtuner/dsets/preprocess.py @@ -0,0 +1,268 @@ +import os +import tiktoken +from itertools import chain +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union + +from datasets import load_from_disk + +from llmtuner.extras.constants import IGNORE_INDEX +from llmtuner.extras.logging import get_logger +from llmtuner.extras.template import get_template_and_fix_tokenizer + +if TYPE_CHECKING: + from datasets import Dataset, IterableDataset + from transformers import Seq2SeqTrainingArguments + from transformers.tokenization_utils import PreTrainedTokenizer + from llmtuner.hparams import DataArguments + + +logger = get_logger(__name__) + + +def preprocess_dataset( + dataset: Union["Dataset", "IterableDataset"], + tokenizer: "PreTrainedTokenizer", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + stage: Literal["pt", "sft", "rm", "ppo"] +) -> Union["Dataset", "IterableDataset"]: + template = get_template_and_fix_tokenizer(data_args.template, tokenizer) + + if data_args.train_on_prompt and template.efficient_eos: + raise ValueError("Current template does not support `train_on_prompt`.") + + def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]: + for i in range(len(examples["prompt"])): + query, response = examples["prompt"][i], examples["response"][i] + query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query + history = examples["history"][i] if "history" in examples else None + system = examples["system"][i] if "system" in examples else None + yield query, response, history, system + + def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]: + # build grouped texts with format `X1 X2 X3 ...` + if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen) + kwargs = dict(allowed_special="all") + else: + kwargs = dict(add_special_tokens=True) + + if hasattr(tokenizer, "add_eos_token"): # for LLaMA tokenizer + setattr(tokenizer, "add_eos_token", True) + + tokenized_examples = tokenizer(examples["prompt"], **kwargs) + concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} + total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) + block_size = data_args.cutoff_len + # we drop the small remainder, and if the total_length < block_size, we exclude this batch + total_length = (total_length // block_size) * block_size + # split by chunks of cutoff_len + result = { + k: [t[i: i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + return result + + def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]: + # build inputs with format ` X Y ` and labels with format ` ... Y ` + # for multiturn examples, we only mask the prompt part in each prompt-response pair. + model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} + + for query, response, history, system in construct_example(examples): + if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""): + continue + + input_ids, labels = [], [] + for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn( + tokenizer, query, response, history, system + )): + total_len = len(source_ids) + len(target_ids) + max_source_len = int(data_args.cutoff_len * (len(source_ids) / total_len)) + max_target_len = int(data_args.cutoff_len * (len(target_ids) / total_len)) + + if len(source_ids) > max_source_len: + source_ids = source_ids[:max_source_len] + if len(target_ids) > max_target_len: + target_ids = target_ids[:max_target_len] + + if data_args.train_on_prompt: + source_mask = source_ids + elif turn_idx != 0 and template.efficient_eos: + source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) + else: + source_mask = [IGNORE_INDEX] * len(source_ids) + + input_ids += source_ids + target_ids + labels += source_mask + target_ids + + if template.efficient_eos: + input_ids += [tokenizer.eos_token_id] + labels += [tokenizer.eos_token_id] + + if len(input_ids) > data_args.cutoff_len: + input_ids = input_ids[:data_args.cutoff_len] + labels = labels[:data_args.cutoff_len] + + model_inputs["input_ids"].append(input_ids) + model_inputs["attention_mask"].append([1] * len(input_ids)) + model_inputs["labels"].append(labels) + + return model_inputs + + def preprocess_packed_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]: + # build inputs with format ` X1 Y1 X2 Y2 ` + # and labels with format ` ... Y1 ... Y2 ` + model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} + input_ids, labels = [], [] + for query, response, history, system in construct_example(examples): + if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""): + continue + + for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn( + tokenizer, query, response, history, system + )): + if data_args.train_on_prompt: + source_mask = source_ids + elif turn_idx != 0 and template.efficient_eos: + source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) + else: + source_mask = [IGNORE_INDEX] * len(source_ids) + input_ids += source_ids + target_ids + labels += source_mask + target_ids + + if template.efficient_eos: + input_ids += [tokenizer.eos_token_id] + labels += [tokenizer.eos_token_id] + + total_length = len(input_ids) + block_size = data_args.cutoff_len + # we drop the small remainder, and if the total_length < block_size, we exclude this batch + total_length = (total_length // block_size) * block_size + # split by chunks of cutoff_len + for i in range(0, total_length, block_size): + model_inputs["input_ids"].append(input_ids[i: i + block_size]) + model_inputs["attention_mask"].append([1] * block_size) + model_inputs["labels"].append(labels[i: i + block_size]) + + return model_inputs + + def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]: + # build inputs with format ` X` and labels with format `Y ` + model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} + + for query, response, history, system in construct_example(examples): + if not (isinstance(query, str) and query != ""): + continue + + input_ids, labels = template.encode_oneturn(tokenizer, query, response, history, system) + + if template.efficient_eos: + labels += [tokenizer.eos_token_id] + + if len(input_ids) > data_args.cutoff_len: + input_ids = input_ids[:data_args.cutoff_len] + if len(labels) > data_args.cutoff_len: + labels = labels[:data_args.cutoff_len] + + model_inputs["input_ids"].append(input_ids) + model_inputs["attention_mask"].append([1] * len(input_ids)) + model_inputs["labels"].append(labels) + + return model_inputs + + def preprocess_pairwise_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]: + # build input pairs with format ` X`, `Y1 ` and `Y2 ` + model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []} + for query, response, history, system in construct_example(examples): + if not (isinstance(query, str) and isinstance(response, list) and query != "" and len(response) > 1): + continue + + prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system) + _, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system) + + if template.efficient_eos: + chosen_ids += [tokenizer.eos_token_id] + rejected_ids += [tokenizer.eos_token_id] + + total_len = len(prompt_ids) + max(len(chosen_ids), len(rejected_ids)) + max_source_len = int(data_args.cutoff_len * (len(prompt_ids) / total_len)) + max_target_len = int(data_args.cutoff_len * (max(len(chosen_ids), len(rejected_ids)) / total_len)) + + if len(prompt_ids) > max_source_len: + prompt_ids = prompt_ids[:max_source_len] + if len(chosen_ids) > max_target_len: + chosen_ids = chosen_ids[:max_target_len] + if len(rejected_ids) > max_target_len: + rejected_ids = rejected_ids[:max_target_len] + + model_inputs["prompt_ids"].append(prompt_ids) + model_inputs["chosen_ids"].append(chosen_ids) + model_inputs["rejected_ids"].append(rejected_ids) + + return model_inputs + + def print_supervised_dataset_example(example: Dict[str, List[int]]) -> None: + print("input_ids:\n{}".format(example["input_ids"])) + print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) + print("label_ids:\n{}".format(example["labels"])) + print("labels:\n{}".format( + tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False) + )) + + def print_pairwise_dataset_example(example: Dict[str, List[int]]) -> None: + print("prompt_ids:\n{}".format(example["prompt_ids"])) + print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False))) + print("chosen_ids:\n{}".format(example["chosen_ids"])) + print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False))) + print("rejected_ids:\n{}".format(example["rejected_ids"])) + print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False))) + + def print_unsupervised_dataset_example(example: Dict[str, List[int]]) -> None: + print("input_ids:\n{}".format(example["input_ids"])) + print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) + + if stage == "pt": + preprocess_func = preprocess_pretrain_dataset + print_function = print_unsupervised_dataset_example + elif stage == "sft" and not training_args.predict_with_generate: + preprocess_func = preprocess_packed_supervised_dataset if data_args.sft_packing else preprocess_supervised_dataset + print_function = print_supervised_dataset_example + elif stage == "rm": + preprocess_func = preprocess_pairwise_dataset + print_function = print_pairwise_dataset_example + else: + preprocess_func = preprocess_unsupervised_dataset + print_function = print_unsupervised_dataset_example + + if data_args.cache_path is not None and os.path.exists(data_args.cache_path): + logger.warning("Loading dataset from disk will ignore other data arguments.") + return load_from_disk(data_args.cache_path) + + with training_args.main_process_first(desc="dataset map pre-processing"): + column_names = list(next(iter(dataset)).keys()) + kwargs = {} + if not data_args.streaming: + kwargs = dict( + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=(not data_args.overwrite_cache), + desc="Running tokenizer on dataset" + ) + + dataset = dataset.map( + preprocess_func, + batched=True, + remove_columns=column_names, + **kwargs + ) + + if data_args.cache_path is not None and not os.path.exists(data_args.cache_path): + if training_args.should_save: + dataset.save_to_disk(data_args.cache_path) + raise SystemExit("Dataset saved, rerun this script with the same `--cache_file`.") + + if training_args.should_log: + try: + print_function(next(iter(dataset))) + except StopIteration: + raise RuntimeError("Empty dataset!") + + return dataset diff --git a/llm_rl/src/llmtuner/dsets/utils.py b/llm_rl/src/llmtuner/dsets/utils.py new file mode 100644 index 00000000..bf337014 --- /dev/null +++ b/llm_rl/src/llmtuner/dsets/utils.py @@ -0,0 +1,59 @@ +import hashlib +from typing import TYPE_CHECKING, Dict, List, Optional, Union + +from llmtuner.extras.logging import get_logger + +if TYPE_CHECKING: + from datasets import Dataset, IterableDataset + from transformers import TrainingArguments + from llmtuner.hparams import DataArguments + + +logger = get_logger(__name__) + + +EXT2TYPE = { + "csv": "csv", + "json": "json", + "jsonl": "json", + "txt": "text" +} + + +def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None: + if file_sha1 is None: + logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.") + return + + if len(data_files) != 1: + logger.warning("Checksum failed: too many files.") + return + + with open(data_files[0], "rb") as f: + sha1 = hashlib.sha1(f.read()).hexdigest() + if sha1 != file_sha1: + logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0])) + + +def split_dataset( + dataset: Union["Dataset", "IterableDataset"], + data_args: "DataArguments", + training_args: "TrainingArguments" +) -> Dict[str, "Dataset"]: + if training_args.do_train: + if data_args.val_size > 1e-6: # Split the dataset + if data_args.streaming: + val_set = dataset.take(int(data_args.val_size)) + train_set = dataset.skip(int(data_args.val_size)) + dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) + return {"train_dataset": train_set, "eval_dataset": val_set} + else: + val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size + dataset = dataset.train_test_split(test_size=val_size, seed=training_args.seed) + return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]} + else: + if data_args.streaming: + dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) + return {"train_dataset": dataset} + else: # do_eval or do_predict + return {"eval_dataset": dataset} diff --git a/llm_rl/src/llmtuner/extras/__init__.py b/llm_rl/src/llmtuner/extras/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/llm_rl/src/llmtuner/extras/callbacks.py b/llm_rl/src/llmtuner/extras/callbacks.py new file mode 100644 index 00000000..7398d424 --- /dev/null +++ b/llm_rl/src/llmtuner/extras/callbacks.py @@ -0,0 +1,155 @@ +import os +import json +import time +from typing import TYPE_CHECKING +from datetime import timedelta + +from transformers import TrainerCallback +from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR + +from llmtuner.extras.constants import LOG_FILE_NAME +from llmtuner.extras.logging import get_logger + +if TYPE_CHECKING: + from transformers import TrainingArguments, TrainerState, TrainerControl + + +logger = get_logger(__name__) + + +class SavePeftModelCallback(TrainerCallback): + + def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called after a checkpoint save. + """ + if args.should_save: + output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)) + model = kwargs.pop("model") + if getattr(model, "is_peft_model", False): + getattr(model, "pretrained_model").save_pretrained(output_dir) + + def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the end of training. + """ + if args.should_save: + model = kwargs.pop("model") + if getattr(model, "is_peft_model", False): + getattr(model, "pretrained_model").save_pretrained(args.output_dir) + + +class LogCallback(TrainerCallback): + + def __init__(self, runner=None): + self.runner = runner + self.in_training = False + self.start_time = time.time() + self.cur_steps = 0 + self.max_steps = 0 + self.elapsed_time = "" + self.remaining_time = "" + + def timing(self): + cur_time = time.time() + elapsed_time = cur_time - self.start_time + avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0 + remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step + self.elapsed_time = str(timedelta(seconds=int(elapsed_time))) + self.remaining_time = str(timedelta(seconds=int(remaining_time))) + + def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the beginning of training. + """ + if state.is_local_process_zero: + self.in_training = True + self.start_time = time.time() + self.max_steps = state.max_steps + if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir: + logger.warning("Previous log file in this folder will be deleted.") + os.remove(os.path.join(args.output_dir, LOG_FILE_NAME)) + + def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the end of training. + """ + if state.is_local_process_zero: + self.in_training = False + self.cur_steps = 0 + self.max_steps = 0 + + def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the end of an substep during gradient accumulation. + """ + if state.is_local_process_zero and self.runner is not None and self.runner.aborted: + control.should_epoch_stop = True + control.should_training_stop = True + + def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the end of a training step. + """ + if state.is_local_process_zero: + self.cur_steps = state.global_step + self.timing() + if self.runner is not None and self.runner.aborted: + control.should_epoch_stop = True + control.should_training_stop = True + + def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called after an evaluation phase. + """ + if state.is_local_process_zero and not self.in_training: + self.cur_steps = 0 + self.max_steps = 0 + + def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs): + r""" + Event called after a successful prediction. + """ + if state.is_local_process_zero and not self.in_training: + self.cur_steps = 0 + self.max_steps = 0 + + def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None: + r""" + Event called after logging the last logs. + """ + if not state.is_local_process_zero: + return + + logs = dict( + current_steps=self.cur_steps, + total_steps=self.max_steps, + loss=state.log_history[-1].get("loss", None), + eval_loss=state.log_history[-1].get("eval_loss", None), + predict_loss=state.log_history[-1].get("predict_loss", None), + reward=state.log_history[-1].get("reward", None), + learning_rate=state.log_history[-1].get("learning_rate", None), + epoch=state.log_history[-1].get("epoch", None), + percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, + elapsed_time=self.elapsed_time, + remaining_time=self.remaining_time + ) + if self.runner is not None: + logger.info("{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format( + logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0 + )) + + os.makedirs(args.output_dir, exist_ok=True) + with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f: + f.write(json.dumps(logs) + "\n") + + def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called after a prediction step. + """ + eval_dataloader = kwargs.pop("eval_dataloader", None) + if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training: + if self.max_steps == 0: + self.max_steps = len(eval_dataloader) + self.cur_steps += 1 + self.timing() diff --git a/llm_rl/src/llmtuner/extras/constants.py b/llm_rl/src/llmtuner/extras/constants.py new file mode 100644 index 00000000..dc55a080 --- /dev/null +++ b/llm_rl/src/llmtuner/extras/constants.py @@ -0,0 +1,92 @@ +IGNORE_INDEX = -100 + +LOG_FILE_NAME = "trainer_log.jsonl" + +LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp", "ln_1", "ln_2"] + +METHODS = ["full", "freeze", "lora"] + +TRAINING_STAGES = { + "Supervised Fine-Tuning": "sft", + "Reward Modeling": "rm", + "PPO": "ppo", + "DPO": "dpo", + "Pre-Training": "pt" +} + +SUPPORTED_MODELS = { + "LLaMA-7B": "huggyllama/llama-7b", + "LLaMA-13B": "huggyllama/llama-13b", + "LLaMA-30B": "huggyllama/llama-30b", + "LLaMA-65B": "huggyllama/llama-65b", + "LLaMA2-7B": "meta-llama/Llama-2-7b-hf", + "LLaMA2-13B": "meta-llama/Llama-2-13b-hf", + "LLaMA2-70B": "meta-llama/Llama-2-70b-hf", + "LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf", + "LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf", + "LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf", + "ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b", + "ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b", + "ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b", + "ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b", + "BLOOM-560M": "bigscience/bloom-560m", + "BLOOM-3B": "bigscience/bloom-3b", + "BLOOM-7B1": "bigscience/bloom-7b1", + "BLOOMZ-560M": "bigscience/bloomz-560m", + "BLOOMZ-3B": "bigscience/bloomz-3b", + "BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt", + "Falcon-7B": "tiiuae/falcon-7b", + "Falcon-40B": "tiiuae/falcon-40b", + "Falcon-7B-Chat": "tiiuae/falcon-7b-instruct", + "Falcon-40B-Chat": "tiiuae/falcon-40b-instruct", + "Baichuan-7B": "baichuan-inc/Baichuan-7B", + "Baichuan-13B": "baichuan-inc/Baichuan-13B-Base", + "Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat", + "Baichuan2-7B": "baichuan-inc/Baichuan2-7B-Base", + "Baichuan2-13B": "baichuan-inc/Baichuan2-13B-Base", + "Baichuan2-7B-Chat": "baichuan-inc/Baichuan2-7B-Chat", + "Baichuan2-13B-Chat": "baichuan-inc/Baichuan2-13B-Chat", + "InternLM-7B": "internlm/internlm-7b", + "InternLM-20B": "internlm/internlm-20b", + "InternLM-7B-Chat": "internlm/internlm-chat-7b", + "InternLM-20B-Chat": "internlm/internlm-chat-20b", + "Qwen-7B": "Qwen/Qwen-7B", + "Qwen-14B": "Qwen/Qwen-14B", + "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat", + "Qwen-14B-Chat": "Qwen/Qwen-14B-Chat", + "XVERSE-13B": "xverse/XVERSE-13B", + "XVERSE-13B-Chat": "xverse/XVERSE-13B-Chat", + "ChatGLM2-6B-Chat": "THUDM/chatglm2-6b", + "ChatGLM3-6B-Base": "THUDM/chatglm3-6b-base", + "ChatGLM3-6B-Chat": "THUDM/chatglm3-6b", + "Phi1.5-1.3B": "microsoft/phi-1_5" +} + +DEFAULT_MODULE = { + "LLaMA": "q_proj,v_proj", + "LLaMA2": "q_proj,v_proj", + "ChineseLLaMA2": "q_proj,v_proj", + "BLOOM": "query_key_value", + "BLOOMZ": "query_key_value", + "Falcon": "query_key_value", + "Baichuan": "W_pack", + "Baichuan2": "W_pack", + "InternLM": "q_proj,v_proj", + "Qwen": "c_attn", + "XVERSE": "q_proj,v_proj", + "ChatGLM2": "query_key_value", + "ChatGLM3": "query_key_value", + "Phi1.5": "Wqkv" +} + +DEFAULT_TEMPLATE = { + "LLaMA2": "llama2", + "ChineseLLaMA2": "llama2_zh", + "Baichuan": "baichuan", + "Baichuan2": "baichuan2", + "InternLM": "intern", + "Qwen": "chatml", + "XVERSE": "xverse", + "ChatGLM2": "chatglm2", + "ChatGLM3": "chatglm3" +} diff --git a/llm_rl/src/llmtuner/extras/logging.py b/llm_rl/src/llmtuner/extras/logging.py new file mode 100644 index 00000000..d6f185e6 --- /dev/null +++ b/llm_rl/src/llmtuner/extras/logging.py @@ -0,0 +1,43 @@ +import sys +import logging + + +class LoggerHandler(logging.Handler): + + def __init__(self): + super().__init__() + self.log = "" + + def reset(self): + self.log = "" + + def emit(self, record): + if record.name == "httpx": + return + log_entry = self.format(record) + self.log += log_entry + self.log += "\n\n" + + +def reset_logging(): + r""" + Removes basic config of root logger + """ + root = logging.getLogger() + list(map(root.removeHandler, root.handlers)) + list(map(root.removeFilter, root.filters)) + + +def get_logger(name: str) -> logging.Logger: + formatter = logging.Formatter( + fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S" + ) + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(formatter) + + logger = logging.getLogger(name) + logger.setLevel(logging.INFO) + logger.addHandler(handler) + + return logger diff --git a/llm_rl/src/llmtuner/extras/misc.py b/llm_rl/src/llmtuner/extras/misc.py new file mode 100644 index 00000000..960d43ee --- /dev/null +++ b/llm_rl/src/llmtuner/extras/misc.py @@ -0,0 +1,118 @@ +import gc +import torch +from typing import TYPE_CHECKING, Tuple +from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList + +try: + from transformers.utils import ( + is_torch_bf16_cpu_available, + is_torch_bf16_gpu_available, + is_torch_cuda_available, + is_torch_npu_available + ) + _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available() + _is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available +except ImportError: + _is_fp16_available = torch.cuda.is_available() + _is_bf16_available = torch.cuda.is_bf16_supported() + +if TYPE_CHECKING: + from transformers.modeling_utils import PreTrainedModel + + +class AverageMeter: + r""" + Computes and stores the average and current value. + """ + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: + r""" + Returns the number of trainable parameters and number of all parameters in the model. + """ + trainable_params, all_param = 0, 0 + for param in model.parameters(): + num_params = param.numel() + # if using DS Zero 3 and the weights are initialized empty + if num_params == 0 and hasattr(param, "ds_numel"): + num_params = param.ds_numel + + # Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2 + if param.__class__.__name__ == "Params4bit": + num_params = num_params * 2 + + all_param += num_params + if param.requires_grad: + trainable_params += num_params + + return trainable_params, all_param + + +def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype: + r""" + Infers the optimal dtype according to the model_dtype and device compatibility. + """ + if _is_bf16_available and model_dtype == torch.bfloat16: + return torch.bfloat16 + elif _is_fp16_available: + return torch.float16 + else: + return torch.float32 + + +def get_logits_processor() -> LogitsProcessorList: + r""" + Gets logits processor that removes NaN and Inf logits. + """ + logits_processor = LogitsProcessorList() + logits_processor.append(InfNanRemoveLogitsProcessor()) + return logits_processor + + +def torch_gc() -> None: + r""" + Collects GPU memory. + """ + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + +def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel": + r""" + Dispatches a pre-trained model to GPUs with balanced memory. + Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803 + """ + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing + return model + + if torch.cuda.device_count() > 1: + from accelerate import dispatch_model + from accelerate.utils import infer_auto_device_map, get_balanced_memory + + if model._no_split_modules is None: + raise ValueError("The model class needs to implement the `_no_split_modules` attribute.") + + kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules} + max_memory = get_balanced_memory(model, **kwargs) + # Make sure tied weights are tied before creating the device map. + model.tie_weights() + device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs) + return dispatch_model(model, device_map) + else: + return model.cuda() diff --git a/llm_rl/src/llmtuner/extras/patches/__init__.py b/llm_rl/src/llmtuner/extras/patches/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/llm_rl/src/llmtuner/extras/patches/llama_patch.py b/llm_rl/src/llmtuner/extras/patches/llama_patch.py new file mode 100644 index 00000000..a8473311 --- /dev/null +++ b/llm_rl/src/llmtuner/extras/patches/llama_patch.py @@ -0,0 +1,218 @@ +import math +import torch +import torch.nn as nn +from typing import Optional, Tuple +from transformers.utils import logging +from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv + +try: + from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore + from flash_attn.bert_padding import pad_input, unpad_input # type: ignore +except ImportError: + print("FlashAttention-2 is not installed, ignore this if you are not using FlashAttention.") + + +logger = logging.get_logger(__name__) + + +# Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py +class LlamaShiftShortAttention(LlamaAttention): + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + if getattr(self, "num_key_value_groups"): + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if getattr(self.config, "group_size_ratio", None) and self.training: # shift + groupsz = int(q_len * getattr(self.config, "group_size_ratio")) + assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) + num_groups = q_len // groupsz + def shift(state: torch.Tensor) -> torch.Tensor: + state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim) + state = torch.cat(( + state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1) + ), dim=2) + return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2) + + query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) + if attention_mask is not None: + attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :) + attn_output = attn_output.transpose(1, 2).contiguous() + + if getattr(self.config, "group_size_ratio", None) and self.training: # shift back + attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) + attn_output = torch.cat(( + attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1) + )) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaFlashAttention2(LlamaAttention): + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # LlamaFlashAttention2 attention does not support output_attentions + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # cast to half precision + input_dtype = query_states.dtype + if input_dtype == torch.float32: + logger.warning_once("The input hidden states seems to be silently casted in float32.") + query_states = query_states.to(self.config.torch_dtype) + key_states = key_states.to(self.config.torch_dtype) + value_states = value_states.to(self.config.torch_dtype) + + if getattr(self, "num_key_value_groups", None): + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) + key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) + value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) + + if getattr(self.config, "group_size_ratio", None) and self.training: # shift + groupsz = int(q_len * getattr(self.config, "group_size_ratio")) + assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) + num_groups = q_len // groupsz + def shift(state: torch.Tensor) -> torch.Tensor: + state = torch.cat(( + state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1) + ), dim=2) + return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim) + + query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) + if attention_mask is not None: + attention_mask = attention_mask.reshape(bsz * num_groups, groupsz) + + if attention_mask is not None: + logger.warning_once("Padded sequences are less efficient in FlashAttention.") + # -q_len: assumes left padding when q_len != kv_len + unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_states, attention_mask[:, -q_len:]) + unpadded_k, _, cu_seqlens_k, max_seqlen_k = unpad_input(key_states, attention_mask) + unpadded_v, _, _, _ = unpad_input(value_states, attention_mask) + attn_output_unpad = flash_attn_varlen_func( + unpadded_q, + unpadded_k, + unpadded_v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=0.0, + softmax_scale=None, + causal=True, + ) + attn_output = pad_input(attn_output_unpad, indices_q, bsz, q_len) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, 0.0, softmax_scale=None, causal=True + ) + + if getattr(self.config, "group_size_ratio", None) and self.training: # shift back + attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) + attn_output = torch.cat(( + attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1) + )) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Disable the transformation of the attention mask in LlamaModel as flash attention +# takes a boolean padding_mask. Fills in the past kv length for use in forward. +def _prepare_decoder_attention_mask( + self, + attention_mask: torch.Tensor, + input_shape: torch.Tensor, + inputs_embeds: torch.Tensor, + past_key_values_length: int +) -> torch.Tensor: + if attention_mask is not None and torch.all(attention_mask): + return None # This uses the faster call when training with full samples + + return attention_mask diff --git a/llm_rl/src/llmtuner/extras/ploting.py b/llm_rl/src/llmtuner/extras/ploting.py new file mode 100644 index 00000000..82530e45 --- /dev/null +++ b/llm_rl/src/llmtuner/extras/ploting.py @@ -0,0 +1,52 @@ +import os +import math +import json +import matplotlib.pyplot as plt +from typing import List, Optional +from transformers.trainer import TRAINER_STATE_NAME + +from llmtuner.extras.logging import get_logger + + +logger = get_logger(__name__) + + +def smooth(scalars: List[float]) -> List[float]: + r""" + EMA implementation according to TensorBoard. + """ + last = scalars[0] + smoothed = list() + weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function + for next_val in scalars: + smoothed_val = last * weight + (1 - weight) * next_val + smoothed.append(smoothed_val) + last = smoothed_val + return smoothed + + +def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None: + + with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f: + data = json.load(f) + + for key in keys: + steps, metrics = [], [] + for i in range(len(data["log_history"])): + if key in data["log_history"][i]: + steps.append(data["log_history"][i]["step"]) + metrics.append(data["log_history"][i][key]) + + if len(metrics) == 0: + logger.warning(f"No metric {key} to plot.") + continue + + plt.figure() + plt.plot(steps, metrics, alpha=0.4, label="original") + plt.plot(steps, smooth(metrics), label="smoothed") + plt.title("training {} of {}".format(key, save_dictionary)) + plt.xlabel("step") + plt.ylabel(key) + plt.legend() + plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100) + print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key))) diff --git a/llm_rl/src/llmtuner/extras/save_and_load.py b/llm_rl/src/llmtuner/extras/save_and_load.py new file mode 100644 index 00000000..6d819ce6 --- /dev/null +++ b/llm_rl/src/llmtuner/extras/save_and_load.py @@ -0,0 +1,21 @@ +import os +import torch +from transformers.trainer import WEIGHTS_NAME + +from llmtuner.extras.logging import get_logger + + +logger = get_logger(__name__) + + +def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool: + vhead_file = os.path.join(checkpoint_dir, WEIGHTS_NAME) + if not os.path.exists(vhead_file): + logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir)) + return False + vhead_params = torch.load(vhead_file, map_location="cpu") + model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False) + model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False) + model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False) + model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False) + return True diff --git a/llm_rl/src/llmtuner/extras/template.py b/llm_rl/src/llmtuner/extras/template.py new file mode 100644 index 00000000..401750ce --- /dev/null +++ b/llm_rl/src/llmtuner/extras/template.py @@ -0,0 +1,713 @@ +import tiktoken +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +from llmtuner.extras.logging import get_logger + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + + +logger = get_logger(__name__) + + +@dataclass +class Template: + + prefix: List[Union[str, Dict[str, str]]] + prompt: List[Union[str, Dict[str, str]]] + system: str + sep: List[Union[str, Dict[str, str]]] + stop_words: List[str] + use_history: bool + efficient_eos: bool + + def encode_oneturn( + self, + tokenizer: "PreTrainedTokenizer", + query: str, + resp: str, + history: Optional[List[Tuple[str, str]]] = None, + system: Optional[str] = None + ) -> Tuple[List[int], List[int]]: + r""" + Returns a single pair of token ids representing prompt and response respectively. + """ + system, history = self._format(query, resp, history, system) + encoded_pairs = self._encode(tokenizer, system, history) + prompt_ids = [] + for query_ids, resp_ids in encoded_pairs[:-1]: + prompt_ids = prompt_ids + query_ids + resp_ids + prompt_ids, answer_ids = prompt_ids + encoded_pairs[-1][0], encoded_pairs[-1][1] + return prompt_ids, answer_ids + + def encode_multiturn( + self, + tokenizer: "PreTrainedTokenizer", + query: str, + resp: str, + history: Optional[List[Tuple[str, str]]] = None, + system: Optional[str] = None + ) -> List[Tuple[List[int], List[int]]]: + r""" + Returns multiple pairs of token ids representing prompts and responses respectively. + """ + system, history = self._format(query, resp, history, system) + encoded_pairs = self._encode(tokenizer, system, history) + return encoded_pairs + + def _format( + self, + query: str, + resp: str, + history: Optional[List[Tuple[str, str]]] = None, + system: Optional[str] = None + ) -> Tuple[str, List[Tuple[str, str]]]: + r""" + Aligns inputs to the standard format. + """ + system = system or self.system # use system if provided + history = history if (history and self.use_history) else [] + history = history + [(query, resp)] + return system, history + + def _get_special_ids( + self, + tokenizer: "PreTrainedTokenizer" + ) -> Tuple[List[int], List[int]]: + if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True): + bos_ids = [tokenizer.bos_token_id] + else: # baichuan, qwen and gpt2 models have no bos token + bos_ids = [] + + if tokenizer.eos_token_id is None: + raise ValueError("EOS token is required.") + + if self.efficient_eos: # used in baichuan, qwen, chatglm, etc. + eos_ids = [] + else: + eos_ids = [tokenizer.eos_token_id] + + return bos_ids, eos_ids + + def _encode( + self, + tokenizer: "PreTrainedTokenizer", + system: str, + history: List[Tuple[str, str]] + ) -> List[Tuple[List[int], List[int]]]: + r""" + Encodes formatted inputs to pairs of token ids. + Turn 0: bos + prefix + sep + query resp + eos + Turn t: sep + bos + query resp + eos + """ + bos_ids, eos_ids = self._get_special_ids(tokenizer) + sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep) + encoded_pairs = [] + for turn_idx, (query, resp) in enumerate(history): + if turn_idx == 0: + prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system) + if len(prefix_ids) != 0: # has prefix + prefix_ids = bos_ids + prefix_ids + sep_ids + else: + prefix_ids = bos_ids + else: + prefix_ids = sep_ids + bos_ids + + query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query, idx=str(turn_idx)) + resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) + encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids)) + return encoded_pairs + + def _convert_inputs_to_ids( + self, + tokenizer: "PreTrainedTokenizer", + context: List[Union[str, Dict[str, str]]], + system: Optional[str] = None, + query: Optional[str] = None, + idx: Optional[str] = None + ) -> List[int]: + r""" + Converts context to token ids. + """ + if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen) + kwargs = dict(allowed_special="all") + else: + kwargs = dict(add_special_tokens=False) + + token_ids = [] + for elem in context: + if isinstance(elem, str): + elem = elem.replace("{{system}}", system, 1) if system is not None else elem + elem = elem.replace("{{query}}", query, 1) if query is not None else elem + elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem + if len(elem) != 0: + token_ids = token_ids + tokenizer.encode(elem, **kwargs) + elif isinstance(elem, dict): + token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))] + else: + raise ValueError("Input must be string or dict[str, str], got {}".format(type(elem))) + + return token_ids + + +@dataclass +class Llama2Template(Template): + + def _encode( + self, + tokenizer: "PreTrainedTokenizer", + system: str, + history: List[Tuple[str, str]] + ) -> List[Tuple[List[int], List[int]]]: + r""" + Encodes formatted inputs to pairs of token ids. + Turn 0: bos + prefix + query resp + eos + Turn t: bos + query resp + eos + """ + bos_ids, eos_ids = self._get_special_ids(tokenizer) + encoded_pairs = [] + for turn_idx, (query, resp) in enumerate(history): + if turn_idx == 0: # llama2 template has no sep_ids + query = self.prefix[0].replace("{{system}}", system) + query + query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query) + resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) + encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids)) + return encoded_pairs + + +templates: Dict[str, Template] = {} + + +def register_template( + name: str, + prefix: List[Union[str, Dict[str, str]]], + prompt: List[Union[str, Dict[str, str]]], + system: str, + sep: List[Union[str, Dict[str, str]]], + stop_words: Optional[List[str]] = [], + use_history: Optional[bool] = True, + efficient_eos: Optional[bool] = False +) -> None: + template_class = Llama2Template if "llama2" in name else Template + templates[name] = template_class( + prefix=prefix, + prompt=prompt, + system=system, + sep=sep, + stop_words=stop_words, + use_history=use_history, + efficient_eos=efficient_eos + ) + + +def get_template_and_fix_tokenizer( + name: str, + tokenizer: "PreTrainedTokenizer" +) -> Template: + if tokenizer.eos_token_id is None: + tokenizer.eos_token = "<|endoftext|>" + logger.info("Add eos token: {}".format(tokenizer.eos_token)) + + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + logger.info("Add pad token: {}".format(tokenizer.pad_token)) + + if name is None: + return None + + template = templates.get(name, None) + assert template is not None, "Template {} does not exist.".format(name) + tokenizer.add_special_tokens( + dict(additional_special_tokens=template.stop_words), + replace_additional_special_tokens=False + ) + return template + + +r""" +Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff +""" +register_template( + name="alpaca", + prefix=[ + "{{system}}" + ], + prompt=[ + "### Instruction:\n{{query}}\n\n### Response:\n" + ], + system=( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request." + ), + sep=[ + "\n\n" + ] +) + + +r""" +Supports: https://huggingface.co/BAAI/AquilaChat-7B + https://huggingface.co/BAAI/AquilaChat2-7B + https://huggingface.co/BAAI/AquilaChat2-34B +""" +register_template( + name="aquila", + prefix=[ + "{{system}}" + ], + prompt=[ + "Human: {{query}}###Assistant:" + ], + system=( + "A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions." + ), + sep=[ + "###" + ], + stop_words=[ + "" + ], + efficient_eos=True +) + + +r""" +Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat +""" +register_template( + name="baichuan", + prefix=[ + "{{system}}" + ], + prompt=[ + {"token": ""}, # user token + "{{query}}", + {"token": ""} # assistant token + ], + system="", + sep=[], + efficient_eos=True +) + + +r""" +Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat + https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat +""" +register_template( + name="baichuan2", + prefix=[ + "{{system}}" + ], + prompt=[ + {"token": ""}, # user token + "{{query}}", + {"token": ""} # assistant token + ], + system="", + sep=[], + efficient_eos=True +) + + +r""" +Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B +""" +register_template( + name="belle", + prefix=[ + "{{system}}" + ], + prompt=[ + "Human: {{query}}\n\nBelle: " + ], + system="", + sep=[ + "\n\n" + ] +) + + +r""" +Supports: https://huggingface.co/vivo-ai/BlueLM-7B-Chat +""" +register_template( + name="bluelm", + prefix=[ + "{{system}}" + ], + prompt=[ + {"token": "[|Human|]:"}, + "{{query}}", + {"token": "[|AI|]:"} + ], + system="", + sep=[] +) + + +r""" +Supports: https://huggingface.co/THUDM/chatglm2-6b +""" +register_template( + name="chatglm2", + prefix=[ + {"token": "[gMASK]"}, + {"token": "sop"}, + "{{system}}" + ], + prompt=[ + "[Round {{idx}}]\n\n问:{{query}}\n\n答:" + ], + system="", + sep=[ + "\n\n" + ], + efficient_eos=True +) + + +r""" +Supports: https://huggingface.co/THUDM/chatglm3-6b +""" +register_template( + name="chatglm3", + prefix=[ + {"token": "[gMASK]"}, + {"token": "sop"}, + "{{system}}" + ], + prompt=[ + {"token": "<|user|>"}, + "\n", + "{{query}}", + {"token": "<|assistant|>"} + ], + system="", + sep=[], + stop_words=[ + "<|user|>", + "<|observation|>" + ], + efficient_eos=True +) + + +r""" +Supports: https://huggingface.co/deepseek-ai/deepseek-coder-1.3b-instruct + https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct + https://huggingface.co/deepseek-ai/deepseek-coder-33b-instruct +""" +register_template( + name="deepseek", + prefix=[ + "{{system}}" + ], + prompt=[ + "### Instruction:\n{{query}}\n\n### Response:\n" + ], + system=( + "You are an AI programming assistant, utilizing the Deepseek Coder model, " + "developed by Deepseek Company, and you only answer questions related to computer science. " + "For politically sensitive questions, security and privacy issues, " + "and other non-computer science questions, you will refuse to answer." + ), + sep=[ + "\n", + {"token": "<|EOT|>"}, + "\n\n" + ], + stop_words=[ + "<|EOT|>" + ], + efficient_eos=True +) + + +r""" +Default template. +""" +register_template( + name="default", + prefix=[ + "{{system}}" + ], + prompt=[ + "Human: {{query}}\nAssistant:" + ], + system=( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ), + sep=[ + "\n" + ] +) + + +r""" +Supports: https://huggingface.co/internlm/internlm-chat-7b + https://huggingface.co/internlm/internlm-chat-20b +""" +register_template( + name="intern", + prefix=[ + "{{system}}" + ], + prompt=[ + "<|User|>:{{query}}", + {"token": ""}, + "\n<|Bot|>:" + ], + system="", + sep=[ + {"token": ""}, + "\n" + ], + stop_words=[ + "" + ], + efficient_eos=True +) + + +r""" +Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf + https://huggingface.co/meta-llama/Llama-2-13b-chat-hf + https://huggingface.co/meta-llama/Llama-2-70b-chat-hf +""" +register_template( + name="llama2", + prefix=[ + "<>\n{{system}}\n<>\n\n" + ], + prompt=[ + "[INST] {{query}} [/INST]" + ], + system=( + "You are a helpful, respectful and honest assistant. " + "Always answer as helpfully as possible, while being safe. " + "Your answers should not include any harmful, unethical, " + "racist, sexist, toxic, dangerous, or illegal content. " + "Please ensure that your responses are socially unbiased and positive in nature.\n\n" + "If a question does not make any sense, or is not factually coherent, " + "explain why instead of answering something not correct. " + "If you don't know the answer to a question, please don't share false information." + ), + sep=[] +) + + +r""" +Supports: https://huggingface.co/ziqingyang/chinese-alpaca-2-7b + https://huggingface.co/ziqingyang/chinese-alpaca-2-13b +""" +register_template( + name="llama2_zh", + prefix=[ + "<>\n{{system}}\n<>\n\n" + ], + prompt=[ + "[INST] {{query}} [/INST]" + ], + system="You are a helpful assistant. 你是一个乐于助人的助手。", + sep=[] +) + + +r""" +Supports: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1 +""" +register_template( + name="mistral", + prefix=[ + "{{system}}" + ], + prompt=[ + "[INST] {{query}} [/INST]" + ], + system="", + sep=[] +) + + +r""" +Supports: https://huggingface.co/openchat/openchat_3.5 +""" +register_template( + name="openchat", + prefix=[ + "{{system}}" + ], + prompt=[ + "GPT4 Correct User: {{query}}", + {"token": "<|end_of_turn|>"}, + "GPT4 Correct Assistant:" + ], + system="You are a helpful assistant.", + sep=[ + {"token": "<|end_of_turn|>"} + ], + stop_words=[ + "<|end_of_turn|>" + ], + efficient_eos=True +) + + +r""" +Supports: https://huggingface.co/Qwen/Qwen-7B-Chat + https://huggingface.co/Qwen/Qwen-14B-Chat +""" +register_template( + name="qwen", + prefix=[ + {"token": "<|im_start|>"}, + "system\n{{system}}" + ], + prompt=[ + {"token": "<|im_start|>"}, + "user\n{{query}}", + {"token": "<|im_end|>"}, + "\n", + {"token": "<|im_start|>"}, + "assistant\n" + ], + system="You are a helpful assistant.", + sep=[ + {"token": "<|im_end|>"}, + "\n" + ], + stop_words=[ + "<|im_end|>" + ], + efficient_eos=True +) + + +r""" +Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha + https://huggingface.co/HuggingFaceH4/starchat-beta +""" +register_template( + name="starchat", + prefix=[ + {"token": "<|system|>"}, + "\n{{system}}", + ], + prompt=[ + {"token": "<|user|>"}, + "\n{{query}}", + {"token": "<|end|>"}, + "\n", + {"token": "<|assistant|>"} + ], + system="", + sep=[ + {"token": "<|end|>"}, + "\n" + ], + stop_words=[ + "<|end|>" + ], + efficient_eos=True +) + + +r""" +Supports language model inference without histories. +""" +register_template( + name="vanilla", + prefix=[], + prompt=[ + "{{query}}" + ], + system="", + sep=[], + use_history=False +) + + +r""" +Supports: https://huggingface.co/lmsys/vicuna-7b-v1.5 + https://huggingface.co/lmsys/vicuna-13b-v1.5 +""" +register_template( + name="vicuna", + prefix=[ + "{{system}}" + ], + prompt=[ + "USER: {{query}} ASSISTANT:" + ], + system=( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ), + sep=[] +) + + +r""" +Supports: https://huggingface.co/xverse/XVERSE-7B-Chat + https://huggingface.co/xverse/XVERSE-13B-Chat +""" +register_template( + name="xverse", + prefix=[ + "{{system}}" + ], + prompt=[ + "Human: {{query}}\n\nAssistant: " + ], + system="", + sep=[] +) + + +r""" +Supports: https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha + https://huggingface.co/HuggingFaceH4/zephyr-7b-beta +""" +register_template( + name="zephyr", + prefix=[ + {"token": "<|system|>"}, + "\n{{system}}", + {"token": ""} + ], + prompt=[ + {"token": "<|user|>"}, + "\n{{query}}", + {"token": ""}, + {"token": "<|assistant|>"} + ], + system="You are a friendly chatbot who always responds in the style of a pirate", + sep=[] +) + + +r""" +Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1 + https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1.1 + https://huggingface.co/IDEA-CCNL/Ziya2-13B-Chat +""" +register_template( + name="ziya", + prefix=[ + "{{system}}" + ], + prompt=[ + {"token": ""}, + ":{{query}}\n", + {"token": ""}, + ":" + ], + system="", + sep=[ + "\n" + ] +) diff --git a/llm_rl/src/llmtuner/hparams/__init__.py b/llm_rl/src/llmtuner/hparams/__init__.py new file mode 100644 index 00000000..f0547cc5 --- /dev/null +++ b/llm_rl/src/llmtuner/hparams/__init__.py @@ -0,0 +1,4 @@ +from .data_args import DataArguments +from .finetuning_args import FinetuningArguments +from .generating_args import GeneratingArguments +from .model_args import ModelArguments diff --git a/llm_rl/src/llmtuner/hparams/data_args.py b/llm_rl/src/llmtuner/hparams/data_args.py new file mode 100644 index 00000000..4c67dd65 --- /dev/null +++ b/llm_rl/src/llmtuner/hparams/data_args.py @@ -0,0 +1,169 @@ +import os +import json +from typing import List, Literal, Optional +from dataclasses import dataclass, field + + +@dataclass +class DatasetAttr: + + load_from: str + dataset_name: Optional[str] = None + dataset_sha1: Optional[str] = None + system_prompt: Optional[str] = None + subset: Optional[str] = None + ranking: Optional[bool] = False + formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca" + + prompt: Optional[str] = "instruction" + query: Optional[str] = "input" + response: Optional[str] = "output" + history: Optional[str] = None + messages: Optional[str] = "conversations" + role: Optional[str] = "from" + content: Optional[str] = "value" + + def __repr__(self) -> str: + return self.dataset_name + + +@dataclass +class DataArguments: + r""" + Arguments pertaining to what data we are going to input our model for training and evaluation. + """ + template: Optional[str] = field( + default=None, + metadata={"help": "Which template to use for constructing prompts in training and inference."} + ) + dataset: Optional[str] = field( + default=None, + metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."} + ) + dataset_dir: Optional[str] = field( + default="data", + metadata={"help": "The name of the folder containing datasets."} + ) + split: Optional[str] = field( + default="train", + metadata={"help": "Which dataset split to use for training and evaluation."} + ) + cutoff_len: Optional[int] = field( + default=1024, + metadata={"help": "The maximum length of the model inputs after tokenization."} + ) + train_on_prompt: Optional[bool] = field( + default=False, + metadata={"help": "Whether to disable the mask on the prompt or not."} + ) + streaming: Optional[bool] = field( + default=False, + metadata={"help": "Enable dataset streaming."} + ) + buffer_size: Optional[int] = field( + default=16384, + metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."} + ) + mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field( + default="concat", + metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."} + ) + interleave_probs: Optional[str] = field( + default=None, + metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."} + ) + overwrite_cache: Optional[bool] = field( + default=False, + metadata={"help": "Overwrite the cached training and evaluation sets."} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."} + ) + max_samples: Optional[int] = field( + default=None, + metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."} + ) + eval_num_beams: Optional[int] = field( + default=None, + metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"} + ) + ignore_pad_token_for_loss: Optional[bool] = field( + default=True, + metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."} + ) + system_prompt: Optional[str] = field( + default=None, + metadata={"help": "System prompt to add before the user query. Use `|` to separate multiple prompts in training."} + ) + val_size: Optional[float] = field( + default=0, + metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."} + ) + sft_packing: Optional[bool] = field( + default=False, + metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."} + ) + cache_path: Optional[str] = field( + default=None, + metadata={"help": "Path to save or load the preprocessed datasets."} + ) + + def __post_init__(self): + if self.streaming and self.val_size > 1e-6 and self.val_size < 1: + raise ValueError("Streaming mode should have an integer val size.") + + if self.streaming and self.max_samples is not None: + raise ValueError("`max_samples` is incompatible with `streaming`.") + + if self.streaming and self.cache_path: + raise ValueError("`cache_path` is incompatible with `streaming`.") + + def init_for_training(self, seed: int): # support mixing multiple datasets + self.seed = seed + dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else [] + try: + with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f: + dataset_info = json.load(f) + except Exception: + if self.dataset is not None: + raise ValueError("Cannot find dataset_info.json in `dataset_dir`.") + dataset_info = None + + prompt_list = self.system_prompt.split("|") if self.system_prompt else [None] + prompt_list = prompt_list * (len(dataset_names) // len(prompt_list)) + assert len(prompt_list) == len(dataset_names), "Number of system prompts should be equal to datasets or 1." + + if self.interleave_probs is not None: + self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")] + + self.dataset_list: List[DatasetAttr] = [] + for i, name in enumerate(dataset_names): + if name not in dataset_info: + raise ValueError("Undefined dataset {} in dataset_info.json.".format(name)) + + if "hf_hub_url" in dataset_info[name]: + dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) + elif "script_url" in dataset_info[name]: + dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) + else: + dataset_attr = DatasetAttr( + "file", + dataset_name=dataset_info[name]["file_name"], + dataset_sha1=dataset_info[name].get("file_sha1", None) + ) + + if "columns" in dataset_info[name]: + dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None) + dataset_attr.query = dataset_info[name]["columns"].get("query", None) + dataset_attr.response = dataset_info[name]["columns"].get("response", None) + dataset_attr.history = dataset_info[name]["columns"].get("history", None) + dataset_attr.messages = dataset_info[name]["columns"].get("messages", None) + dataset_attr.role = dataset_info[name]["columns"].get("role", None) + dataset_attr.content = dataset_info[name]["columns"].get("content", None) + + dataset_attr.subset = dataset_info[name].get("subset", None) + dataset_attr.ranking = dataset_info[name].get("ranking", False) + dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca") + dataset_attr.system_prompt = prompt_list[i] + self.dataset_list.append(dataset_attr) diff --git a/llm_rl/src/llmtuner/hparams/finetuning_args.py b/llm_rl/src/llmtuner/hparams/finetuning_args.py new file mode 100644 index 00000000..d5ef323d --- /dev/null +++ b/llm_rl/src/llmtuner/hparams/finetuning_args.py @@ -0,0 +1,107 @@ +import json +from typing import Literal, Optional +from dataclasses import asdict, dataclass, field + + +@dataclass +class FinetuningArguments: + r""" + Arguments pertaining to which techniques we are going to fine-tuning with. + """ + stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field( + default="sft", + metadata={"help": "Which stage will be performed in training."} + ) + finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field( + default="lora", + metadata={"help": "Which fine-tuning method to use."} + ) + num_layer_trainable: Optional[int] = field( + default=3, + metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."} + ) + name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field( + default="mlp", + metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \ + LLaMA choices: [\"mlp\", \"self_attn\"], \ + BLOOM & Falcon & ChatGLM2 choices: [\"mlp\", \"self_attention\"], \ + Qwen choices: [\"mlp\", \"attn\"], \ + Phi-1.5 choices: [\"mlp\", \"mixer\"], \ + LLaMA-2, Baichuan, InternLM, XVERSE choices: the same as LLaMA."} + ) + lora_rank: Optional[int] = field( + default=8, + metadata={"help": "The intrinsic dimension for LoRA fine-tuning."} + ) + lora_alpha: Optional[float] = field( + default=32.0, + metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."} + ) + lora_dropout: Optional[float] = field( + default=0.1, + metadata={"help": "Dropout rate for the LoRA fine-tuning."} + ) + lora_target: Optional[str] = field( + default=None, + metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \ + LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ + BLOOM & Falcon & ChatGLM2 choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \ + Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ + Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \ + Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \ + LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."} + ) + additional_target: Optional[str] = field( + default=None, + metadata={"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."} + ) + resume_lora_training: Optional[bool] = field( + default=True, + metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."} + ) + ppo_score_norm: Optional[bool] = field( + default=False, + metadata={"help": "Use score normalization in PPO training."} + ) + ppo_logger: Optional[str] = field( + default=None, + metadata={"help": "Log with either 'wandb' or 'tensorboard' in PPO training."} + ) + ppo_target: Optional[float] = field( + default=6.0, + metadata={"help": "Target KL value for adaptive KL control in PPO training."} + ) + dpo_beta: Optional[float] = field( + default=0.1, + metadata={"help": "The beta parameter for the DPO loss."} + ) + upcast_layernorm: Optional[bool] = field( + default=False, + metadata={"help": "Whether to upcast the layernorm weights in fp32."} + ) + neft_alpha: Optional[float] = field( + default=0, + metadata={"help": "The alpha parameter to control the noise magnitude in NEFTune."} + ) + + def __post_init__(self): + if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA + self.lora_target = [target.strip() for target in self.lora_target.split(",")] + + if isinstance(self.additional_target, str): + self.additional_target = [target.strip() for target in self.additional_target.split(",")] + + assert self.finetuning_type in ["lora", "freeze", "full", "none"], "Invalid fine-tuning method." + + def save_to_json(self, json_path: str): + r"""Saves the content of this instance in JSON format inside `json_path`.""" + json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n" + with open(json_path, "w", encoding="utf-8") as f: + f.write(json_string) + + @classmethod + def load_from_json(cls, json_path: str): + r"""Creates an instance from the content of `json_path`.""" + with open(json_path, "r", encoding="utf-8") as f: + text = f.read() + return cls(**json.loads(text)) diff --git a/llm_rl/src/llmtuner/hparams/general_args.py b/llm_rl/src/llmtuner/hparams/general_args.py new file mode 100644 index 00000000..c0c1a0de --- /dev/null +++ b/llm_rl/src/llmtuner/hparams/general_args.py @@ -0,0 +1,13 @@ +from typing import Literal, Optional +from dataclasses import dataclass, field + + +@dataclass +class GeneralArguments: + r""" + Arguments pertaining to which stage we are going to perform. + """ + stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field( + default="sft", + metadata={"help": "Which stage will be performed in training."} + ) diff --git a/llm_rl/src/llmtuner/hparams/generating_args.py b/llm_rl/src/llmtuner/hparams/generating_args.py new file mode 100644 index 00000000..c04a5c36 --- /dev/null +++ b/llm_rl/src/llmtuner/hparams/generating_args.py @@ -0,0 +1,53 @@ +from typing import Any, Dict, Optional +from dataclasses import asdict, dataclass, field + + +@dataclass +class GeneratingArguments: + r""" + Arguments pertaining to specify the decoding parameters. + """ + do_sample: Optional[bool] = field( + default=True, + metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."} + ) + temperature: Optional[float] = field( + default=0.95, + metadata={"help": "The value used to modulate the next token probabilities."} + ) + top_p: Optional[float] = field( + default=0.7, + metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."} + ) + top_k: Optional[int] = field( + default=50, + metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."} + ) + num_beams: Optional[int] = field( + default=1, + metadata={"help": "Number of beams for beam search. 1 means no beam search."} + ) + max_length: Optional[int] = field( + default=512, + metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."} + ) + max_new_tokens: Optional[int] = field( + default=512, + metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."} + ) + repetition_penalty: Optional[float] = field( + default=1.0, + metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."} + ) + length_penalty: Optional[float] = field( + default=1.0, + metadata={"help": "Exponential penalty to the length that is used with beam-based generation."} + ) + + def to_dict(self) -> Dict[str, Any]: + args = asdict(self) + if args.get("max_new_tokens", -1) > 0: + args.pop("max_length", None) + else: + args.pop("max_new_tokens", None) + return args diff --git a/llm_rl/src/llmtuner/hparams/model_args.py b/llm_rl/src/llmtuner/hparams/model_args.py new file mode 100644 index 00000000..7c25fad1 --- /dev/null +++ b/llm_rl/src/llmtuner/hparams/model_args.py @@ -0,0 +1,93 @@ +from typing import Literal, Optional +from dataclasses import dataclass, field + + +@dataclass +class ModelArguments: + r""" + Arguments pertaining to which model/config/tokenizer we are going to fine-tune. + """ + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."} + ) + use_fast_tokenizer: Optional[bool] = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."} + ) + split_special_tokens: Optional[bool] = field( + default=False, + metadata={"help": "Whether or not the special tokens should be split during the tokenization process."} + ) + use_auth_token: Optional[bool] = field( + default=False, + metadata={"help": "Will use the token generated when running `huggingface-cli login`."} + ) + model_revision: Optional[str] = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."} + ) + quantization_bit: Optional[int] = field( + default=None, + metadata={"help": "The number of bits to quantize the model."} + ) + quantization_type: Optional[Literal["fp4", "nf4"]] = field( + default="nf4", + metadata={"help": "Quantization data type to use in int4 training."} + ) + double_quantization: Optional[bool] = field( + default=True, + metadata={"help": "Whether to use double quantization in int4 training or not."} + ) + rope_scaling: Optional[Literal["linear", "dynamic"]] = field( + default=None, + metadata={"help": "Adopt scaled rotary positional embeddings."} + ) + checkpoint_dir: Optional[str] = field( + default=None, + metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."} + ) + flash_attn: Optional[bool] = field( + default=False, + metadata={"help": "Enable FlashAttention-2 for faster training."} + ) + shift_attn: Optional[bool] = field( + default=False, + metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."} + ) + reward_model: Optional[str] = field( + default=None, + metadata={"help": "Path to the directory containing the checkpoints of the reward model."} + ) + plot_loss: Optional[bool] = field( + default=False, + metadata={"help": "Whether to plot the training loss after fine-tuning or not."} + ) + hf_auth_token: Optional[str] = field( + default=None, + metadata={"help": "Auth token to log in with Hugging Face Hub."} + ) + export_dir: Optional[str] = field( + default=None, + metadata={"help": "Path to the directory to save the exported model."} + ) + + def __post_init__(self): + self.compute_dtype = None + self.model_max_length = None + + if self.split_special_tokens and self.use_fast_tokenizer: + raise ValueError("`split_special_tokens` is only supported for slow tokenizers.") + + if self.checkpoint_dir is not None: # support merging multiple lora weights + self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] + + if self.quantization_bit is not None: + assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization." + + if self.use_auth_token == True and self.hf_auth_token is not None: + from huggingface_hub.hf_api import HfFolder # lazy load + HfFolder.save_token(self.hf_auth_token) diff --git a/llm_rl/src/llmtuner/tuner/__init__.py b/llm_rl/src/llmtuner/tuner/__init__.py new file mode 100644 index 00000000..4d5a83e4 --- /dev/null +++ b/llm_rl/src/llmtuner/tuner/__init__.py @@ -0,0 +1 @@ +from llmtuner.tuner.tune import export_model, run_exp diff --git a/llm_rl/src/llmtuner/tuner/core/__init__.py b/llm_rl/src/llmtuner/tuner/core/__init__.py new file mode 100644 index 00000000..bd1c5cf0 --- /dev/null +++ b/llm_rl/src/llmtuner/tuner/core/__init__.py @@ -0,0 +1,2 @@ +from llmtuner.tuner.core.parser import get_train_args, get_infer_args +from llmtuner.tuner.core.loader import load_model_and_tokenizer diff --git a/llm_rl/src/llmtuner/tuner/core/adapter.py b/llm_rl/src/llmtuner/tuner/core/adapter.py new file mode 100644 index 00000000..4fcc6e62 --- /dev/null +++ b/llm_rl/src/llmtuner/tuner/core/adapter.py @@ -0,0 +1,101 @@ +import torch +from typing import TYPE_CHECKING + +from peft import ( + PeftModel, + TaskType, + LoraConfig, + get_peft_model +) + +from llmtuner.extras.logging import get_logger +from llmtuner.tuner.core.utils import find_all_linear_modules + +if TYPE_CHECKING: + from transformers.modeling_utils import PreTrainedModel + from llmtuner.hparams import ModelArguments, FinetuningArguments + + +logger = get_logger(__name__) + + +def init_adapter( + model: "PreTrainedModel", + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", + is_trainable: bool, + is_mergeable: bool +) -> "PreTrainedModel": + r""" + Initializes the adapters. + + Support full-parameter, freeze and LoRA training. + + Note that the trainable parameters must be cast to float32. + """ + + if finetuning_args.finetuning_type == "none" and is_trainable: + raise ValueError("You cannot use finetuning_type=none while training.") + + if finetuning_args.finetuning_type == "full" and is_trainable: + logger.info("Fine-tuning method: Full") + model = model.float() + + if finetuning_args.finetuning_type == "freeze": + logger.info("Fine-tuning method: Freeze") + num_layers = getattr(model.config, "num_layers") + if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0 + trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)] + else: # fine-tuning the first n layers if num_layer_trainable < 0 + trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)] + + trainable_layers = ["{:d}.{}".format(idx, finetuning_args.name_module_trainable) for idx in trainable_layer_ids] + for name, param in model.named_parameters(): + if not any(trainable_layer in name for trainable_layer in trainable_layers): + param.requires_grad_(False) + else: + param.data = param.data.to(torch.float32) + + if finetuning_args.finetuning_type == "lora": + logger.info("Fine-tuning method: LoRA") + latest_checkpoint = None + + if model_args.checkpoint_dir is not None: + if (is_trainable and finetuning_args.resume_lora_training) or (not is_mergeable): # continually fine-tuning + checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] + else: + checkpoints_to_merge = model_args.checkpoint_dir + + for checkpoint in checkpoints_to_merge: + model = PeftModel.from_pretrained(model, checkpoint) + model = model.merge_and_unload() + + if len(checkpoints_to_merge) > 0: + logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge))) + + if latest_checkpoint is not None: # resume lora training or quantized inference + model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable) + + if is_trainable and latest_checkpoint is None: # create new lora weights while training + if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all": + target_modules = find_all_linear_modules(model, model_args.quantization_bit) + else: + target_modules = finetuning_args.lora_target + + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + inference_mode=False, + r=finetuning_args.lora_rank, + lora_alpha=finetuning_args.lora_alpha, + lora_dropout=finetuning_args.lora_dropout, + target_modules=target_modules, + modules_to_save=finetuning_args.additional_target + ) + model = get_peft_model(model, lora_config) + if id(model.peft_config) != id(model.base_model.peft_config): # https://github.com/huggingface/peft/issues/923 + model.base_model.peft_config = model.peft_config + + if model_args.checkpoint_dir is not None: + logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir))) + + return model diff --git a/llm_rl/src/llmtuner/tuner/core/loader.py b/llm_rl/src/llmtuner/tuner/core/loader.py new file mode 100644 index 00000000..e77c4945 --- /dev/null +++ b/llm_rl/src/llmtuner/tuner/core/loader.py @@ -0,0 +1,244 @@ +import os +import math +import torch +from types import MethodType +from typing import TYPE_CHECKING, Literal, Optional, Tuple + +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, + PretrainedConfig, + PreTrainedModel, + PreTrainedTokenizerBase +) +from transformers.models.llama import modeling_llama as LlamaModule +from transformers.utils.versions import require_version +from trl import AutoModelForCausalLMWithValueHead + +try: + from transformers.integrations import is_deepspeed_zero3_enabled +except ImportError: # https://github.com/huggingface/transformers/releases/tag/v4.33.1 + from transformers.deepspeed import is_deepspeed_zero3_enabled + +from llmtuner.extras.logging import reset_logging, get_logger +from llmtuner.extras.misc import count_parameters, infer_optim_dtype +from llmtuner.extras.patches import llama_patch as LlamaPatches +from llmtuner.extras.save_and_load import load_valuehead_params +from llmtuner.hparams import FinetuningArguments +from llmtuner.tuner.core.adapter import init_adapter +from llmtuner.tuner.core.utils import prepare_model_for_training + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + from llmtuner.hparams import ModelArguments + + +logger = get_logger(__name__) + + +require_version("transformers>=4.31.0,<4.35.0", "To fix: pip install \"transformers>=4.31.0,<4.35.0\"") +require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0") +require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0") +require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0") +require_version("trl>=0.7.2", "To fix: pip install trl>=0.7.2") + + +def load_model_and_tokenizer( + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", + is_trainable: Optional[bool] = False, + stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft" +) -> Tuple[PreTrainedModel, "PreTrainedTokenizer"]: + r""" + Loads pretrained model and tokenizer. + + Support both training and inference. + """ + if (not is_trainable) and model_args.checkpoint_dir is None: + logger.warning("Checkpoint is not found at evaluation, load the original model.") + finetuning_args = FinetuningArguments(finetuning_type="none") + + config_kwargs = { + "trust_remote_code": True, + "cache_dir": model_args.cache_dir, + "revision": model_args.model_revision, + "use_auth_token": True if model_args.use_auth_token else None, + } + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + use_fast=model_args.use_fast_tokenizer, + split_special_tokens=model_args.split_special_tokens, + padding_side="right", # training with left-padded tensors in fp16 precision may cause overflow + **config_kwargs + ) + + if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None: + model_to_load = model_args.checkpoint_dir[0] + else: + model_to_load = model_args.model_name_or_path + + config = AutoConfig.from_pretrained(model_to_load, **config_kwargs) + + # Fix tokenizer (for ChatGLM2 and ChatGLM3) + if getattr(config, "model_type", None) == "chatglm": + tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) + + # Set model dtype + if model_args.compute_dtype is not None: # for training + setattr(config, "torch_dtype", model_args.compute_dtype) + else: # for evaluation, priority: bf16 > fp16 > fp32 + model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) + + # Fix config (for Qwen) + if getattr(config, "model_type", None) == "qwen": + for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]: + setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype) + + # Set RoPE scaling + if model_args.rope_scaling is not None: + if hasattr(config, "use_dynamic_ntk"): # for Qwen models + if is_trainable: + logger.warning("Qwen model does not support RoPE scaling in training.") + else: + setattr(config, "use_dynamic_ntk", True) + setattr(config, "use_logn_attn", True) + logger.info("Using dynamic NTK scaling.") + + elif hasattr(config, "rope_scaling"): # for LLaMA and Falcon models + if is_trainable: + if model_args.rope_scaling == "dynamic": + logger.warning( + "Dynamic NTK may not work well with fine-tuning. " + "See: https://github.com/huggingface/transformers/pull/24653" + ) + + current_max_length = getattr(config, "max_position_embeddings", None) + if current_max_length and model_args.model_max_length > current_max_length: + scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length)) + else: + logger.warning("Input length is smaller than max length. Consider increase input length.") + scaling_factor = 1.0 + else: + scaling_factor = 2.0 + + setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor}) + logger.info("Using {} scaling strategy and setting scaling factor to {}".format( + model_args.rope_scaling, scaling_factor + )) + + else: + logger.warning("Current model does not support RoPE scaling.") + + # Set FlashAttention-2 + if model_args.flash_attn: + if getattr(config, "model_type", None) == "llama": + LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2 + LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask + logger.info("Using FlashAttention-2 for faster training and inference.") + elif getattr(config, "model_type", None) == "qwen": + logger.info("Qwen models automatically enable FlashAttention if installed.") + else: + logger.warning("Current model does not support FlashAttention-2.") + elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama": + LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention + logger.warning("Using `--flash_attn` for faster training in large context length.") + + # Set shift short attention (S^2-Attn) + if is_trainable and model_args.shift_attn: + if getattr(config, "model_type", None) == "llama": + setattr(config, "group_size_ratio", 0.25) + logger.info("Using shift short attention with group_size_ratio=1/4.") + else: + logger.warning("Current model does not support shift short attention.") + + # Quantization configurations (using bitsandbytes library). + is_mergeable = True + if model_args.quantization_bit is not None: + if is_deepspeed_zero3_enabled(): + raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") + + if model_args.quantization_bit == 8: + require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") + config_kwargs["load_in_8bit"] = True + config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) + + elif model_args.quantization_bit == 4: + require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") + config_kwargs["load_in_4bit"] = True + config_kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=model_args.compute_dtype, + bnb_4bit_use_double_quant=model_args.double_quantization, + bnb_4bit_quant_type=model_args.quantization_type + ) + + is_mergeable = False + config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto" + logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) + + # Load and prepare pre-trained models (without valuehead). + model = AutoModelForCausalLM.from_pretrained( + model_to_load, + config=config, + torch_dtype=model_args.compute_dtype, + low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()), + **config_kwargs + ) + + # Disable custom generate method (for Qwen and Baichuan2) + if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__): + model.generate = MethodType(PreTrainedModel.generate, model) + + # Fix LM head (for ChatGLM2 and ChatGLM3) + if getattr(config, "model_type", None) == "chatglm": + setattr(model, "lm_head", model.transformer.output_layer) + setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) + + # Register auto class to save the custom code files. + if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}): + config.__class__.register_for_auto_class() + if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}): + model.__class__.register_for_auto_class() + if isinstance(tokenizer, PreTrainedTokenizerBase) and "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}): + tokenizer.__class__.register_for_auto_class() + + # Initialize adapters + model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model + model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable) + model = model.train() if is_trainable else model.eval() + + # Prepare model with valuehead for RLHF + if stage == "rm" or stage == "ppo": + model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model) + reset_logging() + if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model + logger.warning("Only the last checkpoint containing valuehead will be loaded.") + if load_valuehead_params(model, model_args.checkpoint_dir[-1]): + model.v_head.load_state_dict({ + "summary.weight": getattr(model, "reward_head_weight"), + "summary.bias": getattr(model, "reward_head_bias") + }) + + if stage == "ppo": # load reward model + logger.info("Load reward model from {}".format(model_args.reward_model)) + if getattr(model, "is_peft_model", False): + model.pretrained_model.load_adapter(model_args.reward_model, "reward") + assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded." + + # Prepare model for inference + if not is_trainable: + model.requires_grad_(False) # fix all model params + model = model.to(model_args.compute_dtype) if model_args.quantization_bit is None else model + + trainable_params, all_param = count_parameters(model) + logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( + trainable_params, all_param, 100 * trainable_params / all_param + )) + + if not is_trainable: + logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.") + + return model, tokenizer diff --git a/llm_rl/src/llmtuner/tuner/core/parser.py b/llm_rl/src/llmtuner/tuner/core/parser.py new file mode 100644 index 00000000..603fc1bc --- /dev/null +++ b/llm_rl/src/llmtuner/tuner/core/parser.py @@ -0,0 +1,226 @@ +import os +import sys +import torch +import datasets +import transformers +from typing import Any, Dict, Optional, Tuple +from transformers import HfArgumentParser, Seq2SeqTrainingArguments +from transformers.trainer_utils import get_last_checkpoint + +from llmtuner.extras.logging import get_logger +from llmtuner.hparams import ( + ModelArguments, + DataArguments, + FinetuningArguments, + GeneratingArguments +) + + +logger = get_logger(__name__) + + +def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]: + if args is not None: + return parser.parse_dict(args) + elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): + return parser.parse_yaml_file(os.path.abspath(sys.argv[1])) + elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + return parser.parse_json_file(os.path.abspath(sys.argv[1])) + else: + return parser.parse_args_into_dataclasses() + + +def parse_train_args( + args: Optional[Dict[str, Any]] = None +) -> Tuple[ + ModelArguments, + DataArguments, + Seq2SeqTrainingArguments, + FinetuningArguments, + GeneratingArguments +]: + parser = HfArgumentParser(( + ModelArguments, + DataArguments, + Seq2SeqTrainingArguments, + FinetuningArguments, + GeneratingArguments + )) + return _parse_args(parser, args) + + +def parse_infer_args( + args: Optional[Dict[str, Any]] = None +) -> Tuple[ + ModelArguments, + DataArguments, + FinetuningArguments, + GeneratingArguments +]: + parser = HfArgumentParser(( + ModelArguments, + DataArguments, + FinetuningArguments, + GeneratingArguments + )) + return _parse_args(parser, args) + + +def get_train_args( + args: Optional[Dict[str, Any]] = None +) -> Tuple[ + ModelArguments, + DataArguments, + Seq2SeqTrainingArguments, + FinetuningArguments, + GeneratingArguments +]: + model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args) + + # Setup logging + if training_args.should_log: + # The default of training_args.log_level is passive, so we set log level at info here to have that default. + transformers.utils.logging.set_verbosity_info() + + log_level = training_args.get_process_log_level() + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + # Check arguments + data_args.init_for_training(training_args.seed) + + if finetuning_args.stage != "pt" and data_args.template is None: + raise ValueError("Please specify which `template` to use.") + + if finetuning_args.stage != "sft" and training_args.predict_with_generate: + raise ValueError("`predict_with_generate` cannot be set as True except SFT.") + + if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate: + raise ValueError("Please enable `predict_with_generate` to save model predictions.") + + if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora": + raise ValueError("RM and PPO stages can only be performed with the LoRA method.") + + if finetuning_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None: + raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.") + + if finetuning_args.stage == "ppo" and not training_args.do_train: + raise ValueError("PPO training does not support evaluation.") + + if finetuning_args.stage in ["rm", "dpo"]: + for dataset_attr in data_args.dataset_list: + if not dataset_attr.ranking: + raise ValueError("Please use ranked datasets for reward modeling or DPO training.") + + if finetuning_args.stage == "ppo" and model_args.reward_model is None: + raise ValueError("Reward model is necessary for PPO training.") + + if finetuning_args.stage == "ppo" and model_args.shift_attn: + raise ValueError("PPO training is incompatible with S^2-Attn.") + + if training_args.max_steps == -1 and data_args.streaming: + raise ValueError("Please specify `max_steps` in streaming mode.") + + if training_args.do_train and training_args.predict_with_generate: + raise ValueError("`predict_with_generate` cannot be set as True while training.") + + if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None: + raise ValueError("Please specify `lora_target` in LoRA training.") + + if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": + raise ValueError("Quantization is only compatible with the LoRA method.") + + if model_args.checkpoint_dir is not None: + if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) != 1: + raise ValueError("Only LoRA tuning accepts multiple checkpoints.") + + if model_args.quantization_bit is not None: + if len(model_args.checkpoint_dir) != 1: + raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.") + + if not finetuning_args.resume_lora_training: + raise ValueError("Quantized model cannot create new LoRA weight. Merge them first.") + + if training_args.do_train and model_args.quantization_bit is not None and (not finetuning_args.upcast_layernorm): + logger.warning("We recommend enable `upcast_layernorm` in quantized training.") + + if training_args.do_train and (not training_args.fp16) and (not training_args.bf16): + logger.warning("We recommend enable mixed precision training.") + + if (not training_args.do_train) and model_args.quantization_bit is not None: + logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") + + # postprocess training_args + if ( + training_args.local_rank != -1 + and training_args.ddp_find_unused_parameters is None + and finetuning_args.finetuning_type == "lora" + ): + logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.") + training_args_dict = training_args.to_dict() + training_args_dict.update(dict(ddp_find_unused_parameters=False)) + training_args = Seq2SeqTrainingArguments(**training_args_dict) + + if ( + training_args.resume_from_checkpoint is None + and training_args.do_train + and os.path.isdir(training_args.output_dir) + and not training_args.overwrite_output_dir + ): + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.") + + if last_checkpoint is not None: + training_args_dict = training_args.to_dict() + training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint)) + training_args = Seq2SeqTrainingArguments(**training_args_dict) + logger.info( + "Resuming from checkpoint. Change `output_dir` or use `overwrite_output_dir` to avoid." + ) + + # postprocess model_args + model_args.compute_dtype = ( + torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None) + ) + model_args.model_max_length = data_args.cutoff_len + + # Log on each process the small summary: + logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format( + training_args.local_rank, training_args.device, training_args.n_gpu, + bool(training_args.local_rank != -1), str(model_args.compute_dtype) + )) + logger.info(f"Training/evaluation parameters {training_args}") + + # Set seed before initializing model. + transformers.set_seed(training_args.seed) + + return model_args, data_args, training_args, finetuning_args, generating_args + + +def get_infer_args( + args: Optional[Dict[str, Any]] = None +) -> Tuple[ + ModelArguments, + DataArguments, + FinetuningArguments, + GeneratingArguments +]: + model_args, data_args, finetuning_args, generating_args = parse_infer_args(args) + + if data_args.template is None: + raise ValueError("Please specify which `template` to use.") + + if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": + raise ValueError("Quantization is only compatible with the LoRA method.") + + if model_args.checkpoint_dir is not None: + if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) != 1: + raise ValueError("Only LoRA tuning accepts multiple checkpoints.") + + if model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1: + raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.") + + return model_args, data_args, finetuning_args, generating_args diff --git a/llm_rl/src/llmtuner/tuner/core/utils.py b/llm_rl/src/llmtuner/tuner/core/utils.py new file mode 100644 index 00000000..d9a1aac9 --- /dev/null +++ b/llm_rl/src/llmtuner/tuner/core/utils.py @@ -0,0 +1,94 @@ +import torch +from types import MethodType +from typing import TYPE_CHECKING, List, Optional + +from llmtuner.extras.constants import LAYERNORM_NAMES +from llmtuner.extras.logging import get_logger + +if TYPE_CHECKING: + from transformers.modeling_utils import PreTrainedModel + from llmtuner.hparams import FinetuningArguments + + +logger = get_logger(__name__) + + +def find_all_linear_modules( + model: "PreTrainedModel", + quantization_bit: Optional[int] = None, + output_layer_name: Optional[str] = "lm_head" +) -> List[str]: + if quantization_bit is not None: + import bitsandbytes as bnb + linear_cls = bnb.nn.Linear4bit if quantization_bit == 4 else bnb.nn.Linear8bitLt + else: + linear_cls = torch.nn.Linear + + module_names = set() + for name, module in model.named_modules(): + if output_layer_name not in name and isinstance(module, linear_cls): + module_names.add(name.split(".")[-1]) + + if output_layer_name in module_names: + module_names.pop(output_layer_name) + + return list(module_names) + + +def prepare_model_for_training( + model: "PreTrainedModel", + finetuning_args: "FinetuningArguments", + output_layer_name: Optional[str] = "lm_head", + use_gradient_checkpointing: Optional[bool] = True, + layernorm_names: Optional[List[str]] = LAYERNORM_NAMES +) -> "PreTrainedModel": + r""" + Includes: + (1) cast the layernorm in fp32 + (2) make output embedding layer require grads + (3) upcast the lm_head to fp32 + Inspired by: https://github.com/huggingface/peft/blob/v0.2.0/src/peft/utils/other.py#L33 + """ + if finetuning_args.upcast_layernorm: + for name, param in model.named_parameters(): + if param.ndim == 1 and any(ln_name in name for ln_name in layernorm_names): + param.data = param.data.to(torch.float32) + logger.info("Upcasting weights in layernorm in float32.") + + if finetuning_args.neft_alpha > 1e-6: + input_embed = model.get_input_embeddings() + if isinstance(input_embed, torch.nn.Embedding): + def noisy_forward(self: torch.nn.Embedding, x: torch.Tensor) -> torch.Tensor: + embeddings = input_embed.__class__.forward(self, x) + if self.training: + dims = self.num_embeddings * self.embedding_dim + mag_norm = finetuning_args.neft_alpha / (dims ** 0.5) + embeddings += torch.zeros_like(embeddings).uniform_(-mag_norm, mag_norm) + return embeddings + + input_embed.forward = MethodType(noisy_forward, input_embed) + logger.info("Using noisy embedding with alpha={:.2f}".format(finetuning_args.neft_alpha)) + else: + logger.warning("Input embeddings are not normal nn.Embedding, cannot transform into noisy embedding.") + + if use_gradient_checkpointing: + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + def make_inputs_require_grad(module: torch.nn.Module, input: torch.Tensor, output: torch.Tensor): + output.requires_grad_(True) + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + model.gradient_checkpointing_enable() + model.config.use_cache = False # turn off when gradient checkpointing is enabled + logger.info("Gradient checkpointing enabled.") + + if finetuning_args.finetuning_type != "full" and hasattr(model, output_layer_name): + output_layer = getattr(model, output_layer_name) + if isinstance(output_layer, torch.nn.Linear): + def forward_in_fp32(self, x: torch.Tensor) -> torch.Tensor: + return output_layer.__class__.forward(self, x.to(output_layer.weight.dtype)).to(torch.float32) + + output_layer.forward = MethodType(forward_in_fp32, output_layer) + + return model diff --git a/llm_rl/src/llmtuner/tuner/dpo/__init__.py b/llm_rl/src/llmtuner/tuner/dpo/__init__.py new file mode 100644 index 00000000..f2b5cfb5 --- /dev/null +++ b/llm_rl/src/llmtuner/tuner/dpo/__init__.py @@ -0,0 +1 @@ +from llmtuner.tuner.dpo.workflow import run_dpo diff --git a/llm_rl/src/llmtuner/tuner/dpo/collator.py b/llm_rl/src/llmtuner/tuner/dpo/collator.py new file mode 100644 index 00000000..5c862b4f --- /dev/null +++ b/llm_rl/src/llmtuner/tuner/dpo/collator.py @@ -0,0 +1,51 @@ +import torch +from dataclasses import dataclass +from typing import Any, Dict, List, Sequence, Tuple +from transformers import DataCollatorForSeq2Seq + + +@dataclass +class DPODataCollatorWithPadding(DataCollatorForSeq2Seq): + r""" + Data collator for pairwise data. + """ + + def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor: + padded_labels = [] + for feature, (prompt_len, answer_len) in zip(batch, positions): + if self.tokenizer.padding_side == "left": + start, end = feature.size(0) - answer_len, feature.size(0) + else: + start, end = prompt_len, prompt_len + answer_len + padded_tensor = self.label_pad_token_id * torch.ones_like(feature) + padded_tensor[start:end] = feature[start:end] + padded_labels.append(padded_tensor) + return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory + + def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: + r""" + Pads batched data to the longest sequence in the batch. + + We generate 2 * n examples where the first n examples represent chosen examples and + the last n examples represent rejected examples. + """ + concatenated_features = [] + label_positions = [] + for key in ("chosen_ids", "rejected_ids"): + for feature in features: + prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key]) + concatenated_features.append({ + "input_ids": feature["prompt_ids"] + feature[key], + "attention_mask": [1] * (prompt_len + answer_len) + }) + label_positions.append((prompt_len, answer_len)) + + batch = self.tokenizer.pad( + concatenated_features, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors=self.return_tensors, + ) + batch["labels"] = self._pad_labels(batch["input_ids"], label_positions) + return batch diff --git a/llm_rl/src/llmtuner/tuner/dpo/trainer.py b/llm_rl/src/llmtuner/tuner/dpo/trainer.py new file mode 100644 index 00000000..8a9f8dd6 --- /dev/null +++ b/llm_rl/src/llmtuner/tuner/dpo/trainer.py @@ -0,0 +1,104 @@ +import torch +import deepspeed # type: ignore +from copy import deepcopy +from collections import defaultdict +from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union +from transformers import BatchEncoding, Trainer +from trl import DPOTrainer +from trl.trainer.utils import disable_dropout_in_model + +from llmtuner.extras.constants import IGNORE_INDEX + +if TYPE_CHECKING: + from transformers import PreTrainedModel + from trl import PreTrainedModelWrapper + + +class CustomDPOTrainer(DPOTrainer): + + def __init__( + self, + beta: float, + model: Union["PreTrainedModel", torch.nn.Module], + ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None, + disable_dropout: Optional[bool] = True, + loss_type: Optional[Literal["sigmoid", "hinge"]] = "sigmoid", + **kwargs + ): + if disable_dropout: + disable_dropout_in_model(model) + if ref_model is not None: + disable_dropout_in_model(ref_model) + + self.is_encoder_decoder = model.config.is_encoder_decoder + self.ref_model = ref_model + self.use_dpo_data_collator = True # hack to avoid warning + self.generate_during_eval = False # disable at evaluation + self.label_pad_token_id = IGNORE_INDEX + self.padding_value = 0 + self.beta = beta + self.loss_type = loss_type + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + + Trainer.__init__(self, model=model, **kwargs) + if not hasattr(self, "accelerator"): + raise AttributeError("Please update `transformers`.") + + if ref_model is not None: + if self.is_deepspeed_enabled: + self.ref_model = self._prepare_deepspeed(self.ref_model) + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + + def _prepare_deepspeed(self, model: "PreTrainedModelWrapper"): + # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 + deepspeed_plugin = self.accelerator.state.deepspeed_plugin + config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config) + if model is not None: + if hasattr(model, "config"): + hidden_size = ( + max(model.config.hidden_sizes) + if getattr(model.config, "hidden_sizes", None) + else getattr(model.config, "hidden_size", None) + ) + if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3: + # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0` + # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081 + config_kwargs.update( + { + "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, + "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, + "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, + } + ) + + # If ZeRO-3 is used, we shard both the active and reference model. + # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0) + if config_kwargs["zero_optimization"]["stage"] != 3: + config_kwargs["zero_optimization"]["stage"] = 0 + model, *_ = deepspeed.initialize(model=model, config=config_kwargs) + model.eval() + return model + + def concatenated_forward( + self, + model: Optional[torch.nn.Module] = None, + batch: Optional[Dict[str, torch.Tensor]] = None + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error + + all_logits = model( + input_ids=batch_copied["input_ids"], + attention_mask=batch_copied["attention_mask"], + return_dict=True + ).logits.to(torch.float32) + + all_logps = self._get_batch_logps( + all_logits, + batch["labels"], + average_log_prob=False + ) + batch_size = batch["input_ids"].size(0) // 2 + chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0) + chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0) + return chosen_logps, rejected_logps, chosen_logits, rejected_logits diff --git a/llm_rl/src/llmtuner/tuner/dpo/workflow.py b/llm_rl/src/llmtuner/tuner/dpo/workflow.py new file mode 100644 index 00000000..6e16dd18 --- /dev/null +++ b/llm_rl/src/llmtuner/tuner/dpo/workflow.py @@ -0,0 +1,66 @@ +# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py + +from copy import deepcopy +from peft import PeftModel +from typing import TYPE_CHECKING, Optional, List +from transformers import Seq2SeqTrainingArguments + +from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset +from llmtuner.extras.constants import IGNORE_INDEX +from llmtuner.extras.ploting import plot_loss +from llmtuner.tuner.core import load_model_and_tokenizer +from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding +from llmtuner.tuner.dpo.trainer import CustomDPOTrainer + +if TYPE_CHECKING: + from transformers import TrainerCallback + from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments + + +def run_dpo( + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + callbacks: Optional[List["TrainerCallback"]] = None +): + dataset = get_dataset(model_args, data_args) + model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft") + dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm") + data_collator = DPODataCollatorWithPadding( + tokenizer=tokenizer, + pad_to_multiple_of=4, + label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id + ) + + training_args_dict = training_args.to_dict() + training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset + training_args = Seq2SeqTrainingArguments(**training_args_dict) + + # Initialize our Trainer + trainer = CustomDPOTrainer( + beta=finetuning_args.dpo_beta, + model=model, + ref_model=deepcopy(model) if not isinstance(model, PeftModel) else None, + args=training_args, + tokenizer=tokenizer, + data_collator=data_collator, + callbacks=callbacks, + **split_dataset(dataset, data_args, training_args) + ) + + # Training + if training_args.do_train: + train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + trainer.save_model() + if trainer.is_world_process_zero() and model_args.plot_loss: + plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) + + # Evaluation + if training_args.do_eval: + metrics = trainer.evaluate(metric_key_prefix="eval") + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) diff --git a/llm_rl/src/llmtuner/tuner/ppo/__init__.py b/llm_rl/src/llmtuner/tuner/ppo/__init__.py new file mode 100644 index 00000000..11519bab --- /dev/null +++ b/llm_rl/src/llmtuner/tuner/ppo/__init__.py @@ -0,0 +1 @@ +from llmtuner.tuner.ppo.workflow import run_ppo diff --git a/llm_rl/src/llmtuner/tuner/ppo/trainer.py b/llm_rl/src/llmtuner/tuner/ppo/trainer.py new file mode 100644 index 00000000..372c4891 --- /dev/null +++ b/llm_rl/src/llmtuner/tuner/ppo/trainer.py @@ -0,0 +1,310 @@ +import os +import sys +import math +import torch +from tqdm import tqdm +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR + +from trl import PPOTrainer +from trl.core import PPODecorators, logprobs_from_logits + +from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback +from llmtuner.extras.logging import get_logger +from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor +from llmtuner.tuner.ppo.utils import dump_layernorm, restore_layernorm, replace_model + +if TYPE_CHECKING: + from transformers import Seq2SeqTrainingArguments, TrainerCallback + from trl import AutoModelForCausalLMWithValueHead + from llmtuner.hparams import ModelArguments, FinetuningArguments, GeneratingArguments + + +logger = get_logger(__name__) + + +class CustomPPOTrainer(PPOTrainer, Trainer): + r""" + Inherits PPOTrainer. + """ + + def __init__( + self, + model_args: "ModelArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + callbacks: List["TrainerCallback"], + **kwargs + ): + PPOTrainer.__init__(self, **kwargs) + self.args = training_args + self.model_args = model_args + self.finetuning_args = finetuning_args + self.generation_config = GenerationConfig( + pad_token_id=self.tokenizer.pad_token_id, + eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, + **generating_args.to_dict() + ) + self.state = TrainerState() + self.control = TrainerControl() + self.log_callback, self.save_callback = callbacks[0], callbacks[1] + assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback) + if self.args.max_steps > 0: + logger.info("max_steps is given, it will override any value given in num_train_epochs") + + def ppo_train(self) -> None: + r""" + Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer. + """ + total_train_batch_size = ( + self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.args.world_size + ) + if self.args.max_steps > 0: + num_examples = total_train_batch_size * self.args.max_steps + num_train_epochs = sys.maxsize + max_steps = self.args.max_steps + steps_in_epoch = self.args.max_steps * self.args.gradient_accumulation_steps + else: + len_dataloader = len(self.dataloader) + num_examples = len(self.dataset) + num_train_epochs = self.args.num_train_epochs + max_steps = math.ceil(num_train_epochs * len_dataloader) + steps_in_epoch = len_dataloader + + self.state.max_steps = max_steps + self.state.num_train_epochs = num_train_epochs + self.state.is_local_process_zero = self.is_local_process_zero() + self.state.is_world_process_zero = self.is_world_process_zero() + + if self.is_world_process_zero(): + logger.info("***** Running training *****") + logger.info(f" Num examples = {num_examples}") + logger.info(f" Num Epochs = {num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") + logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_steps}") + logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}") + + unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) + dataiter = iter(self.dataloader) + loss_meter = AverageMeter() + reward_meter = AverageMeter() + self.log_callback.on_train_begin(self.args, self.state, self.control) + + for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()): + try: + batch = next(dataiter) + except StopIteration: + dataiter = iter(self.dataloader) + batch = next(dataiter) + + # Cast to inference mode + unwrapped_model.gradient_checkpointing_disable() + unwrapped_model.config.use_cache = True + self.model.eval() + + # Get inputs + queries, responses = self.get_inputs(batch) + self.tokenizer.padding_side = "right" # change padding side + rewards = self.get_rewards(queries, responses, unwrapped_model) + + # Cast to training mode + unwrapped_model.gradient_checkpointing_enable() + unwrapped_model.config.use_cache = False + self.model.train() + + # Run PPO step + stats = self.step(queries, responses, rewards) + self.tokenizer.padding_side = "left" # restore padding side + loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards)) + reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards)) + + if self.config.log_with is not None: + try: + batch["query"] = self.tokenizer.batch_decode(queries, skip_special_tokens=True) + batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True) + self.log_stats(stats, batch, rewards) + except: + logger.warning("Failed to save stats due to unknown errors.") + + self.state.global_step += 1 + self.log_callback.on_step_end(self.args, self.state, self.control) + + if self.is_local_process_zero() and (step+1) % self.args.logging_steps == 0: + logs = dict( + loss=round(loss_meter.avg, 4), + reward=round(reward_meter.avg, 4), + learning_rate=stats["ppo/learning_rate"], + epoch=round(step / steps_in_epoch, 2) + ) + tqdm.write(str(logs)) + logs["step"] = step + self.state.log_history.append(logs) + self.log_callback.on_log(self.args, self.state, self.control) + loss_meter.reset() + reward_meter.reset() + + if (step+1) % self.args.save_steps == 0: # save checkpoint + self.save_model(os.path.join( + self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step) + )) + self.save_callback.on_save( + self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model) + ) + + if self.control.should_epoch_stop or self.control.should_training_stop: + break + + self.log_callback.on_train_end(self.args, self.state, self.control) + self.save_callback.on_train_end( + self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model) + ) + + @torch.no_grad() + def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + r""" + Generates model's responses given queries. + """ + if self.finetuning_args.upcast_layernorm: + layernorm_params = dump_layernorm(self.model) + + unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) + response: torch.Tensor = unwrapped_model.generate( + generation_config=self.generation_config, + logits_processor=get_logits_processor(), + **batch + ) + + if self.finetuning_args.upcast_layernorm: + restore_layernorm(self.model, layernorm_params) + + query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu() + queries, responses = [], [] + for i in range(len(query)): + query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item() + response_index = (response[i] != self.tokenizer.pad_token_id).nonzero() + + if len(response_index) == 0: + response_length = 1 # allow empty response + elif self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: + response_length = response_index[-1].item() + 2 # save the EOS token + else: + response_length = response_index[-1].item() + 1 + + queries.append(query[i, query_length:]) # remove padding from left + responses.append(response[i, :response_length]) # remove padding from right + + return queries, responses + + @torch.no_grad() + def get_rewards( + self, + queries: List[torch.Tensor], + responses: List[torch.Tensor], + unwrapped_model: "AutoModelForCausalLMWithValueHead" + ) -> List[torch.Tensor]: + r""" + Computes scores using given reward model. + """ + replace_model(unwrapped_model, target="reward") + batch = self.prepare_model_inputs(queries, responses) + + with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16 + _, _, values = self.model(**batch, output_hidden_states=True, return_dict=True) + + if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2 + values = torch.transpose(values, 0, 1) + + rewards = [] + for i in range(values.size(0)): + end_index = batch["attention_mask"][i].nonzero()[-1].item() # use the score on the EOS token + rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type + + replace_model(unwrapped_model, target="default") + return rewards + + @PPODecorators.empty_cuda_cache() + def batched_forward_pass( + self, + model: "AutoModelForCausalLMWithValueHead", + queries: torch.Tensor, + responses: torch.Tensor, + model_inputs: dict, + return_logits: Optional[bool] = False, + response_masks: Optional[torch.Tensor] = None + ): + r""" + Calculates model outputs in multiple batches. + + Subclass and override to inject custom behavior. + """ + bs = len(queries) + fbs = self.config.mini_batch_size + all_logprobs = [] + all_logits = [] + all_masks = [] + all_values = [] + + for i in range(math.ceil(bs / fbs)): + input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()} + query_batch = queries[i * fbs : (i + 1) * fbs] + response_batch = responses[i * fbs : (i + 1) * fbs] + if response_masks is not None: + response_masks_batch = response_masks[i * fbs : (i + 1) * fbs] + input_ids = input_kwargs["input_ids"] + attention_mask = input_kwargs["attention_mask"] + + with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16 + logits, _, values = model(**input_kwargs) + + if values.size(0) != input_ids.size(0): # adapt to chatglm2 + values = torch.transpose(values, 0, 1) + + logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:]) + masks = torch.zeros_like(attention_mask) + masks[:, :-1] = attention_mask[:, 1:] + + for j in range(len(query_batch)): + start = len(query_batch[j]) - 1 + if attention_mask[j, 0] == 0: # offset left padding + start += attention_mask[j, :].nonzero()[0].item() + end = start + len(response_batch[j]) + + if response_masks is not None: + response_masks_batch = torch.cat( + (torch.zeros_like(query_batch[j]), response_masks_batch[j]) + )[1:] + + masks[j, :start] = 0 + masks[j, end:] = 0 + if response_masks is not None: + masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end] + + if return_logits: + all_logits.append(logits) + else: + del logits + + all_values.append(values) + all_logprobs.append(logprobs) + all_masks.append(masks) + + return ( + torch.cat(all_logprobs), + torch.cat(all_logits)[:, :-1] if return_logits else None, + torch.cat(all_values)[:, :-1], + torch.cat(all_masks)[:, :-1], + ) + + def save_model(self, output_dir: Optional[str] = None) -> None: + r""" + Saves model checkpoint. + + Subclass and override to inject custom behavior. + """ + if self.args.should_save: + self._save(output_dir) diff --git a/llm_rl/src/llmtuner/tuner/ppo/utils.py b/llm_rl/src/llmtuner/tuner/ppo/utils.py new file mode 100644 index 00000000..74453a39 --- /dev/null +++ b/llm_rl/src/llmtuner/tuner/ppo/utils.py @@ -0,0 +1,35 @@ +import torch +from typing import TYPE_CHECKING, Dict, Literal, Optional + +if TYPE_CHECKING: + from transformers import PreTrainedModel + from trl import AutoModelForCausalLMWithValueHead + + +def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None: + if target == "reward": # save default head temporarily + valuehead_state_dict: Dict[str, torch.Tensor] = model.v_head.state_dict() + setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"].detach().clone()) + setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"].detach().clone()) + + model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active + model.v_head.load_state_dict({ + "summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(), + "summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone() + }) + + +def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]: + layer_norm_params = {} + for name, param in model.named_parameters(): + if param.data.dtype == torch.float32: + layer_norm_params[name] = param.data.detach().clone() + param.data = param.data.to(model.config.torch_dtype) + + return layer_norm_params + + +def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, torch.Tensor]] = None) -> None: + for name, param in model.named_parameters(): + if name in layernorm_params: + param.data = layernorm_params[name] diff --git a/llm_rl/src/llmtuner/tuner/ppo/workflow.py b/llm_rl/src/llmtuner/tuner/ppo/workflow.py new file mode 100644 index 00000000..4c35f628 --- /dev/null +++ b/llm_rl/src/llmtuner/tuner/ppo/workflow.py @@ -0,0 +1,92 @@ +# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py + +import math +from trl import PPOConfig +from torch.optim import AdamW +from typing import TYPE_CHECKING, Optional, List +from transformers import DataCollatorWithPadding +from transformers.optimization import get_scheduler + +from llmtuner.dsets import get_dataset, preprocess_dataset +from llmtuner.extras.callbacks import SavePeftModelCallback +from llmtuner.extras.ploting import plot_loss +from llmtuner.tuner.core import load_model_and_tokenizer +from llmtuner.tuner.ppo.trainer import CustomPPOTrainer + +if TYPE_CHECKING: + from transformers import Seq2SeqTrainingArguments, TrainerCallback + from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments + + +def run_ppo( + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + callbacks: Optional[List["TrainerCallback"]] = None +): + dataset = get_dataset(model_args, data_args) + model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo") + dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo") + + tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training + data_collator = DataCollatorWithPadding(tokenizer=tokenizer) + + ppo_config = PPOConfig( + model_name=model_args.model_name_or_path, + learning_rate=training_args.learning_rate, + mini_batch_size=training_args.per_device_train_batch_size, + batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps, + gradient_accumulation_steps=training_args.gradient_accumulation_steps, + ppo_epochs=1, + max_grad_norm=training_args.max_grad_norm, + seed=training_args.seed, + optimize_cuda_cache=True, + target=finetuning_args.ppo_target, + log_with=finetuning_args.ppo_logger, + use_score_scaling=finetuning_args.ppo_score_norm, + use_score_norm=finetuning_args.ppo_score_norm, + accelerator_kwargs={"step_scheduler_with_optimizer": False} + ) + + optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate) + if training_args.max_steps > 0: + num_training_steps = training_args.max_steps + else: + total_train_batch_size = ( + training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size + ) + num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size) + + lr_scheduler = get_scheduler( + training_args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=training_args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps + ) + + # Initialize our Trainer + ppo_trainer = CustomPPOTrainer( + model_args=model_args, + training_args=training_args, + finetuning_args=finetuning_args, + generating_args=generating_args, + callbacks=callbacks + [SavePeftModelCallback()], + config=ppo_config, + model=model, + ref_model=None, + tokenizer=tokenizer, + dataset=dataset, + data_collator=data_collator, + optimizer=optimizer, + lr_scheduler=lr_scheduler + ) + + # Training + if training_args.do_train: + ppo_trainer.ppo_train() + ppo_trainer.save_model() + ppo_trainer.save_state() # must be called after save_model to have a folder + if ppo_trainer.is_world_process_zero() and model_args.plot_loss: + plot_loss(training_args.output_dir, keys=["loss", "reward"]) diff --git a/llm_rl/src/llmtuner/tuner/pt/__init__.py b/llm_rl/src/llmtuner/tuner/pt/__init__.py new file mode 100644 index 00000000..8ce509db --- /dev/null +++ b/llm_rl/src/llmtuner/tuner/pt/__init__.py @@ -0,0 +1 @@ +from llmtuner.tuner.pt.workflow import run_pt diff --git a/llm_rl/src/llmtuner/tuner/pt/workflow.py b/llm_rl/src/llmtuner/tuner/pt/workflow.py new file mode 100644 index 00000000..66d08de7 --- /dev/null +++ b/llm_rl/src/llmtuner/tuner/pt/workflow.py @@ -0,0 +1,58 @@ +# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py + +import math +from typing import TYPE_CHECKING, Optional, List +from transformers import DataCollatorForLanguageModeling, Trainer + +from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset +from llmtuner.extras.ploting import plot_loss +from llmtuner.tuner.core import load_model_and_tokenizer + +if TYPE_CHECKING: + from transformers import Seq2SeqTrainingArguments, TrainerCallback + from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments + + +def run_pt( + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + callbacks: Optional[List["TrainerCallback"]] = None +): + dataset = get_dataset(model_args, data_args) + model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt") + dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt") + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + + # Initialize our Trainer + trainer = Trainer( + model=model, + args=training_args, + tokenizer=tokenizer, + data_collator=data_collator, + callbacks=callbacks, + **split_dataset(dataset, data_args, training_args) + ) + + # Training + if training_args.do_train: + train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + trainer.save_model() + if trainer.is_world_process_zero() and model_args.plot_loss: + plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) + + # Evaluation + if training_args.do_eval: + metrics = trainer.evaluate(metric_key_prefix="eval") + try: + perplexity = math.exp(metrics["eval_loss"]) + except OverflowError: + perplexity = float("inf") + + metrics["perplexity"] = perplexity + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) diff --git a/llm_rl/src/llmtuner/tuner/rm/__init__.py b/llm_rl/src/llmtuner/tuner/rm/__init__.py new file mode 100644 index 00000000..54d3d943 --- /dev/null +++ b/llm_rl/src/llmtuner/tuner/rm/__init__.py @@ -0,0 +1 @@ +from llmtuner.tuner.rm.workflow import run_rm diff --git a/llm_rl/src/llmtuner/tuner/rm/collator.py b/llm_rl/src/llmtuner/tuner/rm/collator.py new file mode 100644 index 00000000..161f003d --- /dev/null +++ b/llm_rl/src/llmtuner/tuner/rm/collator.py @@ -0,0 +1,27 @@ +import torch +from dataclasses import dataclass +from typing import Any, Dict, Sequence +from transformers import DataCollatorWithPadding + + +@dataclass +class PairwiseDataCollatorWithPadding(DataCollatorWithPadding): + r""" + Data collator for pairwise data. + """ + + def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: + r""" + Pads batched data to the longest sequence in the batch. + + We generate 2 * n examples where the first n examples represent chosen examples and + the last n examples represent rejected examples. + """ + features = [ + { + "input_ids": feature["prompt_ids"] + feature[key], + "attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key])) + } + for key in ("chosen_ids", "rejected_ids") for feature in features + ] + return super().__call__(features) diff --git a/llm_rl/src/llmtuner/tuner/rm/metric.py b/llm_rl/src/llmtuner/tuner/rm/metric.py new file mode 100644 index 00000000..db9c9243 --- /dev/null +++ b/llm_rl/src/llmtuner/tuner/rm/metric.py @@ -0,0 +1,7 @@ +import numpy as np +from typing import Dict, Sequence, Tuple, Union + + +def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]: + preds, _ = eval_preds + return {"accuracy": (preds[0] > preds[1]).sum() / len(preds[0])} diff --git a/llm_rl/src/llmtuner/tuner/rm/trainer.py b/llm_rl/src/llmtuner/tuner/rm/trainer.py new file mode 100644 index 00000000..80502937 --- /dev/null +++ b/llm_rl/src/llmtuner/tuner/rm/trainer.py @@ -0,0 +1,105 @@ +import os +import json +import torch +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from transformers import Trainer + +from llmtuner.extras.logging import get_logger + +if TYPE_CHECKING: + from transformers.trainer import PredictionOutput + from transformers.modeling_utils import PreTrainedModel + + +logger = get_logger(__name__) + + +class PairwiseTrainer(Trainer): + r""" + Inherits PeftTrainer to compute pairwise loss. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.can_return_loss = True # override property to return eval_loss + + def compute_loss( + self, + model: "PreTrainedModel", + inputs: Dict[str, torch.Tensor], + return_outputs: Optional[bool] = False + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: + r""" + Computes pairwise loss. The first n examples are chosen and the last n examples are rejected. + + Subclass and override to inject custom behavior. + + Note that the first element will be removed from the output tuple. + See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509 + """ + # Compute rewards + _, _, values = model(**inputs, output_hidden_states=True, return_dict=True) + if values.size(0) != inputs["input_ids"].size(0): # adapt to chatglm2 + values = torch.transpose(values, 0, 1) + + # Split the inputs and rewards into two parts, chosen and rejected + batch_size = inputs["input_ids"].size(0) // 2 + chosen_input_ids, rejected_input_ids = inputs["input_ids"][:batch_size], inputs["input_ids"][batch_size:] + chosen_attn_mask, rejected_attn_mask = ( + inputs["attention_mask"][:batch_size], inputs["attention_mask"][batch_size:] + ) + chosen_rewards, rejected_rewards = values[:batch_size], values[batch_size:] + chosen_scores, rejected_scores = [], [] + + # Compute pairwise loss. Only backprop on the different tokens before padding + # Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py + loss = 0 + for i in range(batch_size): + chosen_length = chosen_attn_mask[i].nonzero()[-1] + 1 + rejected_length = rejected_attn_mask[i].nonzero()[-1] + 1 + check_divergence = (chosen_input_ids[i] != rejected_input_ids[i]).nonzero() + + if len(check_divergence) == 0: + end_index = chosen_length + div_index = end_index - 1 + else: + end_index = max(chosen_length, rejected_length) + div_index = check_divergence[0] + + assert div_index > 0 + chosen_trunc_rewards = chosen_rewards[i, div_index:end_index] + rejected_trunc_rewards = rejected_rewards[i, div_index:end_index] + if return_outputs: # use the score on the EOS token for inference + chosen_scores.append(chosen_rewards[i, chosen_length-1]) + rejected_scores.append(rejected_rewards[i, rejected_length-1]) + loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean() + + loss = loss / batch_size + if return_outputs: + chosen_scores, rejected_scores = torch.stack(chosen_scores), torch.stack(rejected_scores) + return loss, [loss, chosen_scores, rejected_scores] + + return loss + + def save_predictions( + self, + predict_results: "PredictionOutput" + ) -> None: + r""" + Saves model predictions to `output_dir`. + + A custom behavior that not contained in Seq2SeqTrainer. + """ + if not self.is_world_process_zero(): + return + + output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") + logger.info(f"Saving prediction results to {output_prediction_file}") + + chosen_scores, rejected_scores = predict_results.predictions + + with open(output_prediction_file, "w", encoding="utf-8") as writer: + res: List[str] = [] + for c_score, r_score in zip(chosen_scores, rejected_scores): + res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)})) + writer.write("\n".join(res)) diff --git a/llm_rl/src/llmtuner/tuner/rm/workflow.py b/llm_rl/src/llmtuner/tuner/rm/workflow.py new file mode 100644 index 00000000..6d2c4422 --- /dev/null +++ b/llm_rl/src/llmtuner/tuner/rm/workflow.py @@ -0,0 +1,68 @@ +# Inspired by: +# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py + +from typing import TYPE_CHECKING, Optional, List +from transformers import Seq2SeqTrainingArguments + +from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset +from llmtuner.extras.callbacks import SavePeftModelCallback +from llmtuner.extras.ploting import plot_loss +from llmtuner.tuner.core import load_model_and_tokenizer +from llmtuner.tuner.rm.metric import compute_accuracy +from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding +from llmtuner.tuner.rm.trainer import PairwiseTrainer + +if TYPE_CHECKING: + from transformers import TrainerCallback + from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments + + +def run_rm( + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + callbacks: Optional[List["TrainerCallback"]] = None +): + dataset = get_dataset(model_args, data_args) + model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm") + dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm") + data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=4) + + training_args_dict = training_args.to_dict() + training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset + training_args = Seq2SeqTrainingArguments(**training_args_dict) + + # Initialize our Trainer + trainer = PairwiseTrainer( + model=model, + args=training_args, + tokenizer=tokenizer, + data_collator=data_collator, + callbacks=callbacks + [SavePeftModelCallback()], + compute_metrics=compute_accuracy, + **split_dataset(dataset, data_args, training_args) + ) + + # Training + if training_args.do_train: + train_result = trainer.train() + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + trainer.save_model() + if trainer.is_world_process_zero() and model_args.plot_loss: + plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) + + # Evaluation + if training_args.do_eval: + metrics = trainer.evaluate(metric_key_prefix="eval") + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # Predict + if training_args.do_predict: + predict_results = trainer.predict(dataset, metric_key_prefix="predict") + trainer.log_metrics("predict", predict_results.metrics) + trainer.save_metrics("predict", predict_results.metrics) + trainer.save_predictions(predict_results) diff --git a/llm_rl/src/llmtuner/tuner/sft/__init__.py b/llm_rl/src/llmtuner/tuner/sft/__init__.py new file mode 100644 index 00000000..493dd1a7 --- /dev/null +++ b/llm_rl/src/llmtuner/tuner/sft/__init__.py @@ -0,0 +1 @@ +from llmtuner.tuner.sft.workflow import run_sft diff --git a/llm_rl/src/llmtuner/tuner/sft/metric.py b/llm_rl/src/llmtuner/tuner/sft/metric.py new file mode 100644 index 00000000..812896ee --- /dev/null +++ b/llm_rl/src/llmtuner/tuner/sft/metric.py @@ -0,0 +1,53 @@ +import numpy as np +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union + +import jieba +from rouge_chinese import Rouge +from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction + +from llmtuner.extras.constants import IGNORE_INDEX + +if TYPE_CHECKING: + from transformers.tokenization_utils import PreTrainedTokenizer + + +@dataclass +class ComputeMetrics: + r""" + Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer. + """ + + tokenizer: "PreTrainedTokenizer" + + def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]: + r""" + Uses the model predictions to compute metrics. + """ + preds, labels = eval_preds + score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} + + preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id) + labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id) + + decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) + decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) + + for pred, label in zip(decoded_preds, decoded_labels): + hypothesis = list(jieba.cut(pred)) + reference = list(jieba.cut(label)) + + if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0: + result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}} + else: + rouge = Rouge() + scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference)) + result = scores[0] + + for k, v in result.items(): + score_dict[k].append(round(v["f"] * 100, 4)) + + bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) + score_dict["bleu-4"].append(round(bleu_score * 100, 4)) + + return {k: float(np.mean(v)) for k, v in score_dict.items()} diff --git a/llm_rl/src/llmtuner/tuner/sft/trainer.py b/llm_rl/src/llmtuner/tuner/sft/trainer.py new file mode 100644 index 00000000..c65cd255 --- /dev/null +++ b/llm_rl/src/llmtuner/tuner/sft/trainer.py @@ -0,0 +1,92 @@ +import os +import json +import torch +import numpy as np +import torch.nn as nn +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from transformers import Seq2SeqTrainer + +from llmtuner.extras.constants import IGNORE_INDEX +from llmtuner.extras.logging import get_logger + +if TYPE_CHECKING: + from transformers.trainer import PredictionOutput + + +logger = get_logger(__name__) + + +class CustomSeq2SeqTrainer(Seq2SeqTrainer): + r""" + Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE. + """ + + def prediction_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: + r""" + Removes the prompt part in the generated tokens. + + Subclass and override to inject custom behavior. + """ + labels = inputs["labels"].detach().clone() if "labels" in inputs else None # backup labels + if self.args.predict_with_generate: + assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor." + prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1) + if prompt_len > label_len: + inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"]) + if label_len > prompt_len: + inputs["labels"] = inputs["labels"][:, :prompt_len] # truncate the labels instead of padding the inputs + + loss, generated_tokens, _ = super().prediction_step( + model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys + ) + if generated_tokens is not None and self.args.predict_with_generate: + generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id + generated_tokens = generated_tokens.contiguous() + + return loss, generated_tokens, labels + + def _pad_tensors_to_target_len( + self, + src_tensor: torch.Tensor, + tgt_tensor: torch.Tensor + ) -> torch.Tensor: + r""" + Pads the tensor to the same length as the target tensor. + """ + assert self.tokenizer.pad_token_id is not None, "Pad token is required." + padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor) + padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding + return padded_tensor.contiguous() # in contiguous memory + + def save_predictions( + self, + predict_results: "PredictionOutput" + ) -> None: + r""" + Saves model predictions to `output_dir`. + + A custom behavior that not contained in Seq2SeqTrainer. + """ + if not self.is_world_process_zero(): + return + + output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") + logger.info(f"Saving prediction results to {output_prediction_file}") + + preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id) + labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id) + + decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True) + decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=True) + + with open(output_prediction_file, "w", encoding="utf-8") as writer: + res: List[str] = [] + for pred, label in zip(decoded_preds, decoded_labels): + res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False)) + writer.write("\n".join(res)) diff --git a/llm_rl/src/llmtuner/tuner/sft/workflow.py b/llm_rl/src/llmtuner/tuner/sft/workflow.py new file mode 100644 index 00000000..8d53605d --- /dev/null +++ b/llm_rl/src/llmtuner/tuner/sft/workflow.py @@ -0,0 +1,90 @@ +# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py + +from typing import TYPE_CHECKING, Optional, List +from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments + +from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset +from llmtuner.extras.constants import IGNORE_INDEX +from llmtuner.extras.misc import get_logits_processor +from llmtuner.extras.ploting import plot_loss +from llmtuner.tuner.core import load_model_and_tokenizer +from llmtuner.tuner.sft.metric import ComputeMetrics +from llmtuner.tuner.sft.trainer import CustomSeq2SeqTrainer + +if TYPE_CHECKING: + from transformers import TrainerCallback + from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments + + +def run_sft( + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + callbacks: Optional[List["TrainerCallback"]] = None +): + dataset = get_dataset(model_args, data_args) + model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft") + dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft") + + if training_args.predict_with_generate: + tokenizer.padding_side = "left" # use left-padding in generation + + data_collator = DataCollatorForSeq2Seq( + tokenizer=tokenizer, + pad_to_multiple_of=4 if tokenizer.padding_side == "right" else None, # for shift short attention + label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id + ) + + # Override the decoding parameters of Seq2SeqTrainer + training_args_dict = training_args.to_dict() + training_args_dict.update(dict( + generation_max_length=training_args.generation_max_length or data_args.cutoff_len, + generation_num_beams=data_args.eval_num_beams or training_args.generation_num_beams + )) + training_args = Seq2SeqTrainingArguments(**training_args_dict) + + # Initialize our Trainer + trainer = CustomSeq2SeqTrainer( + model=model, + args=training_args, + tokenizer=tokenizer, + data_collator=data_collator, + callbacks=callbacks, + compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None, + **split_dataset(dataset, data_args, training_args) + ) + + # Keyword arguments for `model.generate` + gen_kwargs = generating_args.to_dict() + gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids + gen_kwargs["pad_token_id"] = tokenizer.pad_token_id + gen_kwargs["logits_processor"] = get_logits_processor() + + # Training + if training_args.do_train: + train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + trainer.save_model() + if trainer.is_world_process_zero() and model_args.plot_loss: + plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) + + # Evaluation + if training_args.do_eval: + metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs) + if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled + metrics.pop("eval_loss", None) + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # Predict + if training_args.do_predict: + predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs) + if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled + predict_results.metrics.pop("predict_loss", None) + trainer.log_metrics("predict", predict_results.metrics) + trainer.save_metrics("predict", predict_results.metrics) + trainer.save_predictions(predict_results) diff --git a/llm_rl/src/llmtuner/tuner/tune.py b/llm_rl/src/llmtuner/tuner/tune.py new file mode 100644 index 00000000..4eb7f78f --- /dev/null +++ b/llm_rl/src/llmtuner/tuner/tune.py @@ -0,0 +1,51 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from llmtuner.extras.callbacks import LogCallback +from llmtuner.extras.logging import get_logger +from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer +from llmtuner.tuner.pt import run_pt +from llmtuner.tuner.sft import run_sft +from llmtuner.tuner.rm import run_rm +from llmtuner.tuner.ppo import run_ppo +from llmtuner.tuner.dpo import run_dpo + +if TYPE_CHECKING: + from transformers import TrainerCallback + + +logger = get_logger(__name__) + + +def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None): + model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) + callbacks = [LogCallback()] if callbacks is None else callbacks + + if finetuning_args.stage == "pt": + run_pt(model_args, data_args, training_args, finetuning_args, callbacks) + elif finetuning_args.stage == "sft": + run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) + elif finetuning_args.stage == "rm": + run_rm(model_args, data_args, training_args, finetuning_args, callbacks) + elif finetuning_args.stage == "ppo": + run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) + elif finetuning_args.stage == "dpo": + run_dpo(model_args, data_args, training_args, finetuning_args, callbacks) + else: + raise ValueError("Unknown task.") + + +def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"): + model_args, _, finetuning_args, _ = get_infer_args(args) + model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) + model.config.use_cache = True + model.save_pretrained(model_args.export_dir, max_shard_size=max_shard_size) + try: + tokenizer.padding_side = "left" # restore padding side + tokenizer.init_kwargs["padding_side"] = "left" + tokenizer.save_pretrained(model_args.export_dir) + except: + logger.warning("Cannot save tokenizer, please copy the files manually.") + + +if __name__ == "__main__": + run_exp() diff --git a/llm_rl/src/llmtuner/webui/__init__.py b/llm_rl/src/llmtuner/webui/__init__.py new file mode 100644 index 00000000..a27c7f6e --- /dev/null +++ b/llm_rl/src/llmtuner/webui/__init__.py @@ -0,0 +1 @@ +from llmtuner.webui.interface import create_ui, create_web_demo diff --git a/llm_rl/src/llmtuner/webui/chatter.py b/llm_rl/src/llmtuner/webui/chatter.py new file mode 100644 index 00000000..57eadb01 --- /dev/null +++ b/llm_rl/src/llmtuner/webui/chatter.py @@ -0,0 +1,101 @@ +import gradio as gr +from gradio.components import Component # cannot use TYPE_CHECKING here +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple + +from llmtuner.chat.stream_chat import ChatModel +from llmtuner.extras.misc import torch_gc +from llmtuner.hparams import GeneratingArguments +from llmtuner.webui.common import get_save_dir +from llmtuner.webui.locales import ALERTS + +if TYPE_CHECKING: + from llmtuner.webui.manager import Manager + + +class WebChatModel(ChatModel): + + def __init__(self, manager: "Manager", lazy_init: Optional[bool] = True) -> None: + self.manager = manager + self.model = None + self.tokenizer = None + self.generating_args = GeneratingArguments() + if not lazy_init: + super().__init__() + + @property + def loaded(self) -> bool: + return self.model is not None + + def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]: + get = lambda name: data[self.manager.get_elem_by_name(name)] + lang = get("top.lang") + error = "" + if self.loaded: + error = ALERTS["err_exists"][lang] + elif not get("top.model_name"): + error = ALERTS["err_no_model"][lang] + elif not get("top.model_path"): + error = ALERTS["err_no_path"][lang] + + if error: + gr.Warning(error) + yield error + return + + if get("top.checkpoints"): + checkpoint_dir = ",".join([ + get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints") + ]) + else: + checkpoint_dir = None + + yield ALERTS["info_loading"][lang] + args = dict( + model_name_or_path=get("top.model_path"), + checkpoint_dir=checkpoint_dir, + finetuning_type=get("top.finetuning_type"), + quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, + template=get("top.template"), + system_prompt=get("top.system_prompt"), + flash_attn=get("top.flash_attn"), + shift_attn=get("top.shift_attn"), + rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None + ) + super().__init__(args) + + yield ALERTS["info_loaded"][lang] + + def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]: + lang = data[self.manager.get_elem_by_name("top.lang")] + yield ALERTS["info_unloading"][lang] + self.model = None + self.tokenizer = None + torch_gc() + yield ALERTS["info_unloaded"][lang] + + def predict( + self, + chatbot: List[Tuple[str, str]], + query: str, + history: List[Tuple[str, str]], + system: str, + max_new_tokens: int, + top_p: float, + temperature: float + ) -> Generator[Tuple[List[Tuple[str, str]], List[Tuple[str, str]]], None, None]: + chatbot.append([query, ""]) + response = "" + for new_text in self.stream_chat( + query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature + ): + response += new_text + new_history = history + [(query, response)] + chatbot[-1] = [query, self.postprocess(response)] + yield chatbot, new_history + + def postprocess(self, response: str) -> str: + blocks = response.split("```") + for i, block in enumerate(blocks): + if i % 2 == 0: + blocks[i] = block.replace("<", "<").replace(">", ">") + return "```".join(blocks) diff --git a/llm_rl/src/llmtuner/webui/common.py b/llm_rl/src/llmtuner/webui/common.py new file mode 100644 index 00000000..5a6c16d3 --- /dev/null +++ b/llm_rl/src/llmtuner/webui/common.py @@ -0,0 +1,103 @@ +import os +import json +import gradio as gr +from typing import Any, Dict, Optional +from transformers.utils import ( + WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + SAFE_WEIGHTS_INDEX_NAME, + ADAPTER_WEIGHTS_NAME, + ADAPTER_SAFE_WEIGHTS_NAME +) + +from llmtuner.extras.constants import DEFAULT_MODULE, DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES + + +DEFAULT_CACHE_DIR = "cache" +DEFAULT_DATA_DIR = "data" +DEFAULT_SAVE_DIR = "saves" +USER_CONFIG = "user.config" +DATA_CONFIG = "dataset_info.json" +CKPT_NAMES = [ + WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + SAFE_WEIGHTS_INDEX_NAME, + ADAPTER_WEIGHTS_NAME, + ADAPTER_SAFE_WEIGHTS_NAME +] + + +def get_save_dir(*args) -> os.PathLike: + return os.path.join(DEFAULT_SAVE_DIR, *args) + + +def get_config_path() -> os.PathLike: + return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG) + + +def load_config() -> Dict[str, Any]: + try: + with open(get_config_path(), "r", encoding="utf-8") as f: + return json.load(f) + except: + return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None} + + +def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None: + os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) + user_config = load_config() + user_config["lang"] = lang or user_config["lang"] + if model_name: + user_config["last_model"] = model_name + user_config["path_dict"][model_name] = model_path + with open(get_config_path(), "w", encoding="utf-8") as f: + json.dump(user_config, f, indent=2, ensure_ascii=False) + + +def get_model_path(model_name: str) -> str: + user_config = load_config() + return user_config["path_dict"].get(model_name, None) or SUPPORTED_MODELS.get(model_name, "") + + +def get_module(model_name: str) -> str: + return DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj") + + +def get_template(model_name: str) -> str: + if model_name.endswith("Chat") and model_name.split("-")[0] in DEFAULT_TEMPLATE: + return DEFAULT_TEMPLATE[model_name.split("-")[0]] + return "default" + + +def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]: + checkpoints = [] + if model_name: + save_dir = get_save_dir(model_name, finetuning_type) + if save_dir and os.path.isdir(save_dir): + for checkpoint in os.listdir(save_dir): + if ( + os.path.isdir(os.path.join(save_dir, checkpoint)) + and any([os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CKPT_NAMES]) + ): + checkpoints.append(checkpoint) + return gr.update(value=[], choices=checkpoints) + + +def load_dataset_info(dataset_dir: str) -> Dict[str, Any]: + try: + with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: + return json.load(f) + except: + print("Cannot find {} in {}.".format(DATA_CONFIG, dataset_dir)) + return {} + + +def list_dataset( + dataset_dir: Optional[str] = None, training_stage: Optional[str] = list(TRAINING_STAGES.keys())[0] +) -> Dict[str, Any]: + dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR) + ranking = TRAINING_STAGES[training_stage] in ["rm", "dpo"] + datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking] + return gr.update(value=[], choices=datasets) diff --git a/llm_rl/src/llmtuner/webui/components/__init__.py b/llm_rl/src/llmtuner/webui/components/__init__.py new file mode 100644 index 00000000..32228b8e --- /dev/null +++ b/llm_rl/src/llmtuner/webui/components/__init__.py @@ -0,0 +1,6 @@ +from llmtuner.webui.components.top import create_top +from llmtuner.webui.components.train import create_train_tab +from llmtuner.webui.components.eval import create_eval_tab +from llmtuner.webui.components.infer import create_infer_tab +from llmtuner.webui.components.export import create_export_tab +from llmtuner.webui.components.chatbot import create_chat_box diff --git a/llm_rl/src/llmtuner/webui/components/chatbot.py b/llm_rl/src/llmtuner/webui/components/chatbot.py new file mode 100644 index 00000000..13e2dd4d --- /dev/null +++ b/llm_rl/src/llmtuner/webui/components/chatbot.py @@ -0,0 +1,49 @@ +import gradio as gr +from typing import TYPE_CHECKING, Dict, Optional, Tuple + +if TYPE_CHECKING: + from gradio.blocks import Block + from gradio.components import Component + from llmtuner.webui.engine import Engine + + +def create_chat_box( + engine: "Engine", + visible: Optional[bool] = False +) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]: + with gr.Box(visible=visible) as chat_box: + chatbot = gr.Chatbot() + history = gr.State([]) + with gr.Row(): + with gr.Column(scale=4): + system = gr.Textbox(show_label=False) + query = gr.Textbox(show_label=False, lines=8) + submit_btn = gr.Button(variant="primary") + + with gr.Column(scale=1): + clear_btn = gr.Button() + gen_kwargs = engine.chatter.generating_args + max_new_tokens = gr.Slider(10, 2048, value=gen_kwargs.max_new_tokens, step=1) + top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01) + temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01) + + submit_btn.click( + engine.chatter.predict, + [chatbot, query, history, system, max_new_tokens, top_p, temperature], + [chatbot, history], + show_progress=True + ).then( + lambda: gr.update(value=""), outputs=[query] + ) + + clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True) + + return chat_box, chatbot, history, dict( + system=system, + query=query, + submit_btn=submit_btn, + clear_btn=clear_btn, + max_new_tokens=max_new_tokens, + top_p=top_p, + temperature=temperature + ) diff --git a/llm_rl/src/llmtuner/webui/components/data.py b/llm_rl/src/llmtuner/webui/components/data.py new file mode 100644 index 00000000..effa39da --- /dev/null +++ b/llm_rl/src/llmtuner/webui/components/data.py @@ -0,0 +1,103 @@ +import os +import json +import gradio as gr +from typing import TYPE_CHECKING, Any, Dict, Tuple + +from llmtuner.webui.common import DATA_CONFIG + +if TYPE_CHECKING: + from gradio.components import Component + + +PAGE_SIZE = 2 + + +def prev_page(page_index: int) -> int: + return page_index - 1 if page_index > 0 else page_index + + +def next_page(page_index: int, total_num: int) -> int: + return page_index + 1 if (page_index + 1) * PAGE_SIZE < total_num else page_index + + +def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]: + with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: + dataset_info = json.load(f) + + if ( + len(dataset) > 0 + and "file_name" in dataset_info[dataset[0]] + and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])) + ): + return gr.update(interactive=True) + else: + return gr.update(interactive=False) + + +def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, Dict[str, Any]]: + with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: + dataset_info = json.load(f) + + data_file: str = dataset_info[dataset[0]]["file_name"] + with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f: + if data_file.endswith(".json"): + data = json.load(f) + elif data_file.endswith(".jsonl"): + data = [json.loads(line) for line in f] + else: + data = [line for line in f] + return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.update(visible=True) + + +def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dict[str, "Component"]: + data_preview_btn = gr.Button(interactive=False, scale=1) + with gr.Column(visible=False, elem_classes="modal-box") as preview_box: + with gr.Row(): + preview_count = gr.Number(value=0, interactive=False, precision=0) + page_index = gr.Number(value=0, interactive=False, precision=0) + + with gr.Row(): + prev_btn = gr.Button() + next_btn = gr.Button() + close_btn = gr.Button() + + with gr.Row(): + preview_samples = gr.JSON(interactive=False) + + dataset.change( + can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False + ).then( + lambda: 0, outputs=[page_index], queue=False + ) + data_preview_btn.click( + get_preview, + [dataset_dir, dataset, page_index], + [preview_count, preview_samples, preview_box], + queue=False + ) + prev_btn.click( + prev_page, [page_index], [page_index], queue=False + ).then( + get_preview, + [dataset_dir, dataset, page_index], + [preview_count, preview_samples, preview_box], + queue=False + ) + next_btn.click( + next_page, [page_index, preview_count], [page_index], queue=False + ).then( + get_preview, + [dataset_dir, dataset, page_index], + [preview_count, preview_samples, preview_box], + queue=False + ) + close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False) + return dict( + data_preview_btn=data_preview_btn, + preview_count=preview_count, + page_index=page_index, + prev_btn=prev_btn, + next_btn=next_btn, + close_btn=close_btn, + preview_samples=preview_samples + ) diff --git a/llm_rl/src/llmtuner/webui/components/eval.py b/llm_rl/src/llmtuner/webui/components/eval.py new file mode 100644 index 00000000..36c994a6 --- /dev/null +++ b/llm_rl/src/llmtuner/webui/components/eval.py @@ -0,0 +1,70 @@ +import gradio as gr +from typing import TYPE_CHECKING, Dict + +from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR +from llmtuner.webui.components.data import create_preview_box + +if TYPE_CHECKING: + from gradio.components import Component + from llmtuner.webui.engine import Engine + + +def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]: + input_elems = engine.manager.get_base_elems() + elem_dict = dict() + + with gr.Row(): + dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) + dataset = gr.Dropdown(multiselect=True, scale=4) + preview_elems = create_preview_box(dataset_dir, dataset) + + dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False) + + input_elems.update({dataset_dir, dataset}) + elem_dict.update(dict(dataset_dir=dataset_dir, dataset=dataset, **preview_elems)) + + with gr.Row(): + cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1) + max_samples = gr.Textbox(value="100000") + batch_size = gr.Slider(value=8, minimum=1, maximum=512, step=1) + predict = gr.Checkbox(value=True) + + input_elems.update({cutoff_len, max_samples, batch_size, predict}) + elem_dict.update(dict( + cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict + )) + + with gr.Row(): + max_new_tokens = gr.Slider(10, 2048, value=128, step=1) + top_p = gr.Slider(0.01, 1, value=0.7, step=0.01) + temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01) + + input_elems.update({max_new_tokens, top_p, temperature}) + elem_dict.update(dict( + max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature + )) + + with gr.Row(): + cmd_preview_btn = gr.Button() + start_btn = gr.Button() + stop_btn = gr.Button() + + with gr.Row(): + resume_btn = gr.Checkbox(visible=False, interactive=False, value=False) + process_bar = gr.Slider(visible=False, interactive=False) + + with gr.Box(): + output_box = gr.Markdown() + + output_elems = [output_box, process_bar] + elem_dict.update(dict( + cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, + resume_btn=resume_btn, process_bar=process_bar, output_box=output_box + )) + + cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems) + start_btn.click(engine.runner.run_eval, input_elems, output_elems) + stop_btn.click(engine.runner.set_abort, queue=False) + resume_btn.change(engine.runner.monitor, outputs=output_elems) + + return elem_dict diff --git a/llm_rl/src/llmtuner/webui/components/export.py b/llm_rl/src/llmtuner/webui/components/export.py new file mode 100644 index 00000000..d16fa3d1 --- /dev/null +++ b/llm_rl/src/llmtuner/webui/components/export.py @@ -0,0 +1,79 @@ +import gradio as gr +from typing import TYPE_CHECKING, Dict, Generator, List + +from llmtuner.tuner import export_model +from llmtuner.webui.common import get_save_dir +from llmtuner.webui.locales import ALERTS + +if TYPE_CHECKING: + from gradio.components import Component + from llmtuner.webui.engine import Engine + + +def save_model( + lang: str, + model_name: str, + model_path: str, + checkpoints: List[str], + finetuning_type: str, + template: str, + max_shard_size: int, + export_dir: str +) -> Generator[str, None, None]: + error = "" + if not model_name: + error = ALERTS["err_no_model"][lang] + elif not model_path: + error = ALERTS["err_no_path"][lang] + elif not checkpoints: + error = ALERTS["err_no_checkpoint"][lang] + elif not export_dir: + error = ALERTS["err_no_export_dir"][lang] + + if error: + gr.Warning(error) + yield error + return + + args = dict( + model_name_or_path=model_path, + checkpoint_dir=",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]), + finetuning_type=finetuning_type, + template=template, + export_dir=export_dir + ) + + yield ALERTS["info_exporting"][lang] + export_model(args, max_shard_size="{}GB".format(max_shard_size)) + yield ALERTS["info_exported"][lang] + + +def create_export_tab(engine: "Engine") -> Dict[str, "Component"]: + with gr.Row(): + export_dir = gr.Textbox() + max_shard_size = gr.Slider(value=10, minimum=1, maximum=100) + + export_btn = gr.Button() + info_box = gr.Textbox(show_label=False, interactive=False) + + export_btn.click( + save_model, + [ + engine.manager.get_elem_by_name("top.lang"), + engine.manager.get_elem_by_name("top.model_name"), + engine.manager.get_elem_by_name("top.model_path"), + engine.manager.get_elem_by_name("top.checkpoints"), + engine.manager.get_elem_by_name("top.finetuning_type"), + engine.manager.get_elem_by_name("top.template"), + max_shard_size, + export_dir + ], + [info_box] + ) + + return dict( + export_dir=export_dir, + max_shard_size=max_shard_size, + export_btn=export_btn, + info_box=info_box + ) diff --git a/llm_rl/src/llmtuner/webui/components/infer.py b/llm_rl/src/llmtuner/webui/components/infer.py new file mode 100644 index 00000000..d6dd7eed --- /dev/null +++ b/llm_rl/src/llmtuner/webui/components/infer.py @@ -0,0 +1,39 @@ +import gradio as gr +from typing import TYPE_CHECKING, Dict + +from llmtuner.webui.components.chatbot import create_chat_box + +if TYPE_CHECKING: + from gradio.components import Component + from llmtuner.webui.engine import Engine + + +def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]: + input_elems = engine.manager.get_base_elems() + elem_dict = dict() + + with gr.Row(): + load_btn = gr.Button() + unload_btn = gr.Button() + + info_box = gr.Textbox(show_label=False, interactive=False) + elem_dict.update(dict(load_btn=load_btn, unload_btn=unload_btn, info_box=info_box)) + + chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False) + elem_dict.update(dict(chat_box=chat_box, **chat_elems)) + + load_btn.click( + engine.chatter.load_model, input_elems, [info_box] + ).then( + lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box] + ) + + unload_btn.click( + engine.chatter.unload_model, input_elems, [info_box] + ).then( + lambda: ([], []), outputs=[chatbot, history] + ).then( + lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box] + ) + + return elem_dict diff --git a/llm_rl/src/llmtuner/webui/components/top.py b/llm_rl/src/llmtuner/webui/components/top.py new file mode 100644 index 00000000..c6299cab --- /dev/null +++ b/llm_rl/src/llmtuner/webui/components/top.py @@ -0,0 +1,74 @@ +import gradio as gr +from typing import TYPE_CHECKING, Dict + +from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS +from llmtuner.extras.template import templates +from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, save_config +from llmtuner.webui.utils import can_quantize + +if TYPE_CHECKING: + from gradio.components import Component + + +def create_top() -> Dict[str, "Component"]: + available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] + + with gr.Row(): + lang = gr.Dropdown(choices=["en", "zh"], scale=1) + model_name = gr.Dropdown(choices=available_models, scale=3) + model_path = gr.Textbox(scale=3) + + with gr.Row(): + finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1) + checkpoints = gr.Dropdown(multiselect=True, scale=5) + refresh_btn = gr.Button(scale=1) + + with gr.Accordion(label="Advanced config", open=False) as advanced_tab: + with gr.Row(): + quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=1) + template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1) + system_prompt = gr.Textbox(scale=2) + + with gr.Accordion(label="Model config (LLaMA only)", open=False) as llama_tab: + with gr.Row(): + with gr.Column(): + flash_attn = gr.Checkbox(value=False) + shift_attn = gr.Checkbox(value=False) + rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none") + + model_name.change( + list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False + ).then( + get_model_path, [model_name], [model_path], queue=False + ).then( + get_template, [model_name], [template], queue=False + ) # do not save config since the below line will save + + model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False) + + finetuning_type.change( + list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False + ).then( + can_quantize, [finetuning_type], [quantization_bit], queue=False + ) + + refresh_btn.click( + list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False + ) + + return dict( + lang=lang, + model_name=model_name, + model_path=model_path, + finetuning_type=finetuning_type, + checkpoints=checkpoints, + refresh_btn=refresh_btn, + advanced_tab=advanced_tab, + quantization_bit=quantization_bit, + template=template, + system_prompt=system_prompt, + llama_tab=llama_tab, + flash_attn=flash_attn, + shift_attn=shift_attn, + rope_scaling=rope_scaling + ) diff --git a/llm_rl/src/llmtuner/webui/components/train.py b/llm_rl/src/llmtuner/webui/components/train.py new file mode 100644 index 00000000..11109c97 --- /dev/null +++ b/llm_rl/src/llmtuner/webui/components/train.py @@ -0,0 +1,154 @@ +import gradio as gr +from typing import TYPE_CHECKING, Dict +from transformers.trainer_utils import SchedulerType + +from llmtuner.extras.constants import TRAINING_STAGES +from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR +from llmtuner.webui.components.data import create_preview_box +from llmtuner.webui.utils import gen_plot + +if TYPE_CHECKING: + from gradio.components import Component + from llmtuner.webui.engine import Engine + + +def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: + input_elems = engine.manager.get_base_elems() + elem_dict = dict() + + with gr.Row(): + training_stage = gr.Dropdown( + choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=2 + ) + dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) + dataset = gr.Dropdown(multiselect=True, scale=4) + preview_elems = create_preview_box(dataset_dir, dataset) + + training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False) + dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False) + + input_elems.update({training_stage, dataset_dir, dataset}) + elem_dict.update(dict( + training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems + )) + + with gr.Row(): + cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1) + learning_rate = gr.Textbox(value="5e-5") + num_train_epochs = gr.Textbox(value="3.0") + max_samples = gr.Textbox(value="100000") + compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16") + + input_elems.update({cutoff_len, learning_rate, num_train_epochs, max_samples, compute_type}) + elem_dict.update(dict( + cutoff_len=cutoff_len, learning_rate=learning_rate, num_train_epochs=num_train_epochs, + max_samples=max_samples, compute_type=compute_type + )) + + with gr.Row(): + batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1) + gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1) + lr_scheduler_type = gr.Dropdown( + choices=[scheduler.value for scheduler in SchedulerType], value="cosine" + ) + max_grad_norm = gr.Textbox(value="1.0") + val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001) + + input_elems.update({batch_size, gradient_accumulation_steps, lr_scheduler_type, max_grad_norm, val_size}) + elem_dict.update(dict( + batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation_steps, + lr_scheduler_type=lr_scheduler_type, max_grad_norm=max_grad_norm, val_size=val_size + )) + + with gr.Accordion(label="Advanced config", open=False) as advanced_tab: + with gr.Row(): + logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5) + save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10) + warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1) + neft_alpha = gr.Slider(value=0, minimum=0, maximum=10, step=0.1) + + with gr.Column(): + train_on_prompt = gr.Checkbox(value=False) + upcast_layernorm = gr.Checkbox(value=False) + + input_elems.update({logging_steps, save_steps, warmup_steps, neft_alpha, train_on_prompt, upcast_layernorm}) + elem_dict.update(dict( + advanced_tab=advanced_tab, logging_steps=logging_steps, save_steps=save_steps, warmup_steps=warmup_steps, + neft_alpha=neft_alpha, train_on_prompt=train_on_prompt, upcast_layernorm=upcast_layernorm + )) + + with gr.Accordion(label="LoRA config", open=False) as lora_tab: + with gr.Row(): + lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1) + lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1) + lora_target = gr.Textbox(scale=1) + additional_target = gr.Textbox(scale=1) + resume_lora_training = gr.Checkbox(value=True, scale=1) + + input_elems.update({lora_rank, lora_dropout, lora_target, additional_target, resume_lora_training}) + elem_dict.update(dict( + lora_tab=lora_tab, lora_rank=lora_rank, lora_dropout=lora_dropout, lora_target=lora_target, + additional_target=additional_target, resume_lora_training=resume_lora_training, + )) + + with gr.Accordion(label="RLHF config", open=False) as rlhf_tab: + with gr.Row(): + dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1) + reward_model = gr.Dropdown(scale=3) + refresh_btn = gr.Button(scale=1) + + refresh_btn.click( + list_checkpoint, + [engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")], + [reward_model], + queue=False + ) + + input_elems.update({dpo_beta, reward_model}) + elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, reward_model=reward_model, refresh_btn=refresh_btn)) + + with gr.Row(): + cmd_preview_btn = gr.Button() + start_btn = gr.Button() + stop_btn = gr.Button() + + with gr.Row(): + with gr.Column(scale=3): + with gr.Row(): + output_dir = gr.Textbox() + + with gr.Row(): + resume_btn = gr.Checkbox(visible=False, interactive=False, value=False) + process_bar = gr.Slider(visible=False, interactive=False) + + with gr.Box(): + output_box = gr.Markdown() + + with gr.Column(scale=1): + loss_viewer = gr.Plot() + + input_elems.add(output_dir) + output_elems = [output_box, process_bar] + + cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems) + start_btn.click(engine.runner.run_train, input_elems, output_elems) + stop_btn.click(engine.runner.set_abort, queue=False) + resume_btn.change(engine.runner.monitor, outputs=output_elems) + + elem_dict.update(dict( + cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_dir=output_dir, + resume_btn=resume_btn, process_bar=process_bar, output_box=output_box, loss_viewer=loss_viewer + )) + + output_box.change( + gen_plot, + [ + engine.manager.get_elem_by_name("top.model_name"), + engine.manager.get_elem_by_name("top.finetuning_type"), + output_dir + ], + loss_viewer, + queue=False + ) + + return elem_dict diff --git a/llm_rl/src/llmtuner/webui/css.py b/llm_rl/src/llmtuner/webui/css.py new file mode 100644 index 00000000..c86fb96b --- /dev/null +++ b/llm_rl/src/llmtuner/webui/css.py @@ -0,0 +1,20 @@ +CSS = r""" +.modal-box { + position: fixed !important; + top: 50%; + left: 50%; + transform: translate(-50%, -50%); /* center horizontally */ + max-width: 1000px; + max-height: 750px; + overflow-y: auto; + background-color: var(--input-background-fill); + flex-wrap: nowrap !important; + border: 2px solid black !important; + z-index: 1000; + padding: 10px; +} + +.dark .modal-box { + border: 2px solid white !important; +} +""" diff --git a/llm_rl/src/llmtuner/webui/engine.py b/llm_rl/src/llmtuner/webui/engine.py new file mode 100644 index 00000000..661dfb48 --- /dev/null +++ b/llm_rl/src/llmtuner/webui/engine.py @@ -0,0 +1,57 @@ +import gradio as gr +from gradio.components import Component # cannot use TYPE_CHECKING here +from typing import Any, Dict, Generator, Optional + +from llmtuner.webui.chatter import WebChatModel +from llmtuner.webui.common import get_model_path, list_dataset, load_config +from llmtuner.webui.locales import LOCALES +from llmtuner.webui.manager import Manager +from llmtuner.webui.runner import Runner +from llmtuner.webui.utils import get_time + + +class Engine: + + def __init__(self, pure_chat: Optional[bool] = False) -> None: + self.pure_chat = pure_chat + self.manager: "Manager" = Manager() + self.runner: "Runner" = Runner(self.manager) + self.chatter: "WebChatModel" = WebChatModel(manager=self.manager, lazy_init=(not pure_chat)) + + def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]): + return {self.manager.get_elem_by_name(k): gr.update(**v) for k, v in resume_dict.items()} + + def resume(self) -> Generator[Dict[Component, Dict[str, Any]], None, None]: + user_config = load_config() + lang = user_config.get("lang", None) or "en" + + init_dict = { + "top.lang": {"value": lang}, + "infer.chat_box": {"visible": self.chatter.loaded} + } + + if not self.pure_chat: + init_dict["train.dataset"] = {"choices": list_dataset()["choices"]} + init_dict["eval.dataset"] = {"choices": list_dataset()["choices"]} + + if user_config.get("last_model", None): + init_dict["top.model_name"] = {"value": user_config["last_model"]} + init_dict["top.model_path"] = {"value": get_model_path(user_config["last_model"])} + + yield self._form_dict(init_dict) + + if not self.pure_chat: + if self.runner.alive: + yield {elem: gr.update(value=value) for elem, value in self.runner.running_data.items()} + if self.runner.do_train: + yield self._form_dict({"train.resume_btn": {"value": True}}) + else: + yield self._form_dict({"eval.resume_btn": {"value": True}}) + else: + yield self._form_dict({"train.output_dir": {"value": get_time()}}) + + def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]: + return { + component: gr.update(**LOCALES[name][lang]) + for elems in self.manager.all_elems.values() for name, component in elems.items() if name in LOCALES + } diff --git a/llm_rl/src/llmtuner/webui/interface.py b/llm_rl/src/llmtuner/webui/interface.py new file mode 100644 index 00000000..ba663f24 --- /dev/null +++ b/llm_rl/src/llmtuner/webui/interface.py @@ -0,0 +1,66 @@ +import gradio as gr +from transformers.utils.versions import require_version + +from llmtuner.webui.components import ( + create_top, + create_train_tab, + create_eval_tab, + create_infer_tab, + create_export_tab, + create_chat_box +) +from llmtuner.webui.common import save_config +from llmtuner.webui.css import CSS +from llmtuner.webui.engine import Engine + + +require_version("gradio>=3.38.0,<4.0.0", "To fix: pip install \"gradio>=3.38.0,<4.0.0\"") + + +def create_ui() -> gr.Blocks: + engine = Engine(pure_chat=False) + + with gr.Blocks(title="LLaMA Board", css=CSS) as demo: + engine.manager.all_elems["top"] = create_top() + lang: "gr.Dropdown" = engine.manager.get_elem_by_name("top.lang") + + with gr.Tab("Train"): + engine.manager.all_elems["train"] = create_train_tab(engine) + + with gr.Tab("Evaluate"): + engine.manager.all_elems["eval"] = create_eval_tab(engine) + + with gr.Tab("Chat"): + engine.manager.all_elems["infer"] = create_infer_tab(engine) + + with gr.Tab("Export"): + engine.manager.all_elems["export"] = create_export_tab(engine) + + demo.load(engine.resume, outputs=engine.manager.list_elems()) + lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False) + lang.input(save_config, inputs=[lang], queue=False) + + return demo + + +def create_web_demo() -> gr.Blocks: + engine = Engine(pure_chat=True) + + with gr.Blocks(title="Web Demo", css=CSS) as demo: + lang = gr.Dropdown(choices=["en", "zh"]) + engine.manager.all_elems["top"] = dict(lang=lang) + + chat_box, _, _, chat_elems = create_chat_box(engine, visible=True) + engine.manager.all_elems["infer"] = dict(chat_box=chat_box, **chat_elems) + + demo.load(engine.resume, outputs=engine.manager.list_elems()) + lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False) + lang.input(save_config, inputs=[lang], queue=False) + + return demo + + +if __name__ == "__main__": + demo = create_ui() + demo.queue() + demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True) diff --git a/llm_rl/src/llmtuner/webui/locales.py b/llm_rl/src/llmtuner/webui/locales.py new file mode 100644 index 00000000..cc2a1842 --- /dev/null +++ b/llm_rl/src/llmtuner/webui/locales.py @@ -0,0 +1,698 @@ +LOCALES = { + "lang": { + "en": { + "label": "Lang" + }, + "zh": { + "label": "语言" + } + }, + "model_name": { + "en": { + "label": "Model name" + }, + "zh": { + "label": "模型名称" + } + }, + "model_path": { + "en": { + "label": "Model path", + "info": "Path to pretrained model or model identifier from Hugging Face." + }, + "zh": { + "label": "模型路径", + "info": "本地模型的文件路径或 Hugging Face 的模型标识符。" + } + }, + "finetuning_type": { + "en": { + "label": "Finetuning method" + }, + "zh": { + "label": "微调方法" + } + }, + "checkpoints": { + "en": { + "label": "Checkpoints" + }, + "zh": { + "label": "模型断点" + } + }, + "refresh_btn": { + "en": { + "value": "Refresh checkpoints" + }, + "zh": { + "value": "刷新断点" + } + }, + "advanced_tab": { + "en": { + "label": "Advanced configurations" + }, + "zh": { + "label": "高级设置" + } + }, + "quantization_bit": { + "en": { + "label": "Quantization bit", + "info": "Enable 4/8-bit model quantization (QLoRA)." + }, + "zh": { + "label": "量化等级", + "info": "启用 4/8 比特模型量化(QLoRA)。" + } + }, + "template": { + "en": { + "label": "Prompt template", + "info": "The template used in constructing prompts." + }, + "zh": { + "label": "提示模板", + "info": "构建提示词时使用的模板" + } + }, + "system_prompt": { + "en": { + "label": "System prompt (optional)", + "info": "A sequence used as the default system prompt." + }, + "zh": { + "label": "系统提示词(非必填)", + "info": "默认使用的系统提示词" + } + }, + "llama_tab": { + "en": { + "label": "Model configurations (LLaMA only)" + }, + "zh": { + "label": "模型设置(仅LLaMA)" + } + }, + "flash_attn": { + "en": { + "label": "Use FlashAttention-2" + }, + "zh": { + "label": "使用 FlashAttention-2" + } + }, + "shift_attn": { + "en": { + "label": "Use shift short attention (S^2-Attn)" + }, + "zh": { + "label": "使用 shift short attention (S^2-Attn)" + } + }, + "rope_scaling": { + "en": { + "label": "RoPE scaling" + }, + "zh": { + "label": "RoPE 插值方法" + } + }, + "training_stage": { + "en": { + "label": "Stage", + "info": "The stage to perform in training." + }, + "zh": { + "label": "训练阶段", + "info": "目前采用的训练方式。" + } + }, + "dataset_dir": { + "en": { + "label": "Data dir", + "info": "Path of the data directory." + }, + "zh": { + "label": "数据路径", + "info": "数据文件夹的路径。" + } + }, + "dataset": { + "en": { + "label": "Dataset" + }, + "zh": { + "label": "数据集" + } + }, + "data_preview_btn": { + "en": { + "value": "Preview dataset" + }, + "zh": { + "value": "预览数据集" + } + }, + "preview_count": { + "en": { + "label": "Count" + }, + "zh": { + "label": "数量" + } + }, + "page_index": { + "en": { + "label": "Page" + }, + "zh": { + "label": "页数" + } + }, + "prev_btn": { + "en": { + "value": "Prev" + }, + "zh": { + "value": "上一页" + } + }, + "next_btn": { + "en": { + "value": "Next" + }, + "zh": { + "value": "下一页" + } + }, + "close_btn": { + "en": { + "value": "Close" + }, + "zh": { + "value": "关闭" + } + }, + "preview_samples": { + "en": { + "label": "Samples" + }, + "zh": { + "label": "样例" + } + }, + "cutoff_len": { + "en": { + "label": "Cutoff length", + "info": "Max tokens in input sequence." + }, + "zh": { + "label": "截断长度", + "info": "输入序列分词后的最大长度。" + } + }, + "learning_rate": { + "en": { + "label": "Learning rate", + "info": "Initial learning rate for AdamW." + }, + "zh": { + "label": "学习率", + "info": "AdamW 优化器的初始学习率。" + } + }, + "num_train_epochs": { + "en": { + "label": "Epochs", + "info": "Total number of training epochs to perform." + }, + "zh": { + "label": "训练轮数", + "info": "需要执行的训练总轮数。" + } + }, + "max_samples": { + "en": { + "label": "Max samples", + "info": "Maximum samples per dataset." + }, + "zh": { + "label": "最大样本数", + "info": "每个数据集最多使用的样本数。" + } + }, + "compute_type": { + "en": { + "label": "Compute type", + "info": "Whether to use fp16 or bf16 mixed precision training." + }, + "zh": { + "label": "计算类型", + "info": "是否启用 FP16 或 BF16 混合精度训练。" + } + }, + "batch_size": { + "en": { + "label": "Batch size", + "info": "Number of samples to process per GPU." + }, + "zh":{ + "label": "批处理大小", + "info": "每块 GPU 上处理的样本数量。" + } + }, + "gradient_accumulation_steps": { + "en": { + "label": "Gradient accumulation", + "info": "Number of gradient accumulation steps." + }, + "zh": { + "label": "梯度累积", + "info": "梯度累积的步数。" + } + }, + "lr_scheduler_type": { + "en": { + "label": "LR Scheduler", + "info": "Name of learning rate scheduler.", + }, + "zh": { + "label": "学习率调节器", + "info": "采用的学习率调节器名称。" + } + }, + "max_grad_norm": { + "en": { + "label": "Maximum gradient norm", + "info": "Norm for gradient clipping.." + }, + "zh": { + "label": "最大梯度范数", + "info": "用于梯度裁剪的范数。" + } + }, + "val_size": { + "en": { + "label": "Val size", + "info": "Proportion of data in the dev set." + }, + "zh": { + "label": "验证集比例", + "info": "验证集占全部样本的百分比。" + } + }, + "logging_steps": { + "en": { + "label": "Logging steps", + "info": "Number of steps between two logs." + }, + "zh": { + "label": "日志间隔", + "info": "每两次日志输出间的更新步数。" + } + }, + "save_steps": { + "en": { + "label": "Save steps", + "info": "Number of steps between two checkpoints." + }, + "zh": { + "label": "保存间隔", + "info": "每两次断点保存间的更新步数。" + } + }, + "warmup_steps": { + "en": { + "label": "Warmup steps", + "info": "Number of steps used for warmup." + }, + "zh": { + "label": "预热步数", + "info": "学习率预热采用的步数。" + } + }, + "neft_alpha": { + "en": { + "label": "NEFTune Alpha", + "info": "Magnitude of noise adding to embedding vectors." + }, + "zh": { + "label": "NEFTune 噪声参数", + "info": "嵌入向量所添加的噪声大小。" + } + }, + "train_on_prompt": { + "en": { + "label": "Train on prompt", + "info": "Compute loss on the prompt tokens in supervised fine-tuning." + }, + "zh": { + "label": "计算输入损失", + "info": "在监督微调时候计算输入序列的损失。" + } + }, + "upcast_layernorm": { + "en": { + "label": "Upcast LayerNorm", + "info": "Upcast weights of layernorm in float32." + }, + "zh": { + "label": "缩放归一化层", + "info": "将归一化层权重缩放至 32 位浮点数。" + } + }, + "lora_tab": { + "en": { + "label": "LoRA configurations" + }, + "zh": { + "label": "LoRA 参数设置" + } + }, + "lora_rank": { + "en": { + "label": "LoRA rank", + "info": "The rank of LoRA matrices." + }, + "zh": { + "label": "LoRA 秩", + "info": "LoRA 矩阵的秩。" + } + }, + "lora_dropout": { + "en": { + "label": "LoRA Dropout", + "info": "Dropout ratio of LoRA weights." + }, + "zh": { + "label": "LoRA 随机丢弃", + "info": "LoRA 权重随机丢弃的概率。" + } + }, + "lora_target": { + "en": { + "label": "LoRA modules (optional)", + "info": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules." + }, + "zh": { + "label": "LoRA 作用模块(非必填)", + "info": "应用 LoRA 的目标模块名称。使用英文逗号分隔多个名称。" + } + }, + "additional_target": { + "en": { + "label": "Additional modules (optional)", + "info": "Name(s) of modules apart from LoRA layers to be set as trainable. Use commas to separate multiple modules." + }, + "zh": { + "label": "附加模块(非必填)", + "info": "除 LoRA 层以外的可训练模块名称。使用英文逗号分隔多个名称。" + } + }, + "resume_lora_training": { + "en": { + "label": "Resume LoRA training", + "info": "Whether to resume training from the last LoRA weights or create new lora weights." + }, + "zh": { + "label": "继续上次的训练", + "info": "接着上次的 LoRA 权重训练或创建一个新的 LoRA 权重。" + } + }, + "rlhf_tab": { + "en": { + "label": "RLHF configurations" + }, + "zh": { + "label": "RLHF 参数设置" + } + }, + "dpo_beta": { + "en": { + "label": "DPO beta", + "info": "Value of the beta parameter in the DPO loss." + }, + "zh": { + "label": "DPO beta 参数", + "info": "DPO 损失函数中 beta 超参数大小。" + } + }, + "reward_model": { + "en": { + "label": "Reward model", + "info": "Checkpoint of the reward model for PPO training. (Needs to refresh checkpoints)" + }, + "zh": { + "label": "奖励模型", + "info": "PPO 训练中奖励模型的断点路径。(需要刷新断点)" + } + }, + "cmd_preview_btn": { + "en": { + "value": "Preview command" + }, + "zh": { + "value": "预览命令" + } + }, + "start_btn": { + "en": { + "value": "Start" + }, + "zh": { + "value": "开始" + } + }, + "stop_btn": { + "en": { + "value": "Abort" + }, + "zh": { + "value": "中断" + } + }, + "output_dir": { + "en": { + "label": "Checkpoint name", + "info": "Directory to save checkpoint." + }, + "zh": { + "label": "断点名称", + "info": "保存模型断点的文件夹名称。" + } + }, + "output_box": { + "en": { + "value": "Ready." + }, + "zh": { + "value": "准备就绪。" + } + }, + "loss_viewer": { + "en": { + "label": "Loss" + }, + "zh": { + "label": "损失" + } + }, + "predict": { + "en": { + "label": "Save predictions" + }, + "zh": { + "label": "保存预测结果" + } + }, + "load_btn": { + "en": { + "value": "Load model" + }, + "zh": { + "value": "加载模型" + } + }, + "unload_btn": { + "en": { + "value": "Unload model" + }, + "zh": { + "value": "卸载模型" + } + }, + "info_box": { + "en": { + "value": "Model unloaded, please load a model first." + }, + "zh": { + "value": "模型未加载,请先加载模型。" + } + }, + "system": { + "en": { + "placeholder": "System prompt (optional)" + }, + "zh": { + "placeholder": "系统提示词(非必填)" + } + }, + "query": { + "en": { + "placeholder": "Input..." + }, + "zh": { + "placeholder": "输入..." + } + }, + "submit_btn": { + "en": { + "value": "Submit" + }, + "zh": { + "value": "提交" + } + }, + "clear_btn": { + "en": { + "value": "Clear history" + }, + "zh": { + "value": "清空历史" + } + }, + "max_length": { + "en": { + "label": "Maximum length" + }, + "zh": { + "label": "最大长度" + } + }, + "max_new_tokens": { + "en": { + "label": "Maximum new tokens" + }, + "zh": { + "label": "最大生成长度" + } + }, + "top_p": { + "en": { + "label": "Top-p" + }, + "zh": { + "label": "Top-p 采样值" + } + }, + "temperature": { + "en": { + "label": "Temperature" + }, + "zh": { + "label": "温度系数" + } + }, + "export_dir": { + "en": { + "label": "Export dir", + "info": "Directory to save exported model." + }, + "zh": { + "label": "导出目录", + "info": "保存导出模型的文件夹路径。" + } + }, + "max_shard_size": { + "en": { + "label": "Max shard size (GB)", + "info": "The maximum size for a model file." + }, + "zh": { + "label": "最大分块大小(GB)", + "info": "模型文件的最大大小。" + } + }, + "export_btn": { + "en": { + "value": "Export" + }, + "zh": { + "value": "开始导出" + } + } +} + + +ALERTS = { + "err_conflict": { + "en": "A process is in running, please abort it firstly.", + "zh": "任务已存在,请先中断训练。" + }, + "err_exists": { + "en": "You have loaded a model, please unload it first.", + "zh": "模型已存在,请先卸载模型。" + }, + "err_no_model": { + "en": "Please select a model.", + "zh": "请选择模型。" + }, + "err_no_path": { + "en": "Model not found.", + "zh": "模型未找到。" + }, + "err_no_dataset": { + "en": "Please choose a dataset.", + "zh": "请选择数据集。" + }, + "err_no_checkpoint": { + "en": "Please select a checkpoint.", + "zh": "请选择断点。" + }, + "err_no_export_dir": { + "en": "Please provide export dir.", + "zh": "请填写导出目录" + }, + "err_failed": { + "en": "Failed.", + "zh": "训练出错。" + }, + "info_aborting": { + "en": "Aborted, wait for terminating...", + "zh": "训练中断,正在等待线程结束……" + }, + "info_aborted": { + "en": "Ready.", + "zh": "准备就绪。" + }, + "info_finished": { + "en": "Finished.", + "zh": "训练完毕。" + }, + "info_loading": { + "en": "Loading model...", + "zh": "加载中……" + }, + "info_unloading": { + "en": "Unloading model...", + "zh": "卸载中……" + }, + "info_loaded": { + "en": "Model loaded, now you can chat with your model!", + "zh": "模型已加载,可以开始聊天了!" + }, + "info_unloaded": { + "en": "Model unloaded.", + "zh": "模型已卸载。" + }, + "info_exporting": { + "en": "Exporting model...", + "zh": "正在导出模型……" + }, + "info_exported": { + "en": "Model exported.", + "zh": "模型导出完成。" + } +} diff --git a/llm_rl/src/llmtuner/webui/manager.py b/llm_rl/src/llmtuner/webui/manager.py new file mode 100644 index 00000000..ca067aea --- /dev/null +++ b/llm_rl/src/llmtuner/webui/manager.py @@ -0,0 +1,35 @@ +from typing import TYPE_CHECKING, Dict, List, Set + +if TYPE_CHECKING: + from gradio.components import Component + + +class Manager: + + def __init__(self) -> None: + self.all_elems: Dict[str, Dict[str, "Component"]] = {} + + def get_elem_by_name(self, name: str) -> "Component": + r""" + Example: top.lang, train.dataset + """ + tab_name, elem_name = name.split(".") + return self.all_elems[tab_name][elem_name] + + def get_base_elems(self) -> Set["Component"]: + return { + self.all_elems["top"]["lang"], + self.all_elems["top"]["model_name"], + self.all_elems["top"]["model_path"], + self.all_elems["top"]["checkpoints"], + self.all_elems["top"]["finetuning_type"], + self.all_elems["top"]["quantization_bit"], + self.all_elems["top"]["template"], + self.all_elems["top"]["system_prompt"], + self.all_elems["top"]["flash_attn"], + self.all_elems["top"]["shift_attn"], + self.all_elems["top"]["rope_scaling"] + } + + def list_elems(self) -> List["Component"]: + return [elem for elems in self.all_elems.values() for elem in elems.values()] diff --git a/llm_rl/src/llmtuner/webui/runner.py b/llm_rl/src/llmtuner/webui/runner.py new file mode 100644 index 00000000..ab9e9ffc --- /dev/null +++ b/llm_rl/src/llmtuner/webui/runner.py @@ -0,0 +1,254 @@ +import os +import time +import logging +import gradio as gr +from threading import Thread +from gradio.components import Component # cannot use TYPE_CHECKING here +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple + +import transformers +from transformers.trainer import TRAINING_ARGS_NAME + +from llmtuner.extras.callbacks import LogCallback +from llmtuner.extras.constants import TRAINING_STAGES +from llmtuner.extras.logging import LoggerHandler +from llmtuner.extras.misc import torch_gc +from llmtuner.tuner import run_exp +from llmtuner.webui.common import get_module, get_save_dir, load_config +from llmtuner.webui.locales import ALERTS +from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar + +if TYPE_CHECKING: + from llmtuner.webui.manager import Manager + + +class Runner: + + def __init__(self, manager: "Manager") -> None: + self.manager = manager + """ Resume """ + self.thread: "Thread" = None + self.do_train = True + self.running_data: Dict["Component", Any] = None + self.monitor_inputs: Dict[str, str] = None + """ State """ + self.aborted = False + self.running = False + """ Handler """ + self.logger_handler = LoggerHandler() + self.logger_handler.setLevel(logging.INFO) + logging.root.addHandler(self.logger_handler) + transformers.logging.add_handler(self.logger_handler) + + @property + def alive(self) -> bool: + return self.thread is not None + + def set_abort(self) -> None: + self.aborted = True + self.running = False + + def _initialize(self, data: Dict[Component, Any], do_train: bool) -> str: + get = lambda name: data[self.manager.get_elem_by_name(name)] + lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path") + dataset = get("train.dataset") if do_train else get("eval.dataset") + + if self.running: + return ALERTS["err_conflict"][lang] + + if not model_name: + return ALERTS["err_no_model"][lang] + + if not model_path: + return ALERTS["err_no_path"][lang] + + if len(dataset) == 0: + return ALERTS["err_no_dataset"][lang] + + self.aborted = False + self.logger_handler.reset() + self.trainer_callback = LogCallback(self) + return "" + + def _finalize(self, lang: str, finish_info: str) -> str: + self.thread = None + self.running = False + torch_gc() + if self.aborted: + return ALERTS["info_aborted"][lang] + else: + return finish_info + + def _parse_train_args(self, data: Dict[Component, Any]) -> Dict[str, Any]: + get = lambda name: data[self.manager.get_elem_by_name(name)] + user_config = load_config() + + if get("top.checkpoints"): + checkpoint_dir = ",".join([ + get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints") + ]) + else: + checkpoint_dir = None + + args = dict( + stage=TRAINING_STAGES[get("train.training_stage")], + model_name_or_path=get("top.model_path"), + do_train=True, + cache_dir=user_config.get("cache_dir", None), + checkpoint_dir=checkpoint_dir, + finetuning_type=get("top.finetuning_type"), + quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, + template=get("top.template"), + system_prompt=get("top.system_prompt"), + flash_attn=get("top.flash_attn"), + shift_attn=get("top.shift_attn"), + rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, + dataset_dir=get("train.dataset_dir"), + dataset=",".join(get("train.dataset")), + cutoff_len=get("train.cutoff_len"), + learning_rate=float(get("train.learning_rate")), + num_train_epochs=float(get("train.num_train_epochs")), + max_samples=int(get("train.max_samples")), + per_device_train_batch_size=get("train.batch_size"), + gradient_accumulation_steps=get("train.gradient_accumulation_steps"), + lr_scheduler_type=get("train.lr_scheduler_type"), + max_grad_norm=float(get("train.max_grad_norm")), + logging_steps=get("train.logging_steps"), + save_steps=get("train.save_steps"), + warmup_steps=get("train.warmup_steps"), + neft_alpha=get("train.neft_alpha"), + train_on_prompt=get("train.train_on_prompt"), + upcast_layernorm=get("train.upcast_layernorm"), + lora_rank=get("train.lora_rank"), + lora_dropout=get("train.lora_dropout"), + lora_target=get("train.lora_target") or get_module(get("top.model_name")), + additional_target=get("train.additional_target") if get("train.additional_target") else None, + resume_lora_training=get("train.resume_lora_training"), + output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")) + ) + args[get("train.compute_type")] = True + args["disable_tqdm"] = True + + if TRAINING_STAGES[get("train.training_stage")] in ["rm", "ppo", "dpo"]: + args["resume_lora_training"] = (args["quantization_bit"] is not None) + + if args["quantization_bit"] is not None: + args["upcast_layernorm"] = True + + if args["stage"] == "ppo": + args["reward_model"] = get("train.reward_model") + + if args["stage"] == "dpo": + args["dpo_beta"] = get("train.dpo_beta") + + if get("train.val_size") > 1e-6 and args["stage"] != "ppo": + args["val_size"] = get("train.val_size") + args["evaluation_strategy"] = "steps" + args["eval_steps"] = get("train.save_steps") + args["load_best_model_at_end"] = True + + return args + + def _parse_eval_args(self, data: Dict[Component, Any]) -> Dict[str, Any]: + get = lambda name: data[self.manager.get_elem_by_name(name)] + user_config = load_config() + + if get("top.checkpoints"): + checkpoint_dir = ",".join([ + get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints") + ]) + output_dir = get_save_dir( + get("top.model_name"), get("top.finetuning_type"), "eval_" + "_".join(get("top.checkpoints")) + ) + else: + checkpoint_dir = None + output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), "eval_base") + + args = dict( + stage="sft", + model_name_or_path=get("top.model_path"), + do_eval=True, + predict_with_generate=True, + cache_dir=user_config.get("cache_dir", None), + checkpoint_dir=checkpoint_dir, + finetuning_type=get("top.finetuning_type"), + quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, + template=get("top.template"), + system_prompt=get("top.system_prompt"), + flash_attn=get("top.flash_attn"), + shift_attn=get("top.shift_attn"), + rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, + dataset_dir=get("eval.dataset_dir"), + dataset=",".join(get("eval.dataset")), + cutoff_len=get("eval.cutoff_len"), + max_samples=int(get("eval.max_samples")), + per_device_eval_batch_size=get("eval.batch_size"), + max_new_tokens=get("eval.max_new_tokens"), + top_p=get("eval.top_p"), + temperature=get("eval.temperature"), + output_dir=output_dir + ) + + if get("eval.predict"): + args.pop("do_eval", None) + args["do_predict"] = True + + return args + + def _preview(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]: + error = self._initialize(data, do_train) + if error: + gr.Warning(error) + yield error, gr.update(visible=False) + else: + args = self._parse_train_args(data) if do_train else self._parse_eval_args(data) + yield gen_cmd(args), gr.update(visible=False) + + def _launch(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]: + error = self._initialize(data, do_train) + if error: + gr.Warning(error) + yield error, gr.update(visible=False) + else: + args = self._parse_train_args(data) if do_train else self._parse_eval_args(data) + run_kwargs = dict(args=args, callbacks=[self.trainer_callback]) + self.running = True + self.do_train, self.running_data = do_train, data + self.monitor_inputs = dict(lang=data[self.manager.get_elem_by_name("top.lang")], output_dir=args["output_dir"]) + self.thread = Thread(target=run_exp, kwargs=run_kwargs) + self.thread.start() + yield from self.monitor() + + def preview_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]: + yield from self._preview(data, do_train=True) + + def preview_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]: + yield from self._preview(data, do_train=False) + + def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]: + yield from self._launch(data, do_train=True) + + def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]: + yield from self._launch(data, do_train=False) + + def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]: + lang, output_dir = self.monitor_inputs["lang"], self.monitor_inputs["output_dir"] + while self.thread.is_alive(): + time.sleep(2) + if self.aborted: + yield ALERTS["info_aborting"][lang], gr.update(visible=False) + else: + yield self.logger_handler.log, update_process_bar(self.trainer_callback) + + if self.do_train: + if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)): + finish_info = ALERTS["info_finished"][lang] + else: + finish_info = ALERTS["err_failed"][lang] + else: + if os.path.exists(os.path.join(output_dir, "all_results.json")): + finish_info = get_eval_results(os.path.join(output_dir, "all_results.json")) + else: + finish_info = ALERTS["err_failed"][lang] + + yield self._finalize(lang, finish_info), gr.update(visible=False) diff --git a/llm_rl/src/llmtuner/webui/utils.py b/llm_rl/src/llmtuner/webui/utils.py new file mode 100644 index 00000000..933d951d --- /dev/null +++ b/llm_rl/src/llmtuner/webui/utils.py @@ -0,0 +1,85 @@ +import os +import json +import gradio as gr +import matplotlib.figure +import matplotlib.pyplot as plt +from typing import TYPE_CHECKING, Any, Dict +from datetime import datetime + +from llmtuner.extras.ploting import smooth +from llmtuner.webui.common import get_save_dir + +if TYPE_CHECKING: + from llmtuner.extras.callbacks import LogCallback + + +def update_process_bar(callback: "LogCallback") -> Dict[str, Any]: + if not callback.max_steps: + return gr.update(visible=False) + + percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0 + label = "Running {:d}/{:d}: {} < {}".format( + callback.cur_steps, + callback.max_steps, + callback.elapsed_time, + callback.remaining_time + ) + return gr.update(label=label, value=percentage, visible=True) + + +def get_time() -> str: + return datetime.now().strftime('%Y-%m-%d-%H-%M-%S') + + +def can_quantize(finetuning_type: str) -> Dict[str, Any]: + if finetuning_type != "lora": + return gr.update(value="None", interactive=False) + else: + return gr.update(interactive=True) + + +def gen_cmd(args: Dict[str, Any]) -> str: + args.pop("disable_tqdm", None) + args["plot_loss"] = args.get("do_train", None) + cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python src/train_bash.py "] + for k, v in args.items(): + if v is not None and v != "": + cmd_lines.append(" --{} {} ".format(k, str(v))) + cmd_text = "\\\n".join(cmd_lines) + cmd_text = "```bash\n{}\n```".format(cmd_text) + return cmd_text + + +def get_eval_results(path: os.PathLike) -> str: + with open(path, "r", encoding="utf-8") as f: + result = json.dumps(json.load(f), indent=4) + return "```json\n{}\n```\n".format(result) + + +def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotlib.figure.Figure: + if not base_model: + return + log_file = get_save_dir(base_model, finetuning_type, output_dir, "trainer_log.jsonl") + if not os.path.isfile(log_file): + return + + plt.close("all") + fig = plt.figure() + ax = fig.add_subplot(111) + steps, losses = [], [] + with open(log_file, "r", encoding="utf-8") as f: + for line in f: + log_info = json.loads(line) + if log_info.get("loss", None): + steps.append(log_info["current_steps"]) + losses.append(log_info["loss"]) + + if len(losses) == 0: + return None + + ax.plot(steps, losses, alpha=0.4, label="original") + ax.plot(steps, smooth(losses), label="smoothed") + ax.legend() + ax.set_xlabel("step") + ax.set_ylabel("loss") + return fig diff --git a/llm_rl/src/train_bash.py b/llm_rl/src/train_bash.py new file mode 100644 index 00000000..9ddd0586 --- /dev/null +++ b/llm_rl/src/train_bash.py @@ -0,0 +1,14 @@ +from llmtuner import run_exp + + +def main(): + run_exp() + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/llm_rl/src/train_web.py b/llm_rl/src/train_web.py new file mode 100644 index 00000000..38efd64d --- /dev/null +++ b/llm_rl/src/train_web.py @@ -0,0 +1,11 @@ +from llmtuner import create_ui + + +def main(): + demo = create_ui() + demo.queue() + demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True) + + +if __name__ == "__main__": + main() diff --git a/llm_rl/src/web_demo.py b/llm_rl/src/web_demo.py new file mode 100644 index 00000000..257536ab --- /dev/null +++ b/llm_rl/src/web_demo.py @@ -0,0 +1,11 @@ +from llmtuner import create_web_demo + + +def main(): + demo = create_web_demo() + demo.queue() + demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True) + + +if __name__ == "__main__": + main() diff --git a/llm_rl/tests/cal_flops.py b/llm_rl/tests/cal_flops.py new file mode 100644 index 00000000..01b005af --- /dev/null +++ b/llm_rl/tests/cal_flops.py @@ -0,0 +1,44 @@ +# coding=utf-8 +# Calculates the flops of pre-trained models. +# Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512 +# Inspired by: https://www.deepspeed.ai/tutorials/flops-profiler/ + +import fire +import torch +from typing import Optional +from deepspeed.accelerator import get_accelerator # type: ignore +from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore + +from llmtuner import ChatModel + + +def calculate( + model_name_or_path: str, + batch_size: Optional[int] = 1, + seq_length: Optional[int] = 256, + flash_attn: Optional[bool] = False +): + with get_accelerator().device(0): + chat_model = ChatModel(dict( + model_name_or_path=model_name_or_path, + template="vanilla", + flash_attn=flash_attn + )) + fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device) + input_dict = { + "input_ids": fake_input, + "labels": fake_input.clone() + } + flops, macs, params = get_model_profile( + chat_model.model, + kwargs=input_dict, + print_profile=True, + detailed=True + ) + print("FLOPs:", flops) + print("MACs:", macs) + print("Params:", params) + + +if __name__ == "__main__": + fire.Fire(calculate) diff --git a/llm_rl/tests/llamafy_baichuan2.py b/llm_rl/tests/llamafy_baichuan2.py new file mode 100644 index 00000000..d08eee1c --- /dev/null +++ b/llm_rl/tests/llamafy_baichuan2.py @@ -0,0 +1,86 @@ +# coding=utf-8 +# Converts the Baichuan2-7B model in the same format as LLaMA2-7B. +# Usage: python llamafy_baichuan2.py --input_dir input --output_dir output --shard_size 10GB +# Inspired by: https://huggingface.co/fireballoon/baichuan-llama-7b/blob/main/convert_baichuan_to_llama.py +# Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied + +import os +import fire +import json +import torch +from collections import OrderedDict +from transformers.modeling_utils import shard_checkpoint, WEIGHTS_NAME, WEIGHTS_INDEX_NAME +from typing import Any, Dict + + +CONFIG_NAME = "config.json" + + +def save_weight( + input_dir: str, + output_dir: str, + shard_size: str +): + baichuan2_state_dict: Dict[str, torch.Tensor] = OrderedDict() + for filepath in os.listdir(input_dir): + if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"): + shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu") + baichuan2_state_dict.update(shard_weight) + + llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict() + for key, value in baichuan2_state_dict.items(): + if "W_pack" in key: + proj_size = value.size(0) // 3 + llama2_state_dict[key.replace("W_pack", "q_proj")] = value[:proj_size, :] + llama2_state_dict[key.replace("W_pack", "k_proj")] = value[proj_size:2*proj_size, :] + llama2_state_dict[key.replace("W_pack", "v_proj")] = value[2*proj_size:, :] + elif "lm_head" in key: + llama2_state_dict[key] = torch.nn.functional.normalize(value) + else: + llama2_state_dict[key] = value + + shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=WEIGHTS_NAME) + for shard_file, shard in shards.items(): + torch.save(shard, os.path.join(output_dir, shard_file)) + + if index is None: + print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME))) + else: + with open(os.path.join(output_dir, WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f: + json.dump(index, f, indent=2, sort_keys=True) + print("Model weights saved in {}".format(output_dir)) + + +def save_config( + input_dir: str, + output_dir: str +): + with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f: + llama2_config_dict: Dict[str, Any] = json.load(f) + + llama2_config_dict["architectures"] = ["LlamaForCausalLM"] + llama2_config_dict.pop("auto_map", None) + llama2_config_dict.pop("tokenizer_class", None) + llama2_config_dict["model_type"] = "llama" + + with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f: + json.dump(llama2_config_dict, f, indent=2) + print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME))) + + +def llamafy_baichuan2( + input_dir: str, + output_dir: str, + shard_size: str +): + try: + os.makedirs(output_dir, exist_ok=False) + except Exception as e: + raise print("Output dir already exists", e) + + save_weight(input_dir, output_dir, shard_size) + save_config(input_dir, output_dir) + + +if __name__ == "__main__": + fire.Fire(llamafy_baichuan2) diff --git a/llm_rl/tests/llamafy_qwen.py b/llm_rl/tests/llamafy_qwen.py new file mode 100644 index 00000000..8b9fc395 --- /dev/null +++ b/llm_rl/tests/llamafy_qwen.py @@ -0,0 +1,135 @@ +# coding=utf-8 +# Converts the Qwen models in the same format as LLaMA2. +# Usage: python llamafy_qwen.py --input_dir input --output_dir output --shard_size 10GB + +import os +import fire +import json +import torch +from collections import OrderedDict +from safetensors import safe_open +from transformers.modeling_utils import shard_checkpoint, WEIGHTS_NAME, WEIGHTS_INDEX_NAME +from transformers.utils import check_min_version +from typing import Any, Dict + +try: + check_min_version("4.34.0") +except: + raise ValueError("Please upgrade `transformers` to 4.34.0") + + +CONFIG_NAME = "config.json" + + +def save_weight( + input_dir: str, + output_dir: str, + shard_size: str +) -> str: + qwen_state_dict: Dict[str, torch.Tensor] = OrderedDict() + for filepath in os.listdir(input_dir): + if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".safetensors"): + with safe_open(os.path.join(input_dir, filepath), framework="pt", device="cpu") as f: + for key in f.keys(): + qwen_state_dict[key] = f.get_tensor(key) + + llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict() + torch_dtype = None + for key, value in qwen_state_dict.items(): + if torch_dtype is None: + torch_dtype = value.dtype + if "wte" in key: + llama2_state_dict["model.embed_tokens.weight"] = value + elif "ln_f" in key: + llama2_state_dict["model.norm.weight"] = value + else: + key = key.replace("transformer.h", "model.layers") + if "attn.c_attn" in key: + proj_size = value.size(0) // 3 + llama2_state_dict[key.replace("attn.c_attn", "self_attn.q_proj")] = value[:proj_size, ...] + llama2_state_dict[key.replace("attn.c_attn", "self_attn.k_proj")] = value[proj_size:2*proj_size, ...] + llama2_state_dict[key.replace("attn.c_attn", "self_attn.v_proj")] = value[2*proj_size:, ...] + elif "attn.c_proj" in key: + llama2_state_dict[key.replace("attn.c_proj", "self_attn.o_proj")] = value + llama2_state_dict[key.replace("attn.c_proj.weight", "self_attn.o_proj.bias")] = ( + torch.zeros_like(value[:, 0]).squeeze() + ) + elif "ln_1" in key: + llama2_state_dict[key.replace("ln_1", "input_layernorm")] = value + elif "ln_2" in key: + llama2_state_dict[key.replace("ln_2", "post_attention_layernorm")] = value + elif "mlp.w1" in key: + llama2_state_dict[key.replace("mlp.w1", "mlp.up_proj")] = value + elif "mlp.w2" in key: + llama2_state_dict[key.replace("mlp.w2", "mlp.gate_proj")] = value + elif "mlp.c_proj" in key: + llama2_state_dict[key.replace("mlp.c_proj", "mlp.down_proj")] = value + elif "lm_head" in key: + llama2_state_dict[key] = value + else: + raise KeyError("Unable to process key {}".format(key)) + + shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=WEIGHTS_NAME) + for shard_file, shard in shards.items(): + torch.save(shard, os.path.join(output_dir, shard_file)) + + if index is None: + print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME))) + else: + with open(os.path.join(output_dir, WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f: + json.dump(index, f, indent=2, sort_keys=True) + print("Model weights saved in {}".format(output_dir)) + + return str(torch_dtype).replace("torch.", "") + + +def save_config( + input_dir: str, + output_dir: str, + torch_dtype: str +): + with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f: + qwen_config_dict: Dict[str, Any] = json.load(f) + + llama2_config_dict: Dict[str, Any] = OrderedDict() + llama2_config_dict["architectures"] = ["LlamaForCausalLM"] + llama2_config_dict["hidden_act"] = "silu" + llama2_config_dict["hidden_size"] = qwen_config_dict["hidden_size"] + llama2_config_dict["initializer_range"] = qwen_config_dict["initializer_range"] + llama2_config_dict["intermediate_size"] = qwen_config_dict["intermediate_size"] // 2 + llama2_config_dict["max_position_embeddings"] = qwen_config_dict["max_position_embeddings"] + llama2_config_dict["model_type"] = "llama" + llama2_config_dict["num_attention_heads"] = qwen_config_dict["num_attention_heads"] + llama2_config_dict["num_hidden_layers"] = qwen_config_dict["num_hidden_layers"] + llama2_config_dict["num_key_value_heads"] = qwen_config_dict["hidden_size"] // qwen_config_dict["kv_channels"] + llama2_config_dict["pretraining_tp"] = 1 + llama2_config_dict["rms_norm_eps"] = qwen_config_dict["layer_norm_epsilon"] + llama2_config_dict["rope_scaling"] = None + llama2_config_dict["tie_word_embeddings"] = qwen_config_dict["tie_word_embeddings"] + llama2_config_dict["torch_dtype"] = torch_dtype + llama2_config_dict["transformers_version"] = "4.34.0" + llama2_config_dict["use_cache"] = True + llama2_config_dict["vocab_size"] = qwen_config_dict["vocab_size"] + llama2_config_dict["attention_bias"] = True + + with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f: + json.dump(llama2_config_dict, f, indent=2) + print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME))) + + +def llamafy_qwen( + input_dir: str, + output_dir: str, + shard_size: str +): + try: + os.makedirs(output_dir, exist_ok=False) + except Exception as e: + raise print("Output dir already exists", e) + + torch_dtype = save_weight(input_dir, output_dir, shard_size) + save_config(input_dir, output_dir, torch_dtype) + + +if __name__ == "__main__": + fire.Fire(llamafy_qwen) diff --git a/llm_rl/tests/quantize.py b/llm_rl/tests/quantize.py new file mode 100644 index 00000000..25321cf3 --- /dev/null +++ b/llm_rl/tests/quantize.py @@ -0,0 +1,50 @@ +# coding=utf-8 +# Quantizes models with AutoGPTQ (https://github.com/PanQiWei/AutoGPTQ). +# Usage: python quantize.py --input_dir path_to_llama_model --output_dir path_to_quant_model --data_file alpaca.json +# --max_length 1024 --max_samples 1024 +# dataset format: instruction (string), input (string), output (string), history (List[string]) + + +import fire +from datasets import load_dataset +from transformers import AutoTokenizer +from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig + + +def quantize(input_dir: str, output_dir: str, data_file: str, max_length: int, max_samples: int): + tokenizer = AutoTokenizer.from_pretrained(input_dir, use_fast=False, padding_side="left") + + def format_example(examples): + prefix=("A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.") + texts = [] + for i in range(len(examples["instruction"])): + prompt = prefix + "\n" + if "history" in examples: + for user_query, bot_resp in examples["history"][i]: + prompt += "Human: {}\nAssistant: {}\n".format(user_query, bot_resp) + prompt += "Human: {}\nAssistant: {}".format( + examples["instruction"][i] + "\n" + examples["input"][i], examples["output"][i] + ) + texts.append(prompt) + return tokenizer(texts, truncation=True, max_length=max_length) + + dataset = load_dataset("json", data_files=data_file)["train"] + column_names = list(dataset.column_names) + dataset = dataset.select(range(min(len(dataset), max_samples))) + dataset = dataset.map(format_example, batched=True, remove_columns=column_names) + dataset = dataset.shuffle() + + quantize_config = BaseQuantizeConfig( + bits=4, + group_size=128, + desc_act=False + ) + + model = AutoGPTQForCausalLM.from_pretrained(input_dir, quantize_config, trust_remote_code=True) + model.quantize(dataset) + model.save_quantized(output_dir) + + +if __name__ == "__main__": + fire.Fire(quantize)