From dec8055a1e71fe25d4b85416ede742e8fdfaf3f0 Mon Sep 17 00:00:00 2001 From: Kazu Hirata Date: Wed, 8 May 2024 23:52:22 -0700 Subject: [PATCH] [mlir] Use StringRef::operator== instead of StringRef::equals (NFC) (#91560) I'm planning to remove StringRef::equals in favor of StringRef::operator==. - StringRef::operator==/!= outnumber StringRef::equals by a factor of 10 under mlir/ in terms of their usage. - The elimination of StringRef::equals brings StringRef closer to std::string_view, which has operator== but not equals. - S == "foo" is more readable than S.equals("foo"), especially for !Long.Expression.equals("str") vs Long.Expression != "str". --- .../Conversion/GPUCommon/GPUToLLVMConversion.cpp | 2 +- mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp | 16 ++++++++-------- mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp | 2 +- mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 14 ++++++-------- mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 13 +++++-------- mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h | 3 +-- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 14 ++++++-------- .../SparseTensor/IR/Detail/LvlTypeParser.cpp | 6 +++--- mlir/lib/IR/AttributeDetail.h | 2 +- mlir/lib/TableGen/Builder.cpp | 2 +- .../Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp | 4 ++-- mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp | 2 +- mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp | 2 +- mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 4 ++-- 14 files changed, 39 insertions(+), 47 deletions(-) diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp index 3a4fc7d8063f40..82bfa9514a8841 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -926,7 +926,7 @@ LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite( static bool isDefinedByCallTo(Value value, StringRef functionName) { assert(isa(value.getType())); if (auto defOp = value.getDefiningOp()) - return defOp.getCallee()->equals(functionName); + return *defOp.getCallee() == functionName; return false; } diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp index 775dd1e609037f..b7fd454c60902f 100644 --- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -42,11 +42,11 @@ static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, static constexpr StringRef kInvalidCaseStr = "Unsupported WMMA variant."; static NVVM::MMAFrag convertOperand(StringRef operandName) { - if (operandName.equals("AOp")) + if (operandName == "AOp") return NVVM::MMAFrag::a; - if (operandName.equals("BOp")) + if (operandName == "BOp") return NVVM::MMAFrag::b; - if (operandName.equals("COp")) + if (operandName == "COp") return NVVM::MMAFrag::c; llvm_unreachable("Unknown operand name"); } @@ -55,8 +55,8 @@ static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) { if (type.getElementType().isF16()) return NVVM::MMATypes::f16; if (type.getElementType().isF32()) - return type.getOperand().equals("COp") ? NVVM::MMATypes::f32 - : NVVM::MMATypes::tf32; + return type.getOperand() == "COp" ? NVVM::MMATypes::f32 + : NVVM::MMATypes::tf32; if (type.getElementType().isSignedInteger(8)) return NVVM::MMATypes::s8; @@ -99,15 +99,15 @@ struct WmmaLoadOpToNVVMLowering NVVM::MMATypes eltype = getElementType(retType); // NVVM intrinsics require to give mxnxk dimensions, infer the missing // dimension based on the valid intrinsics available. - if (retType.getOperand().equals("AOp")) { + if (retType.getOperand() == "AOp") { m = retTypeShape[0]; k = retTypeShape[1]; n = NVVM::WMMALoadOp::inferNDimension(m, k, eltype); - } else if (retType.getOperand().equals("BOp")) { + } else if (retType.getOperand() == "BOp") { k = retTypeShape[0]; n = retTypeShape[1]; m = NVVM::WMMALoadOp::inferMDimension(k, n, eltype); - } else if (retType.getOperand().equals("COp")) { + } else if (retType.getOperand() == "COp") { m = retTypeShape[0]; n = retTypeShape[1]; k = NVVM::WMMALoadOp::inferKDimension(m, n, eltype); diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index f8485e02a2208e..19f02297bfbb71 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -261,7 +261,7 @@ static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp, template static bool isTensorOp(OpTy xferOp) { if (isa(xferOp.getShapedType())) { - if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) { + if (xferOp.getOperationName() == TransferWriteOp::getOperationName()) { // TransferWriteOps on tensors have a result. assert(xferOp->getNumResults() > 0); } diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index c9c0a7b4cc6860..2e31487bd55a0a 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -3585,20 +3585,18 @@ ParseResult AffinePrefetchOp::parse(OpAsmParser &parser, parser.resolveOperands(mapOperands, indexTy, result.operands)) return failure(); - if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) + if (readOrWrite != "read" && readOrWrite != "write") return parser.emitError(parser.getNameLoc(), "rw specifier has to be 'read' or 'write'"); - result.addAttribute( - AffinePrefetchOp::getIsWriteAttrStrName(), - parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); + result.addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(), + parser.getBuilder().getBoolAttr(readOrWrite == "write")); - if (!cacheType.equals("data") && !cacheType.equals("instr")) + if (cacheType != "data" && cacheType != "instr") return parser.emitError(parser.getNameLoc(), "cache type has to be 'data' or 'instr'"); - result.addAttribute( - AffinePrefetchOp::getIsDataCacheAttrStrName(), - parser.getBuilder().getBoolAttr(cacheType.equals("data"))); + result.addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(), + parser.getBuilder().getBoolAttr(cacheType == "data")); return success(); } diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index f1b9ca5c500208..0c2590d711301b 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -152,8 +152,7 @@ LogicalResult MMAMatrixType::verify(function_ref emitError, ArrayRef shape, Type elementType, StringRef operand) { - if (!operand.equals("AOp") && !operand.equals("BOp") && - !operand.equals("COp")) + if (operand != "AOp" && operand != "BOp" && operand != "COp") return emitError() << "operand expected to be one of AOp, BOp or COp"; if (shape.size() != 2) @@ -1941,8 +1940,7 @@ LogicalResult SubgroupMmaLoadMatrixOp::verify() { return emitError( "expected source memref most minor dim must have unit stride"); - if (!operand.equals("AOp") && !operand.equals("BOp") && - !operand.equals("COp")) + if (operand != "AOp" && operand != "BOp" && operand != "COp") return emitError("only AOp, BOp and COp can be loaded"); return success(); @@ -1962,7 +1960,7 @@ LogicalResult SubgroupMmaStoreMatrixOp::verify() { return emitError( "expected destination memref most minor dim must have unit stride"); - if (!srcMatrixType.getOperand().equals("COp")) + if (srcMatrixType.getOperand() != "COp") return emitError( "expected the operand matrix being stored to have 'COp' operand type"); @@ -1980,9 +1978,8 @@ LogicalResult SubgroupMmaComputeOp::verify() { opTypes.push_back(llvm::cast(getOpB().getType())); opTypes.push_back(llvm::cast(getOpC().getType())); - if (!opTypes[A].getOperand().equals("AOp") || - !opTypes[B].getOperand().equals("BOp") || - !opTypes[C].getOperand().equals("COp")) + if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" || + opTypes[C].getOperand() != "COp") return emitError("operands must be in the order AOp, BOp, COp"); ArrayRef aShape, bShape, cShape; diff --git a/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h b/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h index 2040d0a06b2e3b..8767b1c3ffc5bd 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h +++ b/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h @@ -131,8 +131,7 @@ struct LLVMStructTypeStorage : public TypeStorage { /// Compares two keys. bool operator==(const Key &other) const { if (isIdentified()) - return other.isIdentified() && - other.getIdentifier().equals(getIdentifier()); + return other.isIdentified() && other.getIdentifier() == getIdentifier(); return !other.isIdentified() && other.isPacked() == isPacked() && other.getTypeList() == getTypeList(); diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index c9a85919ec799b..199e7330a233c0 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1742,20 +1742,18 @@ ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) { parser.resolveOperands(indexInfo, indexTy, result.operands)) return failure(); - if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) + if (readOrWrite != "read" && readOrWrite != "write") return parser.emitError(parser.getNameLoc(), "rw specifier has to be 'read' or 'write'"); - result.addAttribute( - PrefetchOp::getIsWriteAttrStrName(), - parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); + result.addAttribute(PrefetchOp::getIsWriteAttrStrName(), + parser.getBuilder().getBoolAttr(readOrWrite == "write")); - if (!cacheType.equals("data") && !cacheType.equals("instr")) + if (cacheType != "data" && cacheType != "instr") return parser.emitError(parser.getNameLoc(), "cache type has to be 'data' or 'instr'"); - result.addAttribute( - PrefetchOp::getIsDataCacheAttrStrName(), - parser.getBuilder().getBoolAttr(cacheType.equals("data"))); + result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(), + parser.getBuilder().getBoolAttr(cacheType == "data")); return success(); } diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp index 92e5efaa810497..39f5cf1a750828 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp @@ -89,11 +89,11 @@ ParseResult LvlTypeParser::parseProperty(AsmParser &parser, auto loc = parser.getCurrentLocation(); ERROR_IF(failed(parser.parseOptionalKeyword(&strVal)), "expected valid level property (e.g. nonordered, nonunique or high)") - if (strVal.equals(toPropString(LevelPropNonDefault::Nonunique))) { + if (strVal == toPropString(LevelPropNonDefault::Nonunique)) { *properties |= static_cast(LevelPropNonDefault::Nonunique); - } else if (strVal.equals(toPropString(LevelPropNonDefault::Nonordered))) { + } else if (strVal == toPropString(LevelPropNonDefault::Nonordered)) { *properties |= static_cast(LevelPropNonDefault::Nonordered); - } else if (strVal.equals(toPropString(LevelPropNonDefault::SoA))) { + } else if (strVal == toPropString(LevelPropNonDefault::SoA)) { *properties |= static_cast(LevelPropNonDefault::SoA); } else { parser.emitError(loc, "unknown level property: ") << strVal; diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h index dcd24af0107ddf..26d40ac3a38f63 100644 --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -261,7 +261,7 @@ struct DenseStringElementsAttrStorage : public DenseElementsAttributeStorage { // Check to see if this storage represents a splat. If it doesn't then // combine the hash for the data starting with the first non splat element. for (size_t i = 1, e = data.size(); i != e; i++) - if (!firstElt.equals(data[i])) + if (firstElt != data[i]) return KeyTy(ty, data, llvm::hash_combine(hashVal, data.drop_front(i))); // Otherwise, this is a splat so just return the hash of the first element. diff --git a/mlir/lib/TableGen/Builder.cpp b/mlir/lib/TableGen/Builder.cpp index 47a2f6cc4456eb..044765c726019d 100644 --- a/mlir/lib/TableGen/Builder.cpp +++ b/mlir/lib/TableGen/Builder.cpp @@ -52,7 +52,7 @@ Builder::Builder(const llvm::Record *record, ArrayRef loc) // Initialize the parameters of the builder. const llvm::DagInit *dag = def->getValueAsDag("dagParams"); auto *defInit = dyn_cast(dag->getOperator()); - if (!defInit || !defInit->getDef()->getName().equals("ins")) + if (!defInit || defInit->getDef()->getName() != "ins") PrintFatalError(def->getLoc(), "expected 'ins' in builders"); bool seenDefaultValue = false; diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp index 40d8253d822f64..06673965245c00 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp @@ -93,7 +93,7 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node, return failure(); // Handle function entry count metadata. - if (name->getString().equals("function_entry_count")) { + if (name->getString() == "function_entry_count") { // TODO support function entry count metadata with GUID fields. if (node->getNumOperands() != 2) @@ -111,7 +111,7 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node, << "expected function_entry_count to be attached to a function"; } - if (!name->getString().equals("branch_weights")) + if (name->getString() != "branch_weights") return failure(); // Handle branch weights metadata. diff --git a/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp b/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp index c376d6c73c6452..ebaced57a24a49 100644 --- a/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp @@ -413,7 +413,7 @@ void PrettyPrintedRegionOp::print(OpAsmPrinter &p) { // of inner-op), then we can print the entire region in a succinct way. // Here we assume that the prototype of "test.special.op" can be trivially // derived while parsing it back. - if (innerOp.getName().getStringRef().equals("test.special.op")) { + if (innerOp.getName().getStringRef() == "test.special.op") { p << " start test.special.op end"; } else { p << " ("; diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp index b9a72119790e5a..55bc0714c20ec6 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -50,7 +50,7 @@ static void collectAllDefs(StringRef selectedDialect, } else { // Otherwise, generate the defs that belong to the selected dialect. auto dialectDefs = llvm::make_filter_range(defs, [&](const auto &def) { - return def.getDialect().getName().equals(selectedDialect); + return def.getDialect().getName() == selectedDialect; }); resultDefs.assign(dialectDefs.begin(), dialectDefs.end()); } diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp index 814008c2545114..052020acdcb764 100644 --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -457,7 +457,7 @@ static void emitAttributeAccessors(const Operator &op, raw_ostream &os) { std::string sanitizedName = sanitizeName(namedAttr.name); // Unit attributes are handled specially. - if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) { + if (namedAttr.attr.getStorageType().trim() == "::mlir::UnitAttr") { os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName, namedAttr.name); os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName, @@ -668,7 +668,7 @@ populateBuilderLinesAttr(const Operator &op, continue; // Unit attributes are handled specially. - if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) { + if (attribute->attr.getStorageType().trim() == "::mlir::UnitAttr") { builderLines.push_back(llvm::formatv(initUnitAttributeTemplate, attribute->name, argNames[i])); continue;