Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Aug 16, 2024
1 parent ade0c02 commit e5d40a6
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 10 deletions.
9 changes: 8 additions & 1 deletion tests/pytorch/test_cuda_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
"linear_op",
]


def _test_cuda_graphs(
*,
graph_mode: str,
Expand Down Expand Up @@ -331,6 +332,7 @@ def test_make_graphed_callables(
"mha",
]


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize(
"module",
Expand Down Expand Up @@ -478,7 +480,12 @@ def _test_cuda_graphs_with_kwargs(
grad_output = generate_data(model_config, dtype, requires_grad=False)
attn_mask = torch.randint(
2,
(model_config.batch_size, 1, model_config.sequence_length, model_config.sequence_length),
(
model_config.batch_size,
1,
model_config.sequence_length,
model_config.sequence_length,
),
dtype=torch.bool,
device="cuda",
)
Expand Down
2 changes: 2 additions & 0 deletions transformer_engine/pytorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ def save_fp8_tensors(
with adjusted amax history sizes.
"""
from .ops import Sequential, FusibleOperation # Avoid circular import

fp8_tensors = []
for module in modules:
for m in module.modules():
Expand Down Expand Up @@ -521,6 +522,7 @@ def restore_fp8_tensors(
) -> None:
"""Restore FP8 tensors."""
from .ops import Sequential, FusibleOperation # Avoid circular import

for module in modules:
for m in module.modules():
module_tensors = fp8_tensors.pop(0)
Expand Down
21 changes: 12 additions & 9 deletions transformer_engine/pytorch/ops/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,9 @@ def _maybe_update_fp8_meta(
FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = fp8_meta[
fp8_meta_key
].amax_history[0]
FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = (
fp8_meta[fp8_meta_key].amax_history
)
FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = fp8_meta[
fp8_meta_key
].amax_history

def get_fp8_meta(self, mode: str) -> Optional[dict[str, Any]]:
"""FP8 metadata
Expand Down Expand Up @@ -335,16 +335,19 @@ def _load_fp8_metas(self, fp8_metas: Optional[dict[str, Any]]) -> None:
Tensor copies should be generated with _save_fp8_metas.
"""
assert (self._fp8_metas is None) == (fp8_metas is None), \
"Saved FP8 metadata does not match operation's FP8 metadata"
assert (self._fp8_metas is None) == (
fp8_metas is None
), "Saved FP8 metadata does not match operation's FP8 metadata"
if fp8_metas is None:
return
for mode, fp8_meta in fp8_metas.items():
assert mode in self._fp8_metas, \
f"Found an unexpected key ({mode=}) in saved FP8 metadata"
assert (
mode in self._fp8_metas
), f"Found an unexpected key ({mode=}) in saved FP8 metadata"
for fp8_meta_key, tensors in fp8_meta.items():
assert fp8_meta_key in self._fp8_metas[mode], \
f"Found an unexpected key ({mode=}, {fp8_meta_key=}) in saved FP8 metadata"
assert (
fp8_meta_key in self._fp8_metas[mode]
), f"Found an unexpected key ({mode=}, {fp8_meta_key=}) in saved FP8 metadata"
scale, scale_inv, amax_history = tensors
self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale)
self._fp8_metas[mode][fp8_meta_key].scale_inv.copy_(scale_inv)
Expand Down

0 comments on commit e5d40a6

Please sign in to comment.