Skip to content

Commit

Permalink
[mlir][sparse] refactoring sparse_tensor.iterate lowering pattern imp…
Browse files Browse the repository at this point in the history
…lementation. (llvm#105566)
  • Loading branch information
Peiming Liu authored Aug 23, 2024
1 parent a968ae6 commit 7186704
Showing 1 changed file with 36 additions and 82 deletions.
118 changes: 36 additions & 82 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,88 +244,41 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
std::unique_ptr<SparseIterator> it =
iterSpace.extractIterator(rewriter, loc);

if (it->iteratableByFor()) {
auto [lo, hi] = it->genForCond(rewriter, loc);
Value step = constantIndex(rewriter, loc, 1);
SmallVector<Value> ivs;
for (ValueRange inits : adaptor.getInitArgs())
llvm::append_range(ivs, inits);
scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, ivs);

Block *loopBody = op.getBody();
OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
if (failed(typeConverter->convertSignatureArgs(
loopBody->getArgumentTypes(), bodyTypeMapping)))
return failure();
rewriter.applySignatureConversion(loopBody, bodyTypeMapping);

rewriter.eraseBlock(forOp.getBody());
Region &dstRegion = forOp.getRegion();
rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());

auto yieldOp =
llvm::cast<sparse_tensor::YieldOp>(forOp.getBody()->getTerminator());

rewriter.setInsertionPointToEnd(forOp.getBody());
// replace sparse_tensor.yield with scf.yield.
rewriter.create<scf::YieldOp>(loc, yieldOp.getResults());
rewriter.eraseOp(yieldOp);

const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
rewriter.replaceOp(op, forOp.getResults(), resultMapping);
} else {
SmallVector<Value> ivs;
// TODO: put iterator at the end of argument list to be consistent with
// coiterate operation.
llvm::append_range(ivs, it->getCursor());
for (ValueRange inits : adaptor.getInitArgs())
llvm::append_range(ivs, inits);

assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));

TypeRange types = ValueRange(ivs).getTypes();
auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
SmallVector<Location> l(types.size(), op.getIterator().getLoc());

// Generates loop conditions.
Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
rewriter.setInsertionPointToStart(before);
ValueRange bArgs = before->getArguments();
auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
assert(remArgs.size() == adaptor.getInitArgs().size());
rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments());

// Generates loop body.
Block *loopBody = op.getBody();
OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
if (failed(typeConverter->convertSignatureArgs(
loopBody->getArgumentTypes(), bodyTypeMapping)))
return failure();
rewriter.applySignatureConversion(loopBody, bodyTypeMapping);

Region &dstRegion = whileOp.getAfter();
// TODO: handle uses of coordinate!
rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
ValueRange aArgs = whileOp.getAfterArguments();
auto yieldOp = llvm::cast<sparse_tensor::YieldOp>(
whileOp.getAfterBody()->getTerminator());

rewriter.setInsertionPointToEnd(whileOp.getAfterBody());
SmallVector<Value> ivs;
for (ValueRange inits : adaptor.getInitArgs())
llvm::append_range(ivs, inits);

// Type conversion on iterate op block.
OneToNTypeMapping blockTypeMapping(op.getBody()->getArgumentTypes());
if (failed(typeConverter->convertSignatureArgs(
op.getBody()->getArgumentTypes(), blockTypeMapping)))
return rewriter.notifyMatchFailure(
op, "failed to convert iterate region argurment types");
rewriter.applySignatureConversion(op.getBody(), blockTypeMapping);

Block *block = op.getBody();
ValueRange ret = genLoopWithIterator(
rewriter, loc, it.get(), ivs, /*iterFirst=*/true,
[block](PatternRewriter &rewriter, Location loc, Region &loopBody,
SparseIterator *it, ValueRange reduc) -> SmallVector<Value> {
SmallVector<Value> blockArgs(it->getCursor());
// TODO: Also appends coordinates if used.
// blockArgs.push_back(it->deref(rewriter, loc));
llvm::append_range(blockArgs, reduc);

Block *dstBlock = &loopBody.getBlocks().front();
rewriter.inlineBlockBefore(block, dstBlock, dstBlock->end(),
blockArgs);
auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
// We can not use ValueRange as the operation holding the values will
// be destoryed.
SmallVector<Value> result(yield.getResults());
rewriter.eraseOp(yield);
return result;
});

aArgs = it->linkNewScope(aArgs);
ValueRange nx = it->forward(rewriter, loc);
SmallVector<Value> yields;
llvm::append_range(yields, nx);
llvm::append_range(yields, yieldOp.getResults());

// replace sparse_tensor.yield with scf.yield.
rewriter.eraseOp(yieldOp);
rewriter.create<scf::YieldOp>(loc, yields);
const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
rewriter.replaceOp(
op, whileOp.getResults().drop_front(it->getCursor().size()),
resultMapping);
}
const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
rewriter.replaceOp(op, ret, resultMapping);
return success();
}
};
Expand Down Expand Up @@ -366,9 +319,10 @@ class SparseCoIterateOpConverter
Block *block = &region.getBlocks().front();
OneToNTypeMapping blockTypeMapping(block->getArgumentTypes());
if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
blockTypeMapping)))
blockTypeMapping))) {
return rewriter.notifyMatchFailure(
op, "failed to convert coiterate region argurment types");
}

rewriter.applySignatureConversion(block, blockTypeMapping);
}
Expand Down

0 comments on commit 7186704

Please sign in to comment.