Skip to content

Commit

Permalink
Merge branch 'main' into add_descales
Browse files Browse the repository at this point in the history
  • Loading branch information
cyanguwa authored Sep 27, 2024
2 parents 5bcc355 + 7b152a8 commit 3dbee25
Show file tree
Hide file tree
Showing 17 changed files with 164 additions and 64 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/trigger-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ jobs:
|| github.actor == 'yaox12'
|| github.actor == 'huanghua1994'
|| github.actor == 'mgoldfarb-nvidia'
|| github.actor == 'pggPL'
|| github.actor == 'vasunvidia'
|| github.actor == 'erhoo82'
)
steps:
- name: Check if comment is issued by authorized person
Expand Down
10 changes: 5 additions & 5 deletions build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,15 @@ def setup_pytorch_extension(
)
)

if "80" in cuda_architectures:
nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"])
if "90" in cuda_architectures:
nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"])
for arch in cuda_architectures.split(";"):
if arch == "70":
continue # Already handled
nvcc_flags.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"])

# Libraries
library_dirs = []
libraries = []
if os.getenv("NVTE_UB_WITH_MPI"):
if bool(int(os.getenv("NVTE_UB_WITH_MPI", 0))):
assert (
os.getenv("MPI_HOME") is not None
), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
Expand Down
3 changes: 3 additions & 0 deletions docs/api/pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ pyTorch
.. autoapiclass:: transformer_engine.pytorch.Linear(in_features, out_features, bias=True, **kwargs)
:members: forward, set_tensor_parallel_group

.. autoapiclass:: transformer_engine.pytorch.GroupedLinear(in_features, out_features, bias=True, **kwargs)
:members: forward, set_tensor_parallel_group

.. autoapiclass:: transformer_engine.pytorch.LayerNorm(hidden_size, eps=1e-5, **kwargs)

.. autoapiclass:: transformer_engine.pytorch.RMSNorm(hidden_size, eps=1e-5, **kwargs)
Expand Down
9 changes: 6 additions & 3 deletions docs/examples/te_llama/te_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,11 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k
Custom method adapted from `from_pretrained` method in HuggingFace
Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579
"""
vanilla_model = cls(config).to(kwargs["torch_dtype"])
is_local = os.path.isdir(pretrained_model_name_or_path)
# Before loading the model, set the default dtype for torch
torch.set_default_dtype(kwargs["torch_dtype"])

# Load the vanilla model weights
vanilla_model = cls(config)
subfolder = ""
variant = None
if os.path.isfile(
Expand Down Expand Up @@ -133,7 +136,7 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k
else:
raise AssertionError("Only sharded PyTorch ckpt format supported at the moment")

resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
resolved_archive_file, _ = get_checkpoint_shard_files(
pretrained_model_name_or_path,
archive_file,
)
Expand Down
65 changes: 46 additions & 19 deletions docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -247,15 +247,24 @@
"restart_jupyter_notebook()\n",
"\n",
"\n",
"# Import necessary packages and methods\n",
"# Import necessary packages, methods and variables\n",
"from utils import *\n",
"\n",
"\n",
"# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n",
"## !!! `model_name` attr must point to the location of the model weights !!!\n",
"# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n",
"# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n",
"hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n",
"# Provide Huggingface Access Token\n",
"hyperparams.hf_access_token = \"\"\n",
"assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n",
"\n",
"# Provide a directory to cache weights in to avoid downloading them every time.\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
"hyperparams.weights_cache_dir = \"\"\n",
"\n",
"# For Llama 2, uncomment this line (also set by default)\n",
"hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n",
"\n",
"# For Llama 3, uncomment this line\n",
"# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n",
"\n",
"hyperparams.mixed_precision = \"bf16\"\n",
"\n",
"\n",
Expand Down Expand Up @@ -554,7 +563,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"id": "bdb34b91",
"metadata": {},
"outputs": [
Expand All @@ -573,15 +582,24 @@
"restart_jupyter_notebook()\n",
"\n",
"\n",
"# Import necessary packages and methods\n",
"# Import necessary packages, methods and variables\n",
"from utils import *\n",
"\n",
"\n",
"# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n",
"## !!! `model_name` attr must point to the location of the model weights !!!\n",
"# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n",
"# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n",
"hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n",
"# Provide Huggingface Access Token\n",
"hyperparams.hf_access_token = \"\"\n",
"assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n",
"\n",
"# Provide a directory to cache weights in to avoid downloading them every time.\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
"hyperparams.weights_cache_dir = \"\"\n",
"\n",
"# For Llama 2, uncomment this line (also set by default)\n",
"hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n",
"\n",
"# For Llama 3, uncomment this line\n",
"# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n",
"\n",
"hyperparams.mixed_precision = \"bf16\"\n",
"\n",
"\n",
Expand Down Expand Up @@ -653,15 +671,24 @@
"restart_jupyter_notebook()\n",
"\n",
"\n",
"# Import necessary packages and methods\n",
"# Import necessary packages, methods and variables\n",
"from utils import *\n",
"\n",
"\n",
"# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n",
"## !!! `model_name` attr must point to the location of the model weights !!!\n",
"# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n",
"# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n",
"hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n",
"# Provide Huggingface Access Token\n",
"hyperparams.hf_access_token = \"\"\n",
"assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n",
"\n",
"# Provide a directory to cache weights in to avoid downloading them every time.\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
"hyperparams.weights_cache_dir = \"\"\n",
"\n",
"# For Llama 2, uncomment this line (also set by default)\n",
"hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n",
"\n",
"# For Llama 3, uncomment this line\n",
"# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n",
"\n",
"hyperparams.mixed_precision = \"fp8\"\n",
"\n",
"\n",
Expand Down
56 changes: 51 additions & 5 deletions docs/examples/te_llama/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
class HyperParameters:
def __init__(self):
self.mixed_precision = "bf16"
# self.model_name = "" # <== Add model weight location here

# Set to Meta Llama 2 by default.
self.model_name = "meta-llama/Llama-2-7b-hf"

self.dataset_name = "timdettmers/openassistant-guanaco"
self.dataset_text_field = "text"
self.learning_rate = 1.41e-5
Expand All @@ -35,6 +38,10 @@ def __init__(self):
self.num_warmup_steps = 5
self.num_training_steps = 10

# This is either provided by the user or it will be set when the
# model weights are downloaded.
self.weights_cache_dir = ""


hyperparams = HyperParameters()

Expand Down Expand Up @@ -76,13 +83,49 @@ def tokenize(element):
return train_dataloader


def ensure_model_is_downloaded(hyperparams):
assert hyperparams.model_name in [
"meta-llama/Meta-Llama-3-8B",
"meta-llama/Llama-2-7b-hf",
], "Only Meta Llama 2 7B and Meta Llama 3 8B models are supported!"

# Login using Huggingface Hub API
from huggingface_hub import login

try:
login(hyperparams.hf_access_token)
except Exception as e:
if "Invalid token passed!" in str(e):
print(
"Please pass a valid HF Access Token! More info at"
" https://huggingface.co/docs/hub/en/security-tokens."
)
else:
print(f"Exception is {e}")

# Download the model if it doesn't exist
from huggingface_hub import snapshot_download

supplied_cache_dir = (
hyperparams.weights_cache_dir if hyperparams.weights_cache_dir != "" else None
)
hyperparams.weights_cache_dir = snapshot_download(
repo_id=hyperparams.model_name, cache_dir=supplied_cache_dir
)

print(f"Model cache directory : {hyperparams.weights_cache_dir}")


def init_baseline_model(hyperparams):
# Download and cache the weights
ensure_model_is_downloaded(hyperparams)

# Init the model
config = AutoConfig.from_pretrained(hyperparams.model_name)
config = AutoConfig.from_pretrained(hyperparams.weights_cache_dir)
# make sure to use flash_attention to do iso comparison with TELlamaModel
config._attn_implementation = "flash_attention_2"
model = AutoModelForCausalLM.from_pretrained(
hyperparams.model_name,
hyperparams.weights_cache_dir,
config=config,
torch_dtype=torch.bfloat16,
)
Expand All @@ -94,13 +137,16 @@ def init_baseline_model(hyperparams):


def init_te_llama_model(hyperparams):
# Download and cache the weights
ensure_model_is_downloaded(hyperparams)

# Init the model
from te_llama import TELlamaForCausalLM

config = AutoConfig.from_pretrained(hyperparams.model_name)
config = AutoConfig.from_pretrained(hyperparams.weights_cache_dir)
config._attn_implementation = "flash_attention_2"
model = TELlamaForCausalLM.from_pretrained_local(
hyperparams.model_name,
hyperparams.weights_cache_dir,
config=config,
torch_dtype=torch.bfloat16,
)
Expand Down
2 changes: 2 additions & 0 deletions tests/pytorch/fused_attn/test_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
pytest.skip("THD format is only supported on sm90+!")
if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0):
pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!")
if dtype == "fp8" and get_device_compute_capability() < (9, 0):
pytest.skip("FP8 attention is only supported on sm90+!")

config = model_configs_fused_attn[model]
if qkv_format == "thd" and config.num_heads != config.num_gqa_groups:
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7901,7 +7901,7 @@ class MultiheadAttention(torch.nn.Module):
bias : bool, default = `True`
if set to `False`, the transformer layer will not learn any additive biases.
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
qkv_format: str, default = `sbhd`
Expand Down
36 changes: 30 additions & 6 deletions transformer_engine/pytorch/csrc/extensions/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,21 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
auto qkv_sizes = QKV.sizes().vec();
std::vector<size_t> qkv_shape{qkv_sizes.begin(), qkv_sizes.end()};
std::vector<size_t> q_shape;
for (auto i : qkv_shape) {
if (i != 3) {
q_shape.push_back(i);
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
int loc_3 = 0;
switch (layout_group) {
case NVTE_3HD:
loc_3 = qkv_sizes.size() - 3;
break;
case NVTE_H3D:
loc_3 = qkv_sizes.size() - 2;
break;
default:
NVTE_ERROR("Invalid QKV layout group.");
}
for (auto it = qkv_shape.begin(); it != qkv_shape.end(); ++it) {
if (it - qkv_shape.begin() != loc_3) {
q_shape.push_back(*it);
}
}
std::vector<int64_t> o_shape{q_shape.begin(), q_shape.end()};
Expand Down Expand Up @@ -252,9 +264,21 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
auto qkv_sizes = QKV.sizes().vec();
std::vector<size_t> qkv_shape{qkv_sizes.begin(), qkv_sizes.end()};
std::vector<size_t> q_shape;
for (auto i : qkv_shape) {
if (i != 3) {
q_shape.push_back(i);
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
int loc_3 = 0;
switch (layout_group) {
case NVTE_3HD:
loc_3 = qkv_sizes.size() - 3;
break;
case NVTE_H3D:
loc_3 = qkv_sizes.size() - 2;
break;
default:
NVTE_ERROR("Invalid QKV layout group.");
}
for (auto it = qkv_shape.begin(); it != qkv_shape.end(); ++it) {
if (it - qkv_shape.begin() != loc_3) {
q_shape.push_back(*it);
}
}
auto h = q_shape[q_shape.size() - 2];
Expand Down
18 changes: 5 additions & 13 deletions transformer_engine/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,12 +354,8 @@ def backward(

# Compute the forward pass.
detached_inputs = detach_variable(inputs)
with (
torch.enable_grad(),
ctx.recompute_ctx,
ctx.torch_gpu_amp_ctx,
ctx.torch_cpu_amp_ctx,
activation_recompute_forward(activation_recompute=True, recompute_phase=True),
with torch.enable_grad(), ctx.recompute_ctx, ctx.torch_gpu_amp_ctx, ctx.torch_cpu_amp_ctx, activation_recompute_forward(
activation_recompute=True, recompute_phase=True
):
outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)

Expand Down Expand Up @@ -680,13 +676,9 @@ def checkpoint(
torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx = _get_active_autocast_contexts()

def recompute_fn(*args, **kwargs):
with (
torch.autograd.enable_grad(),
te_recompute_ctx,
user_recompute_ctx,
torch_gpu_amp_forward_ctx,
torch_cpu_amp_forward_ctx,
):
with torch.autograd.enable_grad(), (
te_recompute_ctx
), user_recompute_ctx, torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx:
function(*args, **kwargs)

# Initialize a new checkpoint frame for each new forward pass.
Expand Down
6 changes: 3 additions & 3 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,11 +528,11 @@ class GroupedLinear(TransformerEngineBaseModule):
used for initializing weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
get_rng_state_tracker : Callable, default = `None`
used to get the random number generator state tracker for initilizeing weights.
used to get the random number generator state tracker for initializing weights.
rng_tracker_name : str, default = `None`
the param passed to get_rng_state_tracker to get the specific rng tracker.
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
Expand All @@ -548,7 +548,7 @@ class GroupedLinear(TransformerEngineBaseModule):
`set_tensor_parallel_group(tp_group)` method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives.
parallel_mode : {None, 'Column', 'Row'}, default = `None`
parallel_mode : {None, 'column', 'row'}, default = `None`
used to decide whether this GroupedLinear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed.
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/module/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class LayerNorm(torch.nn.Module):
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
"""
Expand Down
Loading

0 comments on commit 3dbee25

Please sign in to comment.