From da1731bf73166cd662e47434bc86dda70e31a2d6 Mon Sep 17 00:00:00 2001 From: Sven van Haastregt Date: Wed, 27 Nov 2024 12:07:51 +0100 Subject: [PATCH] [SPIR-V 1.2] SPIRVReader: Add AlignmentId support (#2869) If there is no `OpDecorate .. Alignment` in the input, see if there is an `OpDecorateId .. AlignmentId` and take the alignment from the referenced constant instead. Once `AlignmentId` has been translated to LLVM IR, it is indistinguishable from an (non-ID) `Alignment` decoration. --- lib/SPIRV/SPIRVReader.cpp | 39 ++++++++++++++++++++++++++++++--------- lib/SPIRV/SPIRVReader.h | 7 +++++++ test/AlignmentId.spvasm | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 9 deletions(-) create mode 100644 test/AlignmentId.spvasm diff --git a/lib/SPIRV/SPIRVReader.cpp b/lib/SPIRV/SPIRVReader.cpp index 959c246ae..681140148 100644 --- a/lib/SPIRV/SPIRVReader.cpp +++ b/lib/SPIRV/SPIRVReader.cpp @@ -288,6 +288,29 @@ Value *SPIRVToLLVM::mapFunction(SPIRVFunction *BF, Function *F) { return F; } +std::optional SPIRVToLLVM::transIdAsConstant(SPIRVId Id) { + auto *V = BM->get(Id); + const auto *ConstValue = + dyn_cast(transValue(V, nullptr, nullptr)); + if (!ConstValue) + return {}; + return ConstValue->getZExtValue(); +} + +std::optional SPIRVToLLVM::getAlignment(SPIRVValue *V) { + SPIRVWord AlignmentBytes = 0; + if (V->hasAlignment(&AlignmentBytes)) { + return AlignmentBytes; + } + + // If there was no Alignment decoration, look for AlignmentId instead. + SPIRVId AlignId; + if (V->hasDecorateId(DecorationAlignmentId, 0, &AlignId)) { + return transIdAsConstant(AlignId); + } + return {}; +} + Type *SPIRVToLLVM::transFPType(SPIRVType *T) { switch (T->getFloatBitWidth()) { case 16: @@ -3154,9 +3177,9 @@ void SPIRVToLLVM::transFunctionAttrs(SPIRVFunction *BF, Function *F) { SPIRVWord MaxOffset = 0; if (BA->hasDecorate(DecorationMaxByteOffset, 0, &MaxOffset)) Builder.addDereferenceableAttr(MaxOffset); - SPIRVWord AlignmentBytes = 0; - if (BA->hasAlignment(&AlignmentBytes)) - Builder.addAlignmentAttr(AlignmentBytes); + if (auto Alignment = getAlignment(BA)) { + Builder.addAlignmentAttr(*Alignment); + } I->addAttrs(Builder); } BF->foreachReturnValueAttr([&](SPIRVFuncParamAttrKind Kind) { @@ -4931,15 +4954,13 @@ bool SPIRVToLLVM::transFPGAFunctionMetadata(SPIRVFunction *BF, Function *F) { bool SPIRVToLLVM::transAlign(SPIRVValue *BV, Value *V) { if (auto *AL = dyn_cast(V)) { - SPIRVWord Align = 0; - if (BV->hasAlignment(&Align)) - AL->setAlignment(llvm::Align(Align)); + if (auto Align = getAlignment(BV)) + AL->setAlignment(llvm::Align(*Align)); return true; } if (auto *GV = dyn_cast(V)) { - SPIRVWord Align = 0; - if (BV->hasAlignment(&Align)) - GV->setAlignment(MaybeAlign(Align)); + if (auto Align = getAlignment(BV)) + GV->setAlignment(MaybeAlign(*Align)); return true; } return true; diff --git a/lib/SPIRV/SPIRVReader.h b/lib/SPIRV/SPIRVReader.h index c1b74874c..75dfd52c0 100644 --- a/lib/SPIRV/SPIRVReader.h +++ b/lib/SPIRV/SPIRVReader.h @@ -215,6 +215,13 @@ class SPIRVToLLVM : private BuiltinCallHelper { bool isDirectlyTranslatedToOCL(Op OpCode) const; MDString *transOCLKernelArgTypeName(SPIRVFunctionParameter *); + + // Attempt to translate Id as a (specialization) constant. + std::optional transIdAsConstant(SPIRVId Id); + + // Return the value of an Alignment or AlignmentId decoration for V. + std::optional getAlignment(SPIRVValue *V); + Value *mapFunction(SPIRVFunction *BF, Function *F); Value *getTranslatedValue(SPIRVValue *BV); IntrinsicInst *getLifetimeStartIntrinsic(Instruction *I); diff --git a/test/AlignmentId.spvasm b/test/AlignmentId.spvasm new file mode 100644 index 000000000..aa1b28eda --- /dev/null +++ b/test/AlignmentId.spvasm @@ -0,0 +1,34 @@ +; REQUIRES: spirv-as + +; RUN: spirv-as %s --target-env spv1.2 -o %t.spv +; RUN: spirv-val %t.spv +; RUN: llvm-spirv -r -o %t.rev.bc %t.spv +; RUN: llvm-dis %t.rev.bc -o - | FileCheck %s + + OpCapability Addresses + OpCapability Linkage + OpCapability Kernel + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %fn "testAlignmentId" + OpName %p "p" + OpDecorate %x LinkageAttributes "x" Export + OpDecorateId %x AlignmentId %al + OpDecorateId %p AlignmentId %al_spec + %void = OpTypeVoid + %uint = OpTypeInt 32 0 + %ptr = OpTypePointer CrossWorkgroup %uint + %fnTy = OpTypeFunction %void %ptr + %al = OpConstant %uint 16 + %al_spec = OpSpecConstantOp %uint IAdd %al %al + %uint_42 = OpConstant %uint 42 +; Verify alignment of variable. +; CHECK: @x = addrspace(1) global i32 42, align 16 + %x = OpVariable %ptr CrossWorkgroup %uint_42 + + %fn = OpFunction %void None %fnTy +; Verify alignment of function parameter. +; CHECK: define spir_kernel void @testAlignmentId(ptr addrspace(1) align 32 %p) + %p = OpFunctionParameter %ptr + %entry = OpLabel + OpReturn + OpFunctionEnd