Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable GPU->CPU transfers #5593

Merged
merged 9 commits into from
Sep 11, 2024
Merged
33 changes: 26 additions & 7 deletions dali/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,13 +246,33 @@ daliCreatePipeline2(daliPipelineHandle *pipe_handle, const char *serialized_pipe
int async_execution, int separated_execution, int prefetch_queue_depth,
int cpu_prefetch_queue_depth, int gpu_prefetch_queue_depth,
int enable_memory_stats) {
bool se = separated_execution != 0;
bool pe = pipelined_execution != 0;
bool ae = async_execution != 0;
dali_exec_flags_t flags = {};
if (async_execution)
flags = flags | DALI_EXEC_IS_ASYNC;
if (pipelined_execution)
flags = flags | DALI_EXEC_IS_PIPELINED;
if (separated_execution)
flags = flags | DALI_EXEC_IS_SEPARATED;
daliCreatePipeline3(pipe_handle, serialized_pipeline, length,
max_batch_size, num_threads, device_id, flags,
prefetch_queue_depth, cpu_prefetch_queue_depth, gpu_prefetch_queue_depth,
enable_memory_stats);
}

auto pipeline =
std::make_unique<dali::Pipeline>(std::string(serialized_pipeline, length), max_batch_size,
num_threads, device_id, pe, prefetch_queue_depth, ae);
DLL_PUBLIC void
daliCreatePipeline3(daliPipelineHandle *pipe_handle, const char *serialized_pipeline, int length,
int max_batch_size, int num_threads, int device_id,
dali_exec_flags_t exec_flags, int prefetch_queue_depth,
int cpu_prefetch_queue_depth, int gpu_prefetch_queue_depth,
int enable_memory_stats) {
bool se = exec_flags & DALI_EXEC_IS_SEPARATED;
bool pe = exec_flags & DALI_EXEC_IS_PIPELINED;
bool ae = exec_flags & DALI_EXEC_IS_ASYNC;
bool de = exec_flags & DALI_EXEC_IS_DYNAMIC;

auto pipeline = std::make_unique<dali::Pipeline>(
std::string(serialized_pipeline, length), max_batch_size, num_threads, device_id,
pe, prefetch_queue_depth, ae, de);
pipeline->SetExecutionTypes(pe, se, ae);
if (se) {
pipeline->SetQueueSizes(cpu_prefetch_queue_depth, gpu_prefetch_queue_depth);
Expand All @@ -263,7 +283,6 @@ daliCreatePipeline2(daliPipelineHandle *pipe_handle, const char *serialized_pipe
*pipe_handle = WrapPipeline(std::move(pipeline)).release();
}


void daliDeserializeDefault(daliPipelineHandle *pipe_handle, const char *serialized_pipeline,
int length) {
auto pipeline = std::make_unique<dali::Pipeline>(std::string(serialized_pipeline, length));
Expand Down
26 changes: 15 additions & 11 deletions dali/pipeline/executor/executor_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,40 +58,44 @@ bool ForceExec2() {
} // namespace

template <typename... T>
std::unique_ptr<ExecutorBase> GetExecutorImpl(bool pipelined, bool separated, bool async,
T&&... args) {
if (async && separated && pipelined) {
std::unique_ptr<ExecutorBase> GetExecutorImpl(
bool pipelined, bool separated, bool async, bool dynamic, T&&... args) {
// Go over possible combinations and throw otherwise.
if (async && separated && pipelined && !dynamic) {
return std::make_unique<AsyncSeparatedPipelinedExecutor>(std::forward<T>(args)...);
} else if (async && !separated && pipelined) {
if (ForceExec2()) {
std::cerr << "\n!!! Forced use of Executor 2.0 !!!" << std::endl;
bool force_exec2 = ForceExec2();
if (dynamic || force_exec2) {
if (force_exec2)
std::cerr << "\n!!! Forced use of Executor 2.0 !!!" << std::endl;
return std::make_unique<exec2::Executor2>(MakeExec2Config(std::forward<T>(args)...));
} else {
return std::make_unique<AsyncPipelinedExecutor>(std::forward<T>(args)...);
}
} else if (!async && separated && pipelined) {
} else if (!async && separated && pipelined && !dynamic) {
return std::make_unique<SeparatedPipelinedExecutor>(std::forward<T>(args)...);
} else if (!async && !separated && pipelined) {
} else if (!async && !separated && pipelined && !dynamic) {
return std::make_unique<PipelinedExecutor>(std::forward<T>(args)...);
} else if (!async && !separated && !pipelined) {
} else if (!async && !separated && !pipelined && !dynamic) {
return std::make_unique<SimpleExecutor>(std::forward<T>(args)...);
JanuszL marked this conversation as resolved.
Show resolved Hide resolved
}
std::stringstream error;
error << std::boolalpha;
error << "No supported executor selected for pipelined = " << pipelined
<< ", separated = " << separated << ", async = " << async << std::endl;
<< ", separated = " << separated << ", async = " << async
<< ", dynamic = " << dynamic << std::endl;
DALI_FAIL(error.str());
}


std::unique_ptr<ExecutorBase> GetExecutor(bool pipelined, bool separated, bool async,
std::unique_ptr<ExecutorBase> GetExecutor(bool pipelined, bool separated, bool async, bool dynamic,
int batch_size, int num_thread, int device_id,
size_t bytes_per_sample_hint, bool set_affinity,
int max_num_stream,
int default_cuda_stream_priority,
QueueSizes prefetch_queue_depth) {
return GetExecutorImpl(
pipelined, separated, async,
pipelined, separated, async, dynamic,
batch_size, num_thread, device_id, bytes_per_sample_hint, set_affinity,
max_num_stream, default_cuda_stream_priority, prefetch_queue_depth);
}
Expand Down
2 changes: 1 addition & 1 deletion dali/pipeline/executor/executor_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
namespace dali {

DLL_PUBLIC
std::unique_ptr<ExecutorBase> GetExecutor(bool pipelined, bool separated, bool async,
std::unique_ptr<ExecutorBase> GetExecutor(bool pipelined, bool separated, bool async, bool dynamic,
int batch_size, int num_thread, int device_id,
size_t bytes_per_sample_hint, bool set_affinity = false,
int max_num_stream = -1,
Expand Down
25 changes: 20 additions & 5 deletions dali/pipeline/executor/lowered_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,18 @@ void OpGraph::Lower(const graph::OpGraph &definition) {
if (!op_nodes_.empty() || !tensor_nodes_.empty())
throw std::logic_error("The target graph must be empty");
for (auto &node : definition.OpNodes()) {
auto &lowered_op = AddOp(node.spec, node.instance_name);
lowered_op.definition = &node;
try {
JanuszL marked this conversation as resolved.
Show resolved Hide resolved
auto &lowered_op = AddOp(node.spec, node.instance_name);
lowered_op.definition = &node;
} catch (...) {
PropagateError({
std::current_exception(),
make_string(
"Critical error when building pipeline:\n",
GetErrorContextMessage(node.spec)),
"\nCurrent pipeline object is no longer valid."
});
}
}
for (auto &t : tensor_nodes_) {
t.definition = definition.GetData(t.name);
Expand Down Expand Up @@ -131,22 +141,27 @@ OpNode &OpGraph::AddOp(const OpSpec &spec, const std::string &op_name) {
// Validate the op specification
CheckOpConstraints(spec);

const char *gpu2cpu_error =
mdabek-nvidia marked this conversation as resolved.
Show resolved Hide resolved
"This pipeline doesn't support transition from GPU to CPU.\n"
"To enable GPU->CPU transitions, use the experimental \"dynamic\" executor.\n"
"Specify experimental_exec_dynamic=True in your Pipeline constructor or @pipeline_def.";

string device = spec.GetArgument<string>("device");
auto op_type = ParseOpType(device);
// TODO(klecki): refactor this out
switch (op_type) {
case OpType::CPU: {
// Enforce graph constraints
DALI_ENFORCE(AllInputsCPU(spec), "CPU ops cannot receive GPU input data.");
DALI_ENFORCE(AllOutputsCPU(spec), "CPU ops can only produce CPU output data.");
DALI_ENFORCE(AllInputsCPU(spec), gpu2cpu_error);
DALI_ENFORCE(AllOutputsCPU(spec), "CPU operators can only produce CPU output data.");
break;
}
case OpType::GPU: {
break;
}
case OpType::MIXED: {
// Enforce graph constraints
DALI_ENFORCE(AllInputsCPU(spec), "Mixed ops cannot receive GPU input data.");
DALI_ENFORCE(AllInputsCPU(spec), gpu2cpu_error);
break;
}
default:
Expand Down
40 changes: 38 additions & 2 deletions dali/pipeline/operator/builtin/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,47 @@
namespace dali {

template <>
void Copy<CPUBackend>::RunCopies(Workspace &ws) {
scatter_gather_.Run(ws.GetThreadPool(), true);
void Copy<CPUBackend>::RunImpl(Workspace &ws) {
if (ws.InputIsType<CPUBackend>(0)) {
auto &input = ws.Input<CPUBackend>(0);
auto &output = ws.Output<CPUBackend>(0);

int batch_size = input.num_samples();
output.SetLayout(input.GetLayout());
auto shapes = input.shape();

auto &thread_pool = ws.GetThreadPool();
for (int sample_id = 0; sample_id < batch_size; ++sample_id) {
thread_pool.AddWork(
JanuszL marked this conversation as resolved.
Show resolved Hide resolved
[sample_id, &input, &output](int tid) {
output.CopySample(sample_id, input, sample_id, AccessOrder::host());
},
shapes.tensor_size(sample_id));
}
thread_pool.RunAll();
} else {
auto &input = ws.Input<GPUBackend>(0);
auto &output = ws.Output<CPUBackend>(0);
output.Copy(input, ws.output_order());
}
}

template <>
void Copy<GPUBackend>::RunImpl(Workspace &ws) {
if (ws.InputIsType<CPUBackend>(0)) {
auto &input = ws.Input<CPUBackend>(0);
auto &output = ws.Output<GPUBackend>(0);
output.Copy(input, ws.output_order());
} else {
auto &input = ws.Input<GPUBackend>(0);
auto &output = ws.Output<GPUBackend>(0);
output.Copy(input, ws.output_order());
}
}

DALI_REGISTER_OPERATOR(Copy, Copy<CPUBackend>, CPU);
DALI_REGISTER_OPERATOR(Copy, Copy<GPUBackend>, GPU);


DALI_SCHEMA(Copy)
.DocStr("Creates a copy of the input tensor.")
Expand Down
27 changes: 0 additions & 27 deletions dali/pipeline/operator/builtin/copy.cu

This file was deleted.

36 changes: 6 additions & 30 deletions dali/pipeline/operator/builtin/copy.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@ namespace dali {
template <typename Backend>
class Copy : public StatelessOperator<Backend> {
public:
inline explicit Copy(const OpSpec &spec) :
StatelessOperator<Backend>(spec), scatter_gather_(kMaxSizePerBlock) {}

inline ~Copy() override = default;
explicit Copy(const OpSpec &spec) :
StatelessOperator<Backend>(spec) {}

DISABLE_COPY_MOVE_ASSIGN(Copy);

Expand All @@ -42,37 +40,15 @@ class Copy : public StatelessOperator<Backend> {

bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override {
output_desc.resize(1);
const auto &input = ws.Input<Backend>(0);
output_desc[0].type = input.type();
output_desc[0].shape = input.shape();
output_desc[0].type = ws.GetInputDataType(0);
output_desc[0].shape = ws.GetInputShape(0);
return true;
}

void RunImpl(Workspace &ws) override {
auto &input = ws.Input<Backend>(0);
auto data_type_size = input.type_info().size();
auto &output = ws.Output<Backend>(0);
output.SetLayout(input.GetLayout());
for (int i = 0; i < input.num_samples(); i++) {
auto tensor_shape = input.tensor_shape(i);
auto tensor_size = volume(tensor_shape);
scatter_gather_.AddCopy(output.raw_mutable_tensor(i), input.raw_tensor(i),
tensor_size * data_type_size);
}
RunCopies(ws);
}

void RunCopies(Workspace &ws);

std::conditional_t<
std::is_same<Backend, CPUBackend>::value,
kernels::ScatterGatherCPU,
kernels::ScatterGatherGPU> scatter_gather_;
// 256 kB per block for GPU
static constexpr size_t kMaxSizePerBlock =
std::is_same<Backend, CPUBackend>::value ? kernels::ScatterGatherCPU::kAnyBlockSize : 1 << 18;
void RunImpl(Workspace &ws) override;
};


} // namespace dali

#endif // DALI_PIPELINE_OPERATOR_BUILTIN_COPY_H_
2 changes: 1 addition & 1 deletion dali/pipeline/operator/checkpointing/checkpoint_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void BuildFromLegacyGraph(Checkpoint &checkpoint, const OpGraph &graph) {
}

auto GetSimpleExecutor() {
return GetExecutor(false, false, false, 1, 1, CPU_ONLY_DEVICE_ID, 0);
return GetExecutor(false, false, false, false, 1, 1, CPU_ONLY_DEVICE_ID, 0);
}

} // namespace
Expand Down
Loading
Loading