From 7690ef5e088a275ffb9dfcc5d0b7a6632eec1168 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 26 Sep 2024 04:24:02 +0000 Subject: [PATCH] Respect autograd dtype Signed-off-by: Tim Moon --- .../pytorch/ops/fused/userbuffers_backward_linear.py | 2 +- .../pytorch/ops/fused/userbuffers_forward_linear.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index c884b66612..1b435bcb05 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -667,7 +667,7 @@ def fuser_backward( weight_requires_grad=weight_requires_grad, bias_requires_grad=(bias_op is not None), device=linear_op.device, - dtype=linear_op.dtype, + dtype=linear_op.dtype, # TODO: linear_op_ctx.dtype grad_weight=grad_weight, accumulate_into_grad_weight=accumulate_into_main_grad, tensor_parallel_mode=self.tensor_parallel_mode, diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index b1c15f433a..0741ee1200 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -471,6 +471,11 @@ def fuser_forward( if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0: grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output") + # Get autocast dtype if needed + dtype = None + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + # Userbuffers options if linear_op._userbuffers_options is None: raise RuntimeError("Linear op is missing dict for Userbuffers options") @@ -481,7 +486,7 @@ def fuser_forward( weight=linear_op.weight, bias=bias, device=linear_op.device, - dtype=linear_op.dtype, + dtype=dtype, tensor_parallel_mode=self.tensor_parallel_mode, tensor_parallel_group=self.tensor_parallel_group, tensor_parallel_size=self.tensor_parallel_size, @@ -500,6 +505,7 @@ def fuser_forward( linear_op_ctx.weight_fp8_meta = weight_fp8_meta linear_op_ctx.grad_output_fp8_meta = grad_output_fp8_meta linear_op_ctx.grad_input_fp8_meta = grad_input_fp8_meta + linear_op_ctx.dtype = dtype linear_op_ctx.input_dims = input_.size() linear_op_ctx.input_requires_grad = input_.requires_grad linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad