Skip to content

Commit

Permalink
Improve robustness.
Browse files Browse the repository at this point in the history
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
  • Loading branch information
mzient committed Sep 2, 2024
1 parent c199977 commit 771d840
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 18 deletions.
22 changes: 13 additions & 9 deletions dali/pipeline/executor/executor2/stream_assignment.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ class StreamAssignment<StreamPolicy::PerOperator> {
sorted_nodes_.push_back(&node);
node_ids_[&node] = idx;
if (node.inputs.empty()) {
queue_.push({ node_ids_[&node], AssignStreamId(&node).value_or(kInvalidStreamIdx) });
queue_.push({ node_ids_[&node], AssignStreamId(&node).first.value_or(kInvalidStreamIdx) });
} else {
for (auto &inp : node.inputs) {
assert(node_ids_.count(inp->producer) >= 0 && "Nodes must be topologically sorted.");
Expand Down Expand Up @@ -261,13 +261,14 @@ class StreamAssignment<StreamPolicy::PerOperator> {

for (auto &output_desc : node->outputs) {
for (auto *out : output_desc.consumers) {
auto out_stream_id = AssignStreamId(out->consumer, stream_id);
if (out_stream_id.has_value())
auto [out_stream_id, is_new] = AssignStreamId(out->consumer, stream_id);
if (out_stream_id == stream_id)
keep_stream_id = false;
queue_.push({node_ids_[out->consumer], out_stream_id.value_or(kInvalidStreamIdx)});
if (is_new)
queue_.push({node_ids_[out->consumer], out_stream_id.value_or(kInvalidStreamIdx)});
}
}
if (keep_stream_id)
if (stream_id.has_value() && keep_stream_id)
free_stream_ids_.erase(*stream_id);
}
}
Expand Down Expand Up @@ -310,8 +311,9 @@ class StreamAssignment<StreamPolicy::PerOperator> {
os << "\n";
}

std::optional<int> AssignStreamId(const ExecNode *node,
std::optional<int> prev_stream_id = std::nullopt) {
std::pair<std::optional<int>, bool> AssignStreamId(
const ExecNode *node,
std::optional<int> prev_stream_id = std::nullopt) {
// If the preceding node had a stream, then we have to pass it on through CPU nodes if
// there are any GPU nodes down the graph.
// If the preceding node didn't have a stream, then we only need a stream if the current
Expand All @@ -327,10 +329,12 @@ class StreamAssignment<StreamPolicy::PerOperator> {
if (!current.has_value() || *current > next_free) {
current = next_free;
free_stream_ids_.erase(b);
return { next_free, true };
} else {
return { *current, false };
}
return *current;
} else {
return std::nullopt;
return { std::nullopt, true };
}
}

Expand Down
32 changes: 23 additions & 9 deletions dali/pipeline/executor/executor2/stream_assignment_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -335,11 +335,11 @@ TEST(Exec2Test, StreamAssignment_PerOperator_2) {
TEST(Exec2Test, StreamAssignment_PerOperator_3) {
ExecGraph eg;
/*
a --- c -- out
\ /
X
/ \
b --- d -- out
a --- c --- e -- out
\ / \ /
X X
/ \ / \
b --- d --- f -- out
*/
graph::OpGraph::Builder b;
b.Add("a",
Expand All @@ -354,14 +354,26 @@ TEST(Exec2Test, StreamAssignment_PerOperator_3) {
SpecGPU()
.AddInput("a->c", "gpu")
.AddInput("b->c", "gpu")
.AddOutput("c->o", "gpu"));
.AddOutput("c->e", "gpu")
.AddOutput("c->f", "gpu"));
b.Add("d",
SpecGPU()
.AddInput("a->d", "gpu")
.AddInput("b->d", "gpu")
.AddOutput("d->o", "gpu"));
b.AddOutput("c->o_gpu");
b.AddOutput("d->o_gpu");
.AddOutput("d->e", "gpu")
.AddOutput("d->f", "gpu"));
b.Add("e",
SpecGPU()
.AddInput("c->e", "gpu")
.AddInput("d->e", "gpu")
.AddOutput("e->o", "gpu"));
b.Add("f",
SpecGPU()
.AddInput("c->f", "gpu")
.AddInput("d->f", "gpu")
.AddOutput("f->o", "gpu"));
b.AddOutput("e->o_gpu");
b.AddOutput("f->o_gpu");
auto g = std::move(b).GetGraph(true);
eg.Lower(g);

Expand All @@ -371,6 +383,8 @@ TEST(Exec2Test, StreamAssignment_PerOperator_3) {
EXPECT_EQ(assignment[map["b"]], 1);
EXPECT_EQ(assignment[map["c"]], 0);
EXPECT_EQ(assignment[map["d"]], 1);
EXPECT_EQ(assignment[map["e"]], 0);
EXPECT_EQ(assignment[map["f"]], 1);
}

TEST(Exec2Test, StreamAssignment_PerOperator_4) {
Expand Down

0 comments on commit 771d840

Please sign in to comment.