Skip to content

Commit

Permalink
[ODLA/DNNL] Fix batchnorm
Browse files Browse the repository at this point in the history
  • Loading branch information
Weiming Zhao authored and weimingzha0 committed Aug 21, 2021
1 parent 66a3da3 commit 6aa728d
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 25 deletions.
44 changes: 21 additions & 23 deletions ODLA/platforms/dnnl/odla_dnnl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1090,7 +1090,7 @@ odla_value odla_BatchNormalization(odla_value input,
odla_value offset, odla_float32 scalar_scale,
odla_float32 scalar_offset,
const odla_value_id value_id) {
dnnl::memory oring_mem;
dnnl::memory origin_mem;
dnnl::memory::data_type dtype = input->mem.get_desc().data_type();
// black list op should convert to fp32
bool bf16_mode = (dtype == dnnl::memory::data_type::bf16 ||
Expand All @@ -1105,7 +1105,7 @@ odla_value odla_BatchNormalization(odla_value input,
scale->mem = cast_op(scale, dnnl::memory::data_type::f32);
offset->mem = cast_op(offset, dnnl::memory::data_type::f32);
}
oring_mem = input->mem;
origin_mem = input->mem;
input->mem = f32_input_mem;
}

Expand All @@ -1121,25 +1121,23 @@ odla_value odla_BatchNormalization(odla_value input,
}

unsigned channels = input_dims.dims[1];
dnnl::memory::desc weight_md(dnnl::memory::dims{2, channels}, type,
dnnl::memory::format_tag::nc);
dnnl::memory weight_mem = dnnl::memory(weight_md, g_comp->eng);

if (scale != nullptr && offset != nullptr) {
dnnl::memory scale_offset_mem = dnnl::memory();
if (scale != nullptr || offset != nullptr || scalar_offset != 0.0F ||
scalar_scale != 1.0F) {
// make a tensor [scale, bias].
auto get_value = [channels](odla_value x, float scalar) {
if (x == nullptr) {
x = odla_CreateConstant({ODLA_FLOAT32, {2, {1, 1}}}, &scalar,
nullptr); // FIXME: copy to buf
}
return odla_Reshape(x, {2, {1, channels}}, nullptr);
};
odla_value s = get_value(scale, scalar_scale);
odla_value b = get_value(offset, scalar_offset);
flags |= dnnl::normalization_flags::use_scale_shift;
auto scale_md =
dnnl::memory::desc({1, channels}, type, dnnl::memory::format_tag::nc);
auto scale_mem =
dnnl::memory(scale_md, g_comp->eng, scale->mem.get_data_handle());
auto offset_md =
dnnl::memory::desc({1, channels}, type, dnnl::memory::format_tag::nc);
auto c_pd = dnnl::concat::primitive_desc(
weight_md, 0, {scale_md, offset_md}, g_comp->eng);
auto c = dnnl::concat(c_pd);
c.execute(dnnl::stream(g_comp->eng),
{{DNNL_ARG_MULTIPLE_SRC, scale->mem},
{DNNL_ARG_MULTIPLE_SRC + 1, offset->mem},
{DNNL_ARG_DST, weight_mem}});
auto scale_offset =
odla_Concat({2, {s, b}}, 0, {2, {2, channels}}, nullptr);
scale_offset_mem = scale_offset->mem;
}
auto op_desc = dnnl::batch_normalization_forward::desc(
dnnl::prop_kind::forward, input_md, epsilon, flags);
Expand All @@ -1148,15 +1146,15 @@ odla_value odla_BatchNormalization(odla_value input,
auto prim = dnnl::batch_normalization_forward(pd);
auto ret_mem = dnnl::memory(input_md, g_comp->eng);

odla_value v = CreateValue(ret_mem, orig_dims, value_id);
add_op(prim, {{DNNL_ARG_SRC, input->mem},
{DNNL_ARG_MEAN, mean->mem},
{DNNL_ARG_VARIANCE, var->mem},
{DNNL_ARG_SCALE_SHIFT, weight_mem},
{DNNL_ARG_SCALE_SHIFT, scale_offset_mem},
{DNNL_ARG_DST, ret_mem}});
odla_value v = CreateValue(ret_mem, orig_dims, value_id);
if (g_comp->opts.bf16_mode == BF16_PERFORMACE_MODE) {
v->mem = cast_op(v, dnnl::memory::data_type::bf16);
input->mem = oring_mem;
input->mem = origin_mem;
}

InterpretIfNeeded();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@
// RUN: %t_dnnl.exe 0.0001 0 dnnl %data_path/test_batchnorm_epsilon | FileCheck %s
// CHECK: Result Pass
// clang-format on
// XFAIL: *

#include "test_batchnorm_epsilon_dnnl.cc.tmp.main.cc.in"
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@
// RUN: %t_dnnl.exe 0.0001 0 dnnl %data_path/test_batchnorm_example | FileCheck %s
// CHECK: Result Pass
// clang-format on
// XFAIL: *

#include "test_batchnorm_example_dnnl.cc.tmp.main.cc.in"

0 comments on commit 6aa728d

Please sign in to comment.