Skip to content

Commit

Permalink
working on fixing formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
natevm committed Dec 8, 2024
1 parent 04c2638 commit d0bd3f9
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 114 deletions.
217 changes: 135 additions & 82 deletions source/slang/slang-emit-c-like.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2866,7 +2866,7 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO
emitBitfieldExtractImpl(inst);
break;
}
case kIROp_BitfieldInsert:
case kIROp_BitfieldInsert:
{
emitBitfieldInsertImpl(inst);
break;
Expand Down Expand Up @@ -3843,40 +3843,60 @@ void CLikeSourceEmitter::emitFuncDecorationsImpl(IRFunc* func)
}
}

bool CLikeSourceEmitter::tryGetIntInfo(IRType* elementType, bool &isSigned, int &bitWidth)
bool CLikeSourceEmitter::tryGetIntInfo(IRType* elementType, bool& isSigned, int& bitWidth)
{
Slang::IROp type = elementType->getOp();
if (!(type >= kIROp_Int8Type && type <= kIROp_UInt64Type)) return false;
if (!(type >= kIROp_Int8Type && type <= kIROp_UInt64Type))
return false;
isSigned = (type >= kIROp_Int8Type && type <= kIROp_Int64Type);

Slang::IROp stype = (isSigned) ? type : Slang::IROp(type - 4);
bitWidth = 8 << (stype - kIROp_Int8Type);
return true;
}

void CLikeSourceEmitter::emitVecNOrScalar(IRVectorType* vectorType, std::function<void()> emitComponentLogic)
void CLikeSourceEmitter::emitVecNOrScalar(
IRVectorType* vectorType,
std::function<void()> emitComponentLogic)
{
if (vectorType)
{
int N = int(getIntVal(vectorType->getElementCount()));
Slang::IRType *elementType = vectorType->getElementType();
Slang::IRType* elementType = vectorType->getElementType();

// Special handling required for CUDA target
if (isCUDATarget(getTargetReq()))
{
m_writer->emit("make_");

switch(elementType->getOp())
switch (elementType->getOp())
{
case kIROp_Int8Type: m_writer->emit("char"); break;
case kIROp_Int16Type: m_writer->emit("short"); break;
case kIROp_IntType: m_writer->emit("int"); break;
case kIROp_Int64Type: m_writer->emit("longlong"); break;
case kIROp_UInt8Type: m_writer->emit("uchar"); break;
case kIROp_UInt16Type: m_writer->emit("ushort"); break;
case kIROp_UIntType: m_writer->emit("uint"); break;
case kIROp_UInt64Type: m_writer->emit("ulonglong"); break;
default: SLANG_ABORT_COMPILATION("Unhandled type emitting CUDA vector");
case kIROp_Int8Type:
m_writer->emit("char");
break;
case kIROp_Int16Type:
m_writer->emit("short");
break;
case kIROp_IntType:
m_writer->emit("int");
break;
case kIROp_Int64Type:
m_writer->emit("longlong");
break;
case kIROp_UInt8Type:
m_writer->emit("uchar");
break;
case kIROp_UInt16Type:
m_writer->emit("ushort");
break;
case kIROp_UIntType:
m_writer->emit("uint");
break;
case kIROp_UInt64Type:
m_writer->emit("ulonglong");
break;
default:
SLANG_ABORT_COMPILATION("Unhandled type emitting CUDA vector");
}

m_writer->emitRawText(std::to_string(N).c_str());
Expand All @@ -3892,12 +3912,13 @@ void CLikeSourceEmitter::emitVecNOrScalar(IRVectorType* vectorType, std::functio
}

// In other languages, we can output the Slang vector type directly
else {
else
{
emitType(vectorType);
}

m_writer->emit("(");
for (int i = 0; i < N; ++i)
for (int i = 0; i < N; ++i)
{
emitType(elementType);
m_writer->emit("(");
Expand All @@ -3916,7 +3937,7 @@ void CLikeSourceEmitter::emitVecNOrScalar(IRVectorType* vectorType, std::functio
}
}

void CLikeSourceEmitter::emitBitfieldExtractImpl(IRInst* inst)
void CLikeSourceEmitter::emitBitfieldExtractImpl(IRInst* inst)
{
// If unsigned, bfue := ((val>>off)&((1u<<bts)-1))
// Else signed, bfse := (((val>>off)&((1u<<bts)-1))<<(nbts-bts)>>(nbts-bts));
Expand All @@ -3930,26 +3951,38 @@ void CLikeSourceEmitter::emitBitfieldExtractImpl(IRInst* inst)
if (vectorType)
elementType = vectorType->getElementType();

bool isSigned;
bool isSigned;
int bitWidth;
if (!tryGetIntInfo(elementType, isSigned, bitWidth))
if (!tryGetIntInfo(elementType, isSigned, bitWidth))
{
SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "non-integer element type given to bitfieldExtract");
SLANG_DIAGNOSE_UNEXPECTED(
getSink(),
SourceLoc(),
"non-integer element type given to bitfieldExtract");
return;
}

String one;
switch(bitWidth)
switch (bitWidth)
{
case 8: one = "uint8_t(1)"; break;
case 16: one = "uint16_t(1)"; break;
case 32: one = "uint32_t(1)"; break;
case 64: one = "uint64_t(1)"; break;
default: SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unexpected bit width");
case 8:
one = "uint8_t(1)";
break;
case 16:
one = "uint16_t(1)";
break;
case 32:
one = "uint32_t(1)";
break;
case 64:
one = "uint64_t(1)";
break;
default:
SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unexpected bit width");
}

// Emit open paren and type cast for later sign extension
if (isSigned)
if (isSigned)
{
m_writer->emit("(");
emitType(inst->getDataType());
Expand All @@ -3960,50 +3993,55 @@ void CLikeSourceEmitter::emitBitfieldExtractImpl(IRInst* inst)
m_writer->emit("((");
emitOperand(val, getInfo(EmitOp::General));
m_writer->emit(">>");
emitVecNOrScalar(vectorType, [&]() {
emitOperand(off, getInfo(EmitOp::General));
});
emitVecNOrScalar(vectorType, [&]() { emitOperand(off, getInfo(EmitOp::General)); });
m_writer->emit(")&(");
emitVecNOrScalar(vectorType, [&]() {
m_writer->emit("((" + one + "<<");
emitOperand(bts, getInfo(EmitOp::General));
m_writer->emit(")-" + one + ")");
});
emitVecNOrScalar(
vectorType,
[&]()
{
m_writer->emit("((" + one + "<<");
emitOperand(bts, getInfo(EmitOp::General));
m_writer->emit(")-" + one + ")");
});
m_writer->emit("))");

// Emit sign extension logic
// (type(bitfield<<(numBits-bts))>>(numBits-bts))
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
if (isSigned)
if (isSigned)
{
m_writer->emit("<<");
emitVecNOrScalar(vectorType, [&]()
{
m_writer->emit("(");
m_writer->emit(bitWidth);
m_writer->emit("-");
emitOperand(bts, getInfo(EmitOp::General));
m_writer->emit(")");
});
emitVecNOrScalar(
vectorType,
[&]()
{
m_writer->emit("(");
m_writer->emit(bitWidth);
m_writer->emit("-");
emitOperand(bts, getInfo(EmitOp::General));
m_writer->emit(")");
});
m_writer->emit(")>>");
emitVecNOrScalar(vectorType, [&]()
{
m_writer->emit("(");
m_writer->emit(bitWidth);
m_writer->emit("-");
emitOperand(bts, getInfo(EmitOp::General));
m_writer->emit(")");
});
emitVecNOrScalar(
vectorType,
[&]()
{
m_writer->emit("(");
m_writer->emit(bitWidth);
m_writer->emit("-");
emitOperand(bts, getInfo(EmitOp::General));
m_writer->emit(")");
});
m_writer->emit(")");
}
}

void CLikeSourceEmitter::emitBitfieldInsertImpl(IRInst* inst)
void CLikeSourceEmitter::emitBitfieldInsertImpl(IRInst* inst)
{
// uint clearMask = ~(((1u << bits) - 1u) << offset);
// uint clearedBase = base & clearMask;
// uint maskedInsert = (insert & ((1u << bits) - 1u)) << offset;
// BitfieldInsert := T(uint(clearedBase) | uint(maskedInsert));
// BitfieldInsert := T(uint(clearedBase) | uint(maskedInsert));
Slang::IRType* dataType = inst->getDataType();
Slang::IRInst* bse = inst->getOperand(0);
Slang::IRInst* ins = inst->getOperand(1);
Expand All @@ -4019,55 +4057,70 @@ void CLikeSourceEmitter::emitBitfieldInsertImpl(IRInst* inst)
int bitWidth;
if (!tryGetIntInfo(elementType, isSigned, bitWidth))
{
SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "non-integer element type given to bitfieldInsert");
SLANG_DIAGNOSE_UNEXPECTED(
getSink(),
SourceLoc(),
"non-integer element type given to bitfieldInsert");
return;
}

String one;
switch(bitWidth) {
case 8: one = "uint8_t(1)"; break;
case 16: one = "uint16_t(1)"; break;
case 32: one = "uint32_t(1)"; break;
case 64: one = "uint64_t(1)"; break;
default: SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unexpected bit width");
switch (bitWidth)
{
case 8:
one = "uint8_t(1)";
break;
case 16:
one = "uint16_t(1)";
break;
case 32:
one = "uint32_t(1)";
break;
case 64:
one = "uint64_t(1)";
break;
default:
SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unexpected bit width");
}

m_writer->emit("((");

// emit clearedBase := uint(bse & ~(((1u<<bts)-1u)<<off))
emitOperand(bse, getInfo(EmitOp::General));
m_writer->emit("&");
emitVecNOrScalar(vectorType, [&]()
{
m_writer->emit("~(((" + one + "<<");
emitOperand(bts, getInfo(EmitOp::General));
m_writer->emit(")-" + one + ")<<");
emitOperand(off, getInfo(EmitOp::General));
m_writer->emit(")");
});


emitVecNOrScalar(
vectorType,
[&]()
{
m_writer->emit("~(((" + one + "<<");
emitOperand(bts, getInfo(EmitOp::General));
m_writer->emit(")-" + one + ")<<");
emitOperand(off, getInfo(EmitOp::General));
m_writer->emit(")");
});

// bitwise or clearedBase with maskedInsert
m_writer->emit(")|(");

// Emit maskedInsert := ((insert & ((1u << bits) - 1u)) << offset);

// - first emit mask := (insert & ((1u << bits) - 1u))
m_writer->emit("(");
emitOperand(ins, getInfo(EmitOp::General));
m_writer->emit("&");
emitVecNOrScalar(vectorType, [&](){
m_writer->emit("(" + one + "<<");
emitOperand(bts, getInfo(EmitOp::General));
m_writer->emit(")-" + one);
});
emitVecNOrScalar(
vectorType,
[&]()
{
m_writer->emit("(" + one + "<<");
emitOperand(bts, getInfo(EmitOp::General));
m_writer->emit(")-" + one);
});
m_writer->emit(")");

// then emit shift := << offset
m_writer->emit("<<");
emitVecNOrScalar(vectorType, [&](){
emitOperand(off, getInfo(EmitOp::General));
});
emitVecNOrScalar(vectorType, [&]() { emitOperand(off, getInfo(EmitOp::General)); });
m_writer->emit("))");
}

Expand Down
2 changes: 1 addition & 1 deletion source/slang/slang-emit-c-like.h
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ class CLikeSourceEmitter : public SourceEmitterBase
SLANG_UNUSED(baseName);
}

bool tryGetIntInfo(IRType* elementType, bool &isSigned, int &bitWidth);
bool tryGetIntInfo(IRType* elementType, bool& isSigned, int& bitWidth);
void emitVecNOrScalar(IRVectorType* vectorType, std::function<void()> func);
virtual void emitBitfieldExtractImpl(IRInst* inst);
virtual void emitBitfieldInsertImpl(IRInst* inst);
Expand Down
15 changes: 8 additions & 7 deletions source/slang/slang-emit-metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,8 @@ bool MetalSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inO

void MetalSourceEmitter::emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount)
{
// NM: Passing count here, as Metal 64-bit vector type names do not match their scalar equivalents.
// NM: Passing count here, as Metal 64-bit vector type names do not match their scalar
// equivalents.
emitSimpleTypeKnowingCount(elementType, elementCount);

switch (elementType->getOp())
Expand Down Expand Up @@ -1042,10 +1043,10 @@ void MetalSourceEmitter::emitParamTypeImpl(IRType* type, String const& name)

void MetalSourceEmitter::emitSimpleTypeKnowingCount(IRType* type, IRIntegerValue elementCount)
{
// NM: note, "ulong/ushort" is only type that works for i16/i64 vec, but can't be used for scalars.
// (See metal specification pg 26)
// NM: note, "ulong/ushort" is only type that works for i16/i64 vec, but can't be used for
// scalars. (See metal specification pg 26)

switch (type->getOp())
switch (type->getOp())
{
case kIROp_VoidType:
case kIROp_BoolType:
Expand All @@ -1063,7 +1064,7 @@ void MetalSourceEmitter::emitSimpleTypeKnowingCount(IRType* type, IRIntegerValue
case kIROp_Int64Type:
m_writer->emit("long");
return;
case kIROp_UInt64Type:
case kIROp_UInt64Type:
if (elementCount > 1)
m_writer->emit("ulong");
else
Expand All @@ -1073,7 +1074,7 @@ void MetalSourceEmitter::emitSimpleTypeKnowingCount(IRType* type, IRIntegerValue
m_writer->emit("short");
return;
case kIROp_UInt16Type:
if (elementCount > 1)
if (elementCount > 1)
m_writer->emit("ushort");
else
m_writer->emit("uint16_t");
Expand All @@ -1082,7 +1083,7 @@ void MetalSourceEmitter::emitSimpleTypeKnowingCount(IRType* type, IRIntegerValue
m_writer->emit("long");
return;
case kIROp_UIntPtrType:
if (elementCount > 1)
if (elementCount > 1)
m_writer->emit("ulong");
else
m_writer->emit("uint64_t");
Expand Down
Loading

0 comments on commit d0bd3f9

Please sign in to comment.