Skip to content

Commit

Permalink
Review suggestions from @ptrendx
Browse files Browse the repository at this point in the history
Output tensor dtype and device take precedence over weight tensor in linear functional API. Move some index calculation to fuser constructor. Avoid some unnecessary dereferences.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
  • Loading branch information
timmoon10 committed Aug 3, 2024
1 parent 8b78f65 commit 82b83c9
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
6 changes: 3 additions & 3 deletions transformer_engine/pytorch/ops/basic/basic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def _functional_forward(

# Check device
if device is None:
device = weight.device
device = weight.device if out is None else out.device
device = canonicalize_device(device)
if device.type != "cuda":
raise ValueError(f"Only CUDA devices are supported (got {device})")
Expand All @@ -387,7 +387,7 @@ def _functional_forward(

# Check datatype
if dtype is None:
dtype = weight.dtype
dtype = weight.dtype if out is None else out.dtype
dtype = canonicalize_dtype(dtype)
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
Expand Down Expand Up @@ -1062,7 +1062,7 @@ def _functional_backward(
_wait_async(dy_async)
_wait_async(x_async)
_wait_async(dx_async)
if grad_input is None:
if dx is not None and grad_input is None:
grad_input = reshape(dx, input_dims)
return grad_input, grad_weight

Expand Down
22 changes: 11 additions & 11 deletions transformer_engine/pytorch/ops/fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def forward(
basic_ops: list[BasicOperation],
basic_op_kwargs: list[dict[str, Any]],
num_params: int,
*extra_inputs: torch.nn.Parameter,
num_extra_inputs: int,
*params_and_extra_inputs: torch.nn.Parameter,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Forward pass
Expand Down Expand Up @@ -87,14 +88,13 @@ def forward(
basic_op_ctxs = [OperationContext() for _ in range(len(basic_ops))]

# Unflatten list of parameters and extra tensor inputs
num_extra_inputs = sum(op.num_extra_inputs for op in basic_ops)
if len(extra_inputs) != num_params + num_extra_inputs:
if len(params_and_extra_inputs) != num_params + num_extra_inputs:
raise ValueError(
f"Expected {num_params + num_extra_inputs} extra tensor arguments "
f"({num_params} parameters, {num_extra_inputs} extra inputs), "
f"but got {len(extra_inputs)}"
)
_, extra_inputs = _split_tuple(extra_inputs, num_params)
_, extra_inputs = _split_tuple(params_and_extra_inputs, num_params)
basic_op_extra_inputs = []
for op in basic_ops:
xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs)
Expand Down Expand Up @@ -184,11 +184,9 @@ def backward(
basic_op_ctxs = func_ctx.basic_op_ctxs

# Unflatten list of saved tensors
saved_tensors = func_ctx.saved_tensors
for ctx in basic_op_ctxs:
ctx.saved_tensors = saved_tensors[slice(*ctx._saved_tensors_range)]
ctx.saved_tensors = func_ctx.saved_tensors[slice(*ctx._saved_tensors_range)]
ctx._saved_tensors_range = None
del saved_tensors

# Unflatten list of extra tensor output grads
if len(grad_extra_outputs) != func_ctx.num_extra_outputs:
Expand All @@ -200,7 +198,6 @@ def backward(
for op in basic_ops:
dys, grad_extra_outputs = _split_tuple(grad_extra_outputs, op.num_extra_outputs)
basic_op_grad_extra_outputs.append(dys)
del dys

# Apply backward ops
dx = grad_output
Expand Down Expand Up @@ -268,6 +265,7 @@ def backward(
None, # basic_ops
None, # basic_op_kwargs
None, # num_params
None, # num_extra_inputs
*grad_params_flat,
*grad_extra_inputs_flat,
)
Expand Down Expand Up @@ -301,6 +299,9 @@ def __init__(
self._num_basic_ops: int = len(basic_ops)
self._basic_ops: list[BasicOperation] = basic_ops

# Number of extra tensor inputs
self._num_extra_inputs: int = sum(op.num_extra_inputs for op in basic_ops)

# Ops for forward and backward pass
self._forward_ops: list[tuple[FusibleOperation, list[int]]]
self._backward_ops: list[tuple[FusibleOperation, list[int]]]
Expand Down Expand Up @@ -351,9 +352,7 @@ def __call__(
basic_op_kwargs = [{} for _ in range(len(self._basic_ops))]

# Flatten list of parameters
params = []
for op in self._basic_ops:
params.extend(op.parameters())
params = [param for op.parameters() for op in self._basic_ops]

# Fuser forward pass
return _OperationFuserAutogradFunction.apply(
Expand All @@ -363,6 +362,7 @@ def __call__(
self._basic_ops,
basic_op_kwargs,
len(params),
self._num_extra_inputs,
*params,
*extra_inputs,
)

0 comments on commit 82b83c9

Please sign in to comment.