diff --git a/.gitignore b/.gitignore
index 3e563d1d..83339037 100644
--- a/.gitignore
+++ b/.gitignore
@@ -5,6 +5,13 @@ __pycache__
dist
.venv
+# Byte-compiled / optimized / DLL files
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
# Log
*.log
*.log.*
@@ -33,4 +40,157 @@ tests/state_of_the_union.txt
# Build
build
-!dummy_file
\ No newline at end of file
+!dummy_file
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/
\ No newline at end of file
diff --git a/llm_rl/LICENSE b/llm_rl/LICENSE
new file mode 100644
index 00000000..b09cd785
--- /dev/null
+++ b/llm_rl/LICENSE
@@ -0,0 +1,201 @@
+Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/llm_rl/README.md b/llm_rl/README.md
new file mode 100644
index 00000000..54da92da
--- /dev/null
+++ b/llm_rl/README.md
@@ -0,0 +1,501 @@
+# LLaMA Factory: Training and Evaluating Large Language Models with Minimal Effort
+
+[![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Factory?style=social)](https://github.com/hiyouga/LLaMA-Factory/stargazers)
+[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
+[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
+[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/)
+[![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/)
+[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
+[![Discord](https://dcbadge.vercel.app/api/server/e73gccsSd?compact=true&style=flat)](https://discord.gg/e73gccsSd)
+
+👋 Join our [WeChat](assets/wechat.jpg).
+
+\[ English | [中文](README_zh.md) \]
+
+## LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory
+
+Launch **LLaMA Board** via `CUDA_VISIBLE_DEVICES=0 python src/train_web.py`. (multiple GPUs are not supported yet)
+
+Here is an example of altering the self-cognition of an instruction-tuned language model within 10 minutes on a single GPU.
+
+https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846-2d88920d5ba1
+
+## Changelog
+
+[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neft_alpha` argument to activate NEFTune, e.g., `--neft_alpha 5`.
+
+[23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `--shift_attn` argument to enable shift short attention.
+
+[23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models.
+
+[23/09/10] We supported using **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)** for the LLaMA models. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
+
+[23/08/12] We supported **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings.
+
+[23/08/11] We supported **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [this example](#dpo-training) to train your models.
+
+[23/07/31] We supported **dataset streaming**. Try `--streaming` and `--max_steps 10000` arguments to load your dataset in streaming mode.
+
+[23/07/29] We released two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft)) for details.
+
+[23/07/18] We developed an **all-in-one Web UI** for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
+
+[23/07/09] We released **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
+
+[23/06/29] We provided a **reproducible example** of training a chat model using instruction-following datasets, see [Baichuan-7B-sft](https://huggingface.co/hiyouga/Baichuan-7B-sft) for details.
+
+[23/06/22] We aligned the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in **arbitrary ChatGPT-based applications**.
+
+[23/06/03] We supported quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). Try `--quantization_bit 4/8` argument to work with quantized models.
+
+## Supported Models
+
+| Model | Model size | Default module | Template |
+| -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
+| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan |
+| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
+| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
+| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
+| [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 6B | query_key_value | chatglm3 |
+| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | - |
+| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
+| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
+| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
+| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
+| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
+| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B/14B | c_attn | qwen |
+| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse |
+
+> [!NOTE]
+> **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules.
+>
+> For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "chat" models.
+
+Please refer to [template.py](src/llmtuner/extras/template.py) for a full list of models we supported.
+
+## Supported Training Approaches
+
+| Approach | Full-parameter | Partial-parameter | LoRA | QLoRA |
+| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
+| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
+| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
+| Reward Modeling | | | :white_check_mark: | :white_check_mark: |
+| PPO Training | | | :white_check_mark: | :white_check_mark: |
+| DPO Training | :white_check_mark: | | :white_check_mark: | :white_check_mark: |
+
+> [!NOTE]
+> Use `--quantization_bit 4/8` argument to enable QLoRA.
+
+## Provided Datasets
+
+Pre-training datasets
+
+- [Wiki Demo (en)](data/wiki_demo.txt)
+- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
+- [RedPajama V2 (en)](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-V2)
+- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
+- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
+- [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile)
+- [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B)
+- [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack)
+- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
+
+
+
+Supervised fine-tuning datasets
+
+- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
+- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
+- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
+- [Self-cognition (zh)](data/self_cognition.json)
+- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
+- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
+- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
+- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
+- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
+- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
+- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
+- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
+- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
+- [UltraChat (en)](https://github.com/thunlp/UltraChat)
+- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
+- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
+- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
+- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
+- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
+- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
+- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
+- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
+- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
+- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
+- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
+- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
+- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
+- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
+- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
+
+
+
+Preference datasets
+
+- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
+- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
+- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
+
+
+
+Please refer to [data/README.md](data/README.md) for details.
+
+Some datasets require confirmation before using them, so we recommend logging in with your Hugging Face account using these commands.
+
+```bash
+pip install --upgrade huggingface_hub
+huggingface-cli login
+```
+
+## Requirement
+
+- Python 3.8+ and PyTorch 1.13.1+
+- 🤗Transformers, Datasets, Accelerate, PEFT and TRL
+- sentencepiece, protobuf and tiktoken
+- fire, jieba, rouge-chinese and nltk (used at evaluation and predict)
+- gradio and matplotlib (used in web UI)
+- uvicorn, fastapi and sse-starlette (used in API)
+
+And **powerful GPUs**!
+
+## Getting Started
+
+### Data Preparation (optional)
+
+Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use a single `.json` file or a [dataset loading script](https://huggingface.co/docs/datasets/dataset_script) with multiple files to create a custom dataset.
+
+> [!NOTE]
+> Please update `data/dataset_info.json` to use your custom dataset. About the format of this file, please refer to `data/README.md`.
+
+### Dependence Installation (optional)
+
+```bash
+git clone https://github.com/hiyouga/LLaMA-Factory.git
+conda create -n llama_factory python=3.10
+conda activate llama_factory
+cd LLaMA-Factory
+pip install -r requirements.txt
+```
+
+If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you will be required to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.1.
+
+```bash
+pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
+```
+
+### Train on a single GPU
+
+> [!IMPORTANT]
+> If you want to train models on multiple GPUs, please refer to [Distributed Training](#distributed-training).
+
+#### Pre-Training
+
+```bash
+CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
+ --stage pt \
+ --model_name_or_path path_to_llama_model \
+ --do_train \
+ --dataset wiki_demo \
+ --finetuning_type lora \
+ --lora_target q_proj,v_proj \
+ --output_dir path_to_pt_checkpoint \
+ --overwrite_cache \
+ --per_device_train_batch_size 4 \
+ --gradient_accumulation_steps 4 \
+ --lr_scheduler_type cosine \
+ --logging_steps 10 \
+ --save_steps 1000 \
+ --learning_rate 5e-5 \
+ --num_train_epochs 3.0 \
+ --plot_loss \
+ --fp16
+```
+
+#### Supervised Fine-Tuning
+
+```bash
+CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
+ --stage sft \
+ --model_name_or_path path_to_llama_model \
+ --do_train \
+ --dataset alpaca_gpt4_en \
+ --template default \
+ --finetuning_type lora \
+ --lora_target q_proj,v_proj \
+ --output_dir path_to_sft_checkpoint \
+ --overwrite_cache \
+ --per_device_train_batch_size 4 \
+ --gradient_accumulation_steps 4 \
+ --lr_scheduler_type cosine \
+ --logging_steps 10 \
+ --save_steps 1000 \
+ --learning_rate 5e-5 \
+ --num_train_epochs 3.0 \
+ --plot_loss \
+ --fp16
+```
+
+#### Reward Modeling
+
+```bash
+CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
+ --stage rm \
+ --model_name_or_path path_to_llama_model \
+ --do_train \
+ --dataset comparison_gpt4_en \
+ --template default \
+ --finetuning_type lora \
+ --lora_target q_proj,v_proj \
+ --resume_lora_training False \
+ --checkpoint_dir path_to_sft_checkpoint \
+ --output_dir path_to_rm_checkpoint \
+ --per_device_train_batch_size 2 \
+ --gradient_accumulation_steps 4 \
+ --lr_scheduler_type cosine \
+ --logging_steps 10 \
+ --save_steps 1000 \
+ --learning_rate 1e-6 \
+ --num_train_epochs 1.0 \
+ --plot_loss \
+ --fp16
+```
+
+#### PPO Training
+
+```bash
+CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
+ --stage ppo \
+ --model_name_or_path path_to_llama_model \
+ --do_train \
+ --dataset alpaca_gpt4_en \
+ --template default \
+ --finetuning_type lora \
+ --lora_target q_proj,v_proj \
+ --resume_lora_training False \
+ --checkpoint_dir path_to_sft_checkpoint \
+ --reward_model path_to_rm_checkpoint \
+ --output_dir path_to_ppo_checkpoint \
+ --per_device_train_batch_size 2 \
+ --gradient_accumulation_steps 4 \
+ --lr_scheduler_type cosine \
+ --logging_steps 10 \
+ --save_steps 1000 \
+ --learning_rate 1e-5 \
+ --num_train_epochs 1.0 \
+ --plot_loss \
+ --fp16
+```
+
+#### DPO Training
+
+```bash
+CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
+ --stage dpo \
+ --model_name_or_path path_to_llama_model \
+ --do_train \
+ --dataset comparison_gpt4_en \
+ --template default \
+ --finetuning_type lora \
+ --lora_target q_proj,v_proj \
+ --resume_lora_training False \
+ --checkpoint_dir path_to_sft_checkpoint \
+ --output_dir path_to_dpo_checkpoint \
+ --per_device_train_batch_size 2 \
+ --gradient_accumulation_steps 4 \
+ --lr_scheduler_type cosine \
+ --logging_steps 10 \
+ --save_steps 1000 \
+ --learning_rate 1e-5 \
+ --num_train_epochs 1.0 \
+ --plot_loss \
+ --fp16
+```
+
+### Distributed Training
+
+#### Use Huggingface Accelerate
+
+```bash
+accelerate config # configure the environment
+accelerate launch src/train_bash.py # arguments (same as above)
+```
+
+Example config for LoRA training
+
+```yaml
+compute_environment: LOCAL_MACHINE
+distributed_type: MULTI_GPU
+downcast_bf16: 'no'
+gpu_ids: all
+machine_rank: 0
+main_training_function: main
+mixed_precision: fp16
+num_machines: 1
+num_processes: 4
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
+```
+
+
+
+#### Use DeepSpeed
+
+```bash
+deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
+ --deepspeed ds_config.json \
+ ... # arguments (same as above)
+```
+
+Example config for full-parameter training with DeepSpeed ZeRO-2
+
+```json
+{
+ "train_batch_size": "auto",
+ "train_micro_batch_size_per_gpu": "auto",
+ "gradient_accumulation_steps": "auto",
+ "gradient_clipping": "auto",
+ "zero_allow_untested_optimizer": true,
+ "fp16": {
+ "enabled": "auto",
+ "loss_scale": 0,
+ "initial_scale_power": 16,
+ "loss_scale_window": 1000,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "zero_optimization": {
+ "stage": 2,
+ "allgather_partitions": true,
+ "allgather_bucket_size": 5e8,
+ "reduce_scatter": true,
+ "reduce_bucket_size": 5e8,
+ "overlap_comm": false,
+ "contiguous_gradients": true
+ }
+}
+```
+
+
+
+### Export model
+
+```bash
+python src/export_model.py \
+ --model_name_or_path path_to_llama_model \
+ --template default \
+ --finetuning_type lora \
+ --checkpoint_dir path_to_checkpoint \
+ --export_dir path_to_export
+```
+
+### API Demo
+
+```bash
+python src/api_demo.py \
+ --model_name_or_path path_to_llama_model \
+ --template default \
+ --finetuning_type lora \
+ --checkpoint_dir path_to_checkpoint
+```
+
+> [!NOTE]
+> Visit `http://localhost:8000/docs` for API documentation.
+
+### CLI Demo
+
+```bash
+python src/cli_demo.py \
+ --model_name_or_path path_to_llama_model \
+ --template default \
+ --finetuning_type lora \
+ --checkpoint_dir path_to_checkpoint
+```
+
+### Web Demo
+
+```bash
+python src/web_demo.py \
+ --model_name_or_path path_to_llama_model \
+ --template default \
+ --finetuning_type lora \
+ --checkpoint_dir path_to_checkpoint
+```
+
+### Evaluation
+
+```bash
+CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
+ --model_name_or_path path_to_llama_model \
+ --finetuning_type lora \
+ --checkpoint_dir path_to_checkpoint \
+ --template vanilla \
+ --task mmlu \
+ --split test \
+ --lang en \
+ --n_shot 5 \
+ --batch_size 4
+```
+
+### Predict
+
+```bash
+CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
+ --stage sft \
+ --model_name_or_path path_to_llama_model \
+ --do_predict \
+ --dataset alpaca_gpt4_en \
+ --template default \
+ --finetuning_type lora \
+ --checkpoint_dir path_to_checkpoint \
+ --output_dir path_to_predict_result \
+ --per_device_eval_batch_size 8 \
+ --max_samples 100 \
+ --predict_with_generate
+```
+
+> [!NOTE]
+> We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit predict.
+
+## Projects using LLaMA Factory
+
+- **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
+- **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
+- **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
+- **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
+
+## License
+
+This repository is licensed under the [Apache-2.0 License](LICENSE).
+
+Please follow the model licenses to use the corresponding model weights: [Baichuan](https://huggingface.co/baichuan-inc/Baichuan-13B-Base/resolve/main/Community%20License%20for%20Baichuan-13B%20Model.pdf) / [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/resolve/main/Community%20License%20for%20Baichuan2%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [InternLM](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
+
+## Citation
+
+If this work is helpful, please kindly cite as:
+
+```bibtex
+@Misc{llama-factory,
+ title = {LLaMA Factory},
+ author = {hiyouga},
+ howpublished = {\url{https://github.com/hiyouga/LLaMA-Factory}},
+ year = {2023}
+}
+```
+
+## Acknowledgement
+
+This repo benefits from [PEFT](https://github.com/huggingface/peft), [QLoRA](https://github.com/artidoro/qlora) and [FastChat](https://github.com/lm-sys/FastChat). Thanks for their wonderful works.
+
+## Star History
+
+![Star History Chart](https://api.star-history.com/svg?repos=hiyouga/LLaMA-Factory&type=Date)
diff --git a/llm_rl/README_zh.md b/llm_rl/README_zh.md
new file mode 100644
index 00000000..c69e3983
--- /dev/null
+++ b/llm_rl/README_zh.md
@@ -0,0 +1,500 @@
+# LLaMA Factory: 轻松的大模型训练与评估
+
+[![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Factory?style=social)](https://github.com/hiyouga/LLaMA-Factory/stargazers)
+[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
+[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
+[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/)
+[![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/)
+[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
+[![Discord](https://dcbadge.vercel.app/api/server/e73gccsSd?compact=true&style=flat)](https://discord.gg/e73gccsSd)
+
+👋 加入我们的[微信群](assets/wechat.jpg)。
+
+\[ [English](README.md) | 中文 \]
+
+## LLaMA Board: 通过一站式网页界面快速上手 LLaMA Factory
+
+使用 `CUDA_VISIBLE_DEVICES=0 python src/train_web.py` 启动 **LLaMA Board**。(该界面目前仅支持单卡训练)
+
+下面是使用单张 GPU 在 10 分钟内更改对话式大型语言模型自我认知的示例。
+
+https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846-2d88920d5ba1
+
+## 更新日志
+
+[23/10/21] 我们支持了 **[NEFTune](https://arxiv.org/abs/2310.05914)** 训练技巧。请使用 `--neft_alpha` 参数启用 NEFTune,例如 `--neft_alpha 5`。
+
+[23/09/27] 我们针对 LLaMA 模型支持了 [LongLoRA](https://github.com/dvlab-research/LongLoRA) 提出的 **$S^2$-Attn**。请使用 `--shift_attn` 参数以启用该功能。
+
+[23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。
+
+[23/09/10] 我们针对 LLaMA 模型支持了 **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU,请使用 `--flash_attn` 参数以启用 FlashAttention-2。
+
+[23/08/12] 我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型。
+
+[23/08/11] 我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。使用方法请参阅[此示例](#dpo-训练)。
+
+[23/07/31] 我们支持了**数据流式加载**。请尝试使用 `--streaming` 和 `--max_steps 10000` 参数来流式加载数据集。
+
+[23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft))。
+
+[23/07/18] 我们开发了支持训练和测试的**浏览器一体化界面**。请使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。
+
+[23/07/09] 我们开源了 **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹,一个简单易用的、能迅速编辑大模型事实记忆的工具包。如果您感兴趣请关注我们的 [FastEdit](https://github.com/hiyouga/FastEdit) 项目。
+
+[23/06/29] 我们提供了一个**可复现的**指令模型微调示例,详细内容请查阅 [Baichuan-7B-sft](https://huggingface.co/hiyouga/Baichuan-7B-sft)。
+
+[23/06/22] 我们对齐了[示例 API](src/api_demo.py) 与 [OpenAI API](https://platform.openai.com/docs/api-reference/chat) 的格式,您可以将微调模型接入**任意基于 ChatGPT 的应用**中。
+
+[23/06/03] 我们实现了 4 比特的 LoRA 训练(也称 **[QLoRA](https://github.com/artidoro/qlora)**)。请使用 `--quantization_bit 4` 参数进行 4 比特量化微调。
+
+## 模型
+
+| 模型名 | 模型大小 | 默认模块 | Template |
+| -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
+| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan |
+| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
+| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
+| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
+| [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 6B | query_key_value | chatglm3 |
+| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | - |
+| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
+| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
+| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
+| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
+| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
+| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B/14B | c_attn | qwen |
+| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse |
+
+> [!NOTE]
+> **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块。
+>
+> 对于所有“基座”(Base)模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Chat)模型请务必使用**对应的模板**。
+
+项目所支持模型的完整列表请参阅 [template.py](src/llmtuner/extras/template.py)。
+
+## 训练方法
+
+| 方法 | 全参数训练 | 部分参数训练 | LoRA | QLoRA |
+| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
+| 预训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
+| 指令监督微调 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
+| 奖励模型训练 | | | :white_check_mark: | :white_check_mark: |
+| PPO 训练 | | | :white_check_mark: | :white_check_mark: |
+| DPO 训练 | :white_check_mark: | | :white_check_mark: | :white_check_mark: |
+
+> [!NOTE]
+> 请使用 `--quantization_bit 4/8` 参数来启用 QLoRA 训练。
+
+## 数据集
+
+预训练数据集
+
+- [Wiki Demo (en)](data/wiki_demo.txt)
+- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
+- [RedPajama V2 (en)](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-V2)
+- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
+- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
+- [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile)
+- [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B)
+- [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack)
+- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
+
+
+
+指令微调数据集
+
+- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
+- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
+- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
+- [Self-cognition (zh)](data/self_cognition.json)
+- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
+- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
+- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
+- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
+- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
+- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
+- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
+- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
+- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
+- [UltraChat (en)](https://github.com/thunlp/UltraChat)
+- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
+- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
+- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
+- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
+- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
+- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
+- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
+- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
+- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
+- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
+- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
+- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
+- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
+- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
+- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
+
+
+
+偏好数据集
+
+- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
+- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
+- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
+
+
+
+使用方法请参考 [data/README_zh.md](data/README_zh.md) 文件。
+
+部分数据集的使用需要确认,我们推荐使用下述命令登录您的 Hugging Face 账户。
+
+```bash
+pip install --upgrade huggingface_hub
+huggingface-cli login
+```
+
+## 软件依赖
+
+- Python 3.8+ 和 PyTorch 1.13.1+
+- 🤗Transformers, Datasets, Accelerate, PEFT 和 TRL
+- sentencepiece, protobuf 和 tiktoken
+- fire, jieba, rouge-chinese 和 nltk (用于评估及预测)
+- gradio 和 matplotlib (用于网页端交互)
+- uvicorn, fastapi 和 sse-starlette (用于 API)
+
+以及 **强而有力的 GPU**!
+
+## 如何使用
+
+### 数据准备(可跳过)
+
+关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。构建自定义数据集时,既可以使用单个 `.json` 文件,也可以使用一个[数据加载脚本](https://huggingface.co/docs/datasets/dataset_script)和多个文件。
+
+> [!NOTE]
+> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件,该文件的格式请参考 `data/README_zh.md`。
+
+### 环境搭建(可跳过)
+
+```bash
+git clone https://github.com/hiyouga/LLaMA-Factory.git
+conda create -n llama_factory python=3.10
+conda activate llama_factory
+cd LLaMA-Factory
+pip install -r requirements.txt
+```
+
+如果要在 Windows 平台上开启量化 LoRA(QLoRA),需要安装预编译的 `bitsandbytes` 库, 支持 CUDA 11.1 到 12.1.
+
+```bash
+pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
+```
+
+### 单 GPU 训练
+
+> [!IMPORTANT]
+> 如果您使用多张 GPU 训练模型,请移步[多 GPU 分布式训练](#多-gpu-分布式训练)部分。
+
+#### 预训练
+
+```bash
+CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
+ --stage pt \
+ --model_name_or_path path_to_llama_model \
+ --do_train \
+ --dataset wiki_demo \
+ --finetuning_type lora \
+ --lora_target q_proj,v_proj \
+ --output_dir path_to_pt_checkpoint \
+ --overwrite_cache \
+ --per_device_train_batch_size 4 \
+ --gradient_accumulation_steps 4 \
+ --lr_scheduler_type cosine \
+ --logging_steps 10 \
+ --save_steps 1000 \
+ --learning_rate 5e-5 \
+ --num_train_epochs 3.0 \
+ --plot_loss \
+ --fp16
+```
+
+#### 指令监督微调
+
+```bash
+CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
+ --stage sft \
+ --model_name_or_path path_to_llama_model \
+ --do_train \
+ --dataset alpaca_gpt4_zh \
+ --template default \
+ --finetuning_type lora \
+ --lora_target q_proj,v_proj \
+ --output_dir path_to_sft_checkpoint \
+ --overwrite_cache \
+ --per_device_train_batch_size 4 \
+ --gradient_accumulation_steps 4 \
+ --lr_scheduler_type cosine \
+ --logging_steps 10 \
+ --save_steps 1000 \
+ --learning_rate 5e-5 \
+ --num_train_epochs 3.0 \
+ --plot_loss \
+ --fp16
+```
+
+#### 奖励模型训练
+
+```bash
+CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
+ --stage rm \
+ --model_name_or_path path_to_llama_model \
+ --do_train \
+ --dataset comparison_gpt4_zh \
+ --template default \
+ --finetuning_type lora \
+ --lora_target q_proj,v_proj \
+ --resume_lora_training False \
+ --checkpoint_dir path_to_sft_checkpoint \
+ --output_dir path_to_rm_checkpoint \
+ --per_device_train_batch_size 2 \
+ --gradient_accumulation_steps 4 \
+ --lr_scheduler_type cosine \
+ --logging_steps 10 \
+ --save_steps 1000 \
+ --learning_rate 1e-6 \
+ --num_train_epochs 1.0 \
+ --plot_loss \
+ --fp16
+```
+
+#### PPO 训练
+
+```bash
+CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
+ --stage ppo \
+ --model_name_or_path path_to_llama_model \
+ --do_train \
+ --dataset alpaca_gpt4_zh \
+ --template default \
+ --finetuning_type lora \
+ --lora_target q_proj,v_proj \
+ --resume_lora_training False \
+ --checkpoint_dir path_to_sft_checkpoint \
+ --reward_model path_to_rm_checkpoint \
+ --output_dir path_to_ppo_checkpoint \
+ --per_device_train_batch_size 2 \
+ --gradient_accumulation_steps 4 \
+ --lr_scheduler_type cosine \
+ --logging_steps 10 \
+ --save_steps 1000 \
+ --learning_rate 1e-5 \
+ --num_train_epochs 1.0 \
+ --plot_loss
+```
+
+#### DPO 训练
+
+```bash
+CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
+ --stage dpo \
+ --model_name_or_path path_to_llama_model \
+ --do_train \
+ --dataset comparison_gpt4_zh \
+ --template default \
+ --finetuning_type lora \
+ --lora_target q_proj,v_proj \
+ --resume_lora_training False \
+ --checkpoint_dir path_to_sft_checkpoint \
+ --output_dir path_to_dpo_checkpoint \
+ --per_device_train_batch_size 2 \
+ --gradient_accumulation_steps 4 \
+ --lr_scheduler_type cosine \
+ --logging_steps 10 \
+ --save_steps 1000 \
+ --learning_rate 1e-5 \
+ --num_train_epochs 1.0 \
+ --plot_loss \
+ --fp16
+```
+
+### 多 GPU 分布式训练
+
+#### 使用 Huggingface Accelerate
+
+```bash
+accelerate config # 首先配置分布式环境
+accelerate launch src/train_bash.py # 参数同上
+```
+
+LoRA 训练的 Accelerate 配置示例
+
+```yaml
+compute_environment: LOCAL_MACHINE
+distributed_type: MULTI_GPU
+downcast_bf16: 'no'
+gpu_ids: all
+machine_rank: 0
+main_training_function: main
+mixed_precision: fp16
+num_machines: 1
+num_processes: 4
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
+```
+
+
+
+#### 使用 DeepSpeed
+
+```bash
+deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
+ --deepspeed ds_config.json \
+ ... # 参数同上
+```
+
+使用 DeepSpeed ZeRO-2 进行全参数训练的 DeepSpeed 配置示例
+
+```json
+{
+ "train_batch_size": "auto",
+ "train_micro_batch_size_per_gpu": "auto",
+ "gradient_accumulation_steps": "auto",
+ "gradient_clipping": "auto",
+ "zero_allow_untested_optimizer": true,
+ "fp16": {
+ "enabled": "auto",
+ "loss_scale": 0,
+ "initial_scale_power": 16,
+ "loss_scale_window": 1000,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "zero_optimization": {
+ "stage": 2,
+ "allgather_partitions": true,
+ "allgather_bucket_size": 5e8,
+ "reduce_scatter": true,
+ "reduce_bucket_size": 5e8,
+ "overlap_comm": false,
+ "contiguous_gradients": true
+ }
+}
+```
+
+
+
+### 导出微调后的完整模型
+
+```bash
+python src/export_model.py \
+ --model_name_or_path path_to_llama_model \
+ --template default \
+ --finetuning_type lora \
+ --checkpoint_dir path_to_checkpoint \
+ --export_dir path_to_export
+```
+
+### API 服务
+
+```bash
+python src/api_demo.py \
+ --model_name_or_path path_to_llama_model \
+ --template default \
+ --finetuning_type lora \
+ --checkpoint_dir path_to_checkpoint
+```
+
+> [!NOTE]
+> 关于 API 文档请见 `http://localhost:8000/docs`。
+
+### 命令行测试
+
+```bash
+python src/cli_demo.py \
+ --model_name_or_path path_to_llama_model \
+ --template default \
+ --finetuning_type lora \
+ --checkpoint_dir path_to_checkpoint
+```
+
+### 浏览器测试
+
+```bash
+python src/web_demo.py \
+ --model_name_or_path path_to_llama_model \
+ --template default \
+ --finetuning_type lora \
+ --checkpoint_dir path_to_checkpoint
+```
+
+### 模型评估
+
+```bash
+CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
+ --model_name_or_path path_to_llama_model \
+ --finetuning_type lora \
+ --checkpoint_dir path_to_checkpoint \
+ --template vanilla \
+ --task ceval \
+ --split validation \
+ --lang zh \
+ --n_shot 5 \
+ --batch_size 4
+```
+
+### 模型预测
+
+```bash
+CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
+ --stage sft \
+ --model_name_or_path path_to_llama_model \
+ --do_predict \
+ --dataset alpaca_gpt4_zh \
+ --template default \
+ --finetuning_type lora \
+ --checkpoint_dir path_to_checkpoint \
+ --output_dir path_to_predict_result \
+ --per_device_eval_batch_size 8 \
+ --max_samples 100 \
+ --predict_with_generate
+```
+
+> [!NOTE]
+> 我们建议在量化模型的预测中使用 `--per_device_eval_batch_size=1` 和 `--max_target_length 128`。
+
+## 使用了 LLaMA Factory 的项目
+
+- **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper,基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
+- **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM,基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
+- **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
+- **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT,基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
+
+## 协议
+
+本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
+
+使用模型权重时,请遵循对应的模型协议:[Baichuan](https://huggingface.co/baichuan-inc/Baichuan-13B-Base/resolve/main/Community%20License%20for%20Baichuan-13B%20Model.pdf) / [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/resolve/main/Community%20License%20for%20Baichuan2%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [InternLM](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
+
+## 引用
+
+如果您觉得此项目有帮助,请考虑以下列格式引用
+
+```bibtex
+@Misc{llama-factory,
+ title = {LLaMA Factory},
+ author = {hiyouga},
+ howpublished = {\url{https://github.com/hiyouga/LLaMA-Factory}},
+ year = {2023}
+}
+```
+
+## 致谢
+
+本项目受益于 [PEFT](https://github.com/huggingface/peft)、[QLoRA](https://github.com/artidoro/qlora) 和 [FastChat](https://github.com/lm-sys/FastChat),感谢以上诸位作者的付出。
+
+## Star History
+
+![Star History Chart](https://api.star-history.com/svg?repos=hiyouga/LLaMA-Factory&type=Date)
diff --git a/llm_rl/assets/wechat.jpg b/llm_rl/assets/wechat.jpg
new file mode 100644
index 00000000..8df68f9c
Binary files /dev/null and b/llm_rl/assets/wechat.jpg differ
diff --git a/llm_rl/pyproject.toml b/llm_rl/pyproject.toml
new file mode 100644
index 00000000..638dd9c5
--- /dev/null
+++ b/llm_rl/pyproject.toml
@@ -0,0 +1,3 @@
+[build-system]
+requires = ["setuptools>=61.0"]
+build-backend = "setuptools.build_meta"
diff --git a/llm_rl/requirements.txt b/llm_rl/requirements.txt
new file mode 100644
index 00000000..840d2f2d
--- /dev/null
+++ b/llm_rl/requirements.txt
@@ -0,0 +1,20 @@
+torch>=1.13.1
+transformers>=4.31.0,<4.35.0
+datasets>=2.12.0
+accelerate>=0.21.0
+peft>=0.4.0
+trl>=0.7.2
+gradio>=3.38.0,<4.0.0
+scipy
+sentencepiece
+protobuf
+tiktoken
+fire
+jieba
+rouge-chinese
+nltk
+uvicorn
+pydantic
+fastapi
+sse-starlette
+matplotlib
diff --git a/llm_rl/reward_model.sh b/llm_rl/reward_model.sh
new file mode 100644
index 00000000..3068fb43
--- /dev/null
+++ b/llm_rl/reward_model.sh
@@ -0,0 +1,21 @@
+python src/train_bash.py \
+ --stage rm \
+ --model_name_or_path meta-llama/Llama-2-13b \
+ --do_train \
+ --dataset comparison_gpt4_en \
+ --template default \
+ --finetuning_type lora \
+ --lora_target q_proj,v_proj \
+ --resume_lora_training False \
+ --checkpoint_dir ./llama-2-13b-rm \
+ --output_dir ./llama-2-13b-rm \
+ --per_device_train_batch_size 2 \
+ --gradient_accumulation_steps 4 \
+ --lr_scheduler_type cosine \
+ --logging_steps 10 \
+ --save_steps 1000 \
+ --learning_rate 1e-6 \
+ --num_train_epochs 1.0 \
+ --plot_loss \
+ --fp16 \
+ --hf_auth_token "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"
\ No newline at end of file
diff --git a/llm_rl/setup.py b/llm_rl/setup.py
new file mode 100644
index 00000000..7638eaab
--- /dev/null
+++ b/llm_rl/setup.py
@@ -0,0 +1,55 @@
+import os
+import re
+from setuptools import setup, find_packages
+
+
+def get_version():
+ with open(os.path.join("src", "llmtuner", "__init__.py"), "r", encoding="utf-8") as f:
+ file_content = f.read()
+ pattern = r"{0}\W*=\W*\"([^\"]+)\"".format("__version__")
+ version, = re.findall(pattern, file_content)
+ return version
+
+
+def get_requires():
+ with open("requirements.txt", "r", encoding="utf-8") as f:
+ file_content = f.read()
+ lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
+ return lines
+
+
+def main():
+
+ setup(
+ name="llmtuner",
+ version=get_version(),
+ author="hiyouga",
+ author_email="hiyouga" "@" "buaa.edu.cn",
+ description="Easy-to-use LLM fine-tuning framework",
+ long_description=open("README.md", "r", encoding="utf-8").read(),
+ long_description_content_type="text/markdown",
+ keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"],
+ license="Apache 2.0 License",
+ url="https://github.com/hiyouga/LLaMA-Factory",
+ package_dir={"": "src"},
+ packages=find_packages("src"),
+ python_requires=">=3.8.0",
+ install_requires=get_requires(),
+ classifiers=[
+ "Development Status :: 3 - Alpha",
+ "Intended Audience :: Developers",
+ "Intended Audience :: Education",
+ "Intended Audience :: Science/Research",
+ "License :: OSI Approved :: Apache Software License",
+ "Operating System :: OS Independent",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ ]
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/llm_rl/src/api_demo.py b/llm_rl/src/api_demo.py
new file mode 100644
index 00000000..720089fd
--- /dev/null
+++ b/llm_rl/src/api_demo.py
@@ -0,0 +1,14 @@
+import uvicorn
+
+from llmtuner import ChatModel, create_app
+
+
+def main():
+ chat_model = ChatModel()
+ app = create_app(chat_model)
+ print("Visit http://localhost:8000/docs for API document.")
+ uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/llm_rl/src/cli_demo.py b/llm_rl/src/cli_demo.py
new file mode 100644
index 00000000..fe6a0bc4
--- /dev/null
+++ b/llm_rl/src/cli_demo.py
@@ -0,0 +1,39 @@
+import readline
+from llmtuner import ChatModel
+
+
+def main():
+ chat_model = ChatModel()
+ history = []
+ print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
+
+ while True:
+ try:
+ query = input("\nUser: ")
+ except UnicodeDecodeError:
+ print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
+ continue
+ except Exception:
+ raise
+
+ if query.strip() == "exit":
+ break
+
+ if query.strip() == "clear":
+ history = []
+ print("History has been removed.")
+ continue
+
+ print("Assistant: ", end="", flush=True)
+
+ response = ""
+ for new_text in chat_model.stream_chat(query, history):
+ print(new_text, end="", flush=True)
+ response += new_text
+ print()
+
+ history = history + [(query, response)]
+
+
+if __name__ == "__main__":
+ main()
diff --git a/llm_rl/src/evaluate.py b/llm_rl/src/evaluate.py
new file mode 100644
index 00000000..8af8c12c
--- /dev/null
+++ b/llm_rl/src/evaluate.py
@@ -0,0 +1,190 @@
+# coding=utf-8
+# Evaluates the performance of pre-trained models.
+# Usage: python evaluate.py --model_name_or_path path_to_model --checkpoint_dir path_to_ckpt --template vanilla
+# --task ceval --split validation --lang zh --n_shot 5 --batch_size 4 --save_name result
+# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
+
+import os
+import fire
+import json
+import torch
+import numpy as np
+import transformers
+from collections import Counter
+from datasets import load_dataset
+from dataclasses import dataclass
+from tqdm import tqdm, trange
+from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
+
+from llmtuner import ChatModel
+
+if TYPE_CHECKING:
+ from datasets import Dataset
+
+
+choices = ["A", "B", "C", "D"]
+
+
+@dataclass
+class EvalTemplate:
+
+ system: str
+ choice: str
+ answer: str
+ prefix: str
+
+ def parse_example(
+ self,
+ example: Dict[str, str]
+ ) -> Tuple[str, str]:
+ candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in choices if ch in example]
+ return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
+
+ def format_example(
+ self,
+ target_data: Dict[str, str],
+ support_set: "Dataset",
+ subject_name: str,
+ use_history: bool
+ ) -> Tuple[str, str, List[Tuple[str, str]]]:
+ query, resp = self.parse_example(target_data)
+ history = [self.parse_example(support_set[k]) for k in range(len(support_set))]
+
+ if len(history):
+ temp = history.pop(0)
+ history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1]))
+ else:
+ query = self.system.format(subject=subject_name) + query
+
+ if not use_history:
+ query = "\n\n".join(["".join(item) for item in history] + [query])
+ history = []
+ return query.strip(), resp, history
+
+
+eval_templates = {
+ "en": EvalTemplate(
+ system="The following are multiple choice questions (with answers) about {subject}.\n\n",
+ choice="\n{choice}. {content}",
+ answer="\nAnswer: ",
+ prefix=" "
+ ),
+ "zh": EvalTemplate(
+ system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
+ choice="\n{choice}. {content}",
+ answer="\n答案:",
+ prefix="\n"
+ )
+}
+
+
+@torch.inference_mode()
+def batch_inference(
+ chat_model: ChatModel,
+ batch_input: Dict[str, torch.Tensor],
+ prefix_char: str
+) -> List[str]:
+ logits = chat_model.model(**batch_input).logits
+ lengths = torch.sum(batch_input["attention_mask"], dim=-1)
+ nextword_logits = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
+ probs = torch.nn.functional.softmax(
+ torch.stack(
+ [
+ nextword_logits[:, chat_model.tokenizer.encode(prefix_char + choice, add_special_tokens=False)[-1]]
+ for choice in choices
+ ],
+ dim=-1
+ ),
+ dim=-1
+ ).detach()
+ return [chr(ord("A") + offset.item()) for offset in torch.argmax(probs, dim=-1)]
+
+
+def evaluate(
+ model_name_or_path: str,
+ finetuning_type: Optional[str] = "lora",
+ checkpoint_dir: Optional[str] = None,
+ template: Optional[str] = "vanilla",
+ task: Optional[str] = "ceval",
+ dataset_dir: Optional[str] = "evaluation",
+ split: Optional[Literal["validation", "test"]] = "validation",
+ lang: Optional[Literal["zh", "en"]] = "zh",
+ n_shot: Optional[int] = 5,
+ n_avg: Optional[int] = 1,
+ batch_size: Optional[int] = 4,
+ save_name: Optional[str] = None,
+ seed: Optional[int] = 42
+):
+ with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f:
+ categorys: Dict[str, Dict[str, str]] = json.load(f)
+
+ transformers.set_seed(seed)
+ chat_model = ChatModel(dict(
+ model_name_or_path=model_name_or_path,
+ finetuning_type=finetuning_type,
+ checkpoint_dir=checkpoint_dir,
+ template=template
+ ))
+ chat_model.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
+ eval_template = eval_templates[lang]
+
+ category_corrects: Dict[str, np.ndarray] = {
+ subj: np.array([], dtype="bool") for subj in ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
+ }
+ pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
+ results = {}
+ for subject in pbar:
+ dataset = load_dataset(os.path.join(dataset_dir, task), subject)
+ labels, answers, all_outputs = [], [], []
+ for epoch in range(n_avg):
+ pbar.set_postfix_str("{} Trial: {}".format(categorys[subject]["name"], epoch))
+ inputs, outputs = [], []
+ for i in trange(len(dataset[split]), desc="Formatting batches", position=1, leave=False):
+ support_set = dataset["train"].shuffle().select(range(min(n_shot, len(dataset["train"]))))
+ query, resp, history = eval_template.format_example(
+ target_data=dataset[split][i],
+ support_set=support_set,
+ subject_name=categorys[subject]["name"],
+ use_history=chat_model.template.use_history
+ )
+ input_ids, _ = chat_model.template.encode_oneturn(
+ tokenizer=chat_model.tokenizer, query=query, resp=resp, history=history
+ )
+ inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
+ if epoch == 0:
+ labels.append(resp)
+
+ for i in trange(0, len(inputs), batch_size, desc="Predicting batches", position=1, leave=False):
+ batch_input = chat_model.tokenizer.pad(
+ inputs[i : i + batch_size], return_attention_mask=True, return_tensors="pt"
+ ).to(chat_model.model.device)
+ preds = batch_inference(chat_model, batch_input, eval_template.prefix)
+ outputs += preds
+ all_outputs.append(outputs)
+
+ for i in range(len(all_outputs[0])):
+ count = Counter([all_outputs[epoch][i] for epoch in range(n_avg)])
+ answers.append(count.most_common(1)[0][0])
+
+ corrects = (np.array(answers) == np.array(labels))
+ category_name = categorys[subject]["category"]
+ category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
+ category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0)
+ results[subject] = {str(i): answers[i] for i in range(len(answers))}
+
+ score_info = "\n".join([
+ "{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
+ for category_name, category_correct in category_corrects.items() if len(category_correct)
+ ])
+
+ print(score_info)
+ if save_name is not None:
+ with open(save_name + ".json", "w", encoding="utf-8", newline="\n") as f:
+ json.dump(results, f, indent=2)
+
+ with open(save_name + ".log", "w", encoding="utf-8", newline="\n") as f:
+ f.write(score_info)
+
+
+if __name__ == "__main__":
+ fire.Fire(evaluate)
diff --git a/llm_rl/src/export_model.py b/llm_rl/src/export_model.py
new file mode 100644
index 00000000..4baeb2c3
--- /dev/null
+++ b/llm_rl/src/export_model.py
@@ -0,0 +1,9 @@
+from llmtuner import export_model
+
+
+def main():
+ export_model()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/llm_rl/src/llmtuner/__init__.py b/llm_rl/src/llmtuner/__init__.py
new file mode 100644
index 00000000..37eb9535
--- /dev/null
+++ b/llm_rl/src/llmtuner/__init__.py
@@ -0,0 +1,9 @@
+# Level: api, webui > chat > tuner > dsets > extras, hparams
+
+from llmtuner.api import create_app
+from llmtuner.chat import ChatModel
+from llmtuner.tuner import export_model, run_exp
+from llmtuner.webui import create_ui, create_web_demo
+
+
+__version__ = "0.2.0"
diff --git a/llm_rl/src/llmtuner/api/__init__.py b/llm_rl/src/llmtuner/api/__init__.py
new file mode 100644
index 00000000..b3ce183a
--- /dev/null
+++ b/llm_rl/src/llmtuner/api/__init__.py
@@ -0,0 +1 @@
+from llmtuner.api.app import create_app
diff --git a/llm_rl/src/llmtuner/api/app.py b/llm_rl/src/llmtuner/api/app.py
new file mode 100644
index 00000000..27fb19e0
--- /dev/null
+++ b/llm_rl/src/llmtuner/api/app.py
@@ -0,0 +1,146 @@
+import json
+import uvicorn
+from fastapi import FastAPI, HTTPException, status
+from fastapi.middleware.cors import CORSMiddleware
+from contextlib import asynccontextmanager
+from sse_starlette import EventSourceResponse
+from typing import List, Tuple
+from pydantic import BaseModel
+
+from llmtuner.extras.misc import torch_gc
+from llmtuner.chat import ChatModel
+from llmtuner.api.protocol import (
+ Role,
+ Finish,
+ ModelCard,
+ ModelList,
+ ChatMessage,
+ DeltaMessage,
+ ChatCompletionRequest,
+ ChatCompletionResponse,
+ ChatCompletionStreamResponse,
+ ChatCompletionResponseChoice,
+ ChatCompletionResponseStreamChoice,
+ ChatCompletionResponseUsage
+)
+
+
+@asynccontextmanager
+async def lifespan(app: FastAPI): # collects GPU memory
+ yield
+ torch_gc()
+
+
+def to_json(data: BaseModel) -> str:
+ try: # pydantic v2
+ return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
+ except: # pydantic v1
+ return data.json(exclude_unset=True, ensure_ascii=False)
+
+
+def create_app(chat_model: ChatModel) -> FastAPI:
+ app = FastAPI(lifespan=lifespan)
+
+ app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+ )
+
+ @app.get("/v1/models", response_model=ModelList)
+ async def list_models():
+ model_card = ModelCard(id="gpt-3.5-turbo")
+ return ModelList(data=[model_card])
+
+ @app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
+ async def create_chat_completion(request: ChatCompletionRequest):
+ if len(request.messages) < 1 or request.messages[-1].role != Role.USER:
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
+
+ query = request.messages[-1].content
+ prev_messages = request.messages[:-1]
+ if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
+ system = prev_messages.pop(0).content
+ else:
+ system = None
+
+ history = []
+ if len(prev_messages) % 2 == 0:
+ for i in range(0, len(prev_messages), 2):
+ if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT:
+ history.append([prev_messages[i].content, prev_messages[i+1].content])
+ else:
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
+
+ if request.stream:
+ generate = predict(query, history, system, request)
+ return EventSourceResponse(generate, media_type="text/event-stream")
+
+ response, (prompt_length, response_length) = chat_model.chat(
+ query, history, system,
+ do_sample=request.do_sample,
+ temperature=request.temperature,
+ top_p=request.top_p,
+ max_new_tokens=request.max_tokens,
+ num_return_sequences=request.n
+ )
+
+ usage = ChatCompletionResponseUsage(
+ prompt_tokens=prompt_length,
+ completion_tokens=response_length,
+ total_tokens=prompt_length+response_length
+ )
+
+ choices = [ChatCompletionResponseChoice(
+ index=i,
+ message=ChatMessage(role=Role.ASSISTANT, content=choice),
+ finish_reason=Finish.STOP
+ ) for i, choice in enumerate(response)]
+
+ return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
+
+ async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=0,
+ delta=DeltaMessage(role=Role.ASSISTANT),
+ finish_reason=None
+ )
+ chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
+ yield to_json(chunk)
+
+ for new_text in chat_model.stream_chat(
+ query, history, system,
+ do_sample=request.do_sample,
+ temperature=request.temperature,
+ top_p=request.top_p,
+ max_new_tokens=request.max_tokens
+ ):
+ if len(new_text) == 0:
+ continue
+
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=0,
+ delta=DeltaMessage(content=new_text),
+ finish_reason=None
+ )
+ chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
+ yield to_json(chunk)
+
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=0,
+ delta=DeltaMessage(),
+ finish_reason=Finish.STOP
+ )
+ chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
+ yield to_json(chunk)
+ yield "[DONE]"
+
+ return app
+
+
+if __name__ == "__main__":
+ chat_model = ChatModel()
+ app = create_app(chat_model)
+ uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
diff --git a/llm_rl/src/llmtuner/api/protocol.py b/llm_rl/src/llmtuner/api/protocol.py
new file mode 100644
index 00000000..6b99da40
--- /dev/null
+++ b/llm_rl/src/llmtuner/api/protocol.py
@@ -0,0 +1,83 @@
+import time
+from enum import Enum
+from pydantic import BaseModel, Field
+from typing import List, Optional
+
+
+class Role(str, Enum):
+ USER = "user"
+ ASSISTANT = "assistant"
+ SYSTEM = "system"
+
+
+class Finish(str, Enum):
+ STOP = "stop"
+ LENGTH = "length"
+
+
+class ModelCard(BaseModel):
+ id: str
+ object: Optional[str] = "model"
+ created: Optional[int] = Field(default_factory=lambda: int(time.time()))
+ owned_by: Optional[str] = "owner"
+
+
+class ModelList(BaseModel):
+ object: Optional[str] = "list"
+ data: Optional[List[ModelCard]] = []
+
+
+class ChatMessage(BaseModel):
+ role: Role
+ content: str
+
+
+class DeltaMessage(BaseModel):
+ role: Optional[Role] = None
+ content: Optional[str] = None
+
+
+class ChatCompletionRequest(BaseModel):
+ model: str
+ messages: List[ChatMessage]
+ do_sample: Optional[bool] = True
+ temperature: Optional[float] = None
+ top_p: Optional[float] = None
+ n: Optional[int] = 1
+ max_tokens: Optional[int] = None
+ stream: Optional[bool] = False
+
+
+class ChatCompletionResponseChoice(BaseModel):
+ index: int
+ message: ChatMessage
+ finish_reason: Finish
+
+
+class ChatCompletionResponseStreamChoice(BaseModel):
+ index: int
+ delta: DeltaMessage
+ finish_reason: Optional[Finish] = None
+
+
+class ChatCompletionResponseUsage(BaseModel):
+ prompt_tokens: int
+ completion_tokens: int
+ total_tokens: int
+
+
+class ChatCompletionResponse(BaseModel):
+ id: Optional[str] = "chatcmpl-default"
+ object: Optional[str] = "chat.completion"
+ created: Optional[int] = Field(default_factory=lambda: int(time.time()))
+ model: str
+ choices: List[ChatCompletionResponseChoice]
+ usage: ChatCompletionResponseUsage
+
+
+class ChatCompletionStreamResponse(BaseModel):
+ id: Optional[str] = "chatcmpl-default"
+ object: Optional[str] = "chat.completion.chunk"
+ created: Optional[int] = Field(default_factory=lambda: int(time.time()))
+ model: str
+ choices: List[ChatCompletionResponseStreamChoice]
diff --git a/llm_rl/src/llmtuner/chat/__init__.py b/llm_rl/src/llmtuner/chat/__init__.py
new file mode 100644
index 00000000..ba240d05
--- /dev/null
+++ b/llm_rl/src/llmtuner/chat/__init__.py
@@ -0,0 +1 @@
+from llmtuner.chat.stream_chat import ChatModel
diff --git a/llm_rl/src/llmtuner/chat/stream_chat.py b/llm_rl/src/llmtuner/chat/stream_chat.py
new file mode 100644
index 00000000..cc815d1b
--- /dev/null
+++ b/llm_rl/src/llmtuner/chat/stream_chat.py
@@ -0,0 +1,109 @@
+import torch
+from typing import Any, Dict, Generator, List, Optional, Tuple
+from threading import Thread
+from transformers import GenerationConfig, TextIteratorStreamer
+
+from llmtuner.extras.misc import dispatch_model, get_logits_processor
+from llmtuner.extras.template import get_template_and_fix_tokenizer
+from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
+
+
+class ChatModel:
+
+ def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
+ model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
+ self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
+ self.tokenizer.padding_side = "left"
+ self.model = dispatch_model(self.model)
+ self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
+ self.system_prompt = data_args.system_prompt
+
+ def process_args(
+ self,
+ query: str,
+ history: Optional[List[Tuple[str, str]]] = None,
+ system: Optional[str] = None,
+ **input_kwargs
+ ) -> Tuple[Dict[str, Any], int]:
+ system = system or self.system_prompt
+ prompt, _ = self.template.encode_oneturn(
+ tokenizer=self.tokenizer, query=query, resp="", history=history, system=system
+ )
+ prompt_length = len(prompt)
+ input_ids = torch.tensor([prompt], device=self.model.device)
+
+ do_sample = input_kwargs.pop("do_sample", None)
+ temperature = input_kwargs.pop("temperature", None)
+ top_p = input_kwargs.pop("top_p", None)
+ top_k = input_kwargs.pop("top_k", None)
+ num_return_sequences = input_kwargs.pop("num_return_sequences", None)
+ repetition_penalty = input_kwargs.pop("repetition_penalty", None)
+ max_length = input_kwargs.pop("max_length", None)
+ max_new_tokens = input_kwargs.pop("max_new_tokens", None)
+
+ generating_args = self.generating_args.to_dict()
+ generating_args.update(dict(
+ do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
+ temperature=temperature or generating_args["temperature"],
+ top_p=top_p or generating_args["top_p"],
+ top_k=top_k or generating_args["top_k"],
+ num_return_sequences=num_return_sequences or 1,
+ repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
+ eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
+ pad_token_id=self.tokenizer.pad_token_id
+ ))
+
+ if isinstance(num_return_sequences, int) and num_return_sequences > 1:
+ generating_args["do_sample"] = True
+
+ if max_length:
+ generating_args.pop("max_new_tokens", None)
+ generating_args["max_length"] = max_length
+
+ if max_new_tokens:
+ generating_args.pop("max_length", None)
+ generating_args["max_new_tokens"] = max_new_tokens
+
+ gen_kwargs = dict(
+ inputs=input_ids,
+ generation_config=GenerationConfig(**generating_args),
+ logits_processor=get_logits_processor()
+ )
+
+ return gen_kwargs, prompt_length
+
+ @torch.inference_mode()
+ def chat(
+ self,
+ query: str,
+ history: Optional[List[Tuple[str, str]]] = None,
+ system: Optional[str] = None,
+ **input_kwargs
+ ) -> Tuple[List[str], Tuple[int, int]]:
+ gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs)
+ generate_output = self.model.generate(**gen_kwargs)
+ response_ids = generate_output[:, prompt_length:]
+ response = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
+ response_length = 0
+ for i in range(len(response_ids)):
+ eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
+ response_length += eos_index[0].item() if len(eos_index) else len(response_ids[i])
+
+ return response, (prompt_length, response_length)
+
+ @torch.inference_mode()
+ def stream_chat(
+ self,
+ query: str,
+ history: Optional[List[Tuple[str, str]]] = None,
+ system: Optional[str] = None,
+ **input_kwargs
+ ) -> Generator[str, None, None]:
+ gen_kwargs, _ = self.process_args(query, history, system, **input_kwargs)
+ streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
+ gen_kwargs["streamer"] = streamer
+
+ thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
+ thread.start()
+
+ yield from streamer
diff --git a/llm_rl/src/llmtuner/dsets/__init__.py b/llm_rl/src/llmtuner/dsets/__init__.py
new file mode 100644
index 00000000..cccbd745
--- /dev/null
+++ b/llm_rl/src/llmtuner/dsets/__init__.py
@@ -0,0 +1,3 @@
+from llmtuner.dsets.loader import get_dataset
+from llmtuner.dsets.preprocess import preprocess_dataset
+from llmtuner.dsets.utils import split_dataset
diff --git a/llm_rl/src/llmtuner/dsets/loader.py b/llm_rl/src/llmtuner/dsets/loader.py
new file mode 100644
index 00000000..834ef733
--- /dev/null
+++ b/llm_rl/src/llmtuner/dsets/loader.py
@@ -0,0 +1,145 @@
+import os
+from typing import TYPE_CHECKING, Any, Dict, List, Union
+
+from datasets import concatenate_datasets, interleave_datasets, load_dataset
+
+from llmtuner.dsets.utils import checksum, EXT2TYPE
+from llmtuner.extras.logging import get_logger
+
+if TYPE_CHECKING:
+ from datasets import Dataset, IterableDataset
+ from llmtuner.hparams import ModelArguments, DataArguments
+
+
+logger = get_logger(__name__)
+
+
+def get_dataset(
+ model_args: "ModelArguments",
+ data_args: "DataArguments"
+) -> Union["Dataset", "IterableDataset"]:
+ max_samples = data_args.max_samples
+ all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
+
+ for dataset_attr in data_args.dataset_list:
+ logger.info("Loading dataset {}...".format(dataset_attr))
+
+ if dataset_attr.load_from == "hf_hub":
+ data_path = dataset_attr.dataset_name
+ data_name = dataset_attr.subset
+ data_files = None
+ elif dataset_attr.load_from == "script":
+ data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
+ data_name = dataset_attr.subset
+ data_files = None
+ elif dataset_attr.load_from == "file":
+ data_path, data_name = None, None
+ data_files: List[str] = []
+ if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is directory
+ for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
+ data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
+ if data_path is None:
+ data_path = EXT2TYPE.get(file_name.split(".")[-1], None)
+ else:
+ assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical."
+ elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is file
+ data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
+ data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None)
+ else:
+ raise ValueError("File not found.")
+
+ assert data_path, "File extension must be txt, csv, json or jsonl."
+ checksum(data_files, dataset_attr.dataset_sha1)
+ else:
+ raise NotImplementedError
+
+ dataset = load_dataset(
+ path=data_path,
+ name=data_name,
+ data_files=data_files,
+ split=data_args.split,
+ cache_dir=model_args.cache_dir,
+ streaming=data_args.streaming,
+ use_auth_token=True if model_args.use_auth_token else None
+ )
+
+ if max_samples is not None: # truncate dataset
+ dataset = dataset.select(range(min(len(dataset), max_samples)))
+
+ def convert_format(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
+ # convert dataset from sharegpt format to alpaca format
+ outputs = {"prompt": [], "query": [], "response": [], "history": []}
+ for msg_list in examples[dataset_attr.messages]:
+ msg_list = msg_list[:len(msg_list) // 2 * 2] # should be multiples of 2
+ if len(msg_list) == 0:
+ continue
+
+ msg_pairs = []
+ user_role, assistant_role = None, None
+ for idx in range(0, len(msg_list), 2):
+ if user_role is None and assistant_role is None:
+ user_role = msg_list[idx][dataset_attr.role]
+ assistant_role = msg_list[idx + 1][dataset_attr.role]
+ else:
+ if (
+ msg_list[idx][dataset_attr.role] != user_role
+ or msg_list[idx+1][dataset_attr.role] != assistant_role
+ ):
+ raise ValueError("Only accepts conversation in u/a/u/a/u/a order.")
+ msg_pairs.append((msg_list[idx][dataset_attr.content], msg_list[idx + 1][dataset_attr.content]))
+
+ if len(msg_pairs) != 0:
+ outputs["prompt"].append(msg_pairs[-1][0])
+ outputs["query"].append("")
+ outputs["response"].append(msg_pairs[-1][1])
+ outputs["history"].append(msg_pairs[:-1])
+
+ return outputs
+
+ if dataset_attr.formatting == "sharegpt": # convert format
+ column_names = list(next(iter(dataset)).keys())
+ kwargs = {}
+ if not data_args.streaming:
+ kwargs = dict(
+ num_proc=data_args.preprocessing_num_workers,
+ load_from_cache_file=(not data_args.overwrite_cache),
+ desc="Converting format of dataset"
+ )
+
+ dataset = dataset.map(
+ convert_format,
+ batched=True,
+ remove_columns=column_names,
+ **kwargs
+ )
+ else:
+ for column_name in ["prompt", "query", "response", "history"]: # align dataset
+ if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
+ dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
+
+ if dataset_attr.system_prompt: # add system prompt
+ system_prompt = dataset_attr.system_prompt
+ if data_args.streaming:
+ dataset = dataset.map(lambda _: {"system": system_prompt})
+ else:
+ dataset = dataset.add_column("system", [system_prompt] * len(dataset))
+
+ all_datasets.append(dataset)
+
+ if len(data_args.dataset_list) == 1:
+ return all_datasets[0]
+ elif data_args.mix_strategy == "concat":
+ if data_args.streaming:
+ logger.warning("The samples between different datasets will not be mixed in streaming mode.")
+ return concatenate_datasets(all_datasets)
+ elif data_args.mix_strategy.startswith("interleave"):
+ if not data_args.streaming:
+ logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
+ return interleave_datasets(
+ datasets=all_datasets,
+ probabilities=data_args.interleave_probs,
+ seed=data_args.seed,
+ stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
+ )
+ else:
+ raise ValueError("Unknown mixing strategy.")
diff --git a/llm_rl/src/llmtuner/dsets/preprocess.py b/llm_rl/src/llmtuner/dsets/preprocess.py
new file mode 100644
index 00000000..0484b78e
--- /dev/null
+++ b/llm_rl/src/llmtuner/dsets/preprocess.py
@@ -0,0 +1,268 @@
+import os
+import tiktoken
+from itertools import chain
+from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union
+
+from datasets import load_from_disk
+
+from llmtuner.extras.constants import IGNORE_INDEX
+from llmtuner.extras.logging import get_logger
+from llmtuner.extras.template import get_template_and_fix_tokenizer
+
+if TYPE_CHECKING:
+ from datasets import Dataset, IterableDataset
+ from transformers import Seq2SeqTrainingArguments
+ from transformers.tokenization_utils import PreTrainedTokenizer
+ from llmtuner.hparams import DataArguments
+
+
+logger = get_logger(__name__)
+
+
+def preprocess_dataset(
+ dataset: Union["Dataset", "IterableDataset"],
+ tokenizer: "PreTrainedTokenizer",
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
+ stage: Literal["pt", "sft", "rm", "ppo"]
+) -> Union["Dataset", "IterableDataset"]:
+ template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
+
+ if data_args.train_on_prompt and template.efficient_eos:
+ raise ValueError("Current template does not support `train_on_prompt`.")
+
+ def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
+ for i in range(len(examples["prompt"])):
+ query, response = examples["prompt"][i], examples["response"][i]
+ query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
+ history = examples["history"][i] if "history" in examples else None
+ system = examples["system"][i] if "system" in examples else None
+ yield query, response, history, system
+
+ def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
+ # build grouped texts with format `X1 X2 X3 ...`
+ if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
+ kwargs = dict(allowed_special="all")
+ else:
+ kwargs = dict(add_special_tokens=True)
+
+ if hasattr(tokenizer, "add_eos_token"): # for LLaMA tokenizer
+ setattr(tokenizer, "add_eos_token", True)
+
+ tokenized_examples = tokenizer(examples["prompt"], **kwargs)
+ concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
+ total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
+ block_size = data_args.cutoff_len
+ # we drop the small remainder, and if the total_length < block_size, we exclude this batch
+ total_length = (total_length // block_size) * block_size
+ # split by chunks of cutoff_len
+ result = {
+ k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
+ for k, t in concatenated_examples.items()
+ }
+ return result
+
+ def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
+ # build inputs with format ` X Y ` and labels with format ` ... Y `
+ # for multiturn examples, we only mask the prompt part in each prompt-response pair.
+ model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
+
+ for query, response, history, system in construct_example(examples):
+ if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
+ continue
+
+ input_ids, labels = [], []
+ for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
+ tokenizer, query, response, history, system
+ )):
+ total_len = len(source_ids) + len(target_ids)
+ max_source_len = int(data_args.cutoff_len * (len(source_ids) / total_len))
+ max_target_len = int(data_args.cutoff_len * (len(target_ids) / total_len))
+
+ if len(source_ids) > max_source_len:
+ source_ids = source_ids[:max_source_len]
+ if len(target_ids) > max_target_len:
+ target_ids = target_ids[:max_target_len]
+
+ if data_args.train_on_prompt:
+ source_mask = source_ids
+ elif turn_idx != 0 and template.efficient_eos:
+ source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
+ else:
+ source_mask = [IGNORE_INDEX] * len(source_ids)
+
+ input_ids += source_ids + target_ids
+ labels += source_mask + target_ids
+
+ if template.efficient_eos:
+ input_ids += [tokenizer.eos_token_id]
+ labels += [tokenizer.eos_token_id]
+
+ if len(input_ids) > data_args.cutoff_len:
+ input_ids = input_ids[:data_args.cutoff_len]
+ labels = labels[:data_args.cutoff_len]
+
+ model_inputs["input_ids"].append(input_ids)
+ model_inputs["attention_mask"].append([1] * len(input_ids))
+ model_inputs["labels"].append(labels)
+
+ return model_inputs
+
+ def preprocess_packed_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
+ # build inputs with format ` X1 Y1 X2 Y2 `
+ # and labels with format ` ... Y1 ... Y2 `
+ model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
+ input_ids, labels = [], []
+ for query, response, history, system in construct_example(examples):
+ if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
+ continue
+
+ for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
+ tokenizer, query, response, history, system
+ )):
+ if data_args.train_on_prompt:
+ source_mask = source_ids
+ elif turn_idx != 0 and template.efficient_eos:
+ source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
+ else:
+ source_mask = [IGNORE_INDEX] * len(source_ids)
+ input_ids += source_ids + target_ids
+ labels += source_mask + target_ids
+
+ if template.efficient_eos:
+ input_ids += [tokenizer.eos_token_id]
+ labels += [tokenizer.eos_token_id]
+
+ total_length = len(input_ids)
+ block_size = data_args.cutoff_len
+ # we drop the small remainder, and if the total_length < block_size, we exclude this batch
+ total_length = (total_length // block_size) * block_size
+ # split by chunks of cutoff_len
+ for i in range(0, total_length, block_size):
+ model_inputs["input_ids"].append(input_ids[i: i + block_size])
+ model_inputs["attention_mask"].append([1] * block_size)
+ model_inputs["labels"].append(labels[i: i + block_size])
+
+ return model_inputs
+
+ def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
+ # build inputs with format ` X` and labels with format `Y `
+ model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
+
+ for query, response, history, system in construct_example(examples):
+ if not (isinstance(query, str) and query != ""):
+ continue
+
+ input_ids, labels = template.encode_oneturn(tokenizer, query, response, history, system)
+
+ if template.efficient_eos:
+ labels += [tokenizer.eos_token_id]
+
+ if len(input_ids) > data_args.cutoff_len:
+ input_ids = input_ids[:data_args.cutoff_len]
+ if len(labels) > data_args.cutoff_len:
+ labels = labels[:data_args.cutoff_len]
+
+ model_inputs["input_ids"].append(input_ids)
+ model_inputs["attention_mask"].append([1] * len(input_ids))
+ model_inputs["labels"].append(labels)
+
+ return model_inputs
+
+ def preprocess_pairwise_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
+ # build input pairs with format ` X`, `Y1 ` and `Y2 `
+ model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
+ for query, response, history, system in construct_example(examples):
+ if not (isinstance(query, str) and isinstance(response, list) and query != "" and len(response) > 1):
+ continue
+
+ prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system)
+ _, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system)
+
+ if template.efficient_eos:
+ chosen_ids += [tokenizer.eos_token_id]
+ rejected_ids += [tokenizer.eos_token_id]
+
+ total_len = len(prompt_ids) + max(len(chosen_ids), len(rejected_ids))
+ max_source_len = int(data_args.cutoff_len * (len(prompt_ids) / total_len))
+ max_target_len = int(data_args.cutoff_len * (max(len(chosen_ids), len(rejected_ids)) / total_len))
+
+ if len(prompt_ids) > max_source_len:
+ prompt_ids = prompt_ids[:max_source_len]
+ if len(chosen_ids) > max_target_len:
+ chosen_ids = chosen_ids[:max_target_len]
+ if len(rejected_ids) > max_target_len:
+ rejected_ids = rejected_ids[:max_target_len]
+
+ model_inputs["prompt_ids"].append(prompt_ids)
+ model_inputs["chosen_ids"].append(chosen_ids)
+ model_inputs["rejected_ids"].append(rejected_ids)
+
+ return model_inputs
+
+ def print_supervised_dataset_example(example: Dict[str, List[int]]) -> None:
+ print("input_ids:\n{}".format(example["input_ids"]))
+ print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
+ print("label_ids:\n{}".format(example["labels"]))
+ print("labels:\n{}".format(
+ tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
+ ))
+
+ def print_pairwise_dataset_example(example: Dict[str, List[int]]) -> None:
+ print("prompt_ids:\n{}".format(example["prompt_ids"]))
+ print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
+ print("chosen_ids:\n{}".format(example["chosen_ids"]))
+ print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
+ print("rejected_ids:\n{}".format(example["rejected_ids"]))
+ print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
+
+ def print_unsupervised_dataset_example(example: Dict[str, List[int]]) -> None:
+ print("input_ids:\n{}".format(example["input_ids"]))
+ print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
+
+ if stage == "pt":
+ preprocess_func = preprocess_pretrain_dataset
+ print_function = print_unsupervised_dataset_example
+ elif stage == "sft" and not training_args.predict_with_generate:
+ preprocess_func = preprocess_packed_supervised_dataset if data_args.sft_packing else preprocess_supervised_dataset
+ print_function = print_supervised_dataset_example
+ elif stage == "rm":
+ preprocess_func = preprocess_pairwise_dataset
+ print_function = print_pairwise_dataset_example
+ else:
+ preprocess_func = preprocess_unsupervised_dataset
+ print_function = print_unsupervised_dataset_example
+
+ if data_args.cache_path is not None and os.path.exists(data_args.cache_path):
+ logger.warning("Loading dataset from disk will ignore other data arguments.")
+ return load_from_disk(data_args.cache_path)
+
+ with training_args.main_process_first(desc="dataset map pre-processing"):
+ column_names = list(next(iter(dataset)).keys())
+ kwargs = {}
+ if not data_args.streaming:
+ kwargs = dict(
+ num_proc=data_args.preprocessing_num_workers,
+ load_from_cache_file=(not data_args.overwrite_cache),
+ desc="Running tokenizer on dataset"
+ )
+
+ dataset = dataset.map(
+ preprocess_func,
+ batched=True,
+ remove_columns=column_names,
+ **kwargs
+ )
+
+ if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
+ if training_args.should_save:
+ dataset.save_to_disk(data_args.cache_path)
+ raise SystemExit("Dataset saved, rerun this script with the same `--cache_file`.")
+
+ if training_args.should_log:
+ try:
+ print_function(next(iter(dataset)))
+ except StopIteration:
+ raise RuntimeError("Empty dataset!")
+
+ return dataset
diff --git a/llm_rl/src/llmtuner/dsets/utils.py b/llm_rl/src/llmtuner/dsets/utils.py
new file mode 100644
index 00000000..bf337014
--- /dev/null
+++ b/llm_rl/src/llmtuner/dsets/utils.py
@@ -0,0 +1,59 @@
+import hashlib
+from typing import TYPE_CHECKING, Dict, List, Optional, Union
+
+from llmtuner.extras.logging import get_logger
+
+if TYPE_CHECKING:
+ from datasets import Dataset, IterableDataset
+ from transformers import TrainingArguments
+ from llmtuner.hparams import DataArguments
+
+
+logger = get_logger(__name__)
+
+
+EXT2TYPE = {
+ "csv": "csv",
+ "json": "json",
+ "jsonl": "json",
+ "txt": "text"
+}
+
+
+def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
+ if file_sha1 is None:
+ logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
+ return
+
+ if len(data_files) != 1:
+ logger.warning("Checksum failed: too many files.")
+ return
+
+ with open(data_files[0], "rb") as f:
+ sha1 = hashlib.sha1(f.read()).hexdigest()
+ if sha1 != file_sha1:
+ logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
+
+
+def split_dataset(
+ dataset: Union["Dataset", "IterableDataset"],
+ data_args: "DataArguments",
+ training_args: "TrainingArguments"
+) -> Dict[str, "Dataset"]:
+ if training_args.do_train:
+ if data_args.val_size > 1e-6: # Split the dataset
+ if data_args.streaming:
+ val_set = dataset.take(int(data_args.val_size))
+ train_set = dataset.skip(int(data_args.val_size))
+ dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
+ return {"train_dataset": train_set, "eval_dataset": val_set}
+ else:
+ val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
+ dataset = dataset.train_test_split(test_size=val_size, seed=training_args.seed)
+ return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
+ else:
+ if data_args.streaming:
+ dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
+ return {"train_dataset": dataset}
+ else: # do_eval or do_predict
+ return {"eval_dataset": dataset}
diff --git a/llm_rl/src/llmtuner/extras/__init__.py b/llm_rl/src/llmtuner/extras/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/llm_rl/src/llmtuner/extras/callbacks.py b/llm_rl/src/llmtuner/extras/callbacks.py
new file mode 100644
index 00000000..7398d424
--- /dev/null
+++ b/llm_rl/src/llmtuner/extras/callbacks.py
@@ -0,0 +1,155 @@
+import os
+import json
+import time
+from typing import TYPE_CHECKING
+from datetime import timedelta
+
+from transformers import TrainerCallback
+from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
+
+from llmtuner.extras.constants import LOG_FILE_NAME
+from llmtuner.extras.logging import get_logger
+
+if TYPE_CHECKING:
+ from transformers import TrainingArguments, TrainerState, TrainerControl
+
+
+logger = get_logger(__name__)
+
+
+class SavePeftModelCallback(TrainerCallback):
+
+ def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
+ r"""
+ Event called after a checkpoint save.
+ """
+ if args.should_save:
+ output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
+ model = kwargs.pop("model")
+ if getattr(model, "is_peft_model", False):
+ getattr(model, "pretrained_model").save_pretrained(output_dir)
+
+ def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
+ r"""
+ Event called at the end of training.
+ """
+ if args.should_save:
+ model = kwargs.pop("model")
+ if getattr(model, "is_peft_model", False):
+ getattr(model, "pretrained_model").save_pretrained(args.output_dir)
+
+
+class LogCallback(TrainerCallback):
+
+ def __init__(self, runner=None):
+ self.runner = runner
+ self.in_training = False
+ self.start_time = time.time()
+ self.cur_steps = 0
+ self.max_steps = 0
+ self.elapsed_time = ""
+ self.remaining_time = ""
+
+ def timing(self):
+ cur_time = time.time()
+ elapsed_time = cur_time - self.start_time
+ avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0
+ remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step
+ self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
+ self.remaining_time = str(timedelta(seconds=int(remaining_time)))
+
+ def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
+ r"""
+ Event called at the beginning of training.
+ """
+ if state.is_local_process_zero:
+ self.in_training = True
+ self.start_time = time.time()
+ self.max_steps = state.max_steps
+ if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
+ logger.warning("Previous log file in this folder will be deleted.")
+ os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
+
+ def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
+ r"""
+ Event called at the end of training.
+ """
+ if state.is_local_process_zero:
+ self.in_training = False
+ self.cur_steps = 0
+ self.max_steps = 0
+
+ def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
+ r"""
+ Event called at the end of an substep during gradient accumulation.
+ """
+ if state.is_local_process_zero and self.runner is not None and self.runner.aborted:
+ control.should_epoch_stop = True
+ control.should_training_stop = True
+
+ def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
+ r"""
+ Event called at the end of a training step.
+ """
+ if state.is_local_process_zero:
+ self.cur_steps = state.global_step
+ self.timing()
+ if self.runner is not None and self.runner.aborted:
+ control.should_epoch_stop = True
+ control.should_training_stop = True
+
+ def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
+ r"""
+ Event called after an evaluation phase.
+ """
+ if state.is_local_process_zero and not self.in_training:
+ self.cur_steps = 0
+ self.max_steps = 0
+
+ def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs):
+ r"""
+ Event called after a successful prediction.
+ """
+ if state.is_local_process_zero and not self.in_training:
+ self.cur_steps = 0
+ self.max_steps = 0
+
+ def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
+ r"""
+ Event called after logging the last logs.
+ """
+ if not state.is_local_process_zero:
+ return
+
+ logs = dict(
+ current_steps=self.cur_steps,
+ total_steps=self.max_steps,
+ loss=state.log_history[-1].get("loss", None),
+ eval_loss=state.log_history[-1].get("eval_loss", None),
+ predict_loss=state.log_history[-1].get("predict_loss", None),
+ reward=state.log_history[-1].get("reward", None),
+ learning_rate=state.log_history[-1].get("learning_rate", None),
+ epoch=state.log_history[-1].get("epoch", None),
+ percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
+ elapsed_time=self.elapsed_time,
+ remaining_time=self.remaining_time
+ )
+ if self.runner is not None:
+ logger.info("{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
+ logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
+ ))
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
+ f.write(json.dumps(logs) + "\n")
+
+ def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
+ r"""
+ Event called after a prediction step.
+ """
+ eval_dataloader = kwargs.pop("eval_dataloader", None)
+ if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training:
+ if self.max_steps == 0:
+ self.max_steps = len(eval_dataloader)
+ self.cur_steps += 1
+ self.timing()
diff --git a/llm_rl/src/llmtuner/extras/constants.py b/llm_rl/src/llmtuner/extras/constants.py
new file mode 100644
index 00000000..dc55a080
--- /dev/null
+++ b/llm_rl/src/llmtuner/extras/constants.py
@@ -0,0 +1,92 @@
+IGNORE_INDEX = -100
+
+LOG_FILE_NAME = "trainer_log.jsonl"
+
+LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp", "ln_1", "ln_2"]
+
+METHODS = ["full", "freeze", "lora"]
+
+TRAINING_STAGES = {
+ "Supervised Fine-Tuning": "sft",
+ "Reward Modeling": "rm",
+ "PPO": "ppo",
+ "DPO": "dpo",
+ "Pre-Training": "pt"
+}
+
+SUPPORTED_MODELS = {
+ "LLaMA-7B": "huggyllama/llama-7b",
+ "LLaMA-13B": "huggyllama/llama-13b",
+ "LLaMA-30B": "huggyllama/llama-30b",
+ "LLaMA-65B": "huggyllama/llama-65b",
+ "LLaMA2-7B": "meta-llama/Llama-2-7b-hf",
+ "LLaMA2-13B": "meta-llama/Llama-2-13b-hf",
+ "LLaMA2-70B": "meta-llama/Llama-2-70b-hf",
+ "LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf",
+ "LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf",
+ "LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf",
+ "ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b",
+ "ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b",
+ "ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b",
+ "ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b",
+ "BLOOM-560M": "bigscience/bloom-560m",
+ "BLOOM-3B": "bigscience/bloom-3b",
+ "BLOOM-7B1": "bigscience/bloom-7b1",
+ "BLOOMZ-560M": "bigscience/bloomz-560m",
+ "BLOOMZ-3B": "bigscience/bloomz-3b",
+ "BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt",
+ "Falcon-7B": "tiiuae/falcon-7b",
+ "Falcon-40B": "tiiuae/falcon-40b",
+ "Falcon-7B-Chat": "tiiuae/falcon-7b-instruct",
+ "Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
+ "Baichuan-7B": "baichuan-inc/Baichuan-7B",
+ "Baichuan-13B": "baichuan-inc/Baichuan-13B-Base",
+ "Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat",
+ "Baichuan2-7B": "baichuan-inc/Baichuan2-7B-Base",
+ "Baichuan2-13B": "baichuan-inc/Baichuan2-13B-Base",
+ "Baichuan2-7B-Chat": "baichuan-inc/Baichuan2-7B-Chat",
+ "Baichuan2-13B-Chat": "baichuan-inc/Baichuan2-13B-Chat",
+ "InternLM-7B": "internlm/internlm-7b",
+ "InternLM-20B": "internlm/internlm-20b",
+ "InternLM-7B-Chat": "internlm/internlm-chat-7b",
+ "InternLM-20B-Chat": "internlm/internlm-chat-20b",
+ "Qwen-7B": "Qwen/Qwen-7B",
+ "Qwen-14B": "Qwen/Qwen-14B",
+ "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
+ "Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
+ "XVERSE-13B": "xverse/XVERSE-13B",
+ "XVERSE-13B-Chat": "xverse/XVERSE-13B-Chat",
+ "ChatGLM2-6B-Chat": "THUDM/chatglm2-6b",
+ "ChatGLM3-6B-Base": "THUDM/chatglm3-6b-base",
+ "ChatGLM3-6B-Chat": "THUDM/chatglm3-6b",
+ "Phi1.5-1.3B": "microsoft/phi-1_5"
+}
+
+DEFAULT_MODULE = {
+ "LLaMA": "q_proj,v_proj",
+ "LLaMA2": "q_proj,v_proj",
+ "ChineseLLaMA2": "q_proj,v_proj",
+ "BLOOM": "query_key_value",
+ "BLOOMZ": "query_key_value",
+ "Falcon": "query_key_value",
+ "Baichuan": "W_pack",
+ "Baichuan2": "W_pack",
+ "InternLM": "q_proj,v_proj",
+ "Qwen": "c_attn",
+ "XVERSE": "q_proj,v_proj",
+ "ChatGLM2": "query_key_value",
+ "ChatGLM3": "query_key_value",
+ "Phi1.5": "Wqkv"
+}
+
+DEFAULT_TEMPLATE = {
+ "LLaMA2": "llama2",
+ "ChineseLLaMA2": "llama2_zh",
+ "Baichuan": "baichuan",
+ "Baichuan2": "baichuan2",
+ "InternLM": "intern",
+ "Qwen": "chatml",
+ "XVERSE": "xverse",
+ "ChatGLM2": "chatglm2",
+ "ChatGLM3": "chatglm3"
+}
diff --git a/llm_rl/src/llmtuner/extras/logging.py b/llm_rl/src/llmtuner/extras/logging.py
new file mode 100644
index 00000000..d6f185e6
--- /dev/null
+++ b/llm_rl/src/llmtuner/extras/logging.py
@@ -0,0 +1,43 @@
+import sys
+import logging
+
+
+class LoggerHandler(logging.Handler):
+
+ def __init__(self):
+ super().__init__()
+ self.log = ""
+
+ def reset(self):
+ self.log = ""
+
+ def emit(self, record):
+ if record.name == "httpx":
+ return
+ log_entry = self.format(record)
+ self.log += log_entry
+ self.log += "\n\n"
+
+
+def reset_logging():
+ r"""
+ Removes basic config of root logger
+ """
+ root = logging.getLogger()
+ list(map(root.removeHandler, root.handlers))
+ list(map(root.removeFilter, root.filters))
+
+
+def get_logger(name: str) -> logging.Logger:
+ formatter = logging.Formatter(
+ fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S"
+ )
+ handler = logging.StreamHandler(sys.stdout)
+ handler.setFormatter(formatter)
+
+ logger = logging.getLogger(name)
+ logger.setLevel(logging.INFO)
+ logger.addHandler(handler)
+
+ return logger
diff --git a/llm_rl/src/llmtuner/extras/misc.py b/llm_rl/src/llmtuner/extras/misc.py
new file mode 100644
index 00000000..960d43ee
--- /dev/null
+++ b/llm_rl/src/llmtuner/extras/misc.py
@@ -0,0 +1,118 @@
+import gc
+import torch
+from typing import TYPE_CHECKING, Tuple
+from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
+
+try:
+ from transformers.utils import (
+ is_torch_bf16_cpu_available,
+ is_torch_bf16_gpu_available,
+ is_torch_cuda_available,
+ is_torch_npu_available
+ )
+ _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
+ _is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available
+except ImportError:
+ _is_fp16_available = torch.cuda.is_available()
+ _is_bf16_available = torch.cuda.is_bf16_supported()
+
+if TYPE_CHECKING:
+ from transformers.modeling_utils import PreTrainedModel
+
+
+class AverageMeter:
+ r"""
+ Computes and stores the average and current value.
+ """
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+
+def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
+ r"""
+ Returns the number of trainable parameters and number of all parameters in the model.
+ """
+ trainable_params, all_param = 0, 0
+ for param in model.parameters():
+ num_params = param.numel()
+ # if using DS Zero 3 and the weights are initialized empty
+ if num_params == 0 and hasattr(param, "ds_numel"):
+ num_params = param.ds_numel
+
+ # Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
+ if param.__class__.__name__ == "Params4bit":
+ num_params = num_params * 2
+
+ all_param += num_params
+ if param.requires_grad:
+ trainable_params += num_params
+
+ return trainable_params, all_param
+
+
+def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
+ r"""
+ Infers the optimal dtype according to the model_dtype and device compatibility.
+ """
+ if _is_bf16_available and model_dtype == torch.bfloat16:
+ return torch.bfloat16
+ elif _is_fp16_available:
+ return torch.float16
+ else:
+ return torch.float32
+
+
+def get_logits_processor() -> LogitsProcessorList:
+ r"""
+ Gets logits processor that removes NaN and Inf logits.
+ """
+ logits_processor = LogitsProcessorList()
+ logits_processor.append(InfNanRemoveLogitsProcessor())
+ return logits_processor
+
+
+def torch_gc() -> None:
+ r"""
+ Collects GPU memory.
+ """
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.ipc_collect()
+
+
+def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
+ r"""
+ Dispatches a pre-trained model to GPUs with balanced memory.
+ Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
+ """
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
+ return model
+
+ if torch.cuda.device_count() > 1:
+ from accelerate import dispatch_model
+ from accelerate.utils import infer_auto_device_map, get_balanced_memory
+
+ if model._no_split_modules is None:
+ raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
+
+ kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
+ max_memory = get_balanced_memory(model, **kwargs)
+ # Make sure tied weights are tied before creating the device map.
+ model.tie_weights()
+ device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
+ return dispatch_model(model, device_map)
+ else:
+ return model.cuda()
diff --git a/llm_rl/src/llmtuner/extras/patches/__init__.py b/llm_rl/src/llmtuner/extras/patches/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/llm_rl/src/llmtuner/extras/patches/llama_patch.py b/llm_rl/src/llmtuner/extras/patches/llama_patch.py
new file mode 100644
index 00000000..a8473311
--- /dev/null
+++ b/llm_rl/src/llmtuner/extras/patches/llama_patch.py
@@ -0,0 +1,218 @@
+import math
+import torch
+import torch.nn as nn
+from typing import Optional, Tuple
+from transformers.utils import logging
+from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
+
+try:
+ from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
+ from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
+except ImportError:
+ print("FlashAttention-2 is not installed, ignore this if you are not using FlashAttention.")
+
+
+logger = logging.get_logger(__name__)
+
+
+# Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
+class LlamaShiftShortAttention(LlamaAttention):
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ **kwargs
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None: # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ if getattr(self, "num_key_value_groups"):
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ if getattr(self.config, "group_size_ratio", None) and self.training: # shift
+ groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
+ assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
+ num_groups = q_len // groupsz
+ def shift(state: torch.Tensor) -> torch.Tensor:
+ state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
+ state = torch.cat((
+ state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
+ ), dim=2)
+ return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
+
+ query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
+ attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
+ attn_output = torch.cat((
+ attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
+ ))
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class LlamaFlashAttention2(LlamaAttention):
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ **kwargs
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # LlamaFlashAttention2 attention does not support output_attentions
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None: # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ # cast to half precision
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ logger.warning_once("The input hidden states seems to be silently casted in float32.")
+ query_states = query_states.to(self.config.torch_dtype)
+ key_states = key_states.to(self.config.torch_dtype)
+ value_states = value_states.to(self.config.torch_dtype)
+
+ if getattr(self, "num_key_value_groups", None):
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
+ key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
+ value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
+
+ if getattr(self.config, "group_size_ratio", None) and self.training: # shift
+ groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
+ assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
+ num_groups = q_len // groupsz
+ def shift(state: torch.Tensor) -> torch.Tensor:
+ state = torch.cat((
+ state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
+ ), dim=2)
+ return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
+
+ query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
+ if attention_mask is not None:
+ attention_mask = attention_mask.reshape(bsz * num_groups, groupsz)
+
+ if attention_mask is not None:
+ logger.warning_once("Padded sequences are less efficient in FlashAttention.")
+ # -q_len: assumes left padding when q_len != kv_len
+ unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_states, attention_mask[:, -q_len:])
+ unpadded_k, _, cu_seqlens_k, max_seqlen_k = unpad_input(key_states, attention_mask)
+ unpadded_v, _, _, _ = unpad_input(value_states, attention_mask)
+ attn_output_unpad = flash_attn_varlen_func(
+ unpadded_q,
+ unpadded_k,
+ unpadded_v,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k=max_seqlen_k,
+ dropout_p=0.0,
+ softmax_scale=None,
+ causal=True,
+ )
+ attn_output = pad_input(attn_output_unpad, indices_q, bsz, q_len)
+ else:
+ attn_output = flash_attn_func(
+ query_states, key_states, value_states, 0.0, softmax_scale=None, causal=True
+ )
+
+ if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
+ attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
+ attn_output = torch.cat((
+ attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
+ ))
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+# Disable the transformation of the attention mask in LlamaModel as flash attention
+# takes a boolean padding_mask. Fills in the past kv length for use in forward.
+def _prepare_decoder_attention_mask(
+ self,
+ attention_mask: torch.Tensor,
+ input_shape: torch.Tensor,
+ inputs_embeds: torch.Tensor,
+ past_key_values_length: int
+) -> torch.Tensor:
+ if attention_mask is not None and torch.all(attention_mask):
+ return None # This uses the faster call when training with full samples
+
+ return attention_mask
diff --git a/llm_rl/src/llmtuner/extras/ploting.py b/llm_rl/src/llmtuner/extras/ploting.py
new file mode 100644
index 00000000..82530e45
--- /dev/null
+++ b/llm_rl/src/llmtuner/extras/ploting.py
@@ -0,0 +1,52 @@
+import os
+import math
+import json
+import matplotlib.pyplot as plt
+from typing import List, Optional
+from transformers.trainer import TRAINER_STATE_NAME
+
+from llmtuner.extras.logging import get_logger
+
+
+logger = get_logger(__name__)
+
+
+def smooth(scalars: List[float]) -> List[float]:
+ r"""
+ EMA implementation according to TensorBoard.
+ """
+ last = scalars[0]
+ smoothed = list()
+ weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
+ for next_val in scalars:
+ smoothed_val = last * weight + (1 - weight) * next_val
+ smoothed.append(smoothed_val)
+ last = smoothed_val
+ return smoothed
+
+
+def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
+
+ with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
+ data = json.load(f)
+
+ for key in keys:
+ steps, metrics = [], []
+ for i in range(len(data["log_history"])):
+ if key in data["log_history"][i]:
+ steps.append(data["log_history"][i]["step"])
+ metrics.append(data["log_history"][i][key])
+
+ if len(metrics) == 0:
+ logger.warning(f"No metric {key} to plot.")
+ continue
+
+ plt.figure()
+ plt.plot(steps, metrics, alpha=0.4, label="original")
+ plt.plot(steps, smooth(metrics), label="smoothed")
+ plt.title("training {} of {}".format(key, save_dictionary))
+ plt.xlabel("step")
+ plt.ylabel(key)
+ plt.legend()
+ plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100)
+ print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key)))
diff --git a/llm_rl/src/llmtuner/extras/save_and_load.py b/llm_rl/src/llmtuner/extras/save_and_load.py
new file mode 100644
index 00000000..6d819ce6
--- /dev/null
+++ b/llm_rl/src/llmtuner/extras/save_and_load.py
@@ -0,0 +1,21 @@
+import os
+import torch
+from transformers.trainer import WEIGHTS_NAME
+
+from llmtuner.extras.logging import get_logger
+
+
+logger = get_logger(__name__)
+
+
+def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
+ vhead_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
+ if not os.path.exists(vhead_file):
+ logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
+ return False
+ vhead_params = torch.load(vhead_file, map_location="cpu")
+ model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
+ model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
+ model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
+ model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
+ return True
diff --git a/llm_rl/src/llmtuner/extras/template.py b/llm_rl/src/llmtuner/extras/template.py
new file mode 100644
index 00000000..401750ce
--- /dev/null
+++ b/llm_rl/src/llmtuner/extras/template.py
@@ -0,0 +1,713 @@
+import tiktoken
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+
+from llmtuner.extras.logging import get_logger
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedTokenizer
+
+
+logger = get_logger(__name__)
+
+
+@dataclass
+class Template:
+
+ prefix: List[Union[str, Dict[str, str]]]
+ prompt: List[Union[str, Dict[str, str]]]
+ system: str
+ sep: List[Union[str, Dict[str, str]]]
+ stop_words: List[str]
+ use_history: bool
+ efficient_eos: bool
+
+ def encode_oneturn(
+ self,
+ tokenizer: "PreTrainedTokenizer",
+ query: str,
+ resp: str,
+ history: Optional[List[Tuple[str, str]]] = None,
+ system: Optional[str] = None
+ ) -> Tuple[List[int], List[int]]:
+ r"""
+ Returns a single pair of token ids representing prompt and response respectively.
+ """
+ system, history = self._format(query, resp, history, system)
+ encoded_pairs = self._encode(tokenizer, system, history)
+ prompt_ids = []
+ for query_ids, resp_ids in encoded_pairs[:-1]:
+ prompt_ids = prompt_ids + query_ids + resp_ids
+ prompt_ids, answer_ids = prompt_ids + encoded_pairs[-1][0], encoded_pairs[-1][1]
+ return prompt_ids, answer_ids
+
+ def encode_multiturn(
+ self,
+ tokenizer: "PreTrainedTokenizer",
+ query: str,
+ resp: str,
+ history: Optional[List[Tuple[str, str]]] = None,
+ system: Optional[str] = None
+ ) -> List[Tuple[List[int], List[int]]]:
+ r"""
+ Returns multiple pairs of token ids representing prompts and responses respectively.
+ """
+ system, history = self._format(query, resp, history, system)
+ encoded_pairs = self._encode(tokenizer, system, history)
+ return encoded_pairs
+
+ def _format(
+ self,
+ query: str,
+ resp: str,
+ history: Optional[List[Tuple[str, str]]] = None,
+ system: Optional[str] = None
+ ) -> Tuple[str, List[Tuple[str, str]]]:
+ r"""
+ Aligns inputs to the standard format.
+ """
+ system = system or self.system # use system if provided
+ history = history if (history and self.use_history) else []
+ history = history + [(query, resp)]
+ return system, history
+
+ def _get_special_ids(
+ self,
+ tokenizer: "PreTrainedTokenizer"
+ ) -> Tuple[List[int], List[int]]:
+ if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True):
+ bos_ids = [tokenizer.bos_token_id]
+ else: # baichuan, qwen and gpt2 models have no bos token
+ bos_ids = []
+
+ if tokenizer.eos_token_id is None:
+ raise ValueError("EOS token is required.")
+
+ if self.efficient_eos: # used in baichuan, qwen, chatglm, etc.
+ eos_ids = []
+ else:
+ eos_ids = [tokenizer.eos_token_id]
+
+ return bos_ids, eos_ids
+
+ def _encode(
+ self,
+ tokenizer: "PreTrainedTokenizer",
+ system: str,
+ history: List[Tuple[str, str]]
+ ) -> List[Tuple[List[int], List[int]]]:
+ r"""
+ Encodes formatted inputs to pairs of token ids.
+ Turn 0: bos + prefix + sep + query resp + eos
+ Turn t: sep + bos + query resp + eos
+ """
+ bos_ids, eos_ids = self._get_special_ids(tokenizer)
+ sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
+ encoded_pairs = []
+ for turn_idx, (query, resp) in enumerate(history):
+ if turn_idx == 0:
+ prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system)
+ if len(prefix_ids) != 0: # has prefix
+ prefix_ids = bos_ids + prefix_ids + sep_ids
+ else:
+ prefix_ids = bos_ids
+ else:
+ prefix_ids = sep_ids + bos_ids
+
+ query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query, idx=str(turn_idx))
+ resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
+ encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids))
+ return encoded_pairs
+
+ def _convert_inputs_to_ids(
+ self,
+ tokenizer: "PreTrainedTokenizer",
+ context: List[Union[str, Dict[str, str]]],
+ system: Optional[str] = None,
+ query: Optional[str] = None,
+ idx: Optional[str] = None
+ ) -> List[int]:
+ r"""
+ Converts context to token ids.
+ """
+ if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
+ kwargs = dict(allowed_special="all")
+ else:
+ kwargs = dict(add_special_tokens=False)
+
+ token_ids = []
+ for elem in context:
+ if isinstance(elem, str):
+ elem = elem.replace("{{system}}", system, 1) if system is not None else elem
+ elem = elem.replace("{{query}}", query, 1) if query is not None else elem
+ elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem
+ if len(elem) != 0:
+ token_ids = token_ids + tokenizer.encode(elem, **kwargs)
+ elif isinstance(elem, dict):
+ token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
+ else:
+ raise ValueError("Input must be string or dict[str, str], got {}".format(type(elem)))
+
+ return token_ids
+
+
+@dataclass
+class Llama2Template(Template):
+
+ def _encode(
+ self,
+ tokenizer: "PreTrainedTokenizer",
+ system: str,
+ history: List[Tuple[str, str]]
+ ) -> List[Tuple[List[int], List[int]]]:
+ r"""
+ Encodes formatted inputs to pairs of token ids.
+ Turn 0: bos + prefix + query resp + eos
+ Turn t: bos + query resp + eos
+ """
+ bos_ids, eos_ids = self._get_special_ids(tokenizer)
+ encoded_pairs = []
+ for turn_idx, (query, resp) in enumerate(history):
+ if turn_idx == 0: # llama2 template has no sep_ids
+ query = self.prefix[0].replace("{{system}}", system) + query
+ query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
+ resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
+ encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids))
+ return encoded_pairs
+
+
+templates: Dict[str, Template] = {}
+
+
+def register_template(
+ name: str,
+ prefix: List[Union[str, Dict[str, str]]],
+ prompt: List[Union[str, Dict[str, str]]],
+ system: str,
+ sep: List[Union[str, Dict[str, str]]],
+ stop_words: Optional[List[str]] = [],
+ use_history: Optional[bool] = True,
+ efficient_eos: Optional[bool] = False
+) -> None:
+ template_class = Llama2Template if "llama2" in name else Template
+ templates[name] = template_class(
+ prefix=prefix,
+ prompt=prompt,
+ system=system,
+ sep=sep,
+ stop_words=stop_words,
+ use_history=use_history,
+ efficient_eos=efficient_eos
+ )
+
+
+def get_template_and_fix_tokenizer(
+ name: str,
+ tokenizer: "PreTrainedTokenizer"
+) -> Template:
+ if tokenizer.eos_token_id is None:
+ tokenizer.eos_token = "<|endoftext|>"
+ logger.info("Add eos token: {}".format(tokenizer.eos_token))
+
+ if tokenizer.pad_token_id is None:
+ tokenizer.pad_token = tokenizer.eos_token
+ logger.info("Add pad token: {}".format(tokenizer.pad_token))
+
+ if name is None:
+ return None
+
+ template = templates.get(name, None)
+ assert template is not None, "Template {} does not exist.".format(name)
+ tokenizer.add_special_tokens(
+ dict(additional_special_tokens=template.stop_words),
+ replace_additional_special_tokens=False
+ )
+ return template
+
+
+r"""
+Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
+"""
+register_template(
+ name="alpaca",
+ prefix=[
+ "{{system}}"
+ ],
+ prompt=[
+ "### Instruction:\n{{query}}\n\n### Response:\n"
+ ],
+ system=(
+ "Below is an instruction that describes a task. "
+ "Write a response that appropriately completes the request."
+ ),
+ sep=[
+ "\n\n"
+ ]
+)
+
+
+r"""
+Supports: https://huggingface.co/BAAI/AquilaChat-7B
+ https://huggingface.co/BAAI/AquilaChat2-7B
+ https://huggingface.co/BAAI/AquilaChat2-34B
+"""
+register_template(
+ name="aquila",
+ prefix=[
+ "{{system}}"
+ ],
+ prompt=[
+ "Human: {{query}}###Assistant:"
+ ],
+ system=(
+ "A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions."
+ ),
+ sep=[
+ "###"
+ ],
+ stop_words=[
+ ""
+ ],
+ efficient_eos=True
+)
+
+
+r"""
+Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
+"""
+register_template(
+ name="baichuan",
+ prefix=[
+ "{{system}}"
+ ],
+ prompt=[
+ {"token": ""}, # user token
+ "{{query}}",
+ {"token": ""} # assistant token
+ ],
+ system="",
+ sep=[],
+ efficient_eos=True
+)
+
+
+r"""
+Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat
+ https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat
+"""
+register_template(
+ name="baichuan2",
+ prefix=[
+ "{{system}}"
+ ],
+ prompt=[
+ {"token": ""}, # user token
+ "{{query}}",
+ {"token": ""} # assistant token
+ ],
+ system="",
+ sep=[],
+ efficient_eos=True
+)
+
+
+r"""
+Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
+"""
+register_template(
+ name="belle",
+ prefix=[
+ "{{system}}"
+ ],
+ prompt=[
+ "Human: {{query}}\n\nBelle: "
+ ],
+ system="",
+ sep=[
+ "\n\n"
+ ]
+)
+
+
+r"""
+Supports: https://huggingface.co/vivo-ai/BlueLM-7B-Chat
+"""
+register_template(
+ name="bluelm",
+ prefix=[
+ "{{system}}"
+ ],
+ prompt=[
+ {"token": "[|Human|]:"},
+ "{{query}}",
+ {"token": "[|AI|]:"}
+ ],
+ system="",
+ sep=[]
+)
+
+
+r"""
+Supports: https://huggingface.co/THUDM/chatglm2-6b
+"""
+register_template(
+ name="chatglm2",
+ prefix=[
+ {"token": "[gMASK]"},
+ {"token": "sop"},
+ "{{system}}"
+ ],
+ prompt=[
+ "[Round {{idx}}]\n\n问:{{query}}\n\n答:"
+ ],
+ system="",
+ sep=[
+ "\n\n"
+ ],
+ efficient_eos=True
+)
+
+
+r"""
+Supports: https://huggingface.co/THUDM/chatglm3-6b
+"""
+register_template(
+ name="chatglm3",
+ prefix=[
+ {"token": "[gMASK]"},
+ {"token": "sop"},
+ "{{system}}"
+ ],
+ prompt=[
+ {"token": "<|user|>"},
+ "\n",
+ "{{query}}",
+ {"token": "<|assistant|>"}
+ ],
+ system="",
+ sep=[],
+ stop_words=[
+ "<|user|>",
+ "<|observation|>"
+ ],
+ efficient_eos=True
+)
+
+
+r"""
+Supports: https://huggingface.co/deepseek-ai/deepseek-coder-1.3b-instruct
+ https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct
+ https://huggingface.co/deepseek-ai/deepseek-coder-33b-instruct
+"""
+register_template(
+ name="deepseek",
+ prefix=[
+ "{{system}}"
+ ],
+ prompt=[
+ "### Instruction:\n{{query}}\n\n### Response:\n"
+ ],
+ system=(
+ "You are an AI programming assistant, utilizing the Deepseek Coder model, "
+ "developed by Deepseek Company, and you only answer questions related to computer science. "
+ "For politically sensitive questions, security and privacy issues, "
+ "and other non-computer science questions, you will refuse to answer."
+ ),
+ sep=[
+ "\n",
+ {"token": "<|EOT|>"},
+ "\n\n"
+ ],
+ stop_words=[
+ "<|EOT|>"
+ ],
+ efficient_eos=True
+)
+
+
+r"""
+Default template.
+"""
+register_template(
+ name="default",
+ prefix=[
+ "{{system}}"
+ ],
+ prompt=[
+ "Human: {{query}}\nAssistant:"
+ ],
+ system=(
+ "A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
+ ),
+ sep=[
+ "\n"
+ ]
+)
+
+
+r"""
+Supports: https://huggingface.co/internlm/internlm-chat-7b
+ https://huggingface.co/internlm/internlm-chat-20b
+"""
+register_template(
+ name="intern",
+ prefix=[
+ "{{system}}"
+ ],
+ prompt=[
+ "<|User|>:{{query}}",
+ {"token": ""},
+ "\n<|Bot|>:"
+ ],
+ system="",
+ sep=[
+ {"token": ""},
+ "\n"
+ ],
+ stop_words=[
+ ""
+ ],
+ efficient_eos=True
+)
+
+
+r"""
+Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
+ https://huggingface.co/meta-llama/Llama-2-13b-chat-hf
+ https://huggingface.co/meta-llama/Llama-2-70b-chat-hf
+"""
+register_template(
+ name="llama2",
+ prefix=[
+ "<>\n{{system}}\n<>\n\n"
+ ],
+ prompt=[
+ "[INST] {{query}} [/INST]"
+ ],
+ system=(
+ "You are a helpful, respectful and honest assistant. "
+ "Always answer as helpfully as possible, while being safe. "
+ "Your answers should not include any harmful, unethical, "
+ "racist, sexist, toxic, dangerous, or illegal content. "
+ "Please ensure that your responses are socially unbiased and positive in nature.\n\n"
+ "If a question does not make any sense, or is not factually coherent, "
+ "explain why instead of answering something not correct. "
+ "If you don't know the answer to a question, please don't share false information."
+ ),
+ sep=[]
+)
+
+
+r"""
+Supports: https://huggingface.co/ziqingyang/chinese-alpaca-2-7b
+ https://huggingface.co/ziqingyang/chinese-alpaca-2-13b
+"""
+register_template(
+ name="llama2_zh",
+ prefix=[
+ "<>\n{{system}}\n<>\n\n"
+ ],
+ prompt=[
+ "[INST] {{query}} [/INST]"
+ ],
+ system="You are a helpful assistant. 你是一个乐于助人的助手。",
+ sep=[]
+)
+
+
+r"""
+Supports: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
+"""
+register_template(
+ name="mistral",
+ prefix=[
+ "{{system}}"
+ ],
+ prompt=[
+ "[INST] {{query}} [/INST]"
+ ],
+ system="",
+ sep=[]
+)
+
+
+r"""
+Supports: https://huggingface.co/openchat/openchat_3.5
+"""
+register_template(
+ name="openchat",
+ prefix=[
+ "{{system}}"
+ ],
+ prompt=[
+ "GPT4 Correct User: {{query}}",
+ {"token": "<|end_of_turn|>"},
+ "GPT4 Correct Assistant:"
+ ],
+ system="You are a helpful assistant.",
+ sep=[
+ {"token": "<|end_of_turn|>"}
+ ],
+ stop_words=[
+ "<|end_of_turn|>"
+ ],
+ efficient_eos=True
+)
+
+
+r"""
+Supports: https://huggingface.co/Qwen/Qwen-7B-Chat
+ https://huggingface.co/Qwen/Qwen-14B-Chat
+"""
+register_template(
+ name="qwen",
+ prefix=[
+ {"token": "<|im_start|>"},
+ "system\n{{system}}"
+ ],
+ prompt=[
+ {"token": "<|im_start|>"},
+ "user\n{{query}}",
+ {"token": "<|im_end|>"},
+ "\n",
+ {"token": "<|im_start|>"},
+ "assistant\n"
+ ],
+ system="You are a helpful assistant.",
+ sep=[
+ {"token": "<|im_end|>"},
+ "\n"
+ ],
+ stop_words=[
+ "<|im_end|>"
+ ],
+ efficient_eos=True
+)
+
+
+r"""
+Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha
+ https://huggingface.co/HuggingFaceH4/starchat-beta
+"""
+register_template(
+ name="starchat",
+ prefix=[
+ {"token": "<|system|>"},
+ "\n{{system}}",
+ ],
+ prompt=[
+ {"token": "<|user|>"},
+ "\n{{query}}",
+ {"token": "<|end|>"},
+ "\n",
+ {"token": "<|assistant|>"}
+ ],
+ system="",
+ sep=[
+ {"token": "<|end|>"},
+ "\n"
+ ],
+ stop_words=[
+ "<|end|>"
+ ],
+ efficient_eos=True
+)
+
+
+r"""
+Supports language model inference without histories.
+"""
+register_template(
+ name="vanilla",
+ prefix=[],
+ prompt=[
+ "{{query}}"
+ ],
+ system="",
+ sep=[],
+ use_history=False
+)
+
+
+r"""
+Supports: https://huggingface.co/lmsys/vicuna-7b-v1.5
+ https://huggingface.co/lmsys/vicuna-13b-v1.5
+"""
+register_template(
+ name="vicuna",
+ prefix=[
+ "{{system}}"
+ ],
+ prompt=[
+ "USER: {{query}} ASSISTANT:"
+ ],
+ system=(
+ "A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
+ ),
+ sep=[]
+)
+
+
+r"""
+Supports: https://huggingface.co/xverse/XVERSE-7B-Chat
+ https://huggingface.co/xverse/XVERSE-13B-Chat
+"""
+register_template(
+ name="xverse",
+ prefix=[
+ "{{system}}"
+ ],
+ prompt=[
+ "Human: {{query}}\n\nAssistant: "
+ ],
+ system="",
+ sep=[]
+)
+
+
+r"""
+Supports: https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha
+ https://huggingface.co/HuggingFaceH4/zephyr-7b-beta
+"""
+register_template(
+ name="zephyr",
+ prefix=[
+ {"token": "<|system|>"},
+ "\n{{system}}",
+ {"token": ""}
+ ],
+ prompt=[
+ {"token": "<|user|>"},
+ "\n{{query}}",
+ {"token": ""},
+ {"token": "<|assistant|>"}
+ ],
+ system="You are a friendly chatbot who always responds in the style of a pirate",
+ sep=[]
+)
+
+
+r"""
+Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
+ https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1.1
+ https://huggingface.co/IDEA-CCNL/Ziya2-13B-Chat
+"""
+register_template(
+ name="ziya",
+ prefix=[
+ "{{system}}"
+ ],
+ prompt=[
+ {"token": ""},
+ ":{{query}}\n",
+ {"token": ""},
+ ":"
+ ],
+ system="",
+ sep=[
+ "\n"
+ ]
+)
diff --git a/llm_rl/src/llmtuner/hparams/__init__.py b/llm_rl/src/llmtuner/hparams/__init__.py
new file mode 100644
index 00000000..f0547cc5
--- /dev/null
+++ b/llm_rl/src/llmtuner/hparams/__init__.py
@@ -0,0 +1,4 @@
+from .data_args import DataArguments
+from .finetuning_args import FinetuningArguments
+from .generating_args import GeneratingArguments
+from .model_args import ModelArguments
diff --git a/llm_rl/src/llmtuner/hparams/data_args.py b/llm_rl/src/llmtuner/hparams/data_args.py
new file mode 100644
index 00000000..4c67dd65
--- /dev/null
+++ b/llm_rl/src/llmtuner/hparams/data_args.py
@@ -0,0 +1,169 @@
+import os
+import json
+from typing import List, Literal, Optional
+from dataclasses import dataclass, field
+
+
+@dataclass
+class DatasetAttr:
+
+ load_from: str
+ dataset_name: Optional[str] = None
+ dataset_sha1: Optional[str] = None
+ system_prompt: Optional[str] = None
+ subset: Optional[str] = None
+ ranking: Optional[bool] = False
+ formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
+
+ prompt: Optional[str] = "instruction"
+ query: Optional[str] = "input"
+ response: Optional[str] = "output"
+ history: Optional[str] = None
+ messages: Optional[str] = "conversations"
+ role: Optional[str] = "from"
+ content: Optional[str] = "value"
+
+ def __repr__(self) -> str:
+ return self.dataset_name
+
+
+@dataclass
+class DataArguments:
+ r"""
+ Arguments pertaining to what data we are going to input our model for training and evaluation.
+ """
+ template: Optional[str] = field(
+ default=None,
+ metadata={"help": "Which template to use for constructing prompts in training and inference."}
+ )
+ dataset: Optional[str] = field(
+ default=None,
+ metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}
+ )
+ dataset_dir: Optional[str] = field(
+ default="data",
+ metadata={"help": "The name of the folder containing datasets."}
+ )
+ split: Optional[str] = field(
+ default="train",
+ metadata={"help": "Which dataset split to use for training and evaluation."}
+ )
+ cutoff_len: Optional[int] = field(
+ default=1024,
+ metadata={"help": "The maximum length of the model inputs after tokenization."}
+ )
+ train_on_prompt: Optional[bool] = field(
+ default=False,
+ metadata={"help": "Whether to disable the mask on the prompt or not."}
+ )
+ streaming: Optional[bool] = field(
+ default=False,
+ metadata={"help": "Enable dataset streaming."}
+ )
+ buffer_size: Optional[int] = field(
+ default=16384,
+ metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}
+ )
+ mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
+ default="concat",
+ metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."}
+ )
+ interleave_probs: Optional[str] = field(
+ default=None,
+ metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."}
+ )
+ overwrite_cache: Optional[bool] = field(
+ default=False,
+ metadata={"help": "Overwrite the cached training and evaluation sets."}
+ )
+ preprocessing_num_workers: Optional[int] = field(
+ default=None,
+ metadata={"help": "The number of processes to use for the preprocessing."}
+ )
+ max_samples: Optional[int] = field(
+ default=None,
+ metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
+ )
+ eval_num_beams: Optional[int] = field(
+ default=None,
+ metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}
+ )
+ ignore_pad_token_for_loss: Optional[bool] = field(
+ default=True,
+ metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
+ )
+ system_prompt: Optional[str] = field(
+ default=None,
+ metadata={"help": "System prompt to add before the user query. Use `|` to separate multiple prompts in training."}
+ )
+ val_size: Optional[float] = field(
+ default=0,
+ metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
+ )
+ sft_packing: Optional[bool] = field(
+ default=False,
+ metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}
+ )
+ cache_path: Optional[str] = field(
+ default=None,
+ metadata={"help": "Path to save or load the preprocessed datasets."}
+ )
+
+ def __post_init__(self):
+ if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
+ raise ValueError("Streaming mode should have an integer val size.")
+
+ if self.streaming and self.max_samples is not None:
+ raise ValueError("`max_samples` is incompatible with `streaming`.")
+
+ if self.streaming and self.cache_path:
+ raise ValueError("`cache_path` is incompatible with `streaming`.")
+
+ def init_for_training(self, seed: int): # support mixing multiple datasets
+ self.seed = seed
+ dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
+ try:
+ with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
+ dataset_info = json.load(f)
+ except Exception:
+ if self.dataset is not None:
+ raise ValueError("Cannot find dataset_info.json in `dataset_dir`.")
+ dataset_info = None
+
+ prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
+ prompt_list = prompt_list * (len(dataset_names) // len(prompt_list))
+ assert len(prompt_list) == len(dataset_names), "Number of system prompts should be equal to datasets or 1."
+
+ if self.interleave_probs is not None:
+ self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")]
+
+ self.dataset_list: List[DatasetAttr] = []
+ for i, name in enumerate(dataset_names):
+ if name not in dataset_info:
+ raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
+
+ if "hf_hub_url" in dataset_info[name]:
+ dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
+ elif "script_url" in dataset_info[name]:
+ dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
+ else:
+ dataset_attr = DatasetAttr(
+ "file",
+ dataset_name=dataset_info[name]["file_name"],
+ dataset_sha1=dataset_info[name].get("file_sha1", None)
+ )
+
+ if "columns" in dataset_info[name]:
+ dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
+ dataset_attr.query = dataset_info[name]["columns"].get("query", None)
+ dataset_attr.response = dataset_info[name]["columns"].get("response", None)
+ dataset_attr.history = dataset_info[name]["columns"].get("history", None)
+ dataset_attr.messages = dataset_info[name]["columns"].get("messages", None)
+ dataset_attr.role = dataset_info[name]["columns"].get("role", None)
+ dataset_attr.content = dataset_info[name]["columns"].get("content", None)
+
+ dataset_attr.subset = dataset_info[name].get("subset", None)
+ dataset_attr.ranking = dataset_info[name].get("ranking", False)
+ dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
+ dataset_attr.system_prompt = prompt_list[i]
+ self.dataset_list.append(dataset_attr)
diff --git a/llm_rl/src/llmtuner/hparams/finetuning_args.py b/llm_rl/src/llmtuner/hparams/finetuning_args.py
new file mode 100644
index 00000000..d5ef323d
--- /dev/null
+++ b/llm_rl/src/llmtuner/hparams/finetuning_args.py
@@ -0,0 +1,107 @@
+import json
+from typing import Literal, Optional
+from dataclasses import asdict, dataclass, field
+
+
+@dataclass
+class FinetuningArguments:
+ r"""
+ Arguments pertaining to which techniques we are going to fine-tuning with.
+ """
+ stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
+ default="sft",
+ metadata={"help": "Which stage will be performed in training."}
+ )
+ finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field(
+ default="lora",
+ metadata={"help": "Which fine-tuning method to use."}
+ )
+ num_layer_trainable: Optional[int] = field(
+ default=3,
+ metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."}
+ )
+ name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
+ default="mlp",
+ metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
+ LLaMA choices: [\"mlp\", \"self_attn\"], \
+ BLOOM & Falcon & ChatGLM2 choices: [\"mlp\", \"self_attention\"], \
+ Qwen choices: [\"mlp\", \"attn\"], \
+ Phi-1.5 choices: [\"mlp\", \"mixer\"], \
+ LLaMA-2, Baichuan, InternLM, XVERSE choices: the same as LLaMA."}
+ )
+ lora_rank: Optional[int] = field(
+ default=8,
+ metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
+ )
+ lora_alpha: Optional[float] = field(
+ default=32.0,
+ metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."}
+ )
+ lora_dropout: Optional[float] = field(
+ default=0.1,
+ metadata={"help": "Dropout rate for the LoRA fine-tuning."}
+ )
+ lora_target: Optional[str] = field(
+ default=None,
+ metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
+ LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
+ BLOOM & Falcon & ChatGLM2 choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
+ Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
+ Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
+ Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \
+ LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
+ )
+ additional_target: Optional[str] = field(
+ default=None,
+ metadata={"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."}
+ )
+ resume_lora_training: Optional[bool] = field(
+ default=True,
+ metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
+ )
+ ppo_score_norm: Optional[bool] = field(
+ default=False,
+ metadata={"help": "Use score normalization in PPO training."}
+ )
+ ppo_logger: Optional[str] = field(
+ default=None,
+ metadata={"help": "Log with either 'wandb' or 'tensorboard' in PPO training."}
+ )
+ ppo_target: Optional[float] = field(
+ default=6.0,
+ metadata={"help": "Target KL value for adaptive KL control in PPO training."}
+ )
+ dpo_beta: Optional[float] = field(
+ default=0.1,
+ metadata={"help": "The beta parameter for the DPO loss."}
+ )
+ upcast_layernorm: Optional[bool] = field(
+ default=False,
+ metadata={"help": "Whether to upcast the layernorm weights in fp32."}
+ )
+ neft_alpha: Optional[float] = field(
+ default=0,
+ metadata={"help": "The alpha parameter to control the noise magnitude in NEFTune."}
+ )
+
+ def __post_init__(self):
+ if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA
+ self.lora_target = [target.strip() for target in self.lora_target.split(",")]
+
+ if isinstance(self.additional_target, str):
+ self.additional_target = [target.strip() for target in self.additional_target.split(",")]
+
+ assert self.finetuning_type in ["lora", "freeze", "full", "none"], "Invalid fine-tuning method."
+
+ def save_to_json(self, json_path: str):
+ r"""Saves the content of this instance in JSON format inside `json_path`."""
+ json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
+ with open(json_path, "w", encoding="utf-8") as f:
+ f.write(json_string)
+
+ @classmethod
+ def load_from_json(cls, json_path: str):
+ r"""Creates an instance from the content of `json_path`."""
+ with open(json_path, "r", encoding="utf-8") as f:
+ text = f.read()
+ return cls(**json.loads(text))
diff --git a/llm_rl/src/llmtuner/hparams/general_args.py b/llm_rl/src/llmtuner/hparams/general_args.py
new file mode 100644
index 00000000..c0c1a0de
--- /dev/null
+++ b/llm_rl/src/llmtuner/hparams/general_args.py
@@ -0,0 +1,13 @@
+from typing import Literal, Optional
+from dataclasses import dataclass, field
+
+
+@dataclass
+class GeneralArguments:
+ r"""
+ Arguments pertaining to which stage we are going to perform.
+ """
+ stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
+ default="sft",
+ metadata={"help": "Which stage will be performed in training."}
+ )
diff --git a/llm_rl/src/llmtuner/hparams/generating_args.py b/llm_rl/src/llmtuner/hparams/generating_args.py
new file mode 100644
index 00000000..c04a5c36
--- /dev/null
+++ b/llm_rl/src/llmtuner/hparams/generating_args.py
@@ -0,0 +1,53 @@
+from typing import Any, Dict, Optional
+from dataclasses import asdict, dataclass, field
+
+
+@dataclass
+class GeneratingArguments:
+ r"""
+ Arguments pertaining to specify the decoding parameters.
+ """
+ do_sample: Optional[bool] = field(
+ default=True,
+ metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
+ )
+ temperature: Optional[float] = field(
+ default=0.95,
+ metadata={"help": "The value used to modulate the next token probabilities."}
+ )
+ top_p: Optional[float] = field(
+ default=0.7,
+ metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."}
+ )
+ top_k: Optional[int] = field(
+ default=50,
+ metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
+ )
+ num_beams: Optional[int] = field(
+ default=1,
+ metadata={"help": "Number of beams for beam search. 1 means no beam search."}
+ )
+ max_length: Optional[int] = field(
+ default=512,
+ metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}
+ )
+ max_new_tokens: Optional[int] = field(
+ default=512,
+ metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
+ )
+ repetition_penalty: Optional[float] = field(
+ default=1.0,
+ metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
+ )
+ length_penalty: Optional[float] = field(
+ default=1.0,
+ metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}
+ )
+
+ def to_dict(self) -> Dict[str, Any]:
+ args = asdict(self)
+ if args.get("max_new_tokens", -1) > 0:
+ args.pop("max_length", None)
+ else:
+ args.pop("max_new_tokens", None)
+ return args
diff --git a/llm_rl/src/llmtuner/hparams/model_args.py b/llm_rl/src/llmtuner/hparams/model_args.py
new file mode 100644
index 00000000..7c25fad1
--- /dev/null
+++ b/llm_rl/src/llmtuner/hparams/model_args.py
@@ -0,0 +1,93 @@
+from typing import Literal, Optional
+from dataclasses import dataclass, field
+
+
+@dataclass
+class ModelArguments:
+ r"""
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
+ """
+ model_name_or_path: str = field(
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."}
+ )
+ cache_dir: Optional[str] = field(
+ default=None,
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
+ )
+ use_fast_tokenizer: Optional[bool] = field(
+ default=True,
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
+ )
+ split_special_tokens: Optional[bool] = field(
+ default=False,
+ metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}
+ )
+ use_auth_token: Optional[bool] = field(
+ default=False,
+ metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
+ )
+ model_revision: Optional[str] = field(
+ default="main",
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
+ )
+ quantization_bit: Optional[int] = field(
+ default=None,
+ metadata={"help": "The number of bits to quantize the model."}
+ )
+ quantization_type: Optional[Literal["fp4", "nf4"]] = field(
+ default="nf4",
+ metadata={"help": "Quantization data type to use in int4 training."}
+ )
+ double_quantization: Optional[bool] = field(
+ default=True,
+ metadata={"help": "Whether to use double quantization in int4 training or not."}
+ )
+ rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
+ default=None,
+ metadata={"help": "Adopt scaled rotary positional embeddings."}
+ )
+ checkpoint_dir: Optional[str] = field(
+ default=None,
+ metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
+ )
+ flash_attn: Optional[bool] = field(
+ default=False,
+ metadata={"help": "Enable FlashAttention-2 for faster training."}
+ )
+ shift_attn: Optional[bool] = field(
+ default=False,
+ metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
+ )
+ reward_model: Optional[str] = field(
+ default=None,
+ metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
+ )
+ plot_loss: Optional[bool] = field(
+ default=False,
+ metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
+ )
+ hf_auth_token: Optional[str] = field(
+ default=None,
+ metadata={"help": "Auth token to log in with Hugging Face Hub."}
+ )
+ export_dir: Optional[str] = field(
+ default=None,
+ metadata={"help": "Path to the directory to save the exported model."}
+ )
+
+ def __post_init__(self):
+ self.compute_dtype = None
+ self.model_max_length = None
+
+ if self.split_special_tokens and self.use_fast_tokenizer:
+ raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
+
+ if self.checkpoint_dir is not None: # support merging multiple lora weights
+ self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
+
+ if self.quantization_bit is not None:
+ assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
+
+ if self.use_auth_token == True and self.hf_auth_token is not None:
+ from huggingface_hub.hf_api import HfFolder # lazy load
+ HfFolder.save_token(self.hf_auth_token)
diff --git a/llm_rl/src/llmtuner/tuner/__init__.py b/llm_rl/src/llmtuner/tuner/__init__.py
new file mode 100644
index 00000000..4d5a83e4
--- /dev/null
+++ b/llm_rl/src/llmtuner/tuner/__init__.py
@@ -0,0 +1 @@
+from llmtuner.tuner.tune import export_model, run_exp
diff --git a/llm_rl/src/llmtuner/tuner/core/__init__.py b/llm_rl/src/llmtuner/tuner/core/__init__.py
new file mode 100644
index 00000000..bd1c5cf0
--- /dev/null
+++ b/llm_rl/src/llmtuner/tuner/core/__init__.py
@@ -0,0 +1,2 @@
+from llmtuner.tuner.core.parser import get_train_args, get_infer_args
+from llmtuner.tuner.core.loader import load_model_and_tokenizer
diff --git a/llm_rl/src/llmtuner/tuner/core/adapter.py b/llm_rl/src/llmtuner/tuner/core/adapter.py
new file mode 100644
index 00000000..4fcc6e62
--- /dev/null
+++ b/llm_rl/src/llmtuner/tuner/core/adapter.py
@@ -0,0 +1,101 @@
+import torch
+from typing import TYPE_CHECKING
+
+from peft import (
+ PeftModel,
+ TaskType,
+ LoraConfig,
+ get_peft_model
+)
+
+from llmtuner.extras.logging import get_logger
+from llmtuner.tuner.core.utils import find_all_linear_modules
+
+if TYPE_CHECKING:
+ from transformers.modeling_utils import PreTrainedModel
+ from llmtuner.hparams import ModelArguments, FinetuningArguments
+
+
+logger = get_logger(__name__)
+
+
+def init_adapter(
+ model: "PreTrainedModel",
+ model_args: "ModelArguments",
+ finetuning_args: "FinetuningArguments",
+ is_trainable: bool,
+ is_mergeable: bool
+) -> "PreTrainedModel":
+ r"""
+ Initializes the adapters.
+
+ Support full-parameter, freeze and LoRA training.
+
+ Note that the trainable parameters must be cast to float32.
+ """
+
+ if finetuning_args.finetuning_type == "none" and is_trainable:
+ raise ValueError("You cannot use finetuning_type=none while training.")
+
+ if finetuning_args.finetuning_type == "full" and is_trainable:
+ logger.info("Fine-tuning method: Full")
+ model = model.float()
+
+ if finetuning_args.finetuning_type == "freeze":
+ logger.info("Fine-tuning method: Freeze")
+ num_layers = getattr(model.config, "num_layers")
+ if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
+ trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)]
+ else: # fine-tuning the first n layers if num_layer_trainable < 0
+ trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)]
+
+ trainable_layers = ["{:d}.{}".format(idx, finetuning_args.name_module_trainable) for idx in trainable_layer_ids]
+ for name, param in model.named_parameters():
+ if not any(trainable_layer in name for trainable_layer in trainable_layers):
+ param.requires_grad_(False)
+ else:
+ param.data = param.data.to(torch.float32)
+
+ if finetuning_args.finetuning_type == "lora":
+ logger.info("Fine-tuning method: LoRA")
+ latest_checkpoint = None
+
+ if model_args.checkpoint_dir is not None:
+ if (is_trainable and finetuning_args.resume_lora_training) or (not is_mergeable): # continually fine-tuning
+ checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
+ else:
+ checkpoints_to_merge = model_args.checkpoint_dir
+
+ for checkpoint in checkpoints_to_merge:
+ model = PeftModel.from_pretrained(model, checkpoint)
+ model = model.merge_and_unload()
+
+ if len(checkpoints_to_merge) > 0:
+ logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
+
+ if latest_checkpoint is not None: # resume lora training or quantized inference
+ model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable)
+
+ if is_trainable and latest_checkpoint is None: # create new lora weights while training
+ if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
+ target_modules = find_all_linear_modules(model, model_args.quantization_bit)
+ else:
+ target_modules = finetuning_args.lora_target
+
+ lora_config = LoraConfig(
+ task_type=TaskType.CAUSAL_LM,
+ inference_mode=False,
+ r=finetuning_args.lora_rank,
+ lora_alpha=finetuning_args.lora_alpha,
+ lora_dropout=finetuning_args.lora_dropout,
+ target_modules=target_modules,
+ modules_to_save=finetuning_args.additional_target
+ )
+ model = get_peft_model(model, lora_config)
+ if id(model.peft_config) != id(model.base_model.peft_config): # https://github.com/huggingface/peft/issues/923
+ model.base_model.peft_config = model.peft_config
+
+ if model_args.checkpoint_dir is not None:
+ logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
+
+ return model
diff --git a/llm_rl/src/llmtuner/tuner/core/loader.py b/llm_rl/src/llmtuner/tuner/core/loader.py
new file mode 100644
index 00000000..e77c4945
--- /dev/null
+++ b/llm_rl/src/llmtuner/tuner/core/loader.py
@@ -0,0 +1,244 @@
+import os
+import math
+import torch
+from types import MethodType
+from typing import TYPE_CHECKING, Literal, Optional, Tuple
+
+from transformers import (
+ AutoConfig,
+ AutoModelForCausalLM,
+ AutoTokenizer,
+ BitsAndBytesConfig,
+ PretrainedConfig,
+ PreTrainedModel,
+ PreTrainedTokenizerBase
+)
+from transformers.models.llama import modeling_llama as LlamaModule
+from transformers.utils.versions import require_version
+from trl import AutoModelForCausalLMWithValueHead
+
+try:
+ from transformers.integrations import is_deepspeed_zero3_enabled
+except ImportError: # https://github.com/huggingface/transformers/releases/tag/v4.33.1
+ from transformers.deepspeed import is_deepspeed_zero3_enabled
+
+from llmtuner.extras.logging import reset_logging, get_logger
+from llmtuner.extras.misc import count_parameters, infer_optim_dtype
+from llmtuner.extras.patches import llama_patch as LlamaPatches
+from llmtuner.extras.save_and_load import load_valuehead_params
+from llmtuner.hparams import FinetuningArguments
+from llmtuner.tuner.core.adapter import init_adapter
+from llmtuner.tuner.core.utils import prepare_model_for_training
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedTokenizer
+ from llmtuner.hparams import ModelArguments
+
+
+logger = get_logger(__name__)
+
+
+require_version("transformers>=4.31.0,<4.35.0", "To fix: pip install \"transformers>=4.31.0,<4.35.0\"")
+require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
+require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
+require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
+require_version("trl>=0.7.2", "To fix: pip install trl>=0.7.2")
+
+
+def load_model_and_tokenizer(
+ model_args: "ModelArguments",
+ finetuning_args: "FinetuningArguments",
+ is_trainable: Optional[bool] = False,
+ stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
+) -> Tuple[PreTrainedModel, "PreTrainedTokenizer"]:
+ r"""
+ Loads pretrained model and tokenizer.
+
+ Support both training and inference.
+ """
+ if (not is_trainable) and model_args.checkpoint_dir is None:
+ logger.warning("Checkpoint is not found at evaluation, load the original model.")
+ finetuning_args = FinetuningArguments(finetuning_type="none")
+
+ config_kwargs = {
+ "trust_remote_code": True,
+ "cache_dir": model_args.cache_dir,
+ "revision": model_args.model_revision,
+ "use_auth_token": True if model_args.use_auth_token else None,
+ }
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ use_fast=model_args.use_fast_tokenizer,
+ split_special_tokens=model_args.split_special_tokens,
+ padding_side="right", # training with left-padded tensors in fp16 precision may cause overflow
+ **config_kwargs
+ )
+
+ if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
+ model_to_load = model_args.checkpoint_dir[0]
+ else:
+ model_to_load = model_args.model_name_or_path
+
+ config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
+
+ # Fix tokenizer (for ChatGLM2 and ChatGLM3)
+ if getattr(config, "model_type", None) == "chatglm":
+ tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
+
+ # Set model dtype
+ if model_args.compute_dtype is not None: # for training
+ setattr(config, "torch_dtype", model_args.compute_dtype)
+ else: # for evaluation, priority: bf16 > fp16 > fp32
+ model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
+
+ # Fix config (for Qwen)
+ if getattr(config, "model_type", None) == "qwen":
+ for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
+ setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype)
+
+ # Set RoPE scaling
+ if model_args.rope_scaling is not None:
+ if hasattr(config, "use_dynamic_ntk"): # for Qwen models
+ if is_trainable:
+ logger.warning("Qwen model does not support RoPE scaling in training.")
+ else:
+ setattr(config, "use_dynamic_ntk", True)
+ setattr(config, "use_logn_attn", True)
+ logger.info("Using dynamic NTK scaling.")
+
+ elif hasattr(config, "rope_scaling"): # for LLaMA and Falcon models
+ if is_trainable:
+ if model_args.rope_scaling == "dynamic":
+ logger.warning(
+ "Dynamic NTK may not work well with fine-tuning. "
+ "See: https://github.com/huggingface/transformers/pull/24653"
+ )
+
+ current_max_length = getattr(config, "max_position_embeddings", None)
+ if current_max_length and model_args.model_max_length > current_max_length:
+ scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
+ else:
+ logger.warning("Input length is smaller than max length. Consider increase input length.")
+ scaling_factor = 1.0
+ else:
+ scaling_factor = 2.0
+
+ setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
+ logger.info("Using {} scaling strategy and setting scaling factor to {}".format(
+ model_args.rope_scaling, scaling_factor
+ ))
+
+ else:
+ logger.warning("Current model does not support RoPE scaling.")
+
+ # Set FlashAttention-2
+ if model_args.flash_attn:
+ if getattr(config, "model_type", None) == "llama":
+ LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
+ LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
+ logger.info("Using FlashAttention-2 for faster training and inference.")
+ elif getattr(config, "model_type", None) == "qwen":
+ logger.info("Qwen models automatically enable FlashAttention if installed.")
+ else:
+ logger.warning("Current model does not support FlashAttention-2.")
+ elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
+ LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention
+ logger.warning("Using `--flash_attn` for faster training in large context length.")
+
+ # Set shift short attention (S^2-Attn)
+ if is_trainable and model_args.shift_attn:
+ if getattr(config, "model_type", None) == "llama":
+ setattr(config, "group_size_ratio", 0.25)
+ logger.info("Using shift short attention with group_size_ratio=1/4.")
+ else:
+ logger.warning("Current model does not support shift short attention.")
+
+ # Quantization configurations (using bitsandbytes library).
+ is_mergeable = True
+ if model_args.quantization_bit is not None:
+ if is_deepspeed_zero3_enabled():
+ raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
+
+ if model_args.quantization_bit == 8:
+ require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
+ config_kwargs["load_in_8bit"] = True
+ config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
+
+ elif model_args.quantization_bit == 4:
+ require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
+ config_kwargs["load_in_4bit"] = True
+ config_kwargs["quantization_config"] = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_compute_dtype=model_args.compute_dtype,
+ bnb_4bit_use_double_quant=model_args.double_quantization,
+ bnb_4bit_quant_type=model_args.quantization_type
+ )
+
+ is_mergeable = False
+ config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto"
+ logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
+
+ # Load and prepare pre-trained models (without valuehead).
+ model = AutoModelForCausalLM.from_pretrained(
+ model_to_load,
+ config=config,
+ torch_dtype=model_args.compute_dtype,
+ low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
+ **config_kwargs
+ )
+
+ # Disable custom generate method (for Qwen and Baichuan2)
+ if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
+ model.generate = MethodType(PreTrainedModel.generate, model)
+
+ # Fix LM head (for ChatGLM2 and ChatGLM3)
+ if getattr(config, "model_type", None) == "chatglm":
+ setattr(model, "lm_head", model.transformer.output_layer)
+ setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
+
+ # Register auto class to save the custom code files.
+ if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
+ config.__class__.register_for_auto_class()
+ if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
+ model.__class__.register_for_auto_class()
+ if isinstance(tokenizer, PreTrainedTokenizerBase) and "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
+ tokenizer.__class__.register_for_auto_class()
+
+ # Initialize adapters
+ model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model
+ model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
+ model = model.train() if is_trainable else model.eval()
+
+ # Prepare model with valuehead for RLHF
+ if stage == "rm" or stage == "ppo":
+ model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
+ reset_logging()
+ if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
+ logger.warning("Only the last checkpoint containing valuehead will be loaded.")
+ if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
+ model.v_head.load_state_dict({
+ "summary.weight": getattr(model, "reward_head_weight"),
+ "summary.bias": getattr(model, "reward_head_bias")
+ })
+
+ if stage == "ppo": # load reward model
+ logger.info("Load reward model from {}".format(model_args.reward_model))
+ if getattr(model, "is_peft_model", False):
+ model.pretrained_model.load_adapter(model_args.reward_model, "reward")
+ assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
+
+ # Prepare model for inference
+ if not is_trainable:
+ model.requires_grad_(False) # fix all model params
+ model = model.to(model_args.compute_dtype) if model_args.quantization_bit is None else model
+
+ trainable_params, all_param = count_parameters(model)
+ logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
+ trainable_params, all_param, 100 * trainable_params / all_param
+ ))
+
+ if not is_trainable:
+ logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")
+
+ return model, tokenizer
diff --git a/llm_rl/src/llmtuner/tuner/core/parser.py b/llm_rl/src/llmtuner/tuner/core/parser.py
new file mode 100644
index 00000000..603fc1bc
--- /dev/null
+++ b/llm_rl/src/llmtuner/tuner/core/parser.py
@@ -0,0 +1,226 @@
+import os
+import sys
+import torch
+import datasets
+import transformers
+from typing import Any, Dict, Optional, Tuple
+from transformers import HfArgumentParser, Seq2SeqTrainingArguments
+from transformers.trainer_utils import get_last_checkpoint
+
+from llmtuner.extras.logging import get_logger
+from llmtuner.hparams import (
+ ModelArguments,
+ DataArguments,
+ FinetuningArguments,
+ GeneratingArguments
+)
+
+
+logger = get_logger(__name__)
+
+
+def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
+ if args is not None:
+ return parser.parse_dict(args)
+ elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
+ return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
+ elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
+ return parser.parse_json_file(os.path.abspath(sys.argv[1]))
+ else:
+ return parser.parse_args_into_dataclasses()
+
+
+def parse_train_args(
+ args: Optional[Dict[str, Any]] = None
+) -> Tuple[
+ ModelArguments,
+ DataArguments,
+ Seq2SeqTrainingArguments,
+ FinetuningArguments,
+ GeneratingArguments
+]:
+ parser = HfArgumentParser((
+ ModelArguments,
+ DataArguments,
+ Seq2SeqTrainingArguments,
+ FinetuningArguments,
+ GeneratingArguments
+ ))
+ return _parse_args(parser, args)
+
+
+def parse_infer_args(
+ args: Optional[Dict[str, Any]] = None
+) -> Tuple[
+ ModelArguments,
+ DataArguments,
+ FinetuningArguments,
+ GeneratingArguments
+]:
+ parser = HfArgumentParser((
+ ModelArguments,
+ DataArguments,
+ FinetuningArguments,
+ GeneratingArguments
+ ))
+ return _parse_args(parser, args)
+
+
+def get_train_args(
+ args: Optional[Dict[str, Any]] = None
+) -> Tuple[
+ ModelArguments,
+ DataArguments,
+ Seq2SeqTrainingArguments,
+ FinetuningArguments,
+ GeneratingArguments
+]:
+ model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args)
+
+ # Setup logging
+ if training_args.should_log:
+ # The default of training_args.log_level is passive, so we set log level at info here to have that default.
+ transformers.utils.logging.set_verbosity_info()
+
+ log_level = training_args.get_process_log_level()
+ datasets.utils.logging.set_verbosity(log_level)
+ transformers.utils.logging.set_verbosity(log_level)
+ transformers.utils.logging.enable_default_handler()
+ transformers.utils.logging.enable_explicit_format()
+
+ # Check arguments
+ data_args.init_for_training(training_args.seed)
+
+ if finetuning_args.stage != "pt" and data_args.template is None:
+ raise ValueError("Please specify which `template` to use.")
+
+ if finetuning_args.stage != "sft" and training_args.predict_with_generate:
+ raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
+
+ if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
+ raise ValueError("Please enable `predict_with_generate` to save model predictions.")
+
+ if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora":
+ raise ValueError("RM and PPO stages can only be performed with the LoRA method.")
+
+ if finetuning_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None:
+ raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
+
+ if finetuning_args.stage == "ppo" and not training_args.do_train:
+ raise ValueError("PPO training does not support evaluation.")
+
+ if finetuning_args.stage in ["rm", "dpo"]:
+ for dataset_attr in data_args.dataset_list:
+ if not dataset_attr.ranking:
+ raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
+
+ if finetuning_args.stage == "ppo" and model_args.reward_model is None:
+ raise ValueError("Reward model is necessary for PPO training.")
+
+ if finetuning_args.stage == "ppo" and model_args.shift_attn:
+ raise ValueError("PPO training is incompatible with S^2-Attn.")
+
+ if training_args.max_steps == -1 and data_args.streaming:
+ raise ValueError("Please specify `max_steps` in streaming mode.")
+
+ if training_args.do_train and training_args.predict_with_generate:
+ raise ValueError("`predict_with_generate` cannot be set as True while training.")
+
+ if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None:
+ raise ValueError("Please specify `lora_target` in LoRA training.")
+
+ if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
+ raise ValueError("Quantization is only compatible with the LoRA method.")
+
+ if model_args.checkpoint_dir is not None:
+ if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) != 1:
+ raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
+
+ if model_args.quantization_bit is not None:
+ if len(model_args.checkpoint_dir) != 1:
+ raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.")
+
+ if not finetuning_args.resume_lora_training:
+ raise ValueError("Quantized model cannot create new LoRA weight. Merge them first.")
+
+ if training_args.do_train and model_args.quantization_bit is not None and (not finetuning_args.upcast_layernorm):
+ logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
+
+ if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
+ logger.warning("We recommend enable mixed precision training.")
+
+ if (not training_args.do_train) and model_args.quantization_bit is not None:
+ logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
+
+ # postprocess training_args
+ if (
+ training_args.local_rank != -1
+ and training_args.ddp_find_unused_parameters is None
+ and finetuning_args.finetuning_type == "lora"
+ ):
+ logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
+ training_args_dict = training_args.to_dict()
+ training_args_dict.update(dict(ddp_find_unused_parameters=False))
+ training_args = Seq2SeqTrainingArguments(**training_args_dict)
+
+ if (
+ training_args.resume_from_checkpoint is None
+ and training_args.do_train
+ and os.path.isdir(training_args.output_dir)
+ and not training_args.overwrite_output_dir
+ ):
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
+ raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
+
+ if last_checkpoint is not None:
+ training_args_dict = training_args.to_dict()
+ training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint))
+ training_args = Seq2SeqTrainingArguments(**training_args_dict)
+ logger.info(
+ "Resuming from checkpoint. Change `output_dir` or use `overwrite_output_dir` to avoid."
+ )
+
+ # postprocess model_args
+ model_args.compute_dtype = (
+ torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
+ )
+ model_args.model_max_length = data_args.cutoff_len
+
+ # Log on each process the small summary:
+ logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
+ training_args.local_rank, training_args.device, training_args.n_gpu,
+ bool(training_args.local_rank != -1), str(model_args.compute_dtype)
+ ))
+ logger.info(f"Training/evaluation parameters {training_args}")
+
+ # Set seed before initializing model.
+ transformers.set_seed(training_args.seed)
+
+ return model_args, data_args, training_args, finetuning_args, generating_args
+
+
+def get_infer_args(
+ args: Optional[Dict[str, Any]] = None
+) -> Tuple[
+ ModelArguments,
+ DataArguments,
+ FinetuningArguments,
+ GeneratingArguments
+]:
+ model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
+
+ if data_args.template is None:
+ raise ValueError("Please specify which `template` to use.")
+
+ if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
+ raise ValueError("Quantization is only compatible with the LoRA method.")
+
+ if model_args.checkpoint_dir is not None:
+ if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) != 1:
+ raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
+
+ if model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
+ raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.")
+
+ return model_args, data_args, finetuning_args, generating_args
diff --git a/llm_rl/src/llmtuner/tuner/core/utils.py b/llm_rl/src/llmtuner/tuner/core/utils.py
new file mode 100644
index 00000000..d9a1aac9
--- /dev/null
+++ b/llm_rl/src/llmtuner/tuner/core/utils.py
@@ -0,0 +1,94 @@
+import torch
+from types import MethodType
+from typing import TYPE_CHECKING, List, Optional
+
+from llmtuner.extras.constants import LAYERNORM_NAMES
+from llmtuner.extras.logging import get_logger
+
+if TYPE_CHECKING:
+ from transformers.modeling_utils import PreTrainedModel
+ from llmtuner.hparams import FinetuningArguments
+
+
+logger = get_logger(__name__)
+
+
+def find_all_linear_modules(
+ model: "PreTrainedModel",
+ quantization_bit: Optional[int] = None,
+ output_layer_name: Optional[str] = "lm_head"
+) -> List[str]:
+ if quantization_bit is not None:
+ import bitsandbytes as bnb
+ linear_cls = bnb.nn.Linear4bit if quantization_bit == 4 else bnb.nn.Linear8bitLt
+ else:
+ linear_cls = torch.nn.Linear
+
+ module_names = set()
+ for name, module in model.named_modules():
+ if output_layer_name not in name and isinstance(module, linear_cls):
+ module_names.add(name.split(".")[-1])
+
+ if output_layer_name in module_names:
+ module_names.pop(output_layer_name)
+
+ return list(module_names)
+
+
+def prepare_model_for_training(
+ model: "PreTrainedModel",
+ finetuning_args: "FinetuningArguments",
+ output_layer_name: Optional[str] = "lm_head",
+ use_gradient_checkpointing: Optional[bool] = True,
+ layernorm_names: Optional[List[str]] = LAYERNORM_NAMES
+) -> "PreTrainedModel":
+ r"""
+ Includes:
+ (1) cast the layernorm in fp32
+ (2) make output embedding layer require grads
+ (3) upcast the lm_head to fp32
+ Inspired by: https://github.com/huggingface/peft/blob/v0.2.0/src/peft/utils/other.py#L33
+ """
+ if finetuning_args.upcast_layernorm:
+ for name, param in model.named_parameters():
+ if param.ndim == 1 and any(ln_name in name for ln_name in layernorm_names):
+ param.data = param.data.to(torch.float32)
+ logger.info("Upcasting weights in layernorm in float32.")
+
+ if finetuning_args.neft_alpha > 1e-6:
+ input_embed = model.get_input_embeddings()
+ if isinstance(input_embed, torch.nn.Embedding):
+ def noisy_forward(self: torch.nn.Embedding, x: torch.Tensor) -> torch.Tensor:
+ embeddings = input_embed.__class__.forward(self, x)
+ if self.training:
+ dims = self.num_embeddings * self.embedding_dim
+ mag_norm = finetuning_args.neft_alpha / (dims ** 0.5)
+ embeddings += torch.zeros_like(embeddings).uniform_(-mag_norm, mag_norm)
+ return embeddings
+
+ input_embed.forward = MethodType(noisy_forward, input_embed)
+ logger.info("Using noisy embedding with alpha={:.2f}".format(finetuning_args.neft_alpha))
+ else:
+ logger.warning("Input embeddings are not normal nn.Embedding, cannot transform into noisy embedding.")
+
+ if use_gradient_checkpointing:
+ if hasattr(model, "enable_input_require_grads"):
+ model.enable_input_require_grads()
+ else:
+ def make_inputs_require_grad(module: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
+ output.requires_grad_(True)
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
+
+ model.gradient_checkpointing_enable()
+ model.config.use_cache = False # turn off when gradient checkpointing is enabled
+ logger.info("Gradient checkpointing enabled.")
+
+ if finetuning_args.finetuning_type != "full" and hasattr(model, output_layer_name):
+ output_layer = getattr(model, output_layer_name)
+ if isinstance(output_layer, torch.nn.Linear):
+ def forward_in_fp32(self, x: torch.Tensor) -> torch.Tensor:
+ return output_layer.__class__.forward(self, x.to(output_layer.weight.dtype)).to(torch.float32)
+
+ output_layer.forward = MethodType(forward_in_fp32, output_layer)
+
+ return model
diff --git a/llm_rl/src/llmtuner/tuner/dpo/__init__.py b/llm_rl/src/llmtuner/tuner/dpo/__init__.py
new file mode 100644
index 00000000..f2b5cfb5
--- /dev/null
+++ b/llm_rl/src/llmtuner/tuner/dpo/__init__.py
@@ -0,0 +1 @@
+from llmtuner.tuner.dpo.workflow import run_dpo
diff --git a/llm_rl/src/llmtuner/tuner/dpo/collator.py b/llm_rl/src/llmtuner/tuner/dpo/collator.py
new file mode 100644
index 00000000..5c862b4f
--- /dev/null
+++ b/llm_rl/src/llmtuner/tuner/dpo/collator.py
@@ -0,0 +1,51 @@
+import torch
+from dataclasses import dataclass
+from typing import Any, Dict, List, Sequence, Tuple
+from transformers import DataCollatorForSeq2Seq
+
+
+@dataclass
+class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
+ r"""
+ Data collator for pairwise data.
+ """
+
+ def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor:
+ padded_labels = []
+ for feature, (prompt_len, answer_len) in zip(batch, positions):
+ if self.tokenizer.padding_side == "left":
+ start, end = feature.size(0) - answer_len, feature.size(0)
+ else:
+ start, end = prompt_len, prompt_len + answer_len
+ padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
+ padded_tensor[start:end] = feature[start:end]
+ padded_labels.append(padded_tensor)
+ return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
+
+ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
+ r"""
+ Pads batched data to the longest sequence in the batch.
+
+ We generate 2 * n examples where the first n examples represent chosen examples and
+ the last n examples represent rejected examples.
+ """
+ concatenated_features = []
+ label_positions = []
+ for key in ("chosen_ids", "rejected_ids"):
+ for feature in features:
+ prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
+ concatenated_features.append({
+ "input_ids": feature["prompt_ids"] + feature[key],
+ "attention_mask": [1] * (prompt_len + answer_len)
+ })
+ label_positions.append((prompt_len, answer_len))
+
+ batch = self.tokenizer.pad(
+ concatenated_features,
+ padding=self.padding,
+ max_length=self.max_length,
+ pad_to_multiple_of=self.pad_to_multiple_of,
+ return_tensors=self.return_tensors,
+ )
+ batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
+ return batch
diff --git a/llm_rl/src/llmtuner/tuner/dpo/trainer.py b/llm_rl/src/llmtuner/tuner/dpo/trainer.py
new file mode 100644
index 00000000..8a9f8dd6
--- /dev/null
+++ b/llm_rl/src/llmtuner/tuner/dpo/trainer.py
@@ -0,0 +1,104 @@
+import torch
+import deepspeed # type: ignore
+from copy import deepcopy
+from collections import defaultdict
+from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
+from transformers import BatchEncoding, Trainer
+from trl import DPOTrainer
+from trl.trainer.utils import disable_dropout_in_model
+
+from llmtuner.extras.constants import IGNORE_INDEX
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedModel
+ from trl import PreTrainedModelWrapper
+
+
+class CustomDPOTrainer(DPOTrainer):
+
+ def __init__(
+ self,
+ beta: float,
+ model: Union["PreTrainedModel", torch.nn.Module],
+ ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
+ disable_dropout: Optional[bool] = True,
+ loss_type: Optional[Literal["sigmoid", "hinge"]] = "sigmoid",
+ **kwargs
+ ):
+ if disable_dropout:
+ disable_dropout_in_model(model)
+ if ref_model is not None:
+ disable_dropout_in_model(ref_model)
+
+ self.is_encoder_decoder = model.config.is_encoder_decoder
+ self.ref_model = ref_model
+ self.use_dpo_data_collator = True # hack to avoid warning
+ self.generate_during_eval = False # disable at evaluation
+ self.label_pad_token_id = IGNORE_INDEX
+ self.padding_value = 0
+ self.beta = beta
+ self.loss_type = loss_type
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
+
+ Trainer.__init__(self, model=model, **kwargs)
+ if not hasattr(self, "accelerator"):
+ raise AttributeError("Please update `transformers`.")
+
+ if ref_model is not None:
+ if self.is_deepspeed_enabled:
+ self.ref_model = self._prepare_deepspeed(self.ref_model)
+ else:
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
+
+ def _prepare_deepspeed(self, model: "PreTrainedModelWrapper"):
+ # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
+ deepspeed_plugin = self.accelerator.state.deepspeed_plugin
+ config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
+ if model is not None:
+ if hasattr(model, "config"):
+ hidden_size = (
+ max(model.config.hidden_sizes)
+ if getattr(model.config, "hidden_sizes", None)
+ else getattr(model.config, "hidden_size", None)
+ )
+ if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
+ # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
+ # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
+ config_kwargs.update(
+ {
+ "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
+ "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
+ "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
+ }
+ )
+
+ # If ZeRO-3 is used, we shard both the active and reference model.
+ # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
+ if config_kwargs["zero_optimization"]["stage"] != 3:
+ config_kwargs["zero_optimization"]["stage"] = 0
+ model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
+ model.eval()
+ return model
+
+ def concatenated_forward(
+ self,
+ model: Optional[torch.nn.Module] = None,
+ batch: Optional[Dict[str, torch.Tensor]] = None
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
+ batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
+
+ all_logits = model(
+ input_ids=batch_copied["input_ids"],
+ attention_mask=batch_copied["attention_mask"],
+ return_dict=True
+ ).logits.to(torch.float32)
+
+ all_logps = self._get_batch_logps(
+ all_logits,
+ batch["labels"],
+ average_log_prob=False
+ )
+ batch_size = batch["input_ids"].size(0) // 2
+ chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
+ chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
+ return chosen_logps, rejected_logps, chosen_logits, rejected_logits
diff --git a/llm_rl/src/llmtuner/tuner/dpo/workflow.py b/llm_rl/src/llmtuner/tuner/dpo/workflow.py
new file mode 100644
index 00000000..6e16dd18
--- /dev/null
+++ b/llm_rl/src/llmtuner/tuner/dpo/workflow.py
@@ -0,0 +1,66 @@
+# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
+
+from copy import deepcopy
+from peft import PeftModel
+from typing import TYPE_CHECKING, Optional, List
+from transformers import Seq2SeqTrainingArguments
+
+from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
+from llmtuner.extras.constants import IGNORE_INDEX
+from llmtuner.extras.ploting import plot_loss
+from llmtuner.tuner.core import load_model_and_tokenizer
+from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
+from llmtuner.tuner.dpo.trainer import CustomDPOTrainer
+
+if TYPE_CHECKING:
+ from transformers import TrainerCallback
+ from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
+
+
+def run_dpo(
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
+ finetuning_args: "FinetuningArguments",
+ callbacks: Optional[List["TrainerCallback"]] = None
+):
+ dataset = get_dataset(model_args, data_args)
+ model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
+ dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
+ data_collator = DPODataCollatorWithPadding(
+ tokenizer=tokenizer,
+ pad_to_multiple_of=4,
+ label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
+ )
+
+ training_args_dict = training_args.to_dict()
+ training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
+ training_args = Seq2SeqTrainingArguments(**training_args_dict)
+
+ # Initialize our Trainer
+ trainer = CustomDPOTrainer(
+ beta=finetuning_args.dpo_beta,
+ model=model,
+ ref_model=deepcopy(model) if not isinstance(model, PeftModel) else None,
+ args=training_args,
+ tokenizer=tokenizer,
+ data_collator=data_collator,
+ callbacks=callbacks,
+ **split_dataset(dataset, data_args, training_args)
+ )
+
+ # Training
+ if training_args.do_train:
+ train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
+ trainer.log_metrics("train", train_result.metrics)
+ trainer.save_metrics("train", train_result.metrics)
+ trainer.save_state()
+ trainer.save_model()
+ if trainer.is_world_process_zero() and model_args.plot_loss:
+ plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
+
+ # Evaluation
+ if training_args.do_eval:
+ metrics = trainer.evaluate(metric_key_prefix="eval")
+ trainer.log_metrics("eval", metrics)
+ trainer.save_metrics("eval", metrics)
diff --git a/llm_rl/src/llmtuner/tuner/ppo/__init__.py b/llm_rl/src/llmtuner/tuner/ppo/__init__.py
new file mode 100644
index 00000000..11519bab
--- /dev/null
+++ b/llm_rl/src/llmtuner/tuner/ppo/__init__.py
@@ -0,0 +1 @@
+from llmtuner.tuner.ppo.workflow import run_ppo
diff --git a/llm_rl/src/llmtuner/tuner/ppo/trainer.py b/llm_rl/src/llmtuner/tuner/ppo/trainer.py
new file mode 100644
index 00000000..372c4891
--- /dev/null
+++ b/llm_rl/src/llmtuner/tuner/ppo/trainer.py
@@ -0,0 +1,310 @@
+import os
+import sys
+import math
+import torch
+from tqdm import tqdm
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+
+from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
+from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
+
+from trl import PPOTrainer
+from trl.core import PPODecorators, logprobs_from_logits
+
+from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback
+from llmtuner.extras.logging import get_logger
+from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
+from llmtuner.tuner.ppo.utils import dump_layernorm, restore_layernorm, replace_model
+
+if TYPE_CHECKING:
+ from transformers import Seq2SeqTrainingArguments, TrainerCallback
+ from trl import AutoModelForCausalLMWithValueHead
+ from llmtuner.hparams import ModelArguments, FinetuningArguments, GeneratingArguments
+
+
+logger = get_logger(__name__)
+
+
+class CustomPPOTrainer(PPOTrainer, Trainer):
+ r"""
+ Inherits PPOTrainer.
+ """
+
+ def __init__(
+ self,
+ model_args: "ModelArguments",
+ training_args: "Seq2SeqTrainingArguments",
+ finetuning_args: "FinetuningArguments",
+ generating_args: "GeneratingArguments",
+ callbacks: List["TrainerCallback"],
+ **kwargs
+ ):
+ PPOTrainer.__init__(self, **kwargs)
+ self.args = training_args
+ self.model_args = model_args
+ self.finetuning_args = finetuning_args
+ self.generation_config = GenerationConfig(
+ pad_token_id=self.tokenizer.pad_token_id,
+ eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
+ **generating_args.to_dict()
+ )
+ self.state = TrainerState()
+ self.control = TrainerControl()
+ self.log_callback, self.save_callback = callbacks[0], callbacks[1]
+ assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback)
+ if self.args.max_steps > 0:
+ logger.info("max_steps is given, it will override any value given in num_train_epochs")
+
+ def ppo_train(self) -> None:
+ r"""
+ Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
+ """
+ total_train_batch_size = (
+ self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.args.world_size
+ )
+ if self.args.max_steps > 0:
+ num_examples = total_train_batch_size * self.args.max_steps
+ num_train_epochs = sys.maxsize
+ max_steps = self.args.max_steps
+ steps_in_epoch = self.args.max_steps * self.args.gradient_accumulation_steps
+ else:
+ len_dataloader = len(self.dataloader)
+ num_examples = len(self.dataset)
+ num_train_epochs = self.args.num_train_epochs
+ max_steps = math.ceil(num_train_epochs * len_dataloader)
+ steps_in_epoch = len_dataloader
+
+ self.state.max_steps = max_steps
+ self.state.num_train_epochs = num_train_epochs
+ self.state.is_local_process_zero = self.is_local_process_zero()
+ self.state.is_world_process_zero = self.is_world_process_zero()
+
+ if self.is_world_process_zero():
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {num_examples}")
+ logger.info(f" Num Epochs = {num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {max_steps}")
+ logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
+
+ unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
+ dataiter = iter(self.dataloader)
+ loss_meter = AverageMeter()
+ reward_meter = AverageMeter()
+ self.log_callback.on_train_begin(self.args, self.state, self.control)
+
+ for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
+ try:
+ batch = next(dataiter)
+ except StopIteration:
+ dataiter = iter(self.dataloader)
+ batch = next(dataiter)
+
+ # Cast to inference mode
+ unwrapped_model.gradient_checkpointing_disable()
+ unwrapped_model.config.use_cache = True
+ self.model.eval()
+
+ # Get inputs
+ queries, responses = self.get_inputs(batch)
+ self.tokenizer.padding_side = "right" # change padding side
+ rewards = self.get_rewards(queries, responses, unwrapped_model)
+
+ # Cast to training mode
+ unwrapped_model.gradient_checkpointing_enable()
+ unwrapped_model.config.use_cache = False
+ self.model.train()
+
+ # Run PPO step
+ stats = self.step(queries, responses, rewards)
+ self.tokenizer.padding_side = "left" # restore padding side
+ loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards))
+ reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
+
+ if self.config.log_with is not None:
+ try:
+ batch["query"] = self.tokenizer.batch_decode(queries, skip_special_tokens=True)
+ batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
+ self.log_stats(stats, batch, rewards)
+ except:
+ logger.warning("Failed to save stats due to unknown errors.")
+
+ self.state.global_step += 1
+ self.log_callback.on_step_end(self.args, self.state, self.control)
+
+ if self.is_local_process_zero() and (step+1) % self.args.logging_steps == 0:
+ logs = dict(
+ loss=round(loss_meter.avg, 4),
+ reward=round(reward_meter.avg, 4),
+ learning_rate=stats["ppo/learning_rate"],
+ epoch=round(step / steps_in_epoch, 2)
+ )
+ tqdm.write(str(logs))
+ logs["step"] = step
+ self.state.log_history.append(logs)
+ self.log_callback.on_log(self.args, self.state, self.control)
+ loss_meter.reset()
+ reward_meter.reset()
+
+ if (step+1) % self.args.save_steps == 0: # save checkpoint
+ self.save_model(os.path.join(
+ self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step)
+ ))
+ self.save_callback.on_save(
+ self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
+ )
+
+ if self.control.should_epoch_stop or self.control.should_training_stop:
+ break
+
+ self.log_callback.on_train_end(self.args, self.state, self.control)
+ self.save_callback.on_train_end(
+ self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
+ )
+
+ @torch.no_grad()
+ def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
+ r"""
+ Generates model's responses given queries.
+ """
+ if self.finetuning_args.upcast_layernorm:
+ layernorm_params = dump_layernorm(self.model)
+
+ unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
+ response: torch.Tensor = unwrapped_model.generate(
+ generation_config=self.generation_config,
+ logits_processor=get_logits_processor(),
+ **batch
+ )
+
+ if self.finetuning_args.upcast_layernorm:
+ restore_layernorm(self.model, layernorm_params)
+
+ query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
+ queries, responses = [], []
+ for i in range(len(query)):
+ query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
+ response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
+
+ if len(response_index) == 0:
+ response_length = 1 # allow empty response
+ elif self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
+ response_length = response_index[-1].item() + 2 # save the EOS token
+ else:
+ response_length = response_index[-1].item() + 1
+
+ queries.append(query[i, query_length:]) # remove padding from left
+ responses.append(response[i, :response_length]) # remove padding from right
+
+ return queries, responses
+
+ @torch.no_grad()
+ def get_rewards(
+ self,
+ queries: List[torch.Tensor],
+ responses: List[torch.Tensor],
+ unwrapped_model: "AutoModelForCausalLMWithValueHead"
+ ) -> List[torch.Tensor]:
+ r"""
+ Computes scores using given reward model.
+ """
+ replace_model(unwrapped_model, target="reward")
+ batch = self.prepare_model_inputs(queries, responses)
+
+ with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
+ _, _, values = self.model(**batch, output_hidden_states=True, return_dict=True)
+
+ if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2
+ values = torch.transpose(values, 0, 1)
+
+ rewards = []
+ for i in range(values.size(0)):
+ end_index = batch["attention_mask"][i].nonzero()[-1].item() # use the score on the EOS token
+ rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
+
+ replace_model(unwrapped_model, target="default")
+ return rewards
+
+ @PPODecorators.empty_cuda_cache()
+ def batched_forward_pass(
+ self,
+ model: "AutoModelForCausalLMWithValueHead",
+ queries: torch.Tensor,
+ responses: torch.Tensor,
+ model_inputs: dict,
+ return_logits: Optional[bool] = False,
+ response_masks: Optional[torch.Tensor] = None
+ ):
+ r"""
+ Calculates model outputs in multiple batches.
+
+ Subclass and override to inject custom behavior.
+ """
+ bs = len(queries)
+ fbs = self.config.mini_batch_size
+ all_logprobs = []
+ all_logits = []
+ all_masks = []
+ all_values = []
+
+ for i in range(math.ceil(bs / fbs)):
+ input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
+ query_batch = queries[i * fbs : (i + 1) * fbs]
+ response_batch = responses[i * fbs : (i + 1) * fbs]
+ if response_masks is not None:
+ response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
+ input_ids = input_kwargs["input_ids"]
+ attention_mask = input_kwargs["attention_mask"]
+
+ with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
+ logits, _, values = model(**input_kwargs)
+
+ if values.size(0) != input_ids.size(0): # adapt to chatglm2
+ values = torch.transpose(values, 0, 1)
+
+ logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
+ masks = torch.zeros_like(attention_mask)
+ masks[:, :-1] = attention_mask[:, 1:]
+
+ for j in range(len(query_batch)):
+ start = len(query_batch[j]) - 1
+ if attention_mask[j, 0] == 0: # offset left padding
+ start += attention_mask[j, :].nonzero()[0].item()
+ end = start + len(response_batch[j])
+
+ if response_masks is not None:
+ response_masks_batch = torch.cat(
+ (torch.zeros_like(query_batch[j]), response_masks_batch[j])
+ )[1:]
+
+ masks[j, :start] = 0
+ masks[j, end:] = 0
+ if response_masks is not None:
+ masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
+
+ if return_logits:
+ all_logits.append(logits)
+ else:
+ del logits
+
+ all_values.append(values)
+ all_logprobs.append(logprobs)
+ all_masks.append(masks)
+
+ return (
+ torch.cat(all_logprobs),
+ torch.cat(all_logits)[:, :-1] if return_logits else None,
+ torch.cat(all_values)[:, :-1],
+ torch.cat(all_masks)[:, :-1],
+ )
+
+ def save_model(self, output_dir: Optional[str] = None) -> None:
+ r"""
+ Saves model checkpoint.
+
+ Subclass and override to inject custom behavior.
+ """
+ if self.args.should_save:
+ self._save(output_dir)
diff --git a/llm_rl/src/llmtuner/tuner/ppo/utils.py b/llm_rl/src/llmtuner/tuner/ppo/utils.py
new file mode 100644
index 00000000..74453a39
--- /dev/null
+++ b/llm_rl/src/llmtuner/tuner/ppo/utils.py
@@ -0,0 +1,35 @@
+import torch
+from typing import TYPE_CHECKING, Dict, Literal, Optional
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedModel
+ from trl import AutoModelForCausalLMWithValueHead
+
+
+def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
+ if target == "reward": # save default head temporarily
+ valuehead_state_dict: Dict[str, torch.Tensor] = model.v_head.state_dict()
+ setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"].detach().clone())
+ setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"].detach().clone())
+
+ model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
+ model.v_head.load_state_dict({
+ "summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(),
+ "summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone()
+ })
+
+
+def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
+ layer_norm_params = {}
+ for name, param in model.named_parameters():
+ if param.data.dtype == torch.float32:
+ layer_norm_params[name] = param.data.detach().clone()
+ param.data = param.data.to(model.config.torch_dtype)
+
+ return layer_norm_params
+
+
+def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, torch.Tensor]] = None) -> None:
+ for name, param in model.named_parameters():
+ if name in layernorm_params:
+ param.data = layernorm_params[name]
diff --git a/llm_rl/src/llmtuner/tuner/ppo/workflow.py b/llm_rl/src/llmtuner/tuner/ppo/workflow.py
new file mode 100644
index 00000000..4c35f628
--- /dev/null
+++ b/llm_rl/src/llmtuner/tuner/ppo/workflow.py
@@ -0,0 +1,92 @@
+# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
+
+import math
+from trl import PPOConfig
+from torch.optim import AdamW
+from typing import TYPE_CHECKING, Optional, List
+from transformers import DataCollatorWithPadding
+from transformers.optimization import get_scheduler
+
+from llmtuner.dsets import get_dataset, preprocess_dataset
+from llmtuner.extras.callbacks import SavePeftModelCallback
+from llmtuner.extras.ploting import plot_loss
+from llmtuner.tuner.core import load_model_and_tokenizer
+from llmtuner.tuner.ppo.trainer import CustomPPOTrainer
+
+if TYPE_CHECKING:
+ from transformers import Seq2SeqTrainingArguments, TrainerCallback
+ from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
+
+
+def run_ppo(
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
+ finetuning_args: "FinetuningArguments",
+ generating_args: "GeneratingArguments",
+ callbacks: Optional[List["TrainerCallback"]] = None
+):
+ dataset = get_dataset(model_args, data_args)
+ model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
+ dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo")
+
+ tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
+
+ ppo_config = PPOConfig(
+ model_name=model_args.model_name_or_path,
+ learning_rate=training_args.learning_rate,
+ mini_batch_size=training_args.per_device_train_batch_size,
+ batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps,
+ gradient_accumulation_steps=training_args.gradient_accumulation_steps,
+ ppo_epochs=1,
+ max_grad_norm=training_args.max_grad_norm,
+ seed=training_args.seed,
+ optimize_cuda_cache=True,
+ target=finetuning_args.ppo_target,
+ log_with=finetuning_args.ppo_logger,
+ use_score_scaling=finetuning_args.ppo_score_norm,
+ use_score_norm=finetuning_args.ppo_score_norm,
+ accelerator_kwargs={"step_scheduler_with_optimizer": False}
+ )
+
+ optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
+ if training_args.max_steps > 0:
+ num_training_steps = training_args.max_steps
+ else:
+ total_train_batch_size = (
+ training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
+ )
+ num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
+
+ lr_scheduler = get_scheduler(
+ training_args.lr_scheduler_type,
+ optimizer=optimizer,
+ num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
+ num_training_steps=num_training_steps
+ )
+
+ # Initialize our Trainer
+ ppo_trainer = CustomPPOTrainer(
+ model_args=model_args,
+ training_args=training_args,
+ finetuning_args=finetuning_args,
+ generating_args=generating_args,
+ callbacks=callbacks + [SavePeftModelCallback()],
+ config=ppo_config,
+ model=model,
+ ref_model=None,
+ tokenizer=tokenizer,
+ dataset=dataset,
+ data_collator=data_collator,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler
+ )
+
+ # Training
+ if training_args.do_train:
+ ppo_trainer.ppo_train()
+ ppo_trainer.save_model()
+ ppo_trainer.save_state() # must be called after save_model to have a folder
+ if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
+ plot_loss(training_args.output_dir, keys=["loss", "reward"])
diff --git a/llm_rl/src/llmtuner/tuner/pt/__init__.py b/llm_rl/src/llmtuner/tuner/pt/__init__.py
new file mode 100644
index 00000000..8ce509db
--- /dev/null
+++ b/llm_rl/src/llmtuner/tuner/pt/__init__.py
@@ -0,0 +1 @@
+from llmtuner.tuner.pt.workflow import run_pt
diff --git a/llm_rl/src/llmtuner/tuner/pt/workflow.py b/llm_rl/src/llmtuner/tuner/pt/workflow.py
new file mode 100644
index 00000000..66d08de7
--- /dev/null
+++ b/llm_rl/src/llmtuner/tuner/pt/workflow.py
@@ -0,0 +1,58 @@
+# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py
+
+import math
+from typing import TYPE_CHECKING, Optional, List
+from transformers import DataCollatorForLanguageModeling, Trainer
+
+from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
+from llmtuner.extras.ploting import plot_loss
+from llmtuner.tuner.core import load_model_and_tokenizer
+
+if TYPE_CHECKING:
+ from transformers import Seq2SeqTrainingArguments, TrainerCallback
+ from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
+
+
+def run_pt(
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
+ finetuning_args: "FinetuningArguments",
+ callbacks: Optional[List["TrainerCallback"]] = None
+):
+ dataset = get_dataset(model_args, data_args)
+ model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")
+ dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt")
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
+
+ # Initialize our Trainer
+ trainer = Trainer(
+ model=model,
+ args=training_args,
+ tokenizer=tokenizer,
+ data_collator=data_collator,
+ callbacks=callbacks,
+ **split_dataset(dataset, data_args, training_args)
+ )
+
+ # Training
+ if training_args.do_train:
+ train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
+ trainer.log_metrics("train", train_result.metrics)
+ trainer.save_metrics("train", train_result.metrics)
+ trainer.save_state()
+ trainer.save_model()
+ if trainer.is_world_process_zero() and model_args.plot_loss:
+ plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
+
+ # Evaluation
+ if training_args.do_eval:
+ metrics = trainer.evaluate(metric_key_prefix="eval")
+ try:
+ perplexity = math.exp(metrics["eval_loss"])
+ except OverflowError:
+ perplexity = float("inf")
+
+ metrics["perplexity"] = perplexity
+ trainer.log_metrics("eval", metrics)
+ trainer.save_metrics("eval", metrics)
diff --git a/llm_rl/src/llmtuner/tuner/rm/__init__.py b/llm_rl/src/llmtuner/tuner/rm/__init__.py
new file mode 100644
index 00000000..54d3d943
--- /dev/null
+++ b/llm_rl/src/llmtuner/tuner/rm/__init__.py
@@ -0,0 +1 @@
+from llmtuner.tuner.rm.workflow import run_rm
diff --git a/llm_rl/src/llmtuner/tuner/rm/collator.py b/llm_rl/src/llmtuner/tuner/rm/collator.py
new file mode 100644
index 00000000..161f003d
--- /dev/null
+++ b/llm_rl/src/llmtuner/tuner/rm/collator.py
@@ -0,0 +1,27 @@
+import torch
+from dataclasses import dataclass
+from typing import Any, Dict, Sequence
+from transformers import DataCollatorWithPadding
+
+
+@dataclass
+class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
+ r"""
+ Data collator for pairwise data.
+ """
+
+ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
+ r"""
+ Pads batched data to the longest sequence in the batch.
+
+ We generate 2 * n examples where the first n examples represent chosen examples and
+ the last n examples represent rejected examples.
+ """
+ features = [
+ {
+ "input_ids": feature["prompt_ids"] + feature[key],
+ "attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key]))
+ }
+ for key in ("chosen_ids", "rejected_ids") for feature in features
+ ]
+ return super().__call__(features)
diff --git a/llm_rl/src/llmtuner/tuner/rm/metric.py b/llm_rl/src/llmtuner/tuner/rm/metric.py
new file mode 100644
index 00000000..db9c9243
--- /dev/null
+++ b/llm_rl/src/llmtuner/tuner/rm/metric.py
@@ -0,0 +1,7 @@
+import numpy as np
+from typing import Dict, Sequence, Tuple, Union
+
+
+def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
+ preds, _ = eval_preds
+ return {"accuracy": (preds[0] > preds[1]).sum() / len(preds[0])}
diff --git a/llm_rl/src/llmtuner/tuner/rm/trainer.py b/llm_rl/src/llmtuner/tuner/rm/trainer.py
new file mode 100644
index 00000000..80502937
--- /dev/null
+++ b/llm_rl/src/llmtuner/tuner/rm/trainer.py
@@ -0,0 +1,105 @@
+import os
+import json
+import torch
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+from transformers import Trainer
+
+from llmtuner.extras.logging import get_logger
+
+if TYPE_CHECKING:
+ from transformers.trainer import PredictionOutput
+ from transformers.modeling_utils import PreTrainedModel
+
+
+logger = get_logger(__name__)
+
+
+class PairwiseTrainer(Trainer):
+ r"""
+ Inherits PeftTrainer to compute pairwise loss.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.can_return_loss = True # override property to return eval_loss
+
+ def compute_loss(
+ self,
+ model: "PreTrainedModel",
+ inputs: Dict[str, torch.Tensor],
+ return_outputs: Optional[bool] = False
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
+ r"""
+ Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
+
+ Subclass and override to inject custom behavior.
+
+ Note that the first element will be removed from the output tuple.
+ See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
+ """
+ # Compute rewards
+ _, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
+ if values.size(0) != inputs["input_ids"].size(0): # adapt to chatglm2
+ values = torch.transpose(values, 0, 1)
+
+ # Split the inputs and rewards into two parts, chosen and rejected
+ batch_size = inputs["input_ids"].size(0) // 2
+ chosen_input_ids, rejected_input_ids = inputs["input_ids"][:batch_size], inputs["input_ids"][batch_size:]
+ chosen_attn_mask, rejected_attn_mask = (
+ inputs["attention_mask"][:batch_size], inputs["attention_mask"][batch_size:]
+ )
+ chosen_rewards, rejected_rewards = values[:batch_size], values[batch_size:]
+ chosen_scores, rejected_scores = [], []
+
+ # Compute pairwise loss. Only backprop on the different tokens before padding
+ # Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py
+ loss = 0
+ for i in range(batch_size):
+ chosen_length = chosen_attn_mask[i].nonzero()[-1] + 1
+ rejected_length = rejected_attn_mask[i].nonzero()[-1] + 1
+ check_divergence = (chosen_input_ids[i] != rejected_input_ids[i]).nonzero()
+
+ if len(check_divergence) == 0:
+ end_index = chosen_length
+ div_index = end_index - 1
+ else:
+ end_index = max(chosen_length, rejected_length)
+ div_index = check_divergence[0]
+
+ assert div_index > 0
+ chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
+ rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
+ if return_outputs: # use the score on the EOS token for inference
+ chosen_scores.append(chosen_rewards[i, chosen_length-1])
+ rejected_scores.append(rejected_rewards[i, rejected_length-1])
+ loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
+
+ loss = loss / batch_size
+ if return_outputs:
+ chosen_scores, rejected_scores = torch.stack(chosen_scores), torch.stack(rejected_scores)
+ return loss, [loss, chosen_scores, rejected_scores]
+
+ return loss
+
+ def save_predictions(
+ self,
+ predict_results: "PredictionOutput"
+ ) -> None:
+ r"""
+ Saves model predictions to `output_dir`.
+
+ A custom behavior that not contained in Seq2SeqTrainer.
+ """
+ if not self.is_world_process_zero():
+ return
+
+ output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
+ logger.info(f"Saving prediction results to {output_prediction_file}")
+
+ chosen_scores, rejected_scores = predict_results.predictions
+
+ with open(output_prediction_file, "w", encoding="utf-8") as writer:
+ res: List[str] = []
+ for c_score, r_score in zip(chosen_scores, rejected_scores):
+ res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))
+ writer.write("\n".join(res))
diff --git a/llm_rl/src/llmtuner/tuner/rm/workflow.py b/llm_rl/src/llmtuner/tuner/rm/workflow.py
new file mode 100644
index 00000000..6d2c4422
--- /dev/null
+++ b/llm_rl/src/llmtuner/tuner/rm/workflow.py
@@ -0,0 +1,68 @@
+# Inspired by:
+# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
+
+from typing import TYPE_CHECKING, Optional, List
+from transformers import Seq2SeqTrainingArguments
+
+from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
+from llmtuner.extras.callbacks import SavePeftModelCallback
+from llmtuner.extras.ploting import plot_loss
+from llmtuner.tuner.core import load_model_and_tokenizer
+from llmtuner.tuner.rm.metric import compute_accuracy
+from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
+from llmtuner.tuner.rm.trainer import PairwiseTrainer
+
+if TYPE_CHECKING:
+ from transformers import TrainerCallback
+ from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
+
+
+def run_rm(
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
+ finetuning_args: "FinetuningArguments",
+ callbacks: Optional[List["TrainerCallback"]] = None
+):
+ dataset = get_dataset(model_args, data_args)
+ model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
+ dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
+ data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=4)
+
+ training_args_dict = training_args.to_dict()
+ training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
+ training_args = Seq2SeqTrainingArguments(**training_args_dict)
+
+ # Initialize our Trainer
+ trainer = PairwiseTrainer(
+ model=model,
+ args=training_args,
+ tokenizer=tokenizer,
+ data_collator=data_collator,
+ callbacks=callbacks + [SavePeftModelCallback()],
+ compute_metrics=compute_accuracy,
+ **split_dataset(dataset, data_args, training_args)
+ )
+
+ # Training
+ if training_args.do_train:
+ train_result = trainer.train()
+ trainer.log_metrics("train", train_result.metrics)
+ trainer.save_metrics("train", train_result.metrics)
+ trainer.save_state()
+ trainer.save_model()
+ if trainer.is_world_process_zero() and model_args.plot_loss:
+ plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
+
+ # Evaluation
+ if training_args.do_eval:
+ metrics = trainer.evaluate(metric_key_prefix="eval")
+ trainer.log_metrics("eval", metrics)
+ trainer.save_metrics("eval", metrics)
+
+ # Predict
+ if training_args.do_predict:
+ predict_results = trainer.predict(dataset, metric_key_prefix="predict")
+ trainer.log_metrics("predict", predict_results.metrics)
+ trainer.save_metrics("predict", predict_results.metrics)
+ trainer.save_predictions(predict_results)
diff --git a/llm_rl/src/llmtuner/tuner/sft/__init__.py b/llm_rl/src/llmtuner/tuner/sft/__init__.py
new file mode 100644
index 00000000..493dd1a7
--- /dev/null
+++ b/llm_rl/src/llmtuner/tuner/sft/__init__.py
@@ -0,0 +1 @@
+from llmtuner.tuner.sft.workflow import run_sft
diff --git a/llm_rl/src/llmtuner/tuner/sft/metric.py b/llm_rl/src/llmtuner/tuner/sft/metric.py
new file mode 100644
index 00000000..812896ee
--- /dev/null
+++ b/llm_rl/src/llmtuner/tuner/sft/metric.py
@@ -0,0 +1,53 @@
+import numpy as np
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
+
+import jieba
+from rouge_chinese import Rouge
+from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
+
+from llmtuner.extras.constants import IGNORE_INDEX
+
+if TYPE_CHECKING:
+ from transformers.tokenization_utils import PreTrainedTokenizer
+
+
+@dataclass
+class ComputeMetrics:
+ r"""
+ Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
+ """
+
+ tokenizer: "PreTrainedTokenizer"
+
+ def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
+ r"""
+ Uses the model predictions to compute metrics.
+ """
+ preds, labels = eval_preds
+ score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
+
+ preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
+ labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
+
+ decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
+ decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
+
+ for pred, label in zip(decoded_preds, decoded_labels):
+ hypothesis = list(jieba.cut(pred))
+ reference = list(jieba.cut(label))
+
+ if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0:
+ result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}}
+ else:
+ rouge = Rouge()
+ scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference))
+ result = scores[0]
+
+ for k, v in result.items():
+ score_dict[k].append(round(v["f"] * 100, 4))
+
+ bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
+ score_dict["bleu-4"].append(round(bleu_score * 100, 4))
+
+ return {k: float(np.mean(v)) for k, v in score_dict.items()}
diff --git a/llm_rl/src/llmtuner/tuner/sft/trainer.py b/llm_rl/src/llmtuner/tuner/sft/trainer.py
new file mode 100644
index 00000000..c65cd255
--- /dev/null
+++ b/llm_rl/src/llmtuner/tuner/sft/trainer.py
@@ -0,0 +1,92 @@
+import os
+import json
+import torch
+import numpy as np
+import torch.nn as nn
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+from transformers import Seq2SeqTrainer
+
+from llmtuner.extras.constants import IGNORE_INDEX
+from llmtuner.extras.logging import get_logger
+
+if TYPE_CHECKING:
+ from transformers.trainer import PredictionOutput
+
+
+logger = get_logger(__name__)
+
+
+class CustomSeq2SeqTrainer(Seq2SeqTrainer):
+ r"""
+ Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
+ """
+
+ def prediction_step(
+ self,
+ model: nn.Module,
+ inputs: Dict[str, Union[torch.Tensor, Any]],
+ prediction_loss_only: bool,
+ ignore_keys: Optional[List[str]] = None,
+ ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
+ r"""
+ Removes the prompt part in the generated tokens.
+
+ Subclass and override to inject custom behavior.
+ """
+ labels = inputs["labels"].detach().clone() if "labels" in inputs else None # backup labels
+ if self.args.predict_with_generate:
+ assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
+ prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
+ if prompt_len > label_len:
+ inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
+ if label_len > prompt_len:
+ inputs["labels"] = inputs["labels"][:, :prompt_len] # truncate the labels instead of padding the inputs
+
+ loss, generated_tokens, _ = super().prediction_step(
+ model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
+ )
+ if generated_tokens is not None and self.args.predict_with_generate:
+ generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id
+ generated_tokens = generated_tokens.contiguous()
+
+ return loss, generated_tokens, labels
+
+ def _pad_tensors_to_target_len(
+ self,
+ src_tensor: torch.Tensor,
+ tgt_tensor: torch.Tensor
+ ) -> torch.Tensor:
+ r"""
+ Pads the tensor to the same length as the target tensor.
+ """
+ assert self.tokenizer.pad_token_id is not None, "Pad token is required."
+ padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor)
+ padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
+ return padded_tensor.contiguous() # in contiguous memory
+
+ def save_predictions(
+ self,
+ predict_results: "PredictionOutput"
+ ) -> None:
+ r"""
+ Saves model predictions to `output_dir`.
+
+ A custom behavior that not contained in Seq2SeqTrainer.
+ """
+ if not self.is_world_process_zero():
+ return
+
+ output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
+ logger.info(f"Saving prediction results to {output_prediction_file}")
+
+ preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
+ labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)
+
+ decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
+ decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=True)
+
+ with open(output_prediction_file, "w", encoding="utf-8") as writer:
+ res: List[str] = []
+ for pred, label in zip(decoded_preds, decoded_labels):
+ res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
+ writer.write("\n".join(res))
diff --git a/llm_rl/src/llmtuner/tuner/sft/workflow.py b/llm_rl/src/llmtuner/tuner/sft/workflow.py
new file mode 100644
index 00000000..8d53605d
--- /dev/null
+++ b/llm_rl/src/llmtuner/tuner/sft/workflow.py
@@ -0,0 +1,90 @@
+# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
+
+from typing import TYPE_CHECKING, Optional, List
+from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
+
+from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
+from llmtuner.extras.constants import IGNORE_INDEX
+from llmtuner.extras.misc import get_logits_processor
+from llmtuner.extras.ploting import plot_loss
+from llmtuner.tuner.core import load_model_and_tokenizer
+from llmtuner.tuner.sft.metric import ComputeMetrics
+from llmtuner.tuner.sft.trainer import CustomSeq2SeqTrainer
+
+if TYPE_CHECKING:
+ from transformers import TrainerCallback
+ from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
+
+
+def run_sft(
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
+ finetuning_args: "FinetuningArguments",
+ generating_args: "GeneratingArguments",
+ callbacks: Optional[List["TrainerCallback"]] = None
+):
+ dataset = get_dataset(model_args, data_args)
+ model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
+ dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft")
+
+ if training_args.predict_with_generate:
+ tokenizer.padding_side = "left" # use left-padding in generation
+
+ data_collator = DataCollatorForSeq2Seq(
+ tokenizer=tokenizer,
+ pad_to_multiple_of=4 if tokenizer.padding_side == "right" else None, # for shift short attention
+ label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
+ )
+
+ # Override the decoding parameters of Seq2SeqTrainer
+ training_args_dict = training_args.to_dict()
+ training_args_dict.update(dict(
+ generation_max_length=training_args.generation_max_length or data_args.cutoff_len,
+ generation_num_beams=data_args.eval_num_beams or training_args.generation_num_beams
+ ))
+ training_args = Seq2SeqTrainingArguments(**training_args_dict)
+
+ # Initialize our Trainer
+ trainer = CustomSeq2SeqTrainer(
+ model=model,
+ args=training_args,
+ tokenizer=tokenizer,
+ data_collator=data_collator,
+ callbacks=callbacks,
+ compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
+ **split_dataset(dataset, data_args, training_args)
+ )
+
+ # Keyword arguments for `model.generate`
+ gen_kwargs = generating_args.to_dict()
+ gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
+ gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
+ gen_kwargs["logits_processor"] = get_logits_processor()
+
+ # Training
+ if training_args.do_train:
+ train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
+ trainer.log_metrics("train", train_result.metrics)
+ trainer.save_metrics("train", train_result.metrics)
+ trainer.save_state()
+ trainer.save_model()
+ if trainer.is_world_process_zero() and model_args.plot_loss:
+ plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
+
+ # Evaluation
+ if training_args.do_eval:
+ metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
+ if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
+ metrics.pop("eval_loss", None)
+ trainer.log_metrics("eval", metrics)
+ trainer.save_metrics("eval", metrics)
+
+ # Predict
+ if training_args.do_predict:
+ predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
+ if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
+ predict_results.metrics.pop("predict_loss", None)
+ trainer.log_metrics("predict", predict_results.metrics)
+ trainer.save_metrics("predict", predict_results.metrics)
+ trainer.save_predictions(predict_results)
diff --git a/llm_rl/src/llmtuner/tuner/tune.py b/llm_rl/src/llmtuner/tuner/tune.py
new file mode 100644
index 00000000..4eb7f78f
--- /dev/null
+++ b/llm_rl/src/llmtuner/tuner/tune.py
@@ -0,0 +1,51 @@
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+
+from llmtuner.extras.callbacks import LogCallback
+from llmtuner.extras.logging import get_logger
+from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer
+from llmtuner.tuner.pt import run_pt
+from llmtuner.tuner.sft import run_sft
+from llmtuner.tuner.rm import run_rm
+from llmtuner.tuner.ppo import run_ppo
+from llmtuner.tuner.dpo import run_dpo
+
+if TYPE_CHECKING:
+ from transformers import TrainerCallback
+
+
+logger = get_logger(__name__)
+
+
+def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
+ model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
+ callbacks = [LogCallback()] if callbacks is None else callbacks
+
+ if finetuning_args.stage == "pt":
+ run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
+ elif finetuning_args.stage == "sft":
+ run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
+ elif finetuning_args.stage == "rm":
+ run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
+ elif finetuning_args.stage == "ppo":
+ run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
+ elif finetuning_args.stage == "dpo":
+ run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
+ else:
+ raise ValueError("Unknown task.")
+
+
+def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"):
+ model_args, _, finetuning_args, _ = get_infer_args(args)
+ model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
+ model.config.use_cache = True
+ model.save_pretrained(model_args.export_dir, max_shard_size=max_shard_size)
+ try:
+ tokenizer.padding_side = "left" # restore padding side
+ tokenizer.init_kwargs["padding_side"] = "left"
+ tokenizer.save_pretrained(model_args.export_dir)
+ except:
+ logger.warning("Cannot save tokenizer, please copy the files manually.")
+
+
+if __name__ == "__main__":
+ run_exp()
diff --git a/llm_rl/src/llmtuner/webui/__init__.py b/llm_rl/src/llmtuner/webui/__init__.py
new file mode 100644
index 00000000..a27c7f6e
--- /dev/null
+++ b/llm_rl/src/llmtuner/webui/__init__.py
@@ -0,0 +1 @@
+from llmtuner.webui.interface import create_ui, create_web_demo
diff --git a/llm_rl/src/llmtuner/webui/chatter.py b/llm_rl/src/llmtuner/webui/chatter.py
new file mode 100644
index 00000000..57eadb01
--- /dev/null
+++ b/llm_rl/src/llmtuner/webui/chatter.py
@@ -0,0 +1,101 @@
+import gradio as gr
+from gradio.components import Component # cannot use TYPE_CHECKING here
+from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
+
+from llmtuner.chat.stream_chat import ChatModel
+from llmtuner.extras.misc import torch_gc
+from llmtuner.hparams import GeneratingArguments
+from llmtuner.webui.common import get_save_dir
+from llmtuner.webui.locales import ALERTS
+
+if TYPE_CHECKING:
+ from llmtuner.webui.manager import Manager
+
+
+class WebChatModel(ChatModel):
+
+ def __init__(self, manager: "Manager", lazy_init: Optional[bool] = True) -> None:
+ self.manager = manager
+ self.model = None
+ self.tokenizer = None
+ self.generating_args = GeneratingArguments()
+ if not lazy_init:
+ super().__init__()
+
+ @property
+ def loaded(self) -> bool:
+ return self.model is not None
+
+ def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
+ get = lambda name: data[self.manager.get_elem_by_name(name)]
+ lang = get("top.lang")
+ error = ""
+ if self.loaded:
+ error = ALERTS["err_exists"][lang]
+ elif not get("top.model_name"):
+ error = ALERTS["err_no_model"][lang]
+ elif not get("top.model_path"):
+ error = ALERTS["err_no_path"][lang]
+
+ if error:
+ gr.Warning(error)
+ yield error
+ return
+
+ if get("top.checkpoints"):
+ checkpoint_dir = ",".join([
+ get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
+ ])
+ else:
+ checkpoint_dir = None
+
+ yield ALERTS["info_loading"][lang]
+ args = dict(
+ model_name_or_path=get("top.model_path"),
+ checkpoint_dir=checkpoint_dir,
+ finetuning_type=get("top.finetuning_type"),
+ quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
+ template=get("top.template"),
+ system_prompt=get("top.system_prompt"),
+ flash_attn=get("top.flash_attn"),
+ shift_attn=get("top.shift_attn"),
+ rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None
+ )
+ super().__init__(args)
+
+ yield ALERTS["info_loaded"][lang]
+
+ def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
+ lang = data[self.manager.get_elem_by_name("top.lang")]
+ yield ALERTS["info_unloading"][lang]
+ self.model = None
+ self.tokenizer = None
+ torch_gc()
+ yield ALERTS["info_unloaded"][lang]
+
+ def predict(
+ self,
+ chatbot: List[Tuple[str, str]],
+ query: str,
+ history: List[Tuple[str, str]],
+ system: str,
+ max_new_tokens: int,
+ top_p: float,
+ temperature: float
+ ) -> Generator[Tuple[List[Tuple[str, str]], List[Tuple[str, str]]], None, None]:
+ chatbot.append([query, ""])
+ response = ""
+ for new_text in self.stream_chat(
+ query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
+ ):
+ response += new_text
+ new_history = history + [(query, response)]
+ chatbot[-1] = [query, self.postprocess(response)]
+ yield chatbot, new_history
+
+ def postprocess(self, response: str) -> str:
+ blocks = response.split("```")
+ for i, block in enumerate(blocks):
+ if i % 2 == 0:
+ blocks[i] = block.replace("<", "<").replace(">", ">")
+ return "```".join(blocks)
diff --git a/llm_rl/src/llmtuner/webui/common.py b/llm_rl/src/llmtuner/webui/common.py
new file mode 100644
index 00000000..5a6c16d3
--- /dev/null
+++ b/llm_rl/src/llmtuner/webui/common.py
@@ -0,0 +1,103 @@
+import os
+import json
+import gradio as gr
+from typing import Any, Dict, Optional
+from transformers.utils import (
+ WEIGHTS_NAME,
+ WEIGHTS_INDEX_NAME,
+ SAFE_WEIGHTS_NAME,
+ SAFE_WEIGHTS_INDEX_NAME,
+ ADAPTER_WEIGHTS_NAME,
+ ADAPTER_SAFE_WEIGHTS_NAME
+)
+
+from llmtuner.extras.constants import DEFAULT_MODULE, DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES
+
+
+DEFAULT_CACHE_DIR = "cache"
+DEFAULT_DATA_DIR = "data"
+DEFAULT_SAVE_DIR = "saves"
+USER_CONFIG = "user.config"
+DATA_CONFIG = "dataset_info.json"
+CKPT_NAMES = [
+ WEIGHTS_NAME,
+ WEIGHTS_INDEX_NAME,
+ SAFE_WEIGHTS_NAME,
+ SAFE_WEIGHTS_INDEX_NAME,
+ ADAPTER_WEIGHTS_NAME,
+ ADAPTER_SAFE_WEIGHTS_NAME
+]
+
+
+def get_save_dir(*args) -> os.PathLike:
+ return os.path.join(DEFAULT_SAVE_DIR, *args)
+
+
+def get_config_path() -> os.PathLike:
+ return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
+
+
+def load_config() -> Dict[str, Any]:
+ try:
+ with open(get_config_path(), "r", encoding="utf-8") as f:
+ return json.load(f)
+ except:
+ return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
+
+
+def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None:
+ os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
+ user_config = load_config()
+ user_config["lang"] = lang or user_config["lang"]
+ if model_name:
+ user_config["last_model"] = model_name
+ user_config["path_dict"][model_name] = model_path
+ with open(get_config_path(), "w", encoding="utf-8") as f:
+ json.dump(user_config, f, indent=2, ensure_ascii=False)
+
+
+def get_model_path(model_name: str) -> str:
+ user_config = load_config()
+ return user_config["path_dict"].get(model_name, None) or SUPPORTED_MODELS.get(model_name, "")
+
+
+def get_module(model_name: str) -> str:
+ return DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj")
+
+
+def get_template(model_name: str) -> str:
+ if model_name.endswith("Chat") and model_name.split("-")[0] in DEFAULT_TEMPLATE:
+ return DEFAULT_TEMPLATE[model_name.split("-")[0]]
+ return "default"
+
+
+def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
+ checkpoints = []
+ if model_name:
+ save_dir = get_save_dir(model_name, finetuning_type)
+ if save_dir and os.path.isdir(save_dir):
+ for checkpoint in os.listdir(save_dir):
+ if (
+ os.path.isdir(os.path.join(save_dir, checkpoint))
+ and any([os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CKPT_NAMES])
+ ):
+ checkpoints.append(checkpoint)
+ return gr.update(value=[], choices=checkpoints)
+
+
+def load_dataset_info(dataset_dir: str) -> Dict[str, Any]:
+ try:
+ with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
+ return json.load(f)
+ except:
+ print("Cannot find {} in {}.".format(DATA_CONFIG, dataset_dir))
+ return {}
+
+
+def list_dataset(
+ dataset_dir: Optional[str] = None, training_stage: Optional[str] = list(TRAINING_STAGES.keys())[0]
+) -> Dict[str, Any]:
+ dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
+ ranking = TRAINING_STAGES[training_stage] in ["rm", "dpo"]
+ datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
+ return gr.update(value=[], choices=datasets)
diff --git a/llm_rl/src/llmtuner/webui/components/__init__.py b/llm_rl/src/llmtuner/webui/components/__init__.py
new file mode 100644
index 00000000..32228b8e
--- /dev/null
+++ b/llm_rl/src/llmtuner/webui/components/__init__.py
@@ -0,0 +1,6 @@
+from llmtuner.webui.components.top import create_top
+from llmtuner.webui.components.train import create_train_tab
+from llmtuner.webui.components.eval import create_eval_tab
+from llmtuner.webui.components.infer import create_infer_tab
+from llmtuner.webui.components.export import create_export_tab
+from llmtuner.webui.components.chatbot import create_chat_box
diff --git a/llm_rl/src/llmtuner/webui/components/chatbot.py b/llm_rl/src/llmtuner/webui/components/chatbot.py
new file mode 100644
index 00000000..13e2dd4d
--- /dev/null
+++ b/llm_rl/src/llmtuner/webui/components/chatbot.py
@@ -0,0 +1,49 @@
+import gradio as gr
+from typing import TYPE_CHECKING, Dict, Optional, Tuple
+
+if TYPE_CHECKING:
+ from gradio.blocks import Block
+ from gradio.components import Component
+ from llmtuner.webui.engine import Engine
+
+
+def create_chat_box(
+ engine: "Engine",
+ visible: Optional[bool] = False
+) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
+ with gr.Box(visible=visible) as chat_box:
+ chatbot = gr.Chatbot()
+ history = gr.State([])
+ with gr.Row():
+ with gr.Column(scale=4):
+ system = gr.Textbox(show_label=False)
+ query = gr.Textbox(show_label=False, lines=8)
+ submit_btn = gr.Button(variant="primary")
+
+ with gr.Column(scale=1):
+ clear_btn = gr.Button()
+ gen_kwargs = engine.chatter.generating_args
+ max_new_tokens = gr.Slider(10, 2048, value=gen_kwargs.max_new_tokens, step=1)
+ top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01)
+ temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01)
+
+ submit_btn.click(
+ engine.chatter.predict,
+ [chatbot, query, history, system, max_new_tokens, top_p, temperature],
+ [chatbot, history],
+ show_progress=True
+ ).then(
+ lambda: gr.update(value=""), outputs=[query]
+ )
+
+ clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
+
+ return chat_box, chatbot, history, dict(
+ system=system,
+ query=query,
+ submit_btn=submit_btn,
+ clear_btn=clear_btn,
+ max_new_tokens=max_new_tokens,
+ top_p=top_p,
+ temperature=temperature
+ )
diff --git a/llm_rl/src/llmtuner/webui/components/data.py b/llm_rl/src/llmtuner/webui/components/data.py
new file mode 100644
index 00000000..effa39da
--- /dev/null
+++ b/llm_rl/src/llmtuner/webui/components/data.py
@@ -0,0 +1,103 @@
+import os
+import json
+import gradio as gr
+from typing import TYPE_CHECKING, Any, Dict, Tuple
+
+from llmtuner.webui.common import DATA_CONFIG
+
+if TYPE_CHECKING:
+ from gradio.components import Component
+
+
+PAGE_SIZE = 2
+
+
+def prev_page(page_index: int) -> int:
+ return page_index - 1 if page_index > 0 else page_index
+
+
+def next_page(page_index: int, total_num: int) -> int:
+ return page_index + 1 if (page_index + 1) * PAGE_SIZE < total_num else page_index
+
+
+def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
+ with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
+ dataset_info = json.load(f)
+
+ if (
+ len(dataset) > 0
+ and "file_name" in dataset_info[dataset[0]]
+ and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"]))
+ ):
+ return gr.update(interactive=True)
+ else:
+ return gr.update(interactive=False)
+
+
+def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, Dict[str, Any]]:
+ with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
+ dataset_info = json.load(f)
+
+ data_file: str = dataset_info[dataset[0]]["file_name"]
+ with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
+ if data_file.endswith(".json"):
+ data = json.load(f)
+ elif data_file.endswith(".jsonl"):
+ data = [json.loads(line) for line in f]
+ else:
+ data = [line for line in f]
+ return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.update(visible=True)
+
+
+def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dict[str, "Component"]:
+ data_preview_btn = gr.Button(interactive=False, scale=1)
+ with gr.Column(visible=False, elem_classes="modal-box") as preview_box:
+ with gr.Row():
+ preview_count = gr.Number(value=0, interactive=False, precision=0)
+ page_index = gr.Number(value=0, interactive=False, precision=0)
+
+ with gr.Row():
+ prev_btn = gr.Button()
+ next_btn = gr.Button()
+ close_btn = gr.Button()
+
+ with gr.Row():
+ preview_samples = gr.JSON(interactive=False)
+
+ dataset.change(
+ can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False
+ ).then(
+ lambda: 0, outputs=[page_index], queue=False
+ )
+ data_preview_btn.click(
+ get_preview,
+ [dataset_dir, dataset, page_index],
+ [preview_count, preview_samples, preview_box],
+ queue=False
+ )
+ prev_btn.click(
+ prev_page, [page_index], [page_index], queue=False
+ ).then(
+ get_preview,
+ [dataset_dir, dataset, page_index],
+ [preview_count, preview_samples, preview_box],
+ queue=False
+ )
+ next_btn.click(
+ next_page, [page_index, preview_count], [page_index], queue=False
+ ).then(
+ get_preview,
+ [dataset_dir, dataset, page_index],
+ [preview_count, preview_samples, preview_box],
+ queue=False
+ )
+ close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False)
+ return dict(
+ data_preview_btn=data_preview_btn,
+ preview_count=preview_count,
+ page_index=page_index,
+ prev_btn=prev_btn,
+ next_btn=next_btn,
+ close_btn=close_btn,
+ preview_samples=preview_samples
+ )
diff --git a/llm_rl/src/llmtuner/webui/components/eval.py b/llm_rl/src/llmtuner/webui/components/eval.py
new file mode 100644
index 00000000..36c994a6
--- /dev/null
+++ b/llm_rl/src/llmtuner/webui/components/eval.py
@@ -0,0 +1,70 @@
+import gradio as gr
+from typing import TYPE_CHECKING, Dict
+
+from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
+from llmtuner.webui.components.data import create_preview_box
+
+if TYPE_CHECKING:
+ from gradio.components import Component
+ from llmtuner.webui.engine import Engine
+
+
+def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
+ input_elems = engine.manager.get_base_elems()
+ elem_dict = dict()
+
+ with gr.Row():
+ dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
+ dataset = gr.Dropdown(multiselect=True, scale=4)
+ preview_elems = create_preview_box(dataset_dir, dataset)
+
+ dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False)
+
+ input_elems.update({dataset_dir, dataset})
+ elem_dict.update(dict(dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
+
+ with gr.Row():
+ cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
+ max_samples = gr.Textbox(value="100000")
+ batch_size = gr.Slider(value=8, minimum=1, maximum=512, step=1)
+ predict = gr.Checkbox(value=True)
+
+ input_elems.update({cutoff_len, max_samples, batch_size, predict})
+ elem_dict.update(dict(
+ cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict
+ ))
+
+ with gr.Row():
+ max_new_tokens = gr.Slider(10, 2048, value=128, step=1)
+ top_p = gr.Slider(0.01, 1, value=0.7, step=0.01)
+ temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01)
+
+ input_elems.update({max_new_tokens, top_p, temperature})
+ elem_dict.update(dict(
+ max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
+ ))
+
+ with gr.Row():
+ cmd_preview_btn = gr.Button()
+ start_btn = gr.Button()
+ stop_btn = gr.Button()
+
+ with gr.Row():
+ resume_btn = gr.Checkbox(visible=False, interactive=False, value=False)
+ process_bar = gr.Slider(visible=False, interactive=False)
+
+ with gr.Box():
+ output_box = gr.Markdown()
+
+ output_elems = [output_box, process_bar]
+ elem_dict.update(dict(
+ cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn,
+ resume_btn=resume_btn, process_bar=process_bar, output_box=output_box
+ ))
+
+ cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems)
+ start_btn.click(engine.runner.run_eval, input_elems, output_elems)
+ stop_btn.click(engine.runner.set_abort, queue=False)
+ resume_btn.change(engine.runner.monitor, outputs=output_elems)
+
+ return elem_dict
diff --git a/llm_rl/src/llmtuner/webui/components/export.py b/llm_rl/src/llmtuner/webui/components/export.py
new file mode 100644
index 00000000..d16fa3d1
--- /dev/null
+++ b/llm_rl/src/llmtuner/webui/components/export.py
@@ -0,0 +1,79 @@
+import gradio as gr
+from typing import TYPE_CHECKING, Dict, Generator, List
+
+from llmtuner.tuner import export_model
+from llmtuner.webui.common import get_save_dir
+from llmtuner.webui.locales import ALERTS
+
+if TYPE_CHECKING:
+ from gradio.components import Component
+ from llmtuner.webui.engine import Engine
+
+
+def save_model(
+ lang: str,
+ model_name: str,
+ model_path: str,
+ checkpoints: List[str],
+ finetuning_type: str,
+ template: str,
+ max_shard_size: int,
+ export_dir: str
+) -> Generator[str, None, None]:
+ error = ""
+ if not model_name:
+ error = ALERTS["err_no_model"][lang]
+ elif not model_path:
+ error = ALERTS["err_no_path"][lang]
+ elif not checkpoints:
+ error = ALERTS["err_no_checkpoint"][lang]
+ elif not export_dir:
+ error = ALERTS["err_no_export_dir"][lang]
+
+ if error:
+ gr.Warning(error)
+ yield error
+ return
+
+ args = dict(
+ model_name_or_path=model_path,
+ checkpoint_dir=",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]),
+ finetuning_type=finetuning_type,
+ template=template,
+ export_dir=export_dir
+ )
+
+ yield ALERTS["info_exporting"][lang]
+ export_model(args, max_shard_size="{}GB".format(max_shard_size))
+ yield ALERTS["info_exported"][lang]
+
+
+def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
+ with gr.Row():
+ export_dir = gr.Textbox()
+ max_shard_size = gr.Slider(value=10, minimum=1, maximum=100)
+
+ export_btn = gr.Button()
+ info_box = gr.Textbox(show_label=False, interactive=False)
+
+ export_btn.click(
+ save_model,
+ [
+ engine.manager.get_elem_by_name("top.lang"),
+ engine.manager.get_elem_by_name("top.model_name"),
+ engine.manager.get_elem_by_name("top.model_path"),
+ engine.manager.get_elem_by_name("top.checkpoints"),
+ engine.manager.get_elem_by_name("top.finetuning_type"),
+ engine.manager.get_elem_by_name("top.template"),
+ max_shard_size,
+ export_dir
+ ],
+ [info_box]
+ )
+
+ return dict(
+ export_dir=export_dir,
+ max_shard_size=max_shard_size,
+ export_btn=export_btn,
+ info_box=info_box
+ )
diff --git a/llm_rl/src/llmtuner/webui/components/infer.py b/llm_rl/src/llmtuner/webui/components/infer.py
new file mode 100644
index 00000000..d6dd7eed
--- /dev/null
+++ b/llm_rl/src/llmtuner/webui/components/infer.py
@@ -0,0 +1,39 @@
+import gradio as gr
+from typing import TYPE_CHECKING, Dict
+
+from llmtuner.webui.components.chatbot import create_chat_box
+
+if TYPE_CHECKING:
+ from gradio.components import Component
+ from llmtuner.webui.engine import Engine
+
+
+def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
+ input_elems = engine.manager.get_base_elems()
+ elem_dict = dict()
+
+ with gr.Row():
+ load_btn = gr.Button()
+ unload_btn = gr.Button()
+
+ info_box = gr.Textbox(show_label=False, interactive=False)
+ elem_dict.update(dict(load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))
+
+ chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False)
+ elem_dict.update(dict(chat_box=chat_box, **chat_elems))
+
+ load_btn.click(
+ engine.chatter.load_model, input_elems, [info_box]
+ ).then(
+ lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
+ )
+
+ unload_btn.click(
+ engine.chatter.unload_model, input_elems, [info_box]
+ ).then(
+ lambda: ([], []), outputs=[chatbot, history]
+ ).then(
+ lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
+ )
+
+ return elem_dict
diff --git a/llm_rl/src/llmtuner/webui/components/top.py b/llm_rl/src/llmtuner/webui/components/top.py
new file mode 100644
index 00000000..c6299cab
--- /dev/null
+++ b/llm_rl/src/llmtuner/webui/components/top.py
@@ -0,0 +1,74 @@
+import gradio as gr
+from typing import TYPE_CHECKING, Dict
+
+from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
+from llmtuner.extras.template import templates
+from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, save_config
+from llmtuner.webui.utils import can_quantize
+
+if TYPE_CHECKING:
+ from gradio.components import Component
+
+
+def create_top() -> Dict[str, "Component"]:
+ available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
+
+ with gr.Row():
+ lang = gr.Dropdown(choices=["en", "zh"], scale=1)
+ model_name = gr.Dropdown(choices=available_models, scale=3)
+ model_path = gr.Textbox(scale=3)
+
+ with gr.Row():
+ finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
+ checkpoints = gr.Dropdown(multiselect=True, scale=5)
+ refresh_btn = gr.Button(scale=1)
+
+ with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
+ with gr.Row():
+ quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=1)
+ template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1)
+ system_prompt = gr.Textbox(scale=2)
+
+ with gr.Accordion(label="Model config (LLaMA only)", open=False) as llama_tab:
+ with gr.Row():
+ with gr.Column():
+ flash_attn = gr.Checkbox(value=False)
+ shift_attn = gr.Checkbox(value=False)
+ rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
+
+ model_name.change(
+ list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
+ ).then(
+ get_model_path, [model_name], [model_path], queue=False
+ ).then(
+ get_template, [model_name], [template], queue=False
+ ) # do not save config since the below line will save
+
+ model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
+
+ finetuning_type.change(
+ list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
+ ).then(
+ can_quantize, [finetuning_type], [quantization_bit], queue=False
+ )
+
+ refresh_btn.click(
+ list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
+ )
+
+ return dict(
+ lang=lang,
+ model_name=model_name,
+ model_path=model_path,
+ finetuning_type=finetuning_type,
+ checkpoints=checkpoints,
+ refresh_btn=refresh_btn,
+ advanced_tab=advanced_tab,
+ quantization_bit=quantization_bit,
+ template=template,
+ system_prompt=system_prompt,
+ llama_tab=llama_tab,
+ flash_attn=flash_attn,
+ shift_attn=shift_attn,
+ rope_scaling=rope_scaling
+ )
diff --git a/llm_rl/src/llmtuner/webui/components/train.py b/llm_rl/src/llmtuner/webui/components/train.py
new file mode 100644
index 00000000..11109c97
--- /dev/null
+++ b/llm_rl/src/llmtuner/webui/components/train.py
@@ -0,0 +1,154 @@
+import gradio as gr
+from typing import TYPE_CHECKING, Dict
+from transformers.trainer_utils import SchedulerType
+
+from llmtuner.extras.constants import TRAINING_STAGES
+from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR
+from llmtuner.webui.components.data import create_preview_box
+from llmtuner.webui.utils import gen_plot
+
+if TYPE_CHECKING:
+ from gradio.components import Component
+ from llmtuner.webui.engine import Engine
+
+
+def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
+ input_elems = engine.manager.get_base_elems()
+ elem_dict = dict()
+
+ with gr.Row():
+ training_stage = gr.Dropdown(
+ choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=2
+ )
+ dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
+ dataset = gr.Dropdown(multiselect=True, scale=4)
+ preview_elems = create_preview_box(dataset_dir, dataset)
+
+ training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
+ dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
+
+ input_elems.update({training_stage, dataset_dir, dataset})
+ elem_dict.update(dict(
+ training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems
+ ))
+
+ with gr.Row():
+ cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
+ learning_rate = gr.Textbox(value="5e-5")
+ num_train_epochs = gr.Textbox(value="3.0")
+ max_samples = gr.Textbox(value="100000")
+ compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16")
+
+ input_elems.update({cutoff_len, learning_rate, num_train_epochs, max_samples, compute_type})
+ elem_dict.update(dict(
+ cutoff_len=cutoff_len, learning_rate=learning_rate, num_train_epochs=num_train_epochs,
+ max_samples=max_samples, compute_type=compute_type
+ ))
+
+ with gr.Row():
+ batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
+ gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
+ lr_scheduler_type = gr.Dropdown(
+ choices=[scheduler.value for scheduler in SchedulerType], value="cosine"
+ )
+ max_grad_norm = gr.Textbox(value="1.0")
+ val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
+
+ input_elems.update({batch_size, gradient_accumulation_steps, lr_scheduler_type, max_grad_norm, val_size})
+ elem_dict.update(dict(
+ batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation_steps,
+ lr_scheduler_type=lr_scheduler_type, max_grad_norm=max_grad_norm, val_size=val_size
+ ))
+
+ with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
+ with gr.Row():
+ logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
+ save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
+ warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
+ neft_alpha = gr.Slider(value=0, minimum=0, maximum=10, step=0.1)
+
+ with gr.Column():
+ train_on_prompt = gr.Checkbox(value=False)
+ upcast_layernorm = gr.Checkbox(value=False)
+
+ input_elems.update({logging_steps, save_steps, warmup_steps, neft_alpha, train_on_prompt, upcast_layernorm})
+ elem_dict.update(dict(
+ advanced_tab=advanced_tab, logging_steps=logging_steps, save_steps=save_steps, warmup_steps=warmup_steps,
+ neft_alpha=neft_alpha, train_on_prompt=train_on_prompt, upcast_layernorm=upcast_layernorm
+ ))
+
+ with gr.Accordion(label="LoRA config", open=False) as lora_tab:
+ with gr.Row():
+ lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
+ lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
+ lora_target = gr.Textbox(scale=1)
+ additional_target = gr.Textbox(scale=1)
+ resume_lora_training = gr.Checkbox(value=True, scale=1)
+
+ input_elems.update({lora_rank, lora_dropout, lora_target, additional_target, resume_lora_training})
+ elem_dict.update(dict(
+ lora_tab=lora_tab, lora_rank=lora_rank, lora_dropout=lora_dropout, lora_target=lora_target,
+ additional_target=additional_target, resume_lora_training=resume_lora_training,
+ ))
+
+ with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
+ with gr.Row():
+ dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
+ reward_model = gr.Dropdown(scale=3)
+ refresh_btn = gr.Button(scale=1)
+
+ refresh_btn.click(
+ list_checkpoint,
+ [engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")],
+ [reward_model],
+ queue=False
+ )
+
+ input_elems.update({dpo_beta, reward_model})
+ elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, reward_model=reward_model, refresh_btn=refresh_btn))
+
+ with gr.Row():
+ cmd_preview_btn = gr.Button()
+ start_btn = gr.Button()
+ stop_btn = gr.Button()
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ with gr.Row():
+ output_dir = gr.Textbox()
+
+ with gr.Row():
+ resume_btn = gr.Checkbox(visible=False, interactive=False, value=False)
+ process_bar = gr.Slider(visible=False, interactive=False)
+
+ with gr.Box():
+ output_box = gr.Markdown()
+
+ with gr.Column(scale=1):
+ loss_viewer = gr.Plot()
+
+ input_elems.add(output_dir)
+ output_elems = [output_box, process_bar]
+
+ cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems)
+ start_btn.click(engine.runner.run_train, input_elems, output_elems)
+ stop_btn.click(engine.runner.set_abort, queue=False)
+ resume_btn.change(engine.runner.monitor, outputs=output_elems)
+
+ elem_dict.update(dict(
+ cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_dir=output_dir,
+ resume_btn=resume_btn, process_bar=process_bar, output_box=output_box, loss_viewer=loss_viewer
+ ))
+
+ output_box.change(
+ gen_plot,
+ [
+ engine.manager.get_elem_by_name("top.model_name"),
+ engine.manager.get_elem_by_name("top.finetuning_type"),
+ output_dir
+ ],
+ loss_viewer,
+ queue=False
+ )
+
+ return elem_dict
diff --git a/llm_rl/src/llmtuner/webui/css.py b/llm_rl/src/llmtuner/webui/css.py
new file mode 100644
index 00000000..c86fb96b
--- /dev/null
+++ b/llm_rl/src/llmtuner/webui/css.py
@@ -0,0 +1,20 @@
+CSS = r"""
+.modal-box {
+ position: fixed !important;
+ top: 50%;
+ left: 50%;
+ transform: translate(-50%, -50%); /* center horizontally */
+ max-width: 1000px;
+ max-height: 750px;
+ overflow-y: auto;
+ background-color: var(--input-background-fill);
+ flex-wrap: nowrap !important;
+ border: 2px solid black !important;
+ z-index: 1000;
+ padding: 10px;
+}
+
+.dark .modal-box {
+ border: 2px solid white !important;
+}
+"""
diff --git a/llm_rl/src/llmtuner/webui/engine.py b/llm_rl/src/llmtuner/webui/engine.py
new file mode 100644
index 00000000..661dfb48
--- /dev/null
+++ b/llm_rl/src/llmtuner/webui/engine.py
@@ -0,0 +1,57 @@
+import gradio as gr
+from gradio.components import Component # cannot use TYPE_CHECKING here
+from typing import Any, Dict, Generator, Optional
+
+from llmtuner.webui.chatter import WebChatModel
+from llmtuner.webui.common import get_model_path, list_dataset, load_config
+from llmtuner.webui.locales import LOCALES
+from llmtuner.webui.manager import Manager
+from llmtuner.webui.runner import Runner
+from llmtuner.webui.utils import get_time
+
+
+class Engine:
+
+ def __init__(self, pure_chat: Optional[bool] = False) -> None:
+ self.pure_chat = pure_chat
+ self.manager: "Manager" = Manager()
+ self.runner: "Runner" = Runner(self.manager)
+ self.chatter: "WebChatModel" = WebChatModel(manager=self.manager, lazy_init=(not pure_chat))
+
+ def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]):
+ return {self.manager.get_elem_by_name(k): gr.update(**v) for k, v in resume_dict.items()}
+
+ def resume(self) -> Generator[Dict[Component, Dict[str, Any]], None, None]:
+ user_config = load_config()
+ lang = user_config.get("lang", None) or "en"
+
+ init_dict = {
+ "top.lang": {"value": lang},
+ "infer.chat_box": {"visible": self.chatter.loaded}
+ }
+
+ if not self.pure_chat:
+ init_dict["train.dataset"] = {"choices": list_dataset()["choices"]}
+ init_dict["eval.dataset"] = {"choices": list_dataset()["choices"]}
+
+ if user_config.get("last_model", None):
+ init_dict["top.model_name"] = {"value": user_config["last_model"]}
+ init_dict["top.model_path"] = {"value": get_model_path(user_config["last_model"])}
+
+ yield self._form_dict(init_dict)
+
+ if not self.pure_chat:
+ if self.runner.alive:
+ yield {elem: gr.update(value=value) for elem, value in self.runner.running_data.items()}
+ if self.runner.do_train:
+ yield self._form_dict({"train.resume_btn": {"value": True}})
+ else:
+ yield self._form_dict({"eval.resume_btn": {"value": True}})
+ else:
+ yield self._form_dict({"train.output_dir": {"value": get_time()}})
+
+ def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]:
+ return {
+ component: gr.update(**LOCALES[name][lang])
+ for elems in self.manager.all_elems.values() for name, component in elems.items() if name in LOCALES
+ }
diff --git a/llm_rl/src/llmtuner/webui/interface.py b/llm_rl/src/llmtuner/webui/interface.py
new file mode 100644
index 00000000..ba663f24
--- /dev/null
+++ b/llm_rl/src/llmtuner/webui/interface.py
@@ -0,0 +1,66 @@
+import gradio as gr
+from transformers.utils.versions import require_version
+
+from llmtuner.webui.components import (
+ create_top,
+ create_train_tab,
+ create_eval_tab,
+ create_infer_tab,
+ create_export_tab,
+ create_chat_box
+)
+from llmtuner.webui.common import save_config
+from llmtuner.webui.css import CSS
+from llmtuner.webui.engine import Engine
+
+
+require_version("gradio>=3.38.0,<4.0.0", "To fix: pip install \"gradio>=3.38.0,<4.0.0\"")
+
+
+def create_ui() -> gr.Blocks:
+ engine = Engine(pure_chat=False)
+
+ with gr.Blocks(title="LLaMA Board", css=CSS) as demo:
+ engine.manager.all_elems["top"] = create_top()
+ lang: "gr.Dropdown" = engine.manager.get_elem_by_name("top.lang")
+
+ with gr.Tab("Train"):
+ engine.manager.all_elems["train"] = create_train_tab(engine)
+
+ with gr.Tab("Evaluate"):
+ engine.manager.all_elems["eval"] = create_eval_tab(engine)
+
+ with gr.Tab("Chat"):
+ engine.manager.all_elems["infer"] = create_infer_tab(engine)
+
+ with gr.Tab("Export"):
+ engine.manager.all_elems["export"] = create_export_tab(engine)
+
+ demo.load(engine.resume, outputs=engine.manager.list_elems())
+ lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False)
+ lang.input(save_config, inputs=[lang], queue=False)
+
+ return demo
+
+
+def create_web_demo() -> gr.Blocks:
+ engine = Engine(pure_chat=True)
+
+ with gr.Blocks(title="Web Demo", css=CSS) as demo:
+ lang = gr.Dropdown(choices=["en", "zh"])
+ engine.manager.all_elems["top"] = dict(lang=lang)
+
+ chat_box, _, _, chat_elems = create_chat_box(engine, visible=True)
+ engine.manager.all_elems["infer"] = dict(chat_box=chat_box, **chat_elems)
+
+ demo.load(engine.resume, outputs=engine.manager.list_elems())
+ lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False)
+ lang.input(save_config, inputs=[lang], queue=False)
+
+ return demo
+
+
+if __name__ == "__main__":
+ demo = create_ui()
+ demo.queue()
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True)
diff --git a/llm_rl/src/llmtuner/webui/locales.py b/llm_rl/src/llmtuner/webui/locales.py
new file mode 100644
index 00000000..cc2a1842
--- /dev/null
+++ b/llm_rl/src/llmtuner/webui/locales.py
@@ -0,0 +1,698 @@
+LOCALES = {
+ "lang": {
+ "en": {
+ "label": "Lang"
+ },
+ "zh": {
+ "label": "语言"
+ }
+ },
+ "model_name": {
+ "en": {
+ "label": "Model name"
+ },
+ "zh": {
+ "label": "模型名称"
+ }
+ },
+ "model_path": {
+ "en": {
+ "label": "Model path",
+ "info": "Path to pretrained model or model identifier from Hugging Face."
+ },
+ "zh": {
+ "label": "模型路径",
+ "info": "本地模型的文件路径或 Hugging Face 的模型标识符。"
+ }
+ },
+ "finetuning_type": {
+ "en": {
+ "label": "Finetuning method"
+ },
+ "zh": {
+ "label": "微调方法"
+ }
+ },
+ "checkpoints": {
+ "en": {
+ "label": "Checkpoints"
+ },
+ "zh": {
+ "label": "模型断点"
+ }
+ },
+ "refresh_btn": {
+ "en": {
+ "value": "Refresh checkpoints"
+ },
+ "zh": {
+ "value": "刷新断点"
+ }
+ },
+ "advanced_tab": {
+ "en": {
+ "label": "Advanced configurations"
+ },
+ "zh": {
+ "label": "高级设置"
+ }
+ },
+ "quantization_bit": {
+ "en": {
+ "label": "Quantization bit",
+ "info": "Enable 4/8-bit model quantization (QLoRA)."
+ },
+ "zh": {
+ "label": "量化等级",
+ "info": "启用 4/8 比特模型量化(QLoRA)。"
+ }
+ },
+ "template": {
+ "en": {
+ "label": "Prompt template",
+ "info": "The template used in constructing prompts."
+ },
+ "zh": {
+ "label": "提示模板",
+ "info": "构建提示词时使用的模板"
+ }
+ },
+ "system_prompt": {
+ "en": {
+ "label": "System prompt (optional)",
+ "info": "A sequence used as the default system prompt."
+ },
+ "zh": {
+ "label": "系统提示词(非必填)",
+ "info": "默认使用的系统提示词"
+ }
+ },
+ "llama_tab": {
+ "en": {
+ "label": "Model configurations (LLaMA only)"
+ },
+ "zh": {
+ "label": "模型设置(仅LLaMA)"
+ }
+ },
+ "flash_attn": {
+ "en": {
+ "label": "Use FlashAttention-2"
+ },
+ "zh": {
+ "label": "使用 FlashAttention-2"
+ }
+ },
+ "shift_attn": {
+ "en": {
+ "label": "Use shift short attention (S^2-Attn)"
+ },
+ "zh": {
+ "label": "使用 shift short attention (S^2-Attn)"
+ }
+ },
+ "rope_scaling": {
+ "en": {
+ "label": "RoPE scaling"
+ },
+ "zh": {
+ "label": "RoPE 插值方法"
+ }
+ },
+ "training_stage": {
+ "en": {
+ "label": "Stage",
+ "info": "The stage to perform in training."
+ },
+ "zh": {
+ "label": "训练阶段",
+ "info": "目前采用的训练方式。"
+ }
+ },
+ "dataset_dir": {
+ "en": {
+ "label": "Data dir",
+ "info": "Path of the data directory."
+ },
+ "zh": {
+ "label": "数据路径",
+ "info": "数据文件夹的路径。"
+ }
+ },
+ "dataset": {
+ "en": {
+ "label": "Dataset"
+ },
+ "zh": {
+ "label": "数据集"
+ }
+ },
+ "data_preview_btn": {
+ "en": {
+ "value": "Preview dataset"
+ },
+ "zh": {
+ "value": "预览数据集"
+ }
+ },
+ "preview_count": {
+ "en": {
+ "label": "Count"
+ },
+ "zh": {
+ "label": "数量"
+ }
+ },
+ "page_index": {
+ "en": {
+ "label": "Page"
+ },
+ "zh": {
+ "label": "页数"
+ }
+ },
+ "prev_btn": {
+ "en": {
+ "value": "Prev"
+ },
+ "zh": {
+ "value": "上一页"
+ }
+ },
+ "next_btn": {
+ "en": {
+ "value": "Next"
+ },
+ "zh": {
+ "value": "下一页"
+ }
+ },
+ "close_btn": {
+ "en": {
+ "value": "Close"
+ },
+ "zh": {
+ "value": "关闭"
+ }
+ },
+ "preview_samples": {
+ "en": {
+ "label": "Samples"
+ },
+ "zh": {
+ "label": "样例"
+ }
+ },
+ "cutoff_len": {
+ "en": {
+ "label": "Cutoff length",
+ "info": "Max tokens in input sequence."
+ },
+ "zh": {
+ "label": "截断长度",
+ "info": "输入序列分词后的最大长度。"
+ }
+ },
+ "learning_rate": {
+ "en": {
+ "label": "Learning rate",
+ "info": "Initial learning rate for AdamW."
+ },
+ "zh": {
+ "label": "学习率",
+ "info": "AdamW 优化器的初始学习率。"
+ }
+ },
+ "num_train_epochs": {
+ "en": {
+ "label": "Epochs",
+ "info": "Total number of training epochs to perform."
+ },
+ "zh": {
+ "label": "训练轮数",
+ "info": "需要执行的训练总轮数。"
+ }
+ },
+ "max_samples": {
+ "en": {
+ "label": "Max samples",
+ "info": "Maximum samples per dataset."
+ },
+ "zh": {
+ "label": "最大样本数",
+ "info": "每个数据集最多使用的样本数。"
+ }
+ },
+ "compute_type": {
+ "en": {
+ "label": "Compute type",
+ "info": "Whether to use fp16 or bf16 mixed precision training."
+ },
+ "zh": {
+ "label": "计算类型",
+ "info": "是否启用 FP16 或 BF16 混合精度训练。"
+ }
+ },
+ "batch_size": {
+ "en": {
+ "label": "Batch size",
+ "info": "Number of samples to process per GPU."
+ },
+ "zh":{
+ "label": "批处理大小",
+ "info": "每块 GPU 上处理的样本数量。"
+ }
+ },
+ "gradient_accumulation_steps": {
+ "en": {
+ "label": "Gradient accumulation",
+ "info": "Number of gradient accumulation steps."
+ },
+ "zh": {
+ "label": "梯度累积",
+ "info": "梯度累积的步数。"
+ }
+ },
+ "lr_scheduler_type": {
+ "en": {
+ "label": "LR Scheduler",
+ "info": "Name of learning rate scheduler.",
+ },
+ "zh": {
+ "label": "学习率调节器",
+ "info": "采用的学习率调节器名称。"
+ }
+ },
+ "max_grad_norm": {
+ "en": {
+ "label": "Maximum gradient norm",
+ "info": "Norm for gradient clipping.."
+ },
+ "zh": {
+ "label": "最大梯度范数",
+ "info": "用于梯度裁剪的范数。"
+ }
+ },
+ "val_size": {
+ "en": {
+ "label": "Val size",
+ "info": "Proportion of data in the dev set."
+ },
+ "zh": {
+ "label": "验证集比例",
+ "info": "验证集占全部样本的百分比。"
+ }
+ },
+ "logging_steps": {
+ "en": {
+ "label": "Logging steps",
+ "info": "Number of steps between two logs."
+ },
+ "zh": {
+ "label": "日志间隔",
+ "info": "每两次日志输出间的更新步数。"
+ }
+ },
+ "save_steps": {
+ "en": {
+ "label": "Save steps",
+ "info": "Number of steps between two checkpoints."
+ },
+ "zh": {
+ "label": "保存间隔",
+ "info": "每两次断点保存间的更新步数。"
+ }
+ },
+ "warmup_steps": {
+ "en": {
+ "label": "Warmup steps",
+ "info": "Number of steps used for warmup."
+ },
+ "zh": {
+ "label": "预热步数",
+ "info": "学习率预热采用的步数。"
+ }
+ },
+ "neft_alpha": {
+ "en": {
+ "label": "NEFTune Alpha",
+ "info": "Magnitude of noise adding to embedding vectors."
+ },
+ "zh": {
+ "label": "NEFTune 噪声参数",
+ "info": "嵌入向量所添加的噪声大小。"
+ }
+ },
+ "train_on_prompt": {
+ "en": {
+ "label": "Train on prompt",
+ "info": "Compute loss on the prompt tokens in supervised fine-tuning."
+ },
+ "zh": {
+ "label": "计算输入损失",
+ "info": "在监督微调时候计算输入序列的损失。"
+ }
+ },
+ "upcast_layernorm": {
+ "en": {
+ "label": "Upcast LayerNorm",
+ "info": "Upcast weights of layernorm in float32."
+ },
+ "zh": {
+ "label": "缩放归一化层",
+ "info": "将归一化层权重缩放至 32 位浮点数。"
+ }
+ },
+ "lora_tab": {
+ "en": {
+ "label": "LoRA configurations"
+ },
+ "zh": {
+ "label": "LoRA 参数设置"
+ }
+ },
+ "lora_rank": {
+ "en": {
+ "label": "LoRA rank",
+ "info": "The rank of LoRA matrices."
+ },
+ "zh": {
+ "label": "LoRA 秩",
+ "info": "LoRA 矩阵的秩。"
+ }
+ },
+ "lora_dropout": {
+ "en": {
+ "label": "LoRA Dropout",
+ "info": "Dropout ratio of LoRA weights."
+ },
+ "zh": {
+ "label": "LoRA 随机丢弃",
+ "info": "LoRA 权重随机丢弃的概率。"
+ }
+ },
+ "lora_target": {
+ "en": {
+ "label": "LoRA modules (optional)",
+ "info": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules."
+ },
+ "zh": {
+ "label": "LoRA 作用模块(非必填)",
+ "info": "应用 LoRA 的目标模块名称。使用英文逗号分隔多个名称。"
+ }
+ },
+ "additional_target": {
+ "en": {
+ "label": "Additional modules (optional)",
+ "info": "Name(s) of modules apart from LoRA layers to be set as trainable. Use commas to separate multiple modules."
+ },
+ "zh": {
+ "label": "附加模块(非必填)",
+ "info": "除 LoRA 层以外的可训练模块名称。使用英文逗号分隔多个名称。"
+ }
+ },
+ "resume_lora_training": {
+ "en": {
+ "label": "Resume LoRA training",
+ "info": "Whether to resume training from the last LoRA weights or create new lora weights."
+ },
+ "zh": {
+ "label": "继续上次的训练",
+ "info": "接着上次的 LoRA 权重训练或创建一个新的 LoRA 权重。"
+ }
+ },
+ "rlhf_tab": {
+ "en": {
+ "label": "RLHF configurations"
+ },
+ "zh": {
+ "label": "RLHF 参数设置"
+ }
+ },
+ "dpo_beta": {
+ "en": {
+ "label": "DPO beta",
+ "info": "Value of the beta parameter in the DPO loss."
+ },
+ "zh": {
+ "label": "DPO beta 参数",
+ "info": "DPO 损失函数中 beta 超参数大小。"
+ }
+ },
+ "reward_model": {
+ "en": {
+ "label": "Reward model",
+ "info": "Checkpoint of the reward model for PPO training. (Needs to refresh checkpoints)"
+ },
+ "zh": {
+ "label": "奖励模型",
+ "info": "PPO 训练中奖励模型的断点路径。(需要刷新断点)"
+ }
+ },
+ "cmd_preview_btn": {
+ "en": {
+ "value": "Preview command"
+ },
+ "zh": {
+ "value": "预览命令"
+ }
+ },
+ "start_btn": {
+ "en": {
+ "value": "Start"
+ },
+ "zh": {
+ "value": "开始"
+ }
+ },
+ "stop_btn": {
+ "en": {
+ "value": "Abort"
+ },
+ "zh": {
+ "value": "中断"
+ }
+ },
+ "output_dir": {
+ "en": {
+ "label": "Checkpoint name",
+ "info": "Directory to save checkpoint."
+ },
+ "zh": {
+ "label": "断点名称",
+ "info": "保存模型断点的文件夹名称。"
+ }
+ },
+ "output_box": {
+ "en": {
+ "value": "Ready."
+ },
+ "zh": {
+ "value": "准备就绪。"
+ }
+ },
+ "loss_viewer": {
+ "en": {
+ "label": "Loss"
+ },
+ "zh": {
+ "label": "损失"
+ }
+ },
+ "predict": {
+ "en": {
+ "label": "Save predictions"
+ },
+ "zh": {
+ "label": "保存预测结果"
+ }
+ },
+ "load_btn": {
+ "en": {
+ "value": "Load model"
+ },
+ "zh": {
+ "value": "加载模型"
+ }
+ },
+ "unload_btn": {
+ "en": {
+ "value": "Unload model"
+ },
+ "zh": {
+ "value": "卸载模型"
+ }
+ },
+ "info_box": {
+ "en": {
+ "value": "Model unloaded, please load a model first."
+ },
+ "zh": {
+ "value": "模型未加载,请先加载模型。"
+ }
+ },
+ "system": {
+ "en": {
+ "placeholder": "System prompt (optional)"
+ },
+ "zh": {
+ "placeholder": "系统提示词(非必填)"
+ }
+ },
+ "query": {
+ "en": {
+ "placeholder": "Input..."
+ },
+ "zh": {
+ "placeholder": "输入..."
+ }
+ },
+ "submit_btn": {
+ "en": {
+ "value": "Submit"
+ },
+ "zh": {
+ "value": "提交"
+ }
+ },
+ "clear_btn": {
+ "en": {
+ "value": "Clear history"
+ },
+ "zh": {
+ "value": "清空历史"
+ }
+ },
+ "max_length": {
+ "en": {
+ "label": "Maximum length"
+ },
+ "zh": {
+ "label": "最大长度"
+ }
+ },
+ "max_new_tokens": {
+ "en": {
+ "label": "Maximum new tokens"
+ },
+ "zh": {
+ "label": "最大生成长度"
+ }
+ },
+ "top_p": {
+ "en": {
+ "label": "Top-p"
+ },
+ "zh": {
+ "label": "Top-p 采样值"
+ }
+ },
+ "temperature": {
+ "en": {
+ "label": "Temperature"
+ },
+ "zh": {
+ "label": "温度系数"
+ }
+ },
+ "export_dir": {
+ "en": {
+ "label": "Export dir",
+ "info": "Directory to save exported model."
+ },
+ "zh": {
+ "label": "导出目录",
+ "info": "保存导出模型的文件夹路径。"
+ }
+ },
+ "max_shard_size": {
+ "en": {
+ "label": "Max shard size (GB)",
+ "info": "The maximum size for a model file."
+ },
+ "zh": {
+ "label": "最大分块大小(GB)",
+ "info": "模型文件的最大大小。"
+ }
+ },
+ "export_btn": {
+ "en": {
+ "value": "Export"
+ },
+ "zh": {
+ "value": "开始导出"
+ }
+ }
+}
+
+
+ALERTS = {
+ "err_conflict": {
+ "en": "A process is in running, please abort it firstly.",
+ "zh": "任务已存在,请先中断训练。"
+ },
+ "err_exists": {
+ "en": "You have loaded a model, please unload it first.",
+ "zh": "模型已存在,请先卸载模型。"
+ },
+ "err_no_model": {
+ "en": "Please select a model.",
+ "zh": "请选择模型。"
+ },
+ "err_no_path": {
+ "en": "Model not found.",
+ "zh": "模型未找到。"
+ },
+ "err_no_dataset": {
+ "en": "Please choose a dataset.",
+ "zh": "请选择数据集。"
+ },
+ "err_no_checkpoint": {
+ "en": "Please select a checkpoint.",
+ "zh": "请选择断点。"
+ },
+ "err_no_export_dir": {
+ "en": "Please provide export dir.",
+ "zh": "请填写导出目录"
+ },
+ "err_failed": {
+ "en": "Failed.",
+ "zh": "训练出错。"
+ },
+ "info_aborting": {
+ "en": "Aborted, wait for terminating...",
+ "zh": "训练中断,正在等待线程结束……"
+ },
+ "info_aborted": {
+ "en": "Ready.",
+ "zh": "准备就绪。"
+ },
+ "info_finished": {
+ "en": "Finished.",
+ "zh": "训练完毕。"
+ },
+ "info_loading": {
+ "en": "Loading model...",
+ "zh": "加载中……"
+ },
+ "info_unloading": {
+ "en": "Unloading model...",
+ "zh": "卸载中……"
+ },
+ "info_loaded": {
+ "en": "Model loaded, now you can chat with your model!",
+ "zh": "模型已加载,可以开始聊天了!"
+ },
+ "info_unloaded": {
+ "en": "Model unloaded.",
+ "zh": "模型已卸载。"
+ },
+ "info_exporting": {
+ "en": "Exporting model...",
+ "zh": "正在导出模型……"
+ },
+ "info_exported": {
+ "en": "Model exported.",
+ "zh": "模型导出完成。"
+ }
+}
diff --git a/llm_rl/src/llmtuner/webui/manager.py b/llm_rl/src/llmtuner/webui/manager.py
new file mode 100644
index 00000000..ca067aea
--- /dev/null
+++ b/llm_rl/src/llmtuner/webui/manager.py
@@ -0,0 +1,35 @@
+from typing import TYPE_CHECKING, Dict, List, Set
+
+if TYPE_CHECKING:
+ from gradio.components import Component
+
+
+class Manager:
+
+ def __init__(self) -> None:
+ self.all_elems: Dict[str, Dict[str, "Component"]] = {}
+
+ def get_elem_by_name(self, name: str) -> "Component":
+ r"""
+ Example: top.lang, train.dataset
+ """
+ tab_name, elem_name = name.split(".")
+ return self.all_elems[tab_name][elem_name]
+
+ def get_base_elems(self) -> Set["Component"]:
+ return {
+ self.all_elems["top"]["lang"],
+ self.all_elems["top"]["model_name"],
+ self.all_elems["top"]["model_path"],
+ self.all_elems["top"]["checkpoints"],
+ self.all_elems["top"]["finetuning_type"],
+ self.all_elems["top"]["quantization_bit"],
+ self.all_elems["top"]["template"],
+ self.all_elems["top"]["system_prompt"],
+ self.all_elems["top"]["flash_attn"],
+ self.all_elems["top"]["shift_attn"],
+ self.all_elems["top"]["rope_scaling"]
+ }
+
+ def list_elems(self) -> List["Component"]:
+ return [elem for elems in self.all_elems.values() for elem in elems.values()]
diff --git a/llm_rl/src/llmtuner/webui/runner.py b/llm_rl/src/llmtuner/webui/runner.py
new file mode 100644
index 00000000..ab9e9ffc
--- /dev/null
+++ b/llm_rl/src/llmtuner/webui/runner.py
@@ -0,0 +1,254 @@
+import os
+import time
+import logging
+import gradio as gr
+from threading import Thread
+from gradio.components import Component # cannot use TYPE_CHECKING here
+from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple
+
+import transformers
+from transformers.trainer import TRAINING_ARGS_NAME
+
+from llmtuner.extras.callbacks import LogCallback
+from llmtuner.extras.constants import TRAINING_STAGES
+from llmtuner.extras.logging import LoggerHandler
+from llmtuner.extras.misc import torch_gc
+from llmtuner.tuner import run_exp
+from llmtuner.webui.common import get_module, get_save_dir, load_config
+from llmtuner.webui.locales import ALERTS
+from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
+
+if TYPE_CHECKING:
+ from llmtuner.webui.manager import Manager
+
+
+class Runner:
+
+ def __init__(self, manager: "Manager") -> None:
+ self.manager = manager
+ """ Resume """
+ self.thread: "Thread" = None
+ self.do_train = True
+ self.running_data: Dict["Component", Any] = None
+ self.monitor_inputs: Dict[str, str] = None
+ """ State """
+ self.aborted = False
+ self.running = False
+ """ Handler """
+ self.logger_handler = LoggerHandler()
+ self.logger_handler.setLevel(logging.INFO)
+ logging.root.addHandler(self.logger_handler)
+ transformers.logging.add_handler(self.logger_handler)
+
+ @property
+ def alive(self) -> bool:
+ return self.thread is not None
+
+ def set_abort(self) -> None:
+ self.aborted = True
+ self.running = False
+
+ def _initialize(self, data: Dict[Component, Any], do_train: bool) -> str:
+ get = lambda name: data[self.manager.get_elem_by_name(name)]
+ lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
+ dataset = get("train.dataset") if do_train else get("eval.dataset")
+
+ if self.running:
+ return ALERTS["err_conflict"][lang]
+
+ if not model_name:
+ return ALERTS["err_no_model"][lang]
+
+ if not model_path:
+ return ALERTS["err_no_path"][lang]
+
+ if len(dataset) == 0:
+ return ALERTS["err_no_dataset"][lang]
+
+ self.aborted = False
+ self.logger_handler.reset()
+ self.trainer_callback = LogCallback(self)
+ return ""
+
+ def _finalize(self, lang: str, finish_info: str) -> str:
+ self.thread = None
+ self.running = False
+ torch_gc()
+ if self.aborted:
+ return ALERTS["info_aborted"][lang]
+ else:
+ return finish_info
+
+ def _parse_train_args(self, data: Dict[Component, Any]) -> Dict[str, Any]:
+ get = lambda name: data[self.manager.get_elem_by_name(name)]
+ user_config = load_config()
+
+ if get("top.checkpoints"):
+ checkpoint_dir = ",".join([
+ get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
+ ])
+ else:
+ checkpoint_dir = None
+
+ args = dict(
+ stage=TRAINING_STAGES[get("train.training_stage")],
+ model_name_or_path=get("top.model_path"),
+ do_train=True,
+ cache_dir=user_config.get("cache_dir", None),
+ checkpoint_dir=checkpoint_dir,
+ finetuning_type=get("top.finetuning_type"),
+ quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
+ template=get("top.template"),
+ system_prompt=get("top.system_prompt"),
+ flash_attn=get("top.flash_attn"),
+ shift_attn=get("top.shift_attn"),
+ rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
+ dataset_dir=get("train.dataset_dir"),
+ dataset=",".join(get("train.dataset")),
+ cutoff_len=get("train.cutoff_len"),
+ learning_rate=float(get("train.learning_rate")),
+ num_train_epochs=float(get("train.num_train_epochs")),
+ max_samples=int(get("train.max_samples")),
+ per_device_train_batch_size=get("train.batch_size"),
+ gradient_accumulation_steps=get("train.gradient_accumulation_steps"),
+ lr_scheduler_type=get("train.lr_scheduler_type"),
+ max_grad_norm=float(get("train.max_grad_norm")),
+ logging_steps=get("train.logging_steps"),
+ save_steps=get("train.save_steps"),
+ warmup_steps=get("train.warmup_steps"),
+ neft_alpha=get("train.neft_alpha"),
+ train_on_prompt=get("train.train_on_prompt"),
+ upcast_layernorm=get("train.upcast_layernorm"),
+ lora_rank=get("train.lora_rank"),
+ lora_dropout=get("train.lora_dropout"),
+ lora_target=get("train.lora_target") or get_module(get("top.model_name")),
+ additional_target=get("train.additional_target") if get("train.additional_target") else None,
+ resume_lora_training=get("train.resume_lora_training"),
+ output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir"))
+ )
+ args[get("train.compute_type")] = True
+ args["disable_tqdm"] = True
+
+ if TRAINING_STAGES[get("train.training_stage")] in ["rm", "ppo", "dpo"]:
+ args["resume_lora_training"] = (args["quantization_bit"] is not None)
+
+ if args["quantization_bit"] is not None:
+ args["upcast_layernorm"] = True
+
+ if args["stage"] == "ppo":
+ args["reward_model"] = get("train.reward_model")
+
+ if args["stage"] == "dpo":
+ args["dpo_beta"] = get("train.dpo_beta")
+
+ if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
+ args["val_size"] = get("train.val_size")
+ args["evaluation_strategy"] = "steps"
+ args["eval_steps"] = get("train.save_steps")
+ args["load_best_model_at_end"] = True
+
+ return args
+
+ def _parse_eval_args(self, data: Dict[Component, Any]) -> Dict[str, Any]:
+ get = lambda name: data[self.manager.get_elem_by_name(name)]
+ user_config = load_config()
+
+ if get("top.checkpoints"):
+ checkpoint_dir = ",".join([
+ get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
+ ])
+ output_dir = get_save_dir(
+ get("top.model_name"), get("top.finetuning_type"), "eval_" + "_".join(get("top.checkpoints"))
+ )
+ else:
+ checkpoint_dir = None
+ output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), "eval_base")
+
+ args = dict(
+ stage="sft",
+ model_name_or_path=get("top.model_path"),
+ do_eval=True,
+ predict_with_generate=True,
+ cache_dir=user_config.get("cache_dir", None),
+ checkpoint_dir=checkpoint_dir,
+ finetuning_type=get("top.finetuning_type"),
+ quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
+ template=get("top.template"),
+ system_prompt=get("top.system_prompt"),
+ flash_attn=get("top.flash_attn"),
+ shift_attn=get("top.shift_attn"),
+ rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
+ dataset_dir=get("eval.dataset_dir"),
+ dataset=",".join(get("eval.dataset")),
+ cutoff_len=get("eval.cutoff_len"),
+ max_samples=int(get("eval.max_samples")),
+ per_device_eval_batch_size=get("eval.batch_size"),
+ max_new_tokens=get("eval.max_new_tokens"),
+ top_p=get("eval.top_p"),
+ temperature=get("eval.temperature"),
+ output_dir=output_dir
+ )
+
+ if get("eval.predict"):
+ args.pop("do_eval", None)
+ args["do_predict"] = True
+
+ return args
+
+ def _preview(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
+ error = self._initialize(data, do_train)
+ if error:
+ gr.Warning(error)
+ yield error, gr.update(visible=False)
+ else:
+ args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
+ yield gen_cmd(args), gr.update(visible=False)
+
+ def _launch(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
+ error = self._initialize(data, do_train)
+ if error:
+ gr.Warning(error)
+ yield error, gr.update(visible=False)
+ else:
+ args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
+ run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
+ self.running = True
+ self.do_train, self.running_data = do_train, data
+ self.monitor_inputs = dict(lang=data[self.manager.get_elem_by_name("top.lang")], output_dir=args["output_dir"])
+ self.thread = Thread(target=run_exp, kwargs=run_kwargs)
+ self.thread.start()
+ yield from self.monitor()
+
+ def preview_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
+ yield from self._preview(data, do_train=True)
+
+ def preview_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
+ yield from self._preview(data, do_train=False)
+
+ def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
+ yield from self._launch(data, do_train=True)
+
+ def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
+ yield from self._launch(data, do_train=False)
+
+ def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
+ lang, output_dir = self.monitor_inputs["lang"], self.monitor_inputs["output_dir"]
+ while self.thread.is_alive():
+ time.sleep(2)
+ if self.aborted:
+ yield ALERTS["info_aborting"][lang], gr.update(visible=False)
+ else:
+ yield self.logger_handler.log, update_process_bar(self.trainer_callback)
+
+ if self.do_train:
+ if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
+ finish_info = ALERTS["info_finished"][lang]
+ else:
+ finish_info = ALERTS["err_failed"][lang]
+ else:
+ if os.path.exists(os.path.join(output_dir, "all_results.json")):
+ finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
+ else:
+ finish_info = ALERTS["err_failed"][lang]
+
+ yield self._finalize(lang, finish_info), gr.update(visible=False)
diff --git a/llm_rl/src/llmtuner/webui/utils.py b/llm_rl/src/llmtuner/webui/utils.py
new file mode 100644
index 00000000..933d951d
--- /dev/null
+++ b/llm_rl/src/llmtuner/webui/utils.py
@@ -0,0 +1,85 @@
+import os
+import json
+import gradio as gr
+import matplotlib.figure
+import matplotlib.pyplot as plt
+from typing import TYPE_CHECKING, Any, Dict
+from datetime import datetime
+
+from llmtuner.extras.ploting import smooth
+from llmtuner.webui.common import get_save_dir
+
+if TYPE_CHECKING:
+ from llmtuner.extras.callbacks import LogCallback
+
+
+def update_process_bar(callback: "LogCallback") -> Dict[str, Any]:
+ if not callback.max_steps:
+ return gr.update(visible=False)
+
+ percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0
+ label = "Running {:d}/{:d}: {} < {}".format(
+ callback.cur_steps,
+ callback.max_steps,
+ callback.elapsed_time,
+ callback.remaining_time
+ )
+ return gr.update(label=label, value=percentage, visible=True)
+
+
+def get_time() -> str:
+ return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
+
+
+def can_quantize(finetuning_type: str) -> Dict[str, Any]:
+ if finetuning_type != "lora":
+ return gr.update(value="None", interactive=False)
+ else:
+ return gr.update(interactive=True)
+
+
+def gen_cmd(args: Dict[str, Any]) -> str:
+ args.pop("disable_tqdm", None)
+ args["plot_loss"] = args.get("do_train", None)
+ cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python src/train_bash.py "]
+ for k, v in args.items():
+ if v is not None and v != "":
+ cmd_lines.append(" --{} {} ".format(k, str(v)))
+ cmd_text = "\\\n".join(cmd_lines)
+ cmd_text = "```bash\n{}\n```".format(cmd_text)
+ return cmd_text
+
+
+def get_eval_results(path: os.PathLike) -> str:
+ with open(path, "r", encoding="utf-8") as f:
+ result = json.dumps(json.load(f), indent=4)
+ return "```json\n{}\n```\n".format(result)
+
+
+def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotlib.figure.Figure:
+ if not base_model:
+ return
+ log_file = get_save_dir(base_model, finetuning_type, output_dir, "trainer_log.jsonl")
+ if not os.path.isfile(log_file):
+ return
+
+ plt.close("all")
+ fig = plt.figure()
+ ax = fig.add_subplot(111)
+ steps, losses = [], []
+ with open(log_file, "r", encoding="utf-8") as f:
+ for line in f:
+ log_info = json.loads(line)
+ if log_info.get("loss", None):
+ steps.append(log_info["current_steps"])
+ losses.append(log_info["loss"])
+
+ if len(losses) == 0:
+ return None
+
+ ax.plot(steps, losses, alpha=0.4, label="original")
+ ax.plot(steps, smooth(losses), label="smoothed")
+ ax.legend()
+ ax.set_xlabel("step")
+ ax.set_ylabel("loss")
+ return fig
diff --git a/llm_rl/src/train_bash.py b/llm_rl/src/train_bash.py
new file mode 100644
index 00000000..9ddd0586
--- /dev/null
+++ b/llm_rl/src/train_bash.py
@@ -0,0 +1,14 @@
+from llmtuner import run_exp
+
+
+def main():
+ run_exp()
+
+
+def _mp_fn(index):
+ # For xla_spawn (TPUs)
+ main()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/llm_rl/src/train_web.py b/llm_rl/src/train_web.py
new file mode 100644
index 00000000..38efd64d
--- /dev/null
+++ b/llm_rl/src/train_web.py
@@ -0,0 +1,11 @@
+from llmtuner import create_ui
+
+
+def main():
+ demo = create_ui()
+ demo.queue()
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/llm_rl/src/web_demo.py b/llm_rl/src/web_demo.py
new file mode 100644
index 00000000..257536ab
--- /dev/null
+++ b/llm_rl/src/web_demo.py
@@ -0,0 +1,11 @@
+from llmtuner import create_web_demo
+
+
+def main():
+ demo = create_web_demo()
+ demo.queue()
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/llm_rl/tests/cal_flops.py b/llm_rl/tests/cal_flops.py
new file mode 100644
index 00000000..01b005af
--- /dev/null
+++ b/llm_rl/tests/cal_flops.py
@@ -0,0 +1,44 @@
+# coding=utf-8
+# Calculates the flops of pre-trained models.
+# Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
+# Inspired by: https://www.deepspeed.ai/tutorials/flops-profiler/
+
+import fire
+import torch
+from typing import Optional
+from deepspeed.accelerator import get_accelerator # type: ignore
+from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore
+
+from llmtuner import ChatModel
+
+
+def calculate(
+ model_name_or_path: str,
+ batch_size: Optional[int] = 1,
+ seq_length: Optional[int] = 256,
+ flash_attn: Optional[bool] = False
+):
+ with get_accelerator().device(0):
+ chat_model = ChatModel(dict(
+ model_name_or_path=model_name_or_path,
+ template="vanilla",
+ flash_attn=flash_attn
+ ))
+ fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device)
+ input_dict = {
+ "input_ids": fake_input,
+ "labels": fake_input.clone()
+ }
+ flops, macs, params = get_model_profile(
+ chat_model.model,
+ kwargs=input_dict,
+ print_profile=True,
+ detailed=True
+ )
+ print("FLOPs:", flops)
+ print("MACs:", macs)
+ print("Params:", params)
+
+
+if __name__ == "__main__":
+ fire.Fire(calculate)
diff --git a/llm_rl/tests/llamafy_baichuan2.py b/llm_rl/tests/llamafy_baichuan2.py
new file mode 100644
index 00000000..d08eee1c
--- /dev/null
+++ b/llm_rl/tests/llamafy_baichuan2.py
@@ -0,0 +1,86 @@
+# coding=utf-8
+# Converts the Baichuan2-7B model in the same format as LLaMA2-7B.
+# Usage: python llamafy_baichuan2.py --input_dir input --output_dir output --shard_size 10GB
+# Inspired by: https://huggingface.co/fireballoon/baichuan-llama-7b/blob/main/convert_baichuan_to_llama.py
+# Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied
+
+import os
+import fire
+import json
+import torch
+from collections import OrderedDict
+from transformers.modeling_utils import shard_checkpoint, WEIGHTS_NAME, WEIGHTS_INDEX_NAME
+from typing import Any, Dict
+
+
+CONFIG_NAME = "config.json"
+
+
+def save_weight(
+ input_dir: str,
+ output_dir: str,
+ shard_size: str
+):
+ baichuan2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
+ for filepath in os.listdir(input_dir):
+ if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
+ shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu")
+ baichuan2_state_dict.update(shard_weight)
+
+ llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
+ for key, value in baichuan2_state_dict.items():
+ if "W_pack" in key:
+ proj_size = value.size(0) // 3
+ llama2_state_dict[key.replace("W_pack", "q_proj")] = value[:proj_size, :]
+ llama2_state_dict[key.replace("W_pack", "k_proj")] = value[proj_size:2*proj_size, :]
+ llama2_state_dict[key.replace("W_pack", "v_proj")] = value[2*proj_size:, :]
+ elif "lm_head" in key:
+ llama2_state_dict[key] = torch.nn.functional.normalize(value)
+ else:
+ llama2_state_dict[key] = value
+
+ shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=WEIGHTS_NAME)
+ for shard_file, shard in shards.items():
+ torch.save(shard, os.path.join(output_dir, shard_file))
+
+ if index is None:
+ print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
+ else:
+ with open(os.path.join(output_dir, WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f:
+ json.dump(index, f, indent=2, sort_keys=True)
+ print("Model weights saved in {}".format(output_dir))
+
+
+def save_config(
+ input_dir: str,
+ output_dir: str
+):
+ with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
+ llama2_config_dict: Dict[str, Any] = json.load(f)
+
+ llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
+ llama2_config_dict.pop("auto_map", None)
+ llama2_config_dict.pop("tokenizer_class", None)
+ llama2_config_dict["model_type"] = "llama"
+
+ with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
+ json.dump(llama2_config_dict, f, indent=2)
+ print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
+
+
+def llamafy_baichuan2(
+ input_dir: str,
+ output_dir: str,
+ shard_size: str
+):
+ try:
+ os.makedirs(output_dir, exist_ok=False)
+ except Exception as e:
+ raise print("Output dir already exists", e)
+
+ save_weight(input_dir, output_dir, shard_size)
+ save_config(input_dir, output_dir)
+
+
+if __name__ == "__main__":
+ fire.Fire(llamafy_baichuan2)
diff --git a/llm_rl/tests/llamafy_qwen.py b/llm_rl/tests/llamafy_qwen.py
new file mode 100644
index 00000000..8b9fc395
--- /dev/null
+++ b/llm_rl/tests/llamafy_qwen.py
@@ -0,0 +1,135 @@
+# coding=utf-8
+# Converts the Qwen models in the same format as LLaMA2.
+# Usage: python llamafy_qwen.py --input_dir input --output_dir output --shard_size 10GB
+
+import os
+import fire
+import json
+import torch
+from collections import OrderedDict
+from safetensors import safe_open
+from transformers.modeling_utils import shard_checkpoint, WEIGHTS_NAME, WEIGHTS_INDEX_NAME
+from transformers.utils import check_min_version
+from typing import Any, Dict
+
+try:
+ check_min_version("4.34.0")
+except:
+ raise ValueError("Please upgrade `transformers` to 4.34.0")
+
+
+CONFIG_NAME = "config.json"
+
+
+def save_weight(
+ input_dir: str,
+ output_dir: str,
+ shard_size: str
+) -> str:
+ qwen_state_dict: Dict[str, torch.Tensor] = OrderedDict()
+ for filepath in os.listdir(input_dir):
+ if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".safetensors"):
+ with safe_open(os.path.join(input_dir, filepath), framework="pt", device="cpu") as f:
+ for key in f.keys():
+ qwen_state_dict[key] = f.get_tensor(key)
+
+ llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
+ torch_dtype = None
+ for key, value in qwen_state_dict.items():
+ if torch_dtype is None:
+ torch_dtype = value.dtype
+ if "wte" in key:
+ llama2_state_dict["model.embed_tokens.weight"] = value
+ elif "ln_f" in key:
+ llama2_state_dict["model.norm.weight"] = value
+ else:
+ key = key.replace("transformer.h", "model.layers")
+ if "attn.c_attn" in key:
+ proj_size = value.size(0) // 3
+ llama2_state_dict[key.replace("attn.c_attn", "self_attn.q_proj")] = value[:proj_size, ...]
+ llama2_state_dict[key.replace("attn.c_attn", "self_attn.k_proj")] = value[proj_size:2*proj_size, ...]
+ llama2_state_dict[key.replace("attn.c_attn", "self_attn.v_proj")] = value[2*proj_size:, ...]
+ elif "attn.c_proj" in key:
+ llama2_state_dict[key.replace("attn.c_proj", "self_attn.o_proj")] = value
+ llama2_state_dict[key.replace("attn.c_proj.weight", "self_attn.o_proj.bias")] = (
+ torch.zeros_like(value[:, 0]).squeeze()
+ )
+ elif "ln_1" in key:
+ llama2_state_dict[key.replace("ln_1", "input_layernorm")] = value
+ elif "ln_2" in key:
+ llama2_state_dict[key.replace("ln_2", "post_attention_layernorm")] = value
+ elif "mlp.w1" in key:
+ llama2_state_dict[key.replace("mlp.w1", "mlp.up_proj")] = value
+ elif "mlp.w2" in key:
+ llama2_state_dict[key.replace("mlp.w2", "mlp.gate_proj")] = value
+ elif "mlp.c_proj" in key:
+ llama2_state_dict[key.replace("mlp.c_proj", "mlp.down_proj")] = value
+ elif "lm_head" in key:
+ llama2_state_dict[key] = value
+ else:
+ raise KeyError("Unable to process key {}".format(key))
+
+ shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=WEIGHTS_NAME)
+ for shard_file, shard in shards.items():
+ torch.save(shard, os.path.join(output_dir, shard_file))
+
+ if index is None:
+ print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
+ else:
+ with open(os.path.join(output_dir, WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f:
+ json.dump(index, f, indent=2, sort_keys=True)
+ print("Model weights saved in {}".format(output_dir))
+
+ return str(torch_dtype).replace("torch.", "")
+
+
+def save_config(
+ input_dir: str,
+ output_dir: str,
+ torch_dtype: str
+):
+ with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
+ qwen_config_dict: Dict[str, Any] = json.load(f)
+
+ llama2_config_dict: Dict[str, Any] = OrderedDict()
+ llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
+ llama2_config_dict["hidden_act"] = "silu"
+ llama2_config_dict["hidden_size"] = qwen_config_dict["hidden_size"]
+ llama2_config_dict["initializer_range"] = qwen_config_dict["initializer_range"]
+ llama2_config_dict["intermediate_size"] = qwen_config_dict["intermediate_size"] // 2
+ llama2_config_dict["max_position_embeddings"] = qwen_config_dict["max_position_embeddings"]
+ llama2_config_dict["model_type"] = "llama"
+ llama2_config_dict["num_attention_heads"] = qwen_config_dict["num_attention_heads"]
+ llama2_config_dict["num_hidden_layers"] = qwen_config_dict["num_hidden_layers"]
+ llama2_config_dict["num_key_value_heads"] = qwen_config_dict["hidden_size"] // qwen_config_dict["kv_channels"]
+ llama2_config_dict["pretraining_tp"] = 1
+ llama2_config_dict["rms_norm_eps"] = qwen_config_dict["layer_norm_epsilon"]
+ llama2_config_dict["rope_scaling"] = None
+ llama2_config_dict["tie_word_embeddings"] = qwen_config_dict["tie_word_embeddings"]
+ llama2_config_dict["torch_dtype"] = torch_dtype
+ llama2_config_dict["transformers_version"] = "4.34.0"
+ llama2_config_dict["use_cache"] = True
+ llama2_config_dict["vocab_size"] = qwen_config_dict["vocab_size"]
+ llama2_config_dict["attention_bias"] = True
+
+ with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
+ json.dump(llama2_config_dict, f, indent=2)
+ print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
+
+
+def llamafy_qwen(
+ input_dir: str,
+ output_dir: str,
+ shard_size: str
+):
+ try:
+ os.makedirs(output_dir, exist_ok=False)
+ except Exception as e:
+ raise print("Output dir already exists", e)
+
+ torch_dtype = save_weight(input_dir, output_dir, shard_size)
+ save_config(input_dir, output_dir, torch_dtype)
+
+
+if __name__ == "__main__":
+ fire.Fire(llamafy_qwen)
diff --git a/llm_rl/tests/quantize.py b/llm_rl/tests/quantize.py
new file mode 100644
index 00000000..25321cf3
--- /dev/null
+++ b/llm_rl/tests/quantize.py
@@ -0,0 +1,50 @@
+# coding=utf-8
+# Quantizes models with AutoGPTQ (https://github.com/PanQiWei/AutoGPTQ).
+# Usage: python quantize.py --input_dir path_to_llama_model --output_dir path_to_quant_model --data_file alpaca.json
+# --max_length 1024 --max_samples 1024
+# dataset format: instruction (string), input (string), output (string), history (List[string])
+
+
+import fire
+from datasets import load_dataset
+from transformers import AutoTokenizer
+from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
+
+
+def quantize(input_dir: str, output_dir: str, data_file: str, max_length: int, max_samples: int):
+ tokenizer = AutoTokenizer.from_pretrained(input_dir, use_fast=False, padding_side="left")
+
+ def format_example(examples):
+ prefix=("A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.")
+ texts = []
+ for i in range(len(examples["instruction"])):
+ prompt = prefix + "\n"
+ if "history" in examples:
+ for user_query, bot_resp in examples["history"][i]:
+ prompt += "Human: {}\nAssistant: {}\n".format(user_query, bot_resp)
+ prompt += "Human: {}\nAssistant: {}".format(
+ examples["instruction"][i] + "\n" + examples["input"][i], examples["output"][i]
+ )
+ texts.append(prompt)
+ return tokenizer(texts, truncation=True, max_length=max_length)
+
+ dataset = load_dataset("json", data_files=data_file)["train"]
+ column_names = list(dataset.column_names)
+ dataset = dataset.select(range(min(len(dataset), max_samples)))
+ dataset = dataset.map(format_example, batched=True, remove_columns=column_names)
+ dataset = dataset.shuffle()
+
+ quantize_config = BaseQuantizeConfig(
+ bits=4,
+ group_size=128,
+ desc_act=False
+ )
+
+ model = AutoGPTQForCausalLM.from_pretrained(input_dir, quantize_config, trust_remote_code=True)
+ model.quantize(dataset)
+ model.save_quantized(output_dir)
+
+
+if __name__ == "__main__":
+ fire.Fire(quantize)