Skip to content

Commit

Permalink
Fix docstring related to t in thd (#1111)
Browse files Browse the repository at this point in the history
fix typos regarding t in thd

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
  • Loading branch information
cyanguwa authored Aug 15, 2024
1 parent a326e35 commit 941364d
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ extern "C" {
/*! \enum NVTE_QKV_Layout
* \brief Memory layouts of QKV tensors.
* `S`, `B`, `H`, `D`, and `T` stand for sequence length, batch size, number of heads,
* head size, and the total number of sequences in a batch, i.e. `t = sum(s_i) for i = 0...b-1`.
* head size, and the total number of tokens in a batch, i.e. `t = sum(s_i) for i = 0...b-1`.
* `SBHD` and `BSHD`-based layouts are used when sequences in a batch are of equal length
* or padded to the same length, and `THD`-based layouts are used when sequences have
* different lengths in a batch.
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3122,7 +3122,7 @@ def get_qkv_layout(
qkv_format: str, default = `sbhd`
Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for
the sequence length dimension, `b` batch size, `h` the number of attention heads,
`d` head size, and `t` the total number of sequences in a batch, i.e.
`d` head size, and `t` the total number of tokens in a batch, i.e.
`t = sum(s_i) for i = 0...b-1`.
Returns
Expand Down Expand Up @@ -5232,7 +5232,7 @@ class DotProductAttention(TransformerEngineBaseModule):
qkv_format: str, default = `sbhd`
dimension format for `query_layer`, `key_layer` and `value_layer`,
{`sbhd`, `bshd`, `thd`}. `s` stands for the sequence length, `b` batch size,
`h` the number of heads, `d` head size, and `t` the total number of sequences
`h` the number of heads, `d` head size, and `t` the total number of tokens
in a batch, with `t = sum(s_i), for i = 0...b-1`. `sbhd` and `bshd` formats
are used for when sequences in a batch are of equal length or padded to
equal length, and the `thd` format is used for when sequences in a batch
Expand Down

0 comments on commit 941364d

Please sign in to comment.