diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h b/mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h index 929bf4efc..4a8c30423 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h @@ -52,6 +52,10 @@ static inline bool mtrtCompilerClientIsNull(MTRT_CompilerClient options) { return !options.ptr; } +MLIR_CAPI_EXPORTED MTRT_Status mtrtCompilerClientGetCompilationTask( + MTRT_CompilerClient client, MlirStringRef taskMnemonic, + const MlirStringRef *argv, unsigned argc, MlirPassManager *result); + //===----------------------------------------------------------------------===// // MTRT_OptionsContext //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/Client.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/Client.h index 6824ffad6..030d64bf7 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/Client.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/Client.h @@ -29,6 +29,7 @@ #include "mlir-executor/Support/Status.h" #include "mlir-tensorrt/Compiler/OptionsProviders.h" +#include "mlir-tensorrt/Compiler/OptionsRegistry.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/TypeID.h" @@ -100,28 +101,69 @@ class CompilerClient { ~CompilerClient() = default; - /// Create or retrieve a cached PassManager of the given derived type using - /// the provided options. PassManagers are cached by type and a hash of the - /// string representation of the options. - /// This function should only be called if the options have a valid hash. - template - mlir::PassManager &getOrCreatePassManager(const OptionsType &options) { - std::optional hash = options.getHash(); - if (!hash) - llvm::report_fatal_error("attempted to lookup a PassManager from a cache " - "with an un-hashable options key"); - - auto key = - std::make_pair(mlir::TypeID::get(), hash.value()); + /// Create or retrieve from the cache a compilation task of the specified + /// type and options. If an existing compilation task is not in the cache, + /// then it is constructed using the registered construction function and + /// inserted into the cache. + StatusOr + getCompilationTask(mlir::TypeID taskID, + llvm::ArrayRef options); + + /// Create or retrieve from the cache a compilation task of the specified + /// type ID and options. If an existing compilation task is not in the cache, + /// then it is constructed using the registered construction function and + /// inserted into the cache. + StatusOr + getCompilationTask(mlir::TypeID taskID, llvm::ArrayRef options) { + return getCompilationTask( + taskID, llvm::map_to_vector(options, [](const std::string &x) { + return llvm::StringRef(x); + })); + } + + StatusOr + getCompilationTask(llvm::StringRef mnemonic, + llvm::ArrayRef options); + + /// Create or retrieve from the cache a compilation task of the specified + /// type and options. If an existing compilation task is not in the cache, + /// then it is constructed using the registered construction function and + /// inserted into the cache. + template + StatusOr getCompilationTask(Args &&...args) { + return getCompilationTask(mlir::TypeID::get(), + std::forward(args)...); + } + + /// Insert a compilation task of type T with options hash `hash` into the + /// cache. + template + void updateCachedCompilationTask(const llvm::hash_code &hash, + std::unique_ptr task) { + cachedPassManagers[std::make_pair(mlir::TypeID::get(), hash)] = + std::move(task); + } + + /// Check whether a CompilationTask with the specified typeID and whose + /// options have the given hash is in the cache. If so, return it; otherwise + /// returns nullptr. + CompilationTaskBase * + lookupCachedCompilationTask(mlir::TypeID taskID, + const llvm::hash_code &optionsHash) { + auto key = std::make_pair(taskID, optionsHash); auto it = cachedPassManagers.find(key); - if (it == cachedPassManagers.end()) { - auto pm = std::make_unique(context, options); - setupPassManagerLogging(*pm, options.template get()); - auto *ptr = pm.get(); - cachedPassManagers[key] = std::move(pm); - return *ptr; - } - return *it->second; + if (it == cachedPassManagers.end()) + return nullptr; + return it->second.get(); + } + + /// Check whether a CompilationTask with the specified type T and whose + /// options have the given hash is in the cache. If so, return it; otherwise + /// returns nullptr. + template + CompilationTaskBase * + lookupCachedCompilationTask(const llvm::hash_code &optionsHash) { + return lookupCachedCompilationTask(mlir::TypeID::get(), optionsHash); } /// Return the MLIRContext associated with the client. @@ -147,6 +189,68 @@ class CompilerClient { cachedPassManagers; }; +/// A registry function that adds passes to the given pass manager. This should +/// also parse options and return success() if parsing succeeded. +/// `errorHandler` is a functor used to emit errors during parsing. +/// parameter corresponds to the raw location within the pipeline string. This +/// should always return failure. +using TaskRegistryFunction = std::function( + CompilerClient &client, llvm::ArrayRef options)>; + +struct TaskRegistration { + TaskRegistryFunction registryFunc; +}; + +void registerCompilationTask(llvm::StringRef mnemonic, mlir::TypeID typeID, + TaskRegistryFunction func); + +template +void registerCompilationTask(llvm::StringRef mnemonic, + TaskRegistryFunction func) { + return registerCompilationTask(mnemonic, mlir::TypeID::get(), + std::move(func)); +} + +template +void registerCompilationTaskWithNoExtensions(llvm::StringRef mnemonic) { + registerCompilationTask( + mnemonic, + [](CompilerClient &client, llvm::ArrayRef options) + -> StatusOr { + OptionsType result; + std::string err; + if (failed(result.parse(options, err))) + return getInvalidArgStatus( + "failed to parse options string \"{0:$[ ]}\" due to error {1}", + llvm::iterator_range(options), err); + + llvm::Error finalizeStatus = result.finalize(); + std::optional errMsg{}; + llvm::handleAllErrors(std::move(finalizeStatus), + [&errMsg](const llvm::StringError &err) { + errMsg = err.getMessage(); + }); + + if (errMsg) + return getInvalidArgStatus("failed to parse options due to error {0}", + errMsg); + + std::optional hashCode = result.getHash(); + if (!hashCode) + return getInvalidArgStatus("failed to hash options"); + + CompilationTaskBase *cached = + client.lookupCachedCompilationTask(*hashCode); + if (cached) + return cached; + + auto newPM = std::make_unique(client.getContext(), result); + auto ptr = newPM.get(); + client.updateCachedCompilationTask(*hashCode, std::move(newPM)); + return ptr; + }); +} + } // namespace mlirtrt::compiler #endif // MLIR_TENSORRT_COMPILER_CLIENT diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/OptionsProviders.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/OptionsProviders.h index 8e86f8bd0..4a0f0ffa3 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/OptionsProviders.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/OptionsProviders.h @@ -47,6 +47,24 @@ constexpr bool has_finalize_impl_v< // a default implementation otherwise. template struct OptionsProvider { + OptionsProvider(mlir::OptionsContext &ctx) : ctx(ctx) {} + + // We don't allow move construction since the actual ptrs/locations of + // individual member elements of an OptionsProvider are captured into the + // OptionsContext. If the OptionsContext is populated upon construction, + // moving can change the memory location of the owned values, which will cause + // a crash later on. This is in particular can happen if you are constructing + // a tuple of `OptionsProviders`. Since we are deleting the move constructor, + // one must instead use a tuple of `unique_ptr`. + OptionsProvider(OptionsProvider &&) = delete; + + mlir::OptionsContext &ctx; + + template + using Option = mlir::OptionsContext::Option; + template + using ListOption = mlir::OptionsContext::ListOption; + /// Modifies options after parsing. This is required since we may need /// to make changes to options based on the values of other options. /// Do *not* override this method; instead, implement `finalizeImpl()`. @@ -62,67 +80,63 @@ struct OptionsProvider { /// interfaces. struct DebugOptions : public OptionsProvider { public: + using OptionsProvider::OptionsProvider; /// A directory path where the IR will be dumped during compilation /// using the `mlir-print-ir-tree-dir` mechanism. - std::string dumpIRPath = ""; + Option dumpIRPath{&this->ctx, "mlir-print-ir-tree-dir", + llvm::cl::init("")}; /// Whether the LLVM 'debug' flag that enables execution of code guarded by /// the `LLVM_DEBUG` macro should be set to 'on'. This results in very verbose /// output from the compiler dumped to stderr. - bool enableLLVMDebugFlag = false; + Option enableLLVMDebugFlag{&this->ctx, "debug", llvm::cl::init(false)}; /// A set of names to be given to the LLVM 'debug types' option, akin to /// setting /// `-debug-types=...` from the command line. - mlir::SmallVector llvmDebugTypes = {}; - -public: - void addToOptions(mlir::OptionsContext &context) { - context.addOption("mlir-print-ir-tree-dir", dumpIRPath, llvm::cl::init("")); - context.addOption("debug", enableLLVMDebugFlag); - context.addList("debug-only", llvmDebugTypes, - llvm::cl::ZeroOrMore, - llvm::cl::CommaSeparated); - } + ListOption llvmDebugTypes{ + &this->ctx, "debug-only", llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated}; }; struct ExecutorOptions : public OptionsProvider { public: - /// The host index bit-width. - int64_t indexBitwidth{64}; + using OptionsProvider::OptionsProvider; - /// Whether to pass memref's as struct/table in function calls. - bool usePackedMemRefCConv{true}; + Option indexBitwidth{&this->ctx, "executor-index-bitwidth", + llvm::cl::init(64), + llvm::cl::desc("executor index bitwidth")}; -public: - void addToOptions(mlir::OptionsContext &context) { - context.addOption("executor-index-bitwidth", indexBitwidth, - llvm::cl::init(64)); - } + Option usePackedMemRefCConv{ + &this->ctx, "executor-use-packed-memref-cconv", llvm::cl::init(true), + llvm::cl::desc( + "whether to use packed or unpacked memref calling convention")}; }; struct DeviceOptions : public OptionsProvider { public: + using OptionsProvider::OptionsProvider; + + /// Device information. Members are manually bound to options in the + /// constructor. DeviceInfo info; - /// Whether to ignore `deviceX` options and instead infer them from the GPUs - /// on the host system running the compilation. - bool shouldInferFromHost = false; + Option shouldInferFromHost{ + &this->ctx, "device-infer-from-host", llvm::cl::init(true), + llvm::cl::desc("whether to ignore `deviceX` options and instead infer " + "them from the host GPU")}; + Status inferFromHost(); public: - void addToOptions(mlir::OptionsContext &context) { - context.addOption( + DeviceOptions(mlir::OptionsContext &ctx) : OptionsProvider(ctx) { + ctx.addOption( "device-compute-capability", info.computeCapability, llvm::cl::init(60), llvm::cl::desc("Sets the device compute capbility. Only relevant " "if '--device-infer-from-host=false'")); - context.addOption("device-max-shared-memory-per-block-kb", - info.maxSharedMemoryPerBlockKb, llvm::cl::init(48)); - context.addOption("device-max-registers-per-block", - info.maxRegistersPerBlock, llvm::cl::init(65536)); - context.addOption("device-infer-from-host", shouldInferFromHost, - llvm::cl::init(true), - llvm::cl::desc("Infers device information from host")); + ctx.addOption("device-max-shared-memory-per-block-kb", + info.maxSharedMemoryPerBlockKb, llvm::cl::init(48)); + ctx.addOption("device-max-registers-per-block", info.maxRegistersPerBlock, + llvm::cl::init(65536)); } llvm::Error finalizeImpl(); diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/OptionsRegistry.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/OptionsRegistry.h index 8fc387eeb..25cf0bbc6 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/OptionsRegistry.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/OptionsRegistry.h @@ -28,8 +28,8 @@ #define MLIR_TENSORRT_COMPILER_OPTIONS_REGISTRY #include "mlir-tensorrt-dialect/Utils/Options.h" -#include "mlir-tensorrt/Compiler/Client.h" #include "mlir-tensorrt/Dialect/Plan/IR/Plan.h" +#include "mlir/IR/MLIRContext.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Error.h" @@ -39,25 +39,23 @@ namespace mlirtrt::compiler { using OptionsConstructorFuncT = std::function>( - const CompilerClient &client, const llvm::ArrayRef)>; + mlir::MLIRContext *, llvm::ArrayRef)>; /// Registers an options creation function for a specific options type. -void registerOption(const llvm::StringRef optionsType, - OptionsConstructorFuncT func); +void registerOption(llvm::StringRef optionsType, OptionsConstructorFuncT func); /// Creates an options instance for the specified options type using a creation /// function that was previously registered. StatusOr> -createOptions(const CompilerClient &client, const llvm::StringRef optionsType, - const llvm::ArrayRef args); +createOptions(mlir::MLIRContext *client, llvm::StringRef optionsType, + llvm::ArrayRef args); /// Helper to build callbacks that can create options. template -StatusOr> -optionsCreateFromArgs(const CompilerClient &client, - const llvm::ArrayRef args) { +StatusOr> +optionsCreateFromArgs(mlir::MLIRContext *context, + llvm::ArrayRef args) { // Load available extensions. - mlir::MLIRContext *context = client.getContext(); mlir::plan::PlanDialect *planDialect = context->getLoadedDialect(); compiler::TaskExtensionRegistry extensions = @@ -83,7 +81,7 @@ optionsCreateFromArgs(const CompilerClient &client, return getInternalErrorStatus("failed to initialize options: %s", errMsg->c_str()); - return std::unique_ptr(result.release()); + return result; } } // namespace mlirtrt::compiler diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h index e67b07bf2..101a33d76 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h @@ -58,15 +58,19 @@ struct StablehloToExecutableOptions StablehloToExecutableOptions(TaskExtensionRegistry extensions); /// Whether to disallow host tensors in TensorRT clusters. - bool disallowHostTensorsInTensorRTClusters = false; + Option disallowHostTensorsInTensorRTClusters{ + this, "plan-clustering-disallow-host-tensors-in-tensorrt-clusters", + llvm::cl::init(false), + llvm::cl::desc("Don't allow TensorRt clusters to contain host tensor " + "calculations (but they can still be inputs)")}; + + Option entrypoint{this, "entrypoint", llvm::cl::init("main"), + llvm::cl::desc("entrypoint function name")}; /// Use non-DPS style calling convention for entrypoint function /// and backend types that support allocating results. bool enableNonDPSReturns = false; - /// Entrypoint function name. - std::string entrypoint = "main"; - /// Base class for extensions associated with StableHloToExecutableTask. class ExtensionBase : public TaskExtensionBase { public: @@ -134,13 +138,6 @@ class StablehloToExecutableTask static void populatePassManager(mlir::PassManager &pm, const StablehloToExecutableOptions &options); - /// Compile a StableHLO module into a MLIR-TensorRT Runtime executable. - /// This is the "functional" entrypoint that will allocate a new PassManager - /// for a single run. - static mlirtrt::StatusOr> - compileStableHLOToExecutable(mlir::ModuleOp module, - const StablehloToExecutableOptions &options); - /// Compile a StableHLO module into a MLIR-TensorRT Runtime executable. /// This is the "functional" entrypoint that will allocate a new PassManager /// for a single run. diff --git a/mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp b/mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp index c5d390f0f..4a6bfc3da 100644 --- a/mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp +++ b/mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp @@ -23,6 +23,7 @@ //===----------------------------------------------------------------------===// #include "mlir-tensorrt-c/Compiler/Compiler.h" #include "mlir-c/IR.h" +#include "mlir-c/Pass.h" #include "mlir-c/Support.h" #include "mlir-executor-c/Support/Status.h" #include "mlir-executor/Target/Lua/TranslateToRuntimeExecutable.h" @@ -35,6 +36,7 @@ #include "mlir-tensorrt/Dialect/Plan/IR/Plan.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Utils.h" +#include "mlir/Pass/PassManager.h" #include "llvm/ADT/StringExtras.h" using namespace mlirtrt; @@ -104,6 +106,22 @@ MTRT_Status mtrtCompilerClientDestroy(MTRT_CompilerClient client) { return mtrtStatusGetOk(); } +MTRT_Status mtrtCompilerClientGetCompilationTask(MTRT_CompilerClient client, + MlirStringRef taskMnemonic, + const MlirStringRef *argv, + unsigned argc, + MlirPassManager *result) { + std::vector argvStrRef(argc); + for (unsigned i = 0; i < argc; i++) + argvStrRef[i] = llvm::StringRef(argv[i].data, argv[i].length); + StatusOr task = unwrap(client)->getCompilationTask( + StringRef(taskMnemonic.data, taskMnemonic.length), argvStrRef); + if (!task.isOk()) + return wrap(task.getStatus()); + *result = MlirPassManager{static_cast(*task)}; + return mtrtStatusGetOk(); +} + //===----------------------------------------------------------------------===// // MTRT_OptionsContext //===----------------------------------------------------------------------===// @@ -116,8 +134,8 @@ MLIR_CAPI_EXPORTED MTRT_Status mtrtOptionsContextCreateFromArgs( argvStrRef[i] = llvm::StringRef(argv[i].data, argv[i].length); auto result = createOptions( - *unwrap(client), llvm::StringRef(optionsType.data, optionsType.length), - argvStrRef); + unwrap(client)->getContext(), + llvm::StringRef(optionsType.data, optionsType.length), argvStrRef); if (!result.isOk()) return wrap(result.getStatus()); @@ -260,15 +278,16 @@ mtrtStableHloPipelineGetCached(MTRT_CompilerClient client, MTRT_StableHLOToExecutableOptions options, MlirPassManager *result) { - mlir::PassManager *runner{}; - if (unwrap(options)->getHash()) { - runner = &unwrap(client)->getOrCreatePassManager( - *unwrap(options)); - result->ptr = runner; - return mtrtStatusGetOk(); - } - return mtrtStatusCreate(MTRT_StatusCode::MTRT_StatusCode_InternalError, - "options cannot be hashed"); + if (!unwrap(options)->getHash()) + return mtrtStatusCreate(MTRT_StatusCode::MTRT_StatusCode_InternalError, + "options cannot be hashed"); + StatusOr runner = + unwrap(client)->getCompilationTask( + unwrap(options)->serialize()); + if (!runner.isOk()) + return wrap(runner.getStatus()); + *result = MlirPassManager{static_cast(*runner)}; + return mtrtStatusGetOk(); } //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/compiler/lib/Compiler/Client.cpp b/mlir-tensorrt/compiler/lib/Compiler/Client.cpp index dca44e340..4783f16cc 100644 --- a/mlir-tensorrt/compiler/lib/Compiler/Client.cpp +++ b/mlir-tensorrt/compiler/lib/Compiler/Client.cpp @@ -22,8 +22,11 @@ /// //===----------------------------------------------------------------------===// #include "mlir-tensorrt/Compiler/Client.h" +#include "mlir-executor/Support/Status.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Support/FileUtilities.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/ManagedStatic.h" using namespace mlirtrt; using namespace mlirtrt::compiler; @@ -32,6 +35,12 @@ using namespace mlir; #define DEBUG_TYPE "compiler-api" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]") +static llvm::ManagedStatic> + taskRegistry{}; + +/// Global registry for mapping task mnemonics to type IDs. +static llvm::ManagedStatic> taskNameRegistry; + //===----------------------------------------------------------------------===// // CompilationTask //===----------------------------------------------------------------------===// @@ -67,3 +76,35 @@ void CompilerClient::setupPassManagerLogging(mlir::PassManager &pm, mlir::OpPrintingFlags().elideLargeElementsAttrs(32)); } } + +StatusOr +CompilerClient::getCompilationTask(mlir::TypeID taskID, + llvm::ArrayRef options) { + auto it = taskRegistry->find(taskID); + if (it == taskRegistry->end()) + llvm::report_fatal_error("no such task registered"); + return it->second.registryFunc(*this, options); +} + +StatusOr +CompilerClient::getCompilationTask(llvm::StringRef mnemonic, + llvm::ArrayRef options) { + auto it = taskNameRegistry->find(mnemonic); + if (it == taskNameRegistry->end()) + return getInvalidArgStatus("no compilation task registered with name {0}", + mnemonic); + + return getCompilationTask(taskNameRegistry->lookup(mnemonic), options); +} + +void compiler::registerCompilationTask(llvm::StringRef mnemonic, + mlir::TypeID typeID, + TaskRegistryFunction func) { + if (taskNameRegistry->contains(mnemonic) || taskRegistry->contains(typeID)) + llvm::report_fatal_error( + "detected double registration of compilation task \"" + mnemonic + + "\""); + taskNameRegistry->insert({mnemonic, typeID}); + taskRegistry->insert( + std::make_pair(typeID, TaskRegistration{std::move(func)})); +} diff --git a/mlir-tensorrt/compiler/lib/Compiler/OptionsRegistry.cpp b/mlir-tensorrt/compiler/lib/Compiler/OptionsRegistry.cpp index 13729f82c..7bd1e46fb 100644 --- a/mlir-tensorrt/compiler/lib/Compiler/OptionsRegistry.cpp +++ b/mlir-tensorrt/compiler/lib/Compiler/OptionsRegistry.cpp @@ -23,18 +23,18 @@ using namespace mlirtrt::compiler; static llvm::ManagedStatic> registry{}; -void mlirtrt::compiler::registerOption(const llvm::StringRef optionsType, +void mlirtrt::compiler::registerOption(llvm::StringRef optionsType, OptionsConstructorFuncT func) { (*registry)[optionsType] = std::move(func); } mlirtrt::StatusOr> -mlirtrt::compiler::createOptions(const CompilerClient &client, - const llvm::StringRef optionsType, - const llvm::ArrayRef args) { +mlirtrt::compiler::createOptions(mlir::MLIRContext *ctx, + llvm::StringRef optionsType, + llvm::ArrayRef args) { if (!registry->contains(optionsType)) return getInvalidArgStatus( "{0} is not a valid option type. Valid options were: {1:$[ ]}", optionsType, llvm::iterator_range(registry->keys())); - return (*registry)[optionsType](client, args); + return (*registry)[optionsType](ctx, args); } diff --git a/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp b/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp index 3d609a93e..dbd8dec2a 100644 --- a/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp +++ b/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp @@ -156,15 +156,6 @@ StablehloToExecutableOptions::StablehloToExecutableOptions( // Link in options for all extensions. for (auto &[id, ext] : this->extensions) ext->addToOptions(*this); - - addOption( - "plan-clustering-disallow-host-tensors-in-tensorrt-clusters", - disallowHostTensorsInTensorRTClusters, llvm::cl::init(false), - llvm::cl::desc("Don't allow TensorRt clusters to contain host tensor " - "calculations (but they can still be inputs)")); - - addOption("entrypoint", entrypoint, llvm::cl::init("main"), - llvm::cl::desc("entrypoint function name")); } //===----------------------------------------------------------------------===// @@ -297,69 +288,6 @@ maybePopulateDefaultClusterKinds(mlir::ModuleOp module, } } -StatusOr> -StablehloToExecutableTask::compileStableHLOToExecutable( - mlir::ModuleOp module, const StablehloToExecutableOptions &options) { - LLVM_DEBUG({ - DBGS() << "compiling with options:\n"; - options.print(llvm::dbgs()); - llvm::dbgs() << "\n"; - }); - - maybePopulateDefaultClusterKinds(module, options); - -#ifndef NDEBUG - //===----------------------------------------------------------------------===// - // Set debug options. - //===----------------------------------------------------------------------===// - if (options.get().enableLLVMDebugFlag) { - SmallVector debugTypeLiterals = - llvm::map_to_vector(options.get().llvmDebugTypes, - [](const std::string &x) { return x.c_str(); }); - llvm::setCurrentDebugTypes(debugTypeLiterals.data(), - debugTypeLiterals.size()); - llvm::DebugFlag = true; - } -#endif - - //===----------------------------------------------------------------------===// - // Setup pass manager - //===----------------------------------------------------------------------===// - - StablehloToExecutableTask runner(module->getContext(), options); - if (failed(setupPassManager(runner, options.get()))) { - /// TODO: Ignored. This can fail if pass manager static CL options were not - /// registered/initialized. This happens through invocation of e.g. this - /// function in e.g. Python bindings or standalone calls to C++ or C API - /// without doing all the typical static CL setup. We should instead be - /// accepting a PassManager here that has already been setup to the caller's - /// specifications. - } - if (failed(runner.run(module))) - return getInternalErrorStatus( - "failed to run compilation on module with symbol name: {0}", - module.getName() ? *module.getName() : "no-symbol-name"); - - //===----------------------------------------------------------------------===// - // Translate to Runtime Executable - //===----------------------------------------------------------------------===// - - FailureOr> exeStorage = - mlir::translateToRuntimeExecutable(module); - if (failed(exeStorage)) - return getStatusWithMsg(StatusCode::InternalError, - "failed to translate compiled MLIR module to a " - "MLIR-TensorRT runtime Executable"); - -#ifndef NDEBUG - // Turn debugging back off if we turned it on. - if (options.get().enableLLVMDebugFlag) - llvm::DebugFlag = false; -#endif - - return std::make_unique(std::move(*exeStorage)); -} - mlirtrt::StatusOr> StablehloToExecutableTask::compileStableHLOToExecutable( CompilerClient &client, mlir::ModuleOp module, @@ -387,19 +315,13 @@ StablehloToExecutableTask::compileStableHLOToExecutable( } #endif - mlir::PassManager *runner; - std::unique_ptr pm{}; - - if (options.getHash()) - runner = &client.getOrCreatePassManager(options); - else { - pm.reset(new StablehloToExecutableTask(client.getContext(), options)); - CompilerClient::setupPassManagerLogging(*pm, options.get()); - runner = pm.get(); - } + StatusOr runner = + client.getCompilationTask(options.serialize()); + if (!runner.isOk()) + return runner.getStatus(); // Setup pass manager - if (failed(runner->run(module))) + if (failed((*runner)->run(module))) return getInternalErrorStatus( "failed to run compilation on module with symbol name: {0}", module.getName() ? *module.getName() : "no-symbol-name"); @@ -470,9 +392,64 @@ static StablehloToExecutableOptions populateStablehloClusteringPipelineOpts( } void mlirtrt::compiler::registerStableHloToExecutableTask() { - registerOption("stablehlo-to-executable", - optionsCreateFromArgs); + registerOption( + "stablehlo-to-executable", + [](MLIRContext *ctx, ArrayRef opts) + -> StatusOr> { + auto task = optionsCreateFromArgs(ctx, opts); + if (!task.isOk()) + return task.getStatus(); + return std::unique_ptr(std::move(*task)); + }); + + registerCompilationTask( + "stablehlo-to-executable", + [](CompilerClient &client, llvm::ArrayRef options) + -> StatusOr { + // Load available extensions. + mlir::MLIRContext *context = client.getContext(); + mlir::plan::PlanDialect *planDialect = + context->getLoadedDialect(); + compiler::TaskExtensionRegistry extensions = + planDialect->extensionConstructors + .getExtensionRegistryForTask(); + + StablehloToExecutableOptions result(std::move(extensions)); + + std::string err; + if (failed(result.parse(options, err))) + return getInvalidArgStatus( + "failed to parse options string \"{0:$[ ]}\" due to error {1}", + llvm::iterator_range(options), err); + + llvm::Error finalizeStatus = result.finalize(); + std::optional errMsg{}; + llvm::handleAllErrors(std::move(finalizeStatus), + [&errMsg](const llvm::StringError &err) { + errMsg = err.getMessage(); + }); + + if (errMsg) + return getInvalidArgStatus("failed to parse options due to error {0}", + errMsg); + + std::optional hashCode = result.getHash(); + if (!hashCode) + return getInvalidArgStatus("failed to hash options"); + + CompilationTaskBase *cached = client.lookupCachedCompilationTask( + mlir::TypeID::get(), *hashCode); + if (cached) + return cached; + + auto newPM = std::make_unique( + client.getContext(), result); + auto ptr = newPM.get(); + client.updateCachedCompilationTask( + *hashCode, std::move(newPM)); + return ptr; + }); } void mlirtrt::compiler::registerStablehloClusteringPipelines() { diff --git a/mlir-tensorrt/compiler/test/python/mlir_tensorrt_compiler/compiler_api/test_executor_translation.py b/mlir-tensorrt/compiler/test/python/mlir_tensorrt_compiler/compiler_api/test_executor_translation.py new file mode 100644 index 000000000..7304a10ce --- /dev/null +++ b/mlir-tensorrt/compiler/test/python/mlir_tensorrt_compiler/compiler_api/test_executor_translation.py @@ -0,0 +1,22 @@ +# RUN: %PYTHON %s | FileCheck %s +import mlir_tensorrt.compiler.api as compiler +import mlir_tensorrt.compiler.ir as ir + + +with ir.Context() as ctx: + client = compiler.CompilerClient(ctx) + ASM = """ + func.func @main(%arg0: i32, %arg1: i32) -> i32 attributes{ + executor.function_metadata=#executor.func_meta<[i32, i32],[i32], num_output_args = 0> + }{ + %0 = executor.sremi %arg0, %arg1 : i32 + return %0 : i32 + } + """ + + m = ir.Module.parse(ASM) + exe = compiler.translate_mlir_to_executable(m.operation) + + sig = exe.get_signature("main") + print(sig) + # CHECK: FunctionSignature(Signature) diff --git a/mlir-tensorrt/executor/include/mlir-executor-c/Target/ExecutorTranslations.h b/mlir-tensorrt/executor/include/mlir-executor-c/Target/ExecutorTranslations.h new file mode 100644 index 000000000..4137a24d0 --- /dev/null +++ b/mlir-tensorrt/executor/include/mlir-executor-c/Target/ExecutorTranslations.h @@ -0,0 +1,37 @@ +//===- ExecutorTranslations.h ------------------------------------*- C -*-===// +// +// SPDX-FileCopyrightText: Copyright 2024 NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_EXECUTOR_C_TARGET_EXECUTORTRANSLATIONS +#define MLIR_EXECUTOR_C_TARGET_EXECUTORTRANSLATIONS + +#include "mlir-c/IR.h" +#include "mlir-executor-c/Common/Common.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_CAPI_EXPORTED MTRT_Status +translateToRuntimeExecutable(MlirOperation op, MTRT_Executable *executable); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_EXECUTOR_C_TARGET_EXECUTORTRANSLATIONS diff --git a/mlir-tensorrt/executor/lib/CAPI/CMakeLists.txt b/mlir-tensorrt/executor/lib/CAPI/CMakeLists.txt index 6f68dead7..de95436a1 100644 --- a/mlir-tensorrt/executor/lib/CAPI/CMakeLists.txt +++ b/mlir-tensorrt/executor/lib/CAPI/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(Support) add_subdirectory(Common) -add_subdirectory(Runtime) \ No newline at end of file +add_subdirectory(Runtime) +add_subdirectory(Target) diff --git a/mlir-tensorrt/executor/lib/CAPI/Target/CMakeLists.txt b/mlir-tensorrt/executor/lib/CAPI/Target/CMakeLists.txt new file mode 100644 index 000000000..c94127722 --- /dev/null +++ b/mlir-tensorrt/executor/lib/CAPI/Target/CMakeLists.txt @@ -0,0 +1,8 @@ +add_mlir_public_c_api_library(MLIRTensorRTCAPIExecutorTranslations + ExecutorTranslations.cpp + + LINK_LIBS PUBLIC + MLIRTensorRTSupportStatus + MLIRTensorRTTargetLua + MLIRTensorRTCAPICommon + ) diff --git a/mlir-tensorrt/executor/lib/CAPI/Target/ExecutorTranslations.cpp b/mlir-tensorrt/executor/lib/CAPI/Target/ExecutorTranslations.cpp new file mode 100644 index 000000000..5790b9d24 --- /dev/null +++ b/mlir-tensorrt/executor/lib/CAPI/Target/ExecutorTranslations.cpp @@ -0,0 +1,41 @@ +//===- Compiler.h -------------------------------------------------*- C -*-===// +// +// SPDX-FileCopyrightText: Copyright 2024 NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +#include "mlir-executor-c/Target/ExecutorTranslations.h" +#include "mlir-executor-c/Support/Status.h" +#include "mlir-executor/Runtime/API/API.h" +#include "mlir-executor/Target/Lua/TranslateToRuntimeExecutable.h" +#include "mlir/CAPI/IR.h" + +using namespace mlir; + +MTRT_Status translateToRuntimeExecutable(MlirOperation op, + MTRT_Executable *result) { + FailureOr> exeStorage = + mlir::translateToRuntimeExecutable(unwrap(op)); + if (failed(exeStorage)) + return mtrtStatusCreate(MTRT_StatusCode_InternalError, + "failed to translate to executable"); + + *result = MTRT_Executable{ + std::make_unique(std::move(*exeStorage)) + .release()}; + + return mtrtStatusGetOk(); +} diff --git a/mlir-tensorrt/python/CompilerPackage.cmake b/mlir-tensorrt/python/CompilerPackage.cmake index 44b3dd94e..09473e523 100644 --- a/mlir-tensorrt/python/CompilerPackage.cmake +++ b/mlir-tensorrt/python/CompilerPackage.cmake @@ -95,6 +95,7 @@ declare_mlir_python_extension(MLIRTensorRTPythonCompiler.CompilerAPI.PyBind MLIRTensorRTCAPICompiler MLIRTensorRTCAPISupportStatus MLIRTensorRTCAPICommon + MLIRTensorRTCAPIExecutorTranslations PRIVATE_LINK_LIBS LLVMSupport TensorRTHeaderOnly diff --git a/mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp b/mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp index 07e1930a3..509309342 100644 --- a/mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp +++ b/mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp @@ -11,9 +11,11 @@ #include "../Utils.h" #include "NvInferRuntime.h" #include "mlir-c/IR.h" +#include "mlir-c/Pass.h" #include "mlir-c/Support.h" #include "mlir-executor-c/Common/Common.h" #include "mlir-executor-c/Support/Status.h" +#include "mlir-executor-c/Target/ExecutorTranslations.h" #include "mlir-tensorrt-c/Compiler/Compiler.h" #include "mlir/Bindings/Python/PybindAdaptors.h" #include "pybind11/pybind11.h" @@ -71,11 +73,11 @@ class PyStableHLOToExecutableOptions }; /// Python object type wrapper for `MlirPassManager`. -class PyStableHloPipeline - : public PyMTRTWrapper { +class PyPassManagerReference + : public PyMTRTWrapper { public: using PyMTRTWrapper::PyMTRTWrapper; - DECLARE_WRAPPER_CONSTRUCTORS(PyStableHloPipeline); + DECLARE_WRAPPER_CONSTRUCTORS(PyPassManagerReference); static constexpr auto kMethodTable = CAPITable{mtrtStableHloPipelineIsNull, nullptr}; @@ -240,13 +242,34 @@ PYBIND11_MODULE(_api, m) { populateCommonBindingsInModule(m); + m.def("translate_mlir_to_executable", [](MlirOperation op) { + MTRT_Executable exe{nullptr}; + MTRT_Status status = translateToRuntimeExecutable(op, &exe); + THROW_IF_MTRT_ERROR(status); + return new PyExecutable(exe); + }); + py::class_(m, "CompilerClient", py::module_local()) .def(py::init<>([](MlirContext context) -> PyCompilerClient * { MTRT_CompilerClient client; MTRT_Status s = mtrtCompilerClientCreate(context, &client); THROW_IF_MTRT_ERROR(s); return new PyCompilerClient(client); - })); + })) + .def("get_compilation_task", + [](PyCompilerClient &self, const std::string &mnemonic, + const std::vector &args) { + std::vector refs(args.size()); + for (unsigned i = 0; i < args.size(); i++) + refs[i] = mlirStringRefCreate(args[i].data(), args[i].size()); + + MlirPassManager pm{nullptr}; + MTRT_Status status = mtrtCompilerClientGetCompilationTask( + self, mlirStringRefCreate(mnemonic.data(), mnemonic.size()), + refs.data(), refs.size(), &pm); + THROW_IF_MTRT_ERROR(status); + return new PyPassManagerReference(pm); + }); py::class_(m, "OptionsContext", py::module_local()) .def(py::init<>([](PyCompilerClient &client, @@ -315,20 +338,25 @@ PYBIND11_MODULE(_api, m) { py::arg("dump_ir_tree_dir") = py::none(), py::arg("dump_tensorrt_dir") = py::none()); - py::class_(m, "StableHloPipeline", py::module_local()) + py::class_(m, "StableHloPipeline", py::module_local()) .def(py::init<>([](PyCompilerClient &client, PyStableHLOToExecutableOptions &options) { MlirPassManager pm{}; MTRT_Status status = mtrtStableHloPipelineGetCached(client, options, &pm); THROW_IF_MTRT_ERROR(status); - return new PyStableHloPipeline(pm); + return new PyPassManagerReference(pm); }), - py::arg("client"), py::arg("options")); + py::arg("client"), py::arg("options")) + .def("run", [](PyPassManagerReference &self, MlirOperation op) { + MlirLogicalResult result = mlirPassManagerRunOnOp(self.get(), op); + if (mlirLogicalResultIsFailure(result)) + throw MTRTException("failed to run pass pipeline"); + }); m.def( "get_executable", - [](PyStableHloPipeline &pm, MlirOperation module) { + [](PyPassManagerReference &pm, MlirOperation module) { MTRT_Executable exe{nullptr}; MTRT_Status status = mtrtCompilerGetExecutable(pm, module, &exe); THROW_IF_MTRT_ERROR(status); diff --git a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Utils/Options.h b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Utils/Options.h index 6c4e4b2a7..eb7e8e91d 100644 --- a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Utils/Options.h +++ b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Utils/Options.h @@ -26,6 +26,7 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/raw_ostream.h" +#include namespace mlir { @@ -88,6 +89,67 @@ class OptionsContext : public llvm::cl::SubCommand { OptionsContext(OptionsContext &&) = default; virtual ~OptionsContext() = default; + /// Convenience type for declaring options as class/struct member without + /// having to explicitly write `addOption` in the constructor of the options + /// container class. + template + struct Option { + + T value; + operator const T &() const { return value; } + Option &operator=(const T &rhs) { + value = rhs; + return *this; + } + + template + std::enable_if_t, bool> empty() const { + return value.empty(); + } + + // Implicit conversion operator to StringRef, enabled only if T is + // std::string + template + operator typename std::enable_if_t, + llvm::StringRef>() const { + return value; + } + + template + Option(OptionsContext *ctx, llvm::StringRef name, Args &&...args) { + ctx->addOption(name, value, std::forward(args)...); + } + + Option() = delete; + Option(const Option &) = delete; + Option(Option &&) = default; + Option &operator=(const Option &) = delete; + }; + + /// Convenience type for declaring vector class member as an option without + /// having to explicitly write `addList` in the constructor of the options + /// container class. + template + struct ListOption { + std::vector value; + operator const std::vector &() const { return value; } + + auto empty() const { return value.empty(); } + auto begin() const { return value.begin(); } + auto end() const { return value.end(); } + auto front() const { return value.front(); } + auto back() const { return value.back(); } + auto emplace_back(T &&item) { + return value.emplace_back(std::forward(item)); + } + auto push_back(T &&item) { return value.push_back(std::forward(item)); } + + template + ListOption(OptionsContext *ctx, llvm::StringRef name, Args &&...args) { + ctx->addList(name, value, std::forward(args)...); + } + }; + protected: /// Add an option to this context. The storage `value` must outlive the /// OptionsContext. @@ -148,6 +210,8 @@ class OptionsContext : public llvm::cl::SubCommand { /// Print the options to the stream. void print(llvm::raw_ostream &os) const; + SmallVector serialize() const; + /// Get a hash derived from the string representation of the options. /// Derived classes can use this method to incorporate additional factors /// which cannot be captured by the options string representation. Returning diff --git a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Utils/OptionsBundle.h b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Utils/OptionsBundle.h index 959ad0d50..447c47a80 100644 --- a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Utils/OptionsBundle.h +++ b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Utils/OptionsBundle.h @@ -31,22 +31,18 @@ namespace mlir { template class OptionsBundle : public OptionsContext { public: - OptionsBundle() { - std::apply( - [&](auto &...optionProvider) { - (optionProvider.addToOptions(*this), ...); - }, - optionProviders); - } + OptionsBundle() + : optionProviders(std::make_unique( + *static_cast(this))...) {} template const OptionsProviderT &get() const { - return std::get(optionProviders); + return *std::get>(optionProviders); } template OptionsProviderT &get() { - return std::get(optionProviders); + return *std::get>(optionProviders); } llvm::Error finalize() override { @@ -54,7 +50,7 @@ class OptionsBundle : public OptionsContext { std::apply( [&](auto &...optionProvider) { ((result = std::move(llvm::joinErrors(std::move(result), - optionProvider.finalize()))), + optionProvider->finalize()))), ...); }, optionProviders); @@ -63,7 +59,7 @@ class OptionsBundle : public OptionsContext { } private: - std::tuple optionProviders{}; + std::tuple...> optionProviders; }; } // namespace mlir diff --git a/mlir-tensorrt/tensorrt/lib/Target/TranslateToTensorRT.cpp b/mlir-tensorrt/tensorrt/lib/Target/TranslateToTensorRT.cpp index d551234c8..1248c8aa3 100644 --- a/mlir-tensorrt/tensorrt/lib/Target/TranslateToTensorRT.cpp +++ b/mlir-tensorrt/tensorrt/lib/Target/TranslateToTensorRT.cpp @@ -66,7 +66,7 @@ using namespace mlir::tensorrt; bool ByteSizeParser::parse(llvm::cl::Option &option, StringRef argName, StringRef arg, std::optional &val) { val = std::nullopt; - if (arg.empty()) + if (arg.empty() || arg.lower() == "none") return false; char *End; diff --git a/mlir-tensorrt/tensorrt/lib/Utils/Options.cpp b/mlir-tensorrt/tensorrt/lib/Utils/Options.cpp index b65005122..d41389f5e 100644 --- a/mlir-tensorrt/tensorrt/lib/Utils/Options.cpp +++ b/mlir-tensorrt/tensorrt/lib/Utils/Options.cpp @@ -23,6 +23,7 @@ //===----------------------------------------------------------------------===// #include "mlir-tensorrt-dialect/Utils/Options.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_ostream.h" using namespace mlir; @@ -63,6 +64,22 @@ void OptionsContext::print(llvm::raw_ostream &os) const { " "); } +SmallVector OptionsContext::serialize() const { + assert(getHash() && "cannot serialize non-hashable options"); + SmallVector result; + for (const auto &[key, option] : this->OptionsMap) { + std::string val; + { + llvm::raw_string_ostream ss(val); + auto printer = this->printers.lookup(option); + if (printer) + printer(ss); + } + result.push_back(llvm::formatv("--{0}={1}", key, val)); + } + return result; +} + std::optional OptionsContext::getHash() const { // We hash by just hashing the string representation. llvm::SmallString<128> str;