From 60fc3be2434fbdae09dfb7e79edaae694b41a80c Mon Sep 17 00:00:00 2001 From: zhisbug Date: Tue, 3 Jan 2023 16:08:27 -0500 Subject: [PATCH 01/16] add some dev essentials --- alpa/parallel_method.py | 1 + .../coreweave/hao-dev-1node2a100-128ram.yaml | 175 ++++++++++++++++++ .../coreweave/opt_finetune_cuda113.Dockerfile | 95 ++++++++++ examples/opt_finetune/requirements.txt | 3 + examples/opt_finetune/run_125m_shard.sh | 2 +- examples/opt_finetune/run_2.7b_shard.sh | 2 +- 6 files changed, 276 insertions(+), 2 deletions(-) create mode 100644 examples/opt_finetune/coreweave/hao-dev-1node2a100-128ram.yaml create mode 100644 examples/opt_finetune/coreweave/opt_finetune_cuda113.Dockerfile create mode 100644 examples/opt_finetune/requirements.txt diff --git a/alpa/parallel_method.py b/alpa/parallel_method.py index 3b92902b9..12ae7f1c5 100644 --- a/alpa/parallel_method.py +++ b/alpa/parallel_method.py @@ -259,6 +259,7 @@ def get_3d_parallel_method(num_micro_batches: int, num_devices_per_host = virtual_mesh.num_devices_per_host if data_parallel == -1: data_parallel = (num_devices // operator_parallel // pipeline_parallel) + print(f"num_devices {num_devices} dp {data_parallel}, op {operator_parallel}, pp {pipeline_parallel}") assert num_devices % data_parallel == 0 assert num_devices % operator_parallel == 0 assert num_devices % pipeline_parallel == 0 diff --git a/examples/opt_finetune/coreweave/hao-dev-1node2a100-128ram.yaml b/examples/opt_finetune/coreweave/hao-dev-1node2a100-128ram.yaml new file mode 100644 index 000000000..7d9854e96 --- /dev/null +++ b/examples/opt_finetune/coreweave/hao-dev-1node2a100-128ram.yaml @@ -0,0 +1,175 @@ +apiVersion: v1 +kind: Service +metadata: + namespace: tenant-jiaohpc-jd # TODO: Change to your namespace + name: service-ray-cluster + labels: + app: ray-cluster +spec: + ports: + - name: dashboard + protocol: TCP + port: 8265 + targetPort: 8265 + - name: gcs-server + protocol: TCP + port: 6380 + targetPort: 6380 + selector: + app: ray-cluster + component: ray-head +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + namespace: tenant-jiaohpc-jd # TODO: Change to your namespace + name: deployment-ray-head + labels: + app: ray-cluster + ray-node: head +spec: + # Do not change this - Ray currently only supports one head node per cluster. + replicas: 1 + selector: + matchLabels: + component: ray-head + type: ray + app: ray-cluster + template: + metadata: + labels: + component: ray-head + type: ray + app: ray-cluster + spec: + # If the head node goes down, the entire cluster (including all worker + # nodes) will go down as well. If you want Kubernetes to bring up a new + # head node in this case, set this to "Always," else set it to "Never." + restartPolicy: Always + + # This volume allocates shared memory for Ray to use for its plasma + # object store. If you do not provide this, Ray will fall back to + # /tmp which cause slowdowns if is not a shared memory volume. + volumes: + - name: dshm + emptyDir: + medium: Memory + - name: alpa-opt + persistentVolumeClaim: + claimName: alpa-opt + containers: + - name: ray-head + image: jiaodong/alpa:v1 # TODO: Change to your Alpa docker image + imagePullPolicy: IfNotPresent + # This volume allocates shared memory for Ray to use for its plasma] + # --login in required to have access to conda to activate alpa env + command: + - /bin/bash + - -l + - -c + - -- + - | + conda activate alpa; + cd /mnt/alpa-opt/alpa; + pip install -e .; + ray start --head --port=6380 --num-cpus=$MY_CPU_REQUEST --dashboard-host=0.0.0.0 --object-manager-port=8076 --node-manager-port=8077 --dashboard-agent-grpc-port=8078 --dashboard-agent-listen-port=8079 --min-worker-port=10002 --max-worker-port=19999 --redis-password=''; + # This volume allocates shared memory for Ray to use for its plasma + # object store. If you do not provide this, Ray will fall back to + # /tmp which cause slowdowns if is not a shared memory volume. + volumeMounts: + - mountPath: /dev/shm + name: dshm + - mountPath: /mnt/alpa-opt + name: alpa-opt + env: + # This is used in the ray start command so that Ray can spawn the + # correct number of processes. Omitting this may lead to degraded + # performance. + - name: MY_CPU_REQUEST + valueFrom: + resourceFieldRef: + resource: requests.cpu + resources: + limits: + cpu: 32 + memory: 128Gi + nvidia.com/gpu: 2 + rdma/ib: 1 + # Refer to CoreWeave's documentation for more details about GPU node types and placement + # https://docs.coreweave.com/coreweave-kubernetes/node-types + affinity: + nodeAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + nodeSelectorTerms: + - matchExpressions: + - key: gpu.nvidia.com/class + operator: In + values: + - A100_NVLINK_80GB +#--- +#apiVersion: apps/v1 +#kind: Deployment +#metadata: +# namespace: tenant-jiaohpc-jd # TODO: Change to your namespace +# name: deployment-ray-worker +# labels: +# app: ray-cluster +#spec: +# # Change this to scale the number of worker nodes started in the Ray cluster. +# replicas: 3 +# selector: +# matchLabels: +# component: ray-worker +# type: ray +# app: ray-cluster +# template: +# metadata: +# labels: +# component: ray-worker +# type: ray +# app: ray-cluster +# spec: +# restartPolicy: Always +# volumes: +# - name: dshm +# emptyDir: +# medium: Memory +# containers: +# - name: ray-worker +# image: jiaodong/alpa:v1 # TODO: Change to your Alpa docker image +# imagePullPolicy: IfNotPresent +# # --login in required to have access to conda to activate alpa env +# command: ["/bin/bash", "-l", "-c", "--"] +# args: +# - "conda activate alpa && ray start --num-cpus=$MY_CPU_REQUEST --address=service-ray-cluster:6380 --object-manager-port=8076 --node-manager-port=8077 --dashboard-agent-grpc-port=8078 --dashboard-agent-listen-port=8079 --min-worker-port=10002 --max-worker-port=19999 --block" +# # This volume allocates shared memory for Ray to use for its plasma +# # object store. If you do not provide this, Ray will fall back to +# # /tmp which cause slowdowns if is not a shared memory volume. +# volumeMounts: +# - mountPath: /dev/shm +# name: dshm +# env: +# # This is used in the ray start command so that Ray can spawn the +# # correct number of processes. Omitting this may lead to degraded +# # performance. +# - name: MY_CPU_REQUEST +# valueFrom: +# resourceFieldRef: +# resource: requests.cpu +# resources: +# limits: +# cpu: 32 +# memory: 64Gi +# nvidia.com/gpu: 8 +# rdma/ib: 1 +# # Refer to CoreWeave's documentation for more details about GPU node types and placement +# # https://docs.coreweave.com/coreweave-kubernetes/node-types +# affinity: +# nodeAffinity: +# requiredDuringSchedulingIgnoredDuringExecution: +# nodeSelectorTerms: +# - matchExpressions: +# - key: gpu.nvidia.com/class +# operator: In +# values: +# - A100_NVLINK_80GB \ No newline at end of file diff --git a/examples/opt_finetune/coreweave/opt_finetune_cuda113.Dockerfile b/examples/opt_finetune/coreweave/opt_finetune_cuda113.Dockerfile new file mode 100644 index 000000000..e5f40e5c9 --- /dev/null +++ b/examples/opt_finetune/coreweave/opt_finetune_cuda113.Dockerfile @@ -0,0 +1,95 @@ +# base docker image +FROM nvidia/cuda:11.3.0-cudnn8-devel-ubuntu20.04 + +# InfiniBand (IB) dependencies adopoted from CoreWeave's github +# https://github.com/coreweave/nccl-tests +ARG DEBIAN_FRONTEND=noninteractive +RUN apt-get -qq update && \ + apt-get -qq install -y --allow-change-held-packages --no-install-recommends \ + build-essential libtool autoconf automake autotools-dev unzip \ + ca-certificates \ + wget curl openssh-server vim environment-modules \ + iputils-ping net-tools \ + libnuma1 libsubunit0 libpci-dev \ + libpmix-dev \ + datacenter-gpu-manager + +# Mellanox OFED (latest) +RUN wget -qO - https://www.mellanox.com/downloads/ofed/RPM-GPG-KEY-Mellanox | apt-key add - +RUN cd /etc/apt/sources.list.d/ && wget https://linux.mellanox.com/public/repo/mlnx_ofed/latest/ubuntu18.04/mellanox_mlnx_ofed.list +RUN apt-get -qq update \ + && apt-get -qq install -y --no-install-recommends \ + ibverbs-utils libibverbs-dev libibumad3 libibumad-dev librdmacm-dev rdmacm-utils infiniband-diags ibverbs-utils \ + && rm -rf /var/lib/apt/lists/* + +# HPC-X (2.12) +ENV HPCX_VERSION=2.12 +RUN cd /tmp && \ + wget -q -O - http://blobstore.s3.ord1.coreweave.com/drivers/hpcx-v${HPCX_VERSION}-gcc-MLNX_OFED_LINUX-5-ubuntu20.04-cuda11-gdrcopy2-nccl${HPCX_VERSION}-x86_64.tbz | tar xjf - && \ + mv hpcx-v${HPCX_VERSION}-gcc-MLNX_OFED_LINUX-5-ubuntu20.04-cuda11-gdrcopy2-nccl${HPCX_VERSION}-x86_64 /opt/hpcx + +# GDRCopy userspace components (2.3) +RUN cd /tmp && \ + wget -q https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2011.4/x86/Ubuntu20.04/gdrcopy-tests_2.3-1_amd64.cuda11_4.Ubuntu20_04.deb && \ + wget -q https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2011.4/x86/Ubuntu20.04/libgdrapi_2.3-1_amd64.Ubuntu20_04.deb && \ + dpkg -i *.deb && \ + rm *.deb + +# Begin auto-generated paths +ENV HPCX_DIR=/opt/hpcx +ENV HPCX_UCX_DIR=/opt/hpcx/ucx +ENV HPCX_UCC_DIR=/opt/hpcx/ucc +ENV HPCX_SHARP_DIR=/opt/hpcx/sharp +ENV HPCX_NCCL_RDMA_SHARP_PLUGIN_DIR=/opt/hpcx/nccl_rdma_sharp_plugin +ENV HPCX_HCOLL_DIR=/opt/hpcx/hcoll +ENV HPCX_MPI_DIR=/opt/hpcx/ompi +ENV HPCX_OSHMEM_DIR=/opt/hpcx/ompi +ENV HPCX_MPI_TESTS_DIR=/opt/hpcx/ompi/tests +ENV HPCX_OSU_DIR=/opt/hpcx/ompi/tests/osu-micro-benchmarks-5.8 +ENV HPCX_OSU_CUDA_DIR=/opt/hpcx/ompi/tests/osu-micro-benchmarks-5.8-cuda +ENV HPCX_IPM_DIR=/opt/hpcx/ompi/tests/ipm-2.0.6 +ENV HPCX_CLUSTERKIT_DIR=/opt/hpcx/clusterkit +ENV OMPI_HOME=/opt/hpcx/ompi +ENV MPI_HOME=/opt/hpcx/ompi +ENV OSHMEM_HOME=/opt/hpcx/ompi +ENV OPAL_PREFIX=/opt/hpcx/ompi +ENV PATH=/opt/hpcx/clusterkit/bin:/opt/hpcx/hcoll/bin:/opt/hpcx/ucc/bin:/opt/hpcx/ucx/bin:/opt/hpcx/ompi/bin:$PATH +ENV LD_LIBRARY_PATH=/opt/hpcx/nccl_rdma_sharp_plugin/lib:/opt/hpcx/ucc/lib/ucc:/opt/hpcx/ucc/lib:/opt/hpcx/ucx/lib/ucx:/opt/hpcx/ucx/lib:/opt/hpcx/sharp/lib:/opt/hpcx/hcoll/lib:/opt/hpcx/ompi/lib:$LD_LIBRARY_PATH +ENV LIBRARY_PATH=/opt/hpcx/nccl_rdma_sharp_plugin/lib:/opt/hpcx/ompi/lib:/opt/hpcx/sharp/lib:/opt/hpcx/ucc/lib:/opt/hpcx/ucx/lib:/opt/hpcx/hcoll/lib:/opt/hpcx/ompi/lib:/usr/local/cuda/lib64/stubs +ENV CPATH=/opt/hpcx/ompi/include:/opt/hpcx/ucc/include:/opt/hpcx/ucx/include:/opt/hpcx/sharp/include:/opt/hpcx/hcoll/include: +ENV PKG_CONFIG_PATH=/opt/hpcx/hcoll/lib/pkgconfig:/opt/hpcx/sharp/lib/pkgconfig:/opt/hpcx/ucx/lib/pkgconfig:/opt/hpcx/ompi/lib/pkgconfig: +# End of auto-generated paths + +# install common tool & conda +RUN apt update && \ + apt install -y wget git vim screen && \ + wget --quiet https://repo.anaconda.com/archive/Anaconda3-2022.05-Linux-x86_64.sh -O ~/anaconda.sh && \ + /bin/bash ~/anaconda.sh -b -p /opt/conda && \ + rm ~/anaconda.sh && \ + mkdir -p /opt/conda/envs/alpa && \ + ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \ + echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \ + echo "conda activate base" >> ~/.bashrc + +# Some of my own dev config +RUN wget --quiet https://raw.githubusercontent.com/zhisbug/RC/master/.screenrc -P /root/ && \ + wget --quiet https://raw.githubusercontent.com/zhisbug/RC/master/.vimrc -P /root/ + +# install conda alpa env +RUN . /opt/conda/etc/profile.d/conda.sh && \ + conda create --name alpa python=3.8 -y && \ + conda activate alpa && \ + apt install coinor-cbc -y && \ + pip3 install --upgrade pip && \ + pip3 install cupy-cuda113 && \ + pip3 install alpa && \ + pip3 install jaxlib==0.3.22+cuda113.cudnn820 -f https://alpa-projects.github.io/wheels.html + +# install additional deps for opt finetuning +RUN conda activate alpa && \ + pip3 install datasets && \ + pip3 install transformers && \ + pip3 install tensorflow-gpu + +# Execute in Alpa conda env +ENV PATH /opt/conda/envs/alpa/bin:$PATH \ No newline at end of file diff --git a/examples/opt_finetune/requirements.txt b/examples/opt_finetune/requirements.txt new file mode 100644 index 000000000..3b8901a82 --- /dev/null +++ b/examples/opt_finetune/requirements.txt @@ -0,0 +1,3 @@ +datasets >= 1.1.3 +transformers +tensorflow-gpu diff --git a/examples/opt_finetune/run_125m_shard.sh b/examples/opt_finetune/run_125m_shard.sh index 22b32f64a..1d923dedb 100644 --- a/examples/opt_finetune/run_125m_shard.sh +++ b/examples/opt_finetune/run_125m_shard.sh @@ -8,7 +8,7 @@ python3 run_clm_flax.py \ --per_device_train_batch_size="20" \ --per_device_eval_batch_size="20" \ --num_micro_batches 4 \ - --operator_parallel 4 \ + --operator_parallel 2 \ --pipeline_parallel 1 \ --dtype="float16" \ --learning_rate="5e-4" --warmup_steps="2000" \ diff --git a/examples/opt_finetune/run_2.7b_shard.sh b/examples/opt_finetune/run_2.7b_shard.sh index 33f46e411..1749e82ea 100644 --- a/examples/opt_finetune/run_2.7b_shard.sh +++ b/examples/opt_finetune/run_2.7b_shard.sh @@ -8,7 +8,7 @@ python3 run_clm_flax.py \ --per_device_train_batch_size="20" \ --per_device_eval_batch_size="20" \ --num_micro_batches 4 \ - --operator_parallel 4 \ + --operator_parallel 2 \ --pipeline_parallel 1 \ --dtype="float16" \ --learning_rate="5e-4" --warmup_steps="2000" \ From fd25c2e4537793055a063dce009de604e2b9a916 Mon Sep 17 00:00:00 2001 From: zhisbug Date: Mon, 16 Jan 2023 21:13:54 -0500 Subject: [PATCH 02/16] finetuning works [skip ci] --- alpa/parallel_method.py | 20 +- ...-128ram.yaml => dev-1node2a100-128ram.yaml | 8 +- examples/llm_serving/model/opt_utils.py | 2 +- examples/llm_serving/model/wrapper.py | 4 +- examples/opt_finetune/config_125m.json | 29 +++ examples/opt_finetune/config_175b.json | 29 +++ examples/opt_finetune/config_2.7b.json | 30 +++ .../coreweave/opt_finetune_cuda113.Dockerfile | 7 +- examples/opt_finetune/load_params.py | 224 ++++++++++++++++++ examples/opt_finetune/run_125m_pipe.sh | 22 ++ examples/opt_finetune/run_175b_pipe.sh | 22 ++ examples/opt_finetune/run_clm_flax.py | 138 +++++++++-- examples/opt_finetune/test_save_load.py | 65 +++++ 13 files changed, 561 insertions(+), 39 deletions(-) rename examples/opt_finetune/coreweave/hao-dev-1node2a100-128ram.yaml => dev-1node2a100-128ram.yaml (97%) create mode 100644 examples/opt_finetune/config_125m.json create mode 100644 examples/opt_finetune/config_175b.json create mode 100644 examples/opt_finetune/config_2.7b.json create mode 100644 examples/opt_finetune/load_params.py create mode 100644 examples/opt_finetune/run_125m_pipe.sh create mode 100644 examples/opt_finetune/run_175b_pipe.sh create mode 100644 examples/opt_finetune/test_save_load.py diff --git a/alpa/parallel_method.py b/alpa/parallel_method.py index 12ae7f1c5..713ec684d 100644 --- a/alpa/parallel_method.py +++ b/alpa/parallel_method.py @@ -13,7 +13,7 @@ - PipeshardParallel: which combines pipeline parallelism and shard parallelism. """ from abc import ABC, abstractmethod -from typing import Callable, Optional, Sequence, Union, Any +from typing import Callable, Optional, Sequence, Union, Any, List from jax import linear_util as lu from jax._src import traceback_util @@ -248,9 +248,11 @@ def get_3d_parallel_method(num_micro_batches: int, data_parallel: int, operator_parallel: int, pipeline_parallel: int, - allow_degenerate_into_shard_parallel: bool = True): + allow_degenerate_into_shard_parallel: bool = True, + use_manual_layer_option: bool = False, + forward_stage_layer_ids: List[List[int]] = None): """ - Get a parallel method for 3D parallelism, which reguarlly combines + Get a parallel method for 3D parallelism, which regularly combines data parallelism, operator parallelism and pipeline parallelism. """ # Validity check @@ -288,7 +290,15 @@ def get_3d_parallel_method(num_micro_batches: int, [data_parallel, operator_parallel])) # Return pipeshard parallel - layer_option = AutoLayerOption(layer_num=pp, eps=0.1) + if use_manual_layer_option: + # We assume each layer has been annotated using the mark_pipeline_boundary() + layer_option = ManualLayerOption() + assert forward_stage_layer_ids, "forward_stage_layer_ids must be provided " \ + "when using manual annotation." + else: + # Note: this eps need some tuning. + layer_option = AutoLayerOption(layer_num=pp, eps=0.1) + forward_stage_layer_ids = [[i] for i in range(pp)] return PipeshardParallel( devices=virtual_mesh, num_micro_batches=num_micro_batches, @@ -298,7 +308,7 @@ def get_3d_parallel_method(num_micro_batches: int, ), layer_option=layer_option, stage_option=ManualStageOption( - forward_stage_layer_ids=[[i] for i in range(pp)], + forward_stage_layer_ids=forward_stage_layer_ids, submesh_physical_shapes=[physical_mesh_shape] * pp, submesh_logical_shapes=[logical_mesh_shape] * pp, submesh_autosharding_option_dicts=[{}] * pp)) diff --git a/examples/opt_finetune/coreweave/hao-dev-1node2a100-128ram.yaml b/dev-1node2a100-128ram.yaml similarity index 97% rename from examples/opt_finetune/coreweave/hao-dev-1node2a100-128ram.yaml rename to dev-1node2a100-128ram.yaml index 7d9854e96..f63ebb9bd 100644 --- a/examples/opt_finetune/coreweave/hao-dev-1node2a100-128ram.yaml +++ b/dev-1node2a100-128ram.yaml @@ -59,7 +59,7 @@ spec: claimName: alpa-opt containers: - name: ray-head - image: jiaodong/alpa:v1 # TODO: Change to your Alpa docker image + image: zhisbug/alpa-opt-finetune-dev:1 imagePullPolicy: IfNotPresent # This volume allocates shared memory for Ray to use for its plasma] # --login in required to have access to conda to activate alpa env @@ -72,7 +72,7 @@ spec: conda activate alpa; cd /mnt/alpa-opt/alpa; pip install -e .; - ray start --head --port=6380 --num-cpus=$MY_CPU_REQUEST --dashboard-host=0.0.0.0 --object-manager-port=8076 --node-manager-port=8077 --dashboard-agent-grpc-port=8078 --dashboard-agent-listen-port=8079 --min-worker-port=10002 --max-worker-port=19999 --redis-password=''; + ray start --head --port=6380 --num-cpus=$MY_CPU_REQUEST --dashboard-host=0.0.0.0 --object-manager-port=8076 --node-manager-port=8077 --dashboard-agent-grpc-port=8078 --dashboard-agent-listen-port=8079 --min-worker-port=10002 --max-worker-port=19999 --redis-password='' --block; # This volume allocates shared memory for Ray to use for its plasma # object store. If you do not provide this, Ray will fall back to # /tmp which cause slowdowns if is not a shared memory volume. @@ -92,7 +92,7 @@ spec: resources: limits: cpu: 32 - memory: 128Gi + memory: 256Gi nvidia.com/gpu: 2 rdma/ib: 1 # Refer to CoreWeave's documentation for more details about GPU node types and placement @@ -172,4 +172,4 @@ spec: # - key: gpu.nvidia.com/class # operator: In # values: -# - A100_NVLINK_80GB \ No newline at end of file +# - A100_NVLINK_80GB diff --git a/examples/llm_serving/model/opt_utils.py b/examples/llm_serving/model/opt_utils.py index 4fcd10e2c..a4a466f6c 100644 --- a/examples/llm_serving/model/opt_utils.py +++ b/examples/llm_serving/model/opt_utils.py @@ -4,7 +4,7 @@ from jax import xla, jit from jax.core import Primitive from jax._src.lib import xla_client as xc -from transformers.generation_utils import dataclass +from dataclasses import dataclass def sync(device_id=0): diff --git a/examples/llm_serving/model/wrapper.py b/examples/llm_serving/model/wrapper.py index a2629f9b8..ee4d2ffda 100644 --- a/examples/llm_serving/model/wrapper.py +++ b/examples/llm_serving/model/wrapper.py @@ -13,7 +13,9 @@ jax_index_select) from tqdm import tqdm from transformers import OPTForCausalLM, BloomForCausalLM -from transformers.generation_utils import GenerationMixin, ModelOutput, dataclass +from transformers import GenerationMixin +from transformers.utils import ModelOutput +from dataclasses import dataclass import alpa from alpa.device_mesh import DistributedArray diff --git a/examples/opt_finetune/config_125m.json b/examples/opt_finetune/config_125m.json new file mode 100644 index 000000000..facfb0f68 --- /dev/null +++ b/examples/opt_finetune/config_125m.json @@ -0,0 +1,29 @@ +{ +"_name_or_path": "facebook/opt-125m", +"weight_path": "/home/hao/opt_weights/125M_np", +"activation_dropout": 0, +"activation_function": "relu", +"architectures": [ +"OPTForCausalLM" +], +"attention_dropout": 0, +"bos_token_id": 2, +"do_layer_norm_before": true, +"dropout": 0.1, +"eos_token_id": 2, +"ffn_dim": 3072, +"hidden_size": 768, +"init_std": 0.02, +"layerdrop": 0, +"max_position_embeddings": 2048, +"model_type": "opt", +"num_attention_heads": 12, +"num_hidden_layers": 12, +"pad_token_id": 1, +"prefix": "", +"torch_dtype": "float16", +"transformers_version": "4.21.0.dev0", +"use_cache": true, +"vocab_size": 50272, +"word_embed_proj_dim": 768 +} \ No newline at end of file diff --git a/examples/opt_finetune/config_175b.json b/examples/opt_finetune/config_175b.json new file mode 100644 index 000000000..5ba97d66c --- /dev/null +++ b/examples/opt_finetune/config_175b.json @@ -0,0 +1,29 @@ +{ +"_name_or_path": "./", +"_remove_final_layer_norm": false, +"activation_dropout": 0, +"activation_function": "relu", +"architectures": [ +"OPTForCausalLM" +], +"attention_dropout": 0, +"bos_token_id": 2, +"do_layer_norm_before": true, +"dropout": 0.1, +"eos_token_id": 2, +"ffn_dim": 49152, +"hidden_size": 12288, +"init_std": 0.02, +"layerdrop": 0, +"max_position_embeddings": 2048, +"model_type": "opt", +"num_attention_heads": 96, +"num_hidden_layers": 96, +"pad_token_id": 1, +"prefix": "", +"torch_dtype": "float16", +"transformers_version": "4.21.0.dev0", +"use_cache": true, +"vocab_size": 50272, +"word_embed_proj_dim": 12288 +} \ No newline at end of file diff --git a/examples/opt_finetune/config_2.7b.json b/examples/opt_finetune/config_2.7b.json new file mode 100644 index 000000000..b211cd3f5 --- /dev/null +++ b/examples/opt_finetune/config_2.7b.json @@ -0,0 +1,30 @@ +{ +"_name_or_path": "facebook/opt-2.7b", +"weight_path": "/home/hao/opt_weights/2.7B_np", +"_remove_final_layer_norm": false, +"activation_dropout": 0, +"activation_function": "relu", +"architectures": [ +"OPTForCausalLM" +], +"attention_dropout": 0, +"bos_token_id": 2, +"do_layer_norm_before": true, +"dropout": 0.1, +"eos_token_id": 2, +"ffn_dim": 10240, +"hidden_size": 2560, +"init_std": 0.02, +"layerdrop": 0, +"max_position_embeddings": 2048, +"model_type": "opt", +"num_attention_heads": 32, +"num_hidden_layers": 32, +"pad_token_id": 1, +"prefix": "", +"torch_dtype": "float16", +"transformers_version": "4.21.0.dev0", +"use_cache": true, +"vocab_size": 50272, +"word_embed_proj_dim": 2560 +} diff --git a/examples/opt_finetune/coreweave/opt_finetune_cuda113.Dockerfile b/examples/opt_finetune/coreweave/opt_finetune_cuda113.Dockerfile index e5f40e5c9..18cc8f7a3 100644 --- a/examples/opt_finetune/coreweave/opt_finetune_cuda113.Dockerfile +++ b/examples/opt_finetune/coreweave/opt_finetune_cuda113.Dockerfile @@ -83,13 +83,10 @@ RUN . /opt/conda/etc/profile.d/conda.sh && \ pip3 install --upgrade pip && \ pip3 install cupy-cuda113 && \ pip3 install alpa && \ - pip3 install jaxlib==0.3.22+cuda113.cudnn820 -f https://alpa-projects.github.io/wheels.html - -# install additional deps for opt finetuning -RUN conda activate alpa && \ + pip3 install jaxlib==0.3.22+cuda113.cudnn820 -f https://alpa-projects.github.io/wheels.html && \ pip3 install datasets && \ pip3 install transformers && \ pip3 install tensorflow-gpu # Execute in Alpa conda env -ENV PATH /opt/conda/envs/alpa/bin:$PATH \ No newline at end of file +ENV PATH /opt/conda/envs/alpa/bin:$PATH diff --git a/examples/opt_finetune/load_params.py b/examples/opt_finetune/load_params.py new file mode 100644 index 000000000..afbc8ff51 --- /dev/null +++ b/examples/opt_finetune/load_params.py @@ -0,0 +1,224 @@ +import os +import itertools + +import numpy as np +import alpa +from alpa.device_mesh import (DistributedArray, ReplicatedDistributedArray, + MeshHostWorker, create_remote_array_refs) +import jax +import flax +from jax.tree_util import tree_flatten, tree_unflatten, tree_leaves +from jax.interpreters import pxla + + +def load_opt_params_worker_func(self, path, prefix_to_idx, config, shapes, + uuids, indices, mesh_ids): + """The worker function to load OPT parameters.""" + + def load_array(key): + return np.load(os.path.join(path, key)) + + def load_param(param_key, loaded_array, is_position_embedding=False): + i = prefix_to_idx[param_key] + + for j in range(len(mesh_ids[i])): + if self.mesh_id != mesh_ids[i][j]: + print(f"skipped because self.mesh_id {self.mesh_id}, mesh_ids: {mesh_ids[i][j]}") + continue + + if not is_position_embedding: + assert shapes[i][j] == loaded_array.shape, ( + f"{shapes[i][j]} vs. {loaded_array.shape}") + else: + if shapes[i][j] != loaded_array.shape: + assert shapes[i][j][1] == loaded_array.shape[1] + loaded_array = loaded_array[:shapes[i][j][0], :] + uuid = uuids[i][j] + datas = [] + for k in range(len(self.local_devices)): + idx = self.host_id * len(self.local_devices) + k + datas.append(loaded_array[indices[i][j][idx]]) + self.put_buffers(uuid, datas) + + load_param("model.decoder.embed_tokens.embedding", + load_array("decoder.embed_tokens.weight")) + load_param("model.decoder.embed_positions.embedding", + load_array("decoder.embed_positions.weight"), + is_position_embedding=True) + + # if config.version > 2: + load_param("model.decoder.final_layer_norm.scale", + load_array("decoder.layer_norm.weight")) + load_param("model.decoder.final_layer_norm.bias", + load_array("decoder.layer_norm.bias")) + + layers_per_stage = config.num_hidden_layers // config.pp + + for i in range(config.num_hidden_layers): + if i == 16: + print(f"stage_id: {i // layers_per_stage}, mesh id: {self.mesh_id}") + stage_id = i // layers_per_stage + if stage_id != self.mesh_id: + continue + + param_prefix = f"model.decoder.layers.{i}." + load_prefix = f"decoder.layers.{i}." + # Attention weights + wq = load_array(load_prefix + "self_attn.q_proj.weight") + wk = load_array(load_prefix + "self_attn.k_proj.weight") + wv = load_array(load_prefix + "self_attn.v_proj.weight") + # dim = wq.shape[-1] + # w_qkv = np.concatenate([wq, wk, wv], axis=0).reshape( + # (3, -1, dim)).transpose([2, 1, 0]).reshape((dim, -1)) + # load_param(param_prefix + "attention.self.qkv_combined.kernel", w_qkv) + load_param(param_prefix + "self_attn.q_proj.kernel", wq.T) + load_param(param_prefix + "self_attn.k_proj.kernel", wk.T) + load_param(param_prefix + "self_attn.v_proj.kernel", wv.T) + + + bq = load_array(load_prefix + "self_attn.q_proj.bias") + bk = load_array(load_prefix + "self_attn.k_proj.bias") + bv = load_array(load_prefix + "self_attn.v_proj.bias") + # b_qkv = np.concatenate([bq, bk, bv], axis=0).reshape( + # (3, dim)).transpose([1, 0]).reshape((-1,)) + load_param(param_prefix + "self_attn.q_proj.bias", bq) + load_param(param_prefix + "self_attn.k_proj.bias", bk) + load_param(param_prefix + "self_attn.v_proj.bias", bv) + # load_param(param_prefix + "attention.self.qkv_combined.bias", b_qkv) + load_param( + param_prefix + "self_attn.out_proj.kernel", + np.transpose(load_array(load_prefix + "self_attn.out_proj.weight"))) + load_param(param_prefix + "self_attn.out_proj.bias", + load_array(load_prefix + "self_attn.out_proj.bias")) + load_param(param_prefix + "self_attn_layer_norm.scale", + load_array(load_prefix + "self_attn_layer_norm.weight")) + load_param(param_prefix + "self_attn_layer_norm.bias", + load_array(load_prefix + "self_attn_layer_norm.bias")) + # FFN weights + load_param(param_prefix + "fc1.bias", + load_array(load_prefix + "fc1.bias")) + load_param(param_prefix + "fc1.kernel", + np.transpose(load_array(load_prefix + "fc1.weight"))) + load_param(param_prefix + "fc2.bias", + load_array(load_prefix + "fc2.bias")) + load_param(param_prefix + "fc2.kernel", + np.transpose(load_array(load_prefix + "fc2.weight"))) + load_param(param_prefix + "final_layer_norm.scale", + load_array(load_prefix + "final_layer_norm.weight")) + load_param(param_prefix + "final_layer_norm.bias", + load_array(load_prefix + "final_layer_norm.bias")) + + +setattr(MeshHostWorker, "load_opt_params_worker_func", + load_opt_params_worker_func) + +def load_params_dis_array(path, executable, params_aval, config, dummy=False): + """Load parameters with distributed arrays.""" + if dummy: + alpa.global_config.use_dummy_value_for_benchmarking = True + params_info, _ = executable.get_input_placement_specs() + flat_args, in_tree = tree_flatten(params_aval) + flat_info = tree_leaves(params_info) + if hasattr(executable, "mesh_group"): + ret = executable.mesh_group.shard_args_to_arrays( + flat_info, flat_args) + else: + ret = executable.physical_mesh.shard_args_to_arrays_ps( + flat_info, flat_args) + alpa.global_config.use_dummy_value_for_benchmarking = False + return ret + + params_info, _ = executable.get_input_placement_specs() + params_info = params_info.params + + prefix_to_flat_idx = {} + ct = itertools.count() + + def dfs(dict_tree, result_dict, cur_prefix): + if isinstance(dict_tree, (dict, flax.core.FrozenDict)): + for key in dict_tree.keys(): + dfs(dict_tree[key], result_dict, + cur_prefix + ("." if cur_prefix else "") + key) + else: + result_dict[cur_prefix] = next(ct) + + dfs(params_aval, prefix_to_flat_idx, "") + + flat_infos, in_tree = tree_flatten(params_info) + + flat_shapes = [] + flat_uuids = [] + flat_indices = [] + flat_mesh_ids = [] + flat_arrays = [] + + mesh_group = executable.mesh_group + + for info in flat_infos: + aval = info.aval + if len(info.mesh_ids) == 1: + mesh, spec = mesh_group[info.mesh_ids[0]], info.sharding_specs[0] + indices = pxla.spec_to_indices(aval.shape, spec) + ary_refs, ary_uuid = create_remote_array_refs(mesh) + flat_shapes.append([aval.shape]) + flat_uuids.append([ary_uuid[0]]) + flat_indices.append([indices]) + flat_mesh_ids.append([mesh.mesh_id]) + flat_arrays.append( + DistributedArray(mesh, aval, spec, ary_refs[0], indices)) + else: + tmp_shapes = [] + tmp_uuids = [] + tmp_indices = [] + tmp_mesh_ids = [] + tmp_arrays = [] + tmp_meshes = [] + for mesh_id, spec in zip(info.mesh_ids, info.sharding_specs): + mesh = mesh_group[mesh_id] + indices = pxla.spec_to_indices(aval.shape, spec) + ary_refs, ary_uuid = create_remote_array_refs(mesh) + array = DistributedArray(mesh, aval, spec, ary_refs[0], indices) + tmp_shapes.append(aval.shape) + tmp_uuids.append(ary_uuid[0]) + tmp_indices.append(indices) + tmp_mesh_ids.append(mesh.mesh_id) + tmp_meshes.append(mesh) + tmp_arrays.append(array) + flat_shapes.append(tuple(tmp_shapes)) + flat_uuids.append(tuple(tmp_uuids)) + flat_indices.append(tuple(tmp_indices)) + flat_mesh_ids.append(tuple(tmp_mesh_ids)) + flat_arrays.append( + ReplicatedDistributedArray(tmp_meshes, tmp_arrays)) + + for m in executable.mesh_group.meshes: + for w in m.workers: + w.load_opt_params_worker_func.remote(path, prefix_to_flat_idx, + config, flat_shapes, + flat_uuids, flat_indices, + flat_mesh_ids) + + return tree_unflatten(in_tree, flat_arrays) + # return flat_arrays + + +def load_multi_executable_params_dis_array(path, + executables, + params_aval, + config, + dummy=False): + """Load parameters to workers that will be used by all executables. Accordingly, + we need to make sure the parameter sharding specs are identical for all executables. + """ + shared_input_shard_specs = None + # for executable in executables.values(): + # # stage_input_shard_specs = executable.stage_input_shard_specs + # stage_input_shard_specs = executable.get_input_placement_specs + # if shared_input_shard_specs is not None: + # assert shared_input_shard_specs == stage_input_shard_specs, \ + # "All executables must have the same input sharding specs." + # else: + # shared_input_shard_specs = stage_input_shard_specs + return load_params_dis_array(path, + list(executables.values())[0], params_aval, + config, dummy) diff --git a/examples/opt_finetune/run_125m_pipe.sh b/examples/opt_finetune/run_125m_pipe.sh new file mode 100644 index 000000000..3253fa4f9 --- /dev/null +++ b/examples/opt_finetune/run_125m_pipe.sh @@ -0,0 +1,22 @@ +python3 run_clm_flax.py \ + --output_dir="./output" \ + --config_name="./config_125m.json" \ + --tokenizer_name="facebook/opt-30b" \ + --alpa_init \ + --dataset_name="wikitext" \ + --dataset_config_name="wikitext-2-raw-v1" \ + --do_train \ + --block_size="1024" \ + --per_device_train_batch_size="20" \ + --per_device_eval_batch_size="20" \ + --num_micro_batches 4 \ + --operator_parallel 1 \ + --pipeline_parallel 1 \ + --dtype="float16" \ + --learning_rate="5e-4" --warmup_steps="2000" \ + --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \ + --overwrite_output_dir \ + --num_train_epochs="8" \ + --logging_steps="1" \ + --save_steps="888" \ + --eval_steps="888" diff --git a/examples/opt_finetune/run_175b_pipe.sh b/examples/opt_finetune/run_175b_pipe.sh new file mode 100644 index 000000000..e4808d472 --- /dev/null +++ b/examples/opt_finetune/run_175b_pipe.sh @@ -0,0 +1,22 @@ +python3 run_clm_flax.py \ + --output_dir="./output" \ + --config_name="./config_175b.json" \ + --tokenizer_name="facebook/opt-30b" \ + --dataset_name="wikitext" \ + --dataset_config_name="wikitext-2-raw-v1" \ + --do_train \ + --use_manual_layer \ + --block_size="1024" \ + --per_device_train_batch_size="64" \ + --per_device_eval_batch_size="64" \ + --num_micro_batches 64 \ + --operator_parallel 1 \ + --pipeline_parallel 8 \ + --dtype="float16" \ + --learning_rate="1.2e-4" --warmup_steps="2000" \ + --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \ + --overwrite_output_dir \ + --num_train_epochs="10" \ + --logging_steps="1" \ + --save_steps="40" \ + --eval_steps="25" diff --git a/examples/opt_finetune/run_clm_flax.py b/examples/opt_finetune/run_clm_flax.py index d72961aea..ccaed98d1 100644 --- a/examples/opt_finetune/run_clm_flax.py +++ b/examples/opt_finetune/run_clm_flax.py @@ -32,7 +32,7 @@ import functools from itertools import chain from pathlib import Path -from typing import Callable, Optional +from typing import Callable, Optional, Tuple, Dict import datasets import numpy as np @@ -41,7 +41,7 @@ import alpa from alpa.model.model_util import DynamicScale, TrainState -from alpa import AutoShardingOption, AutoLayerOption, ManualStageOption +from alpa import AutoShardingOption, AutoLayerOption, ManualStageOption, mark_pipeline_boundary import jax import jax.numpy as jnp import optax @@ -50,6 +50,7 @@ from flax import jax_utils, traverse_util from flax.training import train_state from flax.training.common_utils import onehot, shard, shard_prng_key +from flax.core.frozen_dict import unfreeze, FrozenDict from huggingface_hub import Repository from transformers import ( CONFIG_MAPPING, @@ -62,6 +63,8 @@ set_seed, ) +from examples.opt_finetune.load_params import load_params_dis_array + alpa.init(cluster="ray") from transformers.testing_utils import CaptureLogger @@ -91,6 +94,8 @@ class TrainingArguments: ) do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."}) + use_manual_layer: bool = field(default=False, metadata={"help": "Whether to use manual layer annotation."}) + alpa_init: bool = field(default=False, metadata={"help": "Whether to use Alpa's distributed gpu init."}) per_device_train_batch_size: int = field( default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."} ) @@ -320,7 +325,7 @@ def create_learning_rate_fn( return schedule_fn -def monkey_patch_remat(): +def monkey_patch_remat(use_manual_layer=False): # Use monkey patch to add remat for all transformer layers. from transformers.models.opt.modeling_flax_opt import FlaxOPTDecoderLayer, FlaxOPTDecoderLayerCollection from flax.linen.partitioning import remat @@ -349,7 +354,7 @@ def call( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) @@ -365,6 +370,10 @@ def call( if output_attentions: all_self_attns += (layer_outputs[1],) + # add manual layer annotations + if use_manual_layer and i != len(self.layers) - 1: + mark_pipeline_boundary() + outputs = [hidden_states, all_hidden_states, all_self_attns] return outputs @@ -512,6 +521,11 @@ def main(): use_auth_token=True if model_args.use_auth_token else None, ) elif model_args.model_name_or_path: + if "175b" in model_args.model_name_or_path.lower(): + raise RuntimeError("HuggingFace hub does not have OPT-175B model. If you want to finetune OPT-175B, " + "please pass --config_name instead and set it as the path to the config file " + "provided at examples/opt_finetune/config_175b.json. Please obtain the weights" + "by contacting Meta Inc.") config = AutoConfig.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, @@ -522,7 +536,7 @@ def main(): logger.warning("You are instantiating a new config instance from scratch.") if training_args.use_remat: - monkey_patch_remat() + monkey_patch_remat(training_args.use_manual_layer) if model_args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained( @@ -545,6 +559,7 @@ def main(): "You can do it from another script, save it, and load it from here, using --tokenizer_name." ) + do_init = not training_args.alpa_init if model_args.model_name_or_path: model = FlaxAutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, @@ -552,6 +567,7 @@ def main(): seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), use_auth_token=True if model_args.use_auth_token else None, + _do_init=do_init ) #from transformers import FlaxOPTForCausalLM #config.num_hidden_layers = 2 @@ -565,6 +581,7 @@ def main(): config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), + _do_init=do_init ) # Preprocessing the datasets. @@ -654,13 +671,12 @@ def group_texts(examples): max_train_samples = min(len(train_dataset), data_args.max_train_samples) train_dataset = train_dataset.select(range(max_train_samples)) - if training_args.do_eval: - if "validation" not in tokenized_datasets: - raise ValueError("--do_eval requires a validation dataset") - eval_dataset = lm_datasets["validation"] - if data_args.max_eval_samples is not None: - max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) - eval_dataset = eval_dataset.select(range(max_eval_samples)) + if "validation" not in tokenized_datasets: + raise ValueError("--do_eval requires a validation dataset") + eval_dataset = lm_datasets["validation"] + if data_args.max_eval_samples is not None: + max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) + eval_dataset = eval_dataset.select(range(max_eval_samples)) # Adjust batch size and num_micro_batches for small datasets num_devices = alpa.get_global_num_devices() @@ -698,8 +714,19 @@ def group_texts(examples): # Store some constant num_epochs = int(training_args.num_train_epochs) - train_batch_size = int(training_args.per_device_train_batch_size) * num_devices - eval_batch_size = int(training_args.per_device_eval_batch_size) * num_devices + # Infer data parallel + op = training_args.operator_parallel + pp = training_args.pipeline_parallel + dp = num_devices // op // pp + assert dp >= 1, "You settings of `operator_parallel` or `pipeline_parallel` is problematic. " \ + "Please make sure op * pp is divisible by num_devices" + # Copy the parallel config into configs + setattr(config, "op", op) + setattr(config, "pp", pp) + setattr(config, "dp", dp) + + train_batch_size = int(training_args.per_device_train_batch_size) * dp + eval_batch_size = int(training_args.per_device_eval_batch_size) * dp steps_per_epoch = len(train_dataset) // train_batch_size total_train_steps = steps_per_epoch * num_epochs @@ -754,8 +781,6 @@ def decay_mask_fn(params): alpa.global_config.flax_always_use_fp16_embedding = True else: use_master_copy = dynamic_scale = None - state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, - dynamic_scale=dynamic_scale, use_master_copy=use_master_copy) def loss_fn(logits, labels): shift_logits = logits[..., :-1, :] @@ -806,18 +831,45 @@ def eval_step(params, batch): labels = batch.pop("labels") logits = model(**batch, params=params, deterministic=True)[0] loss = loss_fn(logits, labels) - # summarize metrics metrics = {"loss": loss} return metrics - # Create parallel version of the train and eval step - method = alpa.get_3d_parallel_method( - num_micro_batches=training_args.num_micro_batches, - data_parallel=-1, - operator_parallel=training_args.operator_parallel, - pipeline_parallel=training_args.pipeline_parallel) + if training_args.alpa_init: + # In this case, params are not initialized yet, and we use shaped array to trigger alpa compilation + rngkey = jax.core.ShapedArray((2,), jnp.uint32) + input_ids = jax.core.ShapedArray((1, 128), jnp.int32) + attention_mask = jax.core.ShapedArray((1, 128), jnp.int32) + position_ids = jax.core.ShapedArray((1, 128), jnp.int32) + params_aval = unfreeze(jax.eval_shape(model.module.init, rngkey, input_ids, attention_mask, position_ids)["params"]) + state_aval = TrainState.create_aval(apply_fn=model.__call__, params=params_aval, tx=optimizer, + dynamic_scale=dynamic_scale, use_master_copy=use_master_copy) + # params_aval = jax.tree_map(lambda x: jax.ShapeDtypeStruct(x.shape, getattr(jnp, model_args.dtype)), params_aval) + else: + # In this case, params have been initialized by HF in the CPU memory of the driver node. + state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, + dynamic_scale=dynamic_scale, use_master_copy=use_master_copy) + # Create parallel version of the train and eval step + if training_args.use_manual_layer: + # make layer-stage aassignments + assert config.num_hidden_layers % pp == 0, f"There are {config.num_hidden_layers} layers but {pp} stages." + n_layers_per_stage = config.num_hidden_layers // pp + forward_stage_layer_ids = [list(range(i * n_layers_per_stage, (i + 1) * n_layers_per_stage)) for i in range(pp)] + print(f"n_layers_per_stage {n_layers_per_stage}, forward_stage_layer_ids {forward_stage_layer_ids}") + method = alpa.get_3d_parallel_method( + num_micro_batches=training_args.num_micro_batches, + data_parallel=-1, + operator_parallel=training_args.operator_parallel, + pipeline_parallel=training_args.pipeline_parallel, + use_manual_layer_option=True, + forward_stage_layer_ids=forward_stage_layer_ids) + else: + method = alpa.get_3d_parallel_method( + num_micro_batches=training_args.num_micro_batches, + data_parallel=-1, + operator_parallel=training_args.operator_parallel, + pipeline_parallel=training_args.pipeline_parallel) p_train_step = alpa.parallelize(train_step, method=method, donate_argnums=(0,)) @@ -825,6 +877,46 @@ def eval_step(params, batch): method=alpa.FollowParallel( p_train_step, num_micro_batches=eval_num_micro_batches)) + if training_args.alpa_init: + print(f" - Compile executables. ", end="", flush=True) + tic = time.time() + # trigger compile + seq_len = config.max_position_embeddings + train_input_shape = (train_batch_size, seq_len) + batch = {"input_ids": + jax.core.ShapedArray( + train_input_shape, jnp.int32), + "position_ids": + jax.core.ShapedArray( + train_input_shape, jnp.int32), + "attention_mask": + jax.core.ShapedArray( + train_input_shape, jnp.int32), + "labels": + jax.core.ShapedArray( + train_input_shape, jnp.int32), + } + train_executable = p_train_step.get_executable(state_aval, batch) + eval_executable = p_eval_step.get_executable(state_aval.params, batch) + # load params using compiled sharding specs1 + print(f" Compilation takes {time.time() - tic:.2f} seconds.") + + # def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> Dict: + + print(" - Load parameters. ", end="", flush=True) + tic = time.time() + # params = load_multi_executable_params_dis_array(path, exectuables, params_aval, config, dummy) + assert config.weight_path and os.path.exists(config.weight_path), f"Cannot find weight at {config.weight_path}" + params = load_params_dis_array(config.weight_path, train_executable, params_aval, config, dummy=False) + # from alpa.serialization import restore_checkpoint + # params = restore_checkpoint(path, ) + # Work around the model._initialized in HF + model._is_initialized = True + model.params = params + state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, + dynamic_scale=dynamic_scale, use_master_copy=use_master_copy) + print(f" Load parameters takes {time.time() - tic:.2f} seconds.") + dump_debug_info_train_step = dump_debug_info_eval_step = True logger.info("***** Running training *****") diff --git a/examples/opt_finetune/test_save_load.py b/examples/opt_finetune/test_save_load.py new file mode 100644 index 000000000..4d7c82de5 --- /dev/null +++ b/examples/opt_finetune/test_save_load.py @@ -0,0 +1,65 @@ +import os +import tempfile + +import ray + +import alpa +from alpa import (init, shutdown, parallelize, DistributedArray, + PipeshardParallel, save_checkpoint, restore_checkpoint) +from alpa.device_mesh import get_global_cluster +from alpa.testing import get_bert_layer_train_state_and_step, assert_allclose +from alpa.parallel_method import get_3d_parallel_method + + +def _get_save_prefix(): + device_cluster = get_global_cluster() + if len(device_cluster.host_info) > 1: + raise RuntimeError("The multi-host test requires a mounted EFS! ") + else: + # Use tmp dir for the single-host test + save_prefix = "/tmp/" + return save_prefix + + +alpa.init() +ckpt_dir = "/mnt/alpa-opt/alpa/examples/opt_finetune/test_ckpt" +state, batch, train_step = get_bert_layer_train_state_and_step( + batch_size=16, + seq_len=8, + num_layers=2, + hidden_size=128, + num_heads=8, + clip_by_global_norm=False, + use_dynamic_scale=False, + add_manual_pipeline_marker=True) + +method = PipeshardParallel(num_micro_batches=2, layer_option="manual") + +serial_train_step = train_step +parallel_train_step = parallelize(train_step, method=method) +executable = parallel_train_step.get_executable(state, batch) + +serial_state = state +parallel_state = state +serial_state = serial_train_step(serial_state, batch)[0] +parallel_state = parallel_train_step(parallel_state, batch)[0] +# assert_allclose(serial_state.params, parallel_state.params, 1e-3, 1e-3) + +with tempfile.TemporaryDirectory(prefix="/tmp/") as cache_dir: + # Save checkpoint + save_checkpoint(ckpt_dir, parallel_state, 1, cache_dir) + + # Sync all the move workers + executable.sync_move_workers() + + # Restore checkpoint + state_ps, _ = executable.get_input_placement_specs() + load_state = restore_checkpoint(ckpt_dir, 1, state_ps) + + # Run after load + serial_state = serial_train_step(serial_state, batch)[0] + load_state = parallel_train_step(load_state, batch)[0] + + # Check results + assert_allclose(serial_state.params, load_state.params, 1e-3, + 1e-3) From db589d9b24df4ba6f5b875abed95067fe1c0f429 Mon Sep 17 00:00:00 2001 From: zhisbug Date: Wed, 18 Jan 2023 22:15:56 -0500 Subject: [PATCH 03/16] updates --- dev-1node2a100-128ram.yaml | 175 ------------------ examples/opt_finetune/estimate_peak_memory.py | 99 ++++++++++ examples/opt_finetune/run_125m_pipe.sh | 4 +- examples/opt_finetune/run_clm_flax.py | 3 +- 4 files changed, 103 insertions(+), 178 deletions(-) delete mode 100644 dev-1node2a100-128ram.yaml create mode 100644 examples/opt_finetune/estimate_peak_memory.py diff --git a/dev-1node2a100-128ram.yaml b/dev-1node2a100-128ram.yaml deleted file mode 100644 index f63ebb9bd..000000000 --- a/dev-1node2a100-128ram.yaml +++ /dev/null @@ -1,175 +0,0 @@ -apiVersion: v1 -kind: Service -metadata: - namespace: tenant-jiaohpc-jd # TODO: Change to your namespace - name: service-ray-cluster - labels: - app: ray-cluster -spec: - ports: - - name: dashboard - protocol: TCP - port: 8265 - targetPort: 8265 - - name: gcs-server - protocol: TCP - port: 6380 - targetPort: 6380 - selector: - app: ray-cluster - component: ray-head ---- -apiVersion: apps/v1 -kind: Deployment -metadata: - namespace: tenant-jiaohpc-jd # TODO: Change to your namespace - name: deployment-ray-head - labels: - app: ray-cluster - ray-node: head -spec: - # Do not change this - Ray currently only supports one head node per cluster. - replicas: 1 - selector: - matchLabels: - component: ray-head - type: ray - app: ray-cluster - template: - metadata: - labels: - component: ray-head - type: ray - app: ray-cluster - spec: - # If the head node goes down, the entire cluster (including all worker - # nodes) will go down as well. If you want Kubernetes to bring up a new - # head node in this case, set this to "Always," else set it to "Never." - restartPolicy: Always - - # This volume allocates shared memory for Ray to use for its plasma - # object store. If you do not provide this, Ray will fall back to - # /tmp which cause slowdowns if is not a shared memory volume. - volumes: - - name: dshm - emptyDir: - medium: Memory - - name: alpa-opt - persistentVolumeClaim: - claimName: alpa-opt - containers: - - name: ray-head - image: zhisbug/alpa-opt-finetune-dev:1 - imagePullPolicy: IfNotPresent - # This volume allocates shared memory for Ray to use for its plasma] - # --login in required to have access to conda to activate alpa env - command: - - /bin/bash - - -l - - -c - - -- - - | - conda activate alpa; - cd /mnt/alpa-opt/alpa; - pip install -e .; - ray start --head --port=6380 --num-cpus=$MY_CPU_REQUEST --dashboard-host=0.0.0.0 --object-manager-port=8076 --node-manager-port=8077 --dashboard-agent-grpc-port=8078 --dashboard-agent-listen-port=8079 --min-worker-port=10002 --max-worker-port=19999 --redis-password='' --block; - # This volume allocates shared memory for Ray to use for its plasma - # object store. If you do not provide this, Ray will fall back to - # /tmp which cause slowdowns if is not a shared memory volume. - volumeMounts: - - mountPath: /dev/shm - name: dshm - - mountPath: /mnt/alpa-opt - name: alpa-opt - env: - # This is used in the ray start command so that Ray can spawn the - # correct number of processes. Omitting this may lead to degraded - # performance. - - name: MY_CPU_REQUEST - valueFrom: - resourceFieldRef: - resource: requests.cpu - resources: - limits: - cpu: 32 - memory: 256Gi - nvidia.com/gpu: 2 - rdma/ib: 1 - # Refer to CoreWeave's documentation for more details about GPU node types and placement - # https://docs.coreweave.com/coreweave-kubernetes/node-types - affinity: - nodeAffinity: - requiredDuringSchedulingIgnoredDuringExecution: - nodeSelectorTerms: - - matchExpressions: - - key: gpu.nvidia.com/class - operator: In - values: - - A100_NVLINK_80GB -#--- -#apiVersion: apps/v1 -#kind: Deployment -#metadata: -# namespace: tenant-jiaohpc-jd # TODO: Change to your namespace -# name: deployment-ray-worker -# labels: -# app: ray-cluster -#spec: -# # Change this to scale the number of worker nodes started in the Ray cluster. -# replicas: 3 -# selector: -# matchLabels: -# component: ray-worker -# type: ray -# app: ray-cluster -# template: -# metadata: -# labels: -# component: ray-worker -# type: ray -# app: ray-cluster -# spec: -# restartPolicy: Always -# volumes: -# - name: dshm -# emptyDir: -# medium: Memory -# containers: -# - name: ray-worker -# image: jiaodong/alpa:v1 # TODO: Change to your Alpa docker image -# imagePullPolicy: IfNotPresent -# # --login in required to have access to conda to activate alpa env -# command: ["/bin/bash", "-l", "-c", "--"] -# args: -# - "conda activate alpa && ray start --num-cpus=$MY_CPU_REQUEST --address=service-ray-cluster:6380 --object-manager-port=8076 --node-manager-port=8077 --dashboard-agent-grpc-port=8078 --dashboard-agent-listen-port=8079 --min-worker-port=10002 --max-worker-port=19999 --block" -# # This volume allocates shared memory for Ray to use for its plasma -# # object store. If you do not provide this, Ray will fall back to -# # /tmp which cause slowdowns if is not a shared memory volume. -# volumeMounts: -# - mountPath: /dev/shm -# name: dshm -# env: -# # This is used in the ray start command so that Ray can spawn the -# # correct number of processes. Omitting this may lead to degraded -# # performance. -# - name: MY_CPU_REQUEST -# valueFrom: -# resourceFieldRef: -# resource: requests.cpu -# resources: -# limits: -# cpu: 32 -# memory: 64Gi -# nvidia.com/gpu: 8 -# rdma/ib: 1 -# # Refer to CoreWeave's documentation for more details about GPU node types and placement -# # https://docs.coreweave.com/coreweave-kubernetes/node-types -# affinity: -# nodeAffinity: -# requiredDuringSchedulingIgnoredDuringExecution: -# nodeSelectorTerms: -# - matchExpressions: -# - key: gpu.nvidia.com/class -# operator: In -# values: -# - A100_NVLINK_80GB diff --git a/examples/opt_finetune/estimate_peak_memory.py b/examples/opt_finetune/estimate_peak_memory.py new file mode 100644 index 000000000..39dda3e16 --- /dev/null +++ b/examples/opt_finetune/estimate_peak_memory.py @@ -0,0 +1,99 @@ +import dataclasses +from dataclasses import dataclass + +@dataclass(frozen=True) +class OPTConfig: + # Inherited from OPT + num_hidden_layers: int = 12 + max_seq_len: int = 2048 + hidden_size: int = 768 + n_head: int = 12 + input_dim: int = 768 + ffn_embed_dim: int = 3072 + pad: int = 1 + activation_fn: str = 'relu' + dtype: any = jnp.float16 + use_stable_embedding: bool = False + no_scale_embedding: bool = True + decoder_learned_pos: bool = True + decoder_normalize_before: bool = True + share_decoder_input_output_embed: bool = True + # Added + version: int = 1 + vocab_size: int = 50272 + layer_norm_eps: float = 0.00001 + num_pp_stages: int = None + # parallelize + mark_boundary: bool = True + + +def get_config(name, **kwargs): + if name == "opt-125m": + config = OPTConfig( + max_seq_len=2048, num_hidden_layers=12, n_head=12, + hidden_size=768, input_dim=768, ffn_embed_dim=768 * 4, + version=3, + ) + elif name == "opt-350m": + config = OPTConfig( + max_seq_len=2048, num_hidden_layers=24, n_head=16, + hidden_size=1024, input_dim=1024, ffn_embed_dim=1024 * 4, + version=2, + ) + raise NotImplementedError("Not implemented because this model " + "has a different architecture") + elif name == "opt-1.3b": + config = OPTConfig( + max_seq_len=2048, num_hidden_layers=24, n_head=32, + hidden_size=2048, input_dim=2048, ffn_embed_dim=2048 * 4, + version=3, + ) + elif name == "opt-2.7b": + config = OPTConfig( + max_seq_len=2048, num_hidden_layers=32, n_head=32, + hidden_size=2560, input_dim=2560, ffn_embed_dim=2560 * 4, + version=3, + ) + elif name == "opt-6.7b": + config = OPTConfig( + max_seq_len=2048, num_hidden_layers=32, n_head=32, + hidden_size=4096, input_dim=4096, ffn_embed_dim=4096 * 4, + version=3, + ) + elif name == "opt-30b": + config = OPTConfig( + max_seq_len=2048, num_hidden_layers=48, n_head=56, + hidden_size=7168, input_dim=7168, ffn_embed_dim=7168 * 4, + version=3, + ) + elif name == "opt-66b": + config = OPTConfig( + max_seq_len=2048, num_hidden_layers=64, n_head=72, + hidden_size=9216, input_dim=9216, ffn_embed_dim=9216 * 4, + version=3, + ) + elif name == "opt-175b": + config = OPTConfig( + max_seq_len=2048, num_hidden_layers=96, n_head=96, + hidden_size=12288, input_dim=12288, ffn_embed_dim=12288 * 4, + version=3, + ) + else: + raise ValueError(f"Invalid model name: {name}") + + return dataclasses.replace(config, **kwargs) + +def estimate_peak_memory(model_name: str, mbs: int, seq_len: int): + config = get_config(model_name) + n_layer = config.num_hidden_layers + n_head = config.n_head + h = config.hidden_size + n_params = float(model_name.split("-")[:-1]) * 1e9 + + n_bytes = 16 * n_params + 2 * n_layer * mbs * seq_len * h + \ + seq_len * mbs * h * (34 + 5.0 * n_head * seq_len / h) + return n_bytes / 1e9 + + + +estimate_peak_memory("opt-125m", 2, 1024) \ No newline at end of file diff --git a/examples/opt_finetune/run_125m_pipe.sh b/examples/opt_finetune/run_125m_pipe.sh index 3253fa4f9..d66ac2efb 100644 --- a/examples/opt_finetune/run_125m_pipe.sh +++ b/examples/opt_finetune/run_125m_pipe.sh @@ -7,11 +7,11 @@ python3 run_clm_flax.py \ --dataset_config_name="wikitext-2-raw-v1" \ --do_train \ --block_size="1024" \ - --per_device_train_batch_size="20" \ + --per_device_train_batch_size="64" \ --per_device_eval_batch_size="20" \ --num_micro_batches 4 \ --operator_parallel 1 \ - --pipeline_parallel 1 \ + --pipeline_parallel 4 \ --dtype="float16" \ --learning_rate="5e-4" --warmup_steps="2000" \ --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \ diff --git a/examples/opt_finetune/run_clm_flax.py b/examples/opt_finetune/run_clm_flax.py index ccaed98d1..d9c37cdc9 100644 --- a/examples/opt_finetune/run_clm_flax.py +++ b/examples/opt_finetune/run_clm_flax.py @@ -991,7 +991,8 @@ def eval_step(params, batch): f"Loss: {train_metric['loss'].mean():.4f}, " f"Learning Rate: {train_metric['learning_rate'].mean():.5f}, " f"Throughput: {throughput_tokens:.2f} token/s, " - f"{throughput_tflops:.2f} TFLOP/s" + f"{throughput_tflops:.2f} TFLOP/s, " + f"{throughput_tflops * 4 / 3:.2f} TFLOP/s." ) train_metrics = [] From 8047611c04f5979517c7d39fbc6daf5358f400c9 Mon Sep 17 00:00:00 2001 From: zhisbug Date: Wed, 25 Jan 2023 00:21:12 -0500 Subject: [PATCH 04/16] some new updates --- alpa/global_env.py | 2 +- examples/opt_finetune/config_125m.json | 2 +- examples/opt_finetune/config_2.7b.json | 3 +- examples/opt_finetune/config_30b.json | 29 +++++ examples/opt_finetune/config_66b.json | 30 +++++ examples/opt_finetune/estimate_peak_memory.py | 15 ++- examples/opt_finetune/load_params.py | 115 +++++++++++++++++- examples/opt_finetune/run_125m_pipe.sh | 1 + examples/opt_finetune/run_2.7b_pipe.sh | 19 +-- examples/opt_finetune/run_30b_pipe.sh | 23 ++++ examples/opt_finetune/run_clm_flax.py | 86 +++++++++---- 11 files changed, 282 insertions(+), 43 deletions(-) create mode 100644 examples/opt_finetune/config_30b.json create mode 100644 examples/opt_finetune/config_66b.json create mode 100644 examples/opt_finetune/run_30b_pipe.sh diff --git a/alpa/global_env.py b/alpa/global_env.py index d543ee673..c5d4943d6 100644 --- a/alpa/global_env.py +++ b/alpa/global_env.py @@ -75,7 +75,7 @@ def __init__(self): self.resharding_mode = "send_recv" # Which nccl to use. Possible choices: {"cupy", # "xla_extension"} - self.nccl_mode = "cupy" + self.nccl_mode = "xla_extension" self.enable_overlapping = False # Cross mesh resharding load balancing mode. # Possible choices: {"normal", "no_loadbalance", diff --git a/examples/opt_finetune/config_125m.json b/examples/opt_finetune/config_125m.json index facfb0f68..08a88e353 100644 --- a/examples/opt_finetune/config_125m.json +++ b/examples/opt_finetune/config_125m.json @@ -1,6 +1,6 @@ { "_name_or_path": "facebook/opt-125m", -"weight_path": "/home/hao/opt_weights/125M_np", +"weight_path": "/home/ubuntu/dataset/opt_weights/125M_np", "activation_dropout": 0, "activation_function": "relu", "architectures": [ diff --git a/examples/opt_finetune/config_2.7b.json b/examples/opt_finetune/config_2.7b.json index b211cd3f5..0e100abb8 100644 --- a/examples/opt_finetune/config_2.7b.json +++ b/examples/opt_finetune/config_2.7b.json @@ -1,6 +1,5 @@ { -"_name_or_path": "facebook/opt-2.7b", -"weight_path": "/home/hao/opt_weights/2.7B_np", +"weight_path": "/home/ubuntu/tmp_dataset/opt_weights/2.7B_np", "_remove_final_layer_norm": false, "activation_dropout": 0, "activation_function": "relu", diff --git a/examples/opt_finetune/config_30b.json b/examples/opt_finetune/config_30b.json new file mode 100644 index 000000000..61935b928 --- /dev/null +++ b/examples/opt_finetune/config_30b.json @@ -0,0 +1,29 @@ +{ +"weight_path": "/home/ubuntu/tmp_dataset/opt_weights/30B_np", +"_remove_final_layer_norm": false, +"activation_dropout": 0, +"activation_function": "relu", +"architectures": [ +"OPTForCausalLM" +], +"attention_dropout": 0, +"bos_token_id": 2, +"do_layer_norm_before": true, +"dropout": 0.1, +"eos_token_id": 2, +"ffn_dim": 28672, +"hidden_size": 7168, +"init_std": 0.02, +"layerdrop": 0, +"max_position_embeddings": 2048, +"model_type": "opt", +"num_attention_heads": 56, +"num_hidden_layers": 48, +"pad_token_id": 1, +"prefix": "", +"torch_dtype": "float16", +"transformers_version": "4.21.0.dev0", +"use_cache": true, +"vocab_size": 50272, +"word_embed_proj_dim": 7168 +} \ No newline at end of file diff --git a/examples/opt_finetune/config_66b.json b/examples/opt_finetune/config_66b.json new file mode 100644 index 000000000..c171d64e8 --- /dev/null +++ b/examples/opt_finetune/config_66b.json @@ -0,0 +1,30 @@ +{ +"_name_or_path": "facebook/opt-66b", +"weight_path": "/home/ubuntu/tmp_dataset/opt_weights/66B_np", +"_remove_final_layer_norm": false, +"activation_dropout": 0, +"activation_function": "relu", +"architectures": [ +"OPTForCausalLM" +], +"attention_dropout": 0, +"bos_token_id": 2, +"do_layer_norm_before": true, +"dropout": 0.1, +"eos_token_id": 2, +"ffn_dim": 36864, +"hidden_size": 9216, +"init_std": 0.02, +"layerdrop": 0, +"max_position_embeddings": 2048, +"model_type": "opt", +"num_attention_heads": 72, +"num_hidden_layers": 64, +"pad_token_id": 1, +"prefix": "", +"torch_dtype": "float16", +"transformers_version": "4.21.0.dev0", +"use_cache": true, +"vocab_size": 50272, +"word_embed_proj_dim": 9216 +} \ No newline at end of file diff --git a/examples/opt_finetune/estimate_peak_memory.py b/examples/opt_finetune/estimate_peak_memory.py index 39dda3e16..e46152f85 100644 --- a/examples/opt_finetune/estimate_peak_memory.py +++ b/examples/opt_finetune/estimate_peak_memory.py @@ -1,5 +1,6 @@ import dataclasses from dataclasses import dataclass +import jax.numpy as jnp @dataclass(frozen=True) class OPTConfig: @@ -88,12 +89,18 @@ def estimate_peak_memory(model_name: str, mbs: int, seq_len: int): n_layer = config.num_hidden_layers n_head = config.n_head h = config.hidden_size - n_params = float(model_name.split("-")[:-1]) * 1e9 + n_params = float(model_name.split("-")[-1][:-1]) * 1e9 n_bytes = 16 * n_params + 2 * n_layer * mbs * seq_len * h + \ seq_len * mbs * h * (34 + 5.0 * n_head * seq_len / h) - return n_bytes / 1e9 + mem_gb = n_bytes / 1e9 + print(f"model: {model_name}, micro bs: {mbs}, seq_len: {seq_len}, memory lower bound: {mem_gb} GB.") + return mem_gb - -estimate_peak_memory("opt-125m", 2, 1024) \ No newline at end of file +# estimate_peak_memory("opt-2.7b", 2, 1024) +estimate_peak_memory("opt-30b", 2, 1024) +estimate_peak_memory("opt-175b", 2, 1024) +estimate_peak_memory("opt-66b", 2, 1024) +estimate_peak_memory("opt-66b", 8, 1024) +estimate_peak_memory("opt-66b", 16, 1024) \ No newline at end of file diff --git a/examples/opt_finetune/load_params.py b/examples/opt_finetune/load_params.py index afbc8ff51..1e60e10b3 100644 --- a/examples/opt_finetune/load_params.py +++ b/examples/opt_finetune/load_params.py @@ -1,5 +1,6 @@ import os import itertools +import time import numpy as np import alpa @@ -11,6 +12,107 @@ from jax.interpreters import pxla +def load_opt_params_worker_func_66b(self, path, prefix_to_idx, config, shapes, + uuids, indices, mesh_ids): + """The worker function to load OPT parameters.""" + + def load_array(key): + return np.load(os.path.join(path, key)) + + def load_param(param_key, loaded_array, is_position_embedding=False): + i = prefix_to_idx[param_key] + + for j in range(len(mesh_ids[i])): + if self.mesh_id != mesh_ids[i][j]: + continue + + if not is_position_embedding: + assert shapes[i][j] == loaded_array.shape, ( + f"{shapes[i][j]} vs. {loaded_array.shape}") + else: + if shapes[i][j] != loaded_array.shape: + assert shapes[i][j][1] == loaded_array.shape[1] + loaded_array = loaded_array[:shapes[i][j][0], :] + uuid = uuids[i][j] + datas = [] + for k in range(len(self.local_devices)): + idx = self.host_id * len(self.local_devices) + k + datas.append(loaded_array[indices[i][j][idx]]) + self.put_buffers(uuid, datas) + + # print(f" === Reading on {self.mesh_id} ===") + tic = time.time() + load_param("model.decoder.embed_tokens.embedding", + load_array("decoder.embed_tokens.embedding")) + load_param("model.decoder.embed_positions.embedding", + load_array("decoder.embed_positions.embedding"), + is_position_embedding=True) + + # if config.version > 2: + load_param("model.decoder.final_layer_norm.scale", + load_array("decoder.final_layer_norm.scale")) + load_param("model.decoder.final_layer_norm.bias", + load_array("decoder.final_layer_norm.bias")) + + layers_per_stage = config.num_hidden_layers // config.pp + + for i in range(config.num_hidden_layers): + stage_id = i // layers_per_stage + if stage_id != self.mesh_id: + continue + + # print(f"===>Reading layer {i} for mesh {self.mesh_id}") + tic1 = time.time() + param_prefix = f"model.decoder.layers.{i}." + load_prefix = f"decoder.layers.{i}." + # Attention weights + wq = load_array(load_prefix + "self_attn.q_proj.kernel") + wk = load_array(load_prefix + "self_attn.k_proj.kernel") + wv = load_array(load_prefix + "self_attn.v_proj.kernel") + # dim = wq.shape[-1] + # w_qkv = np.concatenate([wq, wk, wv], axis=0).reshape( + # (3, -1, dim)).transpose([2, 1, 0]).reshape((dim, -1)) + # load_param(param_prefix + "attention.self.qkv_combined.kernel", w_qkv) + load_param(param_prefix + "self_attn.q_proj.kernel", wq) + load_param(param_prefix + "self_attn.k_proj.kernel", wk) + load_param(param_prefix + "self_attn.v_proj.kernel", wv) + + + bq = load_array(load_prefix + "self_attn.q_proj.bias") + bk = load_array(load_prefix + "self_attn.k_proj.bias") + bv = load_array(load_prefix + "self_attn.v_proj.bias") + # b_qkv = np.concatenate([bq, bk, bv], axis=0).reshape( + # (3, dim)).transpose([1, 0]).reshape((-1,)) + load_param(param_prefix + "self_attn.q_proj.bias", bq) + load_param(param_prefix + "self_attn.k_proj.bias", bk) + load_param(param_prefix + "self_attn.v_proj.bias", bv) + # load_param(param_prefix + "attention.self.qkv_combined.bias", b_qkv) + load_param( + param_prefix + "self_attn.out_proj.kernel", + load_array(load_prefix + "self_attn.out_proj.kernel")) + load_param(param_prefix + "self_attn.out_proj.bias", + load_array(load_prefix + "self_attn.out_proj.bias")) + load_param(param_prefix + "self_attn_layer_norm.scale", + load_array(load_prefix + "self_attn_layer_norm.scale")) + load_param(param_prefix + "self_attn_layer_norm.bias", + load_array(load_prefix + "self_attn_layer_norm.bias")) + # FFN weights + load_param(param_prefix + "fc1.bias", + load_array(load_prefix + "fc1.bias")) + load_param(param_prefix + "fc1.kernel", + load_array(load_prefix + "fc1.kernel")) + load_param(param_prefix + "fc2.bias", + load_array(load_prefix + "fc2.bias")) + load_param(param_prefix + "fc2.kernel", + load_array(load_prefix + "fc2.kernel")) + load_param(param_prefix + "final_layer_norm.scale", + load_array(load_prefix + "final_layer_norm.scale")) + load_param(param_prefix + "final_layer_norm.bias", + load_array(load_prefix + "final_layer_norm.bias")) + # print(f"===>Reading layer {i} for mesh {self.mesh_id} done: {time.time() - tic1}") + print(f" === Reading on {self.mesh_id} done: {time.time() - tic} ===") + + def load_opt_params_worker_func(self, path, prefix_to_idx, config, shapes, uuids, indices, mesh_ids): """The worker function to load OPT parameters.""" @@ -23,7 +125,6 @@ def load_param(param_key, loaded_array, is_position_embedding=False): for j in range(len(mesh_ids[i])): if self.mesh_id != mesh_ids[i][j]: - print(f"skipped because self.mesh_id {self.mesh_id}, mesh_ids: {mesh_ids[i][j]}") continue if not is_position_embedding: @@ -40,6 +141,8 @@ def load_param(param_key, loaded_array, is_position_embedding=False): datas.append(loaded_array[indices[i][j][idx]]) self.put_buffers(uuid, datas) + # print(f" === Reading on {self.mesh_id} ===") + tic = time.time() load_param("model.decoder.embed_tokens.embedding", load_array("decoder.embed_tokens.weight")) load_param("model.decoder.embed_positions.embedding", @@ -55,12 +158,12 @@ def load_param(param_key, loaded_array, is_position_embedding=False): layers_per_stage = config.num_hidden_layers // config.pp for i in range(config.num_hidden_layers): - if i == 16: - print(f"stage_id: {i // layers_per_stage}, mesh id: {self.mesh_id}") stage_id = i // layers_per_stage if stage_id != self.mesh_id: continue + # print(f"===>Reading layer {i} for mesh {self.mesh_id}") + tic1 = time.time() param_prefix = f"model.decoder.layers.{i}." load_prefix = f"decoder.layers.{i}." # Attention weights @@ -107,10 +210,14 @@ def load_param(param_key, loaded_array, is_position_embedding=False): load_array(load_prefix + "final_layer_norm.weight")) load_param(param_prefix + "final_layer_norm.bias", load_array(load_prefix + "final_layer_norm.bias")) - + # print(f"===>Reading layer {i} for mesh {self.mesh_id} done: {time.time() - tic1}") + print(f" === Reading on {self.mesh_id} done: {time.time() - tic} ===") setattr(MeshHostWorker, "load_opt_params_worker_func", load_opt_params_worker_func) +# setattr(MeshHostWorker, "load_opt_params_worker_func", +# load_opt_params_worker_func_66b) + def load_params_dis_array(path, executable, params_aval, config, dummy=False): """Load parameters with distributed arrays.""" diff --git a/examples/opt_finetune/run_125m_pipe.sh b/examples/opt_finetune/run_125m_pipe.sh index d66ac2efb..f34cea180 100644 --- a/examples/opt_finetune/run_125m_pipe.sh +++ b/examples/opt_finetune/run_125m_pipe.sh @@ -3,6 +3,7 @@ python3 run_clm_flax.py \ --config_name="./config_125m.json" \ --tokenizer_name="facebook/opt-30b" \ --alpa_init \ + --use_manual_layer \ --dataset_name="wikitext" \ --dataset_config_name="wikitext-2-raw-v1" \ --do_train \ diff --git a/examples/opt_finetune/run_2.7b_pipe.sh b/examples/opt_finetune/run_2.7b_pipe.sh index 4fe9ae47a..31b811d98 100644 --- a/examples/opt_finetune/run_2.7b_pipe.sh +++ b/examples/opt_finetune/run_2.7b_pipe.sh @@ -1,20 +1,23 @@ python3 run_clm_flax.py \ --output_dir="./output" \ - --model_name_or_path="facebook/opt-2.7b" \ + --config_name="./config_2.7b.json" \ + --tokenizer_name="facebook/opt-30b" \ + --alpa_init \ + --use_manual_layer \ --dataset_name="wikitext" \ --dataset_config_name="wikitext-2-raw-v1" \ - --do_train --do_eval \ + --do_train \ --block_size="1024" \ - --per_device_train_batch_size="64" \ + --per_device_train_batch_size="128" \ --per_device_eval_batch_size="64" \ - --num_micro_batches 64 \ + --num_micro_batches 16 \ --operator_parallel 1 \ - --pipeline_parallel 2 \ + --pipeline_parallel 4 \ --dtype="float16" \ --learning_rate="5e-4" --warmup_steps="2000" \ --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \ --overwrite_output_dir \ --num_train_epochs="10" \ - --logging_steps="5" \ - --save_steps="40" \ - --eval_steps="25" + --logging_steps="1" \ + --save_steps="888" \ + --eval_steps="888" diff --git a/examples/opt_finetune/run_30b_pipe.sh b/examples/opt_finetune/run_30b_pipe.sh new file mode 100644 index 000000000..06ced36f9 --- /dev/null +++ b/examples/opt_finetune/run_30b_pipe.sh @@ -0,0 +1,23 @@ +python3 run_clm_flax.py \ + --output_dir="./output" \ + --config_name="./config_30b.json" \ + --tokenizer_name="facebook/opt-30b" \ + --alpa_init \ + --use_manual_layer \ + --dataset_name="wikitext" \ + --dataset_config_name="wikitext-2-raw-v1" \ + --do_train \ + --block_size="1024" \ + --per_device_train_batch_size="64" \ + --per_device_eval_batch_size="64" \ + --num_micro_batches 16 \ + --operator_parallel 1 \ + --pipeline_parallel 4 \ + --dtype="float16" \ + --learning_rate="5e-4" --warmup_steps="2000" \ + --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \ + --overwrite_output_dir \ + --num_train_epochs="10" \ + --logging_steps="1" \ + --save_steps="888" \ + --eval_steps="888" diff --git a/examples/opt_finetune/run_clm_flax.py b/examples/opt_finetune/run_clm_flax.py index d9c37cdc9..2b196efba 100644 --- a/examples/opt_finetune/run_clm_flax.py +++ b/examples/opt_finetune/run_clm_flax.py @@ -41,7 +41,8 @@ import alpa from alpa.model.model_util import DynamicScale, TrainState -from alpa import AutoShardingOption, AutoLayerOption, ManualStageOption, mark_pipeline_boundary +from alpa import AutoShardingOption, AutoLayerOption, ManualStageOption, mark_pipeline_boundary, CreateStateParallel +from alpa import global_config import jax import jax.numpy as jnp import optax @@ -94,6 +95,7 @@ class TrainingArguments: ) do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."}) + use_dummy_value: bool = field(default=False, metadata={"help": "Whether to use dummy values for debugging."}) use_manual_layer: bool = field(default=False, metadata={"help": "Whether to use manual layer annotation."}) alpa_init: bool = field(default=False, metadata={"help": "Whether to use Alpa's distributed gpu init."}) per_device_train_batch_size: int = field( @@ -584,6 +586,30 @@ def main(): _do_init=do_init ) + # from jax.tree_util import tree_flatten, tree_unflatten, tree_leaves + # + # file_path = "/tmp/66B_np" + # os.makedirs(file_path, exist_ok=True) + # + # def save_to_disk(params, prefix=""): + # if isinstance(params, dict): + # for key in params: + # if key == "model": + # new_prefix = prefix + # elif len(prefix) > 0: + # new_prefix = prefix + "." + key + # else: + # new_prefix = key + # save_to_disk(params[key], new_prefix) + # if isinstance(params, jax.numpy.DeviceArray): + # print(f"Save the tensor {prefix} to {file_path}") + # with open(os.path.join(file_path, prefix), "wb") as g: + # jnp.save(g, params) + # return + # + # save_to_disk(model.params) + # exit(1) + # Preprocessing the datasets. # First we tokenize all the texts. if training_args.do_train: @@ -782,6 +808,9 @@ def decay_mask_fn(params): else: use_master_copy = dynamic_scale = None + if training_args.use_dummy_value: + alpa.global_config.use_dummy_value_for_benchmarking = True + def loss_fn(logits, labels): shift_logits = logits[..., :-1, :] shift_labels = labels[..., 1:] @@ -844,7 +873,6 @@ def eval_step(params, batch): params_aval = unfreeze(jax.eval_shape(model.module.init, rngkey, input_ids, attention_mask, position_ids)["params"]) state_aval = TrainState.create_aval(apply_fn=model.__call__, params=params_aval, tx=optimizer, dynamic_scale=dynamic_scale, use_master_copy=use_master_copy) - # params_aval = jax.tree_map(lambda x: jax.ShapeDtypeStruct(x.shape, getattr(jnp, model_args.dtype)), params_aval) else: # In this case, params have been initialized by HF in the CPU memory of the driver node. state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, @@ -852,7 +880,6 @@ def eval_step(params, batch): # Create parallel version of the train and eval step if training_args.use_manual_layer: - # make layer-stage aassignments assert config.num_hidden_layers % pp == 0, f"There are {config.num_hidden_layers} layers but {pp} stages." n_layers_per_stage = config.num_hidden_layers // pp forward_stage_layer_ids = [list(range(i * n_layers_per_stage, (i + 1) * n_layers_per_stage)) for i in range(pp)] @@ -881,9 +908,9 @@ def eval_step(params, batch): print(f" - Compile executables. ", end="", flush=True) tic = time.time() # trigger compile - seq_len = config.max_position_embeddings + seq_len = 1024 train_input_shape = (train_batch_size, seq_len) - batch = {"input_ids": + batch_aval = {"input_ids": jax.core.ShapedArray( train_input_shape, jnp.int32), "position_ids": @@ -896,26 +923,36 @@ def eval_step(params, batch): jax.core.ShapedArray( train_input_shape, jnp.int32), } - train_executable = p_train_step.get_executable(state_aval, batch) - eval_executable = p_eval_step.get_executable(state_aval.params, batch) - # load params using compiled sharding specs1 + train_executable = p_train_step.get_executable(state_aval, batch_aval) + eval_executable = p_eval_step.get_executable(state_aval.params, batch_aval) + train_executable.sync() + model._is_initialized = True print(f" Compilation takes {time.time() - tic:.2f} seconds.") # def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> Dict: - print(" - Load parameters. ", end="", flush=True) - tic = time.time() - # params = load_multi_executable_params_dis_array(path, exectuables, params_aval, config, dummy) - assert config.weight_path and os.path.exists(config.weight_path), f"Cannot find weight at {config.weight_path}" - params = load_params_dis_array(config.weight_path, train_executable, params_aval, config, dummy=False) - # from alpa.serialization import restore_checkpoint - # params = restore_checkpoint(path, ) - # Work around the model._initialized in HF - model._is_initialized = True - model.params = params - state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, - dynamic_scale=dynamic_scale, use_master_copy=use_master_copy) - print(f" Load parameters takes {time.time() - tic:.2f} seconds.") + # def create_state(): + # return TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, + # dynamic_scale=dynamic_scale, use_master_copy=use_master_copy) + # p_create_state = alpa.parallelize(create_state, method=CreateStateParallel(p_train_step, batch)) + # state = p_create_state() + if not training_args.use_dummy_value: + print(" - Load parameters. ", end="", flush=True) + tic = time.time() + # params = load_multi_executable_params_dis_array(path, exectuables, params_aval, config, dummy) + assert config.weight_path and os.path.exists(config.weight_path), f"Cannot find weight at {config.weight_path}" + params = load_params_dis_array(config.weight_path, train_executable, params_aval, config, dummy=False) + train_executable.sync() + print(f" Load parameters takes {time.time() - tic:.2f} seconds.") + # from alpa.serialization import restore_checkpoint + # params = restore_checkpoint(path, ) + # Work around the model._initialized in HF + model._is_initialized = True + model.params = params + tic = time.time() + state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, + dynamic_scale=dynamic_scale, use_master_copy=use_master_copy) + print(f" Create train states takes {time.time() - tic:.2f} seconds.") dump_debug_info_train_step = dump_debug_info_eval_step = True @@ -951,7 +988,10 @@ def eval_step(params, batch): batch = next(train_loader) batch["position_ids"] = (batch["attention_mask"].cumsum(axis=1) * batch["attention_mask"]) - 1 - state, train_metric = p_train_step(state, batch) + if training_args.use_dummy_value: + state_aval, train_metric = p_train_step(state_aval, batch_aval) + else: + state, train_metric = p_train_step(state, batch) train_metrics.append(train_metric) cur_step = epoch * (len(train_dataset) // train_batch_size) + step @@ -960,7 +1000,7 @@ def eval_step(params, batch): dump_debug_info_train_step = False executable = p_train_step.get_last_executable() executable.sync() - executable.dump_debug_info("alpa_debug_info") + executable.dump_debug_info("/tmp/alpa_debug_info") epochs.write(f"Initial compilation completed. " f"Time elapsed: {time.time() - train_start:.2f} s") From 4b2528966b5cef14ec93e8790dd4e01875a2147d Mon Sep 17 00:00:00 2001 From: zhisbug Date: Wed, 25 Jan 2023 00:53:46 -0500 Subject: [PATCH 05/16] dummy value update --- examples/opt_finetune/run_clm_flax.py | 28 +++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/examples/opt_finetune/run_clm_flax.py b/examples/opt_finetune/run_clm_flax.py index 2b196efba..ff2018b05 100644 --- a/examples/opt_finetune/run_clm_flax.py +++ b/examples/opt_finetune/run_clm_flax.py @@ -982,14 +982,21 @@ def eval_step(params, batch): # Generate an epoch by shuffling sampling indices from the train dataset train_loader = data_loader(input_rng, train_dataset, train_batch_size, train_min_batch_size, shuffle=True) - steps_per_epoch = len(train_dataset) // train_batch_size + + if training_args.use_dummy_value: + steps_per_epoch = 50 + else: + steps_per_epoch = len(train_dataset) // train_batch_size # train for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): - batch = next(train_loader) - batch["position_ids"] = (batch["attention_mask"].cumsum(axis=1) * - batch["attention_mask"]) - 1 + if not training_args.use_dummy_value: + batch = next(train_loader) + batch["position_ids"] = (batch["attention_mask"].cumsum(axis=1) * + batch["attention_mask"]) - 1 + else: + batch = batch_aval if training_args.use_dummy_value: - state_aval, train_metric = p_train_step(state_aval, batch_aval) + state_aval, train_metric = p_train_step(state_aval, batch) else: state, train_metric = p_train_step(state, batch) train_metrics.append(train_metric) @@ -1116,11 +1123,12 @@ def eval_step(params, batch): json.dump(eval_metrics, f, indent=4, sort_keys=True) # Save the final model - epochs.write("\nSave the final model...") - alpa.prefetch(state.params) - params = alpa.util.map_to_nparray(state.params) - model.save_pretrained(training_args.output_dir, params=params) - tokenizer.save_pretrained(training_args.output_dir) + if not training_args.use_dummy_value: + epochs.write("\nSave the final model...") + alpa.prefetch(state.params) + params = alpa.util.map_to_nparray(state.params) + model.save_pretrained(training_args.output_dir, params=params) + tokenizer.save_pretrained(training_args.output_dir) if __name__ == "__main__": From fbb4621188ea0abcd63a310f8b964a23fb27cb46 Mon Sep 17 00:00:00 2001 From: Hao Zhang Date: Thu, 26 Jan 2023 09:37:50 -0800 Subject: [PATCH 06/16] switch to use broadcast [skip ci] --- alpa/global_env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/alpa/global_env.py b/alpa/global_env.py index c5d4943d6..f66de356a 100644 --- a/alpa/global_env.py +++ b/alpa/global_env.py @@ -72,7 +72,7 @@ def __init__(self): self.use_local_allgather = True # Cross mesh resharding mode. Possible choices: {"send_recv", # "broadcast"} - self.resharding_mode = "send_recv" + self.resharding_mode = "broadcast" # Which nccl to use. Possible choices: {"cupy", # "xla_extension"} self.nccl_mode = "xla_extension" From bfdb2dddb3eca41d0b0c46f329ed16564f46a4a8 Mon Sep 17 00:00:00 2001 From: Hao Zhang Date: Thu, 26 Jan 2023 09:39:11 -0800 Subject: [PATCH 07/16] expand the dataset size 10x --- examples/opt_finetune/run_clm_flax.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/opt_finetune/run_clm_flax.py b/examples/opt_finetune/run_clm_flax.py index ff2018b05..e294a7abd 100644 --- a/examples/opt_finetune/run_clm_flax.py +++ b/examples/opt_finetune/run_clm_flax.py @@ -33,6 +33,7 @@ from itertools import chain from pathlib import Path from typing import Callable, Optional, Tuple, Dict +import copy import datasets import numpy as np @@ -696,6 +697,10 @@ def group_texts(examples): if data_args.max_train_samples is not None: max_train_samples = min(len(train_dataset), data_args.max_train_samples) train_dataset = train_dataset.select(range(max_train_samples)) + new_datasets = [] + for i in range(10): + new_datasets.append(copy.deepcopy(train_dataset)) + train_dataset = datasets.concatenate_datasets(new_datasets) if "validation" not in tokenized_datasets: raise ValueError("--do_eval requires a validation dataset") From 901a483b6e15f130edeee5c84cfd9e9a5a409479 Mon Sep 17 00:00:00 2001 From: Hao Zhang Date: Thu, 26 Jan 2023 09:40:28 -0800 Subject: [PATCH 08/16] do not destroy collective group --- .../collective_group/base_collective_group.py | 1 + .../collective_group/nccl_collective_group.py | 2 +- examples/opt_finetune/run_175b_pipe.sh | 16 +++++++------ examples/opt_finetune/run_30b_pipe.sh | 6 ++--- examples/opt_finetune/run_66b_pipe.sh | 24 +++++++++++++++++++ 5 files changed, 38 insertions(+), 11 deletions(-) create mode 100644 examples/opt_finetune/run_66b_pipe.sh diff --git a/alpa/collective/collective_group/base_collective_group.py b/alpa/collective/collective_group/base_collective_group.py index 8bfc2304a..4efdb249d 100644 --- a/alpa/collective/collective_group/base_collective_group.py +++ b/alpa/collective/collective_group/base_collective_group.py @@ -113,6 +113,7 @@ def get_access_counter(self): def destroy_store(self): """Delete the named actor.""" ray.kill(self._store) + # ray.get(self._store.__ray_terminate__.remote()) self._store = None diff --git a/alpa/collective/collective_group/nccl_collective_group.py b/alpa/collective/collective_group/nccl_collective_group.py index 460c8b94b..0c74edd18 100644 --- a/alpa/collective/collective_group/nccl_collective_group.py +++ b/alpa/collective/collective_group/nccl_collective_group.py @@ -725,7 +725,7 @@ def _rendezvous_nccl_uid(self, rank, comm_key, max_counter, nccl_uid=None): "NCCLUniqueID has been broadcasted. The " "NCCLUniqueIDStore will go out of context and be " "destroyed.") - rendezvous.destroy_store() + # rendezvous.destroy_store() return nccl_uid diff --git a/examples/opt_finetune/run_175b_pipe.sh b/examples/opt_finetune/run_175b_pipe.sh index e4808d472..a7b0e7de6 100644 --- a/examples/opt_finetune/run_175b_pipe.sh +++ b/examples/opt_finetune/run_175b_pipe.sh @@ -2,21 +2,23 @@ python3 run_clm_flax.py \ --output_dir="./output" \ --config_name="./config_175b.json" \ --tokenizer_name="facebook/opt-30b" \ + --alpa_init \ + --use_manual_layer \ + --use_dummy_value \ --dataset_name="wikitext" \ --dataset_config_name="wikitext-2-raw-v1" \ --do_train \ - --use_manual_layer \ --block_size="1024" \ - --per_device_train_batch_size="64" \ + --per_device_train_batch_size="8" \ --per_device_eval_batch_size="64" \ - --num_micro_batches 64 \ - --operator_parallel 1 \ + --num_micro_batches 8 \ + --operator_parallel 8 \ --pipeline_parallel 8 \ --dtype="float16" \ - --learning_rate="1.2e-4" --warmup_steps="2000" \ + --learning_rate="5e-4" --warmup_steps="2000" \ --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \ --overwrite_output_dir \ --num_train_epochs="10" \ --logging_steps="1" \ - --save_steps="40" \ - --eval_steps="25" + --save_steps="888" \ + --eval_steps="888" diff --git a/examples/opt_finetune/run_30b_pipe.sh b/examples/opt_finetune/run_30b_pipe.sh index 06ced36f9..cb389f083 100644 --- a/examples/opt_finetune/run_30b_pipe.sh +++ b/examples/opt_finetune/run_30b_pipe.sh @@ -8,11 +8,11 @@ python3 run_clm_flax.py \ --dataset_config_name="wikitext-2-raw-v1" \ --do_train \ --block_size="1024" \ - --per_device_train_batch_size="64" \ + --per_device_train_batch_size="1024" \ --per_device_eval_batch_size="64" \ - --num_micro_batches 16 \ + --num_micro_batches 256 \ --operator_parallel 1 \ - --pipeline_parallel 4 \ + --pipeline_parallel 16 \ --dtype="float16" \ --learning_rate="5e-4" --warmup_steps="2000" \ --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \ diff --git a/examples/opt_finetune/run_66b_pipe.sh b/examples/opt_finetune/run_66b_pipe.sh new file mode 100644 index 000000000..2990026c4 --- /dev/null +++ b/examples/opt_finetune/run_66b_pipe.sh @@ -0,0 +1,24 @@ +python3 run_clm_flax.py \ + --output_dir="./output" \ + --config_name="./config_66b.json" \ + --tokenizer_name="facebook/opt-30b" \ + --alpa_init \ + --use_manual_layer \ + --use_dummy_value \ + --dataset_name="wikitext" \ + --dataset_config_name="wikitext-2-raw-v1" \ + --do_train \ + --block_size="1024" \ + --per_device_train_batch_size="1024" \ + --per_device_eval_batch_size="64" \ + --num_micro_batches 512 \ + --operator_parallel 4 \ + --pipeline_parallel 8 \ + --dtype="float16" \ + --learning_rate="5e-4" --warmup_steps="2000" \ + --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \ + --overwrite_output_dir \ + --num_train_epochs="10" \ + --logging_steps="1" \ + --save_steps="888" \ + --eval_steps="888" From 40ef377215f72e0ef92942f202f8765983bde089 Mon Sep 17 00:00:00 2001 From: Hao Zhang Date: Mon, 30 Jan 2023 11:04:27 -0800 Subject: [PATCH 09/16] 175B script for 32 node --- examples/opt_finetune/run_175b_pipe.sh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/opt_finetune/run_175b_pipe.sh b/examples/opt_finetune/run_175b_pipe.sh index a7b0e7de6..cbbd620e2 100644 --- a/examples/opt_finetune/run_175b_pipe.sh +++ b/examples/opt_finetune/run_175b_pipe.sh @@ -9,11 +9,11 @@ python3 run_clm_flax.py \ --dataset_config_name="wikitext-2-raw-v1" \ --do_train \ --block_size="1024" \ - --per_device_train_batch_size="8" \ + --per_device_train_batch_size="1024" \ --per_device_eval_batch_size="64" \ - --num_micro_batches 8 \ - --operator_parallel 8 \ - --pipeline_parallel 8 \ + --num_micro_batches 512 \ + --operator_parallel 1 \ + --pipeline_parallel 32 \ --dtype="float16" \ --learning_rate="5e-4" --warmup_steps="2000" \ --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \ From 45306ebbc78d8fac67a460414c5e4cf4a37f3b24 Mon Sep 17 00:00:00 2001 From: zhisbug Date: Wed, 1 Feb 2023 14:21:51 -0500 Subject: [PATCH 10/16] use real data for training instead of dummy --- examples/opt_finetune/run_clm_flax.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/opt_finetune/run_clm_flax.py b/examples/opt_finetune/run_clm_flax.py index e294a7abd..0d20a3de1 100644 --- a/examples/opt_finetune/run_clm_flax.py +++ b/examples/opt_finetune/run_clm_flax.py @@ -698,7 +698,7 @@ def group_texts(examples): max_train_samples = min(len(train_dataset), data_args.max_train_samples) train_dataset = train_dataset.select(range(max_train_samples)) new_datasets = [] - for i in range(10): + for i in range(50): new_datasets.append(copy.deepcopy(train_dataset)) train_dataset = datasets.concatenate_datasets(new_datasets) @@ -999,7 +999,12 @@ def eval_step(params, batch): batch["position_ids"] = (batch["attention_mask"].cumsum(axis=1) * batch["attention_mask"]) - 1 else: - batch = batch_aval + # batch = batch_aval + batch = next(train_loader) + batch["position_ids"] = (batch["attention_mask"].cumsum(axis=1) * + batch["attention_mask"]) - 1 + for b in batch: + batch[b] = batch[b].astype(jnp.int32) if training_args.use_dummy_value: state_aval, train_metric = p_train_step(state_aval, batch) else: From 37b2d9b1573087c9a967edcc9ac3c71e767001db Mon Sep 17 00:00:00 2001 From: zhisbug Date: Wed, 1 Feb 2023 14:45:47 -0500 Subject: [PATCH 11/16] opt init and change to 2048 --- alpa/device_mesh.py | 3 ++- examples/opt_finetune/run_175b_pipe.sh | 10 +++++----- examples/opt_finetune/run_clm_flax.py | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/alpa/device_mesh.py b/alpa/device_mesh.py index 4991f81dc..448a4ef92 100644 --- a/alpa/device_mesh.py +++ b/alpa/device_mesh.py @@ -202,7 +202,8 @@ def shard_and_put_non_zero_buffer(self, uuids: Union[Sequence[int], int], dim_size = len(range(*filled_slice)) shard_shape.append(dim_size) arys[b][device_id] = (self.backend.buffer_from_pyval( - np.full(shard_shape, 1e-8, dtype), + # np.full(shard_shape, 1e-8, dtype), + np.random.normal(0, 0.0006, shard_shape).astype(dtype), self.local_devices[device_id])) for uuid, ary in zip(uuids, arys): self.buffers[uuid] = ary diff --git a/examples/opt_finetune/run_175b_pipe.sh b/examples/opt_finetune/run_175b_pipe.sh index cbbd620e2..40d7f0554 100644 --- a/examples/opt_finetune/run_175b_pipe.sh +++ b/examples/opt_finetune/run_175b_pipe.sh @@ -8,12 +8,12 @@ python3 run_clm_flax.py \ --dataset_name="wikitext" \ --dataset_config_name="wikitext-2-raw-v1" \ --do_train \ - --block_size="1024" \ - --per_device_train_batch_size="1024" \ + --block_size="2048" \ + --per_device_train_batch_size="1536" \ --per_device_eval_batch_size="64" \ - --num_micro_batches 512 \ - --operator_parallel 1 \ - --pipeline_parallel 32 \ + --num_micro_batches 768 \ + --operator_parallel 4 \ + --pipeline_parallel 16 \ --dtype="float16" \ --learning_rate="5e-4" --warmup_steps="2000" \ --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \ diff --git a/examples/opt_finetune/run_clm_flax.py b/examples/opt_finetune/run_clm_flax.py index 0d20a3de1..2f763e300 100644 --- a/examples/opt_finetune/run_clm_flax.py +++ b/examples/opt_finetune/run_clm_flax.py @@ -913,7 +913,7 @@ def eval_step(params, batch): print(f" - Compile executables. ", end="", flush=True) tic = time.time() # trigger compile - seq_len = 1024 + seq_len = 2048 train_input_shape = (train_batch_size, seq_len) batch_aval = {"input_ids": jax.core.ShapedArray( From 6e8d4c2746e924c36a721b0144d768deb618d1f2 Mon Sep 17 00:00:00 2001 From: zhisbug Date: Fri, 3 Feb 2023 17:28:21 -0500 Subject: [PATCH 12/16] update --- alpa/device_mesh.py | 4 +- examples/opt_finetune/config_30b.json | 2 +- examples/opt_finetune/estimate_throughput.py | 105 +++++++++++++++++++ 3 files changed, 108 insertions(+), 3 deletions(-) create mode 100644 examples/opt_finetune/estimate_throughput.py diff --git a/alpa/device_mesh.py b/alpa/device_mesh.py index 448a4ef92..aa96aa1e7 100644 --- a/alpa/device_mesh.py +++ b/alpa/device_mesh.py @@ -202,8 +202,8 @@ def shard_and_put_non_zero_buffer(self, uuids: Union[Sequence[int], int], dim_size = len(range(*filled_slice)) shard_shape.append(dim_size) arys[b][device_id] = (self.backend.buffer_from_pyval( - # np.full(shard_shape, 1e-8, dtype), - np.random.normal(0, 0.0006, shard_shape).astype(dtype), + np.full(shard_shape, 1e-8, dtype), + # np.random.normal(0, 0.0006, shard_shape).astype(dtype), self.local_devices[device_id])) for uuid, ary in zip(uuids, arys): self.buffers[uuid] = ary diff --git a/examples/opt_finetune/config_30b.json b/examples/opt_finetune/config_30b.json index 61935b928..e6eb0cd1c 100644 --- a/examples/opt_finetune/config_30b.json +++ b/examples/opt_finetune/config_30b.json @@ -1,5 +1,5 @@ { -"weight_path": "/home/ubuntu/tmp_dataset/opt_weights/30B_np", +"weight_path": "/home/ubuntu/dataset/opt_weights/30B_np", "_remove_final_layer_norm": false, "activation_dropout": 0, "activation_function": "relu", diff --git a/examples/opt_finetune/estimate_throughput.py b/examples/opt_finetune/estimate_throughput.py new file mode 100644 index 000000000..383a9fb50 --- /dev/null +++ b/examples/opt_finetune/estimate_throughput.py @@ -0,0 +1,105 @@ +import dataclasses +from dataclasses import dataclass +import jax.numpy as jnp +import alpa + +@dataclass(frozen=True) +class OPTConfig: + # Inherited from OPT + num_hidden_layers: int = 12 + max_seq_len: int = 2048 + hidden_size: int = 768 + n_head: int = 12 + input_dim: int = 768 + ffn_embed_dim: int = 3072 + pad: int = 1 + activation_fn: str = 'relu' + dtype: any = jnp.float16 + use_stable_embedding: bool = False + no_scale_embedding: bool = True + decoder_learned_pos: bool = True + decoder_normalize_before: bool = True + share_decoder_input_output_embed: bool = True + # Added + version: int = 1 + vocab_size: int = 50272 + layer_norm_eps: float = 0.00001 + num_pp_stages: int = None + # parallelize + mark_boundary: bool = True + + +def get_config(name, **kwargs): + if name == "opt-125m": + config = OPTConfig( + max_seq_len=2048, num_hidden_layers=12, n_head=12, + hidden_size=768, input_dim=768, ffn_embed_dim=768 * 4, + version=3, + ) + elif name == "opt-350m": + config = OPTConfig( + max_seq_len=2048, num_hidden_layers=24, n_head=16, + hidden_size=1024, input_dim=1024, ffn_embed_dim=1024 * 4, + version=2, + ) + raise NotImplementedError("Not implemented because this model " + "has a different architecture") + elif name == "opt-1.3b": + config = OPTConfig( + max_seq_len=2048, num_hidden_layers=24, n_head=32, + hidden_size=2048, input_dim=2048, ffn_embed_dim=2048 * 4, + version=3, + ) + elif name == "opt-2.7b": + config = OPTConfig( + max_seq_len=2048, num_hidden_layers=32, n_head=32, + hidden_size=2560, input_dim=2560, ffn_embed_dim=2560 * 4, + version=3, + ) + elif name == "opt-6.7b": + config = OPTConfig( + max_seq_len=2048, num_hidden_layers=32, n_head=32, + hidden_size=4096, input_dim=4096, ffn_embed_dim=4096 * 4, + version=3, + ) + elif name == "opt-30b": + config = OPTConfig( + max_seq_len=2048, num_hidden_layers=48, n_head=56, + hidden_size=7168, input_dim=7168, ffn_embed_dim=7168 * 4, + version=3, + ) + elif name == "opt-66b": + config = OPTConfig( + max_seq_len=2048, num_hidden_layers=64, n_head=72, + hidden_size=9216, input_dim=9216, ffn_embed_dim=9216 * 4, + version=3, + ) + elif name == "opt-175b": + config = OPTConfig( + max_seq_len=2048, num_hidden_layers=96, n_head=96, + hidden_size=12288, input_dim=12288, ffn_embed_dim=12288 * 4, + version=3, + ) + else: + raise ValueError(f"Invalid model name: {name}") + + return dataclasses.replace(config, **kwargs) + + +def estimate_throughput(model_name: str, gbs: int, seq_len: int, num_device: int, latency: float): + config = get_config(model_name) + n_layer = config.num_hidden_layers + h = config.hidden_size + + throughput_tflops = alpa.util.compute_gpt_tflops( + batch_size=gbs, + seq_len=seq_len, + num_layers=n_layer, + hidden_size=h, + vocab_size=50272, + num_gpus=num_device, + latency=latency) + print(f"Model {model_name}, gbs: {gbs}, seq_len: {seq_len}, Model TFlops: {throughput_tflops}, HW TFlops: {throughput_tflops * 4 / 3}..") + + +estimate_throughput("opt-175b", 1536, 2048, 128, 207) \ No newline at end of file From a9f64026b3f1ade8a0bd30fe75e4088d2bf1a99a Mon Sep 17 00:00:00 2001 From: zhisbug Date: Fri, 3 Feb 2023 17:29:39 -0500 Subject: [PATCH 13/16] reproduct bug --- examples/opt_finetune/run_clm_flax.py | 70 ++++++++++++++++++++------- 1 file changed, 53 insertions(+), 17 deletions(-) diff --git a/examples/opt_finetune/run_clm_flax.py b/examples/opt_finetune/run_clm_flax.py index 2f763e300..44ad9a35a 100644 --- a/examples/opt_finetune/run_clm_flax.py +++ b/examples/opt_finetune/run_clm_flax.py @@ -869,7 +869,9 @@ def eval_step(params, batch): metrics = {"loss": loss} return metrics - if training_args.alpa_init: + use_create_state_parallel = True + + if not use_create_state_parallel and training_args.alpa_init: # In this case, params are not initialized yet, and we use shaped array to trigger alpa compilation rngkey = jax.core.ShapedArray((2,), jnp.uint32) input_ids = jax.core.ShapedArray((1, 128), jnp.int32) @@ -878,6 +880,34 @@ def eval_step(params, batch): params_aval = unfreeze(jax.eval_shape(model.module.init, rngkey, input_ids, attention_mask, position_ids)["params"]) state_aval = TrainState.create_aval(apply_fn=model.__call__, params=params_aval, tx=optimizer, dynamic_scale=dynamic_scale, use_master_copy=use_master_copy) + elif use_create_state_parallel: + def create_state(): + model = FlaxAutoModelForCausalLM.from_config( + config, + seed=training_args.seed, + dtype=getattr(jnp, model_args.dtype), + _do_init=do_init + ) + optimizer = optax.chain( + optax.clip_by_global_norm(1.0), + optax.adamw( + learning_rate=linear_decay_lr_schedule_fn, + b1=training_args.adam_beta1, + b2=training_args.adam_beta2, + eps=training_args.adam_epsilon, + weight_decay=training_args.weight_decay, + mask=decay_mask_fn) + ) + rngkey = jnp.ones((2,), jnp.uint32) + input_ids = jnp.ones((1, 128), jnp.int32) + attention_mask = jnp.ones((1, 128), jnp.int32) + position_ids = jnp.ones((1, 128), jnp.int32) + # params_aval = unfreeze( + # jax.eval_shape(model.module.init, rngkey, input_ids, attention_mask, position_ids)["params"]) + params = unfreeze(model.module.init(rngkey, input_ids, attention_mask, position_ids)["params"]) + state_aval = TrainState.create(apply_fn=model.__call__, params=params, tx=optimizer, + dynamic_scale=dynamic_scale, use_master_copy=use_master_copy) + return state_aval else: # In this case, params have been initialized by HF in the CPU memory of the driver node. state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, @@ -913,7 +943,7 @@ def eval_step(params, batch): print(f" - Compile executables. ", end="", flush=True) tic = time.time() # trigger compile - seq_len = 2048 + seq_len = data_args.block_size train_input_shape = (train_batch_size, seq_len) batch_aval = {"input_ids": jax.core.ShapedArray( @@ -928,23 +958,27 @@ def eval_step(params, batch): jax.core.ShapedArray( train_input_shape, jnp.int32), } + + batch_aval = { + "input_ids": jnp.ones(train_input_shape, jnp.int32), + "position_ids": jnp.ones(train_input_shape, jnp.int32), + "attention_mask": jnp.ones(train_input_shape, jnp.int32), + "labels": jnp.ones(train_input_shape, jnp.int32), + } + + if use_create_state_parallel: + p_create_state = alpa.parallelize(create_state, method=CreateStateParallel(p_train_step, batch_aval)) + state_aval = p_create_state() + train_executable = p_train_step.get_executable(state_aval, batch_aval) eval_executable = p_eval_step.get_executable(state_aval.params, batch_aval) train_executable.sync() model._is_initialized = True print(f" Compilation takes {time.time() - tic:.2f} seconds.") - # def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> Dict: - - # def create_state(): - # return TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, - # dynamic_scale=dynamic_scale, use_master_copy=use_master_copy) - # p_create_state = alpa.parallelize(create_state, method=CreateStateParallel(p_train_step, batch)) - # state = p_create_state() if not training_args.use_dummy_value: print(" - Load parameters. ", end="", flush=True) tic = time.time() - # params = load_multi_executable_params_dis_array(path, exectuables, params_aval, config, dummy) assert config.weight_path and os.path.exists(config.weight_path), f"Cannot find weight at {config.weight_path}" params = load_params_dis_array(config.weight_path, train_executable, params_aval, config, dummy=False) train_executable.sync() @@ -955,6 +989,8 @@ def eval_step(params, batch): model._is_initialized = True model.params = params tic = time.time() + + state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dynamic_scale=dynamic_scale, use_master_copy=use_master_copy) print(f" Create train states takes {time.time() - tic:.2f} seconds.") @@ -999,12 +1035,12 @@ def eval_step(params, batch): batch["position_ids"] = (batch["attention_mask"].cumsum(axis=1) * batch["attention_mask"]) - 1 else: - # batch = batch_aval - batch = next(train_loader) - batch["position_ids"] = (batch["attention_mask"].cumsum(axis=1) * - batch["attention_mask"]) - 1 - for b in batch: - batch[b] = batch[b].astype(jnp.int32) + batch = batch_aval + # batch = next(train_loader) + # batch["position_ids"] = (batch["attention_mask"].cumsum(axis=1) * + # batch["attention_mask"]) - 1 + # for b in batch: + # batch[b] = batch[b].astype(jnp.int32) if training_args.use_dummy_value: state_aval, train_metric = p_train_step(state_aval, batch) else: @@ -1046,7 +1082,7 @@ def eval_step(params, batch): epochs.write( f"Step... {cur_step} | " f"Loss: {train_metric['loss'].mean():.4f}, " - f"Learning Rate: {train_metric['learning_rate'].mean():.5f}, " + f"Learning Rate: {train_metric['learning_rate'].mean():.8f}, " f"Throughput: {throughput_tokens:.2f} token/s, " f"{throughput_tflops:.2f} TFLOP/s, " f"{throughput_tflops * 4 / 3:.2f} TFLOP/s." From ce155d49cbb8117d872c8d228ea7315e9987ec61 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 6 Feb 2023 00:57:52 +0000 Subject: [PATCH 14/16] Fix output_sharding_dict --- alpa/pipeline_parallel/computation.py | 12 ++++++++---- alpa/util.py | 7 +++++++ examples/opt_finetune/run_clm_flax.py | 2 +- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/alpa/pipeline_parallel/computation.py b/alpa/pipeline_parallel/computation.py index 3dc2c7b68..feac7c966 100644 --- a/alpa/pipeline_parallel/computation.py +++ b/alpa/pipeline_parallel/computation.py @@ -28,7 +28,7 @@ get_compile_options, jaxpr_to_hlo, setup_computation_alias, compile_dummy_zero_constant, get_var_mapping, undefined_sharding_spec_proto, - new_jaxpr_eqn) + new_jaxpr_eqn, replicated_sharding_spec_proto) from alpa.wrapped_hlo import HloStatus, WrappedHlo # pylint: disable=redefined-builtin @@ -750,9 +750,13 @@ def generate_sharded_xla_computations_arguments( hlo.set_input_shardings(sharding_protos) if output_sharding_dict: - sharding_protos = [ - output_sharding_dict[x].sharding_proto() for x in outvars - ] + sharding_protos = [] + for x in outvars: + spec = output_sharding_dict.get(x, None) + if spec is None: + sharding_protos.append(replicated_sharding_spec_proto()) + else: + sharding_protos.append(spec.sharding_proto()) hlo.set_output_shardings(sharding_protos) if stage_input_sharding: diff --git a/alpa/util.py b/alpa/util.py index 72c58c3cb..1af4df5ad 100644 --- a/alpa/util.py +++ b/alpa/util.py @@ -614,6 +614,13 @@ def undefined_sharding_spec_proto(): return proto +def replicated_sharding_spec_proto(): + """Return a proto of ShardingSpec which represents a replicated spec.""" + proto = xc.OpSharding() + proto.type = xc.OpSharding.Type.REPLICATED + return proto + + ######################################## ##### Jaxpr Utilities ######################################## diff --git a/examples/opt_finetune/run_clm_flax.py b/examples/opt_finetune/run_clm_flax.py index 44ad9a35a..fc0d406b3 100644 --- a/examples/opt_finetune/run_clm_flax.py +++ b/examples/opt_finetune/run_clm_flax.py @@ -980,7 +980,7 @@ def create_state(): print(" - Load parameters. ", end="", flush=True) tic = time.time() assert config.weight_path and os.path.exists(config.weight_path), f"Cannot find weight at {config.weight_path}" - params = load_params_dis_array(config.weight_path, train_executable, params_aval, config, dummy=False) + params = load_params_dis_array(config.weight_path, train_executable, state_aval.params, config, dummy=False) train_executable.sync() print(f" Load parameters takes {time.time() - tic:.2f} seconds.") # from alpa.serialization import restore_checkpoint From a70e5405b511d45eac0daf495d2f796c33b5d2b2 Mon Sep 17 00:00:00 2001 From: zhisbug Date: Mon, 6 Feb 2023 14:57:32 -0500 Subject: [PATCH 15/16] add master copy distributed init --- alpa/device_mesh.py | 46 +++++++++++++++++++++++++-- alpa/model/model_util.py | 25 +++++++++++++++ examples/opt_finetune/run_clm_flax.py | 4 +-- 3 files changed, 71 insertions(+), 4 deletions(-) diff --git a/alpa/device_mesh.py b/alpa/device_mesh.py index aa96aa1e7..f9aee0e84 100644 --- a/alpa/device_mesh.py +++ b/alpa/device_mesh.py @@ -30,6 +30,7 @@ import time from typing import Any, List, Union, Sequence, Tuple, Optional +import jax from jax import core, xla, device_put from jax._src.api import ShapeDtypeStruct from jax._src.lib import xla_bridge as xb, xla_extension as xe @@ -232,6 +233,24 @@ def get_buffers(self, for uuid, local_ids in zip(uuids, device_indices) ] + def copy_buffer(self, + target_uuid, + src_uuid, + dtype): + # print(f"target_uuid: {target_uuid}, src_uuid: {src_uuid}...") + datas = self.buffers[src_uuid] + # print(f"old data: {datas} , size: {len(datas)}") + assert len(datas) == self.num_devices + new_datas = [] + for i, data in enumerate(datas): + # print(f"device_id {data.device()}") + if data.dtype != dtype: + new_data = np.asarray(data, dtype) + # print(f"new_data: {new_data}, dtype: {new_data.dtype}...") + new_datas.append(self.backend.buffer_from_pyval(new_data, data.device())) + self.buffers[target_uuid] = new_datas + # print(f"putted: {self.buffers[target_uuid]}...") + def delete_buffers(self, uuids: Union[Sequence[int], int]): if isinstance(uuids, Iterable): for uuid in uuids: @@ -1486,9 +1505,9 @@ class DistributedArray: a normal numpy array. Internally, it stores a pointer to all remote buffers. - The buffers are stored distributedly on remote workers' device memeory. + The buffers are stored distributedly on remote workers' device memory. When users require the value of the array. These buffers will be gathered - to the dirver. + to the driver. """ def __init__(self, @@ -1779,6 +1798,29 @@ def prefetch(dis_arrays: Sequence[Union[ShardedDeviceArray, DistributedArray, array._fetched_np_buffers = np_value # pylint: disable=protected-access +def copy_distributed_array(src_array: Union[DistributedArray, ReplicatedDistributedArray], target_dtype: jnp.dtype): + aval = jax.core.ShapedArray(src_array.aval.shape, target_dtype) + if isinstance(src_array, DistributedArray): + mesh = src_array.device_mesh + spec = src_array.sharding_spec + ary_refs, ary_uuid = create_remote_array_refs(mesh) + dst_array = DistributedArray(mesh, aval, spec, ary_refs[0]) + # Do actual copy + for w in mesh.workers: + w.copy_buffer.remote(dst_array.remote_ref.uuid, src_array.remote_ref.uuid, + target_dtype) + else: + assert isinstance(src_array, ReplicatedDistributedArray) + meshes = [] + arrays = [] + for mesh in src_array._mesh_array_map: + meshes.append(mesh) + ary = copy_distributed_array(src_array._mesh_array_map[mesh], target_dtype) + arrays.append(ary) + dst_array = ReplicatedDistributedArray(meshes, arrays) + return dst_array + + ######################################## ##### Physical Mesh Group ##### ######################################## diff --git a/alpa/model/model_util.py b/alpa/model/model_util.py index 7388d98c6..aaedf7663 100644 --- a/alpa/model/model_util.py +++ b/alpa/model/model_util.py @@ -5,6 +5,7 @@ from typing import Any, Callable, Optional, Tuple, Optional, Union, Sequence from alpa.api import value_and_grad +from alpa.device_mesh import copy_distributed_array import flax from flax.training import train_state, dynamic_scale as dynamic_scale_lib from flax.training.dynamic_scale import DynamicScaleResult @@ -348,6 +349,30 @@ def create(cls, *, apply_fn, params, tx, use_master_copy=False, **kwargs): **kwargs, ) + @classmethod + def create_distributed(cls, *, apply_fn, params, tx, use_master_copy=False, **kwargs): + """The distributed version of create. It assumes the inputs are DistributedArrays.""" + if use_master_copy: + dtype = jax.tree_util.tree_flatten(params)[0][0].dtype + assert dtype == jnp.float16 + # create the master copy distributedly + master_copy = jax.tree_util.tree_map( + lambda x: copy_distributed_array(x, jnp.float32), params) + # TODO (Hao): handle opt_state + opt_state = tx.init(master_copy) + else: + master_copy = None + opt_state = tx.init(params) + return cls( + step=np.array(0, dtype=np.int32), + apply_fn=apply_fn, + params=params, + master_copy=master_copy, + tx=tx, + opt_state=opt_state, + **kwargs, + ) + @classmethod def create_aval(cls, *, diff --git a/examples/opt_finetune/run_clm_flax.py b/examples/opt_finetune/run_clm_flax.py index 2f763e300..172ef1c74 100644 --- a/examples/opt_finetune/run_clm_flax.py +++ b/examples/opt_finetune/run_clm_flax.py @@ -955,8 +955,8 @@ def eval_step(params, batch): model._is_initialized = True model.params = params tic = time.time() - state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, - dynamic_scale=dynamic_scale, use_master_copy=use_master_copy) + state = TrainState.create_distributed(apply_fn=model.__call__, params=model.params, tx=optimizer, + dynamic_scale=dynamic_scale, use_master_copy=use_master_copy) print(f" Create train states takes {time.time() - tic:.2f} seconds.") dump_debug_info_train_step = dump_debug_info_eval_step = True From 256631bc675b73433e8ee58c67f8eef9602e22c0 Mon Sep 17 00:00:00 2001 From: zhisbug Date: Mon, 20 Feb 2023 18:25:41 -0500 Subject: [PATCH 16/16] update opt finetuning weight init code --- alpa/device_mesh.py | 53 ++++++++++++++++++-------- alpa/global_env.py | 2 +- alpa/model/model_util.py | 34 +++++++++++++++++ examples/opt_finetune/config_2.7b.json | 2 +- examples/opt_finetune/run_clm_flax.py | 41 +++++++++++++++----- 5 files changed, 105 insertions(+), 27 deletions(-) diff --git a/alpa/device_mesh.py b/alpa/device_mesh.py index f9aee0e84..b0b43c679 100644 --- a/alpa/device_mesh.py +++ b/alpa/device_mesh.py @@ -234,22 +234,32 @@ def get_buffers(self, ] def copy_buffer(self, - target_uuid, - src_uuid, - dtype): + shape, + src_indices, + dst_indices, + target_uuid, + src_uuid, + dtype): # print(f"target_uuid: {target_uuid}, src_uuid: {src_uuid}...") datas = self.buffers[src_uuid] # print(f"old data: {datas} , size: {len(datas)}") assert len(datas) == self.num_devices + assert len(datas) == len(src_indices) + assert len(datas) == len(dst_indices) new_datas = [] - for i, data in enumerate(datas): - # print(f"device_id {data.device()}") - if data.dtype != dtype: - new_data = np.asarray(data, dtype) - # print(f"new_data: {new_data}, dtype: {new_data.dtype}...") - new_datas.append(self.backend.buffer_from_pyval(new_data, data.device())) + + if src_indices == dst_indices: + logger.debug("Indices are the same...") + for i, data in enumerate(datas): + new_datas.append(self.backend.buffer_from_pyval(np.array(data, dtype=dtype), data.device())) + else: + logger.debug("Indices are different... Resharding!") + src_array = np.zeros(shape, dtype=dtype) + for device_id, ind in enumerate(src_indices): + src_array[ind] = np.array(datas[device_id]) + for i, data in enumerate(datas): + new_datas.append(self.backend.buffer_from_pyval(src_array[dst_indices[i]], data.device())) self.buffers[target_uuid] = new_datas - # print(f"putted: {self.buffers[target_uuid]}...") def delete_buffers(self, uuids: Union[Sequence[int], int]): if isinstance(uuids, Iterable): @@ -1798,16 +1808,27 @@ def prefetch(dis_arrays: Sequence[Union[ShardedDeviceArray, DistributedArray, array._fetched_np_buffers = np_value # pylint: disable=protected-access -def copy_distributed_array(src_array: Union[DistributedArray, ReplicatedDistributedArray], target_dtype: jnp.dtype): +def copy_distributed_array(src_array: Union[DistributedArray, ReplicatedDistributedArray], + target_sharding_spec: ShardingSpec, + target_dtype: jnp.dtype): aval = jax.core.ShapedArray(src_array.aval.shape, target_dtype) if isinstance(src_array, DistributedArray): mesh = src_array.device_mesh - spec = src_array.sharding_spec + src_spec = src_array.sharding_spec ary_refs, ary_uuid = create_remote_array_refs(mesh) - dst_array = DistributedArray(mesh, aval, spec, ary_refs[0]) + dst_array = DistributedArray(mesh, aval, target_sharding_spec, ary_refs[0]) + if src_array.sharding_spec != target_sharding_spec: + print("Sharding spec changed. Will need resharding..." + f"src: {src_array.sharding_spec}, dst: {dst_array.sharding_spec}") + print(f"src_shape {src_array.aval.shape}, dst_shape {dst_array.aval.shape}, " + f"src_array_indices: {src_array.indices}, dst_array indices: {dst_array.indices}") # Do actual copy for w in mesh.workers: - w.copy_buffer.remote(dst_array.remote_ref.uuid, src_array.remote_ref.uuid, + w.copy_buffer.remote(dst_array.aval.shape, + src_array.indices, + dst_array.indices, + dst_array.remote_ref.uuid, + src_array.remote_ref.uuid, target_dtype) else: assert isinstance(src_array, ReplicatedDistributedArray) @@ -1815,7 +1836,9 @@ def copy_distributed_array(src_array: Union[DistributedArray, ReplicatedDistribu arrays = [] for mesh in src_array._mesh_array_map: meshes.append(mesh) - ary = copy_distributed_array(src_array._mesh_array_map[mesh], target_dtype) + ary = copy_distributed_array(src_array._mesh_array_map[mesh], + target_sharding_spec, + target_dtype) arrays.append(ary) dst_array = ReplicatedDistributedArray(meshes, arrays) return dst_array diff --git a/alpa/global_env.py b/alpa/global_env.py index f66de356a..839297aea 100644 --- a/alpa/global_env.py +++ b/alpa/global_env.py @@ -12,7 +12,7 @@ def __init__(self): # See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html self.xla_client_mem_fraction = float( - os.environ.get("XLA_PYTHON_CLIENT_MEM_FRACTION", 0.9)) + os.environ.get("XLA_PYTHON_CLIENT_MEM_FRACTION", 0.8)) self.xla_client_client_preallocate = os.environ.get( "XLA_PYTHON_CLIENT_PREALLOCATE", "true") # The threshold to tigger a batched deletion on workers. diff --git a/alpa/model/model_util.py b/alpa/model/model_util.py index aaedf7663..5103b0119 100644 --- a/alpa/model/model_util.py +++ b/alpa/model/model_util.py @@ -4,6 +4,7 @@ import functools from typing import Any, Callable, Optional, Tuple, Optional, Union, Sequence +import alpa.device_mesh from alpa.api import value_and_grad from alpa.device_mesh import copy_distributed_array import flax @@ -402,6 +403,39 @@ def create_aval(cls, **kwargs, ) + @classmethod + def create_from(cls, + *, + train_state, + params, + use_master_copy=False, + **kwargs): + """Create a new instance where everything except master_copy is given.""" + if use_master_copy: + dtype = jax.tree_util.tree_flatten(params)[0][0].dtype + assert dtype == jnp.float16 + + def get_sharding_spec(array): + if isinstance(array, alpa.device_mesh.DistributedArray): + return array.sharding_spec + else: + assert isinstance(array, alpa.device_mesh.ReplicatedDistributedArray) + return array.replica.sharding_spec + + # create the master copy distributedly + master_copy = jax.tree_util.tree_map( + lambda x, y: copy_distributed_array(x, get_sharding_spec(y), jnp.float32), params, train_state.master_copy) + else: + master_copy = None + return cls( + step=train_state.step, + apply_fn=train_state.apply_fn, + params=params, + master_copy=master_copy, + tx=train_state.tx, + opt_state=train_state.opt_state, + **kwargs + ) class DynamicScale(struct.PyTreeNode): """This is the same as flax.optim.DynamicScale, except that diff --git a/examples/opt_finetune/config_2.7b.json b/examples/opt_finetune/config_2.7b.json index 0e100abb8..91aa0431b 100644 --- a/examples/opt_finetune/config_2.7b.json +++ b/examples/opt_finetune/config_2.7b.json @@ -1,5 +1,5 @@ { -"weight_path": "/home/ubuntu/tmp_dataset/opt_weights/2.7B_np", +"weight_path": "/home/ubuntu/dataset/opt_weights/2.7B_np", "_remove_final_layer_norm": false, "activation_dropout": 0, "activation_function": "relu", diff --git a/examples/opt_finetune/run_clm_flax.py b/examples/opt_finetune/run_clm_flax.py index 734de4063..b1fb26f08 100644 --- a/examples/opt_finetune/run_clm_flax.py +++ b/examples/opt_finetune/run_clm_flax.py @@ -906,7 +906,7 @@ def create_state(): # jax.eval_shape(model.module.init, rngkey, input_ids, attention_mask, position_ids)["params"]) params = unfreeze(model.module.init(rngkey, input_ids, attention_mask, position_ids)["params"]) state_aval = TrainState.create(apply_fn=model.__call__, params=params, tx=optimizer, - dynamic_scale=dynamic_scale, use_master_copy=use_master_copy) + dynamic_scale=dynamic_scale, use_master_copy=use_master_copy) return state_aval else: # In this case, params have been initialized by HF in the CPU memory of the driver node. @@ -959,14 +959,13 @@ def create_state(): train_input_shape, jnp.int32), } - batch_aval = { - "input_ids": jnp.ones(train_input_shape, jnp.int32), - "position_ids": jnp.ones(train_input_shape, jnp.int32), - "attention_mask": jnp.ones(train_input_shape, jnp.int32), - "labels": jnp.ones(train_input_shape, jnp.int32), - } - if use_create_state_parallel: + # batch_aval = { + # "input_ids": jnp.ones(train_input_shape, jnp.int32), + # "position_ids": jnp.ones(train_input_shape, jnp.int32), + # "attention_mask": jnp.ones(train_input_shape, jnp.int32), + # "labels": jnp.ones(train_input_shape, jnp.int32), + # } p_create_state = alpa.parallelize(create_state, method=CreateStateParallel(p_train_step, batch_aval)) state_aval = p_create_state() @@ -989,10 +988,32 @@ def create_state(): model._is_initialized = True model.params = params tic = time.time() + # state = TrainState.create_from(train_state=state_aval, + # params=params, + # dynamic_scale=dynamic_scale, + # use_master_copy=True) + + state = TrainState.create(apply_fn=model.__call__, params=params, tx=optimizer, + dynamic_scale=dynamic_scale, use_master_copy=use_master_copy) + + def compare_dict(state1, state2): + print("Compare params: ") + params1 = jax.tree_util.tree_flatten(state1) + params2 = jax.tree_util.tree_flatten(state2) + for p1, p2 in zip(params1[0], params2[0]): + assert type(p1) == type(p2) + if isinstance(p1, alpa.device_mesh.DistributedArray): + print(f"Sharding spec 1: {p1.sharding_spec}, sharding spec 2 {p2.sharding_spec}") + assert p1.sharding_spec == p2.sharding_spec, f"wrong at {p1} and {p2}" + else: + assert isinstance(p1, alpa.device_mesh.ReplicatedDistributedArray) + print(f"Sharding spec 1: {p1.replica.sharding_spec}, sharding spec 2 {p2.replica.sharding_spec}") + assert p1.replica.sharding_spec == p2.replica.sharding_spec, f"wrong at {p1} and {p2}" + + + # state = TrainState.create_distributed(apply_fn=model.__call__, params=model.params, tx=optimizer, # dynamic_scale=dynamic_scale, use_master_copy=use_master_copy) - state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, - dynamic_scale=dynamic_scale, use_master_copy=use_master_copy) print(f" Create train states takes {time.time() - tic:.2f} seconds.") dump_debug_info_train_step = dump_debug_info_eval_step = True