Skip to content

Commit

Permalink
[finetune] basic upgrade for supporting DDP and FSDP on PVC (intel#74)
Browse files Browse the repository at this point in the history
* fix bugs and update codes for enabling workflow on borealis

* update

* update

* update

* update

* update

* add file

* Update README.finetune.gpu.md

* update

* update

* add gpu workflow yml

* up

* up

* up

* up

* up

* just disable for debugging

* update

* update

* update

* [common] add device option for TorchConfig (intel#126)

* add device option for TorchConfig

* update

* update

* update

* fix bugs and update codes for enabling workflow on borealis

* update

* update

* update

* up

* update

* update

* udpate

* update

* update

* fix comments

* update

* update

* update

* update

* update

* update

* udpate

* update
  • Loading branch information
harborn authored Nov 23, 2023
1 parent 122741e commit 1e390e0
Show file tree
Hide file tree
Showing 9 changed files with 296 additions and 34 deletions.
20 changes: 20 additions & 0 deletions .github/workflows/workflow_finetune_gpu.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
name: Finetune on PVC

on:
workflow_call:

jobs:
finetune:
name: finetune on gpu test
runs-on: self-hosted
steps:
- name: Checkout
uses: actions/checkout@v2

- name: Start remote task
run: |
source ~/borealis-runner/init_conda.sh
python ~/borealis-runner/gpu_task_on_pvc.py
- name: Test Summary
run: echo "to be continued"
21 changes: 13 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,23 @@ If deploying a ray cluster on multiple nodes, please download the workflow repos

#### 1. Prepare Dataset

Now, the workflow only supports datasets in the specified format

The format of dataset similar to [databricks/databricks-dolly-15k](https://huggingface.co/datasets/databricks/databricks-dolly-15k). This type of data is used for finetuning in prompt mode and this type of data is characterized by containing `instruction` `context` and `response` fields where `instruction` and `response` are required fields and `context` is an optional field. In the data preprocessing stage, the three fields will be concatenated to the corresponding format according to [dolly](https://github.com/databrickslabs/dolly/blob/master/training/trainer.py#LL93).
The workflow only supports datasets with JSONL (JSON Lines) format, where each line is a separate JSON object. Here’s the structure each line should follow:

``` json
{"instruction":"<User Input>", "context":"<Additional Information>", "response":"<Expected Output>"}
```

The meaning of the above three columns:
+ Instruction Column: The column in the dataset is the user input, such as a question or a command.
+ Context Column: This column is other information used by instruction, such as the options used in the question and so on. It can be empty.
+ Response: The column in the dataset containing the expected output.
- Instruction: This is the user's input, such as a question, command, or prompt for content generation.
- Context: Supplementary information that aids the instruction. This can include previous conversation parts, background details, or specificities influencing the response. It's optional and can be left empty.
- Response: The model's expected output in response to the 'instruction', considering the 'context' if provided.

##### Examples:
``` json
{"instruction":"Which is a species of fish? Tope or Rope", "context":"", "response":"Tope"}
{"instruction":"What is the average lifespan of a Golden Retriever?","context":"Golden Retrievers are a generally healthy breed; they have an average lifespan of 12 to 13 years. Irresponsible breeding to meet high demand has led to the prevalence of inherited health problems in some breed lines, including allergic skin conditions, eye problems and sometimes snappiness. These problems are rarely encountered in dogs bred from responsible breeders.","response":"The average lifespan of a Golden Retriever is 12 to 13 years."}
```

Therefore, if the your data meets the above two formats, you can use the data by configuring the local data path or huggingface dataset. If not, please refer to the following **Adopt to Your Dataset**.
An example dataset can be accessed at `examples/data/sample_finetune_data.jsonl`. Ensure each line in your dataset follows the above format.

#### 2. Finetune

Expand Down
18 changes: 10 additions & 8 deletions common/trainer/default_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def recovery(self, config):

# update lr_scheduler status
if Path.exists(checkpoint_dir / "lr_scheduler.pt") and hasattr(self, "lr_scheduler"):
scheduler_state = torch.load(checkpoint_dir / "lr_schduler.pt", map_location="cpu")
scheduler_state = torch.load(checkpoint_dir / "lr_scheduler.pt", map_location="cpu")
self.lr_scheduler.load_state_dict(scheduler_state)

# update current epoch
Expand Down Expand Up @@ -111,12 +111,14 @@ def prepare(self, model, tokenizer, dataset, optimizer, accelerator):
lr_scheduler = None

model.train()
self.model, self.optimizer, self.lr_scheduler = accelerator.prepare(
model, optimizer, lr_scheduler
)

self.train_dataloader, self.eval_dataloader = accelerator.prepare(
train_dataloader, eval_dataloader,
# self.model, self.optimizer, self.lr_scheduler, ..., are prepared with 2 steps
# because it is recommended way to prepare model and optimizer while using FSDP.
# https://huggingface.co/docs/accelerate/usage_guides/fsdp#a-few-caveats-to-be-aware-of
self.model = accelerator.prepare(model)

self.optimizer, self.train_dataloader, self.eval_dataloader, self.lr_scheduler = accelerator.prepare(
optimizer, train_dataloader, eval_dataloader, lr_scheduler
)

checkpoint = self.config.get("checkpoint")
Expand Down Expand Up @@ -144,7 +146,7 @@ def train(self):
self.lr_scheduler.step()
self.optimizer.zero_grad()
if step % log_step == 0:
logger.info(f"train epoch:[{idx}/{num_train_epochs}]\tstep:[{step}/{total_steps}]\tloss:{loss}\tppl:{math.exp(loss)}\ttime:{time.time()-start}")
logger.info(f"train epoch:[{idx}/{num_train_epochs}]\tstep:[{step}/{total_steps}]\tloss:{loss:.6f}\tppl:{math.exp(loss):.6f}\ttime:{time.time()-start:.6f}")
report({"train_epoch": idx, "total_epochs": num_train_epochs, "train_step": step, "total_steps": min(max_train_step, total_steps) if max_train_step else total_steps})
start = time.time()
if max_train_step is not None:
Expand Down Expand Up @@ -207,7 +209,7 @@ def save(self, config, epoch = 0):
torch.save(self.optimizer.state_dict(), os.path.join(tmpdir, "optim.pt"))
torch.save({"epoch": epoch}, os.path.join(tmpdir, "epoch.pt"))
if self.lr_scheduler:
torch.save(self.lr_scheduler.state_dict(), os.path.join(tmpdir, "lr_schduler.pt"))
torch.save(self.lr_scheduler.state_dict(), os.path.join(tmpdir, "lr_scheduler.pt"))
checkpoint = Checkpoint.from_directory(tmpdir)
checkpoint.to_directory(local_checkpoint_path)
logger.info(f"save checkpoint to {local_checkpoint_path} finished")
111 changes: 111 additions & 0 deletions docs/README.finetune.intel.gpu.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
The LLM on Ray finetuning workflow supports running with Intel GPU.

The following is an example to setup the environment and finetune a LLM model.

## Hardware and Software requirements

### Hardware Requirements

[Intel® Data Center GPU Max Series](https://www.intel.com/content/www/us/en/products/details/discrete-gpus/data-center-gpu/max-series/products.html)

|Product Name|Launch Date|Memory Size|Xe-cores|
|---|---|---|---|
|[Intel® Data Center GPU Max 1550](https://www.intel.com/content/www/us/en/products/sku/232873/intel-data-center-gpu-max-1550/specifications.html)|Q1'23|128 GB|128|
|[Intel® Data Center GPU Max 1100](https://www.intel.com/content/www/us/en/products/sku/232876/intel-data-center-gpu-max-1100/specifications.html)|Q2'23|48 GB|56|

### Software Requirements
Workflow has been tested on SUSE Linux Enterprise Server 15 SP4 (Linux borealis-uan1 5.14.21-150400.22-default)
- conda
- Python 3.9

## Setup

### Build conda environment
``` bash
conda create -n llm_ray_gpu python=3.9
# after install necessary python modules
conda activate llm_ray_gpu
```

### Clone codes
``` bash
git clone https://github.com/intel-sandbox/llm-ray.git llm-ray-gpu
cd llm-ray-gpu
```

### Install dependencies

Suppose you have installed [Intel® oneAPI Base Toolkit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html)

``` bash
pip install -r requirements.intel.gpu.txt
# install ipex/oneccl/torch/torchvision with intel gpu version
pip install torch==2.0.1a0 torchvision==0.15.2a0 intel-extension-for-pytorch==2.0.110+xpu oneccl-bind-pt --extra-index-url https://developer.intel.com/ipex-whl-stable-xpu
```
More versions are available [here](https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/torch/). More detail information's are [here](https://intel.github.io/intel-extension-for-pytorch/)

## Prepare the Dataset for Finetuning

To finetune the model, your dataset must be in a JSONL (JSON Lines) format, where each line is a separate JSON object. Here’s the structure each line should follow:

``` json
{"instruction":"<User Input>", "context":"<Additional Information>", "response":"<Expected Output>"}
```

- Instruction: This is the user's input, such as a question, command, or prompt for content generation.
- Context: Supplementary information that aids the instruction. This can include previous conversation parts, background details, or specificities influencing the response. It's optional and can be left empty.
- Response: The model's expected output in response to the 'instruction', considering the 'context' if provided.

### Examples:
``` json
{"instruction":"Which is a species of fish? Tope or Rope", "context":"", "response":"Tope"}
{"instruction":"What is the average lifespan of a Golden Retriever?","context":"Golden Retrievers are a generally healthy breed; they have an average lifespan of 12 to 13 years. Irresponsible breeding to meet high demand has led to the prevalence of inherited health problems in some breed lines, including allergic skin conditions, eye problems and sometimes snappiness. These problems are rarely encountered in dogs bred from responsible breeders.","response":"The average lifespan of a Golden Retriever is 12 to 13 years."}
```

An example dataset can be accessed at `examples/data/sample_finetune_data.jsonl`. Ensure each line in your dataset follows the above format.

## Configurations for Finetuning

The workflow is designed for configure driven.

Detail finetuning configure parameters are [here](../docs/finetune_parameters.md)

To finetune on Intel GPU, following options may be modified:

``` json
{
"General": {
"base_model": "EleutherAI/gpt-j-6b",
},
"Dataset": {
"train_file": "examples/data/sample_finetune_data.jsonl",
},
"Training": {
"device": "GPU",
"num_training_workers": 2,
"accelerate_mode": "GPU_DDP",
"resources_per_worker": {
"CPU": 1,
"GPU": 1,
},
},
}
```

### Models

Finetuning workflow support many models as base model, following models are verified:
- gpt-j-6b
- pythia-70m/160m/410m/1b/1.4b/2.8b/6.9b/12b
- Llama-2-7b

### Example

We provide an [example](../examples/finetune/gpt_j_6b/finetune_intel_gpu.conf) configuration for finetuning `gpt-j-6b` on Intel GPUs.

## Finetuning

``` bash
python finetune/finetune.py --config_path examples/finetune/gpt_j_6b/finetune_intel_gpu.conf
```

12 changes: 4 additions & 8 deletions docs/finetune_parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,10 @@ The following are the parameters supported in the finetuning workflow.
|learning_rate|1e-5|Initial learning rate to use.|
|lr_scheduler|linear|The scheduler type to use, supported value: "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"|
|weight_decay|0.0|Weight decay is a regularization technique that adds an L2 norm of all model weights to the loss function while increasing the probability of improving the model generalization.|
|device|CPU|The device type used, can be CPU or GPU.|
|num_training_workers|2|The number of the training process|
|resources_per_worker|{"CPU": 32}|A dict to specify the resources for each worker. If `device` is GPU, please set it like {"CPU": 32, "GPU": 1}.|
|device|CPU|The device type used, can be "CPU", "GPU".|
|num_training_workers|2|The number of the training process.|
|resources_per_worker|{"CPU": 32}|A dict to specify the resources for each worker. If `device` is "GPU", please set it like {"CPU": 32, "GPU": 1}.|
|accelerate_mode|CPU_DDP|The accelerate mode for training model, available options are: "CPU_DDP", "GPU_DDP", "GPU_FSDP".|
|max_train_steps|None|Total number of training steps to perform. If provided, overrides epochs.|
|gradient_accumulation_steps|1|Number of updates steps to accumulate before performing a backward/update pass.|
|seed|None|A seed for reproducible training.|





41 changes: 41 additions & 0 deletions examples/finetune/gpt_j_6b/finetune_intel_gpu.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"General": {
"base_model": "EleutherAI/gpt-j-6b",
# fix issue: https://github.com/huggingface/transformers/issues/22482
# tranformers version 4.26.0 is required for gpt2, gpt-j-6B, pythia...
"gpt_base_model": True,
"output_dir": "/tmp/llm-ray/output",
"checkpoint_dir": "/tmp/llm-ray/checkpoint",
"config": {
"trust_remote_code": False,
"use_auth_token": None,
},
"lora_config": {
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 32,
"lora_dropout": 0.1
}
},
"Dataset": {
"train_file": "examples/data/sample_finetune_data.jsonl",
"validation_file": None,
"validation_split_percentage": 5
},
"Training": {
"optimizer": "AdamW",
"batch_size": 4,
"epochs": 3,
"learning_rate": 1e-5,
"lr_scheduler": "linear",
"weight_decay": 0.0,
"device": "GPU",
"num_training_workers": 2,
"accelerate_mode": "GPU_DDP",
"resources_per_worker": {
"CPU": 1,
"GPU": 1,
},
},
}

4 changes: 3 additions & 1 deletion finetune/finetune.conf
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
"device": "CPU",
"num_training_workers": 2,
"resources_per_worker": {
"CPU": 32
"CPU": 32,
# "GPU": 1,
},
"accelerate_mode": "CPU_DDP",
},
}
Loading

0 comments on commit 1e390e0

Please sign in to comment.