Skip to content

Commit

Permalink
Implement fused kernel for FP8 scale update (#593)
Browse files Browse the repository at this point in the history
* Implement fused kernel for FP8 scale update

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add fused kernel for amax and scale update

Add unit test.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Replace paddle.fluid imports with paddle.base

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Move fused kernel to core library

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Debug test

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Use FP8 update kernel in Paddle

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Debug FP8 scale update in Paddle

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Fix lint errors

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Debug Paddle test failures

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Make update kernel in-place for PyTorch

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Revert cudnn-frontend commit

Signed-off-by: Tim Moon <tmoon@nvidia.com>

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
  • Loading branch information
timmoon10 and ksivaman authored Feb 8, 2024
1 parent 379c1ee commit a950061
Show file tree
Hide file tree
Showing 14 changed files with 607 additions and 74 deletions.
6 changes: 4 additions & 2 deletions tests/paddle/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,6 +1046,7 @@ def test_amax_and_scale_update(update_weight_scale_inv):
num_gemm = 6
history_len = 1024
recipe = DelayedScaling()
fp8_dtype = tex.DType.kFloat8E4M3
fp8_max = recipe.fp8_format.value.max_fwd
non_weight_mask = paddle.to_tensor([True, False] * (num_gemm // 2))

Expand Down Expand Up @@ -1073,12 +1074,13 @@ def calc_ref(amax, scale, fp8_max, margin=0):
scale_actual = paddle.zeros_like(scale_tensor)
scale_inv_actual = paddle.zeros_like(scale_tensor)

if update_weight_scale_inv:
non_weight_mask = paddle.empty([0])
tex.amax_and_scale_update_inplace(_amax_history=amax_history_tensor,
_scale=scale_actual,
_scale_inv=scale_inv_actual,
non_weight_mask=non_weight_mask,
update_weight_scale_inv=update_weight_scale_inv,
fp8_max=fp8_max,
fp8_dtype=int(fp8_dtype),
margin=0.,
amax_compute="max")

Expand Down
164 changes: 164 additions & 0 deletions tests/pytorch/test_recipe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

from typing import Optional

import pytest
import torch

import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import (
FP8GlobalStateManager,
amax_and_scale_update,
get_default_fp8_recipe,
)

# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()

@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFP8Recipe:

@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

@pytest.mark.parametrize("amax_history_len", [1, 31, 1024])
@pytest.mark.parametrize("amax_compute_algo", ["max", "most_recent"])
@pytest.mark.parametrize("is_first_microbatch", [None, True, False])
def test_amax_and_scale_update(
self,
amax_history_len: int,
amax_compute_algo: str,
is_first_microbatch: Optional[bool],
margin: int = 2,
):

# Construct linear module
fp8_format = transformer_engine.common.recipe.Format.HYBRID
recipe = transformer_engine.common.recipe.DelayedScaling(
margin=margin,
interval=1,
fp8_format=fp8_format,
amax_history_len=amax_history_len,
amax_compute_algo=amax_compute_algo,
)
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
module = te.Linear(16, 16)
y = module(torch.zeros([16, 16], device="cuda"))
y.backward(torch.zeros_like(y))

# Get amax history and scaling factors
fp8_meta = module.fp8_meta
forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)
amax_history_forward = fp8_meta[forward_key].amax_history
scale_forward = fp8_meta[forward_key].scale
scale_inv_forward = fp8_meta[forward_key].scale_inv
backward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False)
amax_history_backward = fp8_meta[backward_key].amax_history
scale_backward = fp8_meta[backward_key].scale
scale_inv_backward = fp8_meta[backward_key].scale_inv

# Tweak amax history and scaling factors
amax_history_forward.copy_(2 * torch.rand_like(amax_history_forward) + 0.5)
if amax_history_len > 1:
amax_history_forward[1, 0].fill_(3)
scale_forward.copy_(2 * torch.rand_like(scale_forward) + 0.5)
scale_inv_forward.copy_(torch.reciprocal(scale_forward))
amax_history_backward.copy_(2 * torch.rand_like(amax_history_backward) + 0.5)
scale_backward.copy_(2 * torch.rand_like(scale_backward) + 0.5)
scale_inv_backward.copy_(torch.reciprocal(scale_backward))

# Expected amax history after update
ref_amax_history_forward = torch.roll(amax_history_forward, -1, dims=0)
ref_amax_history_forward[0].zero_()
ref_amax_history_backward = torch.roll(amax_history_backward, -1, dims=0)
ref_amax_history_backward[0].zero_()

# Expected scale and scale inverse
if amax_compute_algo == "max":
ref_amax_forward = amax_history_forward.max(dim=0).values
ref_amax_backward = amax_history_backward.max(dim=0).values
elif amax_compute_algo == "most_recent":
ref_amax_forward = amax_history_forward[0]
ref_amax_backward = amax_history_backward[0]
else:
raise ValueError(f"{amax_compute_algo=} is not supported")
ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2 ** margin)
ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2 ** margin)
ref_scale_inv_forward = torch.reciprocal(ref_scale_forward)
update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch
if not update_weight_scale_inv:
ref_scale_inv_forward[1].copy_(scale_inv_forward[1])
ref_scale_inv_backward = torch.reciprocal(ref_scale_backward)

# Make sure we are not trivially passing tests
if amax_history_len > 1:
with pytest.raises(AssertionError):
torch.testing.assert_close(
amax_history_forward[1:],
ref_amax_history_forward[1:],
)
with pytest.raises(AssertionError):
torch.testing.assert_close(
scale_forward,
ref_scale_forward,
)
with pytest.raises(AssertionError):
torch.testing.assert_close(
scale_inv_forward,
ref_scale_inv_forward,
)
if amax_history_len > 1:
with pytest.raises(AssertionError):
torch.testing.assert_close(
fp8_meta[backward_key].amax_history[1:],
ref_amax_history_backward[1:],
)
with pytest.raises(AssertionError):
torch.testing.assert_close(
fp8_meta[backward_key].scale,
ref_scale_backward,
)
with pytest.raises(AssertionError):
torch.testing.assert_close(
fp8_meta[backward_key].scale_inv,
ref_scale_inv_backward,
)

# Perform forward and backward pass to update fp8_meta
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
x = torch.zeros([16, 16], device="cuda")
y = module(x, is_first_microbatch=is_first_microbatch)
y.backward(torch.zeros_like(y))

# Check that fp8_meta matches expected values
torch.testing.assert_close(
fp8_meta[forward_key].amax_history[1:],
ref_amax_history_forward[1:],
)
torch.testing.assert_close(
fp8_meta[forward_key].scale,
ref_scale_forward,
)
torch.testing.assert_close(
fp8_meta[forward_key].scale_inv,
ref_scale_inv_forward,
)
torch.testing.assert_close(
fp8_meta[backward_key].amax_history[1:],
ref_amax_history_backward[1:],
)
torch.testing.assert_close(
fp8_meta[backward_key].scale,
ref_scale_backward,
)
torch.testing.assert_close(
fp8_meta[backward_key].scale_inv,
ref_scale_inv_backward,
)
3 changes: 2 additions & 1 deletion transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ list(APPEND transformer_engine_SOURCES
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_rope/fused_rope.cu)
fused_rope/fused_rope.cu
recipe/delayed_scaling.cu)
add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
target_include_directories(transformer_engine PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/include")
Expand Down
63 changes: 63 additions & 0 deletions transformer_engine/common/include/transformer_engine/recipe.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

/*! \file recipe.h
* \brief Functions handling FP8 recipes.
*/

#ifndef TRANSFORMER_ENGINE_RECIPE_H_
#define TRANSFORMER_ENGINE_RECIPE_H_

#include "transformer_engine.h"

#ifdef __cplusplus
extern "C" {
#endif

/*! \brief Update FP8 scaling factors with delayed scaling recipe.
*
* The amax history is rotated by -1 (e.g. the first entry shifts to
* the last, the last entry shifts to the second to last) and the
* first entry is set to zero. The scaling factor is estimated so the
* FP8 tensor's maximum absolute value is
* @f$ 2^{-\text{margin}} \text{max}_\text{fp8\_dtype} @f$.
*
* \param[in] amax_history History of maximum absolute values.
* Shape: [history_length, num_scales]
* \param[in] scale Scaling factor for casting to FP8. Shape: [num_scales]
* \param[in] scale_inv Scaling factor for casting from FP8. Shape: [num_scales]
* \param[in] scale_inv_mask Boolean mask indicating scale_inv entries to update. May be
* empty, in which case all scale_inv entries are updated.
* Shape: [num_scales]
* \param[out] updated_amax_history Updated history of maximum absolute values.
* Shape: [history_length, num_scales]
* \param[out] updated_scale Updated scaling factor for casting to FP8.
* Shape: [num_scales]
* \param[out] updated_scale_inv Updated scaling factor for casting from FP8.
* Shape: [num_scales]
* \param[in] amax_compute_algo Method to reduce amax history. Options are "max" and
* "most_recent".
* \param[in] fp8_dtype FP8 datatype.
* \param[in] margin Scaling factor margin.
* \param[in] stream CUDA stream.
*/
void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_history,
const NVTETensor scale,
const NVTETensor scale_inv,
const NVTETensor scale_inv_mask,
NVTETensor updated_amax_history,
NVTETensor updated_scale,
NVTETensor updated_scale_inv,
const char* amax_compute_algo,
NVTEDType fp8_dtype,
float margin,
cudaStream_t stream);

#ifdef __cplusplus
} // extern "C"
#endif

#endif // TRANSFORMER_ENGINE_RECIPE_H_
File renamed without changes.
Loading

0 comments on commit a950061

Please sign in to comment.