Skip to content

Commit

Permalink
Add normalize builtins and normalize HLSL function to DirectX and SPI…
Browse files Browse the repository at this point in the history
…R-V backend (llvm#102683)

This PR adds the normalize intrinsic and an HLSL function that uses it.
The SPIRV backend is also implemented.

Used llvm#101256 as a reference,
along with llvm#102243
Fixes llvm#99139
  • Loading branch information
bob80905 authored Aug 13, 2024
1 parent 643a208 commit 1b2d11d
Show file tree
Hide file tree
Showing 14 changed files with 448 additions and 0 deletions.
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -4725,6 +4725,12 @@ def HLSLMad : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}

def HLSLNormalize : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_normalize"];
let Attributes = [NoThrow, Const];
let Prototype = "void(...)";
}

def HLSLRcp : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_elementwise_rcp"];
let Attributes = [NoThrow, Const];
Expand Down
11 changes: 11 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18584,6 +18584,17 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
CGM.getHLSLRuntime().getLengthIntrinsic(), ArrayRef<Value *>{X},
nullptr, "hlsl.length");
}
case Builtin::BI__builtin_hlsl_normalize: {
Value *X = EmitScalarExpr(E->getArg(0));

assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
"normalize operand must have a float representation");

return Builder.CreateIntrinsic(
/*ReturnType=*/X->getType(),
CGM.getHLSLRuntime().getNormalizeIntrinsic(), ArrayRef<Value *>{X},
nullptr, "hlsl.normalize");
}
case Builtin::BI__builtin_hlsl_elementwise_frac: {
Value *Op0 = EmitScalarExpr(E->getArg(0));
if (!E->getArg(0)->getType()->hasFloatingRepresentation())
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGHLSLRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(Frac, frac)
GENERATE_HLSL_INTRINSIC_FUNCTION(Length, length)
GENERATE_HLSL_INTRINSIC_FUNCTION(Lerp, lerp)
GENERATE_HLSL_INTRINSIC_FUNCTION(Normalize, normalize)
GENERATE_HLSL_INTRINSIC_FUNCTION(Rsqrt, rsqrt)
GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)

Expand Down
32 changes: 32 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -1352,6 +1352,38 @@ double3 min(double3, double3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
double4 min(double4, double4);

//===----------------------------------------------------------------------===//
// normalize builtins
//===----------------------------------------------------------------------===//

/// \fn T normalize(T x)
/// \brief Returns the normalized unit vector of the specified floating-point
/// vector. \param x [in] The vector of floats.
///
/// Normalize is based on the following formula: x / length(x).

_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
half normalize(half);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
half2 normalize(half2);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
half3 normalize(half3);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
half4 normalize(half4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
float normalize(float);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
float2 normalize(float2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
float3 normalize(float3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
float4 normalize(float4);

//===----------------------------------------------------------------------===//
// pow builtins
//===----------------------------------------------------------------------===//
Expand Down
12 changes: 12 additions & 0 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1108,6 +1108,18 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
case Builtin::BI__builtin_hlsl_normalize: {
if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
return true;
if (SemaRef.checkArgCount(TheCall, 1))
return true;

ExprResult A = TheCall->getArg(0);
QualType ArgTyA = A.get()->getType();
// return type is the same as the input type
TheCall->setType(ArgTyA);
break;
}
// Note these are llvm builtins that we want to catch invalid intrinsic
// generation. Normal handling of these builitns will occur elsewhere.
case Builtin::BI__builtin_elementwise_bitreverse: {
Expand Down
100 changes: 100 additions & 0 deletions clang/test/CodeGenHLSL/builtins/normalize.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
// RUN: --check-prefixes=CHECK,DXIL_CHECK,DXIL_NATIVE_HALF,NATIVE_HALF
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
// RUN: -o - | FileCheck %s --check-prefixes=CHECK,DXIL_CHECK,NO_HALF,DXIL_NO_HALF
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: spirv-unknown-vulkan-compute %s -fnative-half-type \
// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
// RUN: --check-prefixes=CHECK,NATIVE_HALF,SPIR_NATIVE_HALF,SPIR_CHECK
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: spirv-unknown-vulkan-compute %s -emit-llvm -disable-llvm-passes \
// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF,SPIR_NO_HALF,SPIR_CHECK

// DXIL_NATIVE_HALF: define noundef half @
// SPIR_NATIVE_HALF: define spir_func noundef half @
// DXIL_NATIVE_HALF: call half @llvm.dx.normalize.f16(half
// SPIR_NATIVE_HALF: call half @llvm.spv.normalize.f16(half
// DXIL_NO_HALF: call float @llvm.dx.normalize.f32(float
// SPIR_NO_HALF: call float @llvm.spv.normalize.f32(float
// NATIVE_HALF: ret half
// NO_HALF: ret float
half test_normalize_half(half p0)
{
return normalize(p0);
}
// DXIL_NATIVE_HALF: define noundef <2 x half> @
// SPIR_NATIVE_HALF: define spir_func noundef <2 x half> @
// DXIL_NATIVE_HALF: call <2 x half> @llvm.dx.normalize.v2f16(<2 x half>
// SPIR_NATIVE_HALF: call <2 x half> @llvm.spv.normalize.v2f16(<2 x half>
// DXIL_NO_HALF: call <2 x float> @llvm.dx.normalize.v2f32(<2 x float>
// SPIR_NO_HALF: call <2 x float> @llvm.spv.normalize.v2f32(<2 x float>
// NATIVE_HALF: ret <2 x half> %hlsl.normalize
// NO_HALF: ret <2 x float> %hlsl.normalize
half2 test_normalize_half2(half2 p0)
{
return normalize(p0);
}
// DXIL_NATIVE_HALF: define noundef <3 x half> @
// SPIR_NATIVE_HALF: define spir_func noundef <3 x half> @
// DXIL_NATIVE_HALF: call <3 x half> @llvm.dx.normalize.v3f16(<3 x half>
// SPIR_NATIVE_HALF: call <3 x half> @llvm.spv.normalize.v3f16(<3 x half>
// DXIL_NO_HALF: call <3 x float> @llvm.dx.normalize.v3f32(<3 x float>
// SPIR_NO_HALF: call <3 x float> @llvm.spv.normalize.v3f32(<3 x float>
// NATIVE_HALF: ret <3 x half> %hlsl.normalize
// NO_HALF: ret <3 x float> %hlsl.normalize
half3 test_normalize_half3(half3 p0)
{
return normalize(p0);
}
// DXIL_NATIVE_HALF: define noundef <4 x half> @
// SPIR_NATIVE_HALF: define spir_func noundef <4 x half> @
// DXIL_NATIVE_HALF: call <4 x half> @llvm.dx.normalize.v4f16(<4 x half>
// SPIR_NATIVE_HALF: call <4 x half> @llvm.spv.normalize.v4f16(<4 x half>
// DXIL_NO_HALF: call <4 x float> @llvm.dx.normalize.v4f32(<4 x float>
// SPIR_NO_HALF: call <4 x float> @llvm.spv.normalize.v4f32(<4 x float>
// NATIVE_HALF: ret <4 x half> %hlsl.normalize
// NO_HALF: ret <4 x float> %hlsl.normalize
half4 test_normalize_half4(half4 p0)
{
return normalize(p0);
}

// DXIL_CHECK: define noundef float @
// SPIR_CHECK: define spir_func noundef float @
// DXIL_CHECK: call float @llvm.dx.normalize.f32(float
// SPIR_CHECK: call float @llvm.spv.normalize.f32(float
// CHECK: ret float
float test_normalize_float(float p0)
{
return normalize(p0);
}
// DXIL_CHECK: define noundef <2 x float> @
// SPIR_CHECK: define spir_func noundef <2 x float> @
// DXIL_CHECK: %hlsl.normalize = call <2 x float> @llvm.dx.normalize.v2f32(
// SPIR_CHECK: %hlsl.normalize = call <2 x float> @llvm.spv.normalize.v2f32(<2 x float>
// CHECK: ret <2 x float> %hlsl.normalize
float2 test_normalize_float2(float2 p0)
{
return normalize(p0);
}
// DXIL_CHECK: define noundef <3 x float> @
// SPIR_CHECK: define spir_func noundef <3 x float> @
// DXIL_CHECK: %hlsl.normalize = call <3 x float> @llvm.dx.normalize.v3f32(
// SPIR_CHECK: %hlsl.normalize = call <3 x float> @llvm.spv.normalize.v3f32(<3 x float>
// CHECK: ret <3 x float> %hlsl.normalize
float3 test_normalize_float3(float3 p0)
{
return normalize(p0);
}
// DXIL_CHECK: define noundef <4 x float> @
// SPIR_CHECK: define spir_func noundef <4 x float> @
// DXIL_CHECK: %hlsl.normalize = call <4 x float> @llvm.dx.normalize.v4f32(
// SPIR_CHECK: %hlsl.normalize = call <4 x float> @llvm.spv.normalize.v4f32(
// CHECK: ret <4 x float> %hlsl.normalize
float4 test_length_float4(float4 p0)
{
return normalize(p0);
}
31 changes: 31 additions & 0 deletions clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -verify -verify-ignore-unexpected

void test_too_few_arg()
{
return __builtin_hlsl_normalize();
// expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
}

void test_too_many_arg(float2 p0)
{
return __builtin_hlsl_normalize(p0, p0);
// expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
}

bool builtin_bool_to_float_type_promotion(bool p1)
{
return __builtin_hlsl_normalize(p1);
// expected-error@-1 {passing 'bool' to parameter of incompatible type 'float'}}
}

bool builtin_normalize_int_to_float_promotion(int p1)
{
return __builtin_hlsl_normalize(p1);
// expected-error@-1 {{passing 'int' to parameter of incompatible type 'float'}}
}

bool2 builtin_normalize_int2_to_float2_promotion(int2 p1)
{
return __builtin_hlsl_normalize(p1);
// expected-error@-1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
}
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def int_dx_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType
def int_dx_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty]>;
def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
def int_dx_rcp : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
}
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsSPIRV.td
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,6 @@ let TargetPrefix = "spv" in {
def int_spv_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>],
[IntrNoMem, IntrWillReturn] >;
def int_spv_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty]>;
def int_spv_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
}
72 changes: 72 additions & 0 deletions llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ static bool isIntrinsicExpansion(Function &F) {
case Intrinsic::dx_uclamp:
case Intrinsic::dx_lerp:
case Intrinsic::dx_length:
case Intrinsic::dx_normalize:
case Intrinsic::dx_sdot:
case Intrinsic::dx_udot:
return true;
Expand Down Expand Up @@ -229,6 +230,75 @@ static bool expandLog10Intrinsic(CallInst *Orig) {
return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f);
}

static bool expandNormalizeIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
Type *Ty = Orig->getType();
Type *EltTy = Ty->getScalarType();
IRBuilder<> Builder(Orig->getParent());
Builder.SetInsertPoint(Orig);

auto *XVec = dyn_cast<FixedVectorType>(Ty);
if (!XVec) {
if (auto *constantFP = dyn_cast<ConstantFP>(X)) {
const APFloat &fpVal = constantFP->getValueAPF();
if (fpVal.isZero())
report_fatal_error(Twine("Invalid input scalar: length is zero"),
/* gen_crash_diag=*/false);
}
Value *Result = Builder.CreateFDiv(X, X);

Orig->replaceAllUsesWith(Result);
Orig->eraseFromParent();
return true;
}

Value *Elt = Builder.CreateExtractElement(X, (uint64_t)0);
unsigned XVecSize = XVec->getNumElements();
Value *DotProduct = nullptr;
// use the dot intrinsic corresponding to the vector size
switch (XVecSize) {
case 1:
report_fatal_error(Twine("Invalid input vector: length is zero"),
/* gen_crash_diag=*/false);
break;
case 2:
DotProduct = Builder.CreateIntrinsic(
EltTy, Intrinsic::dx_dot2, ArrayRef<Value *>{X, X}, nullptr, "dx.dot2");
break;
case 3:
DotProduct = Builder.CreateIntrinsic(
EltTy, Intrinsic::dx_dot3, ArrayRef<Value *>{X, X}, nullptr, "dx.dot3");
break;
case 4:
DotProduct = Builder.CreateIntrinsic(
EltTy, Intrinsic::dx_dot4, ArrayRef<Value *>{X, X}, nullptr, "dx.dot4");
break;
default:
report_fatal_error(Twine("Invalid input vector: vector size is invalid."),
/* gen_crash_diag=*/false);
}

// verify that the length is non-zero
// (if the dot product is non-zero, then the length is non-zero)
if (auto *constantFP = dyn_cast<ConstantFP>(DotProduct)) {
const APFloat &fpVal = constantFP->getValueAPF();
if (fpVal.isZero())
report_fatal_error(Twine("Invalid input vector: length is zero"),
/* gen_crash_diag=*/false);
}

Value *Multiplicand = Builder.CreateIntrinsic(EltTy, Intrinsic::dx_rsqrt,
ArrayRef<Value *>{DotProduct},
nullptr, "dx.rsqrt");

Value *MultiplicandVec = Builder.CreateVectorSplat(XVecSize, Multiplicand);
Value *Result = Builder.CreateFMul(X, MultiplicandVec);

Orig->replaceAllUsesWith(Result);
Orig->eraseFromParent();
return true;
}

static bool expandPowIntrinsic(CallInst *Orig) {

Value *X = Orig->getOperand(0);
Expand Down Expand Up @@ -314,6 +384,8 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
return expandLerpIntrinsic(Orig);
case Intrinsic::dx_length:
return expandLengthIntrinsic(Orig);
case Intrinsic::dx_normalize:
return expandNormalizeIntrinsic(Orig);
case Intrinsic::dx_sdot:
case Intrinsic::dx_udot:
return expandIntegerDot(Orig, F.getIntrinsicID());
Expand Down
22 changes: 22 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectLog10(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

bool selectNormalize(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

Expand Down Expand Up @@ -1409,6 +1412,23 @@ bool SPIRVInstructionSelector::selectFrac(Register ResVReg,
.constrainAllUses(TII, TRI, RBI);
}

bool SPIRVInstructionSelector::selectNormalize(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {

assert(I.getNumOperands() == 3);
assert(I.getOperand(2).isReg());
MachineBasicBlock &BB = *I.getParent();

return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
.addImm(GL::Normalize)
.addUse(I.getOperand(2).getReg())
.constrainAllUses(TII, TRI, RBI);
}

bool SPIRVInstructionSelector::selectRsqrt(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
Expand Down Expand Up @@ -2142,6 +2162,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return selectLength(ResVReg, ResType, I);
case Intrinsic::spv_frac:
return selectFrac(ResVReg, ResType, I);
case Intrinsic::spv_normalize:
return selectNormalize(ResVReg, ResType, I);
case Intrinsic::spv_rsqrt:
return selectRsqrt(ResVReg, ResType, I);
case Intrinsic::spv_lifetime_start:
Expand Down
Loading

0 comments on commit 1b2d11d

Please sign in to comment.