Skip to content

Commit

Permalink
OpaqueType with format strings (#391)
Browse files Browse the repository at this point in the history
OpaqueType: Use format string
  • Loading branch information
josel-amd authored Oct 28, 2024
1 parent 4b36487 commit ad4697c
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 58 deletions.
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ bool isPointerWideType(mlir::Type type);
/// Give the name of the EmitC reference attribute.
StringRef getReferenceAttributeName();

// Either a literal string, or an placeholder for the fmtArgs.
struct Placeholder {};
using ReplacementItem = std::variant<StringRef, Placeholder>;

} // namespace emitc
} // namespace mlir

Expand Down
6 changes: 1 addition & 5 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -1168,11 +1168,7 @@ def EmitC_VerbatimOp : EmitC_Op<"verbatim"> {
}];

let extraClassDeclaration = [{
// Either a literal string, or an placeholder for the fmtArgs.
struct Placeholder {};
using ReplacementItem = std::variant<StringRef, Placeholder>;

FailureOr<SmallVector<ReplacementItem>> parseFormatString();
FailureOr<SmallVector<::mlir::emitc::ReplacementItem>> parseFormatString();
}];

let arguments = (ins StrAttr:$value,
Expand Down
11 changes: 9 additions & 2 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,16 @@ def EmitC_OpaqueType : EmitC_Type<"Opaque", "opaque"> {
```
}];

let parameters = (ins StringRefParameter<"the opaque value">:$value);
let assemblyFormat = "`<` $value `>`";
let parameters = (ins StringRefParameter<"the opaque value">:$value,
OptionalArrayRefParameter<"Type">:$fmtArgs);
let assemblyFormat = "`<` $value (`,` custom<VariadicTypeFmtArgs>($fmtArgs)^)? `>`";
let genVerifyDecl = 1;

let builders = [TypeBuilder<(ins "::llvm::StringRef":$value), [{ return $_get($_ctxt, value, SmallVector<Type>{}); }] >];

let extraClassDeclaration = [{
FailureOr<SmallVector<::mlir::emitc::ReplacementItem>> parseFormatString();
}];
}

def EmitC_PointerType : EmitC_Type<"Pointer", "ptr"> {
Expand Down
154 changes: 105 additions & 49 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"

using namespace mlir;
using namespace mlir::emitc;
Expand Down Expand Up @@ -154,6 +155,64 @@ static LogicalResult verifyInitializationAttribute(Operation *op,
return success();
}

/// Parse a format string and return a list of its parts.
/// A part is either a StringRef that has to be printed as-is, or
/// a Placeholder which requires printing the next operand of the VerbatimOp.
/// In the format string, all `{}` are replaced by Placeholders, except if the
/// `{` is escaped by `{{` - then it doesn't start a placeholder.
template <class ArgType>
FailureOr<SmallVector<ReplacementItem>>
parseFormatString(StringRef toParse, ArgType fmtArgs,
std::optional<llvm::function_ref<mlir::InFlightDiagnostic()>>
emitError = {}) {
SmallVector<ReplacementItem> items;

// If there are not operands, the format string is not interpreted.
if (fmtArgs.empty()) {
items.push_back(toParse);
return items;
}

while (!toParse.empty()) {
size_t idx = toParse.find('{');
if (idx == StringRef::npos) {
// No '{'
items.push_back(toParse);
break;
}
if (idx > 0) {
// Take all chars excluding the '{'.
items.push_back(toParse.take_front(idx));
toParse = toParse.drop_front(idx);
continue;
}
if (toParse.size() < 2) {
// '{' is last character
items.push_back(toParse);
break;
}
// toParse contains at least two characters and starts with `{`.
char nextChar = toParse[1];
if (nextChar == '{') {
// Double '{{' -> '{' (escaping).
items.push_back(toParse.take_front(1));
toParse = toParse.drop_front(2);
continue;
}
if (nextChar == '}') {
items.push_back(Placeholder{});
toParse = toParse.drop_front(2);
continue;
}

if (emitError.has_value()) {
return (*emitError)() << "expected '}' after unescaped '{'";
}
return failure();
}
return items;
}

//===----------------------------------------------------------------------===//
// AddOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -914,7 +973,11 @@ LogicalResult emitc::SubscriptOp::verify() {
//===----------------------------------------------------------------------===//

LogicalResult emitc::VerbatimOp::verify() {
FailureOr<SmallVector<ReplacementItem>> fmt = parseFormatString();
auto errorCallback = [&]() -> InFlightDiagnostic {
return this->emitOpError();
};
FailureOr<SmallVector<ReplacementItem>> fmt =
::parseFormatString(getValue(), getFmtArgs(), errorCallback);
if (failed(fmt))
return failure();

Expand All @@ -929,56 +992,29 @@ LogicalResult emitc::VerbatimOp::verify() {
return success();
}

/// Parse a format string and return a list of its parts.
/// A part is either a StringRef that has to be printed as-is, or
/// a Placeholder which requires printing the next operand of the VerbatimOp.
/// In the format string, all `{}` are replaced by Placeholders, except if the
/// `{` is escaped by `{{` - then it doesn't start a placeholder.
FailureOr<SmallVector<emitc::VerbatimOp::ReplacementItem>>
emitc::VerbatimOp::parseFormatString() {
SmallVector<ReplacementItem> items;
static ParseResult parseVariadicTypeFmtArgs(AsmParser &p,
SmallVector<Type> &params) {
Type type;
if (p.parseType(type))
return failure();

// If there are not operands, the format string is not interpreted.
if (getFmtArgs().empty()) {
items.push_back(getValue());
return items;
params.push_back(type);
while (succeeded(p.parseOptionalComma())) {
if (p.parseType(type))
return failure();
params.push_back(type);
}

StringRef toParse = getValue();
while (!toParse.empty()) {
size_t idx = toParse.find('{');
if (idx == StringRef::npos) {
// No '{'
items.push_back(toParse);
break;
}
if (idx > 0) {
// Take all chars excluding the '{'.
items.push_back(toParse.take_front(idx));
toParse = toParse.drop_front(idx);
continue;
}
if (toParse.size() < 2) {
// '{' is last character
items.push_back(toParse);
break;
}
// toParse contains at least two characters and starts with `{`.
char nextChar = toParse[1];
if (nextChar == '{') {
// Double '{{' -> '{' (escaping).
items.push_back(toParse.take_front(1));
toParse = toParse.drop_front(2);
continue;
}
if (nextChar == '}') {
items.push_back(Placeholder{});
toParse = toParse.drop_front(2);
continue;
}
return emitOpError() << "expected '}' after unescaped '{'";
}
return items;
return success();
}

static void printVariadicTypeFmtArgs(AsmPrinter &p, ArrayRef<Type> params) {
llvm::interleaveComma(params, p, [&](Type type) { p.printType(type); });
}

FailureOr<SmallVector<ReplacementItem>> emitc::VerbatimOp::parseFormatString() {
// Error checking is done in verify.
return ::parseFormatString(getValue(), getFmtArgs());
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1072,17 +1108,37 @@ emitc::ArrayType::cloneWith(std::optional<ArrayRef<int64_t>> shape,

LogicalResult mlir::emitc::OpaqueType::verify(
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
llvm::StringRef value) {
llvm::StringRef value, ArrayRef<Type> fmtArgs) {
if (value.empty()) {
return emitError() << "expected non empty string in !emitc.opaque type";
}
if (value.back() == '*') {
return emitError() << "pointer not allowed as outer type with "
"!emitc.opaque, use !emitc.ptr instead";
}

FailureOr<SmallVector<ReplacementItem>> fmt =
::parseFormatString(value, fmtArgs, emitError);
if (failed(fmt))
return failure();

size_t numPlaceholders = llvm::count_if(*fmt, [](ReplacementItem &item) {
return std::holds_alternative<Placeholder>(item);
});

if (numPlaceholders != fmtArgs.size()) {
return emitError()
<< "requires operands for each placeholder in the format string";
}

return success();
}

FailureOr<SmallVector<ReplacementItem>> emitc::OpaqueType::parseFormatString() {
// Error checking is done in verify.
return ::parseFormatString(getValue(), getFmtArgs());
}

//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//
Expand Down
21 changes: 19 additions & 2 deletions mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -512,14 +512,14 @@ static LogicalResult printOperation(CppEmitter &emitter,
emitc::VerbatimOp verbatimOp) {
raw_ostream &os = emitter.ostream();

FailureOr<SmallVector<emitc::VerbatimOp::ReplacementItem>> items =
FailureOr<SmallVector<ReplacementItem>> items =
verbatimOp.parseFormatString();
if (failed(items))
return failure();

auto fmtArg = verbatimOp.getFmtArgs().begin();

for (emitc::VerbatimOp::ReplacementItem &item : *items) {
for (ReplacementItem &item : *items) {
if (auto *str = std::get_if<StringRef>(&item)) {
os << *str;
} else {
Expand Down Expand Up @@ -1728,6 +1728,23 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) {
if (auto tType = dyn_cast<TupleType>(type))
return emitTupleType(loc, tType.getTypes());
if (auto oType = dyn_cast<emitc::OpaqueType>(type)) {
FailureOr<SmallVector<ReplacementItem>> items = oType.parseFormatString();
if (failed(items))
return failure();

auto fmtArg = oType.getFmtArgs().begin();
for (ReplacementItem &item : *items) {
if (auto *str = std::get_if<StringRef>(&item)) {
os << *str;
} else {
if (failed(emitType(loc, *fmtArg++))) {
return failure();
}
}
}

return success();

os << oType.getValue();
return success();
}
Expand Down
28 changes: 28 additions & 0 deletions mlir/test/Dialect/EmitC/invalid_types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,34 @@ func.func @illegal_opaque_type_2() {

// -----

// expected-error @+1 {{expected non-function type}}
func.func @illegal_opaque_type(%arg0: !emitc.opaque<"{}, {}", "string">) {
return
}

// -----

// expected-error @+1 {{requires operands for each placeholder in the format string}}
func.func @illegal_opaque_type(%arg0: !emitc.opaque<"a", f32>) {
return
}

// -----

// expected-error @+1 {{requires operands for each placeholder in the format string}}
func.func @illegal_opaque_type(%arg0: !emitc.opaque<"{}, {}", f32>) {
return
}

// -----

// expected-error @+1 {{expected '}' after unescaped '{'}}
func.func @illegal_opaque_type(%arg0: !emitc.opaque<"{ ", i32>) {
return
}

// -----

func.func @illegal_array_missing_spec(
// expected-error @+1 {{expected non-function type}}
%arg0: !emitc.array<>) {
Expand Down
6 changes: 6 additions & 0 deletions mlir/test/Dialect/EmitC/types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ func.func @opaque_types() {
emitc.call_opaque "f"() {template_args = [!emitc.opaque<"std::vector<std::string>">]} : () -> ()
// CHECK-NEXT: !emitc.opaque<"SmallVector<int*, 4>">
emitc.call_opaque "f"() {template_args = [!emitc.opaque<"SmallVector<int*, 4>">]} : () -> ()
// CHECK-NEXT: !emitc.opaque<"{}", i32>
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"{}", i32>>]} : () -> ()
// CHECK-NEXT: !emitc.opaque<"{}, {}", i32, f32>]
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"{}, {}", i32, f32>>]} : () -> ()
// CHECK-NEXT: !emitc.opaque<"{}"
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"{}">>]} : () -> ()

return
}
Expand Down
10 changes: 10 additions & 0 deletions mlir/test/Target/Cpp/types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@ func.func @opaque_types() {
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"status_t">>]} : () -> ()
// CHECK-NEXT: f<std::vector<std::string>>();
emitc.call_opaque "f"() {template_args = [!emitc.opaque<"std::vector<std::string>">]} : () -> ()
// CHECK: f<float>()
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"{}", f32>>]} : () -> ()
// CHECK: f<int16_t {>();
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"{} {{", si16>>]} : () -> ()
// CHECK: f<int8_t {>();
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"{} {", i8>>]} : () -> ()
// CHECK: f<status_t>();
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"{}", !emitc<opaque<"status_t">> >>]} : () -> ()
// CHECK: f<top<nested<float>,int32_t>>();
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"top<{},{}>", !emitc<opaque<"nested<{}>", f32>>, i32>>]} : () -> ()

return
}
Expand Down

0 comments on commit ad4697c

Please sign in to comment.