Skip to content

Commit

Permalink
Converter:Bugfix: Fix bug for einsum for 4 - 4 not care about transposeA
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaying committed Aug 20, 2024
1 parent 72ef1ff commit ee15dfd
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion tools/converter/source/optimizer/onnxextra/OnnxEinsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,22 @@ class OnnxEinsumTransform : public OnnxExtraManager::Transform {
}
// find reduce dim
char reduce_dim;
int reduce_dim_pos = -1;
for (int i = 0; i < input0.size(); ++i) {
auto c = input0[i];
if (right.find(c) == std::string::npos) {
reduce_dim = c;
reduce_dim_pos = i;
break;
}
}
bool needTransposeA = false;
if (reduce_dim_pos >= 0 && input0.size() >= 2 && reduce_dim_pos == input0.size() - 2) {
needTransposeA = true;
}
auto need_transpose = input1.find(reduce_dim) == (input1.size() - 1);
// matmul: matmul auto broadcast such: `bhwc @ hkc` -> `bhwc @ bhkc`
auto output = _MatMul(var0, var1, false, need_transpose);
auto output = _MatMul(var0, var1, needTransposeA, need_transpose);
// squeeze
if (sqeeze_axis >= 0) {
output = _Squeeze(output, {sqeeze_axis});
Expand Down

0 comments on commit ee15dfd

Please sign in to comment.