-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] Custom kernel to compute reciprocal of a single float #1016
Conversation
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
/te-ci pytorch |
fwd_scale_inverses, | ||
tex.FP8FwdTensors.GEMM1_INPUT, | ||
inputmat_fp8_scale_inv, | ||
0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fwd_scale_inverses
is created with a clone operation (~20 us), while input_fp8_scale_inv
is created with the scalar reciprocal kernel (~10 us).
@@ -335,7 +343,7 @@ def forward( | |||
weight, | |||
weight_fp8, | |||
weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None, | |||
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, | |||
inputmat_fp8_scale_inv, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is only set for float8tensor case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We handle three cases:
- Non-FP8: scale-inv is
None
inputmat_fp8_scale_inv = None - FP8,
Float8Tensor
input: scale-inv is taken fromFloat8Tensor
inputmat_fp8_scale_inv = inputmat._scale_inv - FP8, non-
Float8Tensor
input: scale-inv is computed with fast kernel
inputmat_fp8_scale_inv = tex.scalar_reciprocal(
Closed by #1083 |
Description
FP8 training is frequently bottlenecked by CPU overheads and a non-trivial fraction of CPU overhead comes from small PyTorch operations. For example, when I benchmark the forward pass of small
Linear
modules on an L40, I estimate ~20% of runtime is spent in handling the FP8 scaling factors (mainly in reciprocal and clone operations). This PR attempts to mitigate these overheads by adding ascalar_reciprocal
kernel that operates on a singlefloat
, bringing the kernel launch cost down from ~20 us to ~10 us. In my benchmark ofLinear
forwards, I see a 8% reduction in runtime.Alternative approaches:
torch.compile
to fuse FP8 scale operations. We would require significant refactoring to avoid incurring extra overhead from graph breaks, especially in how we deal with theFP8TensorMeta
class.Type of change
Changes
Checklist: