diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 0b21c4850f..50fb6655e9 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -2963,6 +2963,11 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO { break; // should already have been parsed and used. } + case kIROp_GlobalValueRef: + { + emitOperand(as(inst)->getOperand(0), getInfo(EmitOp::General)); + break; + } default: diagnoseUnhandledInst(inst); break; diff --git a/source/slang/slang-ir-legalize-global-values.cpp b/source/slang/slang-ir-legalize-global-values.cpp new file mode 100644 index 0000000000..92a1edb2b8 --- /dev/null +++ b/source/slang/slang-ir-legalize-global-values.cpp @@ -0,0 +1,247 @@ +#include "slang-ir-legalize-global-values.h" + +#include "slang-ir-clone.h" +#include "slang-ir-util.h" + +namespace Slang +{ + +void GlobalInstInliningContextGeneric::inlineGlobalValuesAndRemoveIfUnused(IRModule* module) +{ + List globalInstUsesToInline; + + for (auto globalInst : module->getGlobalInsts()) + { + if (isInlinableGlobalInst(globalInst)) + { + for (auto use = globalInst->firstUse; use; use = use->nextUse) + { + if (getParentFunc(use->getUser()) != nullptr) + globalInstUsesToInline.add(use); + } + } + } + + HashSet globalInstsToConsiderDeleting; + for (auto use : globalInstUsesToInline) + { + auto user = use->getUser(); + IRBuilder builder(user); + builder.setInsertBefore(getOutsideASM(user)); + IRCloneEnv cloneEnv; + auto val = maybeInlineGlobalValue(builder, use->getUser(), use->get(), cloneEnv); + if (val != use->get()) + { + // Since certain globals that appear in the IR are considered illegal for all targets, + // e.g. calls to functions, we delete the globals we've inlined. + // Note that the inlining is done such that none of the descendants of the global will + // have any uses either. + globalInstsToConsiderDeleting.add(use->usedValue); + + builder.replaceOperand(use, val); + } + } + + for (auto globalInst : globalInstsToConsiderDeleting) + { + if (!globalInst->hasUses()) + globalInst->removeAndDeallocate(); + } +} + +bool GlobalInstInliningContextGeneric::isLegalGlobalInst(IRInst* inst) +{ + switch (inst->getOp()) + { + case kIROp_MakeStruct: + case kIROp_MakeArray: + case kIROp_MakeArrayFromElement: + case kIROp_MakeVector: + case kIROp_MakeMatrix: + case kIROp_MakeMatrixFromScalar: + case kIROp_MakeVectorFromScalar: + return true; + default: + if (as(inst)) + return true; + if (isLegalGlobalInstForTarget(inst)) + return true; + return false; + } +} + +bool GlobalInstInliningContextGeneric::isInlinableGlobalInst(IRInst* inst) +{ + switch (inst->getOp()) + { + case kIROp_Add: + case kIROp_Sub: + case kIROp_Mul: + case kIROp_FRem: + case kIROp_IRem: + case kIROp_Lsh: + case kIROp_Rsh: + case kIROp_And: + case kIROp_Or: + case kIROp_Not: + case kIROp_Neg: + case kIROp_Div: + case kIROp_FieldExtract: + case kIROp_FieldAddress: + case kIROp_GetElement: + case kIROp_GetElementPtr: + case kIROp_GetOffsetPtr: + case kIROp_UpdateElement: + case kIROp_MakeTuple: + case kIROp_GetTupleElement: + case kIROp_MakeStruct: + case kIROp_MakeArray: + case kIROp_MakeArrayFromElement: + case kIROp_MakeVector: + case kIROp_MakeMatrix: + case kIROp_MakeMatrixFromScalar: + case kIROp_MakeVectorFromScalar: + case kIROp_swizzle: + case kIROp_swizzleSet: + case kIROp_MatrixReshape: + case kIROp_MakeString: + case kIROp_MakeResultError: + case kIROp_MakeResultValue: + case kIROp_GetResultError: + case kIROp_GetResultValue: + case kIROp_CastFloatToInt: + case kIROp_CastIntToFloat: + case kIROp_CastIntToPtr: + case kIROp_PtrCast: + case kIROp_CastPtrToBool: + case kIROp_CastPtrToInt: + case kIROp_BitAnd: + case kIROp_BitNot: + case kIROp_BitOr: + case kIROp_BitXor: + case kIROp_BitCast: + case kIROp_IntCast: + case kIROp_FloatCast: + case kIROp_Greater: + case kIROp_Less: + case kIROp_Geq: + case kIROp_Leq: + case kIROp_Neq: + case kIROp_Eql: + case kIROp_Call: + return true; + default: + if (isInlinableGlobalInstForTarget(inst)) + return true; + return false; + } +} + +bool GlobalInstInliningContextGeneric::shouldInlineInstImpl(IRInst* inst) +{ + // If 'inst' has an ancestor that is currently being inlined, then we + // better inline it since we'll be removing the ancestor. + bool ancestorShouldBeInlined = false; + for (IRInst* ancestor = inst->parent; ancestor != nullptr; ancestor = ancestor->parent) + if (m_mapGlobalInstToShouldInline.tryGetValue(inst, ancestorShouldBeInlined) && + ancestorShouldBeInlined) + return true; + + if (!isInlinableGlobalInst(inst)) + return false; + if (isLegalGlobalInst(inst)) + { + for (UInt i = 0; i < inst->getOperandCount(); i++) + if (shouldInlineInst(inst->getOperand(i))) + return true; + return false; + } + return true; +} + +bool GlobalInstInliningContextGeneric::shouldInlineInst(IRInst* inst) +{ + bool result = false; + if (m_mapGlobalInstToShouldInline.tryGetValue(inst, result)) + return result; + result = shouldInlineInstImpl(inst); + m_mapGlobalInstToShouldInline[inst] = result; + return result; +} + +IRInst* GlobalInstInliningContextGeneric::inlineInst( + IRBuilder& builder, + IRCloneEnv& cloneEnv, + IRInst* inst) +{ + // We rely on this dictionary in order to force inlining of any nodes with that should be + // inlined + SLANG_ASSERT(m_mapGlobalInstToShouldInline[inst]); + + IRInst* result; + if (cloneEnv.mapOldValToNew.tryGetValue(inst, result)) + return result; + + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + auto operand = inst->getOperand(i); + IRBuilder operandBuilder(builder); + operandBuilder.setInsertBefore(getOutsideASM(builder.getInsertLoc().getInst())); + maybeInlineGlobalValue(operandBuilder, inst, operand, cloneEnv); + } + result = cloneInstAndOperands(&cloneEnv, &builder, inst); + cloneEnv.mapOldValToNew[inst] = result; + IRBuilder subBuilder(builder); + subBuilder.setInsertInto(result); + for (auto child : inst->getDecorations()) + { + cloneInst(&cloneEnv, &subBuilder, child); + } + for (auto child : inst->getChildren()) + { + m_mapGlobalInstToShouldInline[child] = true; + inlineInst(subBuilder, cloneEnv, child); + } + return result; +} + +IRInst* GlobalInstInliningContextGeneric::maybeInlineGlobalValue( + IRBuilder& builder, + IRInst* user, + IRInst* inst, + IRCloneEnv& cloneEnv) +{ + if (!shouldInlineInst(inst)) + { + switch (inst->getOp()) + { + case kIROp_Func: + case kIROp_Specialize: + case kIROp_Generic: + case kIROp_LookupWitness: + return inst; + } + if (as(inst)) + return inst; + + // If we encounter a global value that shouldn't be inlined, e.g. a const literal, + // we should insert a GlobalValueRef() inst to wrap around it, so all the dependent + // uses can be pinned to the function body. + auto result = inst; + bool shouldWrapGlobalRef = true; + if (!isLegalGlobalInst(user) && !getIROpInfo(user->getOp()).isHoistable()) + shouldWrapGlobalRef = false; + else if (shouldBeInlinedForTarget(user)) + shouldWrapGlobalRef = false; + if (shouldWrapGlobalRef) + result = builder.emitGlobalValueRef(inst); + cloneEnv.mapOldValToNew[inst] = result; + return result; + } + + // If the global value is inlinable, we make all its operands avaialble locally, and + // then copy it to the local scope. + return inlineInst(builder, cloneEnv, inst); +} + +} // namespace Slang diff --git a/source/slang/slang-ir-legalize-global-values.h b/source/slang/slang-ir-legalize-global-values.h new file mode 100644 index 0000000000..c563b02c9c --- /dev/null +++ b/source/slang/slang-ir-legalize-global-values.h @@ -0,0 +1,46 @@ +#pragma once + +#include "core/slang-dictionary.h" + +namespace Slang +{ +struct IRBuilder; +struct IRCloneEnv; +struct IRInst; +struct IRModule; + +struct GlobalInstInliningContextGeneric +{ + Dictionary m_mapGlobalInstToShouldInline; + + // Target-specific control over how inlining happens + virtual bool isLegalGlobalInstForTarget(IRInst* inst) = 0; + virtual bool isInlinableGlobalInstForTarget(IRInst* inst) = 0; + virtual bool shouldBeInlinedForTarget(IRInst* user) = 0; + virtual IRInst* getOutsideASM(IRInst* beforeInst) = 0; + + // Inline global values that can't represented by the target to their use sites. + // If this leaves any global unused, then remove it. + void inlineGlobalValuesAndRemoveIfUnused(IRModule* module); + + // Opcodes that can exist in global scope, as long as the operands are. + bool isLegalGlobalInst(IRInst* inst); + + // Opcodes that can be inlined into function bodies. + bool isInlinableGlobalInst(IRInst* inst); + + bool shouldInlineInstImpl(IRInst* inst); + + bool shouldInlineInst(IRInst* inst); + + IRInst* inlineInst(IRBuilder& builder, IRCloneEnv& cloneEnv, IRInst* inst); + + /// Inline `inst` in the local function body so they can be emitted as a local inst. + /// + IRInst* maybeInlineGlobalValue( + IRBuilder& builder, + IRInst* user, + IRInst* inst, + IRCloneEnv& cloneEnv); +}; +} // namespace Slang diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index 1de2edd4a0..c9764b2032 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -11,6 +11,7 @@ #include "slang-ir-glsl-legalize.h" #include "slang-ir-insts.h" #include "slang-ir-layout.h" +#include "slang-ir-legalize-global-values.h" #include "slang-ir-legalize-mesh-outputs.h" #include "slang-ir-loop-unroll.h" #include "slang-ir-lower-buffer-element-type.h" @@ -1590,196 +1591,51 @@ struct SPIRVLegalizationContext : public SourceEmitterBase } } - struct GlobalInstInliningContext + struct GlobalInstInliningContext : public GlobalInstInliningContextGeneric { - Dictionary m_mapGlobalInstToShouldInline; - - // Opcodes that can exist in global scope, as long as the operands are. - bool isLegalGlobalInst(IRInst* inst) + bool isLegalGlobalInstForTarget(IRInst* inst) override { - switch (inst->getOp()) - { - case kIROp_MakeStruct: - case kIROp_MakeArray: - case kIROp_MakeArrayFromElement: - case kIROp_MakeVector: - case kIROp_MakeMatrix: - case kIROp_MakeMatrixFromScalar: - case kIROp_MakeVectorFromScalar: - return true; - default: - if (as(inst)) - return true; - if (as(inst)) - return true; - return false; - } + return as(inst); } - // Opcodes that can be inlined into function bodies. - bool isInlinableGlobalInst(IRInst* inst) + bool isInlinableGlobalInstForTarget(IRInst* inst) override { switch (inst->getOp()) { - case kIROp_Add: - case kIROp_Sub: - case kIROp_Mul: - case kIROp_FRem: - case kIROp_IRem: - case kIROp_Lsh: - case kIROp_Rsh: - case kIROp_And: - case kIROp_Or: - case kIROp_Not: - case kIROp_Neg: - case kIROp_Div: - case kIROp_FieldExtract: - case kIROp_FieldAddress: - case kIROp_GetElement: - case kIROp_GetElementPtr: - case kIROp_GetOffsetPtr: - case kIROp_UpdateElement: - case kIROp_MakeTuple: - case kIROp_GetTupleElement: - case kIROp_MakeStruct: - case kIROp_MakeArray: - case kIROp_MakeArrayFromElement: - case kIROp_MakeVector: - case kIROp_MakeMatrix: - case kIROp_MakeMatrixFromScalar: - case kIROp_MakeVectorFromScalar: - case kIROp_swizzle: - case kIROp_swizzleSet: - case kIROp_MatrixReshape: - case kIROp_MakeString: - case kIROp_MakeResultError: - case kIROp_MakeResultValue: - case kIROp_GetResultError: - case kIROp_GetResultValue: - case kIROp_CastFloatToInt: - case kIROp_CastIntToFloat: - case kIROp_CastIntToPtr: - case kIROp_PtrCast: - case kIROp_CastPtrToBool: - case kIROp_CastPtrToInt: - case kIROp_BitAnd: - case kIROp_BitNot: - case kIROp_BitOr: - case kIROp_BitXor: - case kIROp_BitCast: - case kIROp_IntCast: - case kIROp_FloatCast: - case kIROp_Greater: - case kIROp_Less: - case kIROp_Geq: - case kIROp_Leq: - case kIROp_Neq: - case kIROp_Eql: - case kIROp_Call: case kIROp_SPIRVAsm: return true; default: - if (as(inst)) - return true; - if (as(inst)) - return true; - return false; - } - } - - bool shouldInlineInstImpl(IRInst* inst) - { - if (!isInlinableGlobalInst(inst)) - return false; - if (isLegalGlobalInst(inst)) - { - for (UInt i = 0; i < inst->getOperandCount(); i++) - if (shouldInlineInst(inst->getOperand(i))) - return true; - return false; + break; } - return true; - } - bool shouldInlineInst(IRInst* inst) - { - bool result = false; - if (m_mapGlobalInstToShouldInline.tryGetValue(inst, result)) - return result; - result = shouldInlineInstImpl(inst); - m_mapGlobalInstToShouldInline[inst] = result; - return result; + if (as(inst)) + return true; + if (as(inst)) + return true; + return false; } - IRInst* inlineInst(IRBuilder& builder, IRCloneEnv& cloneEnv, IRInst* inst) + bool shouldBeInlinedForTarget(IRInst* user) override { - IRInst* result; - if (cloneEnv.mapOldValToNew.tryGetValue(inst, result)) - return result; - - for (UInt i = 0; i < inst->getOperandCount(); i++) - { - auto operand = inst->getOperand(i); - IRBuilder operandBuilder(builder); - setInsertBeforeOutsideASM(operandBuilder, builder.getInsertLoc().getInst()); - maybeInlineGlobalValue(operandBuilder, inst, operand, cloneEnv); - } - result = cloneInstAndOperands(&cloneEnv, &builder, inst); - cloneEnv.mapOldValToNew[inst] = result; - IRBuilder subBuilder(builder); - subBuilder.setInsertInto(result); - for (auto child : inst->getDecorations()) - { - cloneInst(&cloneEnv, &subBuilder, child); - } - for (auto child : inst->getChildren()) - { - inlineInst(subBuilder, cloneEnv, child); - } - return result; + if (as(user) && as(user)) + return true; + else if (as(user)) + return true; + return false; } - /// Inline `inst` in the local function body so they can be emitted as a local inst. - /// - IRInst* maybeInlineGlobalValue( - IRBuilder& builder, - IRInst* user, - IRInst* inst, - IRCloneEnv& cloneEnv) + IRInst* getOutsideASM(IRInst* beforeInst) override { - if (!shouldInlineInst(inst)) + auto parent = beforeInst->getParent(); + while (parent) { - switch (inst->getOp()) + if (as(parent)) { - case kIROp_Func: - case kIROp_Specialize: - case kIROp_Generic: - case kIROp_LookupWitness: - return inst; + return parent; } - if (as(inst)) - return inst; - - // If we encounter a global value that shouldn't be inlined, e.g. a const literal, - // we should insert a GlobalValueRef() inst to wrap around it, so all the dependent - // uses can be pinned to the function body. - auto result = inst; - bool shouldWrapGlobalRef = true; - if (!isLegalGlobalInst(user) && !getIROpInfo(user->getOp()).isHoistable()) - shouldWrapGlobalRef = false; - else if (as(user) && as(user)) - shouldWrapGlobalRef = false; - else if (as(user)) - shouldWrapGlobalRef = false; - if (shouldWrapGlobalRef) - result = builder.emitGlobalValueRef(inst); - cloneEnv.mapOldValToNew[inst] = result; - return result; + parent = parent->getParent(); } - - // If the global value is inlinable, we make all its operands avaialble locally, and - // then copy it to the local scope. - return inlineInst(builder, cloneEnv, inst); + return beforeInst; } }; @@ -1965,21 +1821,6 @@ struct SPIRVLegalizationContext : public SourceEmitterBase } } - static void setInsertBeforeOutsideASM(IRBuilder& builder, IRInst* beforeInst) - { - auto parent = beforeInst->getParent(); - while (parent) - { - if (as(parent)) - { - builder.setInsertBefore(parent); - return; - } - parent = parent->getParent(); - } - builder.setInsertBefore(beforeInst); - } - void determineSpirvVersion() { // Determine minimum spirv version from target request. @@ -2160,11 +2001,6 @@ struct SPIRVLegalizationContext : public SourceEmitterBase t->replaceUsesWith(lowered); } - // Inline global values that can't represented by SPIRV constant inst - // to their use sites. - List globalInstUsesToInline; - GlobalInstInliningContext globalInstInliningContext; - for (auto globalInst : m_module->getGlobalInsts()) { if (auto func = as(globalInst)) @@ -2178,28 +2014,9 @@ struct SPIRVLegalizationContext : public SourceEmitterBase // true. sortBlocksInFunc(func); } - - if (globalInstInliningContext.isInlinableGlobalInst(globalInst)) - { - for (auto use = globalInst->firstUse; use; use = use->nextUse) - { - if (getParentFunc(use->getUser()) != nullptr) - globalInstUsesToInline.add(use); - } - } } - for (auto use : globalInstUsesToInline) - { - auto user = use->getUser(); - IRBuilder builder(user); - setInsertBeforeOutsideASM(builder, user); - IRCloneEnv cloneEnv; - auto val = globalInstInliningContext - .maybeInlineGlobalValue(builder, use->getUser(), use->get(), cloneEnv); - if (val != use->get()) - builder.replaceOperand(use, val); - } + GlobalInstInliningContext().inlineGlobalValuesAndRemoveIfUnused(m_module); // Some legalization processing may change the function parameter types, // so we need to update the function types to match that. diff --git a/source/slang/slang-ir-wgsl-legalize.cpp b/source/slang/slang-ir-wgsl-legalize.cpp index e6e3755928..afdf412b13 100644 --- a/source/slang/slang-ir-wgsl-legalize.cpp +++ b/source/slang/slang-ir-wgsl-legalize.cpp @@ -1,6 +1,7 @@ #include "slang-ir-wgsl-legalize.h" #include "slang-ir-insts.h" +#include "slang-ir-legalize-global-values.h" #include "slang-ir-legalize-varying-params.h" #include "slang-ir-util.h" #include "slang-ir.h" @@ -1587,6 +1588,35 @@ struct LegalizeWGSLEntryPointContext } }; +struct GlobalInstInliningContext : public GlobalInstInliningContextGeneric +{ + bool isLegalGlobalInstForTarget(IRInst* /* inst */) override + { + // The global instructions that are generically considered legal are fine for + // WGSL. + return false; + } + + bool isInlinableGlobalInstForTarget(IRInst* /* inst */) override + { + // The global instructions that are generically considered inlineable are fine + // for WGSL. + return false; + } + + bool shouldBeInlinedForTarget(IRInst* /* user */) override + { + // WGSL doesn't do any extra inlining beyond what is generically done by default. + return false; + } + + IRInst* getOutsideASM(IRInst* beforeInst) override + { + // Not needed for WGSL, check e.g. the SPIR-V case to see why this is used. + return beforeInst; + } +}; + void legalizeIRForWGSL(IRModule* module, DiagnosticSink* sink) { List entryPoints; @@ -1612,6 +1642,10 @@ void legalizeIRForWGSL(IRModule* module, DiagnosticSink* sink) // Go through every instruction in the module and legalize them as needed. context.processInst(module->getModuleInst()); + + // Some global insts are illegal, e.g. function calls. + // We need to inline and remove those. + GlobalInstInliningContext().inlineGlobalValuesAndRemoveIfUnused(module); } } // namespace Slang diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 29fbcc3c95..d1c16a3a1a 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -7971,7 +7971,7 @@ void IRInst::removeOperand(Index index) } // Remove this instruction from its parent block, -// and then destroy it (it had better have no uses!) +// and then destroy it (it had better have no uses, or descendants with uses!) void IRInst::removeAndDeallocate() { removeAndDeallocateAllDecorationsAndChildren(); diff --git a/tests/expected-failure-github.txt b/tests/expected-failure-github.txt index 81a61133bc..b9b1e4b393 100644 --- a/tests/expected-failure-github.txt +++ b/tests/expected-failure-github.txt @@ -12,5 +12,4 @@ tests/autodiff/custom-intrinsic.slang.2 syn (wgpu) tests/bugs/buffer-swizzle-store.slang.3 syn (wgpu) tests/compute/interface-shader-param-in-struct.slang.4 syn (wgpu) tests/compute/interface-shader-param.slang.5 syn (wgpu) -tests/language-feature/constants/static-const-in-generic-interface.slang.1 syn (wgpu) tests/language-feature/shader-params/interface-shader-param-ordinary.slang.4 syn (wgpu) diff --git a/tests/language-feature/constants/static-const-in-generic-interface.slang b/tests/language-feature/constants/static-const-in-generic-interface.slang index e980a812a4..87d8e3be84 100644 --- a/tests/language-feature/constants/static-const-in-generic-interface.slang +++ b/tests/language-feature/constants/static-const-in-generic-interface.slang @@ -1,8 +1,6 @@ // static-const-in-generic-interface.slang //TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -// WGSL: Functions cannot be called at module scope #5607 -//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu // Test that `static const` variable declarations inside of // a generic `interface` type correctly translate to interface requirements.