diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index c25aa863ad..c7a85029e6 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -36,6 +36,9 @@ jobs: || github.actor == 'yaox12' || github.actor == 'huanghua1994' || github.actor == 'mgoldfarb-nvidia' + || github.actor == 'pggPL' + || github.actor == 'vasunvidia' + || github.actor == 'erhoo82' ) steps: - name: Check if comment is issued by authorized person diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index 3725e58c87..9152229d2f 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -80,15 +80,15 @@ def setup_pytorch_extension( ) ) - if "80" in cuda_architectures: - nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"]) - if "90" in cuda_architectures: - nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"]) + for arch in cuda_architectures.split(";"): + if arch == "70": + continue # Already handled + nvcc_flags.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"]) # Libraries library_dirs = [] libraries = [] - if os.getenv("NVTE_UB_WITH_MPI"): + if bool(int(os.getenv("NVTE_UB_WITH_MPI", 0))): assert ( os.getenv("MPI_HOME") is not None ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1" diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index a210019dc1..b097f14475 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -9,6 +9,9 @@ pyTorch .. autoapiclass:: transformer_engine.pytorch.Linear(in_features, out_features, bias=True, **kwargs) :members: forward, set_tensor_parallel_group +.. autoapiclass:: transformer_engine.pytorch.GroupedLinear(in_features, out_features, bias=True, **kwargs) + :members: forward, set_tensor_parallel_group + .. autoapiclass:: transformer_engine.pytorch.LayerNorm(hidden_size, eps=1e-5, **kwargs) .. autoapiclass:: transformer_engine.pytorch.RMSNorm(hidden_size, eps=1e-5, **kwargs) diff --git a/docs/examples/te_llama/te_llama.py b/docs/examples/te_llama/te_llama.py index cb384aa10c..4413bdfd00 100644 --- a/docs/examples/te_llama/te_llama.py +++ b/docs/examples/te_llama/te_llama.py @@ -102,8 +102,11 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k Custom method adapted from `from_pretrained` method in HuggingFace Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579 """ - vanilla_model = cls(config).to(kwargs["torch_dtype"]) - is_local = os.path.isdir(pretrained_model_name_or_path) + # Before loading the model, set the default dtype for torch + torch.set_default_dtype(kwargs["torch_dtype"]) + + # Load the vanilla model weights + vanilla_model = cls(config) subfolder = "" variant = None if os.path.isfile( @@ -133,7 +136,7 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k else: raise AssertionError("Only sharded PyTorch ckpt format supported at the moment") - resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( + resolved_archive_file, _ = get_checkpoint_shard_files( pretrained_model_name_or_path, archive_file, ) diff --git a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb index 57c1bf6601..7013e85ec6 100644 --- a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb +++ b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb @@ -247,15 +247,24 @@ "restart_jupyter_notebook()\n", "\n", "\n", - "# Import necessary packages and methods\n", + "# Import necessary packages, methods and variables\n", "from utils import *\n", "\n", "\n", - "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", - "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n", - "# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n", - "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "# Provide Huggingface Access Token\n", + "hyperparams.hf_access_token = \"\"\n", + "assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "hyperparams.weights_cache_dir = \"\"\n", + "\n", + "# For Llama 2, uncomment this line (also set by default)\n", + "hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n", + "\n", + "# For Llama 3, uncomment this line\n", + "# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n", + "\n", "hyperparams.mixed_precision = \"bf16\"\n", "\n", "\n", @@ -554,7 +563,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "bdb34b91", "metadata": {}, "outputs": [ @@ -573,15 +582,24 @@ "restart_jupyter_notebook()\n", "\n", "\n", - "# Import necessary packages and methods\n", + "# Import necessary packages, methods and variables\n", "from utils import *\n", "\n", "\n", - "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", - "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n", - "# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n", - "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "# Provide Huggingface Access Token\n", + "hyperparams.hf_access_token = \"\"\n", + "assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "hyperparams.weights_cache_dir = \"\"\n", + "\n", + "# For Llama 2, uncomment this line (also set by default)\n", + "hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n", + "\n", + "# For Llama 3, uncomment this line\n", + "# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n", + "\n", "hyperparams.mixed_precision = \"bf16\"\n", "\n", "\n", @@ -653,15 +671,24 @@ "restart_jupyter_notebook()\n", "\n", "\n", - "# Import necessary packages and methods\n", + "# Import necessary packages, methods and variables\n", "from utils import *\n", "\n", "\n", - "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", - "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n", - "# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n", - "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "# Provide Huggingface Access Token\n", + "hyperparams.hf_access_token = \"\"\n", + "assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "hyperparams.weights_cache_dir = \"\"\n", + "\n", + "# For Llama 2, uncomment this line (also set by default)\n", + "hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n", + "\n", + "# For Llama 3, uncomment this line\n", + "# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n", + "\n", "hyperparams.mixed_precision = \"fp8\"\n", "\n", "\n", diff --git a/docs/examples/te_llama/utils.py b/docs/examples/te_llama/utils.py index b6b3683d4c..1aebe13afb 100644 --- a/docs/examples/te_llama/utils.py +++ b/docs/examples/te_llama/utils.py @@ -25,7 +25,10 @@ class HyperParameters: def __init__(self): self.mixed_precision = "bf16" - # self.model_name = "" # <== Add model weight location here + + # Set to Meta Llama 2 by default. + self.model_name = "meta-llama/Llama-2-7b-hf" + self.dataset_name = "timdettmers/openassistant-guanaco" self.dataset_text_field = "text" self.learning_rate = 1.41e-5 @@ -35,6 +38,10 @@ def __init__(self): self.num_warmup_steps = 5 self.num_training_steps = 10 + # This is either provided by the user or it will be set when the + # model weights are downloaded. + self.weights_cache_dir = "" + hyperparams = HyperParameters() @@ -76,13 +83,49 @@ def tokenize(element): return train_dataloader +def ensure_model_is_downloaded(hyperparams): + assert hyperparams.model_name in [ + "meta-llama/Meta-Llama-3-8B", + "meta-llama/Llama-2-7b-hf", + ], "Only Meta Llama 2 7B and Meta Llama 3 8B models are supported!" + + # Login using Huggingface Hub API + from huggingface_hub import login + + try: + login(hyperparams.hf_access_token) + except Exception as e: + if "Invalid token passed!" in str(e): + print( + "Please pass a valid HF Access Token! More info at" + " https://huggingface.co/docs/hub/en/security-tokens." + ) + else: + print(f"Exception is {e}") + + # Download the model if it doesn't exist + from huggingface_hub import snapshot_download + + supplied_cache_dir = ( + hyperparams.weights_cache_dir if hyperparams.weights_cache_dir != "" else None + ) + hyperparams.weights_cache_dir = snapshot_download( + repo_id=hyperparams.model_name, cache_dir=supplied_cache_dir + ) + + print(f"Model cache directory : {hyperparams.weights_cache_dir}") + + def init_baseline_model(hyperparams): + # Download and cache the weights + ensure_model_is_downloaded(hyperparams) + # Init the model - config = AutoConfig.from_pretrained(hyperparams.model_name) + config = AutoConfig.from_pretrained(hyperparams.weights_cache_dir) # make sure to use flash_attention to do iso comparison with TELlamaModel config._attn_implementation = "flash_attention_2" model = AutoModelForCausalLM.from_pretrained( - hyperparams.model_name, + hyperparams.weights_cache_dir, config=config, torch_dtype=torch.bfloat16, ) @@ -94,13 +137,16 @@ def init_baseline_model(hyperparams): def init_te_llama_model(hyperparams): + # Download and cache the weights + ensure_model_is_downloaded(hyperparams) + # Init the model from te_llama import TELlamaForCausalLM - config = AutoConfig.from_pretrained(hyperparams.model_name) + config = AutoConfig.from_pretrained(hyperparams.weights_cache_dir) config._attn_implementation = "flash_attention_2" model = TELlamaForCausalLM.from_pretrained_local( - hyperparams.model_name, + hyperparams.weights_cache_dir, config=config, torch_dtype=torch.bfloat16, ) diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index d6358d1062..c1c18ffe47 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -112,6 +112,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): pytest.skip("THD format is only supported on sm90+!") if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0): pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!") + if dtype == "fp8" and get_device_compute_capability() < (9, 0): + pytest.skip("FP8 attention is only supported on sm90+!") config = model_configs_fused_attn[model] if qkv_format == "thd" and config.num_heads != config.num_gqa_groups: diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 6c61469f90..b6a2b2dad5 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -7901,7 +7901,7 @@ class MultiheadAttention(torch.nn.Module): bias : bool, default = `True` if set to `False`, the transformer layer will not learn any additive biases. device : Union[torch.device, str], default = "cuda" - The device on which the parameters of the model will allocated. It is the user's + The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. qkv_format: str, default = `sbhd` diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index fb1fc97a33..b2968a688d 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -95,9 +95,21 @@ std::vector fused_attn_fwd_qkvpacked( auto qkv_sizes = QKV.sizes().vec(); std::vector qkv_shape{qkv_sizes.begin(), qkv_sizes.end()}; std::vector q_shape; - for (auto i : qkv_shape) { - if (i != 3) { - q_shape.push_back(i); + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + int loc_3 = 0; + switch (layout_group) { + case NVTE_3HD: + loc_3 = qkv_sizes.size() - 3; + break; + case NVTE_H3D: + loc_3 = qkv_sizes.size() - 2; + break; + default: + NVTE_ERROR("Invalid QKV layout group."); + } + for (auto it = qkv_shape.begin(); it != qkv_shape.end(); ++it) { + if (it - qkv_shape.begin() != loc_3) { + q_shape.push_back(*it); } } std::vector o_shape{q_shape.begin(), q_shape.end()}; @@ -252,9 +264,21 @@ std::vector fused_attn_bwd_qkvpacked( auto qkv_sizes = QKV.sizes().vec(); std::vector qkv_shape{qkv_sizes.begin(), qkv_sizes.end()}; std::vector q_shape; - for (auto i : qkv_shape) { - if (i != 3) { - q_shape.push_back(i); + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + int loc_3 = 0; + switch (layout_group) { + case NVTE_3HD: + loc_3 = qkv_sizes.size() - 3; + break; + case NVTE_H3D: + loc_3 = qkv_sizes.size() - 2; + break; + default: + NVTE_ERROR("Invalid QKV layout group."); + } + for (auto it = qkv_shape.begin(); it != qkv_shape.end(); ++it) { + if (it - qkv_shape.begin() != loc_3) { + q_shape.push_back(*it); } } auto h = q_shape[q_shape.size() - 2]; diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index fdf65db21e..e9fb11e3b9 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -354,12 +354,8 @@ def backward( # Compute the forward pass. detached_inputs = detach_variable(inputs) - with ( - torch.enable_grad(), - ctx.recompute_ctx, - ctx.torch_gpu_amp_ctx, - ctx.torch_cpu_amp_ctx, - activation_recompute_forward(activation_recompute=True, recompute_phase=True), + with torch.enable_grad(), ctx.recompute_ctx, ctx.torch_gpu_amp_ctx, ctx.torch_cpu_amp_ctx, activation_recompute_forward( + activation_recompute=True, recompute_phase=True ): outputs = ctx.run_function(*detached_inputs, **ctx.kwargs) @@ -680,13 +676,9 @@ def checkpoint( torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx = _get_active_autocast_contexts() def recompute_fn(*args, **kwargs): - with ( - torch.autograd.enable_grad(), - te_recompute_ctx, - user_recompute_ctx, - torch_gpu_amp_forward_ctx, - torch_cpu_amp_forward_ctx, - ): + with torch.autograd.enable_grad(), ( + te_recompute_ctx + ), user_recompute_ctx, torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx: function(*args, **kwargs) # Initialize a new checkpoint frame for each new forward pass. diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 10c8d91551..0bad1306c3 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -528,11 +528,11 @@ class GroupedLinear(TransformerEngineBaseModule): used for initializing weights in the following way: `init_method(weight)`. When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. get_rng_state_tracker : Callable, default = `None` - used to get the random number generator state tracker for initilizeing weights. + used to get the random number generator state tracker for initializing weights. rng_tracker_name : str, default = `None` the param passed to get_rng_state_tracker to get the specific rng tracker. device : Union[torch.device, str], default = "cuda" - The device on which the parameters of the model will allocated. It is the user's + The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. @@ -548,7 +548,7 @@ class GroupedLinear(TransformerEngineBaseModule): `set_tensor_parallel_group(tp_group)` method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives. - parallel_mode : {None, 'Column', 'Row'}, default = `None` + parallel_mode : {None, 'column', 'row'}, default = `None` used to decide whether this GroupedLinear layer is Column Parallel Linear or Row Parallel Linear as described `here `_. When set to `None`, no communication is performed. diff --git a/transformer_engine/pytorch/module/layernorm.py b/transformer_engine/pytorch/module/layernorm.py index ec33ad2033..292fcd06de 100644 --- a/transformer_engine/pytorch/module/layernorm.py +++ b/transformer_engine/pytorch/module/layernorm.py @@ -110,7 +110,7 @@ class LayerNorm(torch.nn.Module): y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta device : Union[torch.device, str], default = "cuda" - The device on which the parameters of the model will allocated. It is the user's + The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. """ diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index da77879e06..92030a7f7a 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -816,7 +816,7 @@ class LayerNormLinear(TransformerEngineBaseModule): y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta device : Union[torch.device, str], default = "cuda" - The device on which the parameters of the model will allocated. It is the user's + The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. @@ -832,7 +832,7 @@ class LayerNormLinear(TransformerEngineBaseModule): `set_tensor_parallel_group(tp_group)` method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives. - parallel_mode : {None, 'Column', 'Row'}, default = `None` + parallel_mode : {None, 'column', 'row'}, default = `None` used to decide whether this Linear layer is Column Parallel Linear or Row Parallel Linear as described `here `_. When set to `None`, no communication is performed. diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index b802c972d4..6d5609ccd2 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1193,7 +1193,7 @@ class LayerNormMLP(TransformerEngineBaseModule): y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta device : Union[torch.device, str], default = "cuda" - The device on which the parameters of the model will allocated. It is the user's + The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index a7be82ccf1..8e19a65a28 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -650,7 +650,7 @@ class Linear(TransformerEngineBaseModule): used for initializing weights in the following way: `init_method(weight)`. When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. get_rng_state_tracker : Callable, default = `None` - used to get the random number generator state tracker for initilizeing weights. + used to get the random number generator state tracker for initializing weights. rng_tracker_name : str, default = `None` the param passed to get_rng_state_tracker to get the specific rng tracker. parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None @@ -662,7 +662,7 @@ class Linear(TransformerEngineBaseModule): names that end in `_weight` or `_bias`, so trailing underscores are stripped from any provided names. device : Union[torch.device, str], default = "cuda" - The device on which the parameters of the model will allocated. It is the user's + The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. @@ -678,7 +678,7 @@ class Linear(TransformerEngineBaseModule): `set_tensor_parallel_group(tp_group)` method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives. - parallel_mode : {None, 'Column', 'Row'}, default = `None` + parallel_mode : {None, 'column', 'row'}, default = `None` used to decide whether this Linear layer is Column Parallel Linear or Row Parallel Linear as described `here `_. When set to `None`, no communication is performed. diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py index 969a468426..d5dc400206 100644 --- a/transformer_engine/pytorch/module/rmsnorm.py +++ b/transformer_engine/pytorch/module/rmsnorm.py @@ -120,7 +120,7 @@ class RMSNorm(torch.nn.Module): .. math:: y = \frac{x}{RMS_\varepsilon(x)} * (1 + \gamma) device : Union[torch.device, str], default = "cuda" - The device on which the parameters of the model will allocated. It is the user's + The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. """ diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 958c7019ba..020d262be2 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -173,7 +173,7 @@ class TransformerLayer(torch.nn.Module): Type of activation used in MLP block. Options are: 'gelu', 'relu', 'reglu', 'geglu', 'swiglu', 'qgelu' and 'srelu'. device : Union[torch.device, str], default = "cuda" - The device on which the parameters of the model will allocated. It is the user's + The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. attn_input_format: {'sbhd', 'bshd'}, default = 'sbhd'