Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-8481][VL] Clean up shuffle reader cpp code #8482

Merged
merged 1 commit into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cpp/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ set(SPARK_COLUMNAR_PLUGIN_SRCS
shuffle/RandomPartitioner.cc
shuffle/RoundRobinPartitioner.cc
shuffle/ShuffleMemoryPool.cc
shuffle/ShuffleReader.cc
shuffle/ShuffleWriter.cc
shuffle/SinglePartitioner.cc
shuffle/Spill.cc
Expand Down
1 change: 0 additions & 1 deletion cpp/core/jni/JniWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1132,7 +1132,6 @@ JNIEXPORT void JNICALL Java_org_apache_gluten_vectorized_ShuffleReaderJniWrapper
jlong shuffleReaderHandle) {
JNI_METHOD_START
auto reader = ObjectStore::retrieve<ShuffleReader>(shuffleReaderHandle);
GLUTEN_THROW_NOT_OK(reader->close());
ObjectStore::release(shuffleReaderHandle);
JNI_METHOD_END()
}
Expand Down
55 changes: 0 additions & 55 deletions cpp/core/shuffle/ShuffleReader.cc

This file was deleted.

49 changes: 4 additions & 45 deletions cpp/core/shuffle/ShuffleReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,63 +17,22 @@

#pragma once

#include "memory/ColumnarBatch.h"

#include <arrow/ipc/message.h>
#include <arrow/ipc/options.h>

#include "Options.h"
#include "compute/ResultIterator.h"
#include "utils/Compression.h"

namespace gluten {

class DeserializerFactory {
public:
virtual ~DeserializerFactory() = default;

virtual std::unique_ptr<ColumnarBatchIterator> createDeserializer(std::shared_ptr<arrow::io::InputStream> in) = 0;

virtual arrow::MemoryPool* getPool() = 0;

virtual int64_t getDecompressTime() = 0;

virtual int64_t getDeserializeTime() = 0;

virtual ShuffleWriterType getShuffleWriterType() = 0;
};

class ShuffleReader {
public:
explicit ShuffleReader(std::unique_ptr<DeserializerFactory> factory);

virtual ~ShuffleReader() = default;

// FIXME iterator should be unique_ptr or un-copyable singleton
virtual std::shared_ptr<ResultIterator> readStream(std::shared_ptr<arrow::io::InputStream> in);

arrow::Status close();

int64_t getDecompressTime() const;

int64_t getIpcTime() const;

int64_t getDeserializeTime() const;

arrow::MemoryPool* getPool() const;

ShuffleWriterType getShuffleWriterType() const;
virtual std::shared_ptr<ResultIterator> readStream(std::shared_ptr<arrow::io::InputStream> in) = 0;

protected:
arrow::MemoryPool* pool_;
int64_t decompressTime_ = 0;
int64_t deserializeTime_ = 0;
virtual int64_t getDecompressTime() const = 0;

ShuffleWriterType shuffleWriterType_;
virtual int64_t getDeserializeTime() const = 0;

private:
std::shared_ptr<arrow::Schema> schema_;
std::unique_ptr<DeserializerFactory> factory_;
virtual arrow::MemoryPool* getPool() const = 0;
};

} // namespace gluten
2 changes: 1 addition & 1 deletion cpp/velox/compute/VeloxRuntime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ std::shared_ptr<ShuffleReader> VeloxRuntime::createShuffleReader(
auto codec = gluten::createArrowIpcCodec(options.compressionType, options.codecBackend);
auto ctxVeloxPool = memoryManager()->getLeafMemoryPool();
auto veloxCompressionType = facebook::velox::common::stringToCompressionKind(options.compressionTypeStr);
auto deserializerFactory = std::make_unique<gluten::VeloxColumnarBatchDeserializerFactory>(
auto deserializerFactory = std::make_unique<gluten::VeloxShuffleReaderDeserializerFactory>(
schema,
std::move(codec),
veloxCompressionType,
Expand Down
40 changes: 26 additions & 14 deletions cpp/velox/shuffle/VeloxShuffleReader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
* limitations under the License.
*/

#include "VeloxShuffleReader.h"
#include "GlutenByteStream.h"
#include "shuffle/VeloxShuffleReader.h"

#include <arrow/array/array_binary.h>
#include <arrow/io/buffered.h>

#include "memory/VeloxColumnarBatch.h"
#include "shuffle/GlutenByteStream.h"
#include "shuffle/Payload.h"
#include "shuffle/Utils.h"
#include "utils/Common.h"
Expand Down Expand Up @@ -576,7 +576,7 @@ std::shared_ptr<ColumnarBatch> VeloxRssSortShuffleReaderDeserializer::next() {
return std::make_shared<VeloxColumnarBatch>(std::move(rowVector));
}

VeloxColumnarBatchDeserializerFactory::VeloxColumnarBatchDeserializerFactory(
VeloxShuffleReaderDeserializerFactory::VeloxShuffleReaderDeserializerFactory(
const std::shared_ptr<arrow::Schema>& schema,
const std::shared_ptr<arrow::util::Codec>& codec,
const facebook::velox::common::CompressionKind veloxCompressionType,
Expand All @@ -598,7 +598,7 @@ VeloxColumnarBatchDeserializerFactory::VeloxColumnarBatchDeserializerFactory(
initFromSchema();
}

std::unique_ptr<ColumnarBatchIterator> VeloxColumnarBatchDeserializerFactory::createDeserializer(
std::unique_ptr<ColumnarBatchIterator> VeloxShuffleReaderDeserializerFactory::createDeserializer(
std::shared_ptr<arrow::io::InputStream> in) {
switch (shuffleWriterType_) {
case ShuffleWriterType::kHashShuffle:
Expand Down Expand Up @@ -635,23 +635,19 @@ std::unique_ptr<ColumnarBatchIterator> VeloxColumnarBatchDeserializerFactory::cr
}
}

arrow::MemoryPool* VeloxColumnarBatchDeserializerFactory::getPool() {
arrow::MemoryPool* VeloxShuffleReaderDeserializerFactory::getPool() {
return memoryPool_;
}

ShuffleWriterType VeloxColumnarBatchDeserializerFactory::getShuffleWriterType() {
return shuffleWriterType_;
}

int64_t VeloxColumnarBatchDeserializerFactory::getDecompressTime() {
int64_t VeloxShuffleReaderDeserializerFactory::getDecompressTime() {
return decompressTime_;
}

int64_t VeloxColumnarBatchDeserializerFactory::getDeserializeTime() {
int64_t VeloxShuffleReaderDeserializerFactory::getDeserializeTime() {
return deserializeTime_;
}

void VeloxColumnarBatchDeserializerFactory::initFromSchema() {
void VeloxShuffleReaderDeserializerFactory::initFromSchema() {
GLUTEN_ASSIGN_OR_THROW(auto arrowColumnTypes, toShuffleTypeId(schema_->fields()));
isValidityBuffer_.reserve(arrowColumnTypes.size());
for (size_t i = 0; i < arrowColumnTypes.size(); ++i) {
Expand Down Expand Up @@ -681,7 +677,23 @@ void VeloxColumnarBatchDeserializerFactory::initFromSchema() {
}
}

VeloxShuffleReader::VeloxShuffleReader(std::unique_ptr<DeserializerFactory> factory)
: ShuffleReader(std::move(factory)) {}
VeloxShuffleReader::VeloxShuffleReader(std::unique_ptr<VeloxShuffleReaderDeserializerFactory> factory)
: factory_(std::move(factory)) {}

std::shared_ptr<ResultIterator> VeloxShuffleReader::readStream(std::shared_ptr<arrow::io::InputStream> in) {
return std::make_shared<ResultIterator>(factory_->createDeserializer(in));
}

arrow::MemoryPool* VeloxShuffleReader::getPool() const {
return factory_->getPool();
}

int64_t VeloxShuffleReader::getDecompressTime() const {
return factory_->getDecompressTime();
}

int64_t VeloxShuffleReader::getDeserializeTime() const {
return factory_->getDeserializeTime();
}

} // namespace gluten
33 changes: 20 additions & 13 deletions cpp/velox/shuffle/VeloxShuffleReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,14 @@

#pragma once

#include "operators/serializer/VeloxColumnarBatchSerializer.h"
#include "shuffle/Payload.h"
#include "shuffle/ShuffleReader.h"
#include "shuffle/VeloxSortShuffleWriter.h"
#include "utils/Timer.h"

#include "velox/serializers/PrestoSerializer.h"
#include "velox/type/Type.h"
#include "velox/vector/ComplexVector.h"

#include <velox/serializers/PrestoSerializer.h>

namespace gluten {

class VeloxHashShuffleReaderDeserializer final : public ColumnarBatchIterator {
Expand Down Expand Up @@ -134,9 +132,9 @@ class VeloxRssSortShuffleReaderDeserializer : public ColumnarBatchIterator {
std::shared_ptr<VeloxInputStream> in_;
};

class VeloxColumnarBatchDeserializerFactory : public DeserializerFactory {
class VeloxShuffleReaderDeserializerFactory {
public:
VeloxColumnarBatchDeserializerFactory(
VeloxShuffleReaderDeserializerFactory(
const std::shared_ptr<arrow::Schema>& schema,
const std::shared_ptr<arrow::util::Codec>& codec,
const facebook::velox::common::CompressionKind veloxCompressionType,
Expand All @@ -147,15 +145,13 @@ class VeloxColumnarBatchDeserializerFactory : public DeserializerFactory {
std::shared_ptr<facebook::velox::memory::MemoryPool> veloxPool,
ShuffleWriterType shuffleWriterType);

std::unique_ptr<ColumnarBatchIterator> createDeserializer(std::shared_ptr<arrow::io::InputStream> in) override;

arrow::MemoryPool* getPool() override;
std::unique_ptr<ColumnarBatchIterator> createDeserializer(std::shared_ptr<arrow::io::InputStream> in);

int64_t getDecompressTime() override;
arrow::MemoryPool* getPool();

int64_t getDeserializeTime() override;
int64_t getDecompressTime();

ShuffleWriterType getShuffleWriterType() override;
int64_t getDeserializeTime();

private:
void initFromSchema();
Expand All @@ -180,6 +176,17 @@ class VeloxColumnarBatchDeserializerFactory : public DeserializerFactory {

class VeloxShuffleReader final : public ShuffleReader {
public:
VeloxShuffleReader(std::unique_ptr<DeserializerFactory> factory);
VeloxShuffleReader(std::unique_ptr<VeloxShuffleReaderDeserializerFactory> factory);

std::shared_ptr<ResultIterator> readStream(std::shared_ptr<arrow::io::InputStream> in) override;

int64_t getDecompressTime() const override;

int64_t getDeserializeTime() const override;

arrow::MemoryPool* getPool() const override;

private:
std::unique_ptr<VeloxShuffleReaderDeserializerFactory> factory_;
};
} // namespace gluten
2 changes: 1 addition & 1 deletion cpp/velox/utils/tests/VeloxShuffleWriterTestBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ class VeloxShuffleWriterTest : public ::testing::TestWithParam<ShuffleTestParams
facebook::velox::serializer::presto::PrestoVectorSerde::registerVectorSerde();
}
// Set batchSize to a large value to make all batches are merged by reader.
auto deserializerFactory = std::make_unique<gluten::VeloxColumnarBatchDeserializerFactory>(
auto deserializerFactory = std::make_unique<gluten::VeloxShuffleReaderDeserializerFactory>(
schema,
std::move(codec),
veloxCompressionType,
Expand Down
Loading