From dda821526f67acebe91bb17bfde4e2637e2e8486 Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 10 Dec 2024 11:53:01 -0800 Subject: [PATCH 1/3] Fix attribute reflection. --- include/slang.h | 55 ++++++++----- source/slang/slang-check-modifier.cpp | 15 ++-- source/slang/slang-reflection-api.cpp | 8 +- .../unit-test-attribute-reflection.cpp | 79 +++++++++++++++++++ 4 files changed, 126 insertions(+), 31 deletions(-) create mode 100644 tools/slang-unit-test/unit-test-attribute-reflection.cpp diff --git a/include/slang.h b/include/slang.h index 2dfee6b281..d2e3162bcb 100644 --- a/include/slang.h +++ b/include/slang.h @@ -1775,7 +1775,7 @@ public: \ typedef struct SlangReflectionVariable SlangReflectionVariable; typedef struct SlangReflectionVariableLayout SlangReflectionVariableLayout; typedef struct SlangReflectionTypeParameter SlangReflectionTypeParameter; - typedef struct SlangReflectionUserAttribute SlangReflectionUserAttribute; + typedef struct SlangReflectionAttribute SlangReflectionUserAttribute; typedef struct SlangReflectionFunction SlangReflectionFunction; typedef struct SlangReflectionGeneric SlangReflectionGeneric; @@ -2140,46 +2140,48 @@ union GenericArgReflection bool boolVal; }; -struct UserAttribute +struct Attribute { char const* getName() { - return spReflectionUserAttribute_GetName((SlangReflectionUserAttribute*)this); + return spReflectionUserAttribute_GetName((SlangReflectionAttribute*)this); } uint32_t getArgumentCount() { return (uint32_t)spReflectionUserAttribute_GetArgumentCount( - (SlangReflectionUserAttribute*)this); + (SlangReflectionAttribute*)this); } TypeReflection* getArgumentType(uint32_t index) { return (TypeReflection*)spReflectionUserAttribute_GetArgumentType( - (SlangReflectionUserAttribute*)this, + (SlangReflectionAttribute*)this, index); } SlangResult getArgumentValueInt(uint32_t index, int* value) { return spReflectionUserAttribute_GetArgumentValueInt( - (SlangReflectionUserAttribute*)this, + (SlangReflectionAttribute*)this, index, value); } SlangResult getArgumentValueFloat(uint32_t index, float* value) { return spReflectionUserAttribute_GetArgumentValueFloat( - (SlangReflectionUserAttribute*)this, + (SlangReflectionAttribute*)this, index, value); } const char* getArgumentValueString(uint32_t index, size_t* outSize) { return spReflectionUserAttribute_GetArgumentValueString( - (SlangReflectionUserAttribute*)this, + (SlangReflectionAttribute*)this, index, outSize); } }; +typedef Attribute UserAttribute; + struct TypeReflection { enum class Kind @@ -2320,13 +2322,19 @@ struct TypeReflection return (UserAttribute*)spReflectionType_GetUserAttribute((SlangReflectionType*)this, index); } - UserAttribute* findUserAttributeByName(char const* name) + UserAttribute* findAttributeByName(char const* name) { return (UserAttribute*)spReflectionType_FindUserAttributeByName( (SlangReflectionType*)this, name); } + [[deprecated("use findAttributeByName")]] UserAttribute* findUserAttributeByName( + char const* name) + { + return findAttributeByName(name); + } + TypeReflection* applySpecializations(GenericReflection* generic) { return (TypeReflection*)spReflectionType_applySpecializations( @@ -2777,14 +2785,14 @@ struct VariableReflection return spReflectionVariable_GetUserAttributeCount((SlangReflectionVariable*)this); } - UserAttribute* getUserAttributeByIndex(unsigned int index) + Attribute* getUserAttributeByIndex(unsigned int index) { return (UserAttribute*)spReflectionVariable_GetUserAttribute( (SlangReflectionVariable*)this, index); } - UserAttribute* findUserAttributeByName(SlangSession* globalSession, char const* name) + Attribute* findAttributeByName(SlangSession* globalSession, char const* name) { return (UserAttribute*)spReflectionVariable_FindUserAttributeByName( (SlangReflectionVariable*)this, @@ -2792,6 +2800,13 @@ struct VariableReflection name); } + [[deprecated("use findAttributeByName")]] Attribute* findUserAttributeByName( + SlangSession* globalSession, + char const* name) + { + return findAttributeByName(globalSession, name); + } + bool hasDefaultValue() { return spReflectionVariable_HasDefaultValue((SlangReflectionVariable*)this); @@ -2908,20 +2923,24 @@ struct FunctionReflection { return spReflectionFunction_GetUserAttributeCount((SlangReflectionFunction*)this); } - UserAttribute* getUserAttributeByIndex(unsigned int index) + Attribute* getUserAttributeByIndex(unsigned int index) { - return (UserAttribute*)spReflectionFunction_GetUserAttribute( - (SlangReflectionFunction*)this, - index); + return ( + Attribute*)spReflectionFunction_GetUserAttribute((SlangReflectionFunction*)this, index); } - UserAttribute* findUserAttributeByName(SlangSession* globalSession, char const* name) + Attribute* findAttributeByName(SlangSession* globalSession, char const* name) { - return (UserAttribute*)spReflectionFunction_FindUserAttributeByName( + return (Attribute*)spReflectionFunction_FindUserAttributeByName( (SlangReflectionFunction*)this, globalSession, name); } - + [[deprecated("use findAttributeByName")]] Attribute* findUserAttributeByName( + SlangSession* globalSession, + char const* name) + { + return findAttributeByName(globalSession, name); + } Modifier* findModifier(Modifier::ID id) { return (Modifier*)spReflectionFunction_FindModifier( diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index aebfe3b962..05eb978bcb 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -752,18 +752,15 @@ Modifier* SemanticsVisitor::validateAttribute( { auto& arg = attr->args[paramIndex]; bool typeChecked = false; - if (auto basicType = as(paramDecl->getType())) + if (isValidCompileTimeConstantType(paramDecl->getType())) { - if (basicType->getBaseType() == BaseType::Int) + if (auto cint = checkConstantIntVal(arg)) { - if (auto cint = checkConstantIntVal(arg)) - { - for (Index ci = attr->intArgVals.getCount(); ci < paramIndex + 1; ci++) - attr->intArgVals.add(nullptr); - attr->intArgVals[(uint32_t)paramIndex] = cint; - } - typeChecked = true; + for (Index ci = attr->intArgVals.getCount(); ci < paramIndex + 1; ci++) + attr->intArgVals.add(nullptr); + attr->intArgVals[(uint32_t)paramIndex] = cint; } + typeChecked = true; } if (!typeChecked) { diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index 93e4951944..31cc474cfa 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -23,11 +23,11 @@ namespace Slang // Conversion routines to help with strongly-typed reflection API -static inline UserDefinedAttribute* convert(SlangReflectionUserAttribute* attrib) +static inline Attribute* convert(SlangReflectionUserAttribute* attrib) { - return (UserDefinedAttribute*)attrib; + return (Attribute*)attrib; } -static inline SlangReflectionUserAttribute* convert(UserDefinedAttribute* attrib) +static inline SlangReflectionUserAttribute* convert(Attribute* attrib) { return (SlangReflectionUserAttribute*)attrib; } @@ -154,7 +154,7 @@ static SlangReflectionUserAttribute* findUserAttributeByName( const char* name) { auto nameObj = session->tryGetNameObj(name); - for (auto x : decl->getModifiersOfType()) + for (auto x : decl->getModifiersOfType()) { if (x->keywordName == nameObj) return (SlangReflectionUserAttribute*)(x); diff --git a/tools/slang-unit-test/unit-test-attribute-reflection.cpp b/tools/slang-unit-test/unit-test-attribute-reflection.cpp new file mode 100644 index 0000000000..e60eeb2d44 --- /dev/null +++ b/tools/slang-unit-test/unit-test-attribute-reflection.cpp @@ -0,0 +1,79 @@ +// unit-test-translation-unit-import.cpp + +#include "../../source/core/slang-io.h" +#include "../../source/core/slang-process.h" +#include "slang-com-ptr.h" +#include "slang.h" +#include "unit-test/slang-unit-test.h" + +#include +#include + +using namespace Slang; + +// Test that the reflection API provides correct info about attributes. + +SLANG_UNIT_TEST(attributeReflection) +{ + const char* userSourceBody = R"( + public enum E + { + V0, + V1, + }; + + [__AttributeUsage(_AttributeTargets.Struct)] + public struct NormalTextureAttribute + { + public E Type; + }; + + [COM("042BE50B-CB01-4DBB-8367-3A9CDCBE2F49")] + interface IInterface { void f(); } + + [NormalTexture(E.V1)] + struct TS {}; + )"; + String userSource = userSourceBody; + ComPtr globalSession; + SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK); + slang::TargetDesc targetDesc = {}; + targetDesc.format = SLANG_HLSL; + targetDesc.profile = globalSession->findProfile("sm_5_0"); + slang::SessionDesc sessionDesc = {}; + sessionDesc.targetCount = 1; + sessionDesc.targets = &targetDesc; + ComPtr session; + SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK); + + ComPtr diagnosticBlob; + auto module = session->loadModuleFromSourceString( + "m", + "m.slang", + userSourceBody, + diagnosticBlob.writeRef()); + SLANG_CHECK(module != nullptr); + + auto reflection = module->getLayout(); + + auto interfaceType = reflection->findTypeByName("IInterface"); + SLANG_CHECK(interfaceType != nullptr); + + auto comAttribute = interfaceType->findAttributeByName("COM"); + SLANG_CHECK(comAttribute != nullptr); + + size_t size = 0; + auto guid = comAttribute->getArgumentValueString(0, &size); + UnownedStringSlice stringSlice = UnownedStringSlice(guid, size); + SLANG_CHECK(stringSlice == "\"042BE50B-CB01-4DBB-8367-3A9CDCBE2F49\""); + + auto testType = reflection->findTypeByName("TS"); + SLANG_CHECK(testType != nullptr); + + auto normalTextureAttribute = testType->findAttributeByName("NormalTexture"); + SLANG_CHECK(normalTextureAttribute != nullptr); + + int value = 0; + normalTextureAttribute->getArgumentValueInt(0, &value); + SLANG_CHECK(value == 1); +} From 3a703dcacaa25a23458d7fb1fde93aa1438bd8b2 Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 10 Dec 2024 14:03:19 -0800 Subject: [PATCH 2/3] Fix. --- include/slang.h | 3 ++- source/slang/slang-reflection-api.cpp | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/include/slang.h b/include/slang.h index d2e3162bcb..d009ca63bc 100644 --- a/include/slang.h +++ b/include/slang.h @@ -1775,7 +1775,8 @@ public: \ typedef struct SlangReflectionVariable SlangReflectionVariable; typedef struct SlangReflectionVariableLayout SlangReflectionVariableLayout; typedef struct SlangReflectionTypeParameter SlangReflectionTypeParameter; - typedef struct SlangReflectionAttribute SlangReflectionUserAttribute; + typedef struct SlangReflectionUserAttribute SlangReflectionUserAttribute; + typedef SlangReflectionUserAttribute SlangReflectionAttribute; typedef struct SlangReflectionFunction SlangReflectionFunction; typedef struct SlangReflectionGeneric SlangReflectionGeneric; diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index 31cc474cfa..d7f793d051 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -154,6 +154,8 @@ static SlangReflectionUserAttribute* findUserAttributeByName( const char* name) { auto nameObj = session->tryGetNameObj(name); + if (!nameObj) + return nullptr; for (auto x : decl->getModifiersOfType()) { if (x->keywordName == nameObj) From 5ce37af4726bf907813e340ec970da51dde2adf8 Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 10 Dec 2024 14:23:22 -0800 Subject: [PATCH 3/3] Fix. --- include/slang.h | 14 +++----------- tools/gfx/d3d12/d3d12-shader-object-layout.cpp | 2 +- .../unit-test-decl-tree-reflection.cpp | 2 +- .../unit-test-function-reflection.cpp | 2 +- 4 files changed, 6 insertions(+), 14 deletions(-) diff --git a/include/slang.h b/include/slang.h index d009ca63bc..2ba7150ec1 100644 --- a/include/slang.h +++ b/include/slang.h @@ -2330,11 +2330,7 @@ struct TypeReflection name); } - [[deprecated("use findAttributeByName")]] UserAttribute* findUserAttributeByName( - char const* name) - { - return findAttributeByName(name); - } + UserAttribute* findUserAttributeByName(char const* name) { return findAttributeByName(name); } TypeReflection* applySpecializations(GenericReflection* generic) { @@ -2801,9 +2797,7 @@ struct VariableReflection name); } - [[deprecated("use findAttributeByName")]] Attribute* findUserAttributeByName( - SlangSession* globalSession, - char const* name) + Attribute* findUserAttributeByName(SlangSession* globalSession, char const* name) { return findAttributeByName(globalSession, name); } @@ -2936,9 +2930,7 @@ struct FunctionReflection globalSession, name); } - [[deprecated("use findAttributeByName")]] Attribute* findUserAttributeByName( - SlangSession* globalSession, - char const* name) + Attribute* findUserAttributeByName(SlangSession* globalSession, char const* name) { return findAttributeByName(globalSession, name); } diff --git a/tools/gfx/d3d12/d3d12-shader-object-layout.cpp b/tools/gfx/d3d12/d3d12-shader-object-layout.cpp index 6cf51ee2bb..8e2d24ad64 100644 --- a/tools/gfx/d3d12/d3d12-shader-object-layout.cpp +++ b/tools/gfx/d3d12/d3d12-shader-object-layout.cpp @@ -39,7 +39,7 @@ bool ShaderObjectLayoutImpl::isBindingRangeRootParameter( { if (auto leafVariable = typeLayout->getBindingRangeLeafVariable(bindingRangeIndex)) { - if (leafVariable->findUserAttributeByName(globalSession, rootParameterAttributeName)) + if (leafVariable->findAttributeByName(globalSession, rootParameterAttributeName)) { isRootParameter = true; } diff --git a/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp index 2ceb9981bf..512be9be5f 100644 --- a/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp +++ b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp @@ -178,7 +178,7 @@ SLANG_UNIT_TEST(declTreeReflection) SLANG_CHECK(result == SLANG_OK); SLANG_CHECK(val == 1024); SLANG_CHECK( - funcReflection->findUserAttributeByName(globalSession.get(), "MyFuncProperty") == + funcReflection->findAttributeByName(globalSession.get(), "MyFuncProperty") == userAttribute); } diff --git a/tools/slang-unit-test/unit-test-function-reflection.cpp b/tools/slang-unit-test/unit-test-function-reflection.cpp index 52c2e795a3..3ce6ab7a5d 100644 --- a/tools/slang-unit-test/unit-test-function-reflection.cpp +++ b/tools/slang-unit-test/unit-test-function-reflection.cpp @@ -108,7 +108,7 @@ SLANG_UNIT_TEST(functionReflection) SLANG_CHECK(result == SLANG_OK); SLANG_CHECK(val == 1024); SLANG_CHECK( - funcReflection->findUserAttributeByName(globalSession.get(), "MyFuncProperty") == + funcReflection->findAttributeByName(globalSession.get(), "MyFuncProperty") == userAttribute); // Check overloaded method resolution