Skip to content

Commit

Permalink
[mlir] Use StringRef::operator== instead of StringRef::equals (NFC) (l…
Browse files Browse the repository at this point in the history
…lvm#91560)

I'm planning to remove StringRef::equals in favor of
StringRef::operator==.

- StringRef::operator==/!= outnumber StringRef::equals by a factor of
  10 under mlir/ in terms of their usage.

- The elimination of StringRef::equals brings StringRef closer to
  std::string_view, which has operator== but not equals.

- S == "foo" is more readable than S.equals("foo"), especially for
  !Long.Expression.equals("str") vs Long.Expression != "str".
  • Loading branch information
kazutakahirata committed May 9, 2024
1 parent fd1bd53 commit dec8055
Show file tree
Hide file tree
Showing 14 changed files with 39 additions and 47 deletions.
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,7 @@ LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
static bool isDefinedByCallTo(Value value, StringRef functionName) {
assert(isa<LLVM::LLVMPointerType>(value.getType()));
if (auto defOp = value.getDefiningOp<LLVM::CallOp>())
return defOp.getCallee()->equals(functionName);
return *defOp.getCallee() == functionName;
return false;
}

Expand Down
16 changes: 8 additions & 8 deletions mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
static constexpr StringRef kInvalidCaseStr = "Unsupported WMMA variant.";

static NVVM::MMAFrag convertOperand(StringRef operandName) {
if (operandName.equals("AOp"))
if (operandName == "AOp")
return NVVM::MMAFrag::a;
if (operandName.equals("BOp"))
if (operandName == "BOp")
return NVVM::MMAFrag::b;
if (operandName.equals("COp"))
if (operandName == "COp")
return NVVM::MMAFrag::c;
llvm_unreachable("Unknown operand name");
}
Expand All @@ -55,8 +55,8 @@ static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
if (type.getElementType().isF16())
return NVVM::MMATypes::f16;
if (type.getElementType().isF32())
return type.getOperand().equals("COp") ? NVVM::MMATypes::f32
: NVVM::MMATypes::tf32;
return type.getOperand() == "COp" ? NVVM::MMATypes::f32
: NVVM::MMATypes::tf32;

if (type.getElementType().isSignedInteger(8))
return NVVM::MMATypes::s8;
Expand Down Expand Up @@ -99,15 +99,15 @@ struct WmmaLoadOpToNVVMLowering
NVVM::MMATypes eltype = getElementType(retType);
// NVVM intrinsics require to give mxnxk dimensions, infer the missing
// dimension based on the valid intrinsics available.
if (retType.getOperand().equals("AOp")) {
if (retType.getOperand() == "AOp") {
m = retTypeShape[0];
k = retTypeShape[1];
n = NVVM::WMMALoadOp::inferNDimension(m, k, eltype);
} else if (retType.getOperand().equals("BOp")) {
} else if (retType.getOperand() == "BOp") {
k = retTypeShape[0];
n = retTypeShape[1];
m = NVVM::WMMALoadOp::inferMDimension(k, n, eltype);
} else if (retType.getOperand().equals("COp")) {
} else if (retType.getOperand() == "COp") {
m = retTypeShape[0];
n = retTypeShape[1];
k = NVVM::WMMALoadOp::inferKDimension(m, n, eltype);
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
template <typename OpTy>
static bool isTensorOp(OpTy xferOp) {
if (isa<RankedTensorType>(xferOp.getShapedType())) {
if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) {
if (xferOp.getOperationName() == TransferWriteOp::getOperationName()) {
// TransferWriteOps on tensors have a result.
assert(xferOp->getNumResults() > 0);
}
Expand Down
14 changes: 6 additions & 8 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3585,20 +3585,18 @@ ParseResult AffinePrefetchOp::parse(OpAsmParser &parser,
parser.resolveOperands(mapOperands, indexTy, result.operands))
return failure();

if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
if (readOrWrite != "read" && readOrWrite != "write")
return parser.emitError(parser.getNameLoc(),
"rw specifier has to be 'read' or 'write'");
result.addAttribute(
AffinePrefetchOp::getIsWriteAttrStrName(),
parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
result.addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(),
parser.getBuilder().getBoolAttr(readOrWrite == "write"));

if (!cacheType.equals("data") && !cacheType.equals("instr"))
if (cacheType != "data" && cacheType != "instr")
return parser.emitError(parser.getNameLoc(),
"cache type has to be 'data' or 'instr'");

result.addAttribute(
AffinePrefetchOp::getIsDataCacheAttrStrName(),
parser.getBuilder().getBoolAttr(cacheType.equals("data")));
result.addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(),
parser.getBuilder().getBoolAttr(cacheType == "data"));

return success();
}
Expand Down
13 changes: 5 additions & 8 deletions mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,7 @@ LogicalResult
MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
StringRef operand) {
if (!operand.equals("AOp") && !operand.equals("BOp") &&
!operand.equals("COp"))
if (operand != "AOp" && operand != "BOp" && operand != "COp")
return emitError() << "operand expected to be one of AOp, BOp or COp";

if (shape.size() != 2)
Expand Down Expand Up @@ -1941,8 +1940,7 @@ LogicalResult SubgroupMmaLoadMatrixOp::verify() {
return emitError(
"expected source memref most minor dim must have unit stride");

if (!operand.equals("AOp") && !operand.equals("BOp") &&
!operand.equals("COp"))
if (operand != "AOp" && operand != "BOp" && operand != "COp")
return emitError("only AOp, BOp and COp can be loaded");

return success();
Expand All @@ -1962,7 +1960,7 @@ LogicalResult SubgroupMmaStoreMatrixOp::verify() {
return emitError(
"expected destination memref most minor dim must have unit stride");

if (!srcMatrixType.getOperand().equals("COp"))
if (srcMatrixType.getOperand() != "COp")
return emitError(
"expected the operand matrix being stored to have 'COp' operand type");

Expand All @@ -1980,9 +1978,8 @@ LogicalResult SubgroupMmaComputeOp::verify() {
opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));

if (!opTypes[A].getOperand().equals("AOp") ||
!opTypes[B].getOperand().equals("BOp") ||
!opTypes[C].getOperand().equals("COp"))
if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
opTypes[C].getOperand() != "COp")
return emitError("operands must be in the order AOp, BOp, COp");

ArrayRef<int64_t> aShape, bShape, cShape;
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,7 @@ struct LLVMStructTypeStorage : public TypeStorage {
/// Compares two keys.
bool operator==(const Key &other) const {
if (isIdentified())
return other.isIdentified() &&
other.getIdentifier().equals(getIdentifier());
return other.isIdentified() && other.getIdentifier() == getIdentifier();

return !other.isIdentified() && other.isPacked() == isPacked() &&
other.getTypeList() == getTypeList();
Expand Down
14 changes: 6 additions & 8 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1742,20 +1742,18 @@ ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
parser.resolveOperands(indexInfo, indexTy, result.operands))
return failure();

if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
if (readOrWrite != "read" && readOrWrite != "write")
return parser.emitError(parser.getNameLoc(),
"rw specifier has to be 'read' or 'write'");
result.addAttribute(
PrefetchOp::getIsWriteAttrStrName(),
parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),
parser.getBuilder().getBoolAttr(readOrWrite == "write"));

if (!cacheType.equals("data") && !cacheType.equals("instr"))
if (cacheType != "data" && cacheType != "instr")
return parser.emitError(parser.getNameLoc(),
"cache type has to be 'data' or 'instr'");

result.addAttribute(
PrefetchOp::getIsDataCacheAttrStrName(),
parser.getBuilder().getBoolAttr(cacheType.equals("data")));
result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
parser.getBuilder().getBoolAttr(cacheType == "data"));

return success();
}
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,11 @@ ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
auto loc = parser.getCurrentLocation();
ERROR_IF(failed(parser.parseOptionalKeyword(&strVal)),
"expected valid level property (e.g. nonordered, nonunique or high)")
if (strVal.equals(toPropString(LevelPropNonDefault::Nonunique))) {
if (strVal == toPropString(LevelPropNonDefault::Nonunique)) {
*properties |= static_cast<uint64_t>(LevelPropNonDefault::Nonunique);
} else if (strVal.equals(toPropString(LevelPropNonDefault::Nonordered))) {
} else if (strVal == toPropString(LevelPropNonDefault::Nonordered)) {
*properties |= static_cast<uint64_t>(LevelPropNonDefault::Nonordered);
} else if (strVal.equals(toPropString(LevelPropNonDefault::SoA))) {
} else if (strVal == toPropString(LevelPropNonDefault::SoA)) {
*properties |= static_cast<uint64_t>(LevelPropNonDefault::SoA);
} else {
parser.emitError(loc, "unknown level property: ") << strVal;
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/IR/AttributeDetail.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ struct DenseStringElementsAttrStorage : public DenseElementsAttributeStorage {
// Check to see if this storage represents a splat. If it doesn't then
// combine the hash for the data starting with the first non splat element.
for (size_t i = 1, e = data.size(); i != e; i++)
if (!firstElt.equals(data[i]))
if (firstElt != data[i])
return KeyTy(ty, data, llvm::hash_combine(hashVal, data.drop_front(i)));

// Otherwise, this is a splat so just return the hash of the first element.
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/TableGen/Builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Builder::Builder(const llvm::Record *record, ArrayRef<SMLoc> loc)
// Initialize the parameters of the builder.
const llvm::DagInit *dag = def->getValueAsDag("dagParams");
auto *defInit = dyn_cast<llvm::DefInit>(dag->getOperator());
if (!defInit || !defInit->getDef()->getName().equals("ins"))
if (!defInit || defInit->getDef()->getName() != "ins")
PrintFatalError(def->getLoc(), "expected 'ins' in builders");

bool seenDefaultValue = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
return failure();

// Handle function entry count metadata.
if (name->getString().equals("function_entry_count")) {
if (name->getString() == "function_entry_count") {

// TODO support function entry count metadata with GUID fields.
if (node->getNumOperands() != 2)
Expand All @@ -111,7 +111,7 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
<< "expected function_entry_count to be attached to a function";
}

if (!name->getString().equals("branch_weights"))
if (name->getString() != "branch_weights")
return failure();

// Handle branch weights metadata.
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ void PrettyPrintedRegionOp::print(OpAsmPrinter &p) {
// of inner-op), then we can print the entire region in a succinct way.
// Here we assume that the prototype of "test.special.op" can be trivially
// derived while parsing it back.
if (innerOp.getName().getStringRef().equals("test.special.op")) {
if (innerOp.getName().getStringRef() == "test.special.op") {
p << " start test.special.op end";
} else {
p << " (";
Expand Down
2 changes: 1 addition & 1 deletion mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ static void collectAllDefs(StringRef selectedDialect,
} else {
// Otherwise, generate the defs that belong to the selected dialect.
auto dialectDefs = llvm::make_filter_range(defs, [&](const auto &def) {
return def.getDialect().getName().equals(selectedDialect);
return def.getDialect().getName() == selectedDialect;
});
resultDefs.assign(dialectDefs.begin(), dialectDefs.end());
}
Expand Down
4 changes: 2 additions & 2 deletions mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
std::string sanitizedName = sanitizeName(namedAttr.name);

// Unit attributes are handled specially.
if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
if (namedAttr.attr.getStorageType().trim() == "::mlir::UnitAttr") {
os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName,
namedAttr.name);
os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName,
Expand Down Expand Up @@ -668,7 +668,7 @@ populateBuilderLinesAttr(const Operator &op,
continue;

// Unit attributes are handled specially.
if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
if (attribute->attr.getStorageType().trim() == "::mlir::UnitAttr") {
builderLines.push_back(llvm::formatv(initUnitAttributeTemplate,
attribute->name, argNames[i]));
continue;
Expand Down

0 comments on commit dec8055

Please sign in to comment.