Skip to content

Commit

Permalink
Remove use_bfloat16 from reduce_window_test.cc
Browse files Browse the repository at this point in the history
This CL removes the use_bfloat16 flag from reduce_window_test.cc and replaces it with the test_type flag. This allows the test to be run with any of the supported float types, not just BF16.

PiperOrigin-RevId: 689419354
  • Loading branch information
Google-ML-Automation committed Oct 24, 2024
1 parent 2f99455 commit dfc3fff
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 77 deletions.
10 changes: 7 additions & 3 deletions xla/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1772,24 +1772,28 @@ xla_test_library(
deps = [
":client_library_test_base",
":hlo_test_base",
":literal_test_util",
":test_macros_header",
":xla_internal_test_main",
"//xla:array2d",
"//xla:array3d",
"//xla:array4d",
"//xla:error_spec",
"//xla:literal",
"//xla:literal_util",
"//xla:reference_util",
"//xla:shape_util",
"//xla:xla_data_proto_cc",
"//xla/client:local_client",
"//xla/hlo/builder:padding",
"//xla/hlo/builder:xla_builder",
"//xla/hlo/builder:xla_computation",
"//xla/hlo/builder/lib:arithmetic",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:status",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
],
)
Expand Down
24 changes: 0 additions & 24 deletions xla/tests/client_library_test_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,23 +44,6 @@ limitations under the License.

namespace xla {

// Sets the use_bfloat16 on a container of test cases according to the values in
// use_bfloat16_params. Generates one set of test cases for each values in
// use_bfloat16_params with that value. Returns the result.
template <typename TestCase>
std::vector<TestCase> ExpandUseBfloat16(
absl::Span<const bool> use_bfloat16_params,
absl::Span<const TestCase> specs) {
std::vector<TestCase> expanded;
for (bool use_bfloat16 : use_bfloat16_params) {
for (const auto& spec : specs) {
expanded.push_back(spec);
expanded.back().use_bfloat16 = use_bfloat16;
}
}
return expanded;
}

template <typename TestCase>
std::vector<TestCase> ExpandTestType(
absl::Span<const PrimitiveType> test_type_params,
Expand Down Expand Up @@ -413,13 +396,6 @@ class ClientLibraryTestBase : public ::testing::Test {
XlaBuilder* builder,
XlaOp* data_handle);

// TODO(ralphnathan): These will eventually be removed. Please have new tests
// support multiple primitive types, not just BF16.
// Getter and setter for the test_type flag, which indicates whether to run
// tests with all float-type input/output converted to bfloat16.
bool use_bfloat16() const { return test_type_ == BF16; }
void set_use_bfloat16(bool value) { test_type_ = value ? BF16 : F32; }

// The float type used in this test.
PrimitiveType FloatType() const { return test_type_; }
void set_float_type(PrimitiveType type) { test_type_ = type; }
Expand Down
119 changes: 69 additions & 50 deletions xla/tests/reduce_window_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,51 +15,66 @@ limitations under the License.

// Tests the reduce-window XLA operation.

#include <limits>
#include <algorithm>
#include <array>
#include <cstdint>
#include <iterator>
#include <memory>

#include <numeric>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include <gtest/gtest.h>
#include "absl/log/check.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "xla/array2d.h"
#include "xla/array3d.h"
#include "xla/array4d.h"
#include "xla/client/local_client.h"
#include "xla/error_spec.h"
#include "xla/hlo/builder/lib/arithmetic.h"
#include "xla/hlo/builder/padding.h"
#include "xla/hlo/builder/xla_builder.h"
#include "xla/hlo/builder/xla_computation.h"
#include "xla/layout_util.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/primitive_util.h"
#include "xla/reference_util.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tests/client_library_test_base.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tests/literal_test_util.h"
#include "xla/tests/test_macros.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/status.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"

namespace xla {
namespace {

static std::array<bool, 2> use_bfloat16_params{false, true};
static std::array<PrimitiveType, 2> test_type_params = {F32, BF16};

class ReduceWindowTestBase : public ClientLibraryTestBase {
public:
ErrorSpec DefaultErrorSpec() const {
if (use_bfloat16()) {
if (FloatType() == BF16) {
return ErrorSpec(2e-1, 6e-2);
} else {
return ErrorSpec(1e-3, 1e-3);
}
}
};

class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
class ReduceWindowTest : public ::testing::WithParamInterface<PrimitiveType>,
public ReduceWindowTestBase {
public:
ReduceWindowTest() : builder_(TestName()) { set_use_bfloat16(GetParam()); }
ReduceWindowTest() : builder_(TestName()) { set_float_type(GetParam()); }

void ReduceWindowAdd(const XlaOp input,
absl::Span<const int64_t> window_dimensions,
Expand Down Expand Up @@ -563,7 +578,7 @@ XLA_TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) {
}

INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest,
::testing::ValuesIn(use_bfloat16_params));
::testing::ValuesIn(test_type_params));

enum Reducer { kAdd, kMax };

Expand All @@ -580,7 +595,7 @@ struct R4ReduceWindowTestData {

std::string R4ReduceWindowTestDataToString(
const ::testing::TestParamInfo<
::testing::tuple<R4ReduceWindowTestData, bool>>& data) {
::testing::tuple<R4ReduceWindowTestData, PrimitiveType>>& data) {
const auto& param = ::testing::get<0>(data.param);
std::string str = absl::StrCat(
"base_bounds_", absl::StrJoin(param.base_bounds, "x"), //
Expand All @@ -594,17 +609,18 @@ std::string R4ReduceWindowTestDataToString(

// Test names are not allowed to contain the '-' character.
std::replace(str.begin(), str.end(), '-', 'n');
if (::testing::get<1>(data.param)) {
absl::StrAppend(&str, "_bfloat16");
}
absl::StrAppend(&str, "_",
primitive_util::LowercasePrimitiveTypeName(
::testing::get<1>(data.param)));
return str;
}

class R4ReduceWindowTest : public ReduceWindowTestBase,
public ::testing::WithParamInterface<
::testing::tuple<R4ReduceWindowTestData, bool>> {
class R4ReduceWindowTest
: public ReduceWindowTestBase,
public ::testing::WithParamInterface<
::testing::tuple<R4ReduceWindowTestData, PrimitiveType>> {
protected:
R4ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
R4ReduceWindowTest() { set_float_type(::testing::get<1>(GetParam())); }

void DoIt() {
XlaBuilder b(TestName());
Expand Down Expand Up @@ -878,7 +894,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
INSTANTIATE_TEST_CASE_P(
R4ReduceWindowTestInstantiation, R4ReduceWindowTest,
::testing::Combine(::testing::ValuesIn(kR4ReduceWindowTestValues),
::testing::ValuesIn(use_bfloat16_params)),
::testing::ValuesIn(test_type_params)),
R4ReduceWindowTestDataToString);

class R4ReduceWindowLargeTest : public R4ReduceWindowTest {};
Expand Down Expand Up @@ -967,7 +983,7 @@ const R4ReduceWindowTestData kR4ReduceWindowLargeTestValues[] = {
INSTANTIATE_TEST_CASE_P(
R4ReduceWindowLargeTestInstantiation, R4ReduceWindowLargeTest,
::testing::Combine(::testing::ValuesIn(kR4ReduceWindowLargeTestValues),
::testing::ValuesIn(use_bfloat16_params)),
::testing::ValuesIn(test_type_params)),
R4ReduceWindowTestDataToString);

struct R3ReduceWindowTestData {
Expand Down Expand Up @@ -1017,7 +1033,7 @@ R3ReduceWindowTestData kR3TestCases[] = {

std::string R3ReduceWindowTestDataToString(
const ::testing::TestParamInfo<
::testing::tuple<R3ReduceWindowTestData, bool>>& data) {
::testing::tuple<R3ReduceWindowTestData, PrimitiveType>>& data) {
const auto& param = ::testing::get<0>(data.param);
std::string str = absl::StrCat(
"base_bounds_", absl::StrJoin(param.base_bounds, "x"), "__window_bounds_",
Expand All @@ -1026,17 +1042,18 @@ std::string R3ReduceWindowTestDataToString(
param.padding == Padding::kSame ? "same" : "valid", "__layout_",
param.layout[0], "_", param.layout[1], "_", param.layout[2], "__reducer_",
param.reducer == kAdd ? "add" : "max");
if (::testing::get<1>(data.param)) {
absl::StrAppend(&str, "_bfloat16");
}
absl::StrAppend(&str, "_",
primitive_util::LowercasePrimitiveTypeName(
::testing::get<1>(data.param)));
return str;
}

class R3ReduceWindowTest : public ReduceWindowTestBase,
public ::testing::WithParamInterface<
::testing::tuple<R3ReduceWindowTestData, bool>> {
class R3ReduceWindowTest
: public ReduceWindowTestBase,
public ::testing::WithParamInterface<
::testing::tuple<R3ReduceWindowTestData, PrimitiveType>> {
protected:
R3ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
R3ReduceWindowTest() { set_float_type(::testing::get<1>(GetParam())); }

void DoIt() {
XlaBuilder b(TestName());
Expand All @@ -1052,7 +1069,7 @@ class R3ReduceWindowTest : public ReduceWindowTestBase,
Literal input_literal = LiteralUtil::CreateR3FromArray3DWithLayout(
input, LayoutUtil::MakeLayout(param.layout));
auto reducer = param.reducer;
if (use_bfloat16()) {
if (FloatType() == BF16) {
input_literal = LiteralUtil::ConvertF32ToBF16(input_literal);

// To avoid numerical issues, force the reducer to be kMax for bf16
Expand Down Expand Up @@ -1083,7 +1100,7 @@ XLA_TEST_P(R3ReduceWindowTest, DoIt) { DoIt(); }
INSTANTIATE_TEST_CASE_P(
R3ReduceWindowTestInstantiation, R3ReduceWindowTest,
::testing::Combine(::testing::ValuesIn(kR3TestCases),
::testing::ValuesIn(use_bfloat16_params)),
::testing::ValuesIn(test_type_params)),
R3ReduceWindowTestDataToString);

class R3ReduceWindowLargeTest : public R3ReduceWindowTest {};
Expand All @@ -1106,7 +1123,7 @@ const R3ReduceWindowTestData kR3ReduceWindowLargeTestValues[] = {
INSTANTIATE_TEST_CASE_P(
R3ReduceWindowLargeTestInstantiation, R3ReduceWindowLargeTest,
::testing::Combine(::testing::ValuesIn(kR3ReduceWindowLargeTestValues),
::testing::ValuesIn(use_bfloat16_params)),
::testing::ValuesIn(test_type_params)),
R3ReduceWindowTestDataToString);

struct R2ReduceWindowTestData {
Expand Down Expand Up @@ -1268,7 +1285,7 @@ struct R2ReduceWindowTestData {

std::string R2ReduceWindowTestDataToString(
const ::testing::TestParamInfo<
::testing::tuple<R2ReduceWindowTestData, bool>>& data) {
::testing::tuple<R2ReduceWindowTestData, PrimitiveType>>& data) {
const auto& param = ::testing::get<0>(data.param);
std::string str = absl::StrCat(
"base_bounds_", absl::StrJoin(param.base_bounds, "x"), //
Expand All @@ -1283,24 +1300,25 @@ std::string R2ReduceWindowTestDataToString(

// Test names are not allowed to contain the '-' character.
std::replace(str.begin(), str.end(), '-', 'n');
if (::testing::get<1>(data.param)) {
absl::StrAppend(&str, "_bfloat16");
}
absl::StrAppend(&str, "_",
primitive_util::LowercasePrimitiveTypeName(
::testing::get<1>(data.param)));
return str;
}

class R2ReduceWindowTest : public ReduceWindowTestBase,
public ::testing::WithParamInterface<
::testing::tuple<R2ReduceWindowTestData, bool>> {
class R2ReduceWindowTest
: public ReduceWindowTestBase,
public ::testing::WithParamInterface<
::testing::tuple<R2ReduceWindowTestData, PrimitiveType>> {
protected:
R2ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
R2ReduceWindowTest() { set_float_type(::testing::get<1>(GetParam())); }

void DoIt() {
XlaBuilder b(TestName());
const auto& param = ::testing::get<0>(GetParam());

Array2D<float> input(param.base_bounds[0], param.base_bounds[1], 1.0f);
if (!::testing::get<1>(GetParam())) {
if (FloatType() == F32) {
// We only do this in F32 mode, to avoid precision issues with BF16.
input = *MakeLinspaceArray2D(0, 100, param.base_bounds[0],
param.base_bounds[1]);
Expand Down Expand Up @@ -1343,7 +1361,7 @@ XLA_TEST_P(R2ReduceWindowTest, DoIt) { DoIt(); }
INSTANTIATE_TEST_CASE_P(
R2ReduceWindowTestInstantiation, R2ReduceWindowTest,
::testing::Combine(::testing::ValuesIn(kR2TestCases),
::testing::ValuesIn(use_bfloat16_params)),
::testing::ValuesIn(test_type_params)),
R2ReduceWindowTestDataToString);

struct R1ReduceWindowTestData {
Expand Down Expand Up @@ -1499,7 +1517,7 @@ struct R1ReduceWindowTestData {

std::string R1ReduceWindowTestDataToString(
const ::testing::TestParamInfo<
::testing::tuple<R1ReduceWindowTestData, bool>>& data) {
::testing::tuple<R1ReduceWindowTestData, PrimitiveType>>& data) {
const auto& param = ::testing::get<0>(data.param);
std::string str =
absl::StrCat("base_bounds_", absl::StrJoin(param.base_bounds, "x"),
Expand All @@ -1511,17 +1529,18 @@ std::string R1ReduceWindowTestDataToString(

// Test names are not allowed to contain the '-' character.
std::replace(str.begin(), str.end(), '-', 'n');
if (::testing::get<1>(data.param)) {
absl::StrAppend(&str, "_bfloat16");
}
absl::StrAppend(&str, "_",
primitive_util::LowercasePrimitiveTypeName(
::testing::get<1>(data.param)));
return str;
}

class R1ReduceWindowTest : public ReduceWindowTestBase,
public ::testing::WithParamInterface<
::testing::tuple<R1ReduceWindowTestData, bool>> {
class R1ReduceWindowTest
: public ReduceWindowTestBase,
public ::testing::WithParamInterface<
::testing::tuple<R1ReduceWindowTestData, PrimitiveType>> {
protected:
R1ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
R1ReduceWindowTest() { set_float_type(::testing::get<1>(GetParam())); }
};

XLA_TEST_P(R1ReduceWindowTest, DoIt) {
Expand Down Expand Up @@ -1575,7 +1594,7 @@ XLA_TEST_P(R1ReduceWindowTest, DoIt) {
INSTANTIATE_TEST_CASE_P(
R1ReduceWindowTestInstantiation, R1ReduceWindowTest,
::testing::Combine(::testing::ValuesIn(kR1TestCases),
::testing::ValuesIn(use_bfloat16_params)),
::testing::ValuesIn(test_type_params)),
R1ReduceWindowTestDataToString);

// Test class for text-based test cases. Note that this compares with the
Expand Down

0 comments on commit dfc3fff

Please sign in to comment.