diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index fdf65db21e..e9fb11e3b9 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -354,12 +354,8 @@ def backward( # Compute the forward pass. detached_inputs = detach_variable(inputs) - with ( - torch.enable_grad(), - ctx.recompute_ctx, - ctx.torch_gpu_amp_ctx, - ctx.torch_cpu_amp_ctx, - activation_recompute_forward(activation_recompute=True, recompute_phase=True), + with torch.enable_grad(), ctx.recompute_ctx, ctx.torch_gpu_amp_ctx, ctx.torch_cpu_amp_ctx, activation_recompute_forward( + activation_recompute=True, recompute_phase=True ): outputs = ctx.run_function(*detached_inputs, **ctx.kwargs) @@ -680,13 +676,9 @@ def checkpoint( torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx = _get_active_autocast_contexts() def recompute_fn(*args, **kwargs): - with ( - torch.autograd.enable_grad(), - te_recompute_ctx, - user_recompute_ctx, - torch_gpu_amp_forward_ctx, - torch_cpu_amp_forward_ctx, - ): + with torch.autograd.enable_grad(), ( + te_recompute_ctx + ), user_recompute_ctx, torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx: function(*args, **kwargs) # Initialize a new checkpoint frame for each new forward pass.