Skip to content

Commit

Permalink
Trim array in first pass of radix_sort
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew Evstyukhin committed Nov 18, 2022
1 parent 4278993 commit 9078172
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 28 deletions.
65 changes: 61 additions & 4 deletions src/Bc7Core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ void AreaReduceTable4(__m128i& mc, uint64_t& indices) noexcept
}
}

NOTINLINED Node* radix_sort(Node* input, Node* work, size_t N) noexcept
NOTINLINED void radix_sort_partitions(Node* input, Node* work, size_t N) noexcept
{
constexpr size_t radix = 7;
constexpr size_t bucketCount = 1 << radix;
Expand Down Expand Up @@ -612,10 +612,13 @@ NOTINLINED Node* radix_sort(Node* input, Node* work, size_t N) noexcept
Node* C = A; A = B; B = C;
}

return A;
if (A != input)
{
memcpy(input, work, sizeof(Node) * N);
}
}

NOTINLINED NodeShort* radix_sort(NodeShort* input, NodeShort* work, size_t N) noexcept
NOTINLINED NodeShort* radix_sort(NodeShort* input, NodeShort* work, size_t N, size_t MaxSize) noexcept
{
constexpr size_t radix = 6;
constexpr size_t bucketCount = 1 << radix;
Expand Down Expand Up @@ -649,7 +652,61 @@ NOTINLINED NodeShort* radix_sort(NodeShort* input, NodeShort* work, size_t N) no
}
}

for (size_t shift = 16; shift < 32; shift += radix)
size_t maxShift = 32 - radix;
if (N > MaxSize)
{
#if defined(OPTION_AVX2)
maxShift -= _lzcnt_u32(any_error | (1u << (16 + radix - 1)));
#else
maxShift = 16;
for (uint32_t v = (any_error >> (maxShift + radix)); v != 0; v >>= 1)
{
maxShift++;
}
#endif

for (size_t i = 0; i < bucketCount; i++)
{
counts[i] = 0;
}

for (size_t i = 0; i < N; i++)
{
size_t value = A[i].ColorError;
counts[value >> maxShift]++;
}

size_t water = 0;
uint32_t total = 0;
do
{
uint32_t oldCount = counts[water];
counts[water++] = total;
total += oldCount;
} while (total < static_cast<uint32_t>(MaxSize));

if (total < N)
{
size_t w = 0;
uint32_t top = static_cast<uint32_t>(water) << maxShift;
for (size_t i = 0; i < N; i++)
{
NodeShort val = A[i];
A[w] = val;
w += val.ColorError < top;
}
N = w;

#if defined(OPTION_AVX2)
maxShift = 32 - radix;
uint32_t zeroes = _lzcnt_u32((top - 1) | (1u << (16 + radix - 1)));
maxShift -= zeroes;
any_error &= ~0u >> zeroes;
#endif
}
}

for (size_t shift = 16; shift < maxShift + radix; shift += radix)
{
if (((any_error >> shift) & bucketMask) == 0)
continue;
Expand Down
4 changes: 2 additions & 2 deletions src/Bc7Core.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ void AreaReduceTable2(const Area& area, __m128i& mc, uint64_t& indices) noexcept
void AreaReduceTable3(const Area& area, __m128i& mc, uint64_t& indices) noexcept;
void AreaReduceTable4(__m128i& mc, uint64_t& indices) noexcept;

NOTINLINED Node* radix_sort(Node* input, Node* work, size_t N) noexcept;
NOTINLINED NodeShort* radix_sort(NodeShort* input, NodeShort* work, size_t N) noexcept;
NOTINLINED void radix_sort_partitions(Node* input, Node* work, size_t N) noexcept;
NOTINLINED NodeShort* radix_sort(NodeShort* input, NodeShort* work, size_t N, size_t MaxSize) noexcept;

NOTINLINED int ComputeSubsetTable2(const Area& area, const __m128i mweights, Modulations& state) noexcept;
NOTINLINED int ComputeSubsetTable3(const Area& area, const __m128i mweights, Modulations& state) noexcept;
Expand Down
5 changes: 1 addition & 4 deletions src/Bc7CoreMode0.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -578,10 +578,7 @@ namespace Mode0 {

{
Node order2[16];
if (radix_sort(order, order2, partitionsCount) == order2)
{
memcpy(order, order2, sizeof(Node) * partitionsCount);
}
radix_sort_partitions(order, order2, partitionsCount);
}

for (size_t i = 0; i < partitionsCount; i++)
Expand Down
5 changes: 1 addition & 4 deletions src/Bc7CoreMode1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,10 +508,7 @@ namespace Mode1 {

{
Node order2[64];
if (radix_sort(order, order2, partitionsCount) == order2)
{
memcpy(order, order2, sizeof(Node) * partitionsCount);
}
radix_sort_partitions(order, order2, partitionsCount);
}

for (size_t i = 0; i < partitionsCount; i++)
Expand Down
5 changes: 1 addition & 4 deletions src/Bc7CoreMode2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,10 +502,7 @@ namespace Mode2 {

{
Node order2[64];
if (radix_sort(order, order2, partitionsCount) == order2)
{
memcpy(order, order2, sizeof(Node) * partitionsCount);
}
radix_sort_partitions(order, order2, partitionsCount);
}

for (size_t i = 0; i < partitionsCount; i++)
Expand Down
5 changes: 1 addition & 4 deletions src/Bc7CoreMode3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -507,10 +507,7 @@ namespace Mode3 {

{
Node order2[64];
if (radix_sort(order, order2, partitionsCount) == order2)
{
memcpy(order, order2, sizeof(Node) * partitionsCount);
}
radix_sort_partitions(order, order2, partitionsCount);
}

for (size_t i = 0; i < partitionsCount; i++)
Expand Down
5 changes: 1 addition & 4 deletions src/Bc7CoreMode7.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -667,10 +667,7 @@ namespace Mode7 {

{
Node order2[64];
if (radix_sort(order, order2, partitionsCount) == order2)
{
memcpy(order, order2, sizeof(Node) * partitionsCount);
}
radix_sort_partitions(order, order2, partitionsCount);
}

for (size_t i = 0; i < partitionsCount; i++)
Expand Down
2 changes: 1 addition & 1 deletion src/SnippetLevelsBuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ class LevelsBuffer final
NodeShort* sorted = nodes1;
if (int nodesCount = int(nodesPtr - sorted); nodesCount)
{
sorted = radix_sort(sorted, nodes2, static_cast<uint32_t>(nodesCount));
sorted = radix_sort(sorted, nodes2, static_cast<uint32_t>(nodesCount), MaxSize);

MinErr = (sorted[0].ColorError >> 16) * weight;
Count = Min(nodesCount, MaxSize);
Expand Down
2 changes: 1 addition & 1 deletion src/SnippetLevelsBufferHalf.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ class LevelsBufferHalf final
{
if (nodesCount)
{
sorted = radix_sort(sorted, nodes2, (uint32_t)nodesCount);
sorted = radix_sort(sorted, nodes2, (uint32_t)nodesCount, MaxSize);

MinErr = (sorted[0].ColorError >> 16) * weight;
Count = Min(nodesCount, MaxSize);
Expand Down

0 comments on commit 9078172

Please sign in to comment.