-
Notifications
You must be signed in to change notification settings - Fork 327
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement fused kernel for FP8 scale update (#593)
* Implement fused kernel for FP8 scale update Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add fused kernel for amax and scale update Add unit test. Signed-off-by: Tim Moon <tmoon@nvidia.com> * Replace paddle.fluid imports with paddle.base Signed-off-by: Tim Moon <tmoon@nvidia.com> * Move fused kernel to core library Signed-off-by: Tim Moon <tmoon@nvidia.com> * Debug test Signed-off-by: Tim Moon <tmoon@nvidia.com> * Use FP8 update kernel in Paddle Signed-off-by: Tim Moon <tmoon@nvidia.com> * Debug FP8 scale update in Paddle Signed-off-by: Tim Moon <tmoon@nvidia.com> * Fix lint errors Signed-off-by: Tim Moon <tmoon@nvidia.com> * Debug Paddle test failures Signed-off-by: Tim Moon <tmoon@nvidia.com> * Make update kernel in-place for PyTorch Signed-off-by: Tim Moon <tmoon@nvidia.com> * Revert cudnn-frontend commit Signed-off-by: Tim Moon <tmoon@nvidia.com> --------- Signed-off-by: Tim Moon <tmoon@nvidia.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
- Loading branch information
Showing
14 changed files
with
607 additions
and
74 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# See LICENSE for license information. | ||
|
||
from typing import Optional | ||
|
||
import pytest | ||
import torch | ||
|
||
import transformer_engine.common.recipe | ||
import transformer_engine.pytorch as te | ||
from transformer_engine.pytorch.fp8 import ( | ||
FP8GlobalStateManager, | ||
amax_and_scale_update, | ||
get_default_fp8_recipe, | ||
) | ||
|
||
# Check if FP8 is supported | ||
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() | ||
|
||
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) | ||
class TestFP8Recipe: | ||
|
||
@staticmethod | ||
def setup_class(cls) -> None: | ||
# Configure RNG | ||
seed = 1234 | ||
torch.manual_seed(seed) | ||
torch.cuda.manual_seed(seed) | ||
|
||
@pytest.mark.parametrize("amax_history_len", [1, 31, 1024]) | ||
@pytest.mark.parametrize("amax_compute_algo", ["max", "most_recent"]) | ||
@pytest.mark.parametrize("is_first_microbatch", [None, True, False]) | ||
def test_amax_and_scale_update( | ||
self, | ||
amax_history_len: int, | ||
amax_compute_algo: str, | ||
is_first_microbatch: Optional[bool], | ||
margin: int = 2, | ||
): | ||
|
||
# Construct linear module | ||
fp8_format = transformer_engine.common.recipe.Format.HYBRID | ||
recipe = transformer_engine.common.recipe.DelayedScaling( | ||
margin=margin, | ||
interval=1, | ||
fp8_format=fp8_format, | ||
amax_history_len=amax_history_len, | ||
amax_compute_algo=amax_compute_algo, | ||
) | ||
with te.fp8_autocast(enabled=True, fp8_recipe=recipe): | ||
module = te.Linear(16, 16) | ||
y = module(torch.zeros([16, 16], device="cuda")) | ||
y.backward(torch.zeros_like(y)) | ||
|
||
# Get amax history and scaling factors | ||
fp8_meta = module.fp8_meta | ||
forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) | ||
amax_history_forward = fp8_meta[forward_key].amax_history | ||
scale_forward = fp8_meta[forward_key].scale | ||
scale_inv_forward = fp8_meta[forward_key].scale_inv | ||
backward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False) | ||
amax_history_backward = fp8_meta[backward_key].amax_history | ||
scale_backward = fp8_meta[backward_key].scale | ||
scale_inv_backward = fp8_meta[backward_key].scale_inv | ||
|
||
# Tweak amax history and scaling factors | ||
amax_history_forward.copy_(2 * torch.rand_like(amax_history_forward) + 0.5) | ||
if amax_history_len > 1: | ||
amax_history_forward[1, 0].fill_(3) | ||
scale_forward.copy_(2 * torch.rand_like(scale_forward) + 0.5) | ||
scale_inv_forward.copy_(torch.reciprocal(scale_forward)) | ||
amax_history_backward.copy_(2 * torch.rand_like(amax_history_backward) + 0.5) | ||
scale_backward.copy_(2 * torch.rand_like(scale_backward) + 0.5) | ||
scale_inv_backward.copy_(torch.reciprocal(scale_backward)) | ||
|
||
# Expected amax history after update | ||
ref_amax_history_forward = torch.roll(amax_history_forward, -1, dims=0) | ||
ref_amax_history_forward[0].zero_() | ||
ref_amax_history_backward = torch.roll(amax_history_backward, -1, dims=0) | ||
ref_amax_history_backward[0].zero_() | ||
|
||
# Expected scale and scale inverse | ||
if amax_compute_algo == "max": | ||
ref_amax_forward = amax_history_forward.max(dim=0).values | ||
ref_amax_backward = amax_history_backward.max(dim=0).values | ||
elif amax_compute_algo == "most_recent": | ||
ref_amax_forward = amax_history_forward[0] | ||
ref_amax_backward = amax_history_backward[0] | ||
else: | ||
raise ValueError(f"{amax_compute_algo=} is not supported") | ||
ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2 ** margin) | ||
ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2 ** margin) | ||
ref_scale_inv_forward = torch.reciprocal(ref_scale_forward) | ||
update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch | ||
if not update_weight_scale_inv: | ||
ref_scale_inv_forward[1].copy_(scale_inv_forward[1]) | ||
ref_scale_inv_backward = torch.reciprocal(ref_scale_backward) | ||
|
||
# Make sure we are not trivially passing tests | ||
if amax_history_len > 1: | ||
with pytest.raises(AssertionError): | ||
torch.testing.assert_close( | ||
amax_history_forward[1:], | ||
ref_amax_history_forward[1:], | ||
) | ||
with pytest.raises(AssertionError): | ||
torch.testing.assert_close( | ||
scale_forward, | ||
ref_scale_forward, | ||
) | ||
with pytest.raises(AssertionError): | ||
torch.testing.assert_close( | ||
scale_inv_forward, | ||
ref_scale_inv_forward, | ||
) | ||
if amax_history_len > 1: | ||
with pytest.raises(AssertionError): | ||
torch.testing.assert_close( | ||
fp8_meta[backward_key].amax_history[1:], | ||
ref_amax_history_backward[1:], | ||
) | ||
with pytest.raises(AssertionError): | ||
torch.testing.assert_close( | ||
fp8_meta[backward_key].scale, | ||
ref_scale_backward, | ||
) | ||
with pytest.raises(AssertionError): | ||
torch.testing.assert_close( | ||
fp8_meta[backward_key].scale_inv, | ||
ref_scale_inv_backward, | ||
) | ||
|
||
# Perform forward and backward pass to update fp8_meta | ||
with te.fp8_autocast(enabled=True, fp8_recipe=recipe): | ||
x = torch.zeros([16, 16], device="cuda") | ||
y = module(x, is_first_microbatch=is_first_microbatch) | ||
y.backward(torch.zeros_like(y)) | ||
|
||
# Check that fp8_meta matches expected values | ||
torch.testing.assert_close( | ||
fp8_meta[forward_key].amax_history[1:], | ||
ref_amax_history_forward[1:], | ||
) | ||
torch.testing.assert_close( | ||
fp8_meta[forward_key].scale, | ||
ref_scale_forward, | ||
) | ||
torch.testing.assert_close( | ||
fp8_meta[forward_key].scale_inv, | ||
ref_scale_inv_forward, | ||
) | ||
torch.testing.assert_close( | ||
fp8_meta[backward_key].amax_history[1:], | ||
ref_amax_history_backward[1:], | ||
) | ||
torch.testing.assert_close( | ||
fp8_meta[backward_key].scale, | ||
ref_scale_backward, | ||
) | ||
torch.testing.assert_close( | ||
fp8_meta[backward_key].scale_inv, | ||
ref_scale_inv_backward, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
63 changes: 63 additions & 0 deletions
63
transformer_engine/common/include/transformer_engine/recipe.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
/************************************************************************* | ||
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
* | ||
* See LICENSE for license information. | ||
************************************************************************/ | ||
|
||
/*! \file recipe.h | ||
* \brief Functions handling FP8 recipes. | ||
*/ | ||
|
||
#ifndef TRANSFORMER_ENGINE_RECIPE_H_ | ||
#define TRANSFORMER_ENGINE_RECIPE_H_ | ||
|
||
#include "transformer_engine.h" | ||
|
||
#ifdef __cplusplus | ||
extern "C" { | ||
#endif | ||
|
||
/*! \brief Update FP8 scaling factors with delayed scaling recipe. | ||
* | ||
* The amax history is rotated by -1 (e.g. the first entry shifts to | ||
* the last, the last entry shifts to the second to last) and the | ||
* first entry is set to zero. The scaling factor is estimated so the | ||
* FP8 tensor's maximum absolute value is | ||
* @f$ 2^{-\text{margin}} \text{max}_\text{fp8\_dtype} @f$. | ||
* | ||
* \param[in] amax_history History of maximum absolute values. | ||
* Shape: [history_length, num_scales] | ||
* \param[in] scale Scaling factor for casting to FP8. Shape: [num_scales] | ||
* \param[in] scale_inv Scaling factor for casting from FP8. Shape: [num_scales] | ||
* \param[in] scale_inv_mask Boolean mask indicating scale_inv entries to update. May be | ||
* empty, in which case all scale_inv entries are updated. | ||
* Shape: [num_scales] | ||
* \param[out] updated_amax_history Updated history of maximum absolute values. | ||
* Shape: [history_length, num_scales] | ||
* \param[out] updated_scale Updated scaling factor for casting to FP8. | ||
* Shape: [num_scales] | ||
* \param[out] updated_scale_inv Updated scaling factor for casting from FP8. | ||
* Shape: [num_scales] | ||
* \param[in] amax_compute_algo Method to reduce amax history. Options are "max" and | ||
* "most_recent". | ||
* \param[in] fp8_dtype FP8 datatype. | ||
* \param[in] margin Scaling factor margin. | ||
* \param[in] stream CUDA stream. | ||
*/ | ||
void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_history, | ||
const NVTETensor scale, | ||
const NVTETensor scale_inv, | ||
const NVTETensor scale_inv_mask, | ||
NVTETensor updated_amax_history, | ||
NVTETensor updated_scale, | ||
NVTETensor updated_scale_inv, | ||
const char* amax_compute_algo, | ||
NVTEDType fp8_dtype, | ||
float margin, | ||
cudaStream_t stream); | ||
|
||
#ifdef __cplusplus | ||
} // extern "C" | ||
#endif | ||
|
||
#endif // TRANSFORMER_ENGINE_RECIPE_H_ |
File renamed without changes.
Oops, something went wrong.