Skip to content

Commit

Permalink
Fingers crossed.
Browse files Browse the repository at this point in the history
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
  • Loading branch information
mzient committed Jan 3, 2024
1 parent 4aa621b commit 2923781
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 65 deletions.
3 changes: 2 additions & 1 deletion dali/pipeline/executor/executor_test.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -22,6 +22,7 @@
#include "dali/pipeline/executor/pipelined_executor.h"
#include "dali/pipeline/executor/async_pipelined_executor.h"
#include "dali/pipeline/executor/async_separated_pipelined_executor.h"
#include "dali/pipeline/operator/builtin/external_source.h"
#include "dali/test/dali_test_utils.h"
#include "dali/test/tensor_test_utils.h"

Expand Down
4 changes: 2 additions & 2 deletions dali/pipeline/pipeline.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -34,7 +34,7 @@
#include "dali/pipeline/executor/executor.h"
#include "dali/pipeline/graph/op_graph.h"
#include "dali/pipeline/pipeline_output_desc.h"
#include "dali/pipeline/operator/builtin/external_source.h"
#include "dali/pipeline/operator/builtin/input_operator.h"
#include "dali/pipeline/operator/checkpointing/checkpoint.h"


Expand Down
3 changes: 2 additions & 1 deletion dali/python/backend_impl.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -1927,6 +1927,7 @@ PYBIND11_MODULE(backend_impl, m) {
p->SetOutputDescs(out_desc);
})
.def("Run", &Pipeline::Run, py::call_guard<py::gil_scoped_release>())
.def("Prefetch", &Pipeline::Prefetch, py::call_guard<py::gil_scoped_release>())
.def("Outputs",
[](Pipeline *p) {
Workspace ws;
Expand Down
139 changes: 91 additions & 48 deletions dali/python/nvidia/dali/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2017-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -940,9 +940,13 @@ def build(self):
def input_feed_count(self, input_name):
return self._pipe.InputFeedCount(input_name)

def _feed_input(self, name, data, layout=None, cuda_stream=None, use_copy_kernel=False, is_prefetch=False):
def _feed_input(
self, name, data, layout=None, cuda_stream=None, use_copy_kernel=False, is_prefetch=False
):
from nvidia.dali.external_source import _prep_data_for_feed_input

trace(name, data)

if cuda_stream is None:
cuda_stream = types._get_default_stream_for_array(data)
if cuda_stream == -1:
Expand Down Expand Up @@ -1049,23 +1053,6 @@ def feed_input(self, data_node, data, layout=None, cuda_stream=None, use_copy_ke

self._feed_input(name, data, layout, cuda_stream, use_copy_kernel)

def _run_cpu(self):
"""Run CPU portion of the pipeline."""
if not self._built:
raise RuntimeError("Pipeline must be built first.")
if not self._last_iter:
self._pipe.RunCPU()
self._cpu_batches_to_consume += 1

def _run_gpu(self):
"""Run GPU portion of the pipeline."""
if not self._built:
raise RuntimeError("Pipeline must be built first.")
if self._cpu_batches_to_consume > 0:
self._pipe.RunGPU()
self._cpu_batches_to_consume -= 1
self._gpu_batches_to_consume += 1

def outputs(self):
"""Returns the outputs of the pipeline and releases previous buffer.
Expand Down Expand Up @@ -1154,7 +1141,8 @@ def release_outputs(self):
with self._check_api_type_scope(types.PipelineAPIType.SCHEDULED):
if not self._built:
raise RuntimeError("Pipeline must be built first.")
return self._pipe.ReleaseOutputs()
ret = self._pipe.ReleaseOutputs()
return ret

# for the backward compatibility
def _release_outputs(self):
Expand Down Expand Up @@ -1254,20 +1242,29 @@ def _prefetch(self):
raise RuntimeError("Pipeline must be built first.")
if not self._pipe:
raise RuntimeError("The pipeline was destroyed.")
trace(self._first_iter, self._last_iter)
self._schedule_py_workers()

try:
self._prefetch_inputs()
self._first_iter = False
self._pipe.Prefetch()
except StopIteration:
self._last_iter = True

def _prefetch_inputs(self):
self._run_input_callbacks(True)

prefetch_count = self._cpu_queue_size
if self._exec_separated:
prefetch_count = self._cpu_queue_size + self._gpu_queue_size
self._batches_to_consume += self._gpu_queue_size
else:
prefetch_count = self._cpu_queue_size
self._batches_to_consume += prefetch_count

for i in range(prefetch_count):
self.iter_setup()

self._batches_to_consume += self._gpu_queue_size
self._first_iter = False
self._pipe.Prefetch()

def _run_once(self):
"""Start running the whole pipeline once without waiting for its results.
Expand All @@ -1280,7 +1277,8 @@ def _run_once(self):
# Special case to prevent a deadlock if user didn't release the only buffer
if not self._exec_async and self._prefetch_queue_depth == 1:
self.release_outputs()
self._pipe._run()
if not self._last_iter:
self._pipe.Run()
except StopIteration:
self._last_iter = True

Expand All @@ -1295,6 +1293,7 @@ def reset(self):
If pipeline iterator reached the end then reset its state to the beginning.
"""
trace(self._last_iter)
if self._last_iter:
self._first_iter = True
self._last_iter = False
Expand All @@ -1310,6 +1309,7 @@ def reset(self):

def empty(self):
"""If there is any work scheduled in the pipeline but not yet consumed"""
trace(self._batches_to_consume == 0)
return self._batches_to_consume == 0

def serialize(self, define_graph=None, filename=None):
Expand Down Expand Up @@ -1525,32 +1525,42 @@ def _run_input_callbacks(self, is_prefetch=False):
if self._input_callbacks is None:
return

batches = [] # data from external source callbacks is gathered here
max_count = 1
done = False
stop_iter = False

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable max_count is not used.
for i, group in enumerate(self._parallel_input_callbacks):
try:
count = group.feed_count(self) if is_prefetch else 1
for i in range(count):
batches.append(
iter = 0
while not done and not stop_iter:
done = True
batches = [] # data from external source callbacks is gathered here
for i, group in enumerate(self._parallel_input_callbacks):
try:
count = group.feed_count(self) if is_prefetch else 1
if iter < count:
group.schedule_and_receive(
self, self._py_pool, i, self._max_batch_size, self._epoch_idx
)
)
except StopIteration:
stop_iter = True
for group in self._seq_input_callbacks:
try:
count = group.feed_count(self) if is_prefetch else 1
for i in range(count):
batches.append(group.get_batch(self, self._max_batch_size, self._epoch_idx))
except StopIteration:
stop_iter = True
if stop_iter:
raise StopIteration()

# we only fill external source queues when we know that all callbacks succeeded
for batch in batches:
batch.feed()
if iter + 1 < count:
done = False
except StopIteration:
stop_iter = True
for group in self._seq_input_callbacks:
try:
count = group.feed_count(self) if is_prefetch else 1
if iter < count:
batches.append(group.get_batch(self, self._max_batch_size, self._epoch_idx))
if iter + 1 < count:
done = False
except StopIteration:
stop_iter = True

if stop_iter:
raise StopIteration()

# we only fill external source queues when we know that all callbacks succeeded
for batch in batches:
batch.feed()

iter += 1

def _iter_setup(self):
self._run_input_callbacks()
Expand Down Expand Up @@ -1985,3 +1995,36 @@ def _insert_experimental_pipeline_def():


_insert_experimental_pipeline_def()


_indent = 0


def trace(*args, **kwargs):
pass


# def trace(*args, **kwargs):
# print(' ' * _indent, *args, **kwargs)


# def trace_pipeline_funcs():
# for name, f in inspect.getmembers(Pipeline, predicate=inspect.isfunction):
# if name[0:2] == '__':
# continue
# #@functools.wraps(f)
# def decorate(name, f):
# def tmp(*args, **kwargs):
# global _indent
# try:
# trace(name, "--->")
# _indent += 1
# return f(*args, **kwargs)
# finally:
# _indent -= 1
# trace("<---", name)
# return tmp
# setattr(Pipeline, name, decorate(name, f))


# trace_pipeline_funcs()
10 changes: 5 additions & 5 deletions dali_tf_plugin/dali_dataset_op.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -419,6 +419,7 @@ class DALIDatasetOp::Dataset::Iterator : public DatasetIterator<Dataset> {
* When there are input datasets, feed the pipeline required number of input batches.
*
* TODO(klecki): Inputs handled only for an uniform executor
* TODO(michalz): Clean up the control flow (reverse if nesting)
*/
Status PrefetchPipeline(IteratorContext *context, daliPipelineHandle *pipeline_handle) {
if (!dataset()->pipeline_def_.exec_separated) {
Expand All @@ -441,14 +442,13 @@ class DALIDatasetOp::Dataset::Iterator : public DatasetIterator<Dataset> {
} else {
actual_prefetch_depth = prefetch_depth;
}
TF_DALI_CALL(daliPrefetchUniform(pipeline_handle, actual_prefetch_depth));
for (int i = 0; i < actual_prefetch_depth; i++)
TF_DALI_CALL(daliRun(pipeline_handle));
} else {
if (dataset()->HasInputs()) {
return errors::InvalidArgument("Input datasets are not compatible with split executor.");
}
TF_DALI_CALL(daliPrefetchSeparate(pipeline_handle,
dataset()->pipeline_def_.cpu_prefetch_queue_depth,
dataset()->pipeline_def_.gpu_prefetch_queue_depth));
TF_DALI_CALL(daliPrefetch(pipeline_handle));
}
return Status();
}
Expand Down
10 changes: 2 additions & 8 deletions dali_tf_plugin/daliop.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -155,13 +155,7 @@ class DaliOp : public tf::OpKernel {
#endif
LOG_LINE << "Pipeline created\n";
LOG_LINE << "Prefetching...\n";
if (!exec_separated) {
TF_DALI_CALL(daliPrefetchUniform(&pipe_handle_, prefetch_queue_depth_));
} else {
TF_DALI_CALL(daliPrefetchSeparate(&pipe_handle_,
cpu_prefetch_queue_depth,
prefetch_queue_depth_));
}
TF_DALI_CALL(daliPrefetch(&pipe_handle_));
LOG_LINE << "After first run\n";
}

Expand Down

0 comments on commit 2923781

Please sign in to comment.