From ee15dfdcead281ddf6fb17e25e3beb1d7bc77313 Mon Sep 17 00:00:00 2001 From: xiaying Date: Tue, 20 Aug 2024 13:53:58 +0800 Subject: [PATCH] Converter:Bugfix: Fix bug for einsum for 4 - 4 not care about transposeA --- tools/converter/source/optimizer/onnxextra/OnnxEinsum.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tools/converter/source/optimizer/onnxextra/OnnxEinsum.cpp b/tools/converter/source/optimizer/onnxextra/OnnxEinsum.cpp index ddbdbc657..d0b66d1b9 100644 --- a/tools/converter/source/optimizer/onnxextra/OnnxEinsum.cpp +++ b/tools/converter/source/optimizer/onnxextra/OnnxEinsum.cpp @@ -139,16 +139,22 @@ class OnnxEinsumTransform : public OnnxExtraManager::Transform { } // find reduce dim char reduce_dim; + int reduce_dim_pos = -1; for (int i = 0; i < input0.size(); ++i) { auto c = input0[i]; if (right.find(c) == std::string::npos) { reduce_dim = c; + reduce_dim_pos = i; break; } } + bool needTransposeA = false; + if (reduce_dim_pos >= 0 && input0.size() >= 2 && reduce_dim_pos == input0.size() - 2) { + needTransposeA = true; + } auto need_transpose = input1.find(reduce_dim) == (input1.size() - 1); // matmul: matmul auto broadcast such: `bhwc @ hkc` -> `bhwc @ bhkc` - auto output = _MatMul(var0, var1, false, need_transpose); + auto output = _MatMul(var0, var1, needTransposeA, need_transpose); // squeeze if (sqeeze_axis >= 0) { output = _Squeeze(output, {sqeeze_axis});