From 2923781c99e6009d78593e5885077454d74d0645 Mon Sep 17 00:00:00 2001 From: Michal Zientkiewicz Date: Wed, 3 Jan 2024 19:54:17 +0100 Subject: [PATCH] Fingers crossed. Signed-off-by: Michal Zientkiewicz --- dali/pipeline/executor/executor_test.cc | 3 +- dali/pipeline/pipeline.h | 4 +- dali/python/backend_impl.cc | 3 +- dali/python/nvidia/dali/pipeline.py | 139 ++++++++++++++++-------- dali_tf_plugin/dali_dataset_op.cc | 10 +- dali_tf_plugin/daliop.cc | 10 +- 6 files changed, 104 insertions(+), 65 deletions(-) diff --git a/dali/pipeline/executor/executor_test.cc b/dali/pipeline/executor/executor_test.cc index a74086c7801..7621de8371f 100644 --- a/dali/pipeline/executor/executor_test.cc +++ b/dali/pipeline/executor/executor_test.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ #include "dali/pipeline/executor/pipelined_executor.h" #include "dali/pipeline/executor/async_pipelined_executor.h" #include "dali/pipeline/executor/async_separated_pipelined_executor.h" +#include "dali/pipeline/operator/builtin/external_source.h" #include "dali/test/dali_test_utils.h" #include "dali/test/tensor_test_utils.h" diff --git a/dali/pipeline/pipeline.h b/dali/pipeline/pipeline.h index 6a12a9e8758..f515f67e871 100644 --- a/dali/pipeline/pipeline.h +++ b/dali/pipeline/pipeline.h @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -34,7 +34,7 @@ #include "dali/pipeline/executor/executor.h" #include "dali/pipeline/graph/op_graph.h" #include "dali/pipeline/pipeline_output_desc.h" -#include "dali/pipeline/operator/builtin/external_source.h" +#include "dali/pipeline/operator/builtin/input_operator.h" #include "dali/pipeline/operator/checkpointing/checkpoint.h" diff --git a/dali/python/backend_impl.cc b/dali/python/backend_impl.cc index a3c9d096d50..04c46d54117 100644 --- a/dali/python/backend_impl.cc +++ b/dali/python/backend_impl.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -1927,6 +1927,7 @@ PYBIND11_MODULE(backend_impl, m) { p->SetOutputDescs(out_desc); }) .def("Run", &Pipeline::Run, py::call_guard()) + .def("Prefetch", &Pipeline::Prefetch, py::call_guard()) .def("Outputs", [](Pipeline *p) { Workspace ws; diff --git a/dali/python/nvidia/dali/pipeline.py b/dali/python/nvidia/dali/pipeline.py index 97f29d04744..b4121e8e268 100644 --- a/dali/python/nvidia/dali/pipeline.py +++ b/dali/python/nvidia/dali/pipeline.py @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -940,9 +940,13 @@ def build(self): def input_feed_count(self, input_name): return self._pipe.InputFeedCount(input_name) - def _feed_input(self, name, data, layout=None, cuda_stream=None, use_copy_kernel=False, is_prefetch=False): + def _feed_input( + self, name, data, layout=None, cuda_stream=None, use_copy_kernel=False, is_prefetch=False + ): from nvidia.dali.external_source import _prep_data_for_feed_input + trace(name, data) + if cuda_stream is None: cuda_stream = types._get_default_stream_for_array(data) if cuda_stream == -1: @@ -1049,23 +1053,6 @@ def feed_input(self, data_node, data, layout=None, cuda_stream=None, use_copy_ke self._feed_input(name, data, layout, cuda_stream, use_copy_kernel) - def _run_cpu(self): - """Run CPU portion of the pipeline.""" - if not self._built: - raise RuntimeError("Pipeline must be built first.") - if not self._last_iter: - self._pipe.RunCPU() - self._cpu_batches_to_consume += 1 - - def _run_gpu(self): - """Run GPU portion of the pipeline.""" - if not self._built: - raise RuntimeError("Pipeline must be built first.") - if self._cpu_batches_to_consume > 0: - self._pipe.RunGPU() - self._cpu_batches_to_consume -= 1 - self._gpu_batches_to_consume += 1 - def outputs(self): """Returns the outputs of the pipeline and releases previous buffer. @@ -1154,7 +1141,8 @@ def release_outputs(self): with self._check_api_type_scope(types.PipelineAPIType.SCHEDULED): if not self._built: raise RuntimeError("Pipeline must be built first.") - return self._pipe.ReleaseOutputs() + ret = self._pipe.ReleaseOutputs() + return ret # for the backward compatibility def _release_outputs(self): @@ -1254,20 +1242,29 @@ def _prefetch(self): raise RuntimeError("Pipeline must be built first.") if not self._pipe: raise RuntimeError("The pipeline was destroyed.") + trace(self._first_iter, self._last_iter) self._schedule_py_workers() + + try: + self._prefetch_inputs() + self._first_iter = False + self._pipe.Prefetch() + except StopIteration: + self._last_iter = True + + def _prefetch_inputs(self): self._run_input_callbacks(True) - prefetch_count = self._cpu_queue_size if self._exec_separated: prefetch_count = self._cpu_queue_size + self._gpu_queue_size + self._batches_to_consume += self._gpu_queue_size + else: + prefetch_count = self._cpu_queue_size + self._batches_to_consume += prefetch_count for i in range(prefetch_count): self.iter_setup() - self._batches_to_consume += self._gpu_queue_size - self._first_iter = False - self._pipe.Prefetch() - def _run_once(self): """Start running the whole pipeline once without waiting for its results. @@ -1280,7 +1277,8 @@ def _run_once(self): # Special case to prevent a deadlock if user didn't release the only buffer if not self._exec_async and self._prefetch_queue_depth == 1: self.release_outputs() - self._pipe._run() + if not self._last_iter: + self._pipe.Run() except StopIteration: self._last_iter = True @@ -1295,6 +1293,7 @@ def reset(self): If pipeline iterator reached the end then reset its state to the beginning. """ + trace(self._last_iter) if self._last_iter: self._first_iter = True self._last_iter = False @@ -1310,6 +1309,7 @@ def reset(self): def empty(self): """If there is any work scheduled in the pipeline but not yet consumed""" + trace(self._batches_to_consume == 0) return self._batches_to_consume == 0 def serialize(self, define_graph=None, filename=None): @@ -1525,32 +1525,42 @@ def _run_input_callbacks(self, is_prefetch=False): if self._input_callbacks is None: return - batches = [] # data from external source callbacks is gathered here + max_count = 1 + done = False stop_iter = False - for i, group in enumerate(self._parallel_input_callbacks): - try: - count = group.feed_count(self) if is_prefetch else 1 - for i in range(count): - batches.append( + iter = 0 + while not done and not stop_iter: + done = True + batches = [] # data from external source callbacks is gathered here + for i, group in enumerate(self._parallel_input_callbacks): + try: + count = group.feed_count(self) if is_prefetch else 1 + if iter < count: group.schedule_and_receive( self, self._py_pool, i, self._max_batch_size, self._epoch_idx ) - ) - except StopIteration: - stop_iter = True - for group in self._seq_input_callbacks: - try: - count = group.feed_count(self) if is_prefetch else 1 - for i in range(count): - batches.append(group.get_batch(self, self._max_batch_size, self._epoch_idx)) - except StopIteration: - stop_iter = True - if stop_iter: - raise StopIteration() - - # we only fill external source queues when we know that all callbacks succeeded - for batch in batches: - batch.feed() + if iter + 1 < count: + done = False + except StopIteration: + stop_iter = True + for group in self._seq_input_callbacks: + try: + count = group.feed_count(self) if is_prefetch else 1 + if iter < count: + batches.append(group.get_batch(self, self._max_batch_size, self._epoch_idx)) + if iter + 1 < count: + done = False + except StopIteration: + stop_iter = True + + if stop_iter: + raise StopIteration() + + # we only fill external source queues when we know that all callbacks succeeded + for batch in batches: + batch.feed() + + iter += 1 def _iter_setup(self): self._run_input_callbacks() @@ -1985,3 +1995,36 @@ def _insert_experimental_pipeline_def(): _insert_experimental_pipeline_def() + + +_indent = 0 + + +def trace(*args, **kwargs): + pass + + +# def trace(*args, **kwargs): +# print(' ' * _indent, *args, **kwargs) + + +# def trace_pipeline_funcs(): +# for name, f in inspect.getmembers(Pipeline, predicate=inspect.isfunction): +# if name[0:2] == '__': +# continue +# #@functools.wraps(f) +# def decorate(name, f): +# def tmp(*args, **kwargs): +# global _indent +# try: +# trace(name, "--->") +# _indent += 1 +# return f(*args, **kwargs) +# finally: +# _indent -= 1 +# trace("<---", name) +# return tmp +# setattr(Pipeline, name, decorate(name, f)) + + +# trace_pipeline_funcs() diff --git a/dali_tf_plugin/dali_dataset_op.cc b/dali_tf_plugin/dali_dataset_op.cc index 10c44f56fbd..e6b54cf12d9 100644 --- a/dali_tf_plugin/dali_dataset_op.cc +++ b/dali_tf_plugin/dali_dataset_op.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -419,6 +419,7 @@ class DALIDatasetOp::Dataset::Iterator : public DatasetIterator { * When there are input datasets, feed the pipeline required number of input batches. * * TODO(klecki): Inputs handled only for an uniform executor + * TODO(michalz): Clean up the control flow (reverse if nesting) */ Status PrefetchPipeline(IteratorContext *context, daliPipelineHandle *pipeline_handle) { if (!dataset()->pipeline_def_.exec_separated) { @@ -441,14 +442,13 @@ class DALIDatasetOp::Dataset::Iterator : public DatasetIterator { } else { actual_prefetch_depth = prefetch_depth; } - TF_DALI_CALL(daliPrefetchUniform(pipeline_handle, actual_prefetch_depth)); + for (int i = 0; i < actual_prefetch_depth; i++) + TF_DALI_CALL(daliRun(pipeline_handle)); } else { if (dataset()->HasInputs()) { return errors::InvalidArgument("Input datasets are not compatible with split executor."); } - TF_DALI_CALL(daliPrefetchSeparate(pipeline_handle, - dataset()->pipeline_def_.cpu_prefetch_queue_depth, - dataset()->pipeline_def_.gpu_prefetch_queue_depth)); + TF_DALI_CALL(daliPrefetch(pipeline_handle)); } return Status(); } diff --git a/dali_tf_plugin/daliop.cc b/dali_tf_plugin/daliop.cc index ee59867b65f..db0abd9bc56 100644 --- a/dali_tf_plugin/daliop.cc +++ b/dali_tf_plugin/daliop.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -155,13 +155,7 @@ class DaliOp : public tf::OpKernel { #endif LOG_LINE << "Pipeline created\n"; LOG_LINE << "Prefetching...\n"; - if (!exec_separated) { - TF_DALI_CALL(daliPrefetchUniform(&pipe_handle_, prefetch_queue_depth_)); - } else { - TF_DALI_CALL(daliPrefetchSeparate(&pipe_handle_, - cpu_prefetch_queue_depth, - prefetch_queue_depth_)); - } + TF_DALI_CALL(daliPrefetch(&pipe_handle_)); LOG_LINE << "After first run\n"; }