Skip to content

Commit

Permalink
Add comments.
Browse files Browse the repository at this point in the history
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
  • Loading branch information
mzient committed Sep 4, 2024
1 parent a0df342 commit 2a11ce6
Showing 1 changed file with 44 additions and 4 deletions.
48 changes: 44 additions & 4 deletions dali/pipeline/executor/executor2/stream_assignment.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,9 @@ class StreamAssignment<StreamPolicy::PerOperator> {

void RunAssignment(ExecGraph &graph) {
graph::ClearVisitMarkers(graph.Nodes());
// Process the nodes in reverse topological order, from outputs to inputs
for (int i = sorted_nodes_.size() - 1; i >= 0; i--) {
// Run DFS-based assignment algorithm
ProcessNode(sorted_nodes_[i].first, sorted_nodes_[i].second);
}
}
Expand All @@ -268,21 +270,50 @@ class StreamAssignment<StreamPolicy::PerOperator> {
if (!v)
return;

/* The algorithm
Each node has an associated NodeMeta, which contains the stream assignment and a set
of "ready streams". This is the set of streams which are ready after this node is complete.
It includes the streams of the producers + the stream of this node.
For each node (starting from outputs):
0. Check and set the visit marker
- early-exit if already set
1. Call recursively for the producers of this node's inputs
2a. Pick the producers' ready stream with ths smallest id.
2b. If there are no ready streams, bump the total number of streams and get a new stream id.
3. Go over inputs and erase the currently selected stream from their ready sets
- if the node's ready set contained the id, erase it recursively from this node's
producers' and consumers' ready sets.
4. Compute the current ready set as the union of input ready sets + the current stream id.
*/

auto &stream_id = meta->stream_id;
assert(!stream_id.has_value());

for (auto &e : node->inputs) {
auto &prod_meta = node_meta_[e->producer];
ProcessNode(e->producer, &prod_meta);
// If we're a GPU contributor (and therefore we need a stream assignment), check if the
// producer's ready stream set contains something.
if (meta->gpu_contributor && !prod_meta.ready_streams.empty()) {
// OK, we got ourselves a stream id - is it better than the current assignment?
if (!stream_id.has_value() || *stream_id > *prod_meta.ready_streams.begin()) {
stream_id = *prod_meta.ready_streams.begin();
}
}
// Compute our current ready stream set as a union of producers' ready sets.
meta->ready_streams.insert(prod_meta.ready_streams.begin(), prod_meta.ready_streams.end());
}

if (stream_id.has_value()) {
// We're taking this stream_id, so we want to prevent it from being reused - thus,
// we remove it (recursively) from the ready sets of connected nodes.
for (auto &e : node->inputs) {
EraseReady(e->producer, *stream_id, node);
}
Expand All @@ -296,16 +327,25 @@ class StreamAssignment<StreamPolicy::PerOperator> {
assert(!stream_id.has_value() || meta->ready_streams.count(*stream_id));
}

void EraseReady(const ExecNode *node, int id, const ExecNode *sentinel) {
/** Recursively removes a stream id from the ready sets until it's not found.
*
* This function tries to erase `id_to_remove` from the ready set of the `node`.
* If the removal succeeds (i.e. the id was there), then flood to the neigboring nodes.
*
* The pessimistic case happens when there's a long path connecting heavily branched regions.
* In such cases, the complexity is O(n*l) where n is the number of ids and l is the depth of
* the subgraph where those ids appear.
*/
void EraseReady(const ExecNode *node, int id_to_remove, const ExecNode *sentinel) {
if (node == sentinel)
return;
auto &meta = node_meta_[node];
if (meta.ready_streams.erase(id)) {
if (meta.ready_streams.erase(id_to_remove)) {
for (auto &e : node->inputs)
EraseReady(e->producer, id, sentinel);
EraseReady(e->producer, id_to_remove, sentinel);
for (auto &out : node->outputs)
for (auto &e : out.consumers)
EraseReady(e->consumer, id, sentinel);
EraseReady(e->consumer, id_to_remove, sentinel);
}
}

Expand Down

0 comments on commit 2a11ce6

Please sign in to comment.