Skip to content

Commit

Permalink
graph: backend: dnnl: exclude non-bcast cond case for decomp kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiexin-Zheng authored and TaoLv committed Jan 16, 2025
1 parent ea6c0b7 commit 162281c
Showing 1 changed file with 6 additions and 11 deletions.
17 changes: 6 additions & 11 deletions src/graph/backend/dnnl/kernels/sdp_decomp_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*******************************************************************************/

#include "graph/backend/dnnl/kernels/sdp_decomp_config.hpp"
#include "graph/interface/shape_infer.hpp"

#define VCHECK_SDP_DECOMP(cond, status, msg, ...) \
VCONDCHECK(graph, create, check, sdp_decomp_kernel_t, (cond), status, msg, \
Expand Down Expand Up @@ -62,17 +63,11 @@ bool sdp_decomp_config_t::initial_check(const std::shared_ptr<subgraph_t> &sg,
if (graph_inport[5] != -1 && graph_inport[6] != -1) {
const auto select_cond_dims = ltw(inputs[graph_inport[5]]).vdims();
const auto select_src0_dims = ltw(inputs[graph_inport[6]]).vdims();
VCHECK_SDP_DECOMP(select_cond_dims.size() == select_src0_dims.size(),
false,
"Select cond and src0 dims should be same, but got %zu and %zu",
select_cond_dims.size(), select_src0_dims.size());
for (size_t i = 0; i < select_cond_dims.size(); i++) {

VCHECK_SDP_DECOMP(select_cond_dims[i] == select_src0_dims[i], false,
"Select cond and src0 dims should be same, but got %lld "
"and %lld",
select_cond_dims[i], select_src0_dims[i]);
}
VCHECK_SDP_DECOMP(select_cond_dims != select_src0_dims, false,
"Only supports select for case requiring broadcast cond input, "
"but got cond dims %s and src0 dims %s",
dims2str(select_cond_dims).c_str(),
dims2str(select_src0_dims).c_str());
}

#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_OMP
Expand Down

0 comments on commit 162281c

Please sign in to comment.