Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/feature/fused-ops' into tina.FXM…
Browse files Browse the repository at this point in the history
…L-3548-bump-llvm-to-d13da154a7c7eff77df8686b2de1cfdfa7cc7029
  • Loading branch information
mgehre-amd committed Feb 1, 2024
2 parents 0dc3171 + 83a820c commit ef93e5e
Show file tree
Hide file tree
Showing 10 changed files with 106 additions and 6 deletions.
2 changes: 2 additions & 0 deletions mlir/include/mlir/Tools/PDLL/AST/Nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class NamedAttributeDecl;
class OpNameDecl;
class VariableDecl;

StringRef copyStringWithNull(Context &ctx, StringRef str);

//===----------------------------------------------------------------------===//
// Name
//===----------------------------------------------------------------------===//
Expand Down
31 changes: 31 additions & 0 deletions mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,19 @@ static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
*root = std::make_unique<ExitNode>();
}

/// Sorts the range begin/end with the partial order given by cmp.
/// cmp must be a partial ordering.
template <typename Iterator, typename Compare>
void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) {
while (begin != end) {
auto const next = std::stable_partition(begin, end, [&](auto const &a) {
return std::none_of(begin, end, [&](auto const &b) { return cmp(b, a); });
});
assert(next != begin && "not a partial ordering");
begin = next;
}
}

/// Given a module containing PDL pattern operations, generate a matcher tree
/// using the patterns within the given module and return the root matcher node.
std::unique_ptr<MatcherNode>
Expand Down Expand Up @@ -964,6 +977,24 @@ MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
return *lhs < *rhs;
});

// Mostly keep the now established order, but also ensure that
// ConstraintQuestions come after the results they use.
stableTopologicalSort(ordered.begin(), ordered.end(),
[](OrderedPredicate *a, OrderedPredicate *b) {
auto *cqa = dyn_cast<ConstraintQuestion>(a->question);
auto *cqb = dyn_cast<ConstraintQuestion>(b->question);
if (cqa && cqb) {
// Does any argument of b use a? Then b must be
// sorted after a.
return llvm::any_of(
cqb->getArgs(), [&](Position *p) {
auto *cp = dyn_cast<ConstraintPosition>(p);
return cp && cp->getQuestion() == cqa;
});
}
return false;
});

// Build the matchers for each of the pattern predicate lists.
std::unique_ptr<MatcherNode> root;
for (OrderedPredicateList &list : lists)
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Tools/PDLL/AST/Nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using namespace mlir;
using namespace mlir::pdll::ast;

/// Copy a string reference into the context with a null terminator.
static StringRef copyStringWithNull(Context &ctx, StringRef str) {
StringRef mlir::pdll::ast::copyStringWithNull(Context &ctx, StringRef str) {
if (str.empty())
return str;

Expand Down
2 changes: 0 additions & 2 deletions mlir/lib/Tools/PDLL/Parser/Lexer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,6 @@ Token Lexer::lexToken() {
return formToken(Token::l_paren, tokStart);
case ')':
return formToken(Token::r_paren, tokStart);
case '!':
return formToken(Token::exclam, tokStart);
case '/':
if (*curPtr == '/') {
lexComment();
Expand Down
1 change: 0 additions & 1 deletion mlir/lib/Tools/PDLL/Parser/Lexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ class Token {
equal,
equal_arrow,
semicolon,
exclam,
/// Paired punctuation.
less,
greater,
Expand Down
23 changes: 23 additions & 0 deletions mlir/lib/Tools/PDLL/Parser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ class Parser {
FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
FailureOr<ast::Expr *> parseNegatedExpr();
FailureOr<ast::Expr *> parseIntegerExpr();
FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
FailureOr<ast::Expr *>
Expand Down Expand Up @@ -1835,6 +1836,9 @@ FailureOr<ast::Expr *> Parser::parseExpr() {
case Token::l_square:
lhsExpr = parseArrayAttrExpr();
break;
case Token::integer:
lhsExpr = parseIntegerExpr();
break;
case Token::string_block:
return emitError("expected expression. If you are trying to create an "
"ArrayAttr, use a space between `[` and `{`.");
Expand Down Expand Up @@ -2079,6 +2083,25 @@ FailureOr<ast::Expr *> Parser::parseNegatedExpr() {
return parseCallExpr(*identifierExpr, /*isNegated = */ true);
}

/// Parse
/// integer : identifier
/// into an AttributeExpr.
/// Examples: '4 : i32', '0 : si1'
FailureOr<ast::Expr *> Parser::parseIntegerExpr() {
SMRange loc = curToken.getLoc();
StringRef value = curToken.getSpelling();
consumeToken();
if (!consumeIf(Token::colon))
return emitError("expected colon after integer literal");
if (!curToken.is(Token::identifier))
return emitError("expected integer type");
StringRef type = curToken.getSpelling();
consumeToken();

auto allocated = copyStringWithNull(ctx, (Twine(value) + ":" + type).str());
return ast::AttributeExpr::create(ctx, loc, allocated);
}

FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
SMRange loc = curToken.getLoc();

Expand Down
21 changes: 21 additions & 0 deletions mlir/test/Conversion/PDLToPDLInterp/use-constraint-result.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// RUN: mlir-opt -split-input-file -convert-pdl-to-pdl-interp %s | FileCheck %s

// Ensuse that the dependency between add & less
// causes them to be in the correct order.
// CHECK: apply_constraint "__builtin_add"
// CHECK: apply_constraint "__builtin_less"

module {
pdl.pattern @test : benefit(1) {
%0 = attribute
%1 = types
%2 = operation "tosa.mul" {"shift" = %0} -> (%1 : !pdl.range<type>)
%3 = attribute = 0 : i32
%4 = attribute = 1 : i32
%5 = apply_native_constraint "__builtin_add"(%3, %4 : !pdl.attribute, !pdl.attribute) : !pdl.attribute
apply_native_constraint "__builtin_less"(%0, %5 : !pdl.attribute, !pdl.attribute)
rewrite %2 {
replace %2 with %2
}
}
}
19 changes: 18 additions & 1 deletion mlir/test/mlir-pdll/Parser/expr-failure.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ Pattern {
// -----

Pattern {
// CHECK: expected expression
// CHECK: expected colon after integer literal
let tuple = (10 = _: Value);
erase op<>;
}
Expand Down Expand Up @@ -249,6 +249,23 @@ Pattern {
};;
}

// -----

Pattern {
let root = op<func.func> -> ();
3;
// CHECK: expected colon after integer literal
replace root with root;
}

// -----

Pattern {
let root = op<func.func> -> ();
3 :;
// CHECK: expected integer type
replace root with root;
}

// -----

Expand Down
9 changes: 9 additions & 0 deletions mlir/test/mlir-pdll/Parser/expr.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@ Pattern {

// -----

// CHECK: Module
// CHECK: `-AttributeExpr {{.*}} Value<"10:i32">
Pattern {
let attr = 10 : i32;
erase _: Op;
}

// -----

// CHECK: |-NamedAttributeDecl {{.*}} Name<some_array>
// CHECK: `-UserRewriteDecl {{.*}} Name<addElemToArrayAttr> ResultType<Attr>
// CHECK: `Arguments`
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/mlir-pdll/Parser/stmt-failure.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ Pattern {
// -----

Pattern {
// CHECK: expected expression
// CHECK: expected colon after integer literal
let foo: ValueRange<10>;
}

Expand Down

0 comments on commit ef93e5e

Please sign in to comment.