From 8754e3be687625300bb0b4425924728d1bef3adf Mon Sep 17 00:00:00 2001 From: Xiaoxuan Meng Date: Fri, 20 Dec 2024 09:41:20 -0800 Subject: [PATCH] [native] Fix unsafe row exchange source with compression support --- .../operators/UnsafeRowExchangeSource.cpp | 23 +++++++++++++++---- .../main/operators/tests/BroadcastTest.cpp | 3 ++- presto-native-execution/velox | 2 +- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/presto-native-execution/presto_cpp/main/operators/UnsafeRowExchangeSource.cpp b/presto-native-execution/presto_cpp/main/operators/UnsafeRowExchangeSource.cpp index e952a62970e4..93d46721bda0 100644 --- a/presto-native-execution/presto_cpp/main/operators/UnsafeRowExchangeSource.cpp +++ b/presto-native-execution/presto_cpp/main/operators/UnsafeRowExchangeSource.cpp @@ -16,6 +16,7 @@ #include "presto_cpp/main/common/Configs.h" #include "presto_cpp/main/operators/UnsafeRowExchangeSource.h" +#include "velox/serializers/RowSerializer.h" namespace facebook::presto::operators { @@ -36,7 +37,7 @@ UnsafeRowExchangeSource::request( return std::move(shuffle_->next()) .deferValue([this](velox::BufferPtr buffer) { std::vector promises; - int64_t totalBytes = 0; + int64_t totalBytes{0}; { std::lock_guard l(queue_->mutex()); @@ -45,14 +46,26 @@ UnsafeRowExchangeSource::request( queue_->enqueueLocked(nullptr, promises); } else { totalBytes = buffer->size(); - + VELOX_CHECK_LE(totalBytes, std::numeric_limits::max()); + ++numBatches_; - auto ioBuf = - folly::IOBuf::wrapBuffer(buffer->as(), buffer->size()); + velox::serializer::detail::RowGroupHeader rowHeader{ + .compressedSize = static_cast(totalBytes), + .uncompressedSize = static_cast(totalBytes), + .compressed = false}; + auto headBuffer = std::make_shared( + velox::serializer::detail::RowGroupHeader::size(), '0'); + rowHeader.write(const_cast(headBuffer->data())); + + auto ioBuf = folly::IOBuf::wrapBuffer( + headBuffer->data(), headBuffer->size()); + ioBuf->appendToChain( + folly::IOBuf::wrapBuffer(buffer->as(), buffer->size())); queue_->enqueueLocked( std::make_unique( - std::move(ioBuf), [buffer](auto& /*unused*/) {}), + std::move(ioBuf), + [buffer, headBuffer](auto& /*unused*/) {}), promises); } } diff --git a/presto-native-execution/presto_cpp/main/operators/tests/BroadcastTest.cpp b/presto-native-execution/presto_cpp/main/operators/tests/BroadcastTest.cpp index 0973a8359c0b..f206ac6e1c71 100644 --- a/presto-native-execution/presto_cpp/main/operators/tests/BroadcastTest.cpp +++ b/presto-native-execution/presto_cpp/main/operators/tests/BroadcastTest.cpp @@ -216,7 +216,8 @@ class BroadcastTest : public exec::test::OperatorTestBase { pool(), dataType, velox::getNamedVectorSerde(velox::VectorSerde::Kind::kPresto), - &result); + &result, + nullptr); return result; } }; diff --git a/presto-native-execution/velox b/presto-native-execution/velox index 12942c1eb76d..cded6c2aa4f2 160000 --- a/presto-native-execution/velox +++ b/presto-native-execution/velox @@ -1 +1 @@ -Subproject commit 12942c1eb76de019b775f4b207bc4595d8ace5c0 +Subproject commit cded6c2aa4f22adfb6a7682204a1bc3dece94161