From 6aa728d45e49cec49be161b252547348b095b978 Mon Sep 17 00:00:00 2001 From: Weiming Zhao Date: Sat, 21 Aug 2021 02:03:07 +0000 Subject: [PATCH] [ODLA/DNNL] Fix batchnorm --- ODLA/platforms/dnnl/odla_dnnl.cc | 44 +++++++++---------- .../test_dnnl/test_batchnorm_epsilon_dnnl.cc | 2 +- .../test_dnnl/test_batchnorm_example_dnnl.cc | 2 +- 3 files changed, 23 insertions(+), 25 deletions(-) diff --git a/ODLA/platforms/dnnl/odla_dnnl.cc b/ODLA/platforms/dnnl/odla_dnnl.cc index 15cb1b118..c985bc083 100644 --- a/ODLA/platforms/dnnl/odla_dnnl.cc +++ b/ODLA/platforms/dnnl/odla_dnnl.cc @@ -1090,7 +1090,7 @@ odla_value odla_BatchNormalization(odla_value input, odla_value offset, odla_float32 scalar_scale, odla_float32 scalar_offset, const odla_value_id value_id) { - dnnl::memory oring_mem; + dnnl::memory origin_mem; dnnl::memory::data_type dtype = input->mem.get_desc().data_type(); // black list op should convert to fp32 bool bf16_mode = (dtype == dnnl::memory::data_type::bf16 || @@ -1105,7 +1105,7 @@ odla_value odla_BatchNormalization(odla_value input, scale->mem = cast_op(scale, dnnl::memory::data_type::f32); offset->mem = cast_op(offset, dnnl::memory::data_type::f32); } - oring_mem = input->mem; + origin_mem = input->mem; input->mem = f32_input_mem; } @@ -1121,25 +1121,23 @@ odla_value odla_BatchNormalization(odla_value input, } unsigned channels = input_dims.dims[1]; - dnnl::memory::desc weight_md(dnnl::memory::dims{2, channels}, type, - dnnl::memory::format_tag::nc); - dnnl::memory weight_mem = dnnl::memory(weight_md, g_comp->eng); - - if (scale != nullptr && offset != nullptr) { + dnnl::memory scale_offset_mem = dnnl::memory(); + if (scale != nullptr || offset != nullptr || scalar_offset != 0.0F || + scalar_scale != 1.0F) { + // make a tensor [scale, bias]. + auto get_value = [channels](odla_value x, float scalar) { + if (x == nullptr) { + x = odla_CreateConstant({ODLA_FLOAT32, {2, {1, 1}}}, &scalar, + nullptr); // FIXME: copy to buf + } + return odla_Reshape(x, {2, {1, channels}}, nullptr); + }; + odla_value s = get_value(scale, scalar_scale); + odla_value b = get_value(offset, scalar_offset); flags |= dnnl::normalization_flags::use_scale_shift; - auto scale_md = - dnnl::memory::desc({1, channels}, type, dnnl::memory::format_tag::nc); - auto scale_mem = - dnnl::memory(scale_md, g_comp->eng, scale->mem.get_data_handle()); - auto offset_md = - dnnl::memory::desc({1, channels}, type, dnnl::memory::format_tag::nc); - auto c_pd = dnnl::concat::primitive_desc( - weight_md, 0, {scale_md, offset_md}, g_comp->eng); - auto c = dnnl::concat(c_pd); - c.execute(dnnl::stream(g_comp->eng), - {{DNNL_ARG_MULTIPLE_SRC, scale->mem}, - {DNNL_ARG_MULTIPLE_SRC + 1, offset->mem}, - {DNNL_ARG_DST, weight_mem}}); + auto scale_offset = + odla_Concat({2, {s, b}}, 0, {2, {2, channels}}, nullptr); + scale_offset_mem = scale_offset->mem; } auto op_desc = dnnl::batch_normalization_forward::desc( dnnl::prop_kind::forward, input_md, epsilon, flags); @@ -1148,15 +1146,15 @@ odla_value odla_BatchNormalization(odla_value input, auto prim = dnnl::batch_normalization_forward(pd); auto ret_mem = dnnl::memory(input_md, g_comp->eng); - odla_value v = CreateValue(ret_mem, orig_dims, value_id); add_op(prim, {{DNNL_ARG_SRC, input->mem}, {DNNL_ARG_MEAN, mean->mem}, {DNNL_ARG_VARIANCE, var->mem}, - {DNNL_ARG_SCALE_SHIFT, weight_mem}, + {DNNL_ARG_SCALE_SHIFT, scale_offset_mem}, {DNNL_ARG_DST, ret_mem}}); + odla_value v = CreateValue(ret_mem, orig_dims, value_id); if (g_comp->opts.bf16_mode == BF16_PERFORMACE_MODE) { v->mem = cast_op(v, dnnl::memory::data_type::bf16); - input->mem = oring_mem; + input->mem = origin_mem; } InterpretIfNeeded(); diff --git a/tests/unittests/lit_cases/test_dnnl/test_batchnorm_epsilon_dnnl.cc b/tests/unittests/lit_cases/test_dnnl/test_batchnorm_epsilon_dnnl.cc index cc6f36324..cc5bbf89f 100644 --- a/tests/unittests/lit_cases/test_dnnl/test_batchnorm_epsilon_dnnl.cc +++ b/tests/unittests/lit_cases/test_dnnl/test_batchnorm_epsilon_dnnl.cc @@ -29,5 +29,5 @@ // RUN: %t_dnnl.exe 0.0001 0 dnnl %data_path/test_batchnorm_epsilon | FileCheck %s // CHECK: Result Pass // clang-format on -// XFAIL: * + #include "test_batchnorm_epsilon_dnnl.cc.tmp.main.cc.in" diff --git a/tests/unittests/lit_cases/test_dnnl/test_batchnorm_example_dnnl.cc b/tests/unittests/lit_cases/test_dnnl/test_batchnorm_example_dnnl.cc index 7594d6158..7104d5502 100644 --- a/tests/unittests/lit_cases/test_dnnl/test_batchnorm_example_dnnl.cc +++ b/tests/unittests/lit_cases/test_dnnl/test_batchnorm_example_dnnl.cc @@ -29,5 +29,5 @@ // RUN: %t_dnnl.exe 0.0001 0 dnnl %data_path/test_batchnorm_example | FileCheck %s // CHECK: Result Pass // clang-format on -// XFAIL: * + #include "test_batchnorm_example_dnnl.cc.tmp.main.cc.in"