diff --git a/src/cpu/x64/jit_uni_pool_kernel.cpp b/src/cpu/x64/jit_uni_pool_kernel.cpp index e7478829cd7..4c4eee2b603 100644 --- a/src/cpu/x64/jit_uni_pool_kernel.cpp +++ b/src/cpu/x64/jit_uni_pool_kernel.cpp @@ -368,6 +368,11 @@ status_t jit_uni_pool_kernel::init_conf(jit_pool_conf_t &jpp, } assert(jpp.ur > 0); + jpp.needs_f32_accum_for_bf16 = jpp.is_bf16 + && jpp.alg == alg_kind::pooling_max && jpp.is_backward + && (jpp.stride_d < jpp.kd || jpp.stride_h < jpp.kh + || jpp.stride_w < jpp.kw); + // select jpp.ur_bc if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) { auto min_ur_w = nstl::max(1, utils::div_up(jpp.l_pad, jpp.stride_w)); @@ -393,9 +398,8 @@ status_t jit_uni_pool_kernel::init_conf(jit_pool_conf_t &jpp, } //take into account cache re-usage after zeroing on backward - if (jpp.is_backward && ndims < 5) { - const int L2 = platform::get_per_core_cache_size(2) - / sizeof(jpp.dt_size); + if (jpp.is_backward && ndims < 5 && !jpp.needs_f32_accum_for_bf16) { + const int L2 = platform::get_per_core_cache_size(2) / jpp.dt_size; int ur_bc = nstl::max(1, L2 / (jpp.kh * jpp.iw * jpp.c_block)); jpp.ur_bc = nstl::min(jpp.ur_bc, ur_bc); } @@ -423,10 +427,6 @@ status_t jit_uni_pool_kernel::init_conf(jit_pool_conf_t &jpp, * nscr); } - jpp.needs_f32_accum_for_bf16 = jpp.is_bf16 - && jpp.alg == alg_kind::pooling_max && jpp.is_backward - && (jpp.stride_d < jpp.kd || jpp.stride_h < jpp.kh - || jpp.stride_w < jpp.kw); jpp.f32_accum_block_size = jpp.ur_bc * jpp.c_block; if (jpp.needs_f32_accum_for_bf16) { auto tmp_d = memory_desc_wrapper(jpp.tmp_md);