diff --git a/src/graph/backend/dnnl/kernels/sdp_decomp_config.cpp b/src/graph/backend/dnnl/kernels/sdp_decomp_config.cpp index d09567286ad..ae96ce687dd 100644 --- a/src/graph/backend/dnnl/kernels/sdp_decomp_config.cpp +++ b/src/graph/backend/dnnl/kernels/sdp_decomp_config.cpp @@ -15,6 +15,7 @@ *******************************************************************************/ #include "graph/backend/dnnl/kernels/sdp_decomp_config.hpp" +#include "graph/interface/shape_infer.hpp" #define VCHECK_SDP_DECOMP(cond, status, msg, ...) \ VCONDCHECK(graph, create, check, sdp_decomp_kernel_t, (cond), status, msg, \ @@ -62,17 +63,11 @@ bool sdp_decomp_config_t::initial_check(const std::shared_ptr &sg, if (graph_inport[5] != -1 && graph_inport[6] != -1) { const auto select_cond_dims = ltw(inputs[graph_inport[5]]).vdims(); const auto select_src0_dims = ltw(inputs[graph_inport[6]]).vdims(); - VCHECK_SDP_DECOMP(select_cond_dims.size() == select_src0_dims.size(), - false, - "Select cond and src0 dims should be same, but got %zu and %zu", - select_cond_dims.size(), select_src0_dims.size()); - for (size_t i = 0; i < select_cond_dims.size(); i++) { - - VCHECK_SDP_DECOMP(select_cond_dims[i] == select_src0_dims[i], false, - "Select cond and src0 dims should be same, but got %lld " - "and %lld", - select_cond_dims[i], select_src0_dims[i]); - } + VCHECK_SDP_DECOMP(select_cond_dims != select_src0_dims, false, + "Only supports select for case requiring broadcast cond input, " + "but got cond dims %s and src0 dims %s", + dims2str(select_cond_dims).c_str(), + dims2str(select_src0_dims).c_str()); } #if DNNL_CPU_RUNTIME == DNNL_RUNTIME_OMP