Skip to content

Commit

Permalink
Better algo.
Browse files Browse the repository at this point in the history
Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
  • Loading branch information
mzient committed Sep 6, 2024
1 parent 2a11ce6 commit a80eb0b
Showing 1 changed file with 41 additions and 39 deletions.
80 changes: 41 additions & 39 deletions dali/pipeline/executor/executor2/stream_assignment.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
#include <algorithm>
#include <cassert>
#include <functional>
#include <map>
#include <optional>
#include <queue>
#include <unordered_map>
#include <set>
#include <utility>
#include <vector>
#include "dali/pipeline/graph/graph_util.h"
Expand Down Expand Up @@ -203,7 +203,7 @@ class StreamAssignment<StreamPolicy::PerOperator> {
bool needs_stream = false; // whether this exact node needs a stream
bool gpu_contributor = false; // whether the node contributes to a GPU operator
std::optional<int> stream_id;
std::set<int> ready_streams;
std::map<int, int> ready_streams;
};

void Assign(ExecGraph &graph) {
Expand Down Expand Up @@ -275,83 +275,85 @@ class StreamAssignment<StreamPolicy::PerOperator> {
Each node has an associated NodeMeta, which contains the stream assignment and a set
of "ready streams". This is the set of streams which are ready after this node is complete.
It includes the streams of the producers + the stream of this node.
Each stream is associated with a "use count" and a ready set contains, alongside the id,
the value of the use count at the time at which the stream was inserted to the set.
Later, when trying to reuse the ready streams, the streams which were inserted with an
outdated use count are rejected.
For each node (starting from outputs):
0. Check and set the visit marker
- early-exit if already set
1. Call recursively for the producers of this node's inputs
2a. Pick the producers' ready stream with ths smallest id.
2a. Pick the producers' ready stream with ths smallest id (reject streams with stale use count).
2b. If there are no ready streams, bump the total number of streams and get a new stream id.
3. Go over inputs and erase the currently selected stream from their ready sets
- if the node's ready set contained the id, erase it recursively from this node's
producers' and consumers' ready sets.
3. Bump the use count of the currently assigned stream.
4. Compute the current ready set as the union of input ready sets + the current stream id.
When computing the union, stale streams are removed to speed up subsequent lookups.
*/

auto &stream_id = meta->stream_id;
assert(!stream_id.has_value());
std::optional<int> stream_id = {};

for (auto &e : node->inputs) {
auto &prod_meta = node_meta_[e->producer];
ProcessNode(e->producer, &prod_meta);
// If we're a GPU contributor (and therefore we need a stream assignment), check if the
// producer's ready stream set contains something.
if (meta->gpu_contributor && !prod_meta.ready_streams.empty()) {
// OK, we got ourselves a stream id - is it better than the current assignment?
if (!stream_id.has_value() || *stream_id > *prod_meta.ready_streams.begin()) {
stream_id = *prod_meta.ready_streams.begin();
if (meta->gpu_contributor) {
for (auto it = prod_meta.ready_streams.begin(); it != prod_meta.ready_streams.end(); ++it) {
auto [id, use_count] = *it;
if (use_count < stream_use_count_[id]) {
continue;
}
if (!stream_id.has_value() || *stream_id > id)
stream_id = id;
}
}
// Compute our current ready stream set as a union of producers' ready sets.
meta->ready_streams.insert(prod_meta.ready_streams.begin(), prod_meta.ready_streams.end());
// Add producer's ready set to the current one.
CombineReady(*meta, prod_meta);
}

if (stream_id.has_value()) {
// We're taking this stream_id, so we want to prevent it from being reused - thus,
// we remove it (recursively) from the ready sets of connected nodes.
for (auto &e : node->inputs) {
EraseReady(e->producer, *stream_id, node);
}
UseStream(*meta, *stream_id);
} else {
if (meta->needs_stream) {
stream_id = total_streams_++;
meta->ready_streams.insert(*stream_id);
stream_use_count_[*stream_id] = 1;
UseStream(*meta, *stream_id);
}
}

assert(!stream_id.has_value() || meta->ready_streams.count(*stream_id));
}

/** Recursively removes a stream id from the ready sets until it's not found.
*
* This function tries to erase `id_to_remove` from the ready set of the `node`.
* If the removal succeeds (i.e. the id was there), then flood to the neigboring nodes.
*
* The pessimistic case happens when there's a long path connecting heavily branched regions.
* In such cases, the complexity is O(n*l) where n is the number of ids and l is the depth of
* the subgraph where those ids appear.
*/
void EraseReady(const ExecNode *node, int id_to_remove, const ExecNode *sentinel) {
if (node == sentinel)
return;
auto &meta = node_meta_[node];
if (meta.ready_streams.erase(id_to_remove)) {
for (auto &e : node->inputs)
EraseReady(e->producer, id_to_remove, sentinel);
for (auto &out : node->outputs)
for (auto &e : out.consumers)
EraseReady(e->consumer, id_to_remove, sentinel);
void UseStream(NodeMeta &meta, int id) {
int use_count = ++stream_use_count_[id];
meta.stream_id = id;
meta.ready_streams[id] = use_count;
}

void CombineReady(NodeMeta &to, NodeMeta &from) {
for (auto it = from.ready_streams.begin(); it != from.ready_streams.end(); ) {
auto [id, old_use_count] = *it;
int current_use_count = stream_use_count_[id];
if (current_use_count > old_use_count) {
it = from.ready_streams.erase(it);
continue;
}
to.ready_streams[id] = current_use_count;
++it;
}
}

int total_streams_ = 0;

std::unordered_map<const ExecNode *, NodeMeta> node_meta_;
std::unordered_map<int, int> stream_use_count_;
std::vector<std::pair<const ExecNode *, NodeMeta *>> sorted_nodes_;
};

Expand Down

0 comments on commit a80eb0b

Please sign in to comment.