Skip to content

Commit

Permalink
Move internal change: [compiler] Add compilation task registry (#465)
Browse files Browse the repository at this point in the history
This change adds a CompilationTask (cached pass managers) registry
which enables creating and looking up cached compilation tasks from
the Python API by just passing a mnemonic task name and a list of
string options.

GitOrigin-RevId: f19e634e8ff8338809fe2c1b8efa730ef2e14f21
  • Loading branch information
christopherbate authored Dec 18, 2024
1 parent 7dfa0fb commit c17ace7
Show file tree
Hide file tree
Showing 20 changed files with 568 additions and 199 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ static inline bool mtrtCompilerClientIsNull(MTRT_CompilerClient options) {
return !options.ptr;
}

MLIR_CAPI_EXPORTED MTRT_Status mtrtCompilerClientGetCompilationTask(
MTRT_CompilerClient client, MlirStringRef taskMnemonic,
const MlirStringRef *argv, unsigned argc, MlirPassManager *result);

//===----------------------------------------------------------------------===//
// MTRT_OptionsContext
//===----------------------------------------------------------------------===//
Expand Down
146 changes: 125 additions & 21 deletions mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/Client.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

#include "mlir-executor/Support/Status.h"
#include "mlir-tensorrt/Compiler/OptionsProviders.h"
#include "mlir-tensorrt/Compiler/OptionsRegistry.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/TypeID.h"
Expand Down Expand Up @@ -100,28 +101,69 @@ class CompilerClient {

~CompilerClient() = default;

/// Create or retrieve a cached PassManager of the given derived type using
/// the provided options. PassManagers are cached by type and a hash of the
/// string representation of the options.
/// This function should only be called if the options have a valid hash.
template <typename CompilationTaskType, typename OptionsType>
mlir::PassManager &getOrCreatePassManager(const OptionsType &options) {
std::optional<llvm::hash_code> hash = options.getHash();
if (!hash)
llvm::report_fatal_error("attempted to lookup a PassManager from a cache "
"with an un-hashable options key");

auto key =
std::make_pair(mlir::TypeID::get<CompilationTaskType>(), hash.value());
/// Create or retrieve from the cache a compilation task of the specified
/// type and options. If an existing compilation task is not in the cache,
/// then it is constructed using the registered construction function and
/// inserted into the cache.
StatusOr<CompilationTaskBase *>
getCompilationTask(mlir::TypeID taskID,
llvm::ArrayRef<llvm::StringRef> options);

/// Create or retrieve from the cache a compilation task of the specified
/// type ID and options. If an existing compilation task is not in the cache,
/// then it is constructed using the registered construction function and
/// inserted into the cache.
StatusOr<CompilationTaskBase *>
getCompilationTask(mlir::TypeID taskID, llvm::ArrayRef<std::string> options) {
return getCompilationTask(
taskID, llvm::map_to_vector(options, [](const std::string &x) {
return llvm::StringRef(x);
}));
}

StatusOr<CompilationTaskBase *>
getCompilationTask(llvm::StringRef mnemonic,
llvm::ArrayRef<llvm::StringRef> options);

/// Create or retrieve from the cache a compilation task of the specified
/// type and options. If an existing compilation task is not in the cache,
/// then it is constructed using the registered construction function and
/// inserted into the cache.
template <typename T, typename... Args>
StatusOr<CompilationTaskBase *> getCompilationTask(Args &&...args) {
return getCompilationTask(mlir::TypeID::get<T>(),
std::forward<Args>(args)...);
}

/// Insert a compilation task of type T with options hash `hash` into the
/// cache.
template <typename T>
void updateCachedCompilationTask(const llvm::hash_code &hash,
std::unique_ptr<CompilationTaskBase> task) {
cachedPassManagers[std::make_pair(mlir::TypeID::get<T>(), hash)] =
std::move(task);
}

/// Check whether a CompilationTask with the specified typeID and whose
/// options have the given hash is in the cache. If so, return it; otherwise
/// returns nullptr.
CompilationTaskBase *
lookupCachedCompilationTask(mlir::TypeID taskID,
const llvm::hash_code &optionsHash) {
auto key = std::make_pair(taskID, optionsHash);
auto it = cachedPassManagers.find(key);
if (it == cachedPassManagers.end()) {
auto pm = std::make_unique<CompilationTaskType>(context, options);
setupPassManagerLogging(*pm, options.template get<DebugOptions>());
auto *ptr = pm.get();
cachedPassManagers[key] = std::move(pm);
return *ptr;
}
return *it->second;
if (it == cachedPassManagers.end())
return nullptr;
return it->second.get();
}

/// Check whether a CompilationTask with the specified type T and whose
/// options have the given hash is in the cache. If so, return it; otherwise
/// returns nullptr.
template <typename T>
CompilationTaskBase *
lookupCachedCompilationTask(const llvm::hash_code &optionsHash) {
return lookupCachedCompilationTask(mlir::TypeID::get<T>(), optionsHash);
}

/// Return the MLIRContext associated with the client.
Expand All @@ -147,6 +189,68 @@ class CompilerClient {
cachedPassManagers;
};

/// A registry function that adds passes to the given pass manager. This should
/// also parse options and return success() if parsing succeeded.
/// `errorHandler` is a functor used to emit errors during parsing.
/// parameter corresponds to the raw location within the pipeline string. This
/// should always return failure.
using TaskRegistryFunction = std::function<StatusOr<CompilationTaskBase *>(
CompilerClient &client, llvm::ArrayRef<llvm::StringRef> options)>;

struct TaskRegistration {
TaskRegistryFunction registryFunc;
};

void registerCompilationTask(llvm::StringRef mnemonic, mlir::TypeID typeID,
TaskRegistryFunction func);

template <typename T>
void registerCompilationTask(llvm::StringRef mnemonic,
TaskRegistryFunction func) {
return registerCompilationTask(mnemonic, mlir::TypeID::get<T>(),
std::move(func));
}

template <typename T, typename OptionsType>
void registerCompilationTaskWithNoExtensions(llvm::StringRef mnemonic) {
registerCompilationTask<T>(
mnemonic,
[](CompilerClient &client, llvm::ArrayRef<llvm::StringRef> options)
-> StatusOr<CompilationTaskBase *> {
OptionsType result;
std::string err;
if (failed(result.parse(options, err)))
return getInvalidArgStatus(
"failed to parse options string \"{0:$[ ]}\" due to error {1}",
llvm::iterator_range(options), err);

llvm::Error finalizeStatus = result.finalize();
std::optional<std::string> errMsg{};
llvm::handleAllErrors(std::move(finalizeStatus),
[&errMsg](const llvm::StringError &err) {
errMsg = err.getMessage();
});

if (errMsg)
return getInvalidArgStatus("failed to parse options due to error {0}",
errMsg);

std::optional<llvm::hash_code> hashCode = result.getHash();
if (!hashCode)
return getInvalidArgStatus("failed to hash options");

CompilationTaskBase *cached =
client.lookupCachedCompilationTask<T>(*hashCode);
if (cached)
return cached;

auto newPM = std::make_unique<T>(client.getContext(), result);
auto ptr = newPM.get();
client.updateCachedCompilationTask<T>(*hashCode, std::move(newPM));
return ptr;
});
}

} // namespace mlirtrt::compiler

#endif // MLIR_TENSORRT_COMPILER_CLIENT
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,24 @@ constexpr bool has_finalize_impl_v<
// a default implementation otherwise.
template <typename Derived>
struct OptionsProvider {
OptionsProvider(mlir::OptionsContext &ctx) : ctx(ctx) {}

// We don't allow move construction since the actual ptrs/locations of
// individual member elements of an OptionsProvider are captured into the
// OptionsContext. If the OptionsContext is populated upon construction,
// moving can change the memory location of the owned values, which will cause
// a crash later on. This is in particular can happen if you are constructing
// a tuple of `OptionsProviders`. Since we are deleting the move constructor,
// one must instead use a tuple of `unique_ptr<OptionsProviders...>`.
OptionsProvider(OptionsProvider &&) = delete;

mlir::OptionsContext &ctx;

template <typename T, typename... Mods>
using Option = mlir::OptionsContext::Option<T, Mods...>;
template <typename T, typename... Mods>
using ListOption = mlir::OptionsContext::ListOption<T, Mods...>;

/// Modifies options after parsing. This is required since we may need
/// to make changes to options based on the values of other options.
/// Do *not* override this method; instead, implement `finalizeImpl()`.
Expand All @@ -62,67 +80,63 @@ struct OptionsProvider {
/// interfaces.
struct DebugOptions : public OptionsProvider<DebugOptions> {
public:
using OptionsProvider::OptionsProvider;
/// A directory path where the IR will be dumped during compilation
/// using the `mlir-print-ir-tree-dir` mechanism.
std::string dumpIRPath = "";
Option<std::string> dumpIRPath{&this->ctx, "mlir-print-ir-tree-dir",
llvm::cl::init("")};

/// Whether the LLVM 'debug' flag that enables execution of code guarded by
/// the `LLVM_DEBUG` macro should be set to 'on'. This results in very verbose
/// output from the compiler dumped to stderr.
bool enableLLVMDebugFlag = false;
Option<bool> enableLLVMDebugFlag{&this->ctx, "debug", llvm::cl::init(false)};

/// A set of names to be given to the LLVM 'debug types' option, akin to
/// setting
/// `-debug-types=...` from the command line.
mlir::SmallVector<std::string> llvmDebugTypes = {};

public:
void addToOptions(mlir::OptionsContext &context) {
context.addOption("mlir-print-ir-tree-dir", dumpIRPath, llvm::cl::init(""));
context.addOption("debug", enableLLVMDebugFlag);
context.addList<std::string>("debug-only", llvmDebugTypes,
llvm::cl::ZeroOrMore,
llvm::cl::CommaSeparated);
}
ListOption<std::string> llvmDebugTypes{
&this->ctx, "debug-only", llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated};
};

struct ExecutorOptions : public OptionsProvider<ExecutorOptions> {
public:
/// The host index bit-width.
int64_t indexBitwidth{64};
using OptionsProvider::OptionsProvider;

/// Whether to pass memref's as struct/table in function calls.
bool usePackedMemRefCConv{true};
Option<int64_t> indexBitwidth{&this->ctx, "executor-index-bitwidth",
llvm::cl::init(64),
llvm::cl::desc("executor index bitwidth")};

public:
void addToOptions(mlir::OptionsContext &context) {
context.addOption("executor-index-bitwidth", indexBitwidth,
llvm::cl::init(64));
}
Option<bool> usePackedMemRefCConv{
&this->ctx, "executor-use-packed-memref-cconv", llvm::cl::init(true),
llvm::cl::desc(
"whether to use packed or unpacked memref calling convention")};
};

struct DeviceOptions : public OptionsProvider<DeviceOptions> {
public:
using OptionsProvider::OptionsProvider;

/// Device information. Members are manually bound to options in the
/// constructor.
DeviceInfo info;

/// Whether to ignore `deviceX` options and instead infer them from the GPUs
/// on the host system running the compilation.
bool shouldInferFromHost = false;
Option<bool> shouldInferFromHost{
&this->ctx, "device-infer-from-host", llvm::cl::init(true),
llvm::cl::desc("whether to ignore `deviceX` options and instead infer "
"them from the host GPU")};

Status inferFromHost();

public:
void addToOptions(mlir::OptionsContext &context) {
context.addOption(
DeviceOptions(mlir::OptionsContext &ctx) : OptionsProvider(ctx) {
ctx.addOption(
"device-compute-capability", info.computeCapability, llvm::cl::init(60),
llvm::cl::desc("Sets the device compute capbility. Only relevant "
"if '--device-infer-from-host=false'"));
context.addOption("device-max-shared-memory-per-block-kb",
info.maxSharedMemoryPerBlockKb, llvm::cl::init(48));
context.addOption("device-max-registers-per-block",
info.maxRegistersPerBlock, llvm::cl::init(65536));
context.addOption("device-infer-from-host", shouldInferFromHost,
llvm::cl::init(true),
llvm::cl::desc("Infers device information from host"));
ctx.addOption("device-max-shared-memory-per-block-kb",
info.maxSharedMemoryPerBlockKb, llvm::cl::init(48));
ctx.addOption("device-max-registers-per-block", info.maxRegistersPerBlock,
llvm::cl::init(65536));
}

llvm::Error finalizeImpl();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
#define MLIR_TENSORRT_COMPILER_OPTIONS_REGISTRY

#include "mlir-tensorrt-dialect/Utils/Options.h"
#include "mlir-tensorrt/Compiler/Client.h"
#include "mlir-tensorrt/Dialect/Plan/IR/Plan.h"
#include "mlir/IR/MLIRContext.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Error.h"
Expand All @@ -39,25 +39,23 @@ namespace mlirtrt::compiler {

using OptionsConstructorFuncT =
std::function<StatusOr<std::unique_ptr<mlir::OptionsContext>>(
const CompilerClient &client, const llvm::ArrayRef<llvm::StringRef>)>;
mlir::MLIRContext *, llvm::ArrayRef<llvm::StringRef>)>;

/// Registers an options creation function for a specific options type.
void registerOption(const llvm::StringRef optionsType,
OptionsConstructorFuncT func);
void registerOption(llvm::StringRef optionsType, OptionsConstructorFuncT func);

/// Creates an options instance for the specified options type using a creation
/// function that was previously registered.
StatusOr<std::unique_ptr<mlir::OptionsContext>>
createOptions(const CompilerClient &client, const llvm::StringRef optionsType,
const llvm::ArrayRef<llvm::StringRef> args);
createOptions(mlir::MLIRContext *client, llvm::StringRef optionsType,
llvm::ArrayRef<llvm::StringRef> args);

/// Helper to build callbacks that can create options.
template <typename OptionsT, typename TaskT>
StatusOr<std::unique_ptr<mlir::OptionsContext>>
optionsCreateFromArgs(const CompilerClient &client,
const llvm::ArrayRef<llvm::StringRef> args) {
StatusOr<std::unique_ptr<OptionsT>>
optionsCreateFromArgs(mlir::MLIRContext *context,
llvm::ArrayRef<llvm::StringRef> args) {
// Load available extensions.
mlir::MLIRContext *context = client.getContext();
mlir::plan::PlanDialect *planDialect =
context->getLoadedDialect<mlir::plan::PlanDialect>();
compiler::TaskExtensionRegistry extensions =
Expand All @@ -83,7 +81,7 @@ optionsCreateFromArgs(const CompilerClient &client,
return getInternalErrorStatus("failed to initialize options: %s",
errMsg->c_str());

return std::unique_ptr<mlir::OptionsContext>(result.release());
return result;
}
} // namespace mlirtrt::compiler

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,19 @@ struct StablehloToExecutableOptions
StablehloToExecutableOptions(TaskExtensionRegistry extensions);

/// Whether to disallow host tensors in TensorRT clusters.
bool disallowHostTensorsInTensorRTClusters = false;
Option<bool> disallowHostTensorsInTensorRTClusters{
this, "plan-clustering-disallow-host-tensors-in-tensorrt-clusters",
llvm::cl::init(false),
llvm::cl::desc("Don't allow TensorRt clusters to contain host tensor "
"calculations (but they can still be inputs)")};

Option<std::string> entrypoint{this, "entrypoint", llvm::cl::init("main"),
llvm::cl::desc("entrypoint function name")};

/// Use non-DPS style calling convention for entrypoint function
/// and backend types that support allocating results.
bool enableNonDPSReturns = false;

/// Entrypoint function name.
std::string entrypoint = "main";

/// Base class for extensions associated with StableHloToExecutableTask.
class ExtensionBase : public TaskExtensionBase {
public:
Expand Down Expand Up @@ -134,13 +138,6 @@ class StablehloToExecutableTask
static void populatePassManager(mlir::PassManager &pm,
const StablehloToExecutableOptions &options);

/// Compile a StableHLO module into a MLIR-TensorRT Runtime executable.
/// This is the "functional" entrypoint that will allocate a new PassManager
/// for a single run.
static mlirtrt::StatusOr<std::unique_ptr<runtime::Executable>>
compileStableHLOToExecutable(mlir::ModuleOp module,
const StablehloToExecutableOptions &options);

/// Compile a StableHLO module into a MLIR-TensorRT Runtime executable.
/// This is the "functional" entrypoint that will allocate a new PassManager
/// for a single run.
Expand Down
Loading

0 comments on commit c17ace7

Please sign in to comment.