Skip to content

Commit

Permalink
PR #16775: Add test for EmitReducePrecisionIR
Browse files Browse the repository at this point in the history
Imported from GitHub PR #16775

I noticed that the `EmitReducePrecisionIR` function from `xla/service/elemental_ir_emitter.h` is not covered by unit tests.

Given its non-trivial logic, I believe it should be thoroughly tested, particularly for corner cases.

Changes in this PR:
- Declare `EmitReducePrecisionIR` function in `xla/service/elemental_ir_emitter.h`
- Add `EmitReducePrecisionIR_F16ToF8e5m2` test
- Add `EmitReducePrecisionIR_F16ToF8e4m3fn` test

Related PR:
- [PR-16585](#16585) Add support for float8_e4m3

Copybara import of the project:

--
5972205 by Alexander Pivovarov <pivovaa@amazon.com>:

Add test for EmitReducePrecisionIR

Merging this change closes #16775

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16775 from apivovarov:elemental_ir_emitter_test 5972205
PiperOrigin-RevId: 696646489
  • Loading branch information
apivovarov authored and Google-ML-Automation committed Nov 14, 2024
1 parent bf5761c commit f9c01b1
Show file tree
Hide file tree
Showing 9 changed files with 303 additions and 545 deletions.
12 changes: 0 additions & 12 deletions third_party/llvm/generated.patch
Original file line number Diff line number Diff line change
@@ -1,13 +1 @@
Auto generated patch. Do not edit or delete it, even if empty.
diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -1465,7 +1465,7 @@
: Suffix;

ValueSet StructValues;
- StructType *StructTy;
+ StructType *StructTy = nullptr;
Function *newFunction = constructFunctionDeclaration(
inputs, outputs, EntryFreq, oldFunction->getName() + "." + SuffixToUse,
StructValues, StructTy);
4 changes: 2 additions & 2 deletions third_party/llvm/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive")

def repo(name):
"""Imports LLVM."""
LLVM_COMMIT = "97298853b4de70dbce9c0a140ac38e3ac179e02e"
LLVM_SHA256 = "ac811cb61d281043c865c39260a5114a0e96d16ec0e4eb74a2516a24981b9064"
LLVM_COMMIT = "03730cdd3d10c5270fe436777a37d50b0838a3bf"
LLVM_SHA256 = "54d843249c75b200f7bf9b7947079fe16fa0b657c4aee4abdde4ac05a9cd5f84"

tf_http_archive(
name = name,
Expand Down
612 changes: 99 additions & 513 deletions third_party/shardy/temporary.patch

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions third_party/shardy/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")

def repo():
SHARDY_COMMIT = "879e94b974bc47a37263a897fe5fb83d8b52e266"
SHARDY_SHA256 = "a08491556df5185b37a4d94fda3ce5c71233f0319b3efbb45f00120fbcad973a"
SHARDY_COMMIT = "c3f3c8e9ae90470e24891428072ff8cbc7445e95"
SHARDY_SHA256 = "92dd51c7de67e3b0bd09f388f1f71bebb44547a98b33ee9ece66eba0fac439d1"

tf_http_archive(
name = "shardy",
Expand Down
12 changes: 0 additions & 12 deletions third_party/tsl/third_party/llvm/generated.patch
Original file line number Diff line number Diff line change
@@ -1,13 +1 @@
Auto generated patch. Do not edit or delete it, even if empty.
diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -1465,7 +1465,7 @@
: Suffix;

ValueSet StructValues;
- StructType *StructTy;
+ StructType *StructTy = nullptr;
Function *newFunction = constructFunctionDeclaration(
inputs, outputs, EntryFreq, oldFunction->getName() + "." + SuffixToUse,
StructValues, StructTy);
4 changes: 2 additions & 2 deletions third_party/tsl/third_party/llvm/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive")

def repo(name):
"""Imports LLVM."""
LLVM_COMMIT = "97298853b4de70dbce9c0a140ac38e3ac179e02e"
LLVM_SHA256 = "ac811cb61d281043c865c39260a5114a0e96d16ec0e4eb74a2516a24981b9064"
LLVM_COMMIT = "03730cdd3d10c5270fe436777a37d50b0838a3bf"
LLVM_SHA256 = "54d843249c75b200f7bf9b7947079fe16fa0b657c4aee4abdde4ac05a9cd5f84"

tf_http_archive(
name = name,
Expand Down
4 changes: 2 additions & 2 deletions xla/service/elemental_ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@ using llvm_ir::SetToFirstInsertPoint;
using xla::float8_fnuz_ir_emitter::EmitF8fnuzToFloating;
using xla::float8_fnuz_ir_emitter::EmitFloatingToF8fnuz;

namespace {

absl::StatusOr<llvm::Value*> EmitReducePrecisionIR(
PrimitiveType src_ty, llvm::Value* x, int64_t dest_exponent_bits,
int64_t dest_mantissa_bits, bool quiet_nans, llvm::IRBuilderBase* b) {
Expand Down Expand Up @@ -231,6 +229,8 @@ absl::StatusOr<llvm::Value*> EmitReducePrecisionIR(
return result;
}

namespace {

template <int f8_exponent_bits>
llvm::Value* handle_halfway_points_F16ToF8(llvm::Value* f16_abs_bits,
llvm::Value* f8_bits,
Expand Down
4 changes: 4 additions & 0 deletions xla/service/elemental_ir_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,10 @@ class ElementalIrEmitterForTests : public ElementalIrEmitter {
HloToElementGeneratorMap generator_map_;
};

absl::StatusOr<llvm::Value*> EmitReducePrecisionIR(
PrimitiveType src_ty, llvm::Value* x, int64_t dest_exponent_bits,
int64_t dest_mantissa_bits, bool quiet_nans, llvm::IRBuilderBase* b);

} // namespace xla

#endif // XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_
192 changes: 192 additions & 0 deletions xla/service/elemental_ir_emitter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ limitations under the License.
#include "xla/literal_util.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/llvm_ir/ir_array.h"
#include "xla/service/llvm_ir/llvm_util.h"
#include "xla/test.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tests/test_macros.h"
Expand All @@ -48,6 +49,11 @@ namespace {

using std::nullopt;

struct EmitReducePrecisionIrTestCase {
float input;
std::string expected_res;
};

class ElementalIrEmitterExecutionTest : public HloTestBase {
protected:
void RunTest(const std::string& hlo_text, absl::Span<Literal* const> args) {
Expand Down Expand Up @@ -123,6 +129,192 @@ ENTRY main {
RunTest(hlo_text, {&lhs, &rhs});
}

XLA_TEST_F(ElementalIrEmitterExecutionTest, EmitReducePrecisionIR_F16ToF8e5m2) {
llvm::LLVMContext llvm_context;
llvm::IRBuilder<> builder(llvm_context);
llvm::IRBuilderBase* b = &builder;
llvm::Type* f16_type = b->getHalfTy();

float inf = std::numeric_limits<float>::infinity();
float qnan = std::numeric_limits<float>::quiet_NaN();
float snan = std::numeric_limits<float>::signaling_NaN();

EmitReducePrecisionIrTestCase test_cases[] = {
// clang-format off
{0.0, "half 0xH0000"},
{0x1.0p-14, "half 0xH0400"},
{0.250, "half 0xH3400"},
{1.0, "half 0xH3C00"},
{0x1.2p0, "half 0xH3C00"},
{0x1.Cp15, "half 0xH7B00"},
{-0x1.Cp15, "half 0xHFB00"},
{0x1.Dp15, "half 0xH7B00"},
{0x1.Ep15, "half 0xH7C00"},
{0x1.0p16, "half 0xH7C00"},
{inf, "half 0xH7C00"},
{-inf, "half 0xHFC00"},
{qnan, "half 0xH7E00"},
{-qnan, "half 0xHFE00"},
{snan, "half 0xH7F00"},
{-snan, "half 0xHFF00"},
// clang-format on
};

for (auto tc : test_cases) {
llvm::Value* c0 = llvm::ConstantFP::get(f16_type, tc.input);

absl::StatusOr<llvm::Value*> f16_reduced_statusor = EmitReducePrecisionIR(
/*src_ty=*/F16, c0,
/*dest_exponent_bits=*/primitive_util::ExponentWidth(F8E5M2),
/*dest_mantissa_bits=*/primitive_util::SignificandWidth(F8E5M2) - 1,
/*quiet_nans=*/true, b);
CHECK(f16_reduced_statusor.ok());
llvm::Value* f16_reduced = f16_reduced_statusor.value();

std::string res = llvm_ir::DumpToString(f16_reduced);
EXPECT_EQ(res, tc.expected_res) << "Wrong result for input " << tc.input;
}
}

XLA_TEST_F(ElementalIrEmitterExecutionTest, EmitReducePrecisionIR_F16ToF8e4m3) {
llvm::LLVMContext llvm_context;
llvm::IRBuilder<> builder(llvm_context);
llvm::IRBuilderBase* b = &builder;
llvm::Type* f16_type = b->getHalfTy();

float inf = std::numeric_limits<float>::infinity();
float qnan = std::numeric_limits<float>::quiet_NaN();
float snan = std::numeric_limits<float>::signaling_NaN();

EmitReducePrecisionIrTestCase test_cases[] = {
// clang-format off
{0.0, "half 0xH0000"},
{0x1.0p-6, "half 0xH2400"},
{0.125, "half 0xH3000"},
{1.0, "half 0xH3C00"},
{0x1.1p0, "half 0xH3C00"},
{0x1.Ep7, "half 0xH5B80"},
{-0x1.Ep7, "half 0xHDB80"},
{0x1.E8p7, "half 0xH5B80"},
{0x1.Fp7, "half 0xH7C00"},
{0x1.0p8, "half 0xH7C00"},
{inf, "half 0xH7C00"},
{-inf, "half 0xHFC00"},
{qnan, "half 0xH7E00"},
{-qnan, "half 0xHFE00"},
{snan, "half 0xH7E00"},
{-snan, "half 0xHFE00"},
// clang-format on
};

for (auto tc : test_cases) {
llvm::Value* c0 = llvm::ConstantFP::get(f16_type, tc.input);

absl::StatusOr<llvm::Value*> f16_reduced_statusor = EmitReducePrecisionIR(
/*src_ty=*/F16, c0,
/*dest_exponent_bits=*/4,
/*dest_mantissa_bits=*/3,
/*quiet_nans=*/true, b);
CHECK(f16_reduced_statusor.ok());
llvm::Value* f16_reduced = f16_reduced_statusor.value();

std::string res = llvm_ir::DumpToString(f16_reduced);
EXPECT_EQ(res, tc.expected_res) << "Wrong result for input " << tc.input;
}
}

XLA_TEST_F(ElementalIrEmitterExecutionTest, EmitReducePrecisionIR_F16ToF8e3m4) {
llvm::LLVMContext llvm_context;
llvm::IRBuilder<> builder(llvm_context);
llvm::IRBuilderBase* b = &builder;
llvm::Type* f16_type = b->getHalfTy();

float inf = std::numeric_limits<float>::infinity();
float qnan = std::numeric_limits<float>::quiet_NaN();
float snan = std::numeric_limits<float>::signaling_NaN();

EmitReducePrecisionIrTestCase test_cases[] = {
// clang-format off
{0.0, "half 0xH0000"},
{0x1.0p-2, "half 0xH3400"},
{0.5, "half 0xH3800"},
{1.0, "half 0xH3C00"},
{0x1.08p0, "half 0xH3C00"},
{0x1.Fp3, "half 0xH4BC0"},
{-0x1.Fp3, "half 0xHCBC0"},
{0x1.F4p3, "half 0xH4BC0"},
{0x1.F8p3, "half 0xH7C00"},
{0x1.0p4, "half 0xH7C00"},
{inf, "half 0xH7C00"},
{-inf, "half 0xHFC00"},
{qnan, "half 0xH7E00"},
{-qnan, "half 0xHFE00"},
{snan, "half 0xH7E00"},
{-snan, "half 0xHFE00"},
// clang-format on
};

for (auto tc : test_cases) {
llvm::Value* c0 = llvm::ConstantFP::get(f16_type, tc.input);

absl::StatusOr<llvm::Value*> f16_reduced_statusor = EmitReducePrecisionIR(
/*src_ty=*/F16, c0,
/*dest_exponent_bits=*/3,
/*dest_mantissa_bits=*/4,
/*quiet_nans=*/true, b);
CHECK(f16_reduced_statusor.ok());
llvm::Value* f16_reduced = f16_reduced_statusor.value();

std::string res = llvm_ir::DumpToString(f16_reduced);
EXPECT_EQ(res, tc.expected_res) << "Wrong result for input " << tc.input;
}
}

XLA_TEST_F(ElementalIrEmitterExecutionTest,
EmitReducePrecisionIR_F16ToF8e4m3fn) {
llvm::LLVMContext llvm_context;
llvm::IRBuilder<> builder(llvm_context);
llvm::IRBuilderBase* b = &builder;
llvm::Type* f16_type = b->getHalfTy();

float inf = std::numeric_limits<float>::infinity();

EmitReducePrecisionIrTestCase test_cases[] = {
// clang-format off
{0.0, "half 0xH0000"},
{0x1.0p-6, "half 0xH2400"},
{0.125, "half 0xH3000"},
{1.0, "half 0xH3C00"},
{0x1.1p0, "half 0xH3C00"},
{0x1.Cp8, "half 0xH5F00"},
{-0x1.Cp8, "half 0xHDF00"},
{0x1.Dp8, "half 0xH5F00"},
{0x1.Ep8, "half 0xH5F80"},
{0x1.0p9, "half 0xH6000"},
{inf, "half 0xH7C00"},
{-inf, "half 0xHFC00"},
// clang-format on
};

for (auto tc : test_cases) {
llvm::Value* c0 = llvm::ConstantFP::get(f16_type, tc.input);

// Truncate the mantissa to 3 bits. ReducePrecision cannot deal with
// f8E4M3FN's NaN representations, so don't use ReducePrecision to handle
// exponent reduction.
absl::StatusOr<llvm::Value*> f16_reduced_statusor = EmitReducePrecisionIR(
/*src_ty=*/F16, c0,
/*dest_exponent_bits=*/5,
/*dest_mantissa_bits=*/3,
/*quiet_nans=*/false, b);
CHECK(f16_reduced_statusor.ok());
llvm::Value* f16_reduced = f16_reduced_statusor.value();

std::string res = llvm_ir::DumpToString(f16_reduced);
EXPECT_EQ(res, tc.expected_res) << "Wrong result for input " << tc.input;
}
}

XLA_TEST_F(ElementalIrEmitterExecutionTest, ScalarDotFusion) {
const char* hlo_text = R"(
HloModule ScalarDotFusion
Expand Down

0 comments on commit f9c01b1

Please sign in to comment.