Skip to content

Commit

Permalink
PR #18413: [XLA:CPU][oneDNN] Refactor and parameterize oneDNN convolu…
Browse files Browse the repository at this point in the history
…tion tests

Imported from GitHub PR #18413

This PR refactors and parametrizes the existing tests in the oneDNN convolution test file. Each test now runs with F32, BF16 and F16 precisions on supported hardware.
Copybara import of the project:

--
396244d by Akhil Goel <akhil.goel@intel.com>:

Refactor oneDNN convolution tests

--
52aa720 by Akhil Goel <akhil.goel@intel.com>:

Address review comments

--
17fb547 by Akhil Goel <akhil.goel@intel.com>:

Address review comments

--
2507d5f by Akhil Goel <akhil.goel@intel.com>:

Declare templates as const char*

Merging this change closes #18413

COPYBARA_INTEGRATE_REVIEW=#18413 from Intel-tensorflow:akhil/conv_fusions_3_a 2507d5f
PiperOrigin-RevId: 689300873
  • Loading branch information
akhilgoe authored and Google-ML-Automation committed Oct 24, 2024
1 parent 34ddf96 commit 3ad67d1
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 113 deletions.
1 change: 1 addition & 0 deletions xla/service/cpu/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ xla_cc_test(
"//xla/tests:hlo_test_base",
"//xla/tests:test_macros_header",
"//xla/tests:xla_internal_test_main",
"@com_google_absl//absl/strings",
"@tsl//tsl/platform:platform_port",
],
)
Expand Down
292 changes: 179 additions & 113 deletions xla/service/cpu/tests/onednn_convolution_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.

#include <utility>

#include "absl/strings/str_replace.h"
#include "xla/hlo/utils/hlo_matchers.h"
#include "xla/literal.h"
#include "xla/service/cpu/onednn_contraction_rewriter.h"
Expand All @@ -32,159 +33,224 @@ limitations under the License.
namespace xla {
namespace cpu {

class ConvolutionTest : public HloTestBase {
class ConvolutionTest : public HloTestBase,
public ::testing::WithParamInterface<PrimitiveType> {
protected:
DebugOptions GetDebugOptionsForTest() const override {
DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest();
debug_options.set_xla_cpu_use_thunk_runtime(false);
return debug_options;
}

const char* conv_rewrite_str_ = R"(
; CHECK: custom_call_target="__onednn$convolution",
; CHECK: backend_config={
; CHECK-DAG: "outer_dimension_partitions":[],
; CHECK-DAG: "onednn_conv_config":{
; CHECK-DAG: }
; CHECK: }
)";
PrimitiveType dtype_;
std::string dtypeString_;
bool user_scratchpad_;
bool weights_prepacked_;
float atol_;
float rtol_;

const char* conv_rewrite_bias_str_ = R"(
constexpr static const char* kConvRewriteStr = R"(
; CHECK: custom_call_target="__onednn$convolution",
; CHECK: backend_config={
; CHECK-DAG: "outer_dimension_partitions":[],
; CHECK-DAG: "onednn_conv_config":{
; CHECK-DAG: "fusions":{
; CHECK-DAG: "ops":["BIAS"]
; CHECK-DAG: }
; CHECK-DAG: }
; CHECK: }
; CHECK-DAG: "onednn_conv_config":{$fusions_str,$opt_config
; CHECK-DAG: }
; CHECK: }
)";

const char* fused_convolution_binary_add_ = R"(
; CHECK: custom_call_target="__onednn$convolution",
; CHECK: backend_config={
; CHECK-DAG: "outer_dimension_partitions":[],
; CHECK-DAG: "onednn_conv_config":{
; CHECK-DAG: "fusions":{
; CHECK-DAG: "ops":["BINARY_ADD"]
; CHECK-DAG: }
; CHECK-DAG: }
; CHECK: }
)";
};
constexpr static const char* kConvRewriteFusionsStr = R"(
; CHECK-DAG: "fusions":{
; CHECK-DAG: "ops":[$fused_ops]
; CHECK-DAG: },)";

constexpr static const char* kConvRewriteOptimizationsStr = R"(
; CHECK-DAG: "optimization_config":{
; CHECK-DAG: "weights_prepacked":$weights_prepacked,
; CHECK-DAG: "user_scratchpad":$user_scratchpad,
; CHECK-DAG: })";

ConvolutionTest() {
dtype_ = GetParam();
atol_ = rtol_ = (dtype_ == F32) ? 1e-4 : 1e-2;
// TODO(intel-tf): Set default value of user_scratchpad to true after
// enabling feature
user_scratchpad_ = false;
weights_prepacked_ = false;
dtypeString_ = primitive_util::LowercasePrimitiveTypeName(dtype_);
}

TEST_F(ConvolutionTest, Simple2DTestF32) {
const char* convolution_module_str = R"(
HloModule convolution.test.f32
ENTRY convolution.test.f32 {
arg.0 = f32[1,22,22,1] parameter(0)
reshape.0 = f32[1,22,22,1] reshape(arg.0)
arg.1 = f32[8,8,1,1] parameter(1)
reshape.1 = f32[8,8,1,1] reshape(arg.1)
convolution.0 = f32[1,11,11,1] convolution(reshape.0, reshape.1), window={size=8x8 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f
reshape.2 = f32[1,11,11,1] reshape(convolution.0)
tuple.0 = (f32[1,11,11,1]) tuple(reshape.2)
ROOT get-tuple-element.0 = f32[1,11,11,1] get-tuple-element(tuple.0), index=0
})";
void SetUp() override {
if (!IsSupportedType(dtype_)) {
GTEST_SKIP() << "CPU does not support " << dtypeString_;
}
}

EXPECT_TRUE(RunAndCompare(convolution_module_str, ErrorSpec{1e-4, 1e-4}));
MatchOptimizedHlo(convolution_module_str, conv_rewrite_str_);
}
void SetWeightsPrepacked(bool value) { weights_prepacked_ = value; }

void SetUserScratchpad(bool value) { user_scratchpad_ = value; }

TEST_F(ConvolutionTest, Simple3DTestBF16) {
if (!IsSupportedType(PrimitiveType::BF16)) {
GTEST_SKIP() << "CPU does not support BF16.";
std::string GetOptimizationsString() {
return (user_scratchpad_ || weights_prepacked_)
? absl::StrReplaceAll(kConvRewriteOptimizationsStr,
{{"$weights_prepacked",
weights_prepacked_ ? "true" : "false"},
{"$user_scratchpad",
user_scratchpad_ ? "true" : "false"}})
: "";
}

const char* convolution_module_str = R"(
HloModule convolution.test.bf16
std::string ConvStringWithOptimizations(
const std::vector<absl::string_view> fused_ops) {
std::ostringstream stream;
std::for_each(
fused_ops.begin(), fused_ops.end(),
[&](const absl::string_view& arg) { stream << "\"" << arg << "\","; });
std::string fusions = stream.str();
if (fused_ops.size() > 0) {
fusions.pop_back();
return absl::StrReplaceAll(
kConvRewriteStr,
{{"$fusions_str,", absl::StrReplaceAll(kConvRewriteFusionsStr,
{{"$fused_ops", fusions}})},
{"$opt_config", GetOptimizationsString()}});
}
return absl::StrReplaceAll(
kConvRewriteStr,
{{"$fusions_str,", ""}, {"$opt_config", GetOptimizationsString()}});
}

ENTRY convolution.test.bf16 {
p0 = bf16[8,4,5,5,1] parameter(0)
p1 = bf16[3,3,3,1,32] parameter(1)
ROOT conv = bf16[8,4,5,5,32] convolution(p0, p1), window={size=3x3x3 pad=1_1x1_1x1_1}, dim_labels=b012f_012io->b012f
})";
// TODO(intel-tf): Remove this and simplify patterns when Elemental BF16 is
// fully supported.
PrimitiveType PromotedDtype() {
// BF16 is promoted to F32 because not all HLO Instructions currently
// support BF16 computations. Meanwhile, FP32 and FP16 elementwise
// instructions are not promoted and remain unchanged.
return (dtype_ == BF16) ? F32 : dtype_;
}

EXPECT_TRUE(RunAndCompare(convolution_module_str, ErrorSpec{1e-4, 1e-4}));
MatchOptimizedHlo(convolution_module_str, conv_rewrite_str_);
}
void AdjustToleranceForDtype(PrimitiveType for_type, float atol, float rtol) {
if (dtype_ == for_type) {
atol_ = atol;
rtol_ = rtol;
}
}

std::string PromotedDtypeToString() {
return primitive_util::LowercasePrimitiveTypeName(PromotedDtype());
}

TEST_F(ConvolutionTest, Simple2DTestF16) {
if (!IsSupportedType(PrimitiveType::F16)) {
GTEST_SKIP() << "CPU does not support F16.";
void RunCompareAndMatchOptimizedHlo(
const absl::string_view outline,
const std::vector<absl::string_view> fused_ops) {
const std::string convolution_module_str = absl::StrReplaceAll(
outline,
{{"$dtype", dtypeString_}, {"$pdtype", PromotedDtypeToString()}});
EXPECT_TRUE(RunAndCompare(convolution_module_str, ErrorSpec{atol_, rtol_}));
MatchOptimizedHlo(convolution_module_str,
ConvStringWithOptimizations(fused_ops));
}
};

TEST_P(ConvolutionTest, Simple2DTest1) {
const absl::string_view outline = R"(
HloModule convolution.test
ENTRY convolution.test {
arg.0 = $dtype[1,22,22,1] parameter(0)
reshape.0 = $dtype[1,22,22,1] reshape(arg.0)
arg.1 = $dtype[8,8,1,1] parameter(1)
reshape.1 = $dtype[8,8,1,1] reshape(arg.1)
convolution.0 = $dtype[1,11,11,1] convolution(reshape.0, reshape.1),
window={size=8x8 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f
reshape.2 = $dtype[1,11,11,1] reshape(convolution.0)
tuple.0 = ($dtype[1,11,11,1]) tuple(reshape.2)
ROOT gte.0 = $dtype[1,11,11,1] get-tuple-element(tuple.0), index=0
})";

const char* convolution_module_str = R"(
HloModule convolution.test.f16
RunCompareAndMatchOptimizedHlo(outline, {});
}

ENTRY convolution.test.bf16 {
p0 = f16[8,4,5,5,1] parameter(0)
p1 = f16[3,3,3,1,32] parameter(1)
ROOT conv = f16[8,4,5,5,32] convolution(p0, p1), window={size=3x3x3 pad=1_1x1_1x1_1}, dim_labels=b012f_012io->b012f
TEST_P(ConvolutionTest, Simple3DTest1) {
const absl::string_view outline = R"(
HloModule convolution.test
ENTRY convolution.test {
p0 = $dtype[8,4,5,5,1] parameter(0)
p1 = $dtype[3,3,3,1,32] parameter(1)
ROOT conv = $dtype[8,4,5,5,32] convolution(p0, p1),
window={size=3x3x3 pad=1_1x1_1x1_1}, dim_labels=b012f_012io->b012f
})";

EXPECT_TRUE(RunAndCompare(convolution_module_str, ErrorSpec{1e-4, 1e-4}));
MatchOptimizedHlo(convolution_module_str, conv_rewrite_str_);
RunCompareAndMatchOptimizedHlo(outline, {});
}

TEST_F(ConvolutionTest, Conv3DWithBiasBF16) {
const char* convolution_module_str = R"(
HloModule convolution.test.with.bias.relu.bf16.3D
ENTRY TestComputation {
arg.0 = bf16[15,4,5,5,28] parameter(0)
arg.1 = bf16[3,3,3,28,64] parameter(1)
conv = bf16[15,4,5,5,64] convolution(arg.0, arg.1), window={size=3x3x3 pad=1_1x1_1x1_1}, dim_labels=b012f_012io->b012f
bias = bf16[64] parameter(2)
broadcasted_bias = bf16[15,4,5,5,64] broadcast(bias), dimensions={4}
ROOT add = bf16[15,4,5,5,64] add(conv, broadcasted_bias)
TEST_P(ConvolutionTest, Conv3DWithBiasTest) {
const absl::string_view outline = R"(
HloModule convolution.test.with.bias
ENTRY convolution.test.with.bias {
arg.0 = $dtype[15,4,5,5,28] parameter(0)
arg.1 = $dtype[3,3,3,28,64] parameter(1)
conv = $dtype[15,4,5,5,64] convolution(arg.0, arg.1),
window={size=3x3x3 pad=1_1x1_1x1_1}, dim_labels=b012f_012io->b012f
bias = $dtype[64] parameter(2)
broadcasted_bias = $dtype[15,4,5,5,64] broadcast(bias), dimensions={4}
ROOT add = $dtype[15,4,5,5,64] add(conv, broadcasted_bias)
})";
EXPECT_TRUE(RunAndCompare(convolution_module_str, ErrorSpec{0.01, 0.01}));
MatchOptimizedHlo(convolution_module_str, conv_rewrite_bias_str_);

RunCompareAndMatchOptimizedHlo(outline, {"BIAS"});
}

TEST_F(ConvolutionTest, SimpleTestF32WithBinaryAddFusion1) {
const char* convolution_module_str = R"(
HloModule conv.binaryadd.test.f32
ENTRY matmul.biasadd.test.f32 {
arg0.1 = f32[1,22,22,1] parameter(0)
constant.3 = f32[] constant(1)
broadcast.4 = f32[8,8,1,1] broadcast(constant.3), dimensions={}
convolution.0 = f32[1,11,11,1] convolution(arg0.1, broadcast.4), window={size=8x8 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f
constant.5 = f32[] constant(15)
broadcast.6 = f32[1] broadcast(constant.5), dimensions={}
broadcast.9 = f32[1,11,11,1] broadcast(broadcast.6), dimensions={3}
ROOT add.10 = f32[1,11,11,1] add(convolution.0, broadcast.9)
TEST_P(ConvolutionTest, Conv2DWithBinaryAddTest) {
const absl::string_view outline = R"(
HloModule convolution.test.with.binaryadd
ENTRY convolution.test.with.binaryadd {
arg0.1 = $dtype[1,22,22,1] parameter(0)
constant.3 = $dtype[] constant(1)
broadcast.4 = $dtype[8,8,1,1] broadcast(constant.3), dimensions={}
convolution.0 = $dtype[1,11,11,1] convolution(arg0.1, broadcast.4),
window={size=8x8 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f
constant.5 = $dtype[] constant(15)
broadcast.6 = $dtype[1] broadcast(constant.5), dimensions={}
broadcast.9 = $dtype[1,11,11,1] broadcast(broadcast.6), dimensions={3}
ROOT add.10 = $dtype[1,11,11,1] add(convolution.0, broadcast.9)
})";

EXPECT_TRUE(RunAndCompare(convolution_module_str, ErrorSpec{1e-4, 1e-4}));
MatchOptimizedHlo(convolution_module_str, fused_convolution_binary_add_);
RunCompareAndMatchOptimizedHlo(outline, {"BINARY_ADD"});
}

// This test should match BIAS + Residual Add when the residual add fusion is
// This test should match BIAS + RESIDUAL ADD when the residual add fusion is
// re-enabled.
TEST_F(ConvolutionTest, SimpleTestBF16WithBiasAndAddFusion) {
const char* convolution_module_str = R"(
HloModule convolution.add.test.bf16
ENTRY convolution.add.test.bf16 {
arg0.1 = bf16[1,22,22,1] parameter(0)
arg0.2 = bf16[8,8,1,10] parameter(1)
convolution.0 = bf16[1,11,11,10] convolution(arg0.1, arg0.2), window={size=8x8 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f
const.0 = bf16[10] constant(15)
bcast.1 = bf16[1,11,11,10] broadcast(const.0), dimensions={3}
add.0 = bf16[1,11,11,10] add(convolution.0, bcast.1)
const.1 = bf16[1,11,11,10] constant({...})
ROOT add.1 = bf16[1,11,11,10] add(add.0, const.1)
TEST_P(ConvolutionTest, Conv2DWithBiasAndBinaryAddTest) {
const absl::string_view outline = R"(
HloModule convolution.add.test
ENTRY convolution.add.test {
arg0.1 = $dtype[1,22,22,1] parameter(0)
arg0.2 = $dtype[8,8,1,10] parameter(1)
convolution.0 = $dtype[1,11,11,10] convolution(arg0.1, arg0.2),
window={size=8x8 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f
const.0 = $dtype[10] constant(15)
bcast.1 = $dtype[1,11,11,10] broadcast(const.0), dimensions={3}
add.0 = $dtype[1,11,11,10] add(convolution.0, bcast.1)
const.1 = $dtype[1,11,11,10] constant({...})
ROOT add.1 = $dtype[1,11,11,10] add(add.0, const.1)
})";

EXPECT_TRUE(RunAndCompare(convolution_module_str, ErrorSpec{1e-2, 1e-2}));
MatchOptimizedHlo(convolution_module_str, conv_rewrite_bias_str_);
RunCompareAndMatchOptimizedHlo(outline, {"BIAS"});
}

INSTANTIATE_TEST_SUITE_P(
OneDnnConvolutionTestSuite, ConvolutionTest,
::testing::Values(F32, BF16, F16),
[](const ::testing::TestParamInfo<ConvolutionTest::ParamType>& info) {
auto test_name = primitive_util::LowercasePrimitiveTypeName(info.param);
std::transform(test_name.begin(), test_name.end(), test_name.begin(),
[](auto c) { return std::toupper(c); });
return test_name;
});

} // namespace cpu
} // namespace xla

Expand Down

0 comments on commit 3ad67d1

Please sign in to comment.