Skip to content

Commit

Permalink
Capture output of defered ConstantOps without major modification of C…
Browse files Browse the repository at this point in the history
…ppEmitter methods
  • Loading branch information
lmendesp-amd committed Dec 12, 2024
1 parent e35eb2f commit 2dd8f12
Showing 1 changed file with 63 additions and 85 deletions.
148 changes: 63 additions & 85 deletions mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,9 @@ struct CppEmitter {
/// taken care of by transformations run by the backend.
bool shouldBeInlined(ExpressionOp expressionOp);

/// This emitter will only emit translation units whos id matches this value.
StringRef willOnlyEmitTu() { return onlyTu; }

private:
using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
Expand Down Expand Up @@ -307,20 +310,6 @@ struct CppEmitter {

/// Determine whether expression \p op should be emitted in a deferred way.
bool hasDeferredEmission(Operation *op);

public:
/// Emits attribute to the specified stream or returns failure.
LogicalResult emitAttributeToStream(Location loc, Attribute attr,
raw_ostream &ss);

/// Emits type 'type' to the specified stream or returns failure.
LogicalResult emitTypeToStream(Location loc, Type type, raw_ostream &ss);

private:
/// Emits array of types as a std::tuple of the emitted types independently of
/// the array size to the specified stream.
LogicalResult emitTupleTypeToStream(Location loc, ArrayRef<Type> types,
raw_ostream &ss);
};
} // namespace

Expand Down Expand Up @@ -398,14 +387,19 @@ static LogicalResult printOperation(CppEmitter &emitter,
std::string out;
llvm::raw_string_ostream ss(out);

/// Temporary emitter object that writes to our stream instead of the output
/// allowing for the capture and caching of the produced string.
CppEmitter sniffer = CppEmitter(ss, emitter.shouldDeclareVariablesAtTop(),
emitter.willOnlyEmitTu(),
emitter.shouldUseConstantsAsVariables());

ss << "(";
if (failed(emitter.emitTypeToStream(constantOp.getLoc(),
constantOp.getType(), ss)))
if (failed(sniffer.emitType(constantOp.getLoc(), constantOp.getType())))
return failure();
ss << ") ";

if (failed(emitter.emitAttributeToStream(constantOp.getLoc(),
constantOp.getValue(), ss)))
if (failed(
sniffer.emitAttribute(constantOp.getLoc(), constantOp.getValue())))
return failure();

emitter.cacheDeferredOpResult(constantOp.getResult(), out);
Expand Down Expand Up @@ -1389,21 +1383,16 @@ bool CppEmitter::hasBlockLabel(Block &block) {
}

LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
return CppEmitter::emitAttributeToStream(loc, attr, os);
}

LogicalResult CppEmitter::emitAttributeToStream(Location loc, Attribute attr,
raw_ostream &ss) {
auto printInt = [&](const APInt &val, bool isUnsigned) {
if (val.getBitWidth() == 1) {
if (val.getBoolValue())
ss << "true";
os << "true";
else
ss << "false";
os << "false";
} else {
SmallString<128> strValue;
val.toString(strValue, 10, !isUnsigned, false);
ss << strValue;
os << strValue;
}
};

Expand All @@ -1412,28 +1401,28 @@ LogicalResult CppEmitter::emitAttributeToStream(Location loc, Attribute attr,
SmallString<128> strValue;
// Use default values of toString except don't truncate zeros.
val.toString(strValue, 0, 0, false);
ss << strValue;
os << strValue;
switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) {
case llvm::APFloatBase::S_IEEEhalf:
ss << "f16";
os << "f16";
break;
case llvm::APFloatBase::S_BFloat:
ss << "bf16";
os << "bf16";
break;
case llvm::APFloatBase::S_IEEEsingle:
ss << "f";
os << "f";
break;
case llvm::APFloatBase::S_IEEEdouble:
break;
default:
llvm_unreachable("unsupported floating point type");
};
} else if (val.isNaN()) {
ss << "NAN";
os << "NAN";
} else if (val.isInfinity()) {
if (val.isNegative())
ss << "-";
ss << "INFINITY";
os << "-";
os << "INFINITY";
}
};

Expand All @@ -1453,9 +1442,9 @@ LogicalResult CppEmitter::emitAttributeToStream(Location loc, Attribute attr,
return emitError(
loc, "expected floating point attribute to be f16, bf16, f32 or f64");
}
ss << '{';
interleaveComma(dense, ss, [&](const APFloat &val) { printFloat(val); });
ss << '}';
os << '{';
interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); });
os << '}';
return success();
}

Expand All @@ -1473,40 +1462,40 @@ LogicalResult CppEmitter::emitAttributeToStream(Location loc, Attribute attr,
if (auto dense = dyn_cast<DenseIntElementsAttr>(attr)) {
if (auto iType = dyn_cast<IntegerType>(
cast<TensorType>(dense.getType()).getElementType())) {
ss << '{';
interleaveComma(dense, ss, [&](const APInt &val) {
os << '{';
interleaveComma(dense, os, [&](const APInt &val) {
printInt(val, shouldMapToUnsigned(iType.getSignedness()));
});
ss << '}';
os << '}';
return success();
}
if (auto iType = dyn_cast<IndexType>(
cast<TensorType>(dense.getType()).getElementType())) {
ss << '{';
interleaveComma(dense, ss,
os << '{';
interleaveComma(dense, os,
[&](const APInt &val) { printInt(val, false); });
ss << '}';
os << '}';
return success();
}
}

// Print opaque attributes.
if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(attr)) {
ss << oAttr.getValue();
os << oAttr.getValue();
return success();
}

// Print symbolic reference attributes.
if (auto sAttr = dyn_cast<SymbolRefAttr>(attr)) {
if (sAttr.getNestedReferences().size() > 1)
return emitError(loc, "attribute has more than 1 nested reference");
ss << sAttr.getRootReference().getValue();
os << sAttr.getRootReference().getValue();
return success();
}

// Print type attributes.
if (auto type = dyn_cast<TypeAttr>(attr))
return emitTypeToStream(loc, type.getValue(), ss);
return emitType(loc, type.getValue());

return emitError(loc, "cannot emit attribute: ") << attr;
}
Expand Down Expand Up @@ -1830,23 +1819,18 @@ LogicalResult CppEmitter::emitReferenceToType(Location loc, Type type) {
}

LogicalResult CppEmitter::emitType(Location loc, Type type) {
return emitTypeToStream(loc, type, os);
}

LogicalResult CppEmitter::emitTypeToStream(Location loc, Type type,
raw_ostream &ss) {
if (auto iType = dyn_cast<IntegerType>(type)) {
switch (iType.getWidth()) {
case 1:
return (ss << "bool"), success();
return (os << "bool"), success();
case 8:
case 16:
case 32:
case 64:
if (shouldMapToUnsigned(iType.getSignedness()))
return (ss << "uint" << iType.getWidth() << "_t"), success();
return (os << "uint" << iType.getWidth() << "_t"), success();
else
return (ss << "int" << iType.getWidth() << "_t"), success();
return (os << "int" << iType.getWidth() << "_t"), success();
default:
return emitError(loc, "cannot emit integer type ") << type;
}
Expand All @@ -1855,48 +1839,48 @@ LogicalResult CppEmitter::emitTypeToStream(Location loc, Type type,
switch (fType.getWidth()) {
case 16: {
if (llvm::isa<Float16Type>(type))
return (ss << "_Float16"), success();
return (os << "_Float16"), success();
else if (llvm::isa<BFloat16Type>(type))
return (ss << "__bf16"), success();
return (os << "__bf16"), success();
else
return emitError(loc, "cannot emit float type ") << type;
}
case 32:
return (ss << "float"), success();
return (os << "float"), success();
case 64:
return (ss << "double"), success();
return (os << "double"), success();
default:
return emitError(loc, "cannot emit float type ") << type;
}
}
if (auto iType = dyn_cast<IndexType>(type))
return (ss << "size_t"), success();
return (os << "size_t"), success();
if (auto sType = dyn_cast<emitc::SizeTType>(type))
return (ss << "size_t"), success();
return (os << "size_t"), success();
if (auto sType = dyn_cast<emitc::SignedSizeTType>(type))
return (ss << "ssize_t"), success();
return (os << "ssize_t"), success();
if (auto pType = dyn_cast<emitc::PtrDiffTType>(type))
return (ss << "ptrdiff_t"), success();
return (os << "ptrdiff_t"), success();
if (auto tType = dyn_cast<TensorType>(type)) {
if (!tType.hasRank())
return emitError(loc, "cannot emit unranked tensor type");
if (!tType.hasStaticShape())
return emitError(loc, "cannot emit tensor type with non static shape");
ss << "Tensor<";
os << "Tensor<";
if (isa<ArrayType>(tType.getElementType()))
return emitError(loc, "cannot emit tensor of array type ") << type;
if (failed(emitTypeToStream(loc, tType.getElementType(), ss)))
if (failed(emitType(loc, tType.getElementType())))
return failure();
auto shape = tType.getShape();
for (auto dimSize : shape) {
ss << ", ";
ss << dimSize;
os << ", ";
os << dimSize;
}
ss << ">";
os << ">";
return success();
}
if (auto tType = dyn_cast<TupleType>(type))
return emitTupleTypeToStream(loc, tType.getTypes(), ss);
return emitTupleType(loc, tType.getTypes());
if (auto oType = dyn_cast<emitc::OpaqueType>(type)) {
FailureOr<SmallVector<ReplacementItem>> items = oType.parseFormatString();
if (failed(items))
Expand All @@ -1905,34 +1889,34 @@ LogicalResult CppEmitter::emitTypeToStream(Location loc, Type type,
auto fmtArg = oType.getFmtArgs().begin();
for (ReplacementItem &item : *items) {
if (auto *str = std::get_if<StringRef>(&item)) {
ss << *str;
os << *str;
} else {
if (failed(emitTypeToStream(loc, *fmtArg++, ss))) {
if (failed(emitType(loc, *fmtArg++))) {
return failure();
}
}
}

return success();

ss << oType.getValue();
os << oType.getValue();
return success();
}
if (auto aType = dyn_cast<emitc::ArrayType>(type)) {
if (failed(emitTypeToStream(loc, aType.getElementType(), ss)))
if (failed(emitType(loc, aType.getElementType())))
return failure();
for (auto dim : aType.getShape())
ss << "[" << dim << "]";
os << "[" << dim << "]";
return success();
}
if (auto lType = dyn_cast<emitc::LValueType>(type))
return emitTypeToStream(loc, lType.getValueType(), ss);
return emitType(loc, lType.getValueType());
if (auto pType = dyn_cast<emitc::PointerType>(type)) {
if (isa<ArrayType>(pType.getPointee()))
return emitError(loc, "cannot emit pointer to array type ") << type;
if (failed(emitTypeToStream(loc, pType.getPointee(), ss)))
if (failed(emitType(loc, pType.getPointee())))
return failure();
ss << "*";
os << "*";
return success();
}
return emitError(loc, "cannot emit type ") << type;
Expand All @@ -1951,20 +1935,14 @@ LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) {
}

LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
return emitTupleTypeToStream(loc, types, os);
}
LogicalResult CppEmitter::emitTupleTypeToStream(Location loc,
ArrayRef<Type> types,
raw_ostream &ss) {
if (llvm::any_of(types, llvm::IsaPred<ArrayType>)) {
return emitError(loc, "cannot emit tuple of array type");
}
ss << "std::tuple<";
if (failed(interleaveCommaWithError(types, ss, [&](Type type) {
return emitTypeToStream(loc, type, ss);
})))
os << "std::tuple<";
if (failed(interleaveCommaWithError(
types, os, [&](Type type) { return emitType(loc, type); })))
return failure();
ss << ">";
os << ">";
return success();
}

Expand Down

0 comments on commit 2dd8f12

Please sign in to comment.