Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Branching operations #1027

Merged
merged 15 commits into from
Aug 10, 2024
Merged
381 changes: 371 additions & 10 deletions tests/pytorch/test_fusible_ops.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions transformer_engine/pytorch/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
"""

from transformer_engine.pytorch.ops.basic import (
AddInPlace,
AllGather,
AllReduce,
BasicLinear,
Bias,
Identity,
MakeExtraOutput,
ReduceScatter,
Reshape,
)
Expand Down
2 changes: 2 additions & 0 deletions transformer_engine/pytorch/ops/basic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@

"""Single tensor operations supported by the operation fuser."""

from .add_in_place import AddInPlace
from .all_gather import AllGather
from .all_reduce import AllReduce
from .basic_linear import BasicLinear
from .bias import Bias
from .identity import Identity
from .make_extra_output import MakeExtraOutput
from .reduce_scatter import ReduceScatter
from .reshape import Reshape
79 changes: 79 additions & 0 deletions transformer_engine/pytorch/ops/basic/add_in_place.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Fusible operation for in-place add."""

from __future__ import annotations
from collections.abc import Iterable
from typing import Any, Optional

import torch

from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)


class AddInPlace(BasicOperation):
"""Add in-place

This operation requires an extra tensor input to the operation
fuser. The main input is added in-place to the extra input, and a
view of the extra input is output.

This operation is considered an advanced feature and most users
are discouraged from using it. In-place operations break some
autograd assumptions and they can result in subtle, esoteric bugs.

Compare to `MakeExtraOutput`, which does a similar operation in
the backward pass.

"""

# Operation expects buffer for output tensor
num_extra_inputs: int = 1

def op_forward(self, *args, **kwargs) -> None:
raise RuntimeError(
"{self.__class__.__name__} operation has "
f"{self.num_extra_inputs} extra tensor inputs "
f"and {self.num_extra_outputs} extra tensor outputs. "
"It overrides `fuser_forward` instead of `op_forward`."
)

def op_backward(self, *args, **kwargs) -> None:
raise RuntimeError(
"{self.__class__.__name__} operation has "
f"{self.num_extra_inputs} extra tensor inputs "
f"and {self.num_extra_outputs} extra tensor outputs. "
"It overrides `fuser_backward` instead of `op_backward`."
)

def fuser_forward(
self,
basic_op_ctxs: list[OperationContext],
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
basic_op_prev_ops: list[Optional[BasicOperation]],
basic_op_next_ops: list[Optional[BasicOperation]],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
output = basic_op_extra_inputs[0][0].detach()
output += input_
return output, [()]

def fuser_backward(
self,
basic_op_ctxs: list[OperationContext],
grad_output: torch.Tensor,
*,
basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]],
) -> tuple[
torch.Tensor,
Iterable[Iterable[Optional[torch.Tensor]]],
Iterable[Iterable[Optional[torch.Tensor]]],
]:
return grad_output, [], [(grad_output,)]
Loading
Loading