Skip to content

Commit

Permalink
Respect autograd dtype
Browse files Browse the repository at this point in the history
Signed-off-by: Tim Moon <tmoon@nvidia.com>
  • Loading branch information
timmoon10 committed Sep 26, 2024
1 parent 7d8e08b commit fd4e541
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ def fuser_backward(
weight_requires_grad=weight_requires_grad,
bias_requires_grad=(bias_op is not None),
device=linear_op.device,
dtype=linear_op.dtype,
dtype=linear_op.dtype, # TODO: Use linear_op_ctx.dtype when PR #1202 lands
grad_weight=grad_weight,
accumulate_into_grad_weight=accumulate_into_main_grad,
tensor_parallel_mode=self.tensor_parallel_mode,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,11 @@ def fuser_forward(
if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0:
grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output")

# Get autocast dtype if needed
dtype = None
if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda")

# Userbuffers options
if linear_op._userbuffers_options is None:
raise RuntimeError("Linear op is missing dict for Userbuffers options")
Expand All @@ -481,7 +486,7 @@ def fuser_forward(
weight=linear_op.weight,
bias=bias,
device=linear_op.device,
dtype=linear_op.dtype,
dtype=dtype,
tensor_parallel_mode=self.tensor_parallel_mode,
tensor_parallel_group=self.tensor_parallel_group,
tensor_parallel_size=self.tensor_parallel_size,
Expand All @@ -500,6 +505,7 @@ def fuser_forward(
linear_op_ctx.weight_fp8_meta = weight_fp8_meta
linear_op_ctx.grad_output_fp8_meta = grad_output_fp8_meta
linear_op_ctx.grad_input_fp8_meta = grad_input_fp8_meta
linear_op_ctx.dtype = dtype
linear_op_ctx.input_dims = input_.size()
linear_op_ctx.input_requires_grad = input_.requires_grad
linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad
Expand Down

0 comments on commit fd4e541

Please sign in to comment.