From dfc3fffd146fceb85af0a3e81a9e933697d0a854 Mon Sep 17 00:00:00 2001 From: xla authors Date: Thu, 24 Oct 2024 09:57:27 -0700 Subject: [PATCH] Remove use_bfloat16 from reduce_window_test.cc This CL removes the use_bfloat16 flag from reduce_window_test.cc and replaces it with the test_type flag. This allows the test to be run with any of the supported float types, not just BF16. PiperOrigin-RevId: 689419354 --- xla/tests/BUILD | 10 ++- xla/tests/client_library_test_base.h | 24 ------ xla/tests/reduce_window_test.cc | 119 ++++++++++++++++----------- 3 files changed, 76 insertions(+), 77 deletions(-) diff --git a/xla/tests/BUILD b/xla/tests/BUILD index 825e10edfcda9..9bb85dc36157b 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -1772,24 +1772,28 @@ xla_test_library( deps = [ ":client_library_test_base", ":hlo_test_base", - ":literal_test_util", ":test_macros_header", ":xla_internal_test_main", "//xla:array2d", "//xla:array3d", "//xla:array4d", + "//xla:error_spec", + "//xla:literal", + "//xla:literal_util", "//xla:reference_util", "//xla:shape_util", "//xla:xla_data_proto_cc", - "//xla/client:local_client", "//xla/hlo/builder:padding", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", "//xla/hlo/builder/lib:arithmetic", "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:status", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", ], ) diff --git a/xla/tests/client_library_test_base.h b/xla/tests/client_library_test_base.h index 6fd5006b4891f..016b73b6e7682 100644 --- a/xla/tests/client_library_test_base.h +++ b/xla/tests/client_library_test_base.h @@ -44,23 +44,6 @@ limitations under the License. namespace xla { -// Sets the use_bfloat16 on a container of test cases according to the values in -// use_bfloat16_params. Generates one set of test cases for each values in -// use_bfloat16_params with that value. Returns the result. -template -std::vector ExpandUseBfloat16( - absl::Span use_bfloat16_params, - absl::Span specs) { - std::vector expanded; - for (bool use_bfloat16 : use_bfloat16_params) { - for (const auto& spec : specs) { - expanded.push_back(spec); - expanded.back().use_bfloat16 = use_bfloat16; - } - } - return expanded; -} - template std::vector ExpandTestType( absl::Span test_type_params, @@ -413,13 +396,6 @@ class ClientLibraryTestBase : public ::testing::Test { XlaBuilder* builder, XlaOp* data_handle); - // TODO(ralphnathan): These will eventually be removed. Please have new tests - // support multiple primitive types, not just BF16. - // Getter and setter for the test_type flag, which indicates whether to run - // tests with all float-type input/output converted to bfloat16. - bool use_bfloat16() const { return test_type_ == BF16; } - void set_use_bfloat16(bool value) { test_type_ = value ? BF16 : F32; } - // The float type used in this test. PrimitiveType FloatType() const { return test_type_; } void set_float_type(PrimitiveType type) { test_type_ = type; } diff --git a/xla/tests/reduce_window_test.cc b/xla/tests/reduce_window_test.cc index 4b4a257ca7f10..e6e374e95e856 100644 --- a/xla/tests/reduce_window_test.cc +++ b/xla/tests/reduce_window_test.cc @@ -15,40 +15,55 @@ limitations under the License. // Tests the reduce-window XLA operation. -#include +#include +#include +#include +#include #include - +#include +#include +#include +#include +#include + +#include +#include "absl/log/check.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "xla/array2d.h" #include "xla/array3d.h" #include "xla/array4d.h" -#include "xla/client/local_client.h" +#include "xla/error_spec.h" #include "xla/hlo/builder/lib/arithmetic.h" #include "xla/hlo/builder/padding.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" +#include "xla/layout_util.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/primitive_util.h" #include "xla/reference_util.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/status.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { namespace { -static std::array use_bfloat16_params{false, true}; +static std::array test_type_params = {F32, BF16}; class ReduceWindowTestBase : public ClientLibraryTestBase { public: ErrorSpec DefaultErrorSpec() const { - if (use_bfloat16()) { + if (FloatType() == BF16) { return ErrorSpec(2e-1, 6e-2); } else { return ErrorSpec(1e-3, 1e-3); @@ -56,10 +71,10 @@ class ReduceWindowTestBase : public ClientLibraryTestBase { } }; -class ReduceWindowTest : public ::testing::WithParamInterface, +class ReduceWindowTest : public ::testing::WithParamInterface, public ReduceWindowTestBase { public: - ReduceWindowTest() : builder_(TestName()) { set_use_bfloat16(GetParam()); } + ReduceWindowTest() : builder_(TestName()) { set_float_type(GetParam()); } void ReduceWindowAdd(const XlaOp input, absl::Span window_dimensions, @@ -563,7 +578,7 @@ XLA_TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { } INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest, - ::testing::ValuesIn(use_bfloat16_params)); + ::testing::ValuesIn(test_type_params)); enum Reducer { kAdd, kMax }; @@ -580,7 +595,7 @@ struct R4ReduceWindowTestData { std::string R4ReduceWindowTestDataToString( const ::testing::TestParamInfo< - ::testing::tuple>& data) { + ::testing::tuple>& data) { const auto& param = ::testing::get<0>(data.param); std::string str = absl::StrCat( "base_bounds_", absl::StrJoin(param.base_bounds, "x"), // @@ -594,17 +609,18 @@ std::string R4ReduceWindowTestDataToString( // Test names are not allowed to contain the '-' character. std::replace(str.begin(), str.end(), '-', 'n'); - if (::testing::get<1>(data.param)) { - absl::StrAppend(&str, "_bfloat16"); - } + absl::StrAppend(&str, "_", + primitive_util::LowercasePrimitiveTypeName( + ::testing::get<1>(data.param))); return str; } -class R4ReduceWindowTest : public ReduceWindowTestBase, - public ::testing::WithParamInterface< - ::testing::tuple> { +class R4ReduceWindowTest + : public ReduceWindowTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> { protected: - R4ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } + R4ReduceWindowTest() { set_float_type(::testing::get<1>(GetParam())); } void DoIt() { XlaBuilder b(TestName()); @@ -878,7 +894,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { INSTANTIATE_TEST_CASE_P( R4ReduceWindowTestInstantiation, R4ReduceWindowTest, ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowTestValues), - ::testing::ValuesIn(use_bfloat16_params)), + ::testing::ValuesIn(test_type_params)), R4ReduceWindowTestDataToString); class R4ReduceWindowLargeTest : public R4ReduceWindowTest {}; @@ -967,7 +983,7 @@ const R4ReduceWindowTestData kR4ReduceWindowLargeTestValues[] = { INSTANTIATE_TEST_CASE_P( R4ReduceWindowLargeTestInstantiation, R4ReduceWindowLargeTest, ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowLargeTestValues), - ::testing::ValuesIn(use_bfloat16_params)), + ::testing::ValuesIn(test_type_params)), R4ReduceWindowTestDataToString); struct R3ReduceWindowTestData { @@ -1017,7 +1033,7 @@ R3ReduceWindowTestData kR3TestCases[] = { std::string R3ReduceWindowTestDataToString( const ::testing::TestParamInfo< - ::testing::tuple>& data) { + ::testing::tuple>& data) { const auto& param = ::testing::get<0>(data.param); std::string str = absl::StrCat( "base_bounds_", absl::StrJoin(param.base_bounds, "x"), "__window_bounds_", @@ -1026,17 +1042,18 @@ std::string R3ReduceWindowTestDataToString( param.padding == Padding::kSame ? "same" : "valid", "__layout_", param.layout[0], "_", param.layout[1], "_", param.layout[2], "__reducer_", param.reducer == kAdd ? "add" : "max"); - if (::testing::get<1>(data.param)) { - absl::StrAppend(&str, "_bfloat16"); - } + absl::StrAppend(&str, "_", + primitive_util::LowercasePrimitiveTypeName( + ::testing::get<1>(data.param))); return str; } -class R3ReduceWindowTest : public ReduceWindowTestBase, - public ::testing::WithParamInterface< - ::testing::tuple> { +class R3ReduceWindowTest + : public ReduceWindowTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> { protected: - R3ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } + R3ReduceWindowTest() { set_float_type(::testing::get<1>(GetParam())); } void DoIt() { XlaBuilder b(TestName()); @@ -1052,7 +1069,7 @@ class R3ReduceWindowTest : public ReduceWindowTestBase, Literal input_literal = LiteralUtil::CreateR3FromArray3DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); auto reducer = param.reducer; - if (use_bfloat16()) { + if (FloatType() == BF16) { input_literal = LiteralUtil::ConvertF32ToBF16(input_literal); // To avoid numerical issues, force the reducer to be kMax for bf16 @@ -1083,7 +1100,7 @@ XLA_TEST_P(R3ReduceWindowTest, DoIt) { DoIt(); } INSTANTIATE_TEST_CASE_P( R3ReduceWindowTestInstantiation, R3ReduceWindowTest, ::testing::Combine(::testing::ValuesIn(kR3TestCases), - ::testing::ValuesIn(use_bfloat16_params)), + ::testing::ValuesIn(test_type_params)), R3ReduceWindowTestDataToString); class R3ReduceWindowLargeTest : public R3ReduceWindowTest {}; @@ -1106,7 +1123,7 @@ const R3ReduceWindowTestData kR3ReduceWindowLargeTestValues[] = { INSTANTIATE_TEST_CASE_P( R3ReduceWindowLargeTestInstantiation, R3ReduceWindowLargeTest, ::testing::Combine(::testing::ValuesIn(kR3ReduceWindowLargeTestValues), - ::testing::ValuesIn(use_bfloat16_params)), + ::testing::ValuesIn(test_type_params)), R3ReduceWindowTestDataToString); struct R2ReduceWindowTestData { @@ -1268,7 +1285,7 @@ struct R2ReduceWindowTestData { std::string R2ReduceWindowTestDataToString( const ::testing::TestParamInfo< - ::testing::tuple>& data) { + ::testing::tuple>& data) { const auto& param = ::testing::get<0>(data.param); std::string str = absl::StrCat( "base_bounds_", absl::StrJoin(param.base_bounds, "x"), // @@ -1283,24 +1300,25 @@ std::string R2ReduceWindowTestDataToString( // Test names are not allowed to contain the '-' character. std::replace(str.begin(), str.end(), '-', 'n'); - if (::testing::get<1>(data.param)) { - absl::StrAppend(&str, "_bfloat16"); - } + absl::StrAppend(&str, "_", + primitive_util::LowercasePrimitiveTypeName( + ::testing::get<1>(data.param))); return str; } -class R2ReduceWindowTest : public ReduceWindowTestBase, - public ::testing::WithParamInterface< - ::testing::tuple> { +class R2ReduceWindowTest + : public ReduceWindowTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> { protected: - R2ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } + R2ReduceWindowTest() { set_float_type(::testing::get<1>(GetParam())); } void DoIt() { XlaBuilder b(TestName()); const auto& param = ::testing::get<0>(GetParam()); Array2D input(param.base_bounds[0], param.base_bounds[1], 1.0f); - if (!::testing::get<1>(GetParam())) { + if (FloatType() == F32) { // We only do this in F32 mode, to avoid precision issues with BF16. input = *MakeLinspaceArray2D(0, 100, param.base_bounds[0], param.base_bounds[1]); @@ -1343,7 +1361,7 @@ XLA_TEST_P(R2ReduceWindowTest, DoIt) { DoIt(); } INSTANTIATE_TEST_CASE_P( R2ReduceWindowTestInstantiation, R2ReduceWindowTest, ::testing::Combine(::testing::ValuesIn(kR2TestCases), - ::testing::ValuesIn(use_bfloat16_params)), + ::testing::ValuesIn(test_type_params)), R2ReduceWindowTestDataToString); struct R1ReduceWindowTestData { @@ -1499,7 +1517,7 @@ struct R1ReduceWindowTestData { std::string R1ReduceWindowTestDataToString( const ::testing::TestParamInfo< - ::testing::tuple>& data) { + ::testing::tuple>& data) { const auto& param = ::testing::get<0>(data.param); std::string str = absl::StrCat("base_bounds_", absl::StrJoin(param.base_bounds, "x"), @@ -1511,17 +1529,18 @@ std::string R1ReduceWindowTestDataToString( // Test names are not allowed to contain the '-' character. std::replace(str.begin(), str.end(), '-', 'n'); - if (::testing::get<1>(data.param)) { - absl::StrAppend(&str, "_bfloat16"); - } + absl::StrAppend(&str, "_", + primitive_util::LowercasePrimitiveTypeName( + ::testing::get<1>(data.param))); return str; } -class R1ReduceWindowTest : public ReduceWindowTestBase, - public ::testing::WithParamInterface< - ::testing::tuple> { +class R1ReduceWindowTest + : public ReduceWindowTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> { protected: - R1ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } + R1ReduceWindowTest() { set_float_type(::testing::get<1>(GetParam())); } }; XLA_TEST_P(R1ReduceWindowTest, DoIt) { @@ -1575,7 +1594,7 @@ XLA_TEST_P(R1ReduceWindowTest, DoIt) { INSTANTIATE_TEST_CASE_P( R1ReduceWindowTestInstantiation, R1ReduceWindowTest, ::testing::Combine(::testing::ValuesIn(kR1TestCases), - ::testing::ValuesIn(use_bfloat16_params)), + ::testing::ValuesIn(test_type_params)), R1ReduceWindowTestDataToString); // Test class for text-based test cases. Note that this compares with the