From 388de5f68ebef556f9addcf36685109d2524ee4e Mon Sep 17 00:00:00 2001 From: Ellie Hermaszewska Date: Fri, 5 Jul 2024 21:01:46 +0800 Subject: [PATCH 01/10] Correct type for double log10 (#4550) Fixes https://github.com/shader-slang/slang/issues/4549 --- prelude/slang-cpp-scalar-intrinsics.h | 2 +- tests/hlsl-intrinsic/scalar-double-log10.slang | 10 ++++++++++ .../scalar-double-log10.slang.expected.txt | 5 +++++ 3 files changed, 16 insertions(+), 1 deletion(-) create mode 100644 tests/hlsl-intrinsic/scalar-double-log10.slang create mode 100644 tests/hlsl-intrinsic/scalar-double-log10.slang.expected.txt diff --git a/prelude/slang-cpp-scalar-intrinsics.h b/prelude/slang-cpp-scalar-intrinsics.h index 8fc90fac8c..55001cb217 100644 --- a/prelude/slang-cpp-scalar-intrinsics.h +++ b/prelude/slang-cpp-scalar-intrinsics.h @@ -266,7 +266,7 @@ double F64_cosh(double f); double F64_tanh(double f); double F64_log2(double f); double F64_log(double f); -double F64_log10(float f); +double F64_log10(double f); double F64_exp2(double f); double F64_exp(double f); double F64_abs(double f); diff --git a/tests/hlsl-intrinsic/scalar-double-log10.slang b/tests/hlsl-intrinsic/scalar-double-log10.slang new file mode 100644 index 0000000000..e68088302f --- /dev/null +++ b/tests/hlsl-intrinsic/scalar-double-log10.slang @@ -0,0 +1,10 @@ +//TEST(compute):COMPARE_COMPUTE:-cpu -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name outputBuffer +RWStructuredBuffer outputBuffer; + +[numthreads(4, 1, 1)] +void computeMain(in uint i : SV_GroupIndex) +{ + outputBuffer[i] = int(log10(double(i) + 1.0) * 100); +} diff --git a/tests/hlsl-intrinsic/scalar-double-log10.slang.expected.txt b/tests/hlsl-intrinsic/scalar-double-log10.slang.expected.txt new file mode 100644 index 0000000000..318d50830a --- /dev/null +++ b/tests/hlsl-intrinsic/scalar-double-log10.slang.expected.txt @@ -0,0 +1,5 @@ +type: int32_t +0 +30 +47 +60 From 65194cf0a926267839ff56e47c1a1eb14e2b0977 Mon Sep 17 00:00:00 2001 From: Ellie Hermaszewska Date: Fri, 5 Jul 2024 21:09:13 +0800 Subject: [PATCH 02/10] Add vector overloads for or and and (#4529) * Add vector overloads for or and and Closes #4441 and #4434 * Disable cuda checks which use unsupported bool vectors * Add tests for 4531 --- source/slang/core.meta.slang | 45 ++++++++++++++++++++++++++++++++++++ tests/bugs/gh-4434.slang | 36 +++++++++++++++++++++++++++++ tests/bugs/gh-4441.slang | 36 +++++++++++++++++++++++++++++ tests/bugs/gh-4531.slang | 20 ++++++++++++++++ tests/bugs/gh-4533.slang | 2 +- 5 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 tests/bugs/gh-4434.slang create mode 100644 tests/bugs/gh-4441.slang create mode 100644 tests/bugs/gh-4531.slang diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 631c09d194..a8badc05aa 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -2255,6 +2255,51 @@ bool and(bool v0, bool v1) return __and(v0, v1); } +[__unsafeForceInlineEarly] +[OverloadRank(-10)] +__intrinsic_op($(kIROp_And)) +vector and(vector v0, vector v1); + +[__unsafeForceInlineEarly] +[OverloadRank(-10)] +vector and(bool b, vector v) +{ + return and(vector(b), v); +} + +[__unsafeForceInlineEarly] +[OverloadRank(-10)] +vector and(vector v, bool b) +{ + return and(v, vector(b)); +} + +[__unsafeForceInlineEarly] +[OverloadRank(-10)] +bool or(bool v0, bool v1) +{ + return __or(v0, v1); +} + +[__unsafeForceInlineEarly] +[OverloadRank(-10)] +__intrinsic_op($(kIROp_Or)) +vector or(vector v0, vector v1); + +[__unsafeForceInlineEarly] +[OverloadRank(-10)] +vector or(bool b, vector v) +{ + return or(vector(b), v); +} + +[__unsafeForceInlineEarly] +[OverloadRank(-10)] +vector or(vector v, bool b) +{ + return or(v, vector(b)); +} + __generic [__unsafeForceInlineEarly] [OverloadRank(-10)] diff --git a/tests/bugs/gh-4434.slang b/tests/bugs/gh-4434.slang new file mode 100644 index 0000000000..9752260088 --- /dev/null +++ b/tests/bugs/gh-4434.slang @@ -0,0 +1,36 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-dx12 -use-dxil +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-mtl +//DISABLE_TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-cpu +//DISABLE_TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-cuda + +// CHECK: 1 +// CHECK-NEXT: 1 +// CHECK-NEXT: 1 +// CHECK-NEXT: 1 + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +[numthreads(4, 1, 1)] +void computeMain(uint tid : SV_GroupIndex) +{ + bool a, b, c; + c = and(a, b); + + bool1 i, j, k; + bool2 l, m, n; + bool3 o, p, q; + bool4 r, s, t; + k = and(i, j); + n = and(m, l); + q = and(o, p); + t = and(r, s); + + k = !and(k, false); + n = !and(n, false); + q = !and(q, false); + t = !and(t, false); + + outputBuffer[tid] = all(k) && all(n) && all(q) && all(t); +} diff --git a/tests/bugs/gh-4441.slang b/tests/bugs/gh-4441.slang new file mode 100644 index 0000000000..59d577c8d5 --- /dev/null +++ b/tests/bugs/gh-4441.slang @@ -0,0 +1,36 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-dx12 -use-dxil +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-mtl +//DISABLE_TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-cpu +//DISABLE_TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-cuda + +// CHECK: 1 +// CHECK-NEXT: 1 +// CHECK-NEXT: 1 +// CHECK-NEXT: 1 + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +[numthreads(4, 1, 1)] +void computeMain(uint tid : SV_GroupIndex) +{ + bool a, b, c; + c = or(a, b); + + bool1 i, j, k; + bool2 l, m, n; + bool3 o, p, q; + bool4 r, s, t; + k = or(i, j); + n = or(m, l); + q = or(o, p); + t = or(r, s); + + k = or(k, true); + n = or(n, true); + q = or(q, true); + t = or(t, true); + + outputBuffer[tid] = all(k) && all(n) && all(q) && all(t); +} diff --git a/tests/bugs/gh-4531.slang b/tests/bugs/gh-4531.slang new file mode 100644 index 0000000000..e84c4713fb --- /dev/null +++ b/tests/bugs/gh-4531.slang @@ -0,0 +1,20 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-dx12 -use-dxil +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-mtl +//DISABLE_TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-cpu +//DISABLE_TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-cuda + +// CHECK: 1 +// CHECK-NEXT: 1 +// CHECK-NEXT: 1 +// CHECK-NEXT: 1 + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +[numthreads(4, 1, 1)] +void computeMain(uint tid : SV_GroupIndex) +{ + vector k = true; + outputBuffer[tid] = all(k); +} diff --git a/tests/bugs/gh-4533.slang b/tests/bugs/gh-4533.slang index 3ee27996b9..a483c89b95 100644 --- a/tests/bugs/gh-4533.slang +++ b/tests/bugs/gh-4533.slang @@ -16,5 +16,5 @@ RWStructuredBuffer outputBuffer; void computeMain(uint tid : SV_GroupIndex) { vector k = float1(tid); - outputBuffer[tid] = all(k) && any(k) && bool(asint(k)) && bool(asuint(k)); + outputBuffer[tid] = all(k) && any(k) && bool(asint(k)) && bool(asuint(k)) && bool(sign(k)); } From 2cb65a89f33ac1c4bad216e576069edbdc04933f Mon Sep 17 00:00:00 2001 From: ArielG-NV <159081215+ArielG-NV@users.noreply.github.com> Date: Sun, 7 Jul 2024 07:21:38 -0400 Subject: [PATCH 03/10] correctly setting launch parameters should fix the test (#4551) --- tests/compute/semantic.slang | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/compute/semantic.slang b/tests/compute/semantic.slang index 4946b928dc..e42df4a11c 100644 --- a/tests/compute/semantic.slang +++ b/tests/compute/semantic.slang @@ -2,7 +2,7 @@ //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -compute-dispatch 3,1,1 -shaderobj //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -dx12 -compute-dispatch 3,1,1 -shaderobj //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -compute-dispatch 3,1,1 -shaderobj -//DISABLE_TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl +//TEST(compute):COMPARE_COMPUTE_EX:-mtl -compute -compute-dispatch 3,1,1 -shaderobj //TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer RWStructuredBuffer outputBuffer; From 4a49769c5b6b351b3c1c9a9968b3926839504606 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Mon, 8 Jul 2024 18:48:08 -0400 Subject: [PATCH 04/10] Treat global variables and parameters as non-differentiable when checking derivative data-flow (#4526) Global parameters are by-default not differentiable (even if they are of a differentiable type), because our auto-diff passes do not touch anything outside of function bodies. The solution is to use wrapper objects with differentiable getter/setter methods (and we should provide a few such objects in the stdlib). Fixes: #3289 This is a potentially breaking change: User code that was previously working with global variables of a differentiable type will now throw an error (previously the gradient would be dropped without warning). The solution is to use `detach()` to keep same behavior as before or rewrite the access using differentiable getter/setter methods. --- .../slang-ir-check-differentiability.cpp | 2 -- .../warn-on-shared-memory-access.slang | 32 +++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) create mode 100644 tests/autodiff/warn-on-shared-memory-access.slang diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index c5c03f7da4..8b4886a2cf 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -169,9 +169,7 @@ struct CheckDifferentiabilityPassContext : public InstPassBase switch (addr->getOp()) { case kIROp_Var: - case kIROp_GlobalVar: case kIROp_Param: - case kIROp_GlobalParam: return isDifferentiableType(diffTypeContext, addr->getDataType()); case kIROp_FieldAddress: if (!as(addr)->getField() || diff --git a/tests/autodiff/warn-on-shared-memory-access.slang b/tests/autodiff/warn-on-shared-memory-access.slang new file mode 100644 index 0000000000..bccf8b1fa0 --- /dev/null +++ b/tests/autodiff/warn-on-shared-memory-access.slang @@ -0,0 +1,32 @@ +//TEST:SIMPLE(filecheck=CHECK): -target hlsl -line-directive-mode none + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +groupshared float s_shared; + +[BackwardDifferentiable] +float get_thread_5_value(float v, uint group_thread_id) +{ + if(group_thread_id == 5) + { + // Using 'detach(v)' makes the error go away + s_shared = v; + // CHECK: tests/autodiff/warn-on-shared-memory-access.slang(14): error 41024: derivative is lost during assignment to non-differentiable location, use 'detach()' to clarify intention. + // CHECK: s_shared = v; + // CHECK: ^ + } + GroupMemoryBarrierWithGroupSync(); + return s_shared; +} + +[shader("compute")] +[numthreads(128, 1, 1)] +void computeMain(uint3 group_thread_id: SV_GroupThreadID, uint3 dispatch_thread_id: SV_DispatchThreadID) +{ + DifferentialPair value = diffPair(3.f, 0.f); + + bwd_diff(get_thread_5_value)(value, group_thread_id.x, 1.0f); + + outputBuffer[dispatch_thread_id.x] = value.d; +} \ No newline at end of file From a453fadfb373499f08779dd7df8f2347d292fd91 Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 8 Jul 2024 19:33:51 -0700 Subject: [PATCH 05/10] Language server performance and document symbol fix. (#4561) --- source/slang/slang-check-decl.cpp | 1 + source/slang/slang-language-server-document-symbols.cpp | 2 ++ source/slang/slang.cpp | 8 +++++++- 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index c3074bc553..66bdbc18ec 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1823,6 +1823,7 @@ namespace Slang static ConstructorDecl* _createCtor(SemanticsDeclVisitorBase* visitor, ASTBuilder* m_astBuilder, AggTypeDecl* decl) { auto ctor = m_astBuilder->create(); + addModifier(ctor, m_astBuilder->create()); auto ctorName = visitor->getName("$init"); ctor->ownedScope = m_astBuilder->create(); ctor->ownedScope->containerDecl = ctor; diff --git a/source/slang/slang-language-server-document-symbols.cpp b/source/slang/slang-language-server-document-symbols.cpp index 26366d6cfd..ec9b434ebe 100644 --- a/source/slang/slang-language-server-document-symbols.cpp +++ b/source/slang/slang-language-server-document-symbols.cpp @@ -150,6 +150,8 @@ namespace Slang continue; if (!nameLoc.loc.isValid()) continue; + if (child->hasModifier() || child->hasModifier()) + continue; auto humaneLoc = srcManager->getHumaneLoc(nameLoc.loc, SourceLocType::Actual); if (humaneLoc.line == 0) continue; diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 67492d9b22..adfa031f58 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -3669,8 +3669,14 @@ RefPtr Linkage::findOrImportModule( // Look for a precompiled module first, if not exist, load from source. - for (int checkBinaryModule = 1; checkBinaryModule >= 0; checkBinaryModule--) + bool shouldCheckBinaryModuleSettings[2] = { true, false }; + + for (auto checkBinaryModule : shouldCheckBinaryModuleSettings) { + // When in language server, we always prefer to use source module if it is available. + if (isInLanguageServer()) + checkBinaryModule = !checkBinaryModule; + // Try without translating `_` to `-` first, if that fails, try translating. for (int translateUnderScore = 0; translateUnderScore <= 1; translateUnderScore++) { From 5a174dfab4ae0852cb96df5f48bae474949cc017 Mon Sep 17 00:00:00 2001 From: kaizhangNV <149626564+kaizhangNV@users.noreply.github.com> Date: Mon, 8 Jul 2024 21:34:51 -0700 Subject: [PATCH 06/10] Fix the issue in emitFloatCast (#4559) * Fix the issue in emitFloatCast In emitFloatCast function, we only considered the input type is float scalar or float vector, so if the input type is a float matrix type, it will crash. We should also handle the float matrix type. Also, we add some diagnose info to point out the source location where there is error happened, so in the future it's easier to tell us what happens. * Add a unit test * Disable the test for metal Metal doesn't support 'double'. " metal 32023.35: /tmp/unknown-YgHAsJ.metal(15): error : 'double' is not supported in Metal matrix b_0 = matrix (a_0); " --- source/slang/slang-emit-spirv.cpp | 35 +++++++++++++++++++++++++++---- source/slang/slang-ir-util.cpp | 7 +++++++ source/slang/slang-ir-util.h | 3 +++ tests/bugs/gh-4556.slang | 21 +++++++++++++++++++ 4 files changed, 62 insertions(+), 4 deletions(-) create mode 100644 tests/bugs/gh-4556.slang diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 1a4d80ae10..f7d807190f 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -2403,7 +2403,7 @@ struct SPIRVEmitContext // For now we aren't handling function declarations; // we expect to deal only with fully linked modules. // - SLANG_UNUSED(irFunc); + m_sink->diagnose(irFunc, Diagnostics::internalCompilerError); SLANG_UNEXPECTED("function declaration in SPIR-V emit"); UNREACHABLE_RETURN(nullptr); } @@ -5194,9 +5194,36 @@ struct SPIRVEmitContext { const auto fromTypeV = inst->getOperand(0)->getDataType(); const auto toTypeV = inst->getDataType(); - SLANG_ASSERT(!as(fromTypeV) == !as(toTypeV)); - const auto fromType = getVectorElementType(fromTypeV); - const auto toType = getVectorElementType(toTypeV); + + IRType* fromType = nullptr; + IRType* toType = nullptr; + + if (as(fromTypeV) || as(toTypeV)) + { + fromType = getVectorElementType(fromTypeV); + toType = getVectorElementType(toTypeV); + } + else if (as(fromTypeV) || as(toTypeV)) + { + fromType = getMatrixElementType(fromTypeV); + toType = getMatrixElementType(toTypeV); + } + else + { + fromType = fromTypeV; + toType = toTypeV; + } + + // We'd better give some diagnostics to at least point out which line in the shader is wrong, so + // it can help the user or developers to locate the issue easier. + if (!isFloatingType(fromType)) { + m_sink->diagnose(inst, Diagnostics::internalCompilerError); + } + + if (!isFloatingType(toType)) { + m_sink->diagnose(inst, Diagnostics::internalCompilerError); + } + SLANG_ASSERT(isFloatingType(fromType)); SLANG_ASSERT(isFloatingType(toType)); SLANG_ASSERT(!isTypeEqual(fromType, toType)); diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 10c7bfea62..8294cd533c 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -23,6 +23,13 @@ IRType* getVectorElementType(IRType* type) return type; } +IRType* getMatrixElementType(IRType* type) +{ + if (auto matrixType = as(type)) + return matrixType->getElementType(); + return type; +} + Dictionary buildInterfaceRequirementDict(IRInterfaceType* interfaceType) { Dictionary result; diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 855046c048..c7d6a15441 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -77,6 +77,9 @@ bool isComInterfaceType(IRType* type); // If `type` is a vector, returns its element type. Otherwise, return `type`. IRType* getVectorElementType(IRType* type); +// If `type` is a matrix, returns its element type. Otherwise, return `type`. +IRType* getMatrixElementType(IRType* type); + // True if type is a resource backing memory bool isResourceType(IRType* type); diff --git a/tests/bugs/gh-4556.slang b/tests/bugs/gh-4556.slang new file mode 100644 index 0000000000..eed84779e5 --- /dev/null +++ b/tests/bugs/gh-4556.slang @@ -0,0 +1,21 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-dx12 -compute -output-using-type -shaderobj +//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -compute -output-using-type -shaderobj +//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -glsl -compute -output-using-type -shaderobj +//DISABLE_TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-mtl -compute -output-using-type -shaderobj +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -cpu -output-using-type -shaderobj + +//TEST_INPUT:ubuffer(data=[0.0 0.0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain(uint3 id: SV_DispatchThreadID) +{ + float3x4 a = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + double3x4 b = (double3x4)a; + + // CHECK: 1.000000 + outputBuffer[0] = (float)b[0][0]; + // CHECK: 2.000000 + outputBuffer[1] = (float)b[0][1]; +} From 29418735e5f806e8c0e365314154f55a753e5271 Mon Sep 17 00:00:00 2001 From: kaizhangNV <149626564+kaizhangNV@users.noreply.github.com> Date: Tue, 9 Jul 2024 08:58:35 -0700 Subject: [PATCH 07/10] Add intrinsic for clock2x32ARB (#4563) --- source/slang/hlsl.meta.slang | 19 +++++++++++++++++++ source/slang/slang-capabilities.capdef | 2 ++ tests/glsl-intrinsic/clock-read.slang | 19 +++++++++++++++++++ 3 files changed, 40 insertions(+) create mode 100644 tests/glsl-intrinsic/clock-read.slang diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index b282dca2ad..5564094db3 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -20229,3 +20229,22 @@ struct ConstBufferPointer } } } + +__glsl_version(450) +__glsl_extension(GL_ARB_shader_clock) +__target_intrinsic(glsl, clock2x32ARB) +[require(glsl_spirv, GL_ARB_shader_clock)] +uint2 clock2x32ARB() +{ + __target_switch + { + case glsl: __intrinsic_asm "clock2x32ARB"; + case spirv: + const uint32_t scopeId_subgroup = 3; + return spirv_asm { + OpCapability ShaderClockKHR; + OpExtension "SPV_KHR_shader_clock"; + result:$$uint2 = OpReadClockKHR $scopeId_subgroup; + }; + } +} diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef index 140346163d..4f71eb8233 100644 --- a/source/slang/slang-capabilities.capdef +++ b/source/slang/slang-capabilities.capdef @@ -316,6 +316,7 @@ def _GL_ARB_sparse_texture2 : _GL_ARB_sparse_texture; def _GL_ARB_sparse_texture_clamp : _GL_ARB_sparse_texture2; def _GL_ARB_texture_gather : _GLSL_130; def _GL_ARB_texture_query_levels : _GLSL_130; +def _GL_ARB_shader_clock : _GLSL_450; def _GL_KHR_memory_scope_semantics : _GLSL_420; def _GL_KHR_shader_subgroup_arithmetic : _GLSL_140; @@ -370,6 +371,7 @@ alias GL_ARB_shader_texture_image_samples = _GL_ARB_shader_texture_image_samples alias GL_ARB_sparse_texture_clamp = _GL_ARB_sparse_texture_clamp | spvSparseResidency; alias GL_ARB_texture_gather = _GL_ARB_texture_gather | spvImageGatherExtended | metal; alias GL_ARB_texture_query_levels = _GL_ARB_texture_query_levels | spvImageQuery | metal; +alias GL_ARB_shader_clock = _GL_ARB_shader_clock | spvShaderClockKHR; alias GL_KHR_memory_scope_semantics = _GL_KHR_memory_scope_semantics | _spirv_1_0; alias GL_KHR_shader_subgroup_arithmetic = _GL_KHR_shader_subgroup_arithmetic | spvGroupNonUniformArithmetic; diff --git a/tests/glsl-intrinsic/clock-read.slang b/tests/glsl-intrinsic/clock-read.slang new file mode 100644 index 0000000000..4bac5a07d3 --- /dev/null +++ b/tests/glsl-intrinsic/clock-read.slang @@ -0,0 +1,19 @@ +//TEST:SIMPLE(filecheck=CHECK1): -target glsl +//TEST:SIMPLE(filecheck=CHECK2): -target spirv +//TEST:SIMPLE(filecheck=CHECK3): -target spirv -emit-spirv-via-glsl + + +// CHECK1: GL_ARB_shader_clock : require +// CHECK2: OpCapability ShaderClockKHR +// CHECK2: OpExtension "SPV_KHR_shader_clock" +RWStructuredBuffer output; + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain(uint3 id: SV_DispatchThreadID) +{ + output[0] = clock2x32ARB().x; + // CHECK1: clock2x32ARB + // CHECK2: OpReadClockKHR %v2uint %uint_3 + // CHECK3: OpReadClockKHR %v2uint %uint_3 +} From ddd14be7a7e807a124a29221d53a5e83f92c570a Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 9 Jul 2024 12:57:07 -0400 Subject: [PATCH 08/10] Add documentation of the type system and decorations (#4470) --- docs/design/autodiff/decorators.md | 92 +++++++++ docs/design/autodiff/ir-overview.md | 2 +- docs/design/autodiff/types.md | 290 ++++++++++++++++++++++++++++ 3 files changed, 383 insertions(+), 1 deletion(-) create mode 100644 docs/design/autodiff/decorators.md create mode 100644 docs/design/autodiff/types.md diff --git a/docs/design/autodiff/decorators.md b/docs/design/autodiff/decorators.md new file mode 100644 index 0000000000..b08da29122 --- /dev/null +++ b/docs/design/autodiff/decorators.md @@ -0,0 +1,92 @@ +This document details auto-diff-related decorations that are lowered in to the IR to help annotate methods with relevant information. + +## `[Differentiable]` +The `[Differentiable]` attribute is used to mark functions as being differentiable. The auto-diff process will only touch functions that are marked explicitly as `[Differentiable]`. All other functions are considered non-differentiable and calls to such functions from a differentiable function are simply copied as-is with no transformation. + +Further, only `[Differentiable]` methods are checked during the derivative data-flow pass. This decorator is translated into `BackwardDifferentiableAttribute` (which implies both forward and backward differentiability), and then lowered into the IR `OpBackwardDifferentiableDecoration` + +**Note:** `[Differentiable]` was previously implemented as two separate decorators `[ForwardDifferentiable]` and `[BackwardDifferentiable]` to denote differentiability with each type of auto-diff transformation. However, these are now **deprecated**. The preferred approach is to use only `[Differentiable]` + +`fwd_diff` and `bwd_diff` cannot be directly called on methods that don't have the `[Differentiable]` tag (will result in an error). If non-`[Differentiable]` methods are called from within a `[Differentiable]` method, they must be wrapped in `no_diff()` operation (enforced by the [derivative data-flow analysis pass](./types.md#derivative-data-flow-analysis) ) + +### `[Differentiable]` for `interface` Requirements +The `[Differentiable]` attribute can also be used to decorate interface requirements. In this case, the attribute is handled in a slightly different manner, since we do not have access to the concrete implementations. + +The process is roughly as follows: +1. During the semantic checking step, when checking a method that is an interface requirement (in `checkCallableDeclCommon` in `slang-check-decl.cpp`), we check if the method has a `[Differentiable]` attribute +2. If yes, we construct create a set of new method declarations, one for the forward-mode derivative (`ForwardDerivativeRequirementDecl`) and one for the reverse-mode derivative (`BackwardDerivativeRequirementDecl`), with the appropriate translated function types and insert them into the same interface. +3. Insert a new member into the original method to reference the new declarations (`DerivativeRequirementReferenceDecl`) +4. When lowering to IR, the `DerivativeRequirementReferenceDecl` member is converted into a custom derivative reference by adding the `OpBackwardDerivativeDecoration(deriv-fn-req-key)` and `OpForwardDerivativeDecoration(deriv-fn-req-key)` decorations on the primal method's requirement key. + +Here is an example of what this would look like: + +```C +interface IFoo +{ + [Differentiable] + float bar(float); +}; + +// After checking & lowering +interface IFoo_after_checking_and_lowering +{ + [BackwardDerivative(bar_bwd)] + [ForwardDerivative(bar_fwd)] + float bar(float); + + void bar_bwd(inout DifferentialPair, float); + + DifferentialPair bar_fwd(DifferentialPair); +}; +``` + +**Note:** All conforming types must _also_ declare their corresponding implementations as differentiable so that their derivative implementations are synthesized to match the interface signature. In this sense, the `[Differentiable]` attribute is part of the functions signature, so a `[Differentiable]` interface requirement can only be satisfied by a `[Differentiable]` function implementation + +### `[TreatAsDifferentiable]` +In large codebases where some interfaces may have several possible implementations, it may not be reasonable to have to mark all possible implementations with `[Differentiable]`, especially if certain implementations use hacks or workarounds that need additional consideration before they can be marked `[Differentiable]` + +In such cases, we provide the `[TreatAsDifferentiable]` decoration (AST node: `TreatAsDifferentiableAttribute`, IR: `OpTreatAsDifferentiableDecoration`), which instructs the auto-diff passes to construct an 'empty' function that returns a 0 (or 0-equivalent) for the derivative values. This allows the signature of a `[TreatAsDifferentiable]` function to match a `[Differentiable]` requirment without actually having to produce a derivative. + +## Custom derivative decorators +In many cases, it is desirable to manually specify the derivative code for a method rather than let the auto-diff pass synthesize it from the method body. This is usually desirable if: +1. The body of the method is too complex, and there is a simpler, mathematically equivalent way to compute the same value (often the case for intrinsics like `sin(x)`, `arccos(x)`, etc..) +2. The method involves global/shared memory accesses, and synthesized derivative code may cause race conditions or be very slow due to overuse of synchronization. For this reason Slang assumes global memory accesses are non-differentiable by default, and requires that the user (or stdlib) define separate accessors with different derivative semantics. + +The Slang front-end provides two sets of decorators to facilitate this: +1. To reference a custom derivative function from a primal function: `[ForwardDerivative(fn)]` and `[BackwardDerivative(fn)]` (AST Nodes: `ForwardDerivativeAttribute`/`BackwardDerivativeAttribute`, IR: `OpForwardDervativeDecoration`/`OpBackwardDerivativeDecoration`), and +2. To reference a primal function from its custom derivative function: `[ForwardDerivativeOf(fn)]` and `[BackwardDerivativeOf(fn)]` (AST Nodes: `ForwardDerivativeAttributeOf`/`BackwardDerivativeAttributeOf`). These attributes are useful to provide custom derivatives for existing methods in a different file without having to edit/change that module. For instance, we use `diff.meta.slang` to provide derivatives for stdlib functions in `hlsl.meta.slang`. When lowering to IR, these references are placed on the target (primal function). That way both sets of decorations are lowered on the primal function. + +These decorators also work on generically defined methods, as well as struct methods. Similar to how function calls work, these decorators also work on overloaded methods (and reuse the `ResolveInoke` infrastructure to perform resolution) + +### Checking custom derivative signatures +To ensure that the user-provided derivatives agree with the expected signature, as well as resolve the appropriate method when multiple overloads are available, we check the signature of the custom derivative function against the translated version of the primal function. This currently occurs in `checkDerivativeAttribute()`/`checkDerivativeOfAttribute()`. + +The checking process re-uses existing infrastructure from `ResolveInvoke`, by constructing a temporary invoke expr to call the user-provided derivative using a set of 'imaginary' arguments according to the translated type of the primal method. If `ResolveInvoke` is successful, the provided derivative signature is considered to be a match. This approach also automatically allows us to resolve overloaded methods, account for generic types and type coercion. + +## `[PrimalSubstitute(fn)]` and `[PrimalSubstituteOf(fn)]` +In some cases, we face the opposite problem that inspired custom derivatives. That is, we want the compiler to auto-synthesize the derivative from the function body, but there _is_ no function body to translate. +This frequently occurs with hardware intrinsic operations that are lowered into special op-codes that map to hardware units, such as texture sampling & interpolation operations. +However, these operations do have reference 'software' implementations which can be used to produce the derivative. + +To allow user code to use the fast hardward intrinsics for the primal pass, but use synthesized derivatives for the derivative pass, we provide decorators `[PrimalSubstitute(ref-fn)]` and `[PrimalSubstituteOf(orig-fn)]` (AST Node: `PrimalSubstituteAttribute`/`PrimalSubstituteOfAttribute`, IR: `OpPrimalSubstituteDecoration`), that can be used to provide a reference implementation for the auto-diff pass. + +Example: +```C +[PrimalSubstitute(sampleTexture_ref)] +float sampleTexture(TexHandle2D tex, float2 uv) +{ + // Hardware intrinsics +} + +float sampleTexture_ref(TexHandle2D tex, float2 uv) +{ + // Reference SW implementation. +} + +void sampleTexture_bwd(TexHandle2D tex, inout DifferentialPair dp_uv, float dOut) +{ + // Backward derivate code synthesized using the reference implementation. +} +``` + +The implementation of `[PrimalSubstitute(fn)]` is relatively straightforward. When the transcribers are asked to synthesize a derivative of a function, they check for a `OpPrimalSubstituteDecoration`, and swap the current function out for the substitute function before proceeding with derivative synthesis. \ No newline at end of file diff --git a/docs/design/autodiff/ir-overview.md b/docs/design/autodiff/ir-overview.md index 2dd13b3760..a6b3ec2070 100644 --- a/docs/design/autodiff/ir-overview.md +++ b/docs/design/autodiff/ir-overview.md @@ -1,7 +1,7 @@ This documentation is intended for Slang contributors and is written from a compiler engineering point of view. For Slang users, see the user-guide at this link: [https://shader-slang.com/slang/user-guide/autodiff.html](https://shader-slang.com/slang/user-guide/autodiff.html) # Overview of Automatic Differentiation's IR Passes -In this document we will detail how Slang's auto-diff passes generate valid forward-mode and reverse-mode derivative functions. Refer to [Basics](./basics.md) for a review of two derivative propagation methods and their mathematical connotations & [Types](./types.md) for a review of how different types transform under differentiation. +In this document we will detail how Slang's auto-diff passes generate valid forward-mode and reverse-mode derivative functions. Refer to [Basics](./basics.md) for a review of the two derivative propagation methods and their mathematical connotations & [Types](./types.md) for a review of how types are handled under differentiation. ## Auto-Diff Pass Invocation Note that without an explicit auto-diff instruction (`fwd_diff(fn)` or `bwd_diff(fn)`) from the user present anywhere in the code, none of the auto-diff passes will do anything. diff --git a/docs/design/autodiff/types.md b/docs/design/autodiff/types.md new file mode 100644 index 0000000000..12b468a6e4 --- /dev/null +++ b/docs/design/autodiff/types.md @@ -0,0 +1,290 @@ + +This documentation is intended for Slang contributors and is written from a compiler engineering point of view. For Slang users, see the user-guide at this link: [https://shader-slang.com/slang/user-guide/autodiff.html](https://shader-slang.com/slang/user-guide/autodiff.html) + +Before diving into this document, please review the document on [Basics](./basics.md) for the fundamentals of automatic differentiation. + +# Components of the Type System +Here we detail the main components of the type system: the `IDifferentiable` interface to define differentiable types, the `DifferentialPair` type to carry a primal and corresponding differential in a single type. +We also detail how auto-diff operators are type-checked (the higher-order function checking system), how the `no_diff` decoration can be used to avoid differentiation through attributed types, and the derivative data flow analysis that warns the the user of unintentionally stopping derivatives. + +## `interface IDifferentiable` +Defined in core.meta.slang, `IDifferentiable` forms the basis for denoting differentiable types, both within the stdlib, and otherwise. +The definition of `IDifferentiable` is designed to encapsulate the following 4 items: +1. `Differential`: The type of the differential value of the conforming type. This allows custom data-structures to be defined to carry the differential values, which may be optimized for space instead of relying solely on compiler synthesis/ + +Since the computation of derivatives is inherently linear, we only need access to a few operations. These are: + +2. `dadd(Differential, Differential) -> Differential`: Addition of two values of the differential type. It's implementation must be associative and commutative, or the resulting derivative code may be incorrect. +3. `dzero() -> Differential`: Additive identity (i.e. the zero or empty value) that can be used to initialize variables during gradient aggregation +4. `dmul(S, Differential)`: Scalar multiplication of a real number with the differential type. It's implementation must be distributive over differential addition (`dadd`). + +Points 2, 3 & 4 are derived from the concept of vector spaces. The derivative values of any Slang function always form a vector space (https://en.wikipedia.org/wiki/Vector_space). + +### Derivative member associations +In certain scenarios, the compiler needs information on how the fields in the original type map to the differential type. Particularly, this is a problem when differentiate the implicit construction of a struct through braces (i.e. `{}`), represented by `kIROp_MakeStruct`. We provide the decorator `[DerivativeMember(DifferentialTypeName.fieldName)]` (ASTNode: DerivativeMemberAttribute, IR: kIROp_DerivativeMemberDecoration) to explicitly mark these associations. +Example +```C +struct MyType : IDifferentiable +{ + typealias Differential = MyDiffType; + float a; + + [DerivativeMember(MyDiffType.db)] + float b; + + /* ... */ +}; + +struct MyDiffType +{ + float db; +}; +``` + +### Automatic Synthesis of `IDifferentible` Conformances for Aggregate Types +It can be tedious to expect users to hand-write the associated `Differential` type, the corresponding mappings and interface methods for every user-defined `struct` type. For aggregate types, these are trivial to construct by analysing which of their components conform to `IDifferentiable`. +The synthesis proceeds in roughly the following fashion: +1. `IDifferentiable`'s components are tagged with special decorator `__builtin_requirement(unique_integer_id)` which carries an enum value from `BuiltinRequirementKind`. +2. When checking that types conform to their interfaces, if a user-provided definition does not satisfy a requirement with a built-in tag, we perform synthesis by dispatching to `trySynthesizeRequirementWitness`. +3. For _user-defined types_, Differential **types** are synthesized during conformance-checking through `trySynthesizeDifferentialAssociatedTypeRequirementWitness` by checking if each constituent type conforms to `IDifferentiable`, looking up the corresponding `Differential` type, and constructing a new aggregate type from these differential types. Note that since it is possible that a `Differential` type of a constituent member has not yet been synthesized, we have additional logic in the lookup system (`trySynthesizeRequirementWitness`) that synthesizes a temporary empty type with a `ToBeSynthesizedModifier`, so that the fields can be filled in later, when the member type undergoes conformance checking. +4. For _user-defined types_, Differential methods (`dadd`, `dzero` and `dmul`) are synthesized in `trySynthesizeDifferentialMethodRequirementWitness` by utilizing the `Differential` member and its `[DifferentialMember]` decorations to determine which fields need to be considered and the base type to use for each field. There are two synthesis patterns. The fully-inductive pattern is used for `dadd` and `dzero` which works by calling `dadd` and `dzero` respectively on the individual fields of the `Differential` type under consideration. +Example: +```C +// Synthesized from "struct T {FT1 field1; FT2 field2;}" +T.Differential dadd(T.Differential a, T.Differential b) +{ + return Differential( + FT1.dadd(a.field1, b.field1), + FT2.dadd(a.field2, b.field2), + ) +} +``` +On the other hand, `dmul` uses the fixed-first arg pattern since the first argument is a common scalar, and proceeds inductively on all the other args. +Example: +```C +// Synthesized from "struct T {FT1 field1; FT2 field2;}" +T.Differential dmul(S s, T.Differential a) +{ + return Differential( + FT1.dmul(s, a.field1), + FT2.dmul(s, a.field2), + ) +} +``` +5. During auto-diff, the compiler can sometimes synthesize new aggregate types. The most common case is the intermediate context type (`kIROp_BackwardDerivativeIntermediateContextType`), which is lowered into a standard struct once the auto-diff pass is complete. It is important to synthesize the `IDifferentiable` conformance for such types since they may be further differentiated (through higher-order differentiation). This implementation is contained in `fillDifferentialTypeImplementationForStruct(...)` and is roughly analogous to the AST-side synthesis. + +### Differentiable Type Dictionaries +During auto-diff, the IR passes frequently need to perform lookups to check if an `IRType` is differentiable, and retreive references to the corresponding `IDifferentiable` methods. These lookups also need to work on generic parameters (that are defined inside generic containers), and existential types that are interface-typed parameters. + +To accomodate this range of different type systems, Slang uses a type dictionary system that associates a dictionary of relevant types with each function. This works in the following way: +1. When `CheckTerm()` is called on an expression within a function that is marked differentiable (`[Differentiable]`), we check if the resolved type conforms to `IDifferentiable`. If so, we add this type to the dictionary along with the witness to its differentiability. The dictionary is currently located on `DifferentiableAttribute` that corresponds to the `[Differentiable]` modifier. + +2. When lowering to IR, we create a `DifferentiableTypeDictionaryDecoration` which holds the IR versions of all the types in the dictionary as well as a reference to their `IDifferentiable` witness tables. + +3. When synthesizing the derivative code, all the transcriber passes use `DifferentiableTypeConformanceContext::setFunc()` to load the type dictionary. `DifferentiableTypeConformanceContext` then provides convenience functions to lookup differentiable types, appropriate `IDifferentiable` methods, and construct appropriate `DifferentialPair`s. + +### Looking up Differential Info on _Generic_ types +Generically defined types are also lowered into the differentiable type dictionary, but rather than having a concrete witness table, the witness table is itself a parameter. When auto-diff passes need to find the differential type or place a call to the IDifferentiable methods, this is turned into a lookup on the witness table parameter (i.e. `Lookup(, )`). Note that these lookups instructions are inserted into the generic parent container rather than the inner most function. +Example: +```C +T myFunc(T a) +{ + return a * a; +} + +// Reverse-mode differentiated version +void bwd_myFunc( + inout DifferentialPair dpa, + T.Differential dOut) // T.Differential is Lookup('Differential', T_Witness_Table) +{ + T.Differential da = T.dzero(); // T.dzero is Lookup('dzero', T_Witness_Table) + + da = T.dadd(dpa.p * dOut, da); // T.dadd is Lookup('dadd', T_Witness_Table) + da = T.dadd(dpa.p * dOut, da); + + dpa = diffPair(dpa.p, da); +} +``` + +### Looking up Differential Info on _Existential_ types +Existential types are interface-typed values, where there are multiple possible implementations at run-time. The existential type carries information about the concrete type at run-time and is effectively a 'tagged union' of all possible types. + +#### Differential type of an Existential +The differential type of an existential type is tricky to define since our type system's only restriction on the `.Differential` type is that it also conforms to `IDifferentiable`. The differential type of any interface `IInterface : IDifferentiable` is therefore the interface type `IDifferentiable`. This is problematic since Slang generally requires a static `anyValueSize` that must be a strict upper bound on the sizes of all conforming types (since this size is used to allocate space for the union). Since `IDifferentiable` is defined in the stdlib `core.meta.slang` and can be used by the user, it is impossible to define a reliable bound. +We instead provide a new **any-value-size inference** pass (`slang-ir-any-value-inference.h`/`slang-ir-any-value-inference.cpp`) that assembles a list of types that conform to each interface in the final linked IR and determines a relevant upper bound. This allows us to ignore types that conform to `IDifferentiable` but aren't used in the final IR, and generate a tighter upper bound. + +**Future work:** +This approach, while functional, creates a locality problem since the size of `IDifferentiable` is the max of _all_ types that conform to `IDifferentiable` in visible modules, even though we only care about the subset of types that appear as `T.Differential` for `T : IInterface`. The reason for this problem is that upon performing an associated type lookup, the Slang IR drops all information about the base interface that the lookup starts from and only considers the constraint interface (in this case `Differential : IDifferentiable`). +There are several ways to resolve this issue, including (i) a static analysis pass that determines the possible set of types at each use location and propagates them to determine a narrower set of types, or (ii) generic (or 'parameterized') interfaces, such as `IDifferentiable` where each version can have a different set of conforming types. + + + +Example: +```C +interface IInterface : IDifferentiable +{ + [Differentiable] + This foo(float val); + + [Differentiable] + float bar(); +}; + +float myFunc(IInterface obj, float a) +{ + IInterface k = obj.foo(a); + return k.bar(); +} + +// Reverse-mode differentiated version (in pseudo-code corresponding to IR, some of these will get lowered further) +void bwd_myFunc( + inout DifferentialPair dpobj, + inout DifferentialPair dpa, + float.Differential dOut) // T.Differential is Lookup('Differential', T_Witness_Table) +{ + // Primal pass.. + IInterface obj = dpobj.p; + IInterface k = obj.foo(a); + // ..... + + // Backward pass + DifferentialPair dpk = diffPair(k); + bwd_bar(dpk, dOut); + IDifferentiable dk = dpk.d; // Differential of `IInterface` is `IDifferentiable` + + DifferentialPair dp = diffPair(dpobj.p); + bwd_foo(dpobj, dpa, dk); +} + +``` + +#### Looking up `dadd()` and `dzero()` on Existential Types +There are two distinct cases for lookup on an existential type. The more common case is the closed-box existential type represented simply by an interface. Every value of this type contains a type identifier & a witness table identifier along with the value itself. The less common case is when the function calls are performed directly on the value after being cast to the concrete type. + +**`dzero()` for "closed" Existential type: The `NullDifferential` Type** +For concrete and even generic types, we can initialize a derivative accumulator variable by calling the appropriate `Type.dzero()` method. This is unfortunately not possible when initializing an existential differential (which is currently of type `IDifferentiable`), since we must also initialize the type-id of this existential to one of the implementations, but we do not know which one yet since that is a run-time value that only becomes known after the first differential value is generated. + +To get around this issue, we declare a special type called `NullDifferential` that acts as a "none type" for any `IDifferentiable` existential object. + +**`dadd()` for "closed" Existential types: `__existential_dadd`** +We cannot directly use `dadd()` on two existential differentials of type `IDifferentiable` because we must handle the case where one of them is of type `NullDifferential` and `dadd()` is only defined for differentials of the same type. +We handle this currently by synthesizing a special method called `__existential_dadd` (`getOrCreateExistentialDAddMethod` in `slang-ir-autodiff.cpp`) that performs a run-time type-id check to see if one of the operand is of type `NullDifferential` and returns the other operand if so. If both are non-null, we dispatch to the appropriate `dadd` for the concrete type. + +**`dadd()` and `dzero()` for "open" Existential types** +If we are dealing with values of the concrete type (i.e. the opened value obtained through `ExtractExistentialValue(ExistentialParam)`). Then we can perform lookups in the same way we do for generic type. All existential parameters come with a witness table. We insert instructions to extract this witness table and perform lookups accordingly. That is, for `dadd()`, we use `Lookup('dadd', ExtractExistentialWitnessTable(ExistentialParam))` and place a call to the result. + +## `struct DifferentialPair` +The second major component is `DifferentialPair` that represents a pair of a primal value and its corresponding differential value. +The differential pair is primarily used for passing & receiving derivatives from the synthesized derivative methods, as well as for block parameters on the IR-side. +Both `fwd_diff(fn)` and `bwd_diff(fn)` act as function-to-function transformations, and so the Slang front-end translates the type of `fn` to its derivative version so the arguments can be type checked. + +### Pair type lowering. +The differential pair type is a special type throughout the AST and IR passes (AST Node: `DifferentialPairType`, IR: `kIROp_DifferentialPairType`) because of its use in front-end semantic checking and when synthesizing the derivative code for the functions. Once the auto-diff passes are complete, the pair types are lowering into simple `struct`s so they can be easily emitted (`DiffPairLoweringPass` in `slang-ir-autodiff-pairs.cpp`). +We also define additional instructions for pair construction (`kIROp_MakeDifferentialPair`) and extraction (`kIROp_DifferentialPairGetDifferential` & `kIROp_DifferentialPairGetPrimal`) which are lowered into struct construction and field accessors, respectively. + +### "User-code" Differential Pairs +Just as we use special IR codes for differential pairs because they have special handling in the IR passes, sometimes differential pairs should be _treated as_ regular struct types during the auto-diff passes. +This happens primarily during higher-order differentiation when the user wishes to differentiate the same code multiple times. +Slang's auto-diff approaches this by rewriting all the relevant differential pairs into 'irrelevant' differential pairs (`kIROp_DifferentialPairUserCode`) and 'irrelevant' accessors (`kIROp_DifferentialPairGetDifferentialUserCode`, `kIROp_DifferentialPairGetPrimalUserCode`) at the end of **each auto-diff iteration** so that the next iteration treats these as regular differentiable types. +The user-code versions are also lowered into `struct`s in the same way. + +## Type Checking of Auto-Diff Calls (and other _higher-order_ functions) +Since `fwd_diff` and `bwd_diff` are represented as higher order functions that take a function as an input and return the derivative function, the front-end semantic checking needs some notion of higher-order functions to be able to check and lower the calls into appropriate IR. + +### Higher-order Invocation Base: `HigherOrderInvokeExpr` +All higher order transformations derive from `HigherOrderInvokeExpr`. For auto-diff there are two possible expression classes `ForwardDifferentiateExpr` and `BackwardDifferentiateExpr`, both of which derive from this parent expression. + +### Higher-order Function Call Checking: `HigherOrderInvokeExprCheckingActions` +Resolving the concrete method is not a trivial issue in Slang, given its support for overloading, type coercion and more. This becomes more complex with the presence of a function transformation in the chain. +For example, if we have `fwd_diff(f)(DiffPair(...), DiffPair(...))`, we would need to find the correct match for `f` based on its post-transform argument types. + +To facilitate this we use the following workflow: +1. The `HigherOrderInvokeExprCheckingActions` base class provides a mechanism for different higher-order expressions to implement their type translation (i.e. what is the type of the transformed function). +2. The checking mechanism passes all detected overloads for `f` through the type translation and assembles a new group out of the results (the new functions are 'temporary') +3. This new group is used by `ResolveInvoke` when performing overload resolution and type coercion using the user-provided argument list. +4. The resolved signature (if there is one) is then replaced with the corresponding function reference and wrapped in the appropriate higher-order invoke. + +**Example:** + +Let's say we have two functions with the same name `f`: (`int -> float`, `double, double -> float`) +and we want to resolve `fwd_diff(f)(DiffPair(1.0, 0.0), DiffPair(0.0, 1.0))`. + +The higher-order checking actions will synthesize the 'temporary' group of translated signatures (`int -> DiffPair`, `DiffPair, DiffPair -> DiffPair`). +Invoke resolution will then narrow this down to a single match (`DiffPair, DiffPair -> DiffPair`) by automatically casting the `float`s to `double`s. Once the resolution is complete, +we return `InvokeExpr(ForwardDifferentiateExpr(f : double, double -> float), casted_args)` by wrapping the corresponding function in the corresponding higher-order expr + +## Attributed Types (`no_diff` parameters) + +Often, it will be necessary to prevent gradients from propagating through certain parameters, for correctness reasons. For example, values representing random samples are often not differentiated since the result may be mathematically incorrect. + +Slang provides the `no_diff` operator to mark parameters as non-differentiable, even if they use a type that conforms to `IDifferentiable` + +```C +float myFunc(float a, no_diff float b) +{ + return a * b; +} + +// Resulting fwd-mode derivative: +DiffPair myFunc(DiffPair dpa, float b) +{ + return diffPair(dpa.p * b, dpa.d * b); +} +``` + +Slang uses _OpAttributedType_ to denote the IR type of such parameters. For example, the lowered type of `b` in the above example is `OpAttributedType(OpFloat, OpNoDiffAttr)`. In the front-end, this is represented through the `ModifiedType` AST node. + +Sometimes, this additional layer can get in the way of things like type equality checks and other mechanisms where the `no_diff` is irrelevant. Thus, we provide the `unwrapAttributedType` helper to remove attributed type layers for such cases. + +## Derivative Data-Flow Analysis +Slang has a derivative data-flow analysis pass that is performed on a per-function basis immediately after lowering to IR and before the linking step (`slang-ir-check-differentiability.h`/`slang-ir-check-differentiability.cpp`). + +The job of this pass is to enforce that instructions that are of a differentiable type will propagate a derivatives, unless explicitly dropped by the user through `detach()` or `no_diff`. The reason for this is that Slang requires functions to be decorated with `[Differentiable]` to allow it to propagate derivatives. Otherwise, the function is considered non-differentiable, and effectively produces a 0 derivative. This can lead to frustrating situations where a function may be dropping non-differentiable on purpose. Example: +```C +float nonDiffFunc(float x) +{ + /* ... */ +} + +float differentiableFunc(float x) // Forgot to annotate with [Differentiable] +{ + /* ... */ +} + +float main(float x) +{ + // User doesn't realise that the function that is supposed to be differentiable is not + // getting differentiated, because the types here are all 'float'. + // + return nonDiffFunc(x) * differentiableFunc(x); +} +``` + +The data-flow analysis step enforces that non-differentiable functions used in a differentiable context should get their derivative dropped explicitly. That way, it is clear to the user whether a call is getting differentiated or dropped. + +Same example with `no_diff` enforcement: +```C +float nonDiffFunc(float x) +{ + /* ... */ +} + +[Differentiable] +float differentiableFunc(float x) +{ + /* ... */ +} + +float main(float x) +{ + return no_diff(nonDiffFunc(x)) * differentiableFunc(x); +} +``` + +A `no_diff` can only be used directly on a function call, and turns into a `TreatAsDifferentiableDecoration` that indicates that the function will not produce a derivative. + +The derivative data-flow analysis pass works similar to a standard data-flow pass: +1. We start by assembling a set of instructions that 'produce' derivatives by starting with the parameters of differentiable types (and without an explicit `no_diff`), and propagating them through each instruction in the block. An inst carries a derivative if there one of its operands carries a derivative, and the result type is differentiable. +2. We then assemble a set of instructions that expect a derivative. These are differentiable operands of differentiable functions (unless they have been marked by `no_diff`). We then reverse-propagate this set by adding in all differentiable operands (and repeating this process). +3. During this reverse-propagation, if there is any `OpCall` in the 'expect' set that is not also in the 'produce' set, then we have a situation where the gradient hasn't been explicitly dropped, and we create a user diagnostic. From 1caef5907d0b0f16f686a8fcca479c6afc09f146 Mon Sep 17 00:00:00 2001 From: Jay Kwak <82421531+jkwak-work@users.noreply.github.com> Date: Tue, 9 Jul 2024 12:45:57 -0700 Subject: [PATCH 09/10] Fix Lexer to recognize swizzling on an integer scalar value (#4515) * Fix Lexer to recognize swizzling on an integer scalar value Close #4413 --- source/compiler-core/slang-lexer.cpp | 78 +++++++++++++------- tests/hlsl-intrinsic/scalar-swizzling.slang | 80 +++++++++++++++++++++ 2 files changed, 133 insertions(+), 25 deletions(-) create mode 100644 tests/hlsl-intrinsic/scalar-swizzling.slang diff --git a/source/compiler-core/slang-lexer.cpp b/source/compiler-core/slang-lexer.cpp index 5954dc668b..8c428159cf 100644 --- a/source/compiler-core/slang-lexer.cpp +++ b/source/compiler-core/slang-lexer.cpp @@ -173,38 +173,49 @@ namespace Slang // Look ahead one code point, dealing with complications like // escaped newlines. - static int _peek(Lexer* lexer) + static int _peek(Lexer* lexer, int offset = 0) { - // Look at the next raw byte, and decide what to do - int c = _peekRaw(lexer); + int pos = 0; + int c = kEOF; - if(c == '\\') + do { - // We might have a backslash-escaped newline. - // Look at the next byte (if any) to see. - // - // Note(tfoley): We are assuming a null-terminated input here, - // so that we can safely look at the next byte without issue. - int d = lexer->m_cursor[1]; - switch (d) + if (lexer->m_cursor + pos == lexer->m_end) + return kEOF; + + c = lexer->m_cursor[pos++]; + + if (c == '\\') { - case '\r': case '\n': + // We might have a backslash-escaped newline. + // Look at the next byte (if any) to see. + // + // Note(tfoley): We are assuming a null-terminated input here, + // so that we can safely look at the next byte without issue. + int d = lexer->m_cursor[pos++]; + switch (d) + { + case '\r': case '\n': { // The newline was escaped, so return the code point after *that* - int e = lexer->m_cursor[2]; + int e = lexer->m_cursor[pos++]; if ((d ^ e) == ('\r' ^ '\n')) - return lexer->m_cursor[3]; - return e; + c = lexer->m_cursor[pos++]; + else + c = e; + break; } - default: - break; + default: + break; + } } - } - // TODO: handle UTF-8 encoding for non-ASCII code points here + // TODO: handle UTF-8 encoding for non-ASCII code points here + + // Default case is to just hand along the byte we read as an ASCII code point. + } while (offset--); - // Default case is to just hand along the byte we read as an ASCII code point. return c; } @@ -494,10 +505,19 @@ namespace Slang if( _peek(lexer) == '.' ) { - tokenType = TokenType::FloatingPointLiteral; + switch (_peek(lexer, 1)) + { + // 123.xxxx or 123.rrrr + case 'x': + case 'r': + break; - _advance(lexer); - _lexDigits(lexer, base); + default: + tokenType = TokenType::FloatingPointLiteral; + + _advance(lexer); + _lexDigits(lexer, base); + } } if( _maybeLexNumberExponent(lexer, base)) @@ -1089,8 +1109,16 @@ namespace Slang return _maybeLexNumberSuffix(lexer, TokenType::IntegerLiteral); case '.': - _advance(lexer); - return _lexNumberAfterDecimalPoint(lexer, 10); + switch (_peek(lexer, 1)) + { + // 0.xxxx or 0.rrrr + case 'x': + case 'r': + return _maybeLexNumberSuffix(lexer, TokenType::IntegerLiteral); + default: + _advance(lexer); + return _lexNumberAfterDecimalPoint(lexer, 10); + } case 'x': case 'X': _advance(lexer); diff --git a/tests/hlsl-intrinsic/scalar-swizzling.slang b/tests/hlsl-intrinsic/scalar-swizzling.slang new file mode 100644 index 0000000000..9ca0247553 --- /dev/null +++ b/tests/hlsl-intrinsic/scalar-swizzling.slang @@ -0,0 +1,80 @@ +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-vk -compute -shaderobj -output-using-type + +//TEST_INPUT: ubuffer(data=[0], stride=4):out,name outputBuffer +RWStructuredBuffer outputBuffer; + +[numthreads(1,1,1)] +void computeMain() +{ + bool result = true + && (0.x is int) + && 0 == 0.x + && all(uint2(0) == 0.xx) + && all(uint3(0) == 0.xxx) + && all(uint4(0) == 0.xxxx) + && (0.r is int) + && 0 == 0.r + && all(uint2(0) == 0.rr) + && all(uint3(0) == 0.rrr) + && all(uint4(0) == 0.rrrr) + + && (123.x is int) + && 123 == 123.x + && all(uint2(123) == 123.xx) + && all(uint3(123) == 123.xxx) + && all(uint4(123) == 123.xxxx) + && (123.r is int) + && 123 == 123.r + && all(uint2(123) == 123.rr) + && all(uint3(123) == 123.rrr) + && all(uint4(123) == 123.rrrr) + + && (0.f.x is float) + && 0.f == 0.f.x + && all(float2(0.f) == 0.f.xx) + && all(float3(0.f) == 0.f.xxx) + && all(float4(0.f) == 0.f.xxxx) + && (0.f.r is float) + && 0.f == 0.f.r + && all(float2(0.f) == 0.f.rr) + && all(float3(0.f) == 0.f.rrr) + && all(float4(0.f) == 0.f.rrrr) + + && (123.f.x is float) + && 123.f == 123.f.x + && all(float2(123.f) == 123.f.xx) + && all(float3(123.f) == 123.f.xxx) + && all(float4(123.f) == 123.f.xxxx) + && (123.f.r is float) + && 123.f == 123.f.r + && all(float2(123.f) == 123.f.rr) + && all(float3(123.f) == 123.f.rrr) + && all(float4(123.f) == 123.f.rrrr) + + && (0..x is float) + && 0.f == 0..x + && all(float2(0.f) == 0..xx) + && all(float3(0.f) == 0..xxx) + && all(float4(0.f) == 0..xxxx) + && (0..r is float) + && 0.f == 0..r + && all(float2(0.f) == 0..rr) + && all(float3(0.f) == 0..rrr) + && all(float4(0.f) == 0..rrrr) + + && (123..x is float) + && 123.f == 123..x + && all(float2(123.f) == 123..xx) + && all(float3(123.f) == 123..xxx) + && all(float4(123.f) == 123..xxxx) + && (123..r is float) + && 123.f == 123..r + && all(float2(123.f) == 123..rr) + && all(float3(123.f) == 123..rrr) + && all(float4(123.f) == 123..rrrr) + ; + + //CHK:1 + outputBuffer[0] = int(result); +} + From 0e6c5c518953141f31c09e5f10d3939054f9b1ee Mon Sep 17 00:00:00 2001 From: venkataram-nv Date: Tue, 9 Jul 2024 16:18:36 -0700 Subject: [PATCH 10/10] Warnings for uninitialized values (#4530) This extends the code for handling uninitialized output parameters. Still needs to handle generic templates and assignment of uninitialized values more carefully. The file containing the relevant code are now in source/slang/slang-ir-use-uninitialized-values.cpp rather than the previous source/slang/slang-ir-use-uninitialized-out-param.h and the top-level function is now checkForUsingUinitializedValues. Additionally a rudimentary test shader has been added for this case, which replaces the old file for out params only; tests/diagnositcs/uninitialized-out.slang becomes tests/diagnositcs/uninitialized.slang. What this does not implement (could be future PRs): * Checking uninitialized fields within constructors * Partially uninitialized values with respect to data structure (e.g. arrays/structs/vector types) * Partially uninitialized values with respect to control flow (e.g. if/else/loop) --- build/visual-studio/slang/slang.vcxproj | 4 +- .../visual-studio/slang/slang.vcxproj.filters | 4 +- source/slang/slang-diagnostic-defs.h | 10 +- .../slang-ir-use-uninitialized-out-param.cpp | 150 ------- .../slang-ir-use-uninitialized-values.cpp | 382 ++++++++++++++++++ ....h => slang-ir-use-uninitialized-values.h} | 2 +- source/slang/slang-lower-to-ir.cpp | 6 +- tests/bugs/gh-4434.slang | 24 +- tests/bugs/gh-4441.slang | 24 +- .../ctor-ordinary-retval-legal.slang | 3 +- tests/diagnostics/uninitialized-out.slang | 57 --- tests/diagnostics/uninitialized.slang | 263 ++++++++++++ 12 files changed, 687 insertions(+), 242 deletions(-) delete mode 100644 source/slang/slang-ir-use-uninitialized-out-param.cpp create mode 100644 source/slang/slang-ir-use-uninitialized-values.cpp rename source/slang/{slang-ir-use-uninitialized-out-param.h => slang-ir-use-uninitialized-values.h} (80%) delete mode 100644 tests/diagnostics/uninitialized-out.slang create mode 100644 tests/diagnostics/uninitialized.slang diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj index cbafa99752..13eea9f0a4 100644 --- a/build/visual-studio/slang/slang.vcxproj +++ b/build/visual-studio/slang/slang.vcxproj @@ -492,7 +492,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla - + @@ -737,7 +737,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla - + diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters index ab272f7e13..87a46477d6 100644 --- a/build/visual-studio/slang/slang.vcxproj.filters +++ b/build/visual-studio/slang/slang.vcxproj.filters @@ -564,7 +564,7 @@ Header Files - + Header Files @@ -1295,7 +1295,7 @@ Source Files - + Source Files diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 496fb7e328..7ebe77a8f2 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -598,7 +598,7 @@ DIAGNOSTIC(39999, Error, unableToFindSymbolInModule, "unable to find the mangled DIAGNOSTIC(39999, Error, overloadedParameterToHigherOrderFunction, "passing overloaded functions to higher order functions is not supported") -// 38xxx +// 38xxx DIAGNOSTIC(38000, Error, entryPointFunctionNotFound, "no function found matching entry point name '$0'") DIAGNOSTIC(38001, Error, ambiguousEntryPoint, "more than one function matches entry point name '$0'") @@ -735,9 +735,11 @@ DIAGNOSTIC(41010, Warning, missingReturn, "control flow may reach end of non-'vo DIAGNOSTIC(41011, Error, profileIncompatibleWithTargetSwitch, "__target_switch has no compatable target with current profile '$0'") DIAGNOSTIC(41012, Warning, profileImplicitlyUpgraded, "user set `profile` had an implicit upgrade applied to it, atoms added: '$0'") DIAGNOSTIC(41012, Error, profileImplicitlyUpgradedRestrictive, "user set `profile` had an implicit upgrade applied to it, atoms added: '$0'") -DIAGNOSTIC(41015, Error, usingUninitializedValue, "use of uninitialized value '$0'") -DIAGNOSTIC(41016, Warning, returningWithUninitializedOut, "returning without initializing out parameter '$0'") -DIAGNOSTIC(41017, Warning, returningWithPartiallyUninitializedOut, "returning without fully initializing out parameter '$0'") +DIAGNOSTIC(41015, Warning, usingUninitializedOut, "use of uninitialized out parameter '$0'") +DIAGNOSTIC(41016, Warning, usingUninitializedVariable, "use of uninitialized variable '$0'") +DIAGNOSTIC(41017, Warning, usingUninitializedGlobalVariable, "use of uninitialized global variable '$0'") +DIAGNOSTIC(41018, Warning, returningWithUninitializedOut, "returning without initializing out parameter '$0'") +DIAGNOSTIC(41019, Warning, returningWithPartiallyUninitializedOut, "returning without fully initializing out parameter '$0'") DIAGNOSTIC(41011, Error, typeDoesNotFitAnyValueSize, "type '$0' does not fit in the size required by its conforming interface.") DIAGNOSTIC(41012, Note, typeAndLimit, "sizeof($0) is $1, limit is $2") diff --git a/source/slang/slang-ir-use-uninitialized-out-param.cpp b/source/slang/slang-ir-use-uninitialized-out-param.cpp deleted file mode 100644 index 7e3ef9ca2d..0000000000 --- a/source/slang/slang-ir-use-uninitialized-out-param.cpp +++ /dev/null @@ -1,150 +0,0 @@ -#include "slang-ir-use-uninitialized-out-param.h" -#include "slang-ir-util.h" -#include "slang-ir-reachability.h" - -namespace Slang -{ - class DiagnosticSink; - struct IRModule; - - struct StoreSite - { - IRInst* storeInst; - IRInst* address; - }; - - void checkForUsingUninitializedOutParams(IRFunc* func, DiagnosticSink* sink) - { - List outParams; - auto firstBlock = func->getFirstBlock(); - if (!firstBlock) - return; - - ReachabilityContext reachability(func); - - for (auto param : firstBlock->getParams()) - { - if (auto outType = as(param->getFullType())) - { - // Don't check `out Vertices` or `out Indices` parameters - // in mesh shaders. - // TODO: we should find a better way to represent these mesh shader - // parameters so they conform to the initialize before use convention. - // For example, we can use a `OutputVetices` and `OutputIndices` type - // to represent an output, like `OutputPatch` in domain shader. - // For now, we just skip the check for these parameters. - switch (outType->getValueType()->getOp()) - { - case kIROp_VerticesType: - case kIROp_IndicesType: - case kIROp_PrimitivesType: - continue; - default: - break; - } - } - else - { - continue; - } - List addresses; - addresses.add(param); - List stores; - // Collect all sub-addresses from the param. - for (Index i = 0; i < addresses.getCount(); i++) - { - auto addr = addresses[i]; - for (auto use = addr->firstUse; use; use = use->nextUse) - { - switch (use->getUser()->getOp()) - { - case kIROp_GetElementPtr: - case kIROp_FieldAddress: - addresses.add(use->getUser()); - break; - case kIROp_Store: - case kIROp_SwizzledStore: - // If we see a store of this address, add it to stores set. - if (use == use->getUser()->getOperands()) - stores.add(StoreSite{ use->getUser(), addr }); - break; - case kIROp_Call: - case kIROp_SPIRVAsm: - // If we see a call using this address, treat it as a store. - stores.add(StoreSite{ use->getUser(), addr }); - break; - case kIROp_SPIRVAsmOperandInst: - stores.add(StoreSite{ use->getUser()->getParent(), addr}); - break; - } - } - } - // Check all address loads. - List loadsAndReturns; - for (auto addr : addresses) - { - for (auto use = addr->firstUse; use; use = use->nextUse) - { - if (auto load = as(use->getUser())) - loadsAndReturns.add(load); - } - } - for(const auto& b : func->getBlocks()) - { - auto t = as(b->getTerminator()); - if (!t) continue; - loadsAndReturns.add(t); - } - - for (auto store : stores) - { - // Remove insts from `loads` that is reachable from the store. - for (Index i = 0; i < loadsAndReturns.getCount();) - { - auto load = as(loadsAndReturns[i]); - if (load && !canAddressesPotentiallyAlias(func, store.address, load->getPtr())) - continue; - if (reachability.isInstReachable(store.storeInst, loadsAndReturns[i])) - { - loadsAndReturns.fastRemoveAt(i); - } - else - { - i++; - } - } - } - // If there are any loads left, it means they are using uninitialized out params. - for (auto load : loadsAndReturns) - { - sink->diagnose( - load, - load->m_op == kIROp_Return - ? Diagnostics::returningWithUninitializedOut - : Diagnostics::usingUninitializedValue, - param); - } - } - } - - void checkForUsingUninitializedOutParams( - IRModule* module, - DiagnosticSink* sink) - { - for (auto inst : module->getGlobalInsts()) - { - if (auto func = as(inst)) - { - checkForUsingUninitializedOutParams(func, sink); - } - else if (auto generic = as(inst)) - { - auto retVal = findGenericReturnVal(generic); - if (auto funcVal = as(retVal)) - { - checkForUsingUninitializedOutParams(funcVal, sink); - } - } - } - } -} diff --git a/source/slang/slang-ir-use-uninitialized-values.cpp b/source/slang/slang-ir-use-uninitialized-values.cpp new file mode 100644 index 0000000000..762773ad4b --- /dev/null +++ b/source/slang/slang-ir-use-uninitialized-values.cpp @@ -0,0 +1,382 @@ +#include "slang-ir-use-uninitialized-values.h" +#include "slang-ir-insts.h" +#include "slang-ir-reachability.h" +#include "slang-ir.h" + +namespace Slang +{ + static bool isMetaOp(IRInst* inst) + { + switch (inst->getOp()) + { + // These instructions only look at the parameter's type, + // so passing an undefined value to them is permissible + case kIROp_IsBool: + case kIROp_IsInt: + case kIROp_IsUnsignedInt: + case kIROp_IsSignedInt: + case kIROp_IsHalf: + case kIROp_IsFloat: + case kIROp_IsVector: + case kIROp_GetNaturalStride: + case kIROp_TypeEquals: + return true; + default: + break; + } + + return false; + } + + // Casting to IRUndefined is currently vacuous + // (e.g. any IRInst can be cast to IRUndefined) + static bool isUndefinedValue(IRInst* inst) + { + return (inst->m_op == kIROp_undefined); + } + + static bool isUndefinedParam(IRParam* param) + { + auto outType = as(param->getFullType()); + if (!outType) + return false; + + // Don't check `out Vertices` or `out Indices` parameters + // in mesh shaders. + // TODO: we should find a better way to represent these mesh shader + // parameters so they conform to the initialize before use convention. + // For example, we can use a `OutputVetices` and `OutputIndices` type + // to represent an output, like `OutputPatch` in domain shader. + // For now, we just skip the check for these parameters. + switch (outType->getValueType()->getOp()) + { + case kIROp_VerticesType: + case kIROp_IndicesType: + case kIROp_PrimitivesType: + return false; + default: + break; + } + + return true; + } + + static bool isAliasable(IRInst* inst) + { + switch (inst->getOp()) + { + // These instructions generate (implicit) references to inst + case kIROp_FieldExtract: + case kIROp_FieldAddress: + case kIROp_GetElement: + case kIROp_GetElementPtr: + return true; + default: + break; + } + + return false; + } + + static bool isDifferentiableFunc(IRInst* func) + { + for (auto decor = func->getFirstDecoration(); decor; decor = decor->getNextDecoration()) + { + switch (decor->getOp()) + { + case kIROp_ForwardDerivativeDecoration: + case kIROp_ForwardDifferentiableDecoration: + case kIROp_BackwardDerivativeDecoration: + case kIROp_BackwardDifferentiableDecoration: + case kIROp_UserDefinedBackwardDerivativeDecoration: + return true; + default: + break; + } + } + + return false; + } + + static bool canIgnoreType(IRType* type) + { + if (as(type)) + return true; + + // For structs, ignore if its empty + if (as(type)) + return (type->getFirstChild() == nullptr); + + // Nothing to initialize for a pure interface + if (as(type)) + return true; + + // For pointers, check the value type (primarily for globals) + if (auto ptr = as(type)) + return canIgnoreType(ptr->getValueType()); + + // In the case of specializations, check returned type + if (auto spec = as(type)) + { + IRInst* base = spec->getBase(); + IRGeneric* generic = as(base); + IRInst* inner = findInnerMostGenericReturnVal(generic); + IRType* innerType = as(inner); + return canIgnoreType(innerType); + } + + return false; + } + + static List getAliasableInstructions(IRInst* inst) + { + List addresses; + + addresses.add(inst); + for (auto use = inst->firstUse; use; use = use->nextUse) + { + IRInst* user = use->getUser(); + + // Meta instructions only use the argument type + if (isMetaOp(user) || !isAliasable(user)) + continue; + + addresses.addRange(getAliasableInstructions(user)); + } + + return addresses; + } + + static void collectLoadStore(List& stores, List& loads, IRInst* user) + { + // Meta intrinsics (which evaluate on type) do nothing + if (isMetaOp(user)) + return; + + // Ignore instructions generating more aliases + if (isAliasable(user)) + return; + + switch (user->getOp()) + { + case kIROp_loop: + case kIROp_unconditionalBranch: + // TODO: Ignore branches for now + return; + + // These instructions will store data... + case kIROp_Store: + case kIROp_SwizzledStore: + // TODO: for calls, should make check that the + // function is passing as an out param + case kIROp_Call: + case kIROp_SPIRVAsm: + case kIROp_GenericAsm: + // For now assume that __intrinsic_asm blocks will do the right thing... + stores.add(user); + break; + + case kIROp_SPIRVAsmOperandInst: + // For SPIRV asm instructions, need to check out the entire + // block when doing reachability checks + stores.add(user->getParent()); + break; + + case kIROp_MakeExistential: + case kIROp_MakeExistentialWithRTTI: + // For specializing generic structs + stores.add(user); + break; + + // ... and the rest will load/use them + default: + loads.add(user); + break; + } + } + + static void cancelLoads(ReachabilityContext &reachability, const List& stores, List& loads) + { + // Remove all loads which are reachable from stores + for (auto store : stores) + { + for (Index i = 0; i < loads.getCount(); ) + { + if (reachability.isInstReachable(store, loads[i])) + loads.fastRemoveAt(i); + else + i++; + } + } + } + + static List getUnresolvedParamLoads(ReachabilityContext &reachability, IRFunc* func, IRInst* inst) + { + // Collect all aliasable addresses + auto addresses = getAliasableInstructions(inst); + + // Partition instructions + List stores; + List loads; + + for (auto alias : addresses) + { + // TODO: Mark specific parts assigned to for partial initialization checks + for (auto use = alias->firstUse; use; use = use->nextUse) + { + IRInst* user = use->getUser(); + collectLoadStore(stores, loads, user); + } + } + + // Only for out params we shall add all returns + for (const auto& b : func->getBlocks()) + { + auto t = as(b->getTerminator()); + if (!t) + continue; + + loads.add(t); + } + + cancelLoads(reachability, stores, loads); + + return loads; + } + + static List getUnresolvedVariableLoads(ReachabilityContext &reachability, IRInst* inst) + { + auto addresses = getAliasableInstructions(inst); + + // Partition instructions + List stores; + List loads; + + for (auto alias : addresses) + { + for (auto use = alias->firstUse; use; use = use->nextUse) + { + IRInst* user = use->getUser(); + collectLoadStore(stores, loads, user); + } + } + + cancelLoads(reachability, stores, loads); + + return loads; + } + + static void checkUninitializedValues(IRFunc* func, DiagnosticSink* sink) + { + if (isDifferentiableFunc(func)) + return; + + auto firstBlock = func->getFirstBlock(); + if (!firstBlock) + return; + + ReachabilityContext reachability(func); + + // Check out parameters + for (auto param : firstBlock->getParams()) + { + if (!isUndefinedParam(param)) + continue; + + auto loads = getUnresolvedParamLoads(reachability, func, param); + for (auto load : loads) + { + sink->diagnose(load, + as (load) + ? Diagnostics::returningWithUninitializedOut + : Diagnostics::usingUninitializedOut, + param); + } + } + + // Check ordinary instructions + for (auto inst = firstBlock->getFirstInst(); inst; inst = inst->getNextInst()) + { + if (!isUndefinedValue(inst)) + continue; + + IRType* type = inst->getFullType(); + if (canIgnoreType(type)) + continue; + + auto loads = getUnresolvedVariableLoads(reachability, inst); + for (auto load : loads) + { + sink->diagnose(load, + Diagnostics::usingUninitializedVariable, + inst); + } + } + } + + static void checkUninitializedGlobals(IRGlobalVar* variable, DiagnosticSink* sink) + { + IRType* type = variable->getFullType(); + if (canIgnoreType(type)) + return; + + // Check for semantic decorations + // (e.g. globals like gl_GlobalInvocationID) + if (variable->findDecoration()) + return; + + // Check for initialization blocks + for (auto inst : variable->getChildren()) + { + if (as(inst)) + return; + } + + auto addresses = getAliasableInstructions(variable); + + List stores; + List loads; + + for (auto alias : addresses) + { + for (auto use = alias->firstUse; use; use = use->nextUse) + { + IRInst* user = use->getUser(); + collectLoadStore(stores, loads, user); + + // Disregard if there is at least one store, + // since we cannot tell what the control flow is + if (stores.getCount()) + return; + } + } + + for (auto load : loads) + { + sink->diagnose(load, + Diagnostics::usingUninitializedGlobalVariable, + variable); + } + } + + void checkForUsingUninitializedValues(IRModule* module, DiagnosticSink* sink) + { + for (auto inst : module->getGlobalInsts()) + { + if (auto func = as(inst)) + { + checkUninitializedValues(func, sink); + } + else if (auto generic = as(inst)) + { + auto retVal = findGenericReturnVal(generic); + if (auto funcVal = as(retVal)) + checkUninitializedValues(funcVal, sink); + } + else if (auto global = as(inst)) + { + checkUninitializedGlobals(global, sink); + } + } + } +} diff --git a/source/slang/slang-ir-use-uninitialized-out-param.h b/source/slang/slang-ir-use-uninitialized-values.h similarity index 80% rename from source/slang/slang-ir-use-uninitialized-out-param.h rename to source/slang/slang-ir-use-uninitialized-values.h index fd090c4f99..9b6867a3b5 100644 --- a/source/slang/slang-ir-use-uninitialized-out-param.h +++ b/source/slang/slang-ir-use-uninitialized-values.h @@ -6,7 +6,7 @@ namespace Slang class DiagnosticSink; struct IRModule; - void checkForUsingUninitializedOutParams( + void checkForUsingUninitializedValues( IRModule* module, DiagnosticSink* sink); } diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index c946072f95..6fa2ce67f8 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -33,7 +33,7 @@ #include "slang-ir-clone.h" #include "slang-ir-lower-error-handling.h" #include "slang-ir-obfuscate-loc.h" -#include "slang-ir-use-uninitialized-out-param.h" +#include "slang-ir-use-uninitialized-values.h" #include "slang-ir-peephole.h" #include "slang-mangle.h" #include "slang-type-layout.h" @@ -10925,8 +10925,8 @@ RefPtr generateIRForTranslationUnit( // call graph) based on constraints imposed by different instructions. propagateConstExpr(module, compileRequest->getSink()); - // Check for using uninitialized out parameters. - checkForUsingUninitializedOutParams(module, compileRequest->getSink()); + // Check for using uninitialized values + checkForUsingUninitializedValues(module, compileRequest->getSink()); // TODO: give error messages if any `undefined` or // instructions remain. diff --git a/tests/bugs/gh-4434.slang b/tests/bugs/gh-4434.slang index 9752260088..c9a48d4391 100644 --- a/tests/bugs/gh-4434.slang +++ b/tests/bugs/gh-4434.slang @@ -15,17 +15,19 @@ RWStructuredBuffer outputBuffer; [numthreads(4, 1, 1)] void computeMain(uint tid : SV_GroupIndex) { - bool a, b, c; - c = and(a, b); - - bool1 i, j, k; - bool2 l, m, n; - bool3 o, p, q; - bool4 r, s, t; - k = and(i, j); - n = and(m, l); - q = and(o, p); - t = and(r, s); + bool K = (bool)outputBuffer[tid]; + + bool c = and(K, K); + + bool1 K1 = K; + bool2 K2 = K; + bool3 K3 = K; + bool4 K4 = K; + + bool1 k = and(K, K); + bool2 n = and(K, K); + bool3 q = and(K, K); + bool4 t = and(K, K); k = !and(k, false); n = !and(n, false); diff --git a/tests/bugs/gh-4441.slang b/tests/bugs/gh-4441.slang index 59d577c8d5..a2de7e00dc 100644 --- a/tests/bugs/gh-4441.slang +++ b/tests/bugs/gh-4441.slang @@ -15,17 +15,19 @@ RWStructuredBuffer outputBuffer; [numthreads(4, 1, 1)] void computeMain(uint tid : SV_GroupIndex) { - bool a, b, c; - c = or(a, b); - - bool1 i, j, k; - bool2 l, m, n; - bool3 o, p, q; - bool4 r, s, t; - k = or(i, j); - n = or(m, l); - q = or(o, p); - t = or(r, s); + bool K = (bool)outputBuffer[tid]; + + bool c = or(K, K); + + bool1 K1 = K; + bool2 K2 = K; + bool3 K3 = K; + bool4 K4 = K; + + bool1 k = or(K, K); + bool2 n = or(K, K); + bool3 q = or(K, K); + bool4 t = or(K, K); k = or(k, true); n = or(n, true); diff --git a/tests/diagnostics/ctor-ordinary-retval-legal.slang b/tests/diagnostics/ctor-ordinary-retval-legal.slang index f2d1765ffb..8117b4c802 100644 --- a/tests/diagnostics/ctor-ordinary-retval-legal.slang +++ b/tests/diagnostics/ctor-ordinary-retval-legal.slang @@ -11,7 +11,8 @@ struct S { - float v; + float v; + __init() { v = 0; } } struct S1a : S diff --git a/tests/diagnostics/uninitialized-out.slang b/tests/diagnostics/uninitialized-out.slang deleted file mode 100644 index d0d87449fb..0000000000 --- a/tests/diagnostics/uninitialized-out.slang +++ /dev/null @@ -1,57 +0,0 @@ -//DIAGNOSTIC_TEST:SIMPLE: - -float foo(out float3 v) -{ - // This should error as we haven't set v before we read from it - float r = v.x + 1.0; - // This should warn as we haven't set v before we return - return r; -} - -// This should warn as we return without x being initialized -float bar(out float x) -{ - return 0; -} - -// This should also warn pointing at the implicit return -void baz(out float x) {} - -void twoReturns(bool b, out float y) -{ - if(b) - { - // Should warn - return; - } - y = 0; - // Shouldn't warn - return; -} - -void twoOkReturns(bool b, out float y) -{ - if(b) - { - // Shouldn't warn - unused(y); - return; - } - y = 0; - // Shouldn't warn - return; -} - -// TODO: This should warn that n is potentially uninitialized -int ok(bool b, out int n) -{ - if(b) - n = 0; - return n; -} - -// TODO: This should warn that arr isn't fully initialized -void partial(out float arr[2]) -{ - arr[0] = 1; -} diff --git a/tests/diagnostics/uninitialized.slang b/tests/diagnostics/uninitialized.slang new file mode 100644 index 0000000000..4779f45c94 --- /dev/null +++ b/tests/diagnostics/uninitialized.slang @@ -0,0 +1,263 @@ +//TEST:SIMPLE(filecheck=CHK): -target spirv + +// TODO: +// * warn potentially uninitialized variables (control flow) +// * warn partially uninitialized variables (structs, arrays, etc.) +// * warn uninitialized fields in constructors + +/////////////////////////////////// +// Uninitialized local variables // +/////////////////////////////////// + +// Should not warn here (unconditionalBranch) +float3 unconditional(int mode) +{ + float f(float) { return 1; } + + float k0; + float k1; + + if (mode == 1) + { + k1 = 1; + k0 = 1; + + const float w = k1 * f(1); + k0 = 4.0f * k0 * w; + k1 = 2.0f * k1 * w; + } + + return k0 + k1; +} + +// Warn here for branches using the variables +int conditional() +{ + int k; + //CHK-DAG: warning 41016: use of uninitialized variable 'k' + return (k > 0); +} + +// Using unitialized values +int use_undefined_value(int k) +{ + int x; + x += k; + //CHK-DAG: warning 41016: use of uninitialized variable 'x' + return x; +} + +// Returning uninitialized values +__generic +T generic_undefined_return() +{ + T x; + //CHK-DAG: warning 41016: use of uninitialized variable 'x' + return x; +} + +// Array variables +float undefined_array() +{ + float array[2]; + //CHK-DAG: warning 41016: use of uninitialized variable 'array' + return array[0]; +} + +float filled_array(int mode) +{ + float array[2]; + array[0] = 1.0f; + return array[0]; +} + +// Structs and nested structs +struct Data +{ + float value; +}; + +struct NestedData +{ + Data data; +}; + +// No warnings here, even thought autodiff generates +// IR which frequently returns undefined values +struct DiffStruct : IDifferentiable +{ + Data data; + float x; +} + +// Same story here +[ForwardDifferentiable] +DiffStruct differentiable(float x) +{ + DiffStruct ds; + ds.x = x; + return ds; +} + +// Empty structures should not generate diagnostics +// for empty default constructors +struct EmptyStruct +{ + __init() {} +}; + +// No warnings for empty structs even without __init() +struct NonEmptyStruct +{ + int field; + + __init() + { + field = 1; + } +}; + +// No warnings even when __init() is not specified +struct NoDefault +{ + int f(int i) + { + return i; + } +}; + +// Constructing the above structs +int constructors() +{ + EmptyStruct empty; + NoDefault no_default; + return no_default.f(1); +} + +// Using struct fields and nested structs +float structs() +{ + Data inputData = Data(1.0); + + float undefVar; + Data undefData; + NestedData nestedData; + + float result = inputData.value; + + //CHK-DAG: warning 41016: use of uninitialized variable 'undefVar' + result += undefVar; + + //CHK-DAG: warning 41016: use of uninitialized variable 'undefData' + result += undefData.value; + + //CHK-DAG: warning 41016: use of uninitialized variable 'nestedData' + result += nestedData.data.value; + + return result; +} + +//////////////////////////////////// +// Uninitialized global variables // +//////////////////////////////////// + +// Using groupshared variables +groupshared float4 gsConstexpr = float4(1.0f); +groupshared float4 gsUndefined; + +// OK +float use_constexpr_initialized_gs() +{ + return gsConstexpr.x; +} + +float use_undefined_gs() +{ + //CHK-DAG: warning 41017: use of uninitialized global variable 'gsUndefined' + return gsUndefined.x; +} + +// Using static variables +static const float cexprInitialized = 1.0f; +static float writtenNever; +static float writtenLater; + +// OK +float use_initialized_static() +{ + return cexprInitialized; +} + +// Should detect this and treat it as a store +void write_to_later() +{ + writtenLater = 1.0f; +} + +float use_never_written() +{ + //CHK-DAG: warning 41017: use of uninitialized global variable 'writtenNever' + return writtenNever; +} + +// OK because of prior store +float use_later_writte() +{ + return writtenLater; +} + +////////////////////////////////// +// Uninitialized out parameters // +////////////////////////////////// + +// Using before assigning +float regular_undefined_use(out float3 v) +{ + //CHK-DAG: warning 41015: use of uninitialized out parameter 'v' + float r = v.x + 1.0; + + //CHK-DAG: warning 41018: returning without initializing out parameter 'v' + return r; +} + +// Returning before assigning +float returning_undefined_use(out float x) +{ + //CHK-DAG: warning 41018: returning without initializing out parameter 'x' + return 0; +} + +// Implicit, still returning before assigning +void implicit_undefined_use(out float x) +{ + //CHK-DAG: warning 41018: returning without initializing out parameter 'x' +} + +// Warn on potential return paths +void control_flow_undefined(bool b, out float y) +{ + if(b) + { + //CHK-DAG: warning 41018: returning without initializing out parameter 'y' + return; + } + y = 0; + return; +} + +// No warnings if all paths are fine +void control_flow_defined(bool b, out float y) +{ + if(b) + { + unused(y); + return; + } + y = 0; + return; +} + +//CHK-NOT: warning 41015 +//CHK-NOT: warning 41016 +//CHK-NOT: warning 41017 +//CHK-NOT: warning 41018