diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 828eb9d552..fab016ff3f 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -5910,6 +5910,9 @@ def forward( assert ( fp8_meta is not None ), "FP8 metadata fp8_meta is required for FP8 attention!" + assert ( + not context_parallel or fp8_meta["recipe"].reduce_amax + ), "Amax reduction across TP+CP group is necessary when using context parallelism with FP8!" if context_parallel: assert (