Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Custom kernel to compute reciprocal of a single float #1016

Closed
wants to merge 5 commits into from

Conversation

timmoon10
Copy link
Collaborator

Description

FP8 training is frequently bottlenecked by CPU overheads and a non-trivial fraction of CPU overhead comes from small PyTorch operations. For example, when I benchmark the forward pass of small Linear modules on an L40, I estimate ~20% of runtime is spent in handling the FP8 scaling factors (mainly in reciprocal and clone operations). This PR attempts to mitigate these overheads by adding a scalar_reciprocal kernel that operates on a single float, bringing the kernel launch cost down from ~20 us to ~10 us. In my benchmark of Linear forwards, I see a 8% reduction in runtime.

Alternative approaches:

  • Modify the cast and cast-transpose kernels to perform the scale-inv update, similar to how they perform the amax update. Logically, the scale is part of the FP8 recipe and the scale-inv is part of the data.
  • Use torch.compile to fuse FP8 scale operations. We would require significant refactoring to avoid incurring extra overhead from graph breaks, especially in how we deal with the FP8TensorMeta class.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

  • Custom kernel to compute reciprocal of a single float

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 added the enhancement New feature or request label Jul 15, 2024
@timmoon10 timmoon10 requested a review from ksivaman July 15, 2024 18:28
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT,
inputmat_fp8_scale_inv,
0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fwd_scale_inverses is created with a clone operation (~20 us), while input_fp8_scale_inv is created with the scalar reciprocal kernel (~10 us).

@@ -335,7 +343,7 @@ def forward(
weight,
weight_fp8,
weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
inputmat_fp8_scale_inv,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only set for float8tensor case

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We handle three cases:

@timmoon10 timmoon10 marked this pull request as draft July 16, 2024 18:50
@timmoon10
Copy link
Collaborator Author

Closed by #1083

@timmoon10 timmoon10 closed this Aug 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants