From 3dd85e6d083e9a99a4f38a3f6c3c3179fbe867ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Zientkiewicz?= Date: Tue, 18 Jun 2024 15:25:58 +0200 Subject: [PATCH] [WIP] MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: MichaƂ Zientkiewicz --- dali/pipeline/executor/executor2/exec2.cc | 45 +++++++++++++++++++++++ dali/pipeline/executor/executor2/exec2.h | 33 ++++++++++++++--- 2 files changed, 72 insertions(+), 6 deletions(-) diff --git a/dali/pipeline/executor/executor2/exec2.cc b/dali/pipeline/executor/executor2/exec2.cc index ea996c1a7f0..669b7be379c 100644 --- a/dali/pipeline/executor/executor2/exec2.cc +++ b/dali/pipeline/executor/executor2/exec2.cc @@ -17,5 +17,50 @@ namespace dali { namespace exec2 { +void Executor2::Init() { +} + +void Executor2::Run() { +} + +void Executor2::Prefetch() { +} + + +void Executor2::Outputs(Workspace *ws) { +} + +void Executor2::ShareOutputs(Workspace *ws) { +} + +void Executor2::ReleaseOutputs() { +} + +void Executor2::EnableMemoryStats(bool enable_memory_stats) { +} + +void Executor2::EnableCheckpointing(bool checkpointing) { +} + +ExecutorMetaMap Executor2::GetExecutorMeta() { + return {}; +} + +void Executor2::Shutdown() { +} + +Checkpoint &Executor2::GetCurrentCheckpoint() { +} + +void Executor2::RestoreStateFromCheckpoint(const Checkpoint &cpt) { +} + +int Executor2::InputFeedCount(std::string_view input_name) { +} + +OperatorBase *Executor2::GetOperator(std::string_view name) { +} + + } // namespace exec2 } // namespace dali diff --git a/dali/pipeline/executor/executor2/exec2.h b/dali/pipeline/executor/executor2/exec2.h index c5d763a9104..c7844dd187f 100644 --- a/dali/pipeline/executor/executor2/exec2.h +++ b/dali/pipeline/executor/executor2/exec2.h @@ -19,20 +19,41 @@ #include "dali/pipeline/graph/op_graph2.h" #include "dali/pipeline/executor/executor2/exec_graph.h" #include "dali/pipeline/workspace/workspace.h" +#include "dali/pipeline/executor/executor.h" namespace dali { namespace exec2 { -class Executor2 { +class DLL_PUBLIC Executor2 : public ExecutorBase { public: - void Initialize(std::shared_ptr graph) { - graph_ = graph; - } + explicit Executor2(int queue_depth); - void Run() { + void Build(const graph::OpGraph &graph) override { + + } + // TODO(michalz): Remove + void Build(OpGraph *graph, std::vector output_names) override { + throw std::logic_error("This function is maintained in the interface for legacy tests only."); } - void GetOutputs(Workspace &ws) { + void Init() override; + void Run() override; + void Prefetch() override; + + void Outputs(Workspace *ws) override; + void ShareOutputs(Workspace *ws) override; + void ReleaseOutputs() override; + void EnableMemoryStats(bool enable_memory_stats = false) override; + void EnableCheckpointing(bool checkpointing = false) override; + ExecutorMetaMap GetExecutorMeta() override; + void Shutdown() override; + Checkpoint& GetCurrentCheckpoint() override; + void RestoreStateFromCheckpoint(const Checkpoint &cpt) override; + int InputFeedCount(std::string_view input_name) override; + OperatorBase *GetOperator(std::string_view name) override; + + void Initialize(std::shared_ptr graph) { + graph_ = graph; } ExecGraph exec_graph_;