Skip to content

Commit

Permalink
cpu: x64: pooling: fix ur_bc calculation in back propagation
Browse files Browse the repository at this point in the history
  • Loading branch information
asimonov1 committed Jan 14, 2025
1 parent f8d455f commit 73b85e2
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions src/cpu/x64/jit_uni_pool_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,11 @@ status_t jit_uni_pool_kernel<isa>::init_conf(jit_pool_conf_t &jpp,
}
assert(jpp.ur > 0);

jpp.needs_f32_accum_for_bf16 = jpp.is_bf16
&& jpp.alg == alg_kind::pooling_max && jpp.is_backward
&& (jpp.stride_d < jpp.kd || jpp.stride_h < jpp.kh
|| jpp.stride_w < jpp.kw);

// select jpp.ur_bc
if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) {
auto min_ur_w = nstl::max(1, utils::div_up(jpp.l_pad, jpp.stride_w));
Expand All @@ -393,9 +398,8 @@ status_t jit_uni_pool_kernel<isa>::init_conf(jit_pool_conf_t &jpp,
}

//take into account cache re-usage after zeroing on backward
if (jpp.is_backward && ndims < 5) {
const int L2 = platform::get_per_core_cache_size(2)
/ sizeof(jpp.dt_size);
if (jpp.is_backward && ndims < 5 && !jpp.needs_f32_accum_for_bf16) {
const int L2 = platform::get_per_core_cache_size(2) / jpp.dt_size;
int ur_bc = nstl::max(1, L2 / (jpp.kh * jpp.iw * jpp.c_block));
jpp.ur_bc = nstl::min(jpp.ur_bc, ur_bc);
}
Expand Down Expand Up @@ -423,10 +427,6 @@ status_t jit_uni_pool_kernel<isa>::init_conf(jit_pool_conf_t &jpp,
* nscr);
}

jpp.needs_f32_accum_for_bf16 = jpp.is_bf16
&& jpp.alg == alg_kind::pooling_max && jpp.is_backward
&& (jpp.stride_d < jpp.kd || jpp.stride_h < jpp.kh
|| jpp.stride_w < jpp.kw);
jpp.f32_accum_block_size = jpp.ur_bc * jpp.c_block;
if (jpp.needs_f32_accum_for_bf16) {
auto tmp_d = memory_desc_wrapper(jpp.tmp_md);
Expand Down

0 comments on commit 73b85e2

Please sign in to comment.