Skip to content

Commit

Permalink
Fix XUSGComputeUtil and support multi-miss shaders
Browse files Browse the repository at this point in the history
  • Loading branch information
StarsX committed Oct 5, 2023
1 parent 362997b commit 8f09dfb
Show file tree
Hide file tree
Showing 10 changed files with 35 additions and 24 deletions.
4 changes: 3 additions & 1 deletion Optional/Shaders/CSPrefixSum.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
// Copyright (c) XU, Tianchen. All rights reserved.
//--------------------------------------------------------------------------------------

#ifndef GROUP_SIZE
#define GROUP_SIZE 1024
#endif

// Min and max wave sizes supported
#define MIN_WAVE_SIZE 32
Expand Down Expand Up @@ -104,7 +106,7 @@ void WaitForSlowestGroup()
}

//--------------------------------------------------------------------------------------
// 1-pass global prefix-sum for at most 1024 * 1024 uints
// Single-pass global prefix-sum for at most GROUP_SIZE * GROUP_SIZE uints
//-------------------------------------------------------------------------------------
[numthreads(GROUP_SIZE, 1, 1)]
void main(uint DTid : SV_DispatchThreadID, uint GIdx : SV_GroupIndex, uint Gid : SV_GroupID)
Expand Down
2 changes: 2 additions & 0 deletions Optional/Shaders/CSPrefixSum2.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
// Copyright (c) XU, Tianchen. All rights reserved.
//--------------------------------------------------------------------------------------

#ifndef GROUP_SIZE
#define GROUP_SIZE 1024
#endif

//--------------------------------------------------------------------------------------
// Buffer
Expand Down
2 changes: 1 addition & 1 deletion Optional/XUSGComputeUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ bool ComputeUtil::SetPrefixSum(CommandList* pCommandList, bool safeMode,

// Create resources
m_counter = TypedBuffer::MakeUnique();
XUSG_N_RETURN(m_counter->Create(pDevice, 1, sizeof(uint32_t), Format::R32_FLOAT,
XUSG_N_RETURN(m_counter->Create(pDevice, 1, sizeof(uint32_t), Format::R32_UINT,
ResourceFlag::ALLOW_UNORDERED_ACCESS | ResourceFlag::DENY_SHADER_RESOURCE,
MemoryType::DEFAULT, 0, nullptr, 1, nullptr, MemoryFlag::NONE,
L"GlobalBarrierCounter"), false);
Expand Down
7 changes: 4 additions & 3 deletions XUSGRayTracing-EZ/XUSGRayTracing-EZ.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,13 @@ namespace XUSG
HitGroupType type = HitGroupType::TRIANGLES) = 0;
virtual void RTSetShaderConfig(uint32_t maxPayloadSize, uint32_t maxAttributeSize = sizeof(float[2])) = 0;
virtual void RTSetMaxRecursionDepth(uint32_t depth) = 0;
virtual void DispatchRays(uint32_t width, uint32_t height, uint32_t depth,
const wchar_t* rayGenShaderName, const wchar_t* missShaderName) = 0;
virtual void DispatchRays(uint32_t width, uint32_t height, uint32_t depth, const wchar_t* rayGenShaderName,
const wchar_t* const* pMissShaderNames, uint32_t numMissShaders) = 0;
virtual void DispatchRaysIndirect(const CommandLayout* pCommandlayout,
uint32_t maxCommandCount,
const wchar_t* rayGenShaderName,
const wchar_t* missShaderName,
const wchar_t* const* pMissShaderNames,
uint32_t numMissShaders,
Resource* pArgumentBuffer,
uint64_t argumentBufferOffset = 0,
Resource* pCountBuffer = nullptr,
Expand Down
27 changes: 16 additions & 11 deletions XUSGRayTracing-EZ/XUSGRayTracing-EZ_DX12.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,26 +296,26 @@ void EZ::CommandList_DXR::RTSetMaxRecursionDepth(uint32_t depth)
}

void EZ::CommandList_DXR::DispatchRays(uint32_t width, uint32_t height, uint32_t depth,
const wchar_t* rayGenShaderName, const wchar_t* missShaderName)
const wchar_t* rayGenShaderName, const wchar_t* const* pMissShaderNames, uint32_t numMissShaders)
{
CShaderTablePtr pRayGen = nullptr;
CShaderTablePtr pHitGroup = nullptr;
CShaderTablePtr pMiss = nullptr;

predispatchRays(pRayGen, pHitGroup, pMiss, rayGenShaderName, missShaderName);
predispatchRays(pRayGen, pHitGroup, pMiss, rayGenShaderName, pMissShaderNames, numMissShaders);
RayTracing::CommandList_DX12::DispatchRays(width, height, depth, pRayGen, pHitGroup, pMiss);
}

void EZ::CommandList_DXR::DispatchRaysIndirect(const CommandLayout* pCommandlayout, uint32_t maxCommandCount,
const wchar_t* rayGenShaderName, const wchar_t* missShaderName, Resource* pArgumentBuffer, uint64_t argumentBufferOffset,
Resource* pCountBuffer, uint64_t countBufferOffset)
const wchar_t* rayGenShaderName, const wchar_t* const* pMissShaderNames, uint32_t numMissShaders,
Resource* pArgumentBuffer, uint64_t argumentBufferOffset, Resource* pCountBuffer, uint64_t countBufferOffset)
{
CShaderTablePtr pRayGen = nullptr;
CShaderTablePtr pHitGroup = nullptr;
CShaderTablePtr pMiss = nullptr;

preexecuteIndirect(pArgumentBuffer, pCountBuffer);
predispatchRays(pRayGen, pHitGroup, pMiss, rayGenShaderName, missShaderName);
predispatchRays(pRayGen, pHitGroup, pMiss, rayGenShaderName, pMissShaderNames, numMissShaders);

// Populate arg-buffer header with shader tables
string key(sizeof(ArgumentBufferVA), '\0');
Expand Down Expand Up @@ -417,7 +417,7 @@ bool EZ::CommandList_DXR::createComputePipelineLayouts(uint32_t maxSamplers, con
}

void EZ::CommandList_DXR::predispatchRays(CShaderTablePtr& pRayGen, CShaderTablePtr& pHitGroup, CShaderTablePtr& pMiss,
const wchar_t* rayGenShaderName, const wchar_t* missShaderName)
const wchar_t* rayGenShaderName, const wchar_t* const* pMissShaderNames, uint32_t numMissShaders)
{
// Set barrier command
XUSG::CommandList_DX12::Barrier(static_cast<uint32_t>(m_barriers.size()), m_barriers.data());
Expand Down Expand Up @@ -494,12 +494,17 @@ void EZ::CommandList_DXR::predispatchRays(CShaderTablePtr& pRayGen, CShaderTable
pRayGen = getShaderTable(key, m_rayGenTables, 1);
}

if (missShaderName)
// Get miss shader tables; the miss shaders are independent of the pipeline
if (pMissShaderNames)
{
string key(sizeof(void*), '\0');
auto& shaderID = reinterpret_cast<const void*&>(key[0]);
shaderID = ShaderRecord::GetShaderID(m_pipeline, missShaderName, API::DIRECTX_12);
pMiss = getShaderTable(key, m_missTables, 1);
string key(sizeof(void*) * numMissShaders, '\0');
const auto pShaderIDs = reinterpret_cast<const void**>(&key[0]);
for (auto i = 0u; i < numMissShaders; ++i) // Set hit-group shader IDs into the key
{
assert(pMissShaderNames[i]);
pShaderIDs[i] = ShaderRecord::GetShaderID(m_pipeline, pMissShaderNames[i], API::DIRECTX_12);
}
pMiss = getShaderTable(key, m_missTables, numMissShaders);
}
}

Expand Down
9 changes: 5 additions & 4 deletions XUSGRayTracing-EZ/XUSGRayTracing-EZ_DX12.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,13 @@ namespace XUSG
const wchar_t* anyHitShaderName = nullptr, const wchar_t* intersectionShaderName = nullptr,
HitGroupType type = HitGroupType::TRIANGLES);
void RTSetMaxRecursionDepth(uint32_t depth);
void DispatchRays(uint32_t width, uint32_t height, uint32_t depth,
const wchar_t* rayGenShaderName, const wchar_t* missShaderName);
void DispatchRays(uint32_t width, uint32_t height, uint32_t depth, const wchar_t* rayGenShaderName,
const wchar_t* const* pMissShaderNames, uint32_t numMissShaders);
void DispatchRaysIndirect(const CommandLayout* pCommandlayout,
uint32_t maxCommandCount,
const wchar_t* rayGenShaderName,
const wchar_t* missShaderName,
const wchar_t* const* pMissShaderNames,
uint32_t numMissShaders,
Resource* pArgumentBuffer,
uint64_t argumentBufferOffset = 0,
Resource* pCountBuffer = nullptr,
Expand Down Expand Up @@ -113,7 +114,7 @@ namespace XUSG
uint32_t maxTLASSrvs, uint32_t spaceTLAS);

void predispatchRays(CShaderTablePtr& pRayGen, CShaderTablePtr& pHitGroup, CShaderTablePtr& pMiss,
const wchar_t* rayGenShaderName, const wchar_t* missShaderName);
const wchar_t* rayGenShaderName, const wchar_t* const* pMissShaderNames, uint32_t numMissShaders);

Resource* needScratch(uint32_t size);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ void TopLevelAS_DX12::Build(const CommandList* pCommandList, const Resource* pSc
}

void TopLevelAS_DX12::SetInstances(const RayTracing::Device* pDevice, Resource* pInstances,
uint32_t numInstances, const BottomLevelAS* const* ppBottomLevelASs, float* const* transforms)
uint32_t numInstances, const BottomLevelAS* const* ppBottomLevelASs, const float* const* transforms)
{
assert(numInstances == 0 || (ppBottomLevelASs && transforms));

Expand Down
2 changes: 1 addition & 1 deletion XUSGRayTracing/RayTracing/XUSGAccelerationStructure_DX12.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ namespace XUSG
const TopLevelAS* pSource = nullptr);

static void SetInstances(const Device* pDevice, Resource* pInstances, uint32_t numInstances,
const BottomLevelAS* const* ppBottomLevelASs, float* const* transforms);
const BottomLevelAS* const* ppBottomLevelASs, const float* const* transforms);
static void SetInstances(const Device* pDevice, Resource* pInstances, uint32_t numInstances,
const InstanceDesc* pInstanceDescs);
};
Expand Down
2 changes: 1 addition & 1 deletion XUSGRayTracing/RayTracing/XUSGRayTracing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ BottomLevelAS::sptr BottomLevelAS::MakeShared(API api)

void TopLevelAS::SetInstances(const Device* pDevice, Resource* pInstances,
uint32_t numInstances, const BottomLevelAS* const* pBottomLevelASs,
float* const* transforms, API api)
const float* const* transforms, API api)
{
TopLevelAS_DX12::SetInstances(pDevice, pInstances, numInstances, pBottomLevelASs, transforms);
}
Expand Down
2 changes: 1 addition & 1 deletion XUSGRayTracing/RayTracing/XUSGRayTracing.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ namespace XUSG

static void SetInstances(const Device* pDevice, Resource* pInstances,
uint32_t numInstances, const BottomLevelAS* const* ppBottomLevelASs,
float* const* transforms, API api = API::DIRECTX_12);
const float* const* transforms, API api = API::DIRECTX_12);
static void SetInstances(const Device* pDevice, Resource* pInstances,
uint32_t numInstances, const InstanceDesc* pInstanceDescs,
API api = API::DIRECTX_12);
Expand Down

0 comments on commit 8f09dfb

Please sign in to comment.