Skip to content

Commit

Permalink
Rename Experimental to Advanced (#1611)
Browse files Browse the repository at this point in the history
Signed-off-by: Whitney Tsang <whitney.tsang@intel.com>
  • Loading branch information
whitneywhtsang authored Jul 11, 2024
1 parent b39b566 commit e0b2c79
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 15 deletions.
6 changes: 3 additions & 3 deletions third_party/intel/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def hash(self):

class XPUBackend(BaseBackend):

# Experimental pass pipeline for kernels using block pointers.
class Experimental:
# AdvancedPath pass pipeline for kernels using block pointers.
class AdvancedPath:

@staticmethod
def make_ttgir(mod, metadata, opt):
Expand Down Expand Up @@ -175,7 +175,7 @@ def make_ttgir(mod, metadata, opt, device_arch):
pm.enable_debug()

if (not is_lts and os.getenv("TRITON_INTEL_ENABLE_BLOCK_PTR", "0") == "1"):
return XPUBackend.Experimental.make_ttgir(mod, metadata, opt)
return XPUBackend.AdvancedPath.make_ttgir(mod, metadata, opt)

passes.ttir.add_convert_to_ttgpuir(pm, f"xpu:{device_arch}", opt.num_warps, opt.threads_per_warp, opt.num_ctas)
intel.passes.ttgpuir.add_accelerate_matmul(pm)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class TritonIntelGPUToLLVMTypeConverter : public TritonGPUToLLVMTypeConverter {
using TypeConverter::convertType;

TritonIntelGPUToLLVMTypeConverter(
MLIRContext *ctx, LowerToLLVMOptions &option, bool isLTSDriver,
MLIRContext *ctx, LowerToLLVMOptions &option, bool isAdvancedPathEnabled,
const DataLayoutAnalysis *analysis = nullptr);
};

Expand Down
10 changes: 5 additions & 5 deletions third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,21 +180,21 @@ class TritonGPUToLLVMPipelineManager {
public:
TritonGPUToLLVMPipelineManager(ModuleOp &mod, MLIRContext *ctx)
: mod(mod), ctx(ctx),
blockPtrPathIsEnabled(
isAdvancedPathEnabled(
!mod->hasAttr("triton_gpu.is_lts") &&
mlir::triton::tools::getBoolEnv("TRITON_INTEL_ENABLE_BLOCK_PTR")) {}

/// FIXME: remove once the block ptr conversion path is capable of handling
/// shared memory.
bool skipSharedMemoryAllocation() const { return blockPtrPathIsEnabled; }
bool skipSharedMemoryAllocation() const { return isAdvancedPathEnabled; }

/// Populate the conversion pipeline for function operations.
void populateFunctionConversionPatterns(
RewritePatternSet &funcPatterns,
TritonIntelGPUToLLVMTypeConverter &typeConverter, int numWarps) const {
funcPatterns.add<FuncOpConversion>(typeConverter, numWarps,
/*benefit=*/1);
if (!blockPtrPathIsEnabled)
if (!isAdvancedPathEnabled)
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
funcPatterns);
}
Expand All @@ -213,7 +213,7 @@ class TritonGPUToLLVMPipelineManager {
patterns.add<AddSPIRVEnvPattern>(&typeConverter.getContext(),
patternBenefitAddSPIRVEnv);

if (blockPtrPathIsEnabled) {
if (isAdvancedPathEnabled) {
intel::populateTritonOpsToLLVMPatterns(typeConverter, patterns, benefit);
intel::populateControlFlowOpToLLVMPattern(typeConverter, patterns,
benefit);
Expand Down Expand Up @@ -267,7 +267,7 @@ class TritonGPUToLLVMPipelineManager {
/// Selects which conversion pipeline to use.
/// FIXME: this is temporary and should be removed once we have an analysis to
/// determine whether a kernel uses block pointers.
bool blockPtrPathIsEnabled = false;
bool isAdvancedPathEnabled = false;
};

} // namespace mlir::triton::intel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,11 @@ struct ConvertTritonGPUToLLVM
intel::TritonGPUToLLVMPipelineManager pipelineManager(mod, context);
mlir::LowerToLLVMOptions option(context);
option.overrideIndexBitwidth(32);
bool isLTSDriver = mod->hasAttr("triton_gpu.is_lts");
bool isAdvancedPathEnabled =
!mod->hasAttr("triton_gpu.is_lts") &&
mlir::triton::tools::getBoolEnv("TRITON_INTEL_ENABLE_BLOCK_PTR");
TritonIntelGPUToLLVMTypeConverter typeConverter(context, option,
isLTSDriver);
isAdvancedPathEnabled);
TritonLLVMConversionTarget convTarget(*context);
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod);
Expand All @@ -98,7 +100,7 @@ struct ConvertTritonGPUToLLVM
{
mlir::LowerToLLVMOptions option(context);
TritonIntelGPUToLLVMTypeConverter typeConverter(context, option,
isLTSDriver);
isAdvancedPathEnabled);
TritonLLVMFunctionConversionTarget funcTarget(*context);
RewritePatternSet funcPatterns(context);
pipelineManager.populateFunctionConversionPatterns(
Expand Down
5 changes: 2 additions & 3 deletions third_party/intel/lib/TritonIntelGPUToLLVM/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@
#include "triton/Tools/Sys/GetEnv.hpp"

TritonIntelGPUToLLVMTypeConverter::TritonIntelGPUToLLVMTypeConverter(
MLIRContext *ctx, LowerToLLVMOptions &option, bool isLTSDriver,
MLIRContext *ctx, LowerToLLVMOptions &option, bool isAdvancedPathEnabled,
const DataLayoutAnalysis *analysis)
: TritonGPUToLLVMTypeConverter(ctx, option, analysis) {
// Augment/overwrite type conversions required for the Intel conversion
// passes.
if (!isLTSDriver &&
mlir::triton::tools::getBoolEnv("TRITON_INTEL_ENABLE_BLOCK_PTR")) {
if (isAdvancedPathEnabled) {
// tt::pointer to v2i32.
addConversion([&](PointerType type) -> std::optional<Type> {
if (isa<RankedTensorType>(type.getPointeeType())) {
Expand Down

0 comments on commit e0b2c79

Please sign in to comment.