Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Draft of segmented reduce optimization #578

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 125 additions & 27 deletions cub/agent/agent_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,32 @@ struct AgentReducePolicy : ScalingType
static constexpr CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER;
};

template <int BLOCK_THREADS,
int NOMINAL_WARP_THREADS_4B,
int NOMINAL_ITEMS_PER_THREAD_4B,
typename ComputeT,
int _VECTOR_LOAD_LENGTH,
CacheLoadModifier _LOAD_MODIFIER>
struct AgentWarpReducePolicy
{
// TODO MemBoundScaling-like computation
static constexpr int ITEMS_PER_THREAD = NOMINAL_ITEMS_PER_THREAD_4B;

static constexpr int WARP_THREADS = NOMINAL_WARP_THREADS_4B;

/// Number of items per vectorized load
static constexpr int VECTOR_LOAD_LENGTH = _VECTOR_LOAD_LENGTH;

/// Cache load modifier for reading input elements
static constexpr CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER;

constexpr static int ITEMS_PER_TILE = ITEMS_PER_THREAD * WARP_THREADS;

constexpr static int SEGMENTS_PER_BLOCK = BLOCK_THREADS / WARP_THREADS;

static_assert((BLOCK_THREADS % WARP_THREADS) == 0, "Block should be multiple of warp");
};

/******************************************************************************
* Thread block abstractions
******************************************************************************/
Expand Down Expand Up @@ -116,8 +142,10 @@ template <typename AgentReducePolicy,
typename OutputIteratorT,
typename OffsetT,
typename ReductionOp,
typename AccumT>
struct AgentReduce
typename AccumT,
typename CollectiveReduceT,
int THREADS>
struct AgentReduceImpl
{
//---------------------------------------------------------------------
// Types and constants
Expand All @@ -139,9 +167,8 @@ struct AgentReduce
InputIteratorT>;

/// Constants
static constexpr int BLOCK_THREADS = AgentReducePolicy::BLOCK_THREADS;
static constexpr int ITEMS_PER_THREAD = AgentReducePolicy::ITEMS_PER_THREAD;
static constexpr int TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD;
static constexpr int TILE_ITEMS = THREADS * ITEMS_PER_THREAD;
static constexpr int VECTOR_LOAD_LENGTH =
CUB_MIN(ITEMS_PER_THREAD, AgentReducePolicy::VECTOR_LOAD_LENGTH);

Expand All @@ -155,17 +182,10 @@ struct AgentReduce
static constexpr CacheLoadModifier LOAD_MODIFIER =
AgentReducePolicy::LOAD_MODIFIER;

static constexpr BlockReduceAlgorithm BLOCK_ALGORITHM =
AgentReducePolicy::BLOCK_ALGORITHM;

/// Parameterized BlockReduce primitive
using BlockReduceT =
BlockReduce<AccumT, BLOCK_THREADS, AgentReducePolicy::BLOCK_ALGORITHM>;

/// Shared memory type required by this thread block
struct _TempStorage
{
typename BlockReduceT::TempStorage reduce;
typename CollectiveReduceT::TempStorage reduce;
};

/// Alias wrapper allowing storage to be unioned
Expand All @@ -177,9 +197,10 @@ struct AgentReduce
//---------------------------------------------------------------------

_TempStorage &temp_storage; ///< Reference to temp_storage
InputIteratorT d_in; ///< Input data to reduce
unsigned int lane_id;
WrappedInputIteratorT d_wrapped_in; ///< Wrapped input data to reduce
ReductionOp reduction_op; ///< Binary reduction operator
InputIteratorT d_in; ///< Input data to reduce

//---------------------------------------------------------------------
// Utility
Expand Down Expand Up @@ -213,13 +234,15 @@ struct AgentReduce
* @param d_in Input data to reduce
* @param reduction_op Binary reduction operator
*/
__device__ __forceinline__ AgentReduce(TempStorage &temp_storage,
InputIteratorT d_in,
ReductionOp reduction_op)
__device__ __forceinline__ AgentReduceImpl(TempStorage &temp_storage,
InputIteratorT d_in,
ReductionOp reduction_op,
unsigned int lane_id)
: temp_storage(temp_storage.Alias())
, d_in(d_in)
, d_wrapped_in(d_in)
, reduction_op(reduction_op)
, lane_id(lane_id)
{}

//---------------------------------------------------------------------
Expand All @@ -243,9 +266,9 @@ struct AgentReduce
AccumT items[ITEMS_PER_THREAD];

// Load items in striped fashion
LoadDirectStriped<BLOCK_THREADS>(threadIdx.x,
d_wrapped_in + block_offset,
items);
LoadDirectStriped<THREADS>(lane_id,
d_wrapped_in + block_offset,
items);

// Reduce items within each thread stripe
thread_aggregate =
Expand Down Expand Up @@ -276,7 +299,7 @@ struct AgentReduce

// Fabricate a vectorized input iterator
InputT *d_in_unqualified = const_cast<InputT *>(d_in) + block_offset +
(threadIdx.x * VECTOR_LOAD_LENGTH);
(lane_id * VECTOR_LOAD_LENGTH);
CacheModifiedInputIterator<AgentReducePolicy::LOAD_MODIFIER, VectorT, OffsetT>
d_vec_in(reinterpret_cast<VectorT *>(d_in_unqualified));

Expand All @@ -286,7 +309,7 @@ struct AgentReduce
#pragma unroll
for (int i = 0; i < WORDS; ++i)
{
vec_items[i] = d_vec_in[BLOCK_THREADS * i];
vec_items[i] = d_vec_in[THREADS * i];
}

// Convert from input type to output type
Expand Down Expand Up @@ -320,13 +343,13 @@ struct AgentReduce
Int2Type<CAN_VECTORIZE> /*can_vectorize*/)
{
// Partial tile
int thread_offset = threadIdx.x;
int thread_offset = lane_id;

// Read first item
if ((IS_FIRST_TILE) && (thread_offset < valid_items))
{
thread_aggregate = d_wrapped_in[block_offset + thread_offset];
thread_offset += BLOCK_THREADS;
thread_offset += THREADS;
}

// Continue reading items (block-striped)
Expand All @@ -335,7 +358,7 @@ struct AgentReduce
InputT item(d_wrapped_in[block_offset + thread_offset]);

thread_aggregate = reduction_op(thread_aggregate, item);
thread_offset += BLOCK_THREADS;
thread_offset += THREADS;
}
}

Expand Down Expand Up @@ -364,8 +387,12 @@ struct AgentReduce
valid_items,
Int2Type<false>(),
can_vectorize);
return BlockReduceT(temp_storage.reduce)
.Reduce(thread_aggregate, reduction_op, valid_items);

// TODO Extract clamping into the SFINAE to keep block version as is
int num_valid = (THREADS <= valid_items) ? THREADS : valid_items;

return CollectiveReduceT(temp_storage.reduce)
.Reduce(thread_aggregate, reduction_op, num_valid);
}

// At least one full block
Expand Down Expand Up @@ -399,7 +426,7 @@ struct AgentReduce
}

// Compute block-wide reduction (all threads have valid items)
return BlockReduceT(temp_storage.reduce)
return CollectiveReduceT(temp_storage.reduce)
.Reduce(thread_aggregate, reduction_op);
}

Expand Down Expand Up @@ -440,5 +467,76 @@ struct AgentReduce
}
};

template <typename AgentReducePolicy,
typename InputIteratorT,
typename OutputIteratorT,
typename OffsetT,
typename ReductionOp,
typename AccumT>
struct AgentReduce : AgentReduceImpl<AgentReducePolicy,
InputIteratorT,
OutputIteratorT,
OffsetT,
ReductionOp,
AccumT,
BlockReduce<AccumT,
AgentReducePolicy::BLOCK_THREADS,
AgentReducePolicy::BLOCK_ALGORITHM>,
AgentReducePolicy::BLOCK_THREADS>
{
using base_t = AgentReduceImpl<AgentReducePolicy,
InputIteratorT,
OutputIteratorT,
OffsetT,
ReductionOp,
AccumT,
BlockReduce<AccumT,
AgentReducePolicy::BLOCK_THREADS,
AgentReducePolicy::BLOCK_ALGORITHM>,
AgentReducePolicy::BLOCK_THREADS>;

__device__ __forceinline__ AgentReduce(typename base_t::TempStorage &temp_storage,
InputIteratorT d_in,
ReductionOp reduction_op)
: base_t(temp_storage, d_in, reduction_op, threadIdx.x)
{
}
};

template <typename AgentReducePolicy,
typename InputIteratorT,
typename OutputIteratorT,
typename OffsetT,
typename ReductionOp,
typename AccumT>
struct AgentWarpReduce : AgentReduceImpl<AgentReducePolicy,
InputIteratorT,
OutputIteratorT,
OffsetT,
ReductionOp,
AccumT,
WarpReduce<AccumT,
AgentReducePolicy::WARP_THREADS>,
AgentReducePolicy::WARP_THREADS>
{
using base_t = AgentReduceImpl<AgentReducePolicy,
InputIteratorT,
OutputIteratorT,
OffsetT,
ReductionOp,
AccumT,
WarpReduce<AccumT,
AgentReducePolicy::WARP_THREADS>,
AgentReducePolicy::WARP_THREADS>;

__device__ __forceinline__ AgentWarpReduce(typename base_t::TempStorage &temp_storage,
InputIteratorT d_in,
ReductionOp reduction_op,
int lane_id)
: base_t(temp_storage, d_in, reduction_op, lane_id)
{
}
};

CUB_NAMESPACE_END

Loading