From e5d40a69517dac4474b3ab86edb92fe5ef1bf27f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 16 Aug 2024 01:53:46 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_cuda_graphs.py | 9 ++++++++- transformer_engine/pytorch/graph.py | 2 ++ transformer_engine/pytorch/ops/op.py | 21 ++++++++++++--------- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 1af004f1ad..010050baea 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -141,6 +141,7 @@ def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor: "linear_op", ] + def _test_cuda_graphs( *, graph_mode: str, @@ -331,6 +332,7 @@ def test_make_graphed_callables( "mha", ] + @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.parametrize( "module", @@ -478,7 +480,12 @@ def _test_cuda_graphs_with_kwargs( grad_output = generate_data(model_config, dtype, requires_grad=False) attn_mask = torch.randint( 2, - (model_config.batch_size, 1, model_config.sequence_length, model_config.sequence_length), + ( + model_config.batch_size, + 1, + model_config.sequence_length, + model_config.sequence_length, + ), dtype=torch.bool, device="cuda", ) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index b8b383ad6e..7193d33476 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -492,6 +492,7 @@ def save_fp8_tensors( with adjusted amax history sizes. """ from .ops import Sequential, FusibleOperation # Avoid circular import + fp8_tensors = [] for module in modules: for m in module.modules(): @@ -521,6 +522,7 @@ def restore_fp8_tensors( ) -> None: """Restore FP8 tensors.""" from .ops import Sequential, FusibleOperation # Avoid circular import + for module in modules: for m in module.modules(): module_tensors = fp8_tensors.pop(0) diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index d1f5f2c719..dd3307c0fb 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -286,9 +286,9 @@ def _maybe_update_fp8_meta( FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = fp8_meta[ fp8_meta_key ].amax_history[0] - FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = ( - fp8_meta[fp8_meta_key].amax_history - ) + FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = fp8_meta[ + fp8_meta_key + ].amax_history def get_fp8_meta(self, mode: str) -> Optional[dict[str, Any]]: """FP8 metadata @@ -335,16 +335,19 @@ def _load_fp8_metas(self, fp8_metas: Optional[dict[str, Any]]) -> None: Tensor copies should be generated with _save_fp8_metas. """ - assert (self._fp8_metas is None) == (fp8_metas is None), \ - "Saved FP8 metadata does not match operation's FP8 metadata" + assert (self._fp8_metas is None) == ( + fp8_metas is None + ), "Saved FP8 metadata does not match operation's FP8 metadata" if fp8_metas is None: return for mode, fp8_meta in fp8_metas.items(): - assert mode in self._fp8_metas, \ - f"Found an unexpected key ({mode=}) in saved FP8 metadata" + assert ( + mode in self._fp8_metas + ), f"Found an unexpected key ({mode=}) in saved FP8 metadata" for fp8_meta_key, tensors in fp8_meta.items(): - assert fp8_meta_key in self._fp8_metas[mode], \ - f"Found an unexpected key ({mode=}, {fp8_meta_key=}) in saved FP8 metadata" + assert ( + fp8_meta_key in self._fp8_metas[mode] + ), f"Found an unexpected key ({mode=}, {fp8_meta_key=}) in saved FP8 metadata" scale, scale_inv, amax_history = tensors self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale) self._fp8_metas[mode][fp8_meta_key].scale_inv.copy_(scale_inv)