-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] Branching operations #1027
Conversation
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
/te-ci pytorch |
/te-ci pytorch |
@@ -389,6 +406,32 @@ def _functional_forward( | |||
"are not compatible" | |||
) | |||
|
|||
# Check output tensor dims |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we need to do this here (same for input) or maybe we could rely on the error checking on the C++ side to minimize CPU overhead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that would be a good optimization in the future, especially since the linear functional API is used in multiple operations.
Output tensor dtype and device take precedence over weight tensor in linear functional API. Move some index calculation to fuser constructor. Avoid some unnecessary dereferences. Signed-off-by: Tim Moon <tmoon@nvidia.com>
82b83c9
to
2679fbf
Compare
/te-ci pytorch |
Signed-off-by: Tim Moon <tmoon@nvidia.com>
/te-ci pytorch |
Could you comment on how the change from your last commit helped with the unittest failures? The change from list comprehension to the for loop should not change the behavior, right? |
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
/te-ci pytorch |
Description
This PR modifies the operation-based API (#707) to support some simple branching behavior: operations can now accept extra tensor inputs and generate extra tensor outputs. This enables fusions like GEMMs with
beta=1
:Support for multiple inputs will also be necessary for cross-attention (and SSMs?). Note that we are not planning to support more complicated structures since that will take us down the road of general graph compilers.
Type of change
Changes
beta=1
Checklist: