Skip to content

Commit

Permalink
Extend documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
jorickert committed Dec 20, 2024
1 parent 5519e16 commit 720a8cc
Showing 1 changed file with 30 additions and 1 deletion.
31 changes: 30 additions & 1 deletion src/Dialect/ONNX/Transforms/Decompose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -937,12 +937,41 @@ struct DecomposeScatterNDPattern : public OpRewritePattern<ONNXScatterNDOp> {
}

// Check that all indices are contigous.
// let r = shape(data),
// q = shape(indices)
// u = shape(updates)
// We checked that rank(q) == rank(r) and that u and r only differ in a
// single dimension 'a'. 'a' is the dimension, where the split and concat
// will happen. To ensure that the decomposition to split and concat is
// valid, the indices need to be contiguous beginning from the first index.
// We call them contiguous if each index is pointing at the element in data
// directly following the element pointed to by the previous index.
// "The following element" is the element that will be accessed, if the
// least significant value of an elements index is increased by one. To
// check that the indices are contiguous, we iterate over all indices.
// The indices need to completey cover 'data', with the exception of
// dimension 'a', where they need to only cover u[a].
// By definition, rank(q) -1 dimensions of indices and updates are
// identical.
// Because we checked that u and r only differ in 'a', all
// dimension except 'a' and rank(q) -1 (which contains the indices), are
// also identical in r and q.
// The check works by calculating the expected index (called 'counter' in
// the code) and then comparing it against the actual one. The check begins
// with the first index and then always increments it by one. The increment
// works similar to manual addition, incrementing the least significant
// digit and carrying when needed. The carry happens whenever a dimension
// was completly covered. A complication is, that the first index is not
// always [0, ...], it can be 'shifted'. We checked that the shift is only
// in the split dimension 'a'. If a carry is required, the counter that
// required the carry is not reset to zero, but to its value in the first
// index, which can contain the shift.
{
IndicesContiguosCounter counter(firstIndex, indicesShape.drop_back(1));
for (size_t i = 0; i < indicesFlatAccessor.size(); ++i) {
if (counter.getCounter() != indicesFlatAccessor[i]) {
return rewriter.notifyMatchFailure(
scatterNDOp, "Indices are not contigous");
scatterNDOp, "Indices are not contiguous");
}
counter.increment();
}
Expand Down

0 comments on commit 720a8cc

Please sign in to comment.