Skip to content

Commit

Permalink
Merge branch 'main' into fix_get_attn_backend
Browse files Browse the repository at this point in the history
  • Loading branch information
cyanguwa committed Sep 27, 2024
2 parents e5635ce + 7b152a8 commit f96c1f0
Show file tree
Hide file tree
Showing 11 changed files with 48 additions and 19 deletions.
3 changes: 3 additions & 0 deletions docs/api/pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ pyTorch
.. autoapiclass:: transformer_engine.pytorch.Linear(in_features, out_features, bias=True, **kwargs)
:members: forward, set_tensor_parallel_group

.. autoapiclass:: transformer_engine.pytorch.GroupedLinear(in_features, out_features, bias=True, **kwargs)
:members: forward, set_tensor_parallel_group

.. autoapiclass:: transformer_engine.pytorch.LayerNorm(hidden_size, eps=1e-5, **kwargs)

.. autoapiclass:: transformer_engine.pytorch.RMSNorm(hidden_size, eps=1e-5, **kwargs)
Expand Down
2 changes: 2 additions & 0 deletions tests/pytorch/fused_attn/test_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
pytest.skip("THD format is only supported on sm90+!")
if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0):
pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!")
if dtype == "fp8" and get_device_compute_capability() < (9, 0):
pytest.skip("FP8 attention is only supported on sm90+!")

config = model_configs_fused_attn[model]
if qkv_format == "thd" and config.num_heads != config.num_gqa_groups:
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7878,7 +7878,7 @@ class MultiheadAttention(torch.nn.Module):
bias : bool, default = `True`
if set to `False`, the transformer layer will not learn any additive biases.
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
qkv_format: str, default = `sbhd`
Expand Down
36 changes: 30 additions & 6 deletions transformer_engine/pytorch/csrc/extensions/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,21 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
auto qkv_sizes = QKV.sizes().vec();
std::vector<size_t> qkv_shape{qkv_sizes.begin(), qkv_sizes.end()};
std::vector<size_t> q_shape;
for (auto i : qkv_shape) {
if (i != 3) {
q_shape.push_back(i);
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
int loc_3 = 0;
switch (layout_group) {
case NVTE_3HD:
loc_3 = qkv_sizes.size() - 3;
break;
case NVTE_H3D:
loc_3 = qkv_sizes.size() - 2;
break;
default:
NVTE_ERROR("Invalid QKV layout group.");
}
for (auto it = qkv_shape.begin(); it != qkv_shape.end(); ++it) {
if (it - qkv_shape.begin() != loc_3) {
q_shape.push_back(*it);
}
}
std::vector<int64_t> o_shape{q_shape.begin(), q_shape.end()};
Expand Down Expand Up @@ -252,9 +264,21 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
auto qkv_sizes = QKV.sizes().vec();
std::vector<size_t> qkv_shape{qkv_sizes.begin(), qkv_sizes.end()};
std::vector<size_t> q_shape;
for (auto i : qkv_shape) {
if (i != 3) {
q_shape.push_back(i);
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
int loc_3 = 0;
switch (layout_group) {
case NVTE_3HD:
loc_3 = qkv_sizes.size() - 3;
break;
case NVTE_H3D:
loc_3 = qkv_sizes.size() - 2;
break;
default:
NVTE_ERROR("Invalid QKV layout group.");
}
for (auto it = qkv_shape.begin(); it != qkv_shape.end(); ++it) {
if (it - qkv_shape.begin() != loc_3) {
q_shape.push_back(*it);
}
}
auto h = q_shape[q_shape.size() - 2];
Expand Down
6 changes: 3 additions & 3 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,11 +528,11 @@ class GroupedLinear(TransformerEngineBaseModule):
used for initializing weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
get_rng_state_tracker : Callable, default = `None`
used to get the random number generator state tracker for initilizeing weights.
used to get the random number generator state tracker for initializing weights.
rng_tracker_name : str, default = `None`
the param passed to get_rng_state_tracker to get the specific rng tracker.
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
Expand All @@ -548,7 +548,7 @@ class GroupedLinear(TransformerEngineBaseModule):
`set_tensor_parallel_group(tp_group)` method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives.
parallel_mode : {None, 'Column', 'Row'}, default = `None`
parallel_mode : {None, 'column', 'row'}, default = `None`
used to decide whether this GroupedLinear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed.
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/module/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class LayerNorm(torch.nn.Module):
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
"""
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
Expand All @@ -832,7 +832,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
`set_tensor_parallel_group(tp_group)` method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives.
parallel_mode : {None, 'Column', 'Row'}, default = `None`
parallel_mode : {None, 'column', 'row'}, default = `None`
used to decide whether this Linear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed.
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,7 +1193,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
Expand Down
6 changes: 3 additions & 3 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ class Linear(TransformerEngineBaseModule):
used for initializing weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
get_rng_state_tracker : Callable, default = `None`
used to get the random number generator state tracker for initilizeing weights.
used to get the random number generator state tracker for initializing weights.
rng_tracker_name : str, default = `None`
the param passed to get_rng_state_tracker to get the specific rng tracker.
parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None
Expand All @@ -662,7 +662,7 @@ class Linear(TransformerEngineBaseModule):
names that end in `_weight` or `_bias`, so trailing underscores are
stripped from any provided names.
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
Expand All @@ -678,7 +678,7 @@ class Linear(TransformerEngineBaseModule):
`set_tensor_parallel_group(tp_group)` method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives.
parallel_mode : {None, 'Column', 'Row'}, default = `None`
parallel_mode : {None, 'column', 'row'}, default = `None`
used to decide whether this Linear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed.
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/module/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class RMSNorm(torch.nn.Module):
.. math::
y = \frac{x}{RMS_\varepsilon(x)} * (1 + \gamma)
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
"""
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ class TransformerLayer(torch.nn.Module):
Type of activation used in MLP block.
Options are: 'gelu', 'relu', 'reglu', 'geglu', 'swiglu', 'qgelu' and 'srelu'.
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
attn_input_format: {'sbhd', 'bshd'}, default = 'sbhd'
Expand Down

0 comments on commit f96c1f0

Please sign in to comment.