Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] OPT training improvements #858

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions alpa/collective/collective_group/base_collective_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def get_access_counter(self):
def destroy_store(self):
"""Delete the named actor."""
ray.kill(self._store)
# ray.get(self._store.__ray_terminate__.remote())
self._store = None


Expand Down
2 changes: 1 addition & 1 deletion alpa/collective/collective_group/nccl_collective_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@ def _rendezvous_nccl_uid(self, rank, comm_key, max_counter, nccl_uid=None):
"NCCLUniqueID has been broadcasted. The "
"NCCLUniqueIDStore will go out of context and be "
"destroyed.")
rendezvous.destroy_store()
# rendezvous.destroy_store()
return nccl_uid


Expand Down
70 changes: 68 additions & 2 deletions alpa/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import time
from typing import Any, List, Union, Sequence, Tuple, Optional

import jax
from jax import core, xla, device_put
from jax._src.api import ShapeDtypeStruct
from jax._src.lib import xla_bridge as xb, xla_extension as xe
Expand Down Expand Up @@ -203,6 +204,7 @@ def shard_and_put_non_zero_buffer(self, uuids: Union[Sequence[int], int],
shard_shape.append(dim_size)
arys[b][device_id] = (self.backend.buffer_from_pyval(
np.full(shard_shape, 1e-8, dtype),
# np.random.normal(0, 0.0006, shard_shape).astype(dtype),
self.local_devices[device_id]))
for uuid, ary in zip(uuids, arys):
self.buffers[uuid] = ary
Expand Down Expand Up @@ -231,6 +233,34 @@ def get_buffers(self,
for uuid, local_ids in zip(uuids, device_indices)
]

def copy_buffer(self,
shape,
src_indices,
dst_indices,
target_uuid,
src_uuid,
dtype):
# print(f"target_uuid: {target_uuid}, src_uuid: {src_uuid}...")
datas = self.buffers[src_uuid]
# print(f"old data: {datas} , size: {len(datas)}")
assert len(datas) == self.num_devices
assert len(datas) == len(src_indices)
assert len(datas) == len(dst_indices)
new_datas = []

if src_indices == dst_indices:
logger.debug("Indices are the same...")
for i, data in enumerate(datas):
new_datas.append(self.backend.buffer_from_pyval(np.array(data, dtype=dtype), data.device()))
else:
logger.debug("Indices are different... Resharding!")
src_array = np.zeros(shape, dtype=dtype)
for device_id, ind in enumerate(src_indices):
src_array[ind] = np.array(datas[device_id])
for i, data in enumerate(datas):
new_datas.append(self.backend.buffer_from_pyval(src_array[dst_indices[i]], data.device()))
self.buffers[target_uuid] = new_datas

def delete_buffers(self, uuids: Union[Sequence[int], int]):
if isinstance(uuids, Iterable):
for uuid in uuids:
Expand Down Expand Up @@ -1485,9 +1515,9 @@ class DistributedArray:
a normal numpy array.

Internally, it stores a pointer to all remote buffers.
The buffers are stored distributedly on remote workers' device memeory.
The buffers are stored distributedly on remote workers' device memory.
When users require the value of the array. These buffers will be gathered
to the dirver.
to the driver.
"""

def __init__(self,
Expand Down Expand Up @@ -1778,6 +1808,42 @@ def prefetch(dis_arrays: Sequence[Union[ShardedDeviceArray, DistributedArray,
array._fetched_np_buffers = np_value # pylint: disable=protected-access


def copy_distributed_array(src_array: Union[DistributedArray, ReplicatedDistributedArray],
target_sharding_spec: ShardingSpec,
target_dtype: jnp.dtype):
aval = jax.core.ShapedArray(src_array.aval.shape, target_dtype)
if isinstance(src_array, DistributedArray):
mesh = src_array.device_mesh
src_spec = src_array.sharding_spec
ary_refs, ary_uuid = create_remote_array_refs(mesh)
dst_array = DistributedArray(mesh, aval, target_sharding_spec, ary_refs[0])
if src_array.sharding_spec != target_sharding_spec:
print("Sharding spec changed. Will need resharding..."
f"src: {src_array.sharding_spec}, dst: {dst_array.sharding_spec}")
print(f"src_shape {src_array.aval.shape}, dst_shape {dst_array.aval.shape}, "
f"src_array_indices: {src_array.indices}, dst_array indices: {dst_array.indices}")
# Do actual copy
for w in mesh.workers:
w.copy_buffer.remote(dst_array.aval.shape,
src_array.indices,
dst_array.indices,
dst_array.remote_ref.uuid,
src_array.remote_ref.uuid,
target_dtype)
else:
assert isinstance(src_array, ReplicatedDistributedArray)
meshes = []
arrays = []
for mesh in src_array._mesh_array_map:
meshes.append(mesh)
ary = copy_distributed_array(src_array._mesh_array_map[mesh],
target_sharding_spec,
target_dtype)
arrays.append(ary)
dst_array = ReplicatedDistributedArray(meshes, arrays)
return dst_array


########################################
##### Physical Mesh Group #####
########################################
Expand Down
6 changes: 3 additions & 3 deletions alpa/global_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def __init__(self):

# See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
self.xla_client_mem_fraction = float(
os.environ.get("XLA_PYTHON_CLIENT_MEM_FRACTION", 0.9))
os.environ.get("XLA_PYTHON_CLIENT_MEM_FRACTION", 0.8))
self.xla_client_client_preallocate = os.environ.get(
"XLA_PYTHON_CLIENT_PREALLOCATE", "true")
# The threshold to tigger a batched deletion on workers.
Expand Down Expand Up @@ -72,10 +72,10 @@ def __init__(self):
self.use_local_allgather = True
# Cross mesh resharding mode. Possible choices: {"send_recv",
# "broadcast"}
self.resharding_mode = "send_recv"
self.resharding_mode = "broadcast"
# Which nccl to use. Possible choices: {"cupy",
# "xla_extension"}
self.nccl_mode = "cupy"
self.nccl_mode = "xla_extension"
self.enable_overlapping = False
# Cross mesh resharding load balancing mode.
# Possible choices: {"normal", "no_loadbalance",
Expand Down
59 changes: 59 additions & 0 deletions alpa/model/model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import functools
from typing import Any, Callable, Optional, Tuple, Optional, Union, Sequence

import alpa.device_mesh
from alpa.api import value_and_grad
from alpa.device_mesh import copy_distributed_array
import flax
from flax.training import train_state, dynamic_scale as dynamic_scale_lib
from flax.training.dynamic_scale import DynamicScaleResult
Expand Down Expand Up @@ -348,6 +350,30 @@ def create(cls, *, apply_fn, params, tx, use_master_copy=False, **kwargs):
**kwargs,
)

@classmethod
def create_distributed(cls, *, apply_fn, params, tx, use_master_copy=False, **kwargs):
"""The distributed version of create. It assumes the inputs are DistributedArrays."""
if use_master_copy:
dtype = jax.tree_util.tree_flatten(params)[0][0].dtype
assert dtype == jnp.float16
# create the master copy distributedly
master_copy = jax.tree_util.tree_map(
lambda x: copy_distributed_array(x, jnp.float32), params)
# TODO (Hao): handle opt_state
opt_state = tx.init(master_copy)
else:
master_copy = None
opt_state = tx.init(params)
return cls(
step=np.array(0, dtype=np.int32),
apply_fn=apply_fn,
params=params,
master_copy=master_copy,
tx=tx,
opt_state=opt_state,
**kwargs,
)

@classmethod
def create_aval(cls,
*,
Expand Down Expand Up @@ -377,6 +403,39 @@ def create_aval(cls,
**kwargs,
)

@classmethod
def create_from(cls,
*,
train_state,
params,
use_master_copy=False,
**kwargs):
"""Create a new instance where everything except master_copy is given."""
if use_master_copy:
dtype = jax.tree_util.tree_flatten(params)[0][0].dtype
assert dtype == jnp.float16

def get_sharding_spec(array):
if isinstance(array, alpa.device_mesh.DistributedArray):
return array.sharding_spec
else:
assert isinstance(array, alpa.device_mesh.ReplicatedDistributedArray)
return array.replica.sharding_spec

# create the master copy distributedly
master_copy = jax.tree_util.tree_map(
lambda x, y: copy_distributed_array(x, get_sharding_spec(y), jnp.float32), params, train_state.master_copy)
else:
master_copy = None
return cls(
step=train_state.step,
apply_fn=train_state.apply_fn,
params=params,
master_copy=master_copy,
tx=train_state.tx,
opt_state=train_state.opt_state,
**kwargs
)

class DynamicScale(struct.PyTreeNode):
"""This is the same as flax.optim.DynamicScale, except that
Expand Down
21 changes: 16 additions & 5 deletions alpa/parallel_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
- PipeshardParallel: which combines pipeline parallelism and shard parallelism.
"""
from abc import ABC, abstractmethod
from typing import Callable, Optional, Sequence, Union, Any
from typing import Callable, Optional, Sequence, Union, Any, List

from jax import linear_util as lu
from jax._src import traceback_util
Expand Down Expand Up @@ -248,9 +248,11 @@ def get_3d_parallel_method(num_micro_batches: int,
data_parallel: int,
operator_parallel: int,
pipeline_parallel: int,
allow_degenerate_into_shard_parallel: bool = True):
allow_degenerate_into_shard_parallel: bool = True,
use_manual_layer_option: bool = False,
forward_stage_layer_ids: List[List[int]] = None):
"""
Get a parallel method for 3D parallelism, which reguarlly combines
Get a parallel method for 3D parallelism, which regularly combines
data parallelism, operator parallelism and pipeline parallelism.
"""
# Validity check
Expand All @@ -259,6 +261,7 @@ def get_3d_parallel_method(num_micro_batches: int,
num_devices_per_host = virtual_mesh.num_devices_per_host
if data_parallel == -1:
data_parallel = (num_devices // operator_parallel // pipeline_parallel)
print(f"num_devices {num_devices} dp {data_parallel}, op {operator_parallel}, pp {pipeline_parallel}")
assert num_devices % data_parallel == 0
assert num_devices % operator_parallel == 0
assert num_devices % pipeline_parallel == 0
Expand Down Expand Up @@ -287,7 +290,15 @@ def get_3d_parallel_method(num_micro_batches: int,
[data_parallel, operator_parallel]))

# Return pipeshard parallel
layer_option = AutoLayerOption(layer_num=pp, eps=0.1)
if use_manual_layer_option:
# We assume each layer has been annotated using the mark_pipeline_boundary()
layer_option = ManualLayerOption()
assert forward_stage_layer_ids, "forward_stage_layer_ids must be provided " \
"when using manual annotation."
else:
# Note: this eps need some tuning.
layer_option = AutoLayerOption(layer_num=pp, eps=0.1)
forward_stage_layer_ids = [[i] for i in range(pp)]
return PipeshardParallel(
devices=virtual_mesh,
num_micro_batches=num_micro_batches,
Expand All @@ -297,7 +308,7 @@ def get_3d_parallel_method(num_micro_batches: int,
),
layer_option=layer_option,
stage_option=ManualStageOption(
forward_stage_layer_ids=[[i] for i in range(pp)],
forward_stage_layer_ids=forward_stage_layer_ids,
submesh_physical_shapes=[physical_mesh_shape] * pp,
submesh_logical_shapes=[logical_mesh_shape] * pp,
submesh_autosharding_option_dicts=[{}] * pp))
Expand Down
12 changes: 8 additions & 4 deletions alpa/pipeline_parallel/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
get_compile_options, jaxpr_to_hlo,
setup_computation_alias, compile_dummy_zero_constant,
get_var_mapping, undefined_sharding_spec_proto,
new_jaxpr_eqn)
new_jaxpr_eqn, replicated_sharding_spec_proto)
from alpa.wrapped_hlo import HloStatus, WrappedHlo

# pylint: disable=redefined-builtin
Expand Down Expand Up @@ -750,9 +750,13 @@ def generate_sharded_xla_computations_arguments(
hlo.set_input_shardings(sharding_protos)

if output_sharding_dict:
sharding_protos = [
output_sharding_dict[x].sharding_proto() for x in outvars
]
sharding_protos = []
for x in outvars:
spec = output_sharding_dict.get(x, None)
if spec is None:
sharding_protos.append(replicated_sharding_spec_proto())
else:
sharding_protos.append(spec.sharding_proto())
hlo.set_output_shardings(sharding_protos)

if stage_input_sharding:
Expand Down
7 changes: 7 additions & 0 deletions alpa/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,13 @@ def undefined_sharding_spec_proto():
return proto


def replicated_sharding_spec_proto():
"""Return a proto of ShardingSpec which represents a replicated spec."""
proto = xc.OpSharding()
proto.type = xc.OpSharding.Type.REPLICATED
return proto


########################################
##### Jaxpr Utilities
########################################
Expand Down
2 changes: 1 addition & 1 deletion examples/llm_serving/model/opt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jax import xla, jit
from jax.core import Primitive
from jax._src.lib import xla_client as xc
from transformers.generation_utils import dataclass
from dataclasses import dataclass


def sync(device_id=0):
Expand Down
4 changes: 3 additions & 1 deletion examples/llm_serving/model/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
jax_index_select)
from tqdm import tqdm
from transformers import OPTForCausalLM, BloomForCausalLM
from transformers.generation_utils import GenerationMixin, ModelOutput, dataclass
from transformers import GenerationMixin
from transformers.utils import ModelOutput
from dataclasses import dataclass

import alpa
from alpa.device_mesh import DistributedArray
Expand Down
29 changes: 29 additions & 0 deletions examples/opt_finetune/config_125m.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"_name_or_path": "facebook/opt-125m",
"weight_path": "/home/ubuntu/dataset/opt_weights/125M_np",
"activation_dropout": 0,
"activation_function": "relu",
"architectures": [
"OPTForCausalLM"
],
"attention_dropout": 0,
"bos_token_id": 2,
"do_layer_norm_before": true,
"dropout": 0.1,
"eos_token_id": 2,
"ffn_dim": 3072,
"hidden_size": 768,
"init_std": 0.02,
"layerdrop": 0,
"max_position_embeddings": 2048,
"model_type": "opt",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 1,
"prefix": "</s>",
"torch_dtype": "float16",
"transformers_version": "4.21.0.dev0",
"use_cache": true,
"vocab_size": 50272,
"word_embed_proj_dim": 768
}
Loading