diff --git a/.github/workflows/arrow-flight-tests.yml b/.github/workflows/arrow-flight-tests.yml new file mode 100644 index 000000000000..ee77c122536e --- /dev/null +++ b/.github/workflows/arrow-flight-tests.yml @@ -0,0 +1,82 @@ +name: arrow flight tests + +on: + pull_request: + +env: + CONTINUOUS_INTEGRATION: true + MAVEN_OPTS: "-Xmx1024M -XX:+ExitOnOutOfMemoryError" + MAVEN_INSTALL_OPTS: "-Xmx2G -XX:+ExitOnOutOfMemoryError" + MAVEN_FAST_INSTALL: "-B -V --quiet -T 1C -DskipTests -Dair.check.skip-all --no-transfer-progress -Dmaven.javadoc.skip=true" + MAVEN_TEST: "-B -Dair.check.skip-all -Dmaven.javadoc.skip=true -DLogTestDurationListener.enabled=true --no-transfer-progress --fail-at-end" + RETRY: .github/bin/retry + +jobs: + changes: + runs-on: ubuntu-latest + permissions: + pull-requests: read + outputs: + codechange: ${{ steps.filter.outputs.codechange }} + steps: + - uses: dorny/paths-filter@v2 + id: filter + with: + filters: | + codechange: + - '!presto-docs/**' + test: + runs-on: ubuntu-latest + needs: changes + strategy: + fail-fast: false + matrix: + modules: + - ":presto-base-arrow-flight" # Only run tests for the `presto-base-arrow-flight` module + + timeout-minutes: 80 + concurrency: + group: ${{ github.workflow }}-test-${{ matrix.modules }}-${{ github.event.pull_request.number }} + cancel-in-progress: true + + steps: + # Checkout the code only if there are changes in the relevant files + - uses: actions/checkout@v4 + if: needs.changes.outputs.codechange == 'true' + with: + show-progress: false + + # Set up Java for the build environment + - uses: actions/setup-java@v2 + if: needs.changes.outputs.codechange == 'true' + with: + distribution: 'temurin' + java-version: 8 + + # Cache Maven dependencies to speed up the build + - name: Cache local Maven repository + if: needs.changes.outputs.codechange == 'true' + id: cache-maven + uses: actions/cache@v2 + with: + path: ~/.m2/repository + key: ${{ runner.os }}-maven-2-${{ hashFiles('**/pom.xml') }} + restore-keys: | + ${{ runner.os }}-maven-2- + + # Resolve Maven dependencies (if cache is not found) + - name: Populate Maven cache + if: steps.cache-maven.outputs.cache-hit != 'true' && needs.changes.outputs.codechange == 'true' + run: ./mvnw de.qaware.maven:go-offline-maven-plugin:resolve-dependencies --no-transfer-progress && .github/bin/download_nodejs + + # Install dependencies for the target module + - name: Maven Install + if: needs.changes.outputs.codechange == 'true' + run: | + export MAVEN_OPTS="${MAVEN_INSTALL_OPTS}" + ./mvnw install ${MAVEN_FAST_INSTALL} -am -pl ${{ matrix.modules }} + + # Run Maven tests for the target module + - name: Maven Tests + if: needs.changes.outputs.codechange == 'true' + run: ./mvnw test ${MAVEN_TEST} -pl ${{ matrix.modules }} diff --git a/.github/workflows/test-other-modules.yml b/.github/workflows/test-other-modules.yml index 3c3f84817a0d..1065fe3d0c6e 100644 --- a/.github/workflows/test-other-modules.yml +++ b/.github/workflows/test-other-modules.yml @@ -84,4 +84,5 @@ jobs: !presto-test-coverage, !presto-iceberg, !presto-singlestore, - !presto-native-sidecar-plugin' + !presto-native-sidecar-plugin, + !presto-base-arrow-flight' diff --git a/pom.xml b/pom.xml index 82099ffb5b04..338af56f350f 100644 --- a/pom.xml +++ b/pom.xml @@ -85,6 +85,7 @@ 3.25.5 4.1.115.Final 2.0 + 17.0.0 + + com.fasterxml.jackson.datatype + jackson-datatype-jsr310 + + + + + + org.apache.arrow + flight-core + + + org.slf4j + slf4j-api + + + + + + com.facebook.airlift + bootstrap + + + + com.facebook.airlift + log + + + + com.google.guava + guava + + + + javax.inject + javax.inject + + + + com.facebook.presto + presto-spi + + + + io.airlift + slice + + + + com.facebook.airlift + log-manager + + + + com.fasterxml.jackson.core + jackson-annotations + + + + com.facebook.presto + presto-common + + + + com.fasterxml.jackson.core + jackson-databind + + + + com.google.code.findbugs + jsr305 + true + + + + com.google.inject + guice + + + + com.facebook.airlift + configuration + + + + joda-time + joda-time + + + + org.jdbi + jdbi3-core + + + + + org.testng + testng + test + + + + io.airlift.tpch + tpch + test + + + + com.facebook.presto + presto-tpch + test + + + + com.facebook.airlift + json + test + + + + com.facebook.presto + presto-testng-services + test + + + + com.facebook.airlift + testing + test + + + + com.facebook.presto + presto-main + test + + + + com.facebook.presto + presto-tests + test + + + + com.h2database + h2 + test + + + + + + + com.google.errorprone + error_prone_annotations + ${error_prone_annotations.version} + + + + io.perfmark + perfmark-api + ${perfmark-api.version} + + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + -Xss10M + + + + org.apache.maven.plugins + maven-enforcer-plugin + + + org.apache.maven.plugins + maven-dependency-plugin + + + org.basepom.maven + duplicate-finder-maven-plugin + 1.2.1 + + + module-info + META-INF.versions.9.module-info + + + arrow-git.properties + about.html + + + + + + check + + + + + + + diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowBlockBuilder.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowBlockBuilder.java new file mode 100644 index 000000000000..cf3f8f5e6f4f --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowBlockBuilder.java @@ -0,0 +1,941 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.block.DictionaryBlock; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.BooleanType; +import com.facebook.presto.common.type.CharType; +import com.facebook.presto.common.type.DateType; +import com.facebook.presto.common.type.DecimalType; +import com.facebook.presto.common.type.Decimals; +import com.facebook.presto.common.type.DoubleType; +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.common.type.RealType; +import com.facebook.presto.common.type.RowType; +import com.facebook.presto.common.type.SmallintType; +import com.facebook.presto.common.type.TimeType; +import com.facebook.presto.common.type.TimestampType; +import com.facebook.presto.common.type.TinyintType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarbinaryType; +import com.facebook.presto.common.type.VarcharType; +import com.google.common.base.CharMatcher; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.NullVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeMicroVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeSecVector; +import org.apache.arrow.vector.TimeStampMicroVector; +import org.apache.arrow.vector.TimeStampMilliTZVector; +import org.apache.arrow.vector.TimeStampMilliVector; +import org.apache.arrow.vector.TimeStampSecVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.util.JsonStringArrayList; + +import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.LocalTime; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_TYPE_ERROR; +import static java.util.Objects.requireNonNull; + +public class ArrowBlockBuilder +{ + public Block buildBlockFromFieldVector(FieldVector vector, Type type, DictionaryProvider dictionaryProvider) + { + if (vector.getField().getDictionary() != null) { + Dictionary dictionary = dictionaryProvider.lookup(vector.getField().getDictionary().getId()); + return buildBlockFromDictionaryVector(vector, dictionary.getVector()); + } + else { + return buildBlockFromValueVector(vector, type); + } + } + + public Block buildBlockFromDictionaryVector(FieldVector fieldVector, FieldVector dictionaryVector) + { + // Validate inputs + requireNonNull(fieldVector, "encoded vector is null"); + requireNonNull(dictionaryVector, "dictionary vector is null"); + + Type prestoType = getPrestoTypeFromArrowField(dictionaryVector.getField()); + + Block dictionaryblock = buildBlockFromValueVector(dictionaryVector, prestoType); + + // Return Presto DictionaryBlock + return getDictionaryBlock(fieldVector, dictionaryblock); + } + + protected Type getPrestoTypeFromArrowField(Field field) + { + switch (field.getType().getTypeID()) { + case Int: + ArrowType.Int intType = (ArrowType.Int) field.getType(); + return getPrestoTypeForArrowIntType(intType); + case Binary: + case LargeBinary: + case FixedSizeBinary: + return VarbinaryType.VARBINARY; + case Date: + return DateType.DATE; + case Timestamp: + return TimestampType.TIMESTAMP; + case Utf8: + case LargeUtf8: + return VarcharType.VARCHAR; + case FloatingPoint: + ArrowType.FloatingPoint floatingPoint = (ArrowType.FloatingPoint) field.getType(); + return getPrestoTypeForArrowFloatingPointType(floatingPoint); + case Decimal: + ArrowType.Decimal decimalType = (ArrowType.Decimal) field.getType(); + return DecimalType.createDecimalType(decimalType.getPrecision(), decimalType.getScale()); + case Bool: + return BooleanType.BOOLEAN; + case Time: + return TimeType.TIME; + default: + throw new UnsupportedOperationException("The data type " + field.getType().getTypeID() + " is not supported."); + } + } + + private Type getPrestoTypeForArrowFloatingPointType(ArrowType.FloatingPoint floatingPoint) + { + switch (floatingPoint.getPrecision()) { + case SINGLE: + return RealType.REAL; + case DOUBLE: + return DoubleType.DOUBLE; + default: + throw new ArrowException(ARROW_FLIGHT_TYPE_ERROR, "Unexpected floating point precision " + floatingPoint.getPrecision()); + } + } + + private Type getPrestoTypeForArrowIntType(ArrowType.Int intType) + { + switch (intType.getBitWidth()) { + case 64: + return BigintType.BIGINT; + case 32: + return IntegerType.INTEGER; + case 16: + return SmallintType.SMALLINT; + case 8: + return TinyintType.TINYINT; + default: + throw new ArrowException(ARROW_FLIGHT_TYPE_ERROR, "Unexpected bit width " + intType.getBitWidth()); + } + } + + private DictionaryBlock getDictionaryBlock(FieldVector fieldVector, Block dictionaryblock) + { + if (fieldVector instanceof IntVector) { + // Get the Arrow indices vector + IntVector indicesVector = (IntVector) fieldVector; + int[] ids = new int[indicesVector.getValueCount()]; + for (int i = 0; i < indicesVector.getValueCount(); i++) { + ids[i] = indicesVector.get(i); + } + return new DictionaryBlock(ids.length, dictionaryblock, ids); + } + else if (fieldVector instanceof SmallIntVector) { + // Get the SmallInt indices vector + SmallIntVector smallIntIndicesVector = (SmallIntVector) fieldVector; + int[] ids = new int[smallIntIndicesVector.getValueCount()]; + for (int i = 0; i < smallIntIndicesVector.getValueCount(); i++) { + ids[i] = smallIntIndicesVector.get(i); + } + return new DictionaryBlock(ids.length, dictionaryblock, ids); + } + else if (fieldVector instanceof TinyIntVector) { + // Get the TinyInt indices vector + TinyIntVector tinyIntIndicesVector = (TinyIntVector) fieldVector; + int[] ids = new int[tinyIntIndicesVector.getValueCount()]; + for (int i = 0; i < tinyIntIndicesVector.getValueCount(); i++) { + ids[i] = tinyIntIndicesVector.get(i); + } + return new DictionaryBlock(ids.length, dictionaryblock, ids); + } + else { + // Handle the case where the FieldVector is of an unsupported type + throw new IllegalArgumentException("Unsupported FieldVector type: " + fieldVector.getClass()); + } + } + + private Block buildBlockFromValueVector(ValueVector vector, Type type) + { + if (vector instanceof BitVector) { + return buildBlockFromBitVector((BitVector) vector, type); + } + else if (vector instanceof TinyIntVector) { + return buildBlockFromTinyIntVector((TinyIntVector) vector, type); + } + else if (vector instanceof IntVector) { + return buildBlockFromIntVector((IntVector) vector, type); + } + else if (vector instanceof SmallIntVector) { + return buildBlockFromSmallIntVector((SmallIntVector) vector, type); + } + else if (vector instanceof BigIntVector) { + return buildBlockFromBigIntVector((BigIntVector) vector, type); + } + else if (vector instanceof DecimalVector) { + return buildBlockFromDecimalVector((DecimalVector) vector, type); + } + else if (vector instanceof NullVector) { + return buildBlockFromNullVector((NullVector) vector, type); + } + else if (vector instanceof TimeStampMicroVector) { + return buildBlockFromTimeStampMicroVector((TimeStampMicroVector) vector, type); + } + else if (vector instanceof TimeStampMilliVector) { + return buildBlockFromTimeStampMilliVector((TimeStampMilliVector) vector, type); + } + else if (vector instanceof Float4Vector) { + return buildBlockFromFloat4Vector((Float4Vector) vector, type); + } + else if (vector instanceof Float8Vector) { + return buildBlockFromFloat8Vector((Float8Vector) vector, type); + } + else if (vector instanceof VarCharVector) { + if (type instanceof CharType) { + return buildCharTypeBlockFromVarcharVector((VarCharVector) vector, type); + } + else if (type instanceof TimeType) { + return buildTimeTypeBlockFromVarcharVector((VarCharVector) vector, type); + } + else { + return buildBlockFromVarCharVector((VarCharVector) vector, type); + } + } + else if (vector instanceof VarBinaryVector) { + return buildBlockFromVarBinaryVector((VarBinaryVector) vector, type); + } + else if (vector instanceof DateDayVector) { + return buildBlockFromDateDayVector((DateDayVector) vector, type); + } + else if (vector instanceof DateMilliVector) { + return buildBlockFromDateMilliVector((DateMilliVector) vector, type); + } + else if (vector instanceof TimeMilliVector) { + return buildBlockFromTimeMilliVector((TimeMilliVector) vector, type); + } + else if (vector instanceof TimeSecVector) { + return buildBlockFromTimeSecVector((TimeSecVector) vector, type); + } + else if (vector instanceof TimeStampSecVector) { + return buildBlockFromTimeStampSecVector((TimeStampSecVector) vector, type); + } + else if (vector instanceof TimeMicroVector) { + return buildBlockFromTimeMicroVector((TimeMicroVector) vector, type); + } + else if (vector instanceof TimeStampMilliTZVector) { + return buildBlockFromTimeMilliTZVector((TimeStampMilliTZVector) vector, type); + } + else if (vector instanceof ListVector) { + return buildBlockFromListVector((ListVector) vector, type); + } + else { + throw new UnsupportedOperationException("Unsupported vector type: " + vector.getClass()); + } + } + + public Block buildBlockFromTimeMilliTZVector(TimeStampMilliTZVector vector, Type type) + { + if (!(type instanceof TimestampType)) { + throw new IllegalArgumentException("Type must be a TimestampType for TimeStampMilliTZVector"); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long millis = vector.get(i); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public Block buildBlockFromBitVector(BitVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeBoolean(builder, vector.get(i) == 1); + } + } + return builder.build(); + } + + public Block buildBlockFromIntVector(IntVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + return builder.build(); + } + + public Block buildBlockFromSmallIntVector(SmallIntVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + return builder.build(); + } + + public Block buildBlockFromTinyIntVector(TinyIntVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + return builder.build(); + } + + public Block buildBlockFromBigIntVector(BigIntVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + return builder.build(); + } + + public Block buildBlockFromDecimalVector(DecimalVector vector, Type type) + { + if (!(type instanceof DecimalType)) { + throw new IllegalArgumentException("Type must be a DecimalType for DecimalVector"); + } + + DecimalType decimalType = (DecimalType) type; + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + BigDecimal decimal = vector.getObject(i); // Get the BigDecimal value + if (decimalType.isShort()) { + builder.writeLong(decimal.unscaledValue().longValue()); + } + else { + Slice slice = Decimals.encodeScaledValue(decimal); + decimalType.writeSlice(builder, slice, 0, slice.length()); + } + } + } + return builder.build(); + } + + public Block buildBlockFromNullVector(NullVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + builder.appendNull(); + } + return builder.build(); + } + + public Block buildBlockFromTimeStampMicroVector(TimeStampMicroVector vector, Type type) + { + if (!(type instanceof TimestampType)) { + throw new IllegalArgumentException("Expected TimestampType but got " + type.getClass().getName()); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long micros = vector.get(i); + long millis = TimeUnit.MICROSECONDS.toMillis(micros); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public Block buildBlockFromTimeStampMilliVector(TimeStampMilliVector vector, Type type) + { + if (!(type instanceof TimestampType)) { + throw new IllegalArgumentException("Expected TimestampType but got " + type.getClass().getName()); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long millis = vector.get(i); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public Block buildBlockFromFloat8Vector(Float8Vector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeDouble(builder, vector.get(i)); + } + } + return builder.build(); + } + + public Block buildBlockFromFloat4Vector(Float4Vector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + int intBits = Float.floatToIntBits(vector.get(i)); + type.writeLong(builder, intBits); + } + } + return builder.build(); + } + + public Block buildBlockFromVarBinaryVector(VarBinaryVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + byte[] value = vector.get(i); + type.writeSlice(builder, Slices.wrappedBuffer(value)); + } + } + return builder.build(); + } + + public Block buildBlockFromVarCharVector(VarCharVector vector, Type type) + { + if (!(type instanceof VarcharType)) { + throw new IllegalArgumentException("Expected VarcharType but got " + type.getClass().getName()); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + // Directly create a Slice from the raw byte array + byte[] rawBytes = vector.get(i); + Slice slice = Slices.wrappedBuffer(rawBytes); + // Write the Slice directly to the builder + type.writeSlice(builder, slice); + } + } + return builder.build(); + } + + public Block buildBlockFromDateDayVector(DateDayVector vector, Type type) + { + if (!(type instanceof DateType)) { + throw new IllegalArgumentException("Expected DateType but got " + type.getClass().getName()); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + return builder.build(); + } + + public Block buildBlockFromDateMilliVector(DateMilliVector vector, Type type) + { + if (!(type instanceof DateType)) { + throw new IllegalArgumentException("Expected DateType but got " + type.getClass().getName()); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + DateType dateType = (DateType) type; + long days = TimeUnit.MILLISECONDS.toDays(vector.get(i)); + dateType.writeLong(builder, days); + } + } + return builder.build(); + } + + public Block buildBlockFromTimeSecVector(TimeSecVector vector, Type type) + { + if (!(type instanceof TimeType)) { + throw new IllegalArgumentException("Type must be a TimeType for TimeSecVector"); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + int value = vector.get(i); + long millis = TimeUnit.SECONDS.toMillis(value); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public Block buildBlockFromTimeMilliVector(TimeMilliVector vector, Type type) + { + if (!(type instanceof TimeType)) { + throw new IllegalArgumentException("Type must be a TimeType for TimeSecVector"); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long millis = vector.get(i); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public Block buildBlockFromTimeMicroVector(TimeMicroVector vector, Type type) + { + if (!(type instanceof TimeType)) { + throw new IllegalArgumentException("Type must be a TimeType for TimemicroVector"); + } + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long value = vector.get(i); + long micro = TimeUnit.MICROSECONDS.toMillis(value); + type.writeLong(builder, micro); + } + } + return builder.build(); + } + + public Block buildBlockFromTimeStampSecVector(TimeStampSecVector vector, Type type) + { + if (!(type instanceof TimestampType)) { + throw new IllegalArgumentException("Type must be a TimestampType for TimeStampSecVector"); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long value = vector.get(i); + long millis = TimeUnit.SECONDS.toMillis(value); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public Block buildCharTypeBlockFromVarcharVector(VarCharVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + String value = new String(vector.get(i), StandardCharsets.UTF_8); + type.writeSlice(builder, Slices.utf8Slice(CharMatcher.is(' ').trimTrailingFrom(value))); + } + } + return builder.build(); + } + + public Block buildTimeTypeBlockFromVarcharVector(VarCharVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + String timeString = new String(vector.get(i), StandardCharsets.UTF_8); + LocalTime time = LocalTime.parse(timeString); + long millis = Duration.between(LocalTime.MIN, time).toMillis(); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public Block buildBlockFromListVector(ListVector vector, Type type) + { + if (!(type instanceof ArrayType)) { + throw new IllegalArgumentException("Type must be an ArrayType for ListVector"); + } + + ArrayType arrayType = (ArrayType) type; + Type elementType = arrayType.getElementType(); + BlockBuilder arrayBuilder = type.createBlockBuilder(null, vector.getValueCount()); + + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + arrayBuilder.appendNull(); + } + else { + BlockBuilder elementBuilder = arrayBuilder.beginBlockEntry(); + FieldVector dataVector = vector.getDataVector(); + int startIndex = vector.getElementStartIndex(i); + int endIndex = vector.getElementEndIndex(i); + for (int j = startIndex; j < endIndex; j++) { + appendValueToBuilder(elementType, elementBuilder, dataVector.getObject(j)); + } + arrayBuilder.closeEntry(); + } + } + return arrayBuilder.build(); + } + + public void appendValueToBuilder(Type type, BlockBuilder builder, Object value) + { + if (value == null) { + builder.appendNull(); + return; + } + + if (type instanceof VarcharType) { + writeVarcharType(type, builder, value); + } + else if (type instanceof SmallintType) { + writeSmallintType(type, builder, value); + } + else if (type instanceof TinyintType) { + writeTinyintType(type, builder, value); + } + else if (type instanceof BigintType) { + writeBigintType(type, builder, value); + } + else if (type instanceof IntegerType) { + writeIntegerType(type, builder, value); + } + else if (type instanceof DoubleType) { + writeDoubleType(type, builder, value); + } + else if (type instanceof BooleanType) { + writeBooleanType(type, builder, value); + } + else if (type instanceof DecimalType) { + writeDecimalType((DecimalType) type, builder, value); + } + else if (type instanceof ArrayType) { + writeArrayType((ArrayType) type, builder, value); + } + else if (type instanceof RowType) { + writeRowType((RowType) type, builder, value); + } + else if (type instanceof DateType) { + writeDateType(type, builder, value); + } + else if (type instanceof TimestampType) { + writeTimestampType(type, builder, value); + } + else { + throw new IllegalArgumentException("Unsupported type: " + type); + } + } + + public void writeVarcharType(Type type, BlockBuilder builder, Object value) + { + Slice slice = Slices.utf8Slice(value.toString()); + type.writeSlice(builder, slice); + } + + public void writeSmallintType(Type type, BlockBuilder builder, Object value) + { + if (value instanceof Number) { + type.writeLong(builder, ((Number) value).shortValue()); + } + else if (value instanceof JsonStringArrayList) { + for (Object obj : (JsonStringArrayList) value) { + try { + short shortValue = Short.parseShort(obj.toString()); + type.writeLong(builder, shortValue); + } + catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid number format in JsonStringArrayList for SmallintType: " + obj, e); + } + } + } + else { + throw new IllegalArgumentException("Unsupported type for SmallintType: " + value.getClass()); + } + } + + public void writeTinyintType(Type type, BlockBuilder builder, Object value) + { + if (value instanceof Number) { + type.writeLong(builder, ((Number) value).byteValue()); + } + else if (value instanceof JsonStringArrayList) { + for (Object obj : (JsonStringArrayList) value) { + try { + byte byteValue = Byte.parseByte(obj.toString()); + type.writeLong(builder, byteValue); + } + catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid number format in JsonStringArrayList for TinyintType: " + obj, e); + } + } + } + else { + throw new IllegalArgumentException("Unsupported type for TinyintType: " + value.getClass()); + } + } + + public void writeBigintType(Type type, BlockBuilder builder, Object value) + { + if (value instanceof Long) { + type.writeLong(builder, (Long) value); + } + else if (value instanceof Integer) { + type.writeLong(builder, ((Integer) value).longValue()); + } + else if (value instanceof JsonStringArrayList) { + for (Object obj : (JsonStringArrayList) value) { + try { + long longValue = Long.parseLong(obj.toString()); + type.writeLong(builder, longValue); + } + catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid number format in JsonStringArrayList: " + obj, e); + } + } + } + else { + throw new IllegalArgumentException("Unsupported type for BigintType: " + value.getClass()); + } + } + + public void writeIntegerType(Type type, BlockBuilder builder, Object value) + { + if (value instanceof Integer) { + type.writeLong(builder, (Integer) value); + } + else if (value instanceof JsonStringArrayList) { + for (Object obj : (JsonStringArrayList) value) { + try { + int intValue = Integer.parseInt(obj.toString()); + type.writeLong(builder, intValue); + } + catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid number format in JsonStringArrayList: " + obj, e); + } + } + } + else { + throw new IllegalArgumentException("Unsupported type for IntegerType: " + value.getClass()); + } + } + + public void writeDoubleType(Type type, BlockBuilder builder, Object value) + { + if (value instanceof Double) { + type.writeDouble(builder, (Double) value); + } + else if (value instanceof Float) { + type.writeDouble(builder, ((Float) value).doubleValue()); + } + else if (value instanceof JsonStringArrayList) { + for (Object obj : (JsonStringArrayList) value) { + try { + double doubleValue = Double.parseDouble(obj.toString()); + type.writeDouble(builder, doubleValue); + } + catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid number format in JsonStringArrayList: " + obj, e); + } + } + } + else { + throw new IllegalArgumentException("Unsupported type for DoubleType: " + value.getClass()); + } + } + + public void writeBooleanType(Type type, BlockBuilder builder, Object value) + { + if (value instanceof Boolean) { + type.writeBoolean(builder, (Boolean) value); + } + else { + throw new IllegalArgumentException("Unsupported type for BooleanType: " + value.getClass()); + } + } + + public void writeDecimalType(DecimalType type, BlockBuilder builder, Object value) + { + if (value instanceof BigDecimal) { + BigDecimal decimalValue = (BigDecimal) value; + if (type.isShort()) { + // write ShortDecimalType + long unscaledValue = decimalValue.unscaledValue().longValue(); + type.writeLong(builder, unscaledValue); + } + else { + // write LongDecimalType + Slice slice = Decimals.encodeScaledValue(decimalValue); + type.writeSlice(builder, slice); + } + } + else if (value instanceof Long) { + // Direct handling for ShortDecimalType using long + if (type.isShort()) { + type.writeLong(builder, (Long) value); + } + else { + throw new IllegalArgumentException("Long value is not supported for LongDecimalType."); + } + } + else { + throw new IllegalArgumentException("Unsupported type for DecimalType: " + value.getClass()); + } + } + + public void writeArrayType(ArrayType type, BlockBuilder builder, Object value) + { + Type elementType = type.getElementType(); + BlockBuilder arrayBuilder = builder.beginBlockEntry(); + for (Object element : (Iterable) value) { + appendValueToBuilder(elementType, arrayBuilder, element); + } + builder.closeEntry(); + } + + public void writeRowType(RowType type, BlockBuilder builder, Object value) + { + List rowValues = (List) value; + BlockBuilder rowBuilder = builder.beginBlockEntry(); + List fields = type.getFields(); + for (int i = 0; i < fields.size(); i++) { + Type fieldType = fields.get(i).getType(); + appendValueToBuilder(fieldType, rowBuilder, rowValues.get(i)); + } + builder.closeEntry(); + } + + public void writeDateType(Type type, BlockBuilder builder, Object value) + { + if (value instanceof java.sql.Date || value instanceof java.time.LocalDate) { + int daysSinceEpoch = (int) (value instanceof java.sql.Date + ? ((java.sql.Date) value).toLocalDate().toEpochDay() + : ((java.time.LocalDate) value).toEpochDay()); + type.writeLong(builder, daysSinceEpoch); + } + else { + throw new IllegalArgumentException("Unsupported type for DateType: " + value.getClass()); + } + } + + public void writeTimestampType(Type type, BlockBuilder builder, Object value) + { + if (value instanceof java.sql.Timestamp) { + long millis = ((java.sql.Timestamp) value).getTime(); + type.writeLong(builder, millis); + } + else if (value instanceof java.time.Instant) { + long millis = ((java.time.Instant) value).toEpochMilli(); + type.writeLong(builder, millis); + } + else if (value instanceof Long) { // write long epoch milliseconds directly + type.writeLong(builder, (Long) value); + } + else { + throw new IllegalArgumentException("Unsupported type for TimestampType: " + value.getClass()); + } + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowColumnHandle.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowColumnHandle.java new file mode 100644 index 000000000000..1ee3791e3745 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowColumnHandle.java @@ -0,0 +1,61 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ColumnMetadata; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import static java.util.Objects.requireNonNull; + +public class ArrowColumnHandle + implements ColumnHandle +{ + private final String columnName; + private final Type columnType; + + @JsonCreator + public ArrowColumnHandle( + @JsonProperty("columnName") String columnName, + @JsonProperty("columnType") Type columnType) + { + this.columnName = requireNonNull(columnName, "columnName is null"); + this.columnType = requireNonNull(columnType, "type is null"); + } + + @JsonProperty + public String getColumnName() + { + return columnName; + } + + @JsonProperty + public Type getColumnType() + { + return columnType; + } + + public ColumnMetadata getColumnMetadata() + { + return new ColumnMetadata(columnName, columnType); + } + + @Override + public String toString() + { + return columnName + ":" + columnType; + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java new file mode 100644 index 000000000000..759129d562e3 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java @@ -0,0 +1,88 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.ConnectorHandleResolver; +import com.facebook.presto.spi.connector.Connector; +import com.facebook.presto.spi.connector.ConnectorMetadata; +import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; +import com.facebook.presto.spi.connector.ConnectorSplitManager; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.transaction.IsolationLevel; +import com.google.inject.Inject; + +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public class ArrowConnector + implements Connector +{ + private final ConnectorMetadata metadata; + private final ConnectorSplitManager splitManager; + private final ConnectorPageSourceProvider pageSourceProvider; + private final ConnectorHandleResolver handleResolver; + + private final BaseArrowFlightClient arrowFlightClientHandler; + + @Inject + public ArrowConnector( + ConnectorMetadata metadata, + ConnectorHandleResolver handleResolver, + ConnectorSplitManager splitManager, + ConnectorPageSourceProvider pageSourceProvider, + BaseArrowFlightClient arrowFlightClientHandler) + { + this.metadata = requireNonNull(metadata, "Metadata is null"); + this.handleResolver = requireNonNull(handleResolver, "handleResolver is null"); + this.splitManager = requireNonNull(splitManager, "SplitManager is null"); + this.pageSourceProvider = requireNonNull(pageSourceProvider, "PageSinkProvider is null"); + this.arrowFlightClientHandler = requireNonNull(arrowFlightClientHandler, "arrow flight handler is null"); + } + + public Optional getHandleResolver() + { + return Optional.of(handleResolver); + } + + @Override + public ConnectorTransactionHandle beginTransaction(IsolationLevel isolationLevel, boolean readOnly) + { + return ArrowTransactionHandle.INSTANCE; + } + + @Override + public ConnectorMetadata getMetadata(ConnectorTransactionHandle transactionHandle) + { + return metadata; + } + + @Override + public ConnectorSplitManager getSplitManager() + { + return splitManager; + } + + @Override + public ConnectorPageSourceProvider getPageSourceProvider() + { + return pageSourceProvider; + } + + @Override + public void shutdown() + { + arrowFlightClientHandler.closeRootAllocator(); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorFactory.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorFactory.java new file mode 100644 index 000000000000..ce95cccdd051 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorFactory.java @@ -0,0 +1,104 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.bootstrap.Bootstrap; +import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.spi.ConnectorHandleResolver; +import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.classloader.ThreadContextClassLoader; +import com.facebook.presto.spi.connector.Connector; +import com.facebook.presto.spi.connector.ConnectorContext; +import com.facebook.presto.spi.connector.ConnectorFactory; +import com.facebook.presto.spi.function.FunctionMetadataManager; +import com.facebook.presto.spi.function.StandardFunctionResolution; +import com.facebook.presto.spi.relation.RowExpressionService; +import com.google.common.collect.ImmutableList; +import com.google.inject.ConfigurationException; +import com.google.inject.Injector; +import com.google.inject.Module; + +import java.util.List; +import java.util.Map; + +import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_INTERNAL_ERROR; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Strings.isNullOrEmpty; +import static com.google.common.base.Throwables.throwIfUnchecked; +import static com.google.inject.util.Modules.override; +import static java.util.Objects.requireNonNull; + +public class ArrowConnectorFactory + implements ConnectorFactory +{ + private final String name; + private final Module module; + private final ImmutableList extraModules; + private final ClassLoader classLoader; + + public ArrowConnectorFactory(String name, Module module, List extraModules, ClassLoader classLoader) + { + checkArgument(!isNullOrEmpty(name), "name is null or empty"); + this.name = name; + this.module = requireNonNull(module, "module is null"); + this.extraModules = ImmutableList.copyOf(requireNonNull(extraModules, "extraModules is null")); + this.classLoader = requireNonNull(classLoader, "classLoader is null"); + } + + @Override + public String getName() + { + return name; + } + + @Override + public ConnectorHandleResolver getHandleResolver() + { + return new ArrowHandleResolver(); + } + + @Override + public Connector create(String catalogName, Map requiredConfig, ConnectorContext context) + { + requireNonNull(requiredConfig, "requiredConfig is null"); + + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + Bootstrap app = new Bootstrap(ImmutableList.builder() + .add(binder -> { + binder.bind(TypeManager.class).toInstance(context.getTypeManager()); + binder.bind(FunctionMetadataManager.class).toInstance(context.getFunctionMetadataManager()); + binder.bind(StandardFunctionResolution.class).toInstance(context.getStandardFunctionResolution()); + binder.bind(RowExpressionService.class).toInstance(context.getRowExpressionService()); + binder.bind(NodeManager.class).toInstance(context.getNodeManager()); + }) + .add(override(new ArrowModule(catalogName)).with(module)) + .addAll(extraModules) + .build()); + + Injector injector = app + .doNotInitializeLogging() + .setRequiredConfigurationProperties(requiredConfig) + .initialize(); + + return injector.getInstance(ArrowConnector.class); + } + catch (ConfigurationException ex) { + throw new ArrowException(ARROW_INTERNAL_ERROR, "The connector instance could not be created.", ex); + } + catch (Exception e) { + throwIfUnchecked(e); + throw new RuntimeException(e); + } + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorId.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorId.java new file mode 100644 index 000000000000..dce08bac4ac2 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorId.java @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +public class ArrowConnectorId +{ + private final String id; + + public ArrowConnectorId(String id) + { + this.id = requireNonNull(id, "id is null"); + } + + @Override + public String toString() + { + return id; + } + + @Override + public int hashCode() + { + return Objects.hash(id); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if ((obj == null) || (getClass() != obj.getClass())) { + return false; + } + ArrowConnectorId other = (ArrowConnectorId) obj; + return Objects.equals(this.id, other.id); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowErrorCode.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowErrorCode.java new file mode 100644 index 000000000000..989dd9810102 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowErrorCode.java @@ -0,0 +1,44 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.ErrorCode; +import com.facebook.presto.common.ErrorType; +import com.facebook.presto.spi.ErrorCodeSupplier; + +import static com.facebook.presto.common.ErrorType.EXTERNAL; +import static com.facebook.presto.common.ErrorType.INTERNAL_ERROR; + +public enum ArrowErrorCode + implements ErrorCodeSupplier +{ + ARROW_FLIGHT_INFO_ERROR(0, EXTERNAL), + ARROW_INTERNAL_ERROR(1, INTERNAL_ERROR), + ARROW_FLIGHT_CLIENT_ERROR(2, EXTERNAL), + ARROW_FLIGHT_METADATA_ERROR(3, EXTERNAL), + ARROW_FLIGHT_TYPE_ERROR(4, EXTERNAL); + + private final ErrorCode errorCode; + + ArrowErrorCode(int code, ErrorType type) + { + errorCode = new ErrorCode(code + 0x0509_0000, name(), type); + } + + @Override + public ErrorCode toErrorCode() + { + return errorCode; + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowException.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowException.java new file mode 100644 index 000000000000..ba2c6edba589 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowException.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.ErrorCodeSupplier; +import com.facebook.presto.spi.PrestoException; + +public class ArrowException + extends PrestoException +{ + public ArrowException(ErrorCodeSupplier errorCode, String message) + { + super(errorCode, message); + } + + public ArrowException(ErrorCodeSupplier errorCode, String message, Throwable cause) + { + super(errorCode, message, cause); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightConfig.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightConfig.java new file mode 100644 index 000000000000..a30301dcdc36 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightConfig.java @@ -0,0 +1,84 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.configuration.Config; + +public class ArrowFlightConfig +{ + private String server; + private Boolean verifyServer; + private String flightServerSSLCertificate; + private Boolean arrowFlightServerSslEnabled; + private Integer arrowFlightPort; + public String getFlightServerName() + { + return server; + } + + public Boolean getVerifyServer() + { + return verifyServer; + } + + public Boolean getArrowFlightServerSslEnabled() + { + return arrowFlightServerSslEnabled; + } + + public String getFlightServerSSLCertificate() + { + return flightServerSSLCertificate; + } + + public Integer getArrowFlightPort() + { + return arrowFlightPort; + } + + @Config("arrow-flight.server") + public ArrowFlightConfig setFlightServerName(String server) + { + this.server = server; + return this; + } + + @Config("arrow-flight.server.verify") + public ArrowFlightConfig setVerifyServer(Boolean verifyServer) + { + this.verifyServer = verifyServer; + return this; + } + + @Config("arrow-flight.server.port") + public ArrowFlightConfig setArrowFlightPort(Integer arrowFlightPort) + { + this.arrowFlightPort = arrowFlightPort; + return this; + } + + @Config("arrow-flight.server-ssl-certificate") + public ArrowFlightConfig setFlightServerSSLCertificate(String flightServerSSLCertificate) + { + this.flightServerSSLCertificate = flightServerSSLCertificate; + return this; + } + + @Config("arrow-flight.server-ssl-enabled") + public ArrowFlightConfig setArrowFlightServerSslEnabled(Boolean arrowFlightServerSslEnabled) + { + this.arrowFlightServerSslEnabled = arrowFlightServerSslEnabled; + return this; + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowHandleResolver.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowHandleResolver.java new file mode 100644 index 000000000000..8b231b98a6ee --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowHandleResolver.java @@ -0,0 +1,55 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorHandleResolver; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; + +public class ArrowHandleResolver + implements ConnectorHandleResolver +{ + @Override + public Class getTableHandleClass() + { + return ArrowTableHandle.class; + } + + @Override + public Class getTableLayoutHandleClass() + { + return ArrowTableLayoutHandle.class; + } + + @Override + public Class getColumnHandleClass() + { + return ArrowColumnHandle.class; + } + + @Override + public Class getSplitClass() + { + return ArrowSplit.class; + } + + @Override + public Class getTransactionHandleClass() + { + return ArrowTransactionHandle.class; + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowMetadata.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowMetadata.java new file mode 100644 index 000000000000..5a55a7b3a9aa --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowMetadata.java @@ -0,0 +1,198 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.ConnectorTableLayout; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.ConnectorTableLayoutResult; +import com.facebook.presto.spi.ConnectorTableMetadata; +import com.facebook.presto.spi.Constraint; +import com.facebook.presto.spi.NotFoundException; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.SchemaTablePrefix; +import com.facebook.presto.spi.connector.ConnectorMetadata; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; + +import javax.inject.Inject; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_METADATA_ERROR; +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public class ArrowMetadata + implements ConnectorMetadata +{ + private static final Logger logger = Logger.get(ArrowMetadata.class); + private final BaseArrowFlightClient clientHandler; + private final ArrowBlockBuilder arrowBlockBuilder; + + @Inject + public ArrowMetadata(BaseArrowFlightClient clientHandler, ArrowBlockBuilder arrowBlockBuilder) + { + this.clientHandler = requireNonNull(clientHandler, "clientHandler is null"); + this.arrowBlockBuilder = requireNonNull(arrowBlockBuilder, "arrowPageBuilder is null"); + } + + @Override + public final List listSchemaNames(ConnectorSession session) + { + return clientHandler.listSchemaNames(session); + } + + @Override + public final List listTables(ConnectorSession session, Optional schemaName) + { + return clientHandler.listTables(session, schemaName); + } + + @Override + public ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTableName tableName) + { + if (!listSchemaNames(session).contains(tableName.getSchemaName())) { + return null; + } + + if (!listTables(session, Optional.ofNullable(tableName.getSchemaName())).contains(tableName)) { + return null; + } + return new ArrowTableHandle(tableName.getSchemaName(), tableName.getTableName()); + } + + public List getColumnsList(String schema, String table, ConnectorSession connectorSession) + { + try { + Schema flightSchema = clientHandler.getSchemaForTable(schema, table, connectorSession); + return flightSchema.getFields(); + } + catch (Exception e) { + throw new ArrowException(ARROW_FLIGHT_METADATA_ERROR, "The table columns could not be listed for the table " + table, e); + } + } + + @Override + public Map getColumnHandles(ConnectorSession session, ConnectorTableHandle tableHandle) + { + Map columnHandles = new HashMap<>(); + + String schemaValue = ((ArrowTableHandle) tableHandle).getSchema(); + String tableValue = ((ArrowTableHandle) tableHandle).getTable(); + List columnList = getColumnsList(schemaValue, tableValue, session); + + for (Field field : columnList) { + String columnName = field.getName(); + logger.debug("The value of the flight columnName is:- %s", columnName); + + Type type = getPrestoTypeFromArrowField(field); + columnHandles.put(columnName, new ArrowColumnHandle(columnName, type)); + } + return columnHandles; + } + + @Override + public List getTableLayouts(ConnectorSession session, ConnectorTableHandle table, Constraint constraint, Optional> desiredColumns) + { + checkArgument(table instanceof ArrowTableHandle, + "Invalid table handle: Expected an instance of ArrowTableHandle but received %", + table.getClass().getSimpleName()); + checkArgument(desiredColumns.orElse(Collections.emptySet()).stream().allMatch(f -> f instanceof ArrowColumnHandle), + "Invalid column handles: Expected desired columns to be of type ArrowColumnHandle"); + + ArrowTableHandle tableHandle = (ArrowTableHandle) table; + + List columns = new ArrayList<>(); + if (desiredColumns.isPresent()) { + List arrowColumns = new ArrayList<>(desiredColumns.get()); + columns = (List) (List) arrowColumns; + } + + ConnectorTableLayout layout = new ConnectorTableLayout(new ArrowTableLayoutHandle(tableHandle, columns, constraint.getSummary())); + return ImmutableList.of(new ConnectorTableLayoutResult(layout, constraint.getSummary())); + } + + @Override + public ConnectorTableLayout getTableLayout(ConnectorSession session, ConnectorTableLayoutHandle handle) + { + return new ConnectorTableLayout(handle); + } + + @Override + public ConnectorTableMetadata getTableMetadata(ConnectorSession session, ConnectorTableHandle table) + { + List meta = new ArrayList<>(); + List columnList = getColumnsList(((ArrowTableHandle) table).getSchema(), ((ArrowTableHandle) table).getTable(), session); + + for (Field field : columnList) { + String columnName = field.getName(); + Type fieldType = getPrestoTypeFromArrowField(field); + meta.add(new ColumnMetadata(columnName, fieldType)); + } + return new ConnectorTableMetadata(new SchemaTableName(((ArrowTableHandle) table).getSchema(), ((ArrowTableHandle) table).getTable()), meta); + } + + @Override + public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTableHandle tableHandle, ColumnHandle columnHandle) + { + return ((ArrowColumnHandle) columnHandle).getColumnMetadata(); + } + + @Override + public Map> listTableColumns(ConnectorSession session, SchemaTablePrefix prefix) + { + requireNonNull(prefix, "prefix is null"); + ImmutableMap.Builder> columns = ImmutableMap.builder(); + List tables; + if (prefix.getSchemaName() != null && prefix.getTableName() != null) { + tables = ImmutableList.of(new SchemaTableName(prefix.getSchemaName(), prefix.getTableName())); + } + else { + tables = listTables(session, Optional.of(prefix.getSchemaName())); + } + + for (SchemaTableName tableName : tables) { + try { + ConnectorTableHandle tableHandle = getTableHandle(session, tableName); + columns.put(tableName, getTableMetadata(session, tableHandle).getColumns()); + } + catch (ClassCastException | NotFoundException e) { + throw new ArrowException(ARROW_FLIGHT_METADATA_ERROR, "The table columns could not be listed for the table " + tableName, e); + } + catch (Exception e) { + throw new ArrowException(ARROW_FLIGHT_METADATA_ERROR, e.getMessage(), e); + } + } + return columns.build(); + } + + private Type getPrestoTypeFromArrowField(Field field) + { + return arrowBlockBuilder.getPrestoTypeFromArrowField(field); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowModule.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowModule.java new file mode 100644 index 000000000000..f2175b34f3ff --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowModule.java @@ -0,0 +1,51 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.ConnectorHandleResolver; +import com.facebook.presto.spi.connector.Connector; +import com.facebook.presto.spi.connector.ConnectorMetadata; +import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; +import com.facebook.presto.spi.connector.ConnectorSplitManager; +import com.google.inject.Binder; +import com.google.inject.Module; +import com.google.inject.Scopes; + +import static com.facebook.airlift.configuration.ConfigBinder.configBinder; +import static java.util.Objects.requireNonNull; + +public class ArrowModule + implements Module +{ + protected final String connectorId; + + public ArrowModule(String connectorId) + { + this.connectorId = requireNonNull(connectorId, "connector id is null"); + } + + public void configure(Binder binder) + { + configBinder(binder).bindConfig(ArrowFlightConfig.class); + binder.bind(ConnectorSplitManager.class).to(ArrowSplitManager.class).in(Scopes.SINGLETON); + binder.bind(ConnectorMetadata.class).to(ArrowMetadata.class).in(Scopes.SINGLETON); + binder.bind(ArrowConnector.class).in(Scopes.SINGLETON); + binder.bind(ArrowConnectorId.class).toInstance(new ArrowConnectorId(connectorId)); + binder.bind(ConnectorHandleResolver.class).to(ArrowHandleResolver.class).in(Scopes.SINGLETON); + binder.bind(ArrowPageSourceProvider.class).in(Scopes.SINGLETON); + binder.bind(ConnectorPageSourceProvider.class).to(ArrowPageSourceProvider.class).in(Scopes.SINGLETON); + binder.bind(Connector.class).to(ArrowConnector.class).in(Scopes.SINGLETON); + binder.bind(ArrowBlockBuilder.class).to(ArrowBlockBuilder.class).in(Scopes.SINGLETON); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java new file mode 100644 index 000000000000..80334cbd7f14 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java @@ -0,0 +1,142 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.ConnectorPageSource; +import com.facebook.presto.spi.ConnectorSession; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorSchemaRoot; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_CLIENT_ERROR; +import static java.util.Objects.requireNonNull; + +public class ArrowPageSource + implements ConnectorPageSource +{ + private static final Logger logger = Logger.get(ArrowPageSource.class); + private final ArrowSplit split; + private final List columnHandles; + private final ArrowBlockBuilder arrowBlockBuilder; + private final FlightStream flightStream; + private final ClientClosingFlightStream flightStreamAndClient; + private boolean completed; + private int currentPosition; + private VectorSchemaRoot vectorSchemaRoot; + + public ArrowPageSource( + ArrowSplit split, + List columnHandles, + BaseArrowFlightClient clientHandler, + ConnectorSession connectorSession, + ArrowBlockBuilder arrowBlockBuilder) + { + this.columnHandles = requireNonNull(columnHandles, "columnHandles is null"); + this.split = requireNonNull(split, "split is null"); + this.arrowBlockBuilder = requireNonNull(arrowBlockBuilder, "arrowBlockBuilder is null"); + flightStreamAndClient = clientHandler.getFlightStream(split, getTicket(ByteBuffer.wrap(split.getFlightEndpoint())), connectorSession); + flightStream = flightStreamAndClient.getFlightStream(); + } + + private Ticket getTicket(ByteBuffer byteArray) + { + try { + return FlightEndpoint.deserialize(byteArray).getTicket(); + } + catch (Exception e) { + throw new ArrowException(ARROW_FLIGHT_CLIENT_ERROR, e.getMessage(), e); + } + } + + @Override + public long getCompletedBytes() + { + return 0; + } + + @Override + public long getCompletedPositions() + { + return currentPosition; + } + + @Override + public long getReadTimeNanos() + { + return 0; + } + + @Override + public boolean isFinished() + { + return completed; + } + + @Override + public long getSystemMemoryUsage() + { + return 0; + } + + @Override + public Page getNextPage() + { + if (flightStream.next()) { + vectorSchemaRoot = flightStream.getRoot(); + } + else { + completed = true; + return null; + } + + currentPosition = currentPosition + 1; + + List blocks = new ArrayList<>(); + for (int columnIndex = 0; columnIndex < columnHandles.size(); columnIndex++) { + FieldVector vector = vectorSchemaRoot.getVector(columnIndex); + Type type = columnHandles.get(columnIndex).getColumnType(); + Block block = arrowBlockBuilder.buildBlockFromFieldVector(vector, type, flightStream.getDictionaryProvider()); + blocks.add(block); + } + + return new Page(vectorSchemaRoot.getRowCount(), blocks.toArray(new Block[0])); + } + + @Override + public void close() + { + if (vectorSchemaRoot != null) { + vectorSchemaRoot.close(); + completed = true; + } + + try { + flightStreamAndClient.close(); + } + catch (Exception e) { + logger.error(e, "Error closing flight stream"); + throw new ArrowException(ARROW_FLIGHT_CLIENT_ERROR, e.getMessage(), e); + } + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSourceProvider.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSourceProvider.java new file mode 100644 index 000000000000..60d3ee120695 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSourceProvider.java @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorPageSource; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.SplitContext; +import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.google.common.collect.ImmutableList; + +import javax.inject.Inject; + +import java.util.List; + +import static java.util.Objects.requireNonNull; + +public class ArrowPageSourceProvider + implements ConnectorPageSourceProvider +{ + private static final Logger logger = Logger.get(ArrowPageSourceProvider.class); + private final BaseArrowFlightClient clientHandler; + private final ArrowBlockBuilder arrowBlockBuilder; + + @Inject + public ArrowPageSourceProvider(BaseArrowFlightClient clientHandler, ArrowBlockBuilder arrowBlockBuilder) + { + this.clientHandler = requireNonNull(clientHandler, "clientHandler is null"); + this.arrowBlockBuilder = requireNonNull(arrowBlockBuilder, "arrowBlockBuilder is null"); + } + + @Override + public ConnectorPageSource createPageSource(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorSplit split, List columns, SplitContext splitContext) + { + ImmutableList.Builder columnHandles = ImmutableList.builder(); + for (ColumnHandle handle : columns) { + columnHandles.add((ArrowColumnHandle) handle); + } + ArrowSplit arrowSplit = (ArrowSplit) split; + return new ArrowPageSource(arrowSplit, columnHandles.build(), clientHandler, session, arrowBlockBuilder); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPlugin.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPlugin.java new file mode 100644 index 000000000000..45197317cf81 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPlugin.java @@ -0,0 +1,51 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.Plugin; +import com.facebook.presto.spi.connector.ConnectorFactory; +import com.google.common.collect.ImmutableList; +import com.google.inject.Module; + +import static com.google.common.base.MoreObjects.firstNonNull; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Strings.isNullOrEmpty; +import static java.util.Objects.requireNonNull; + +public class ArrowPlugin + implements Plugin +{ + private final String name; + private final Module module; + private final ImmutableList extraModules; + + public ArrowPlugin(String name, Module module, Module... extraModules) + { + checkArgument(!isNullOrEmpty(name), "name is null or empty"); + this.name = name; + this.module = requireNonNull(module, "module is null"); + this.extraModules = ImmutableList.copyOf(requireNonNull(extraModules, "extraModules is null")); + } + + private static ClassLoader getClassLoader() + { + return firstNonNull(Thread.currentThread().getContextClassLoader(), ArrowPlugin.class.getClassLoader()); + } + + @Override + public Iterable getConnectorFactories() + { + return ImmutableList.of(new ArrowConnectorFactory(name, module, extraModules, getClassLoader())); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplit.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplit.java new file mode 100644 index 000000000000..f5aa3343d753 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplit.java @@ -0,0 +1,81 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.NodeProvider; +import com.facebook.presto.spi.schedule.NodeSelectionStrategy; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import javax.annotation.Nullable; + +import java.util.Collections; +import java.util.List; + +public class ArrowSplit + implements ConnectorSplit +{ + private final String schemaName; + private final String tableName; + private final byte[] flightEndpoint; + + @JsonCreator + public ArrowSplit( + @JsonProperty("schemaName") @Nullable String schemaName, + @JsonProperty("tableName") String tableName, + @JsonProperty("flightEndpoint") byte[] flightEndpoint) + { + this.schemaName = schemaName; + this.tableName = tableName; + this.flightEndpoint = flightEndpoint; + } + + @Override + public NodeSelectionStrategy getNodeSelectionStrategy() + { + return NodeSelectionStrategy.NO_PREFERENCE; + } + + @Override + public List getPreferredNodes(NodeProvider nodeProvider) + { + return Collections.emptyList(); + } + + @Override + public Object getInfo() + { + return this.getInfoMap(); + } + + @JsonProperty + public String getSchemaName() + { + return schemaName; + } + + @JsonProperty + public String getTableName() + { + return tableName; + } + + @JsonProperty + public byte[] getFlightEndpoint() + { + return flightEndpoint; + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplitManager.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplitManager.java new file mode 100644 index 000000000000..2a016a9441bd --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplitManager.java @@ -0,0 +1,60 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorSplitSource; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.FixedSplitSource; +import com.facebook.presto.spi.connector.ConnectorSplitManager; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import org.apache.arrow.flight.FlightInfo; + +import javax.inject.Inject; + +import java.util.List; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +public class ArrowSplitManager + implements ConnectorSplitManager +{ + private final BaseArrowFlightClient clientHandler; + + private ArrowFlightConfig arrowFlightConfig; + + @Inject + public ArrowSplitManager(BaseArrowFlightClient clientHandler, ArrowFlightConfig arrowFlightConfig) + { + this.clientHandler = requireNonNull(clientHandler, "clientHandler is null"); + this.arrowFlightConfig = requireNonNull(arrowFlightConfig, "arrowFlightConfig is null"); + } + + @Override + public ConnectorSplitSource getSplits(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorTableLayoutHandle layout, SplitSchedulingContext splitSchedulingContext) + { + ArrowTableLayoutHandle tableLayoutHandle = (ArrowTableLayoutHandle) layout; + ArrowTableHandle tableHandle = tableLayoutHandle.getTableHandle(); + FlightInfo flightInfo = clientHandler.getFlightInfoForTableScan(tableLayoutHandle, session); + List splits = flightInfo.getEndpoints() + .stream() + .map(info -> new ArrowSplit( + tableHandle.getSchema(), + tableHandle.getTable(), + info.serialize().array())) + .collect(toImmutableList()); + return new FixedSplitSource(splits); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java new file mode 100644 index 000000000000..cef04e4372a5 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java @@ -0,0 +1,73 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.ConnectorTableHandle; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Objects; + +public class ArrowTableHandle + implements ConnectorTableHandle +{ + private final String schema; + private final String table; + + @JsonCreator + public ArrowTableHandle( + @JsonProperty("schema") String schema, + @JsonProperty("table") String table) + { + this.schema = schema; + this.table = table; + } + + @JsonProperty("schema") + public String getSchema() + { + return schema; + } + + @JsonProperty("table") + public String getTable() + { + return table; + } + + @Override + public String toString() + { + return schema + ":" + table; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ArrowTableHandle that = (ArrowTableHandle) o; + return Objects.equals(schema, that.schema) && Objects.equals(table, that.table); + } + + @Override + public int hashCode() + { + return Objects.hash(schema, table); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java new file mode 100644 index 000000000000..b997f81c13a0 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java @@ -0,0 +1,87 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +public class ArrowTableLayoutHandle + implements ConnectorTableLayoutHandle +{ + private final ArrowTableHandle tableHandle; + private final List columnHandles; + private final TupleDomain tupleDomain; + + @JsonCreator + public ArrowTableLayoutHandle( + @JsonProperty("table") ArrowTableHandle table, + @JsonProperty("columnHandles") List columnHandles, + @JsonProperty("tupleDomain") TupleDomain domain) + { + this.tableHandle = requireNonNull(table, "table is null"); + this.columnHandles = requireNonNull(columnHandles, "columns are null"); + this.tupleDomain = requireNonNull(domain, "domain is null"); + } + + @JsonProperty("table") + public ArrowTableHandle getTableHandle() + { + return tableHandle; + } + + @JsonProperty("tupleDomain") + public TupleDomain getTupleDomain() + { + return tupleDomain; + } + + @JsonProperty("columnHandles") + public List getColumnHandles() + { + return columnHandles; + } + + @Override + public String toString() + { + return "tableHandle:" + tableHandle + ", columnHandles:" + columnHandles + ", tupleDomain:" + tupleDomain; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ArrowTableLayoutHandle arrowTableLayoutHandle = (ArrowTableLayoutHandle) o; + return Objects.equals(tableHandle, arrowTableLayoutHandle.tableHandle) && Objects.equals(columnHandles, arrowTableLayoutHandle.columnHandles) && Objects.equals(tupleDomain, arrowTableLayoutHandle.tupleDomain); + } + + @Override + public int hashCode() + { + return Objects.hash(tableHandle, columnHandles, tupleDomain); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTransactionHandle.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTransactionHandle.java new file mode 100644 index 000000000000..07eb7385cfbc --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTransactionHandle.java @@ -0,0 +1,22 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; + +public enum ArrowTransactionHandle + implements ConnectorTransactionHandle +{ + INSTANCE +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/BaseArrowFlightClient.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/BaseArrowFlightClient.java new file mode 100644 index 000000000000..fdecdbadda19 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/BaseArrowFlightClient.java @@ -0,0 +1,175 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.SchemaTableName; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightRuntimeException; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.grpc.CredentialCallOption; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.types.pojo.Schema; + +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.ByteBuffer; +import java.nio.file.Paths; +import java.util.List; +import java.util.Optional; + +import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_CLIENT_ERROR; +import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_INFO_ERROR; +import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_METADATA_ERROR; +import static java.nio.file.Files.newInputStream; +import static java.util.Objects.requireNonNull; + +public abstract class BaseArrowFlightClient +{ + private final ArrowFlightConfig config; + + private RootAllocator allocator; + + public BaseArrowFlightClient(ArrowFlightConfig config) + { + this.config = requireNonNull(config, "config is null"); + } + + protected FlightClient createArrowFlightClient(URI uri) + { + return createClient(new Location(uri)); + } + + protected FlightClient createArrowFlightClient() + { + Location location; + if (config.getArrowFlightServerSslEnabled() != null && !config.getArrowFlightServerSslEnabled()) { + location = Location.forGrpcInsecure(config.getFlightServerName(), config.getArrowFlightPort()); + } + else { + location = Location.forGrpcTls(config.getFlightServerName(), config.getArrowFlightPort()); + } + return createClient(location); + } + + protected FlightClient createClient(Location location) + { + try { + if (null == allocator) { + initializeAllocator(); + } + Optional trustedCertificate = Optional.empty(); + FlightClient.Builder flightClientBuilder = FlightClient.builder(allocator, location); + if (config.getVerifyServer() != null && !config.getVerifyServer()) { + flightClientBuilder.verifyServer(false); + } + else if (config.getFlightServerSSLCertificate() != null) { + trustedCertificate = Optional.of(newInputStream(Paths.get(config.getFlightServerSSLCertificate()))); + flightClientBuilder.trustedCertificates(trustedCertificate.get()).useTls(); + } + + FlightClient flightClient = flightClientBuilder.build(); + if (trustedCertificate.isPresent()) { + trustedCertificate.get().close(); + } + + return flightClient; + } + catch (Exception e) { + throw new ArrowException(ARROW_FLIGHT_CLIENT_ERROR, "The flight client could not be obtained." + e.getMessage(), e); + } + } + + private synchronized void initializeAllocator() + { + if (allocator == null) { + allocator = new RootAllocator(Long.MAX_VALUE); + } + } + + public abstract CredentialCallOption[] getCallOptions(ConnectorSession connectorSession); + + protected FlightInfo getFlightInfo(FlightDescriptor flightDescriptor, ConnectorSession connectorSession) + { + try (FlightClient client = createArrowFlightClient()) { + CredentialCallOption[] auth = getCallOptions(connectorSession); + return client.getInfo(flightDescriptor, auth); + } + catch (InterruptedException e) { + throw new ArrowException(ARROW_FLIGHT_INFO_ERROR, "The flight information could not be obtained from the flight server." + e.getMessage(), e); + } + } + + protected ClientClosingFlightStream getFlightStream(ArrowSplit split, Ticket ticket, ConnectorSession connectorSession) + { + ByteBuffer byteArray = ByteBuffer.wrap(split.getFlightEndpoint()); + try { + FlightClient client = FlightEndpoint.deserialize(byteArray).getLocations().stream() + .map(Location::getUri) + .findAny() + .map(this::createArrowFlightClient) + .orElseGet(this::createArrowFlightClient); + return new ClientClosingFlightStream( + client.getStream(ticket, getCallOptions(connectorSession)), + client); + } + catch (FlightRuntimeException | IOException | URISyntaxException e) { + throw new ArrowException(ARROW_FLIGHT_CLIENT_ERROR, e.getMessage(), e); + } + } + + public Schema getSchema(FlightDescriptor flightDescriptor, ConnectorSession connectorSession) + { + try (FlightClient client = createArrowFlightClient()) { + CredentialCallOption[] auth = this.getCallOptions(connectorSession); + return client.getSchema(flightDescriptor, auth).getSchema(); + } + catch (InterruptedException e) { + throw new ArrowException(ARROW_FLIGHT_METADATA_ERROR, "The flight information could not be obtained from the flight server." + e.getMessage(), e); + } + } + + public void closeRootAllocator() + { + if (null != allocator) { + allocator.close(); + } + } + + public abstract List listSchemaNames(ConnectorSession session); + + public abstract List listTables(ConnectorSession session, Optional schemaName); + + protected abstract FlightDescriptor getFlightDescriptorForSchema(String schemaName, String tableName); + + protected abstract FlightDescriptor getFlightDescriptorForTableScan(ArrowTableLayoutHandle tableLayoutHandle); + + public Schema getSchemaForTable(String schemaName, String tableName, ConnectorSession connectorSession) + { + FlightDescriptor flightDescriptor = getFlightDescriptorForSchema(schemaName, tableName); + return getSchema(flightDescriptor, connectorSession); + } + + public FlightInfo getFlightInfoForTableScan(ArrowTableLayoutHandle tableLayoutHandle, ConnectorSession session) + { + FlightDescriptor flightDescriptor = getFlightDescriptorForTableScan(tableLayoutHandle); + return getFlightInfo(flightDescriptor, session); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ClientClosingFlightStream.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ClientClosingFlightStream.java new file mode 100644 index 000000000000..61a3d3eb2598 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ClientClosingFlightStream.java @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightStream; + +import static java.util.Objects.requireNonNull; + +public class ClientClosingFlightStream + implements AutoCloseable +{ + private final FlightStream flightStream; + private final FlightClient flightClient; + + public ClientClosingFlightStream(FlightStream flightStream, FlightClient flightClient) + { + this.flightStream = requireNonNull(flightStream, "flightStream is null"); + this.flightClient = requireNonNull(flightClient, "flightClient is null"); + } + + public FlightStream getFlightStream() + { + return flightStream; + } + + @Override + public void close() + throws Exception + { + try { + flightStream.close(); + } + finally { + flightClient.close(); + } + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowBlockBuilder.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowBlockBuilder.java new file mode 100644 index 000000000000..e325c9375787 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowBlockBuilder.java @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.type.CharType; +import com.facebook.presto.common.type.TimeType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarcharType; +import org.apache.arrow.vector.types.pojo.Field; + +import java.util.Optional; + +public class TestingArrowBlockBuilder + extends ArrowBlockBuilder +{ + @Override + protected Type getPrestoTypeFromArrowField(Field field) + { + String columnLength = field.getMetadata().get("columnLength"); + int length = columnLength != null ? Integer.parseInt(columnLength) : 0; + + String nativeType = Optional.ofNullable(field.getMetadata().get("columnNativeType")).orElse(""); + + switch (nativeType) { + case "CHAR": + case "CHARACTER": + return CharType.createCharType(length); + case "VARCHAR": + return VarcharType.createVarcharType(length); + case "TIME": + return TimeType.TIME; + default: + return super.getPrestoTypeFromArrowField(field); + } + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightClient.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightClient.java new file mode 100644 index 000000000000..f21c67f80312 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightClient.java @@ -0,0 +1,139 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.SchemaTableName; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.apache.arrow.flight.Action; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.auth2.BearerCredentialWriter; +import org.apache.arrow.flight.grpc.CredentialCallOption; + +import javax.inject.Inject; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; + +import static com.facebook.presto.common.Utils.checkArgument; +import static java.util.Locale.ENGLISH; +import static java.util.Objects.requireNonNull; + +public class TestingArrowFlightClient + extends BaseArrowFlightClient +{ + private final JsonCodec requestCodec; + private final JsonCodec responseCodec; + + @Inject + public TestingArrowFlightClient( + ArrowFlightConfig config, + JsonCodec requestCodec, + JsonCodec responseCodec) + { + super(config); + this.requestCodec = requireNonNull(requestCodec, "requestCodec is null"); + this.responseCodec = requireNonNull(responseCodec, "responseCodec is null"); + } + + @Override + public CredentialCallOption[] getCallOptions(ConnectorSession connectorSession) + { + return new CredentialCallOption[]{new CredentialCallOption(new BearerCredentialWriter(null))}; + } + + @Override + public FlightDescriptor getFlightDescriptorForSchema(String schemaName, String tableName) + { + TestingArrowFlightRequest request = TestingArrowFlightRequest.createDescribeTableRequest(schemaName, tableName); + return FlightDescriptor.command(requestCodec.toBytes(request)); + } + + @Override + public List listSchemaNames(ConnectorSession session) + { + List res; + try (FlightClient client = createArrowFlightClient()) { + List names1 = new ArrayList<>(); + TestingArrowFlightRequest request = TestingArrowFlightRequest.createListSchemaRequest(); + Iterator iterator = client.doAction(new Action("discovery", requestCodec.toJsonBytes(request)), getCallOptions(session)); + while (iterator.hasNext()) { + Result result = iterator.next(); + TestingArrowFlightResponse response = responseCodec.fromJson(result.getBody()); + checkArgument(response != null, "response is null"); + checkArgument(response.getSchemaNames() != null, "response.getSchemaNames() is null"); + names1.addAll(response.getSchemaNames()); + } + res = names1; + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + List listSchemas = res; + List names = new ArrayList<>(); + for (String value : listSchemas) { + names.add(value.toLowerCase(ENGLISH)); + } + return ImmutableList.copyOf(names); + } + + @Override + public List listTables(ConnectorSession session, Optional schemaName) + { + String schemaValue = schemaName.orElse(""); + List res; + try (FlightClient client = createArrowFlightClient()) { + List names = new ArrayList<>(); + TestingArrowFlightRequest request = TestingArrowFlightRequest.createListTablesRequest(schemaName.orElse("")); + Iterator iterator = client.doAction(new Action("discovery", requestCodec.toJsonBytes(request)), getCallOptions(session)); + while (iterator.hasNext()) { + Result result = iterator.next(); + TestingArrowFlightResponse response = responseCodec.fromJson(result.getBody()); + checkArgument(response != null, "response is null"); + checkArgument(response.getTableNames() != null, "response.getTableNames() is null"); + names.addAll(response.getTableNames()); + } + res = names; + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + List listTables = res; + List tables = new ArrayList<>(); + for (String value : listTables) { + tables.add(new SchemaTableName(schemaValue.toLowerCase(ENGLISH), value.toLowerCase(ENGLISH))); + } + + return tables; + } + + @Override + public FlightDescriptor getFlightDescriptorForTableScan(ArrowTableLayoutHandle tableLayoutHandle) + { + ArrowTableHandle tableHandle = tableLayoutHandle.getTableHandle(); + String query = new TestingArrowQueryBuilder().buildSql( + tableHandle.getSchema(), + tableHandle.getTable(), + tableLayoutHandle.getColumnHandles(), ImmutableMap.of(), + tableLayoutHandle.getTupleDomain()); + TestingArrowFlightRequest request = TestingArrowFlightRequest.createQueryRequest(tableHandle.getSchema(), tableHandle.getTable(), query); + return FlightDescriptor.command(requestCodec.toBytes(request)); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightPlugin.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightPlugin.java new file mode 100644 index 000000000000..caeff6e83fcc --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightPlugin.java @@ -0,0 +1,25 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.json.JsonModule; + +public class TestingArrowFlightPlugin + extends ArrowPlugin +{ + public TestingArrowFlightPlugin() + { + super("arrow", new TestingArrowModule(), new JsonModule()); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightRequest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightRequest.java new file mode 100644 index 000000000000..af8368d3c8b1 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightRequest.java @@ -0,0 +1,75 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Optional; + +public class TestingArrowFlightRequest +{ + private final Optional schema; + private final Optional table; + private final Optional query; + + @JsonCreator + public TestingArrowFlightRequest( + @JsonProperty("schema") Optional schema, + @JsonProperty("table") Optional table, + @JsonProperty("query") Optional query) + { + this.schema = schema; + this.table = table; + this.query = query; + } + + public static TestingArrowFlightRequest createListSchemaRequest() + { + return new TestingArrowFlightRequest(Optional.empty(), Optional.empty(), Optional.empty()); + } + + public static TestingArrowFlightRequest createListTablesRequest(String schema) + { + return new TestingArrowFlightRequest(Optional.of(schema), Optional.empty(), Optional.empty()); + } + + public static TestingArrowFlightRequest createDescribeTableRequest(String schema, String table) + { + return new TestingArrowFlightRequest(Optional.of(schema), Optional.of(table), Optional.empty()); + } + + public static TestingArrowFlightRequest createQueryRequest(String schema, String table, String query) + { + return new TestingArrowFlightRequest(Optional.of(schema), Optional.of(table), Optional.of(query)); + } + + @JsonProperty + public Optional getSchema() + { + return schema; + } + + @JsonProperty + public Optional getTable() + { + return table; + } + + @JsonProperty + public Optional getQuery() + { + return query; + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightResponse.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightResponse.java new file mode 100644 index 000000000000..d2b06d52137a --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightResponse.java @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +import static java.util.Objects.requireNonNull; + +public class TestingArrowFlightResponse +{ + private final List schemaNames; + private final List tableNames; + + @JsonCreator + public TestingArrowFlightResponse(@JsonProperty("schemaNames") List schemaNames, @JsonProperty("tableNames") List tableNames) + { + this.schemaNames = ImmutableList.copyOf(requireNonNull(schemaNames, "schemaNames is null")); + this.tableNames = ImmutableList.copyOf(requireNonNull(tableNames, "tableNames is null")); + } + + @JsonProperty + public List getSchemaNames() + { + return schemaNames; + } + + @JsonProperty + public List getTableNames() + { + return tableNames; + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowModule.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowModule.java new file mode 100644 index 000000000000..2adc773ed018 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowModule.java @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.google.inject.Binder; +import com.google.inject.Module; +import com.google.inject.Scopes; + +import static com.facebook.airlift.json.JsonCodecBinder.jsonCodecBinder; + +public class TestingArrowModule + implements Module +{ + @Override + public void configure(Binder binder) + { + // Concrete implementation of the ArrowFlightClient + binder.bind(BaseArrowFlightClient.class).to(TestingArrowFlightClient.class).in(Scopes.SINGLETON); + // Override the ArrowBlockBuilder with an implementation that handles h2 types + binder.bind(ArrowBlockBuilder.class).to(TestingArrowBlockBuilder.class).in(Scopes.SINGLETON); + // Request/response objects + jsonCodecBinder(binder).bindJsonCodec(TestingArrowFlightRequest.class); + jsonCodecBinder(binder).bindJsonCodec(TestingArrowFlightResponse.class); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowQueryBuilder.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowQueryBuilder.java new file mode 100644 index 000000000000..3e000af2fdcb --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowQueryBuilder.java @@ -0,0 +1,305 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.predicate.Domain; +import com.facebook.presto.common.predicate.Range; +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.BooleanType; +import com.facebook.presto.common.type.CharType; +import com.facebook.presto.common.type.DateType; +import com.facebook.presto.common.type.DoubleType; +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.common.type.RealType; +import com.facebook.presto.common.type.SmallintType; +import com.facebook.presto.common.type.TimeType; +import com.facebook.presto.common.type.TimestampType; +import com.facebook.presto.common.type.TinyintType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.spi.ColumnHandle; +import com.google.common.base.Joiner; +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; + +import java.sql.Time; +import java.sql.Timestamp; +import java.text.SimpleDateFormat; +import java.time.ZoneId; +import java.util.ArrayList; +import java.util.Date; +import java.util.List; +import java.util.Map; +import java.util.TimeZone; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Strings.isNullOrEmpty; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.lang.Float.intBitsToFloat; +import static java.lang.Math.toIntExact; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.joining; + +public class TestingArrowQueryBuilder +{ + // not all databases support booleans, so use 1=1 and 1=0 instead + private static final String ALWAYS_TRUE = "1=1"; + private static final String ALWAYS_FALSE = "1=0"; + public static final String DATE_FORMAT = "yyyy-MM-dd"; + public static final String TIMESTAMP_FORMAT = "yyyy-MM-dd HH:mm:ss.SSS"; + public static final String TIME_FORMAT = "HH:mm:ss"; + public static final TimeZone UTC_TIME_ZONE = TimeZone.getTimeZone(ZoneId.of("UTC")); + + public String buildSql( + String schema, + String table, + List columns, + Map columnExpressions, + TupleDomain tupleDomain) + { + StringBuilder sql = new StringBuilder(); + + sql.append("SELECT "); + sql.append(addColumnExpression(columns, columnExpressions)); + + sql.append(" FROM "); + if (!isNullOrEmpty(schema)) { + sql.append(quote(schema)).append('.'); + } + sql.append(quote(table)); + + List accumulator = new ArrayList<>(); + + if (tupleDomain != null && !tupleDomain.isAll()) { + List clauses = toConjuncts(columns, tupleDomain, accumulator); + if (!clauses.isEmpty()) { + sql.append(" WHERE ") + .append(Joiner.on(" AND ").join(clauses)); + } + } + + return sql.toString(); + } + + public static String convertEpochToString(long epochValue, Type type) + { + if (type instanceof DateType) { + long millis = TimeUnit.DAYS.toMillis(epochValue); + Date date = new Date(millis); + SimpleDateFormat dateFormat = new SimpleDateFormat(DATE_FORMAT); + dateFormat.setTimeZone(UTC_TIME_ZONE); + return dateFormat.format(date); + } + else if (type instanceof TimestampType) { + Timestamp timestamp = new Timestamp(epochValue); + SimpleDateFormat timestampFormat = new SimpleDateFormat(TIMESTAMP_FORMAT); + timestampFormat.setTimeZone(UTC_TIME_ZONE); + return timestampFormat.format(timestamp); + } + else if (type instanceof TimeType) { + long millis = TimeUnit.SECONDS.toMillis(epochValue / 1000); + Time time = new Time(millis); + SimpleDateFormat timeFormat = new SimpleDateFormat(TIME_FORMAT); + timeFormat.setTimeZone(UTC_TIME_ZONE); + return timeFormat.format(time); + } + else { + throw new UnsupportedOperationException(type + " is not supported."); + } + } + + protected static class TypeAndValue + { + private final Type type; + private final Object value; + + public TypeAndValue(Type type, Object value) + { + this.type = requireNonNull(type, "type is null"); + this.value = requireNonNull(value, "value is null"); + } + + public Type getType() + { + return type; + } + + public Object getValue() + { + return value; + } + } + + private String addColumnExpression(List columns, Map columnExpressions) + { + if (columns.isEmpty()) { + return "null"; + } + + return columns.stream() + .map(arrowColumnHandle -> { + String columnAlias = quote(arrowColumnHandle.getColumnName()); + String expression = columnExpressions.get(arrowColumnHandle.getColumnName()); + if (expression == null) { + return columnAlias; + } + return format("%s AS %s", expression, columnAlias); + }) + .collect(joining(", ")); + } + + private static boolean isAcceptedType(Type type) + { + Type validType = requireNonNull(type, "type is null"); + return validType.equals(BigintType.BIGINT) || + validType.equals(TinyintType.TINYINT) || + validType.equals(SmallintType.SMALLINT) || + validType.equals(IntegerType.INTEGER) || + validType.equals(DoubleType.DOUBLE) || + validType.equals(RealType.REAL) || + validType.equals(BooleanType.BOOLEAN) || + validType.equals(DateType.DATE) || + validType.equals(TimeType.TIME) || + validType.equals(TimestampType.TIMESTAMP) || + validType instanceof VarcharType || + validType instanceof CharType; + } + private List toConjuncts(List columns, TupleDomain tupleDomain, List accumulator) + { + ImmutableList.Builder builder = ImmutableList.builder(); + for (ArrowColumnHandle column : columns) { + Type type = column.getColumnType(); + if (isAcceptedType(type)) { + Domain domain = tupleDomain.getDomains().get().get(column); + if (domain != null) { + builder.add(toPredicate(column.getColumnName(), domain, column, accumulator)); + } + } + } + return builder.build(); + } + + private String toPredicate(String columnName, Domain domain, ArrowColumnHandle columnHandle, List accumulator) + { + checkArgument(domain.getType().isOrderable(), "Domain type must be orderable"); + + if (domain.getValues().isNone()) { + return domain.isNullAllowed() ? quote(columnName) + " IS NULL" : ALWAYS_FALSE; + } + + if (domain.getValues().isAll()) { + return domain.isNullAllowed() ? ALWAYS_TRUE : quote(columnName) + " IS NOT NULL"; + } + + List disjuncts = new ArrayList<>(); + List singleValues = new ArrayList<>(); + for (Range range : domain.getValues().getRanges().getOrderedRanges()) { + checkState(!range.isAll()); // Already checked + if (range.isSingleValue()) { + singleValues.add(range.getSingleValue()); + } + else { + List rangeConjuncts = new ArrayList<>(); + if (!range.isLowUnbounded()) { + rangeConjuncts.add(toPredicate(columnName, range.isLowInclusive() ? ">=" : ">", range.getLowBoundedValue(), columnHandle, accumulator)); + } + if (!range.isHighUnbounded()) { + rangeConjuncts.add(toPredicate(columnName, range.isHighInclusive() ? "<=" : "<", range.getHighBoundedValue(), columnHandle, accumulator)); + } + // If rangeConjuncts is null, then the range was ALL, which should already have been checked for + checkState(!rangeConjuncts.isEmpty()); + disjuncts.add("(" + Joiner.on(" AND ").join(rangeConjuncts) + ")"); + } + } + + // Add back all of the possible single values either as an equality or an IN predicate + if (singleValues.size() == 1) { + disjuncts.add(toPredicate(columnName, "=", getOnlyElement(singleValues), columnHandle, accumulator)); + } + else if (singleValues.size() > 1) { + for (Object value : singleValues) { + bindValue(value, columnHandle, accumulator); + } + String values = Joiner.on(",").join(singleValues.stream().map(v -> + parameterValueToString(columnHandle.getColumnType(), v)).collect(Collectors.toList())); + disjuncts.add(quote(columnName) + " IN (" + values + ")"); + } + + // Add nullability disjuncts + checkState(!disjuncts.isEmpty()); + if (domain.isNullAllowed()) { + disjuncts.add(quote(columnName) + " IS NULL"); + } + + return "(" + Joiner.on(" OR ").join(disjuncts) + ")"; + } + + private String toPredicate(String columnName, String operator, Object value, ArrowColumnHandle columnHandle, List accumulator) + { + bindValue(value, columnHandle, accumulator); + return quote(columnName) + " " + operator + " " + parameterValueToString(columnHandle.getColumnType(), value); + } + private String quote(String name) + { + return "\"" + name + "\""; + } + + private String quoteValue(String name) + { + return "'" + name + "'"; + } + + private void bindValue(Object value, ArrowColumnHandle columnHandle, List accumulator) + { + Type type = columnHandle.getColumnType(); + accumulator.add(new TypeAndValue(type, value)); + } + + public static String convertLongToFloatString(Long value) + { + float floatFromIntBits = intBitsToFloat(toIntExact(value)); + return String.valueOf(floatFromIntBits); + } + + private String parameterValueToString(Type type, Object value) + { + Class javaType = type.getJavaType(); + if (type instanceof DateType && javaType == long.class) { + return quoteValue(convertEpochToString((Long) value, type)); + } + else if (type instanceof TimeType && javaType == long.class) { + return quoteValue(convertEpochToString((Long) value, type)); + } + else if (type instanceof TimestampType && javaType == long.class) { + return quoteValue(convertEpochToString((Long) value, type)); + } + else if (type instanceof RealType && javaType == long.class) { + return convertLongToFloatString((Long) value); + } + else if (javaType == boolean.class || javaType == double.class || javaType == long.class) { + return value.toString(); + } + else if (javaType == Slice.class) { + return quoteValue(((Slice) value).toStringUtf8()); + } + else { + return quoteValue(value.toString()); + } + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/ArrowFlightQueryRunner.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/ArrowFlightQueryRunner.java new file mode 100644 index 000000000000..4f08cb996f71 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/ArrowFlightQueryRunner.java @@ -0,0 +1,101 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.plugin.arrow.tests; + +import com.facebook.airlift.log.Logger; +import com.facebook.airlift.log.Logging; +import com.facebook.plugin.arrow.TestingArrowFlightPlugin; +import com.facebook.presto.Session; +import com.facebook.presto.tests.DistributedQueryRunner; +import com.google.common.collect.ImmutableMap; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.RootAllocator; + +import java.util.Map; + +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; + +public class ArrowFlightQueryRunner +{ + private ArrowFlightQueryRunner() + { + throw new UnsupportedOperationException("This is a utility class and cannot be instantiated"); + } + + public static DistributedQueryRunner createQueryRunner(int flightServerPort) throws Exception + { + return createQueryRunner(ImmutableMap.of("arrow-flight.server.port", String.valueOf(flightServerPort))); + } + + private static DistributedQueryRunner createQueryRunner(Map catalogProperties) throws Exception + { + return createQueryRunner(ImmutableMap.of(), catalogProperties); + } + + private static DistributedQueryRunner createQueryRunner( + Map extraProperties, + Map catalogProperties) + throws Exception + { + Session session = testSessionBuilder() + .setCatalog("arrow") + .setSchema("tpch") + .build(); + + DistributedQueryRunner queryRunner = DistributedQueryRunner.builder(session).setExtraProperties(extraProperties).build(); + + try { + queryRunner.installPlugin(new TestingArrowFlightPlugin()); + + ImmutableMap.Builder properties = ImmutableMap.builder() + .putAll(catalogProperties) + .put("arrow-flight.server", "localhost") + .put("arrow-flight.server-ssl-enabled", "true") + .put("arrow-flight.server-ssl-certificate", "src/test/resources/server.crt") + .put("arrow-flight.server.verify", "true"); + + queryRunner.createCatalog("arrow", "arrow", properties.build()); + + return queryRunner; + } + catch (Exception e) { + throw new RuntimeException("Failed to create ArrowQueryRunner", e); + } + } + + public static void main(String[] args) + throws Exception + { + Logging.initialize(); + + RootAllocator allocator = new RootAllocator(Long.MAX_VALUE); + Location serverLocation = Location.forGrpcInsecure("127.0.0.1", 9443); + FlightServer server = FlightServer.builder(allocator, serverLocation, new TestingArrowServer(allocator)).build(); + + server.start(); + + Logger log = Logger.get(ArrowFlightQueryRunner.class); + + log.info("Server listening on port " + server.getPort()); + + DistributedQueryRunner queryRunner = createQueryRunner( + ImmutableMap.of("http-server.http.port", "8080"), + ImmutableMap.of("arrow-flight.server.port", String.valueOf(9443))); + Thread.sleep(10); + log.info("======== SERVER STARTED ========"); + log.info("\n====\n%s\n====", queryRunner.getCoordinator().getBaseUrl()); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/ArrowMetadataUtil.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/ArrowMetadataUtil.java new file mode 100644 index 000000000000..224c682fee79 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/ArrowMetadataUtil.java @@ -0,0 +1,77 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow.tests; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.json.JsonCodecFactory; +import com.facebook.airlift.json.JsonObjectMapperProvider; +import com.facebook.plugin.arrow.ArrowColumnHandle; +import com.facebook.plugin.arrow.ArrowTableHandle; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.common.type.Type; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.deser.std.FromStringDeserializer; +import com.google.common.collect.ImmutableMap; + +import java.util.Map; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Locale.ENGLISH; +import static org.testng.Assert.assertEquals; + +final class ArrowMetadataUtil +{ + private ArrowMetadataUtil() {} + + public static final JsonCodec COLUMN_CODEC; + public static final JsonCodec TABLE_CODEC; + + static { + JsonObjectMapperProvider provider = new JsonObjectMapperProvider(); + provider.setJsonDeserializers(ImmutableMap.of(Type.class, new TestingTypeDeserializer())); + JsonCodecFactory codecFactory = new JsonCodecFactory(provider); + COLUMN_CODEC = codecFactory.jsonCodec(ArrowColumnHandle.class); + TABLE_CODEC = codecFactory.jsonCodec(ArrowTableHandle.class); + } + + public static final class TestingTypeDeserializer + extends FromStringDeserializer + { + private final Map types = ImmutableMap.of( + StandardTypes.BIGINT, BIGINT, + StandardTypes.VARCHAR, VARCHAR); + + public TestingTypeDeserializer() + { + super(Type.class); + } + + @Override + protected Type _deserialize(String value, DeserializationContext context) + { + Type type = types.get(value.toLowerCase(ENGLISH)); + checkArgument(type != null, "Unknown type %s", value); + return type; + } + } + + public static void assertJsonRoundTrip(JsonCodec codec, T object) + { + String json = codec.toJson(object); + T copy = codec.fromJson(json); + assertEquals(copy, object); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestArrowBlockBuilder.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestArrowBlockBuilder.java new file mode 100644 index 000000000000..2ab5852b0815 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestArrowBlockBuilder.java @@ -0,0 +1,683 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow.tests; + +import com.facebook.airlift.log.Logger; +import com.facebook.plugin.arrow.ArrowBlockBuilder; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.block.DictionaryBlock; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.BooleanType; +import com.facebook.presto.common.type.DateType; +import com.facebook.presto.common.type.DecimalType; +import com.facebook.presto.common.type.Decimals; +import com.facebook.presto.common.type.DoubleType; +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.common.type.RowType; +import com.facebook.presto.common.type.SmallintType; +import com.facebook.presto.common.type.TimestampType; +import com.facebook.presto.common.type.TinyintType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarcharType; +import io.airlift.slice.Slice; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BaseIntVector; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeStampMicroVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.impl.UnionListWriter; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryEncoder; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.charset.StandardCharsets; +import java.time.LocalDate; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +public class TestArrowBlockBuilder +{ + private static final Logger logger = Logger.get(TestArrowBlockBuilder.class); + private static final int DICTIONARY_LENGTH = 10; + private static final int VECTOR_LENGTH = 50; + private BufferAllocator allocator; + private ArrowBlockBuilder arrowBlockBuilder; + + @BeforeClass + public void setUp() + { + // Initialize the Arrow allocator + allocator = new RootAllocator(Integer.MAX_VALUE); + logger.debug("Allocator initialized: %s", allocator); + arrowBlockBuilder = new ArrowBlockBuilder(); + } + + @Test + public void testBuildBlockFromBitVector() + { + // Create a BitVector and populate it with values + BitVector bitVector = new BitVector("bitVector", allocator); + bitVector.allocateNew(3); // Allocating space for 3 elements + + bitVector.set(0, 1); // Set value to 1 (true) + bitVector.set(1, 0); // Set value to 0 (false) + bitVector.setNull(2); // Set null value + + bitVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = arrowBlockBuilder.buildBlockFromBitVector(bitVector, BooleanType.BOOLEAN); + + // Now verify the result block + assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions + assertTrue(resultBlock.isNull(2)); // The 3rd element should be null + } + + @Test + public void testBuildBlockFromTinyIntVector() + { + // Create a TinyIntVector and populate it with values + TinyIntVector tinyIntVector = new TinyIntVector("tinyIntVector", allocator); + tinyIntVector.allocateNew(3); // Allocating space for 3 elements + tinyIntVector.set(0, 10); + tinyIntVector.set(1, 20); + tinyIntVector.setNull(2); // Set null value + + tinyIntVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = arrowBlockBuilder.buildBlockFromTinyIntVector(tinyIntVector, TinyintType.TINYINT); + + // Now verify the result block + assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions + assertTrue(resultBlock.isNull(2)); // The 3rd element should be null + } + + @Test + public void testBuildBlockFromSmallIntVector() + { + // Create a SmallIntVector and populate it with values + SmallIntVector smallIntVector = new SmallIntVector("smallIntVector", allocator); + smallIntVector.allocateNew(3); // Allocating space for 3 elements + smallIntVector.set(0, 10); + smallIntVector.set(1, 20); + smallIntVector.setNull(2); // Set null value + + smallIntVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = arrowBlockBuilder.buildBlockFromSmallIntVector(smallIntVector, SmallintType.SMALLINT); + + // Now verify the result block + assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions + assertTrue(resultBlock.isNull(2)); // The 3rd element should be null + } + + @Test + public void testBuildBlockFromIntVector() + { + // Create an IntVector and populate it with values + IntVector intVector = new IntVector("intVector", allocator); + intVector.allocateNew(3); // Allocating space for 3 elements + intVector.set(0, 10); + intVector.set(1, 20); + intVector.set(2, 30); + + intVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = arrowBlockBuilder.buildBlockFromIntVector(intVector, IntegerType.INTEGER); + + // Now verify the result block + assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions + assertEquals(10, resultBlock.getInt(0)); // The 1st element should be 10 + assertEquals(20, resultBlock.getInt(1)); // The 2nd element should be 20 + assertEquals(30, resultBlock.getInt(2)); // The 3rd element should be 30 + } + + @Test + public void testBuildBlockFromBigIntVector() + throws InstantiationException, IllegalAccessException + { + // Create a BigIntVector and populate it with values + BigIntVector bigIntVector = new BigIntVector("bigIntVector", allocator); + bigIntVector.allocateNew(3); // Allocating space for 3 elements + + bigIntVector.set(0, 10L); + bigIntVector.set(1, 20L); + bigIntVector.set(2, 30L); + + bigIntVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = arrowBlockBuilder.buildBlockFromBigIntVector(bigIntVector, BigintType.BIGINT); + + // Now verify the result block + assertEquals(10L, resultBlock.getInt(0)); // The 1st element should be 10L + assertEquals(20L, resultBlock.getInt(1)); // The 2nd element should be 20L + assertEquals(30L, resultBlock.getInt(2)); // The 3rd element should be 30L + } + + @Test + public void testBuildBlockFromDecimalVector() + { + // Create a DecimalVector and populate it with values + DecimalVector decimalVector = new DecimalVector("decimalVector", allocator, 10, 2); // Precision = 10, Scale = 2 + decimalVector.allocateNew(2); // Allocating space for 2 elements + decimalVector.set(0, new BigDecimal("123.45")); + + decimalVector.setValueCount(2); + + // Build the block from the vector + Block resultBlock = arrowBlockBuilder.buildBlockFromDecimalVector(decimalVector, DecimalType.createDecimalType(10, 2)); + + // Now verify the result block + assertEquals(2, resultBlock.getPositionCount()); // Should have 2 positions + assertTrue(resultBlock.isNull(1)); // The 2nd element should be null + } + + @Test + public void testBuildBlockFromTimeStampMicroVector() + { + // Create a TimeStampMicroVector and populate it with values + TimeStampMicroVector timestampMicroVector = new TimeStampMicroVector("timestampMicroVector", allocator); + timestampMicroVector.allocateNew(3); // Allocating space for 3 elements + timestampMicroVector.set(0, 1000000L); // 1 second in microseconds + timestampMicroVector.set(1, 2000000L); // 2 seconds in microseconds + timestampMicroVector.setNull(2); // Set null value + + timestampMicroVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = arrowBlockBuilder.buildBlockFromTimeStampMicroVector(timestampMicroVector, TimestampType.TIMESTAMP); + + // Now verify the result block + assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions + assertTrue(resultBlock.isNull(2)); // The 3rd element should be null + assertEquals(1000L, resultBlock.getLong(0)); // The 1st element should be 1000ms (1 second) + assertEquals(2000L, resultBlock.getLong(1)); // The 2nd element should be 2000ms (2 seconds) + } + + @Test + public void testBuildBlockFromListVector() + { + // Create a root allocator for Arrow vectors + try (BufferAllocator allocator = new RootAllocator(); + ListVector listVector = ListVector.empty("listVector", allocator)) { + // Allocate the vector and get the writer + listVector.allocateNew(); + UnionListWriter listWriter = listVector.getWriter(); + + int[] data = new int[] {1, 2, 3, 10, 20, 30, 100, 200, 300, 1000, 2000, 3000}; + int tmpIndex = 0; + + for (int i = 0; i < 4; i++) { // 4 lists to be added + listWriter.startList(); + for (int j = 0; j < 3; j++) { // Each list has 3 integers + listWriter.writeInt(data[tmpIndex]); + tmpIndex++; + } + listWriter.endList(); + } + + // Set the number of lists + listVector.setValueCount(4); + + // Create Presto ArrayType for Integer + ArrayType arrayType = new ArrayType(IntegerType.INTEGER); + + // Call the method to test + Block block = arrowBlockBuilder.buildBlockFromListVector(listVector, arrayType); + + // Validate the result + assertEquals(block.getPositionCount(), 4); // 4 lists in the block + for (int j = 0; j < block.getPositionCount(); j++) { + Block subBlock = block.getBlock(j); + assertEquals(subBlock.getPositionCount(), 3); // each list should have 3 elements + } + } + } + + @Test + public void testProcessDictionaryVector() + { + // Create dictionary vector + VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator); + dictionaryVector.allocateNew(DICTIONARY_LENGTH); + for (int i = 0; i < DICTIONARY_LENGTH; i++) { + dictionaryVector.setSafe(i, String.valueOf(i).getBytes(StandardCharsets.UTF_8)); + } + dictionaryVector.setValueCount(DICTIONARY_LENGTH); + + // Create raw vector + VarCharVector rawVector = new VarCharVector("raw", allocator); + rawVector.allocateNew(VECTOR_LENGTH); + for (int i = 0; i < VECTOR_LENGTH; i++) { + int value = i % DICTIONARY_LENGTH; + rawVector.setSafe(i, String.valueOf(value).getBytes(StandardCharsets.UTF_8)); + } + rawVector.setValueCount(VECTOR_LENGTH); + + // Encode using dictionary + ArrowType.Int index = new ArrowType.Int(16, true); + Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, index)); + BaseIntVector encodedVector = (BaseIntVector) DictionaryEncoder.encode(rawVector, dictionary); + + // Process the dictionary vector + Block result = arrowBlockBuilder.buildBlockFromDictionaryVector(encodedVector, dictionary.getVector()); + + // Verify the result + assertNotNull(result, "The BlockBuilder should not be null."); + assertEquals(result.getPositionCount(), 50); + } + + @Test + public void testBuildBlockFromDictionaryVector() + { + IntVector indicesVector = new IntVector("indices", allocator); + indicesVector.allocateNew(3); // allocating space for 3 values + + // Initialize a dummy dictionary vector + // Example: dictionary contains 3 string values + VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator); + dictionaryVector.allocateNew(3); // allocating 3 elements in dictionary + + // Fill dictionaryVector with some values + dictionaryVector.set(0, "apple".getBytes()); + dictionaryVector.set(1, "banana".getBytes()); + dictionaryVector.set(2, "cherry".getBytes()); + dictionaryVector.setValueCount(3); + + // Set up index values (this would reference the dictionary) + indicesVector.set(0, 0); // First index points to "apple" + indicesVector.set(1, 1); // Second index points to "banana" + indicesVector.set(2, 2); + indicesVector.set(3, 2); // Third index points to "cherry" + indicesVector.setValueCount(4); + // Call the method under test + Block block = arrowBlockBuilder.buildBlockFromDictionaryVector(indicesVector, dictionaryVector); + + // Assertions to check the dictionary block's behavior + assertNotNull(block); + assertTrue(block instanceof DictionaryBlock); + + DictionaryBlock dictionaryBlock = (DictionaryBlock) block; + + // Verify the dictionary block contains the right dictionary + + for (int i = 0; i < dictionaryBlock.getPositionCount(); i++) { + // Get the slice (string value) at the given position + Slice slice = dictionaryBlock.getSlice(i, 0, dictionaryBlock.getSliceLength(i)); + + // Assert based on the expected values + if (i == 0) { + assertEquals(slice.toStringUtf8(), "apple"); + } + else if (i == 1) { + assertEquals(slice.toStringUtf8(), "banana"); + } + else if (i == 2) { + assertEquals(slice.toStringUtf8(), "cherry"); + } + else if (i == 3) { + assertEquals(slice.toStringUtf8(), "cherry"); + } + } + } + + @Test + public void testBuildBlockFromDictionaryVectorSmallInt() + { + SmallIntVector indicesVector = new SmallIntVector("indices", allocator); + + indicesVector.allocateNew(3); // allocating space for 3 values + indicesVector.set(0, (short) 0); + indicesVector.set(1, (short) 1); + indicesVector.set(2, (short) 2); + indicesVector.setValueCount(3); + + // Initialize a dummy dictionary vector + // Example: dictionary contains 3 string values + VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator); + dictionaryVector.allocateNew(3); // allocating 3 elements in dictionary + + // Fill dictionaryVector with some values + dictionaryVector.set(0, "apple".getBytes()); + dictionaryVector.set(1, "banana".getBytes()); + dictionaryVector.set(2, "cherry".getBytes()); + dictionaryVector.setValueCount(3); + + // Call the method under test + Block block = arrowBlockBuilder.buildBlockFromDictionaryVector(indicesVector, dictionaryVector); + + // Assertions to check the dictionary block's behavior + assertNotNull(block); + assertTrue(block instanceof DictionaryBlock); + + DictionaryBlock dictionaryBlock = (DictionaryBlock) block; + + // Verify the dictionary block contains the right dictionary + + for (int i = 0; i < dictionaryBlock.getPositionCount(); i++) { + // Get the slice (string value) at the given position + Slice slice = dictionaryBlock.getSlice(i, 0, dictionaryBlock.getSliceLength(i)); + + // Assert based on the expected values + if (i == 0) { + assertEquals(slice.toStringUtf8(), "apple"); + } + else if (i == 1) { + assertEquals(slice.toStringUtf8(), "banana"); + } + else if (i == 2) { + assertEquals(slice.toStringUtf8(), "cherry"); + } + } + } + + @Test + public void testBuildBlockFromDictionaryVectorTinyInt() + { + TinyIntVector indicesVector = new TinyIntVector("indices", allocator); + + indicesVector.allocateNew(3); // allocating space for 3 values + indicesVector.set(0, (byte) 0); + indicesVector.set(1, (byte) 1); + indicesVector.set(2, (byte) 2); + indicesVector.setValueCount(3); + + // Initialize a dummy dictionary vector + // Example: dictionary contains 3 string values + VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator); + dictionaryVector.allocateNew(3); // allocating 3 elements in dictionary + + // Fill dictionaryVector with some values + dictionaryVector.set(0, "apple".getBytes()); + dictionaryVector.set(1, "banana".getBytes()); + dictionaryVector.set(2, "cherry".getBytes()); + dictionaryVector.setValueCount(3); + + // Call the method under test + Block block = arrowBlockBuilder.buildBlockFromDictionaryVector(indicesVector, dictionaryVector); + + // Assertions to check the dictionary block's behavior + assertNotNull(block); + assertTrue(block instanceof DictionaryBlock); + + DictionaryBlock dictionaryBlock = (DictionaryBlock) block; + + // Verify the dictionary block contains the right dictionary + + for (int i = 0; i < dictionaryBlock.getPositionCount(); i++) { + // Get the slice (string value) at the given position + Slice slice = dictionaryBlock.getSlice(i, 0, dictionaryBlock.getSliceLength(i)); + + // Assert based on the expected values + if (i == 0) { + assertEquals(slice.toStringUtf8(), "apple"); + } + else if (i == 1) { + assertEquals(slice.toStringUtf8(), "banana"); + } + else if (i == 2) { + assertEquals(slice.toStringUtf8(), "cherry"); + } + } + } + + @Test + public void testWriteVarcharType() + { + Type varcharType = VarcharType.createUnboundedVarcharType(); + BlockBuilder builder = varcharType.createBlockBuilder(null, 1); + + String value = "test_string"; + arrowBlockBuilder.writeVarcharType(varcharType, builder, value); + + Block block = builder.build(); + Slice result = varcharType.getSlice(block, 0); + assertEquals(result.toStringUtf8(), value); + } + + @Test + public void testWriteSmallintType() + { + Type smallintType = SmallintType.SMALLINT; + BlockBuilder builder = smallintType.createBlockBuilder(null, 1); + + short value = 42; + arrowBlockBuilder.writeSmallintType(smallintType, builder, value); + + Block block = builder.build(); + long result = smallintType.getLong(block, 0); + assertEquals(result, value); + } + + @Test + public void testWriteTinyintType() + { + Type tinyintType = TinyintType.TINYINT; + BlockBuilder builder = tinyintType.createBlockBuilder(null, 1); + + byte value = 7; + arrowBlockBuilder.writeTinyintType(tinyintType, builder, value); + + Block block = builder.build(); + long result = tinyintType.getLong(block, 0); + assertEquals(result, value); + } + + @Test + public void testWriteBigintType() + { + Type bigintType = BigintType.BIGINT; + BlockBuilder builder = bigintType.createBlockBuilder(null, 1); + + long value = 123456789L; + arrowBlockBuilder.writeBigintType(bigintType, builder, value); + + Block block = builder.build(); + long result = bigintType.getLong(block, 0); + assertEquals(result, value); + } + + @Test + public void testWriteIntegerType() + { + Type integerType = IntegerType.INTEGER; + BlockBuilder builder = integerType.createBlockBuilder(null, 1); + + int value = 42; + arrowBlockBuilder.writeIntegerType(integerType, builder, value); + + Block block = builder.build(); + long result = integerType.getLong(block, 0); + assertEquals(result, value); + } + + @Test + public void testWriteDoubleType() + { + Type doubleType = DoubleType.DOUBLE; + BlockBuilder builder = doubleType.createBlockBuilder(null, 1); + + double value = 42.42; + arrowBlockBuilder.writeDoubleType(doubleType, builder, value); + + Block block = builder.build(); + double result = doubleType.getDouble(block, 0); + assertEquals(result, value, 0.001); + } + + @Test + public void testWriteBooleanType() + { + Type booleanType = BooleanType.BOOLEAN; + BlockBuilder builder = booleanType.createBlockBuilder(null, 1); + + boolean value = true; + arrowBlockBuilder.writeBooleanType(booleanType, builder, value); + + Block block = builder.build(); + boolean result = booleanType.getBoolean(block, 0); + assertEquals(result, value); + } + + @Test + public void testWriteArrayType() + { + Type elementType = IntegerType.INTEGER; + ArrayType arrayType = new ArrayType(elementType); + BlockBuilder builder = arrayType.createBlockBuilder(null, 1); + + List values = Arrays.asList(1, 2, 3); + arrowBlockBuilder.writeArrayType(arrayType, builder, values); + + Block block = builder.build(); + Block arrayBlock = arrayType.getObject(block, 0); + assertEquals(arrayBlock.getPositionCount(), values.size()); + for (int i = 0; i < values.size(); i++) { + assertEquals(elementType.getLong(arrayBlock, i), values.get(i).longValue()); + } + } + + @Test + public void testWriteRowType() + { + RowType.Field field1 = new RowType.Field(Optional.of("field1"), IntegerType.INTEGER); + RowType.Field field2 = new RowType.Field(Optional.of("field2"), VarcharType.createUnboundedVarcharType()); + RowType rowType = RowType.from(Arrays.asList(field1, field2)); + BlockBuilder builder = rowType.createBlockBuilder(null, 1); + + List rowValues = Arrays.asList(42, "test"); + arrowBlockBuilder.writeRowType(rowType, builder, rowValues); + + Block block = builder.build(); + Block rowBlock = rowType.getObject(block, 0); + assertEquals(IntegerType.INTEGER.getLong(rowBlock, 0), 42); + assertEquals(VarcharType.createUnboundedVarcharType().getSlice(rowBlock, 1).toStringUtf8(), "test"); + } + + @Test + public void testWriteDateType() + { + Type dateType = DateType.DATE; + BlockBuilder builder = dateType.createBlockBuilder(null, 1); + + LocalDate value = LocalDate.of(2020, 1, 1); + arrowBlockBuilder.writeDateType(dateType, builder, value); + + Block block = builder.build(); + long result = dateType.getLong(block, 0); + assertEquals(result, value.toEpochDay()); + } + + @Test + public void testWriteTimestampType() + { + Type timestampType = TimestampType.TIMESTAMP; + BlockBuilder builder = timestampType.createBlockBuilder(null, 1); + + long value = 1609459200000L; // Jan 1, 2021, 00:00:00 UTC + arrowBlockBuilder.writeTimestampType(timestampType, builder, value); + + Block block = builder.build(); + long result = timestampType.getLong(block, 0); + assertEquals(result, value); + } + + @Test + public void testWriteTimestampTypeWithSqlTimestamp() + { + Type timestampType = TimestampType.TIMESTAMP; + BlockBuilder builder = timestampType.createBlockBuilder(null, 1); + + java.sql.Timestamp timestamp = java.sql.Timestamp.valueOf("2021-01-01 00:00:00"); + long expectedMillis = timestamp.getTime(); + arrowBlockBuilder.writeTimestampType(timestampType, builder, timestamp); + + Block block = builder.build(); + long result = timestampType.getLong(block, 0); + assertEquals(result, expectedMillis); + } + + @Test + public void testShortDecimalRetrieval() + { + DecimalType shortDecimalType = DecimalType.createDecimalType(10, 2); // Precision: 10, Scale: 2 + BlockBuilder builder = shortDecimalType.createBlockBuilder(null, 1); + + BigDecimal decimalValue = new BigDecimal("12345.67"); + arrowBlockBuilder.writeDecimalType(shortDecimalType, builder, decimalValue); + + Block block = builder.build(); + long unscaledValue = shortDecimalType.getLong(block, 0); // Unscaled value: 1234567 + BigDecimal result = BigDecimal.valueOf(unscaledValue).movePointLeft(shortDecimalType.getScale()); + assertEquals(result, decimalValue); + } + + @Test + public void testLongDecimalRetrieval() + { + // Create a DecimalType with precision 38 and scale 10 + DecimalType longDecimalType = DecimalType.createDecimalType(38, 10); + BlockBuilder builder = longDecimalType.createBlockBuilder(null, 1); + BigDecimal decimalValue = new BigDecimal("1234567890.1234567890"); + arrowBlockBuilder.writeDecimalType(longDecimalType, builder, decimalValue); + // Build the block after inserting the decimal value + Block block = builder.build(); + Slice unscaledSlice = longDecimalType.getSlice(block, 0); + BigInteger unscaledValue = Decimals.decodeUnscaledValue(unscaledSlice); + BigDecimal result = new BigDecimal(unscaledValue).movePointLeft(longDecimalType.getScale()); + // Assert the decoded result is equal to the original decimal value + assertEquals(result, decimalValue); + } + + @Test + public void testVarcharVector() + { + VarCharVector vector = new VarCharVector("test_string", allocator); + vector.setSafe(0, "apple".getBytes()); + vector.setSafe(1, "fig".getBytes()); + vector.setValueCount(2); + Block resultblock = arrowBlockBuilder.buildBlockFromVarCharVector(vector, VarcharType.VARCHAR); + assertEquals(2, resultblock.getPositionCount()); + // Extract values from the Block and compare with the values in the vector + for (int i = 0; i < vector.getValueCount(); i++) { + // Retrieve the value as a Slice for the ith position in the Block + Slice slice = resultblock.getSlice(i, 0, resultblock.getSliceLength(i)); + // Assert based on the expected values + assertEquals(slice.toStringUtf8(), new String(vector.get(i))); + } + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestArrowColumnHandle.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestArrowColumnHandle.java new file mode 100644 index 000000000000..1034b6a70836 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestArrowColumnHandle.java @@ -0,0 +1,84 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow.tests; + +import com.facebook.plugin.arrow.ArrowColumnHandle; +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.spi.ColumnMetadata; +import org.testng.annotations.Test; + +import java.util.Locale; + +import static com.facebook.presto.testing.assertions.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; + +public class TestArrowColumnHandle +{ + @Test + public void testConstructorAndGetters() + { + // Given + String columnName = "testColumn"; + // When + ArrowColumnHandle columnHandle = new ArrowColumnHandle(columnName, IntegerType.INTEGER); + + // Then + assertEquals(columnHandle.getColumnName(), columnName, "Column name should match the input"); + assertEquals(columnHandle.getColumnType(), IntegerType.INTEGER, "Column type should match the input"); + } + + @Test(expectedExceptions = NullPointerException.class) + public void testConstructorWithNullColumnName() + { + // Given + // When + new ArrowColumnHandle(null, IntegerType.INTEGER); // Should throw NullPointerException + } + + @Test(expectedExceptions = NullPointerException.class) + public void testConstructorWithNullColumnType() + { + // Given + String columnName = "testColumn"; + + // When + new ArrowColumnHandle(columnName, null); // Should throw NullPointerException + } + + @Test + public void testGetColumnMetadata() + { + // Given + String columnName = "testColumn"; + ArrowColumnHandle columnHandle = new ArrowColumnHandle(columnName, IntegerType.INTEGER); + + // When + ColumnMetadata columnMetadata = columnHandle.getColumnMetadata(); + + // Then + assertNotNull(columnMetadata, "ColumnMetadata should not be null"); + assertEquals(columnMetadata.getName(), columnName.toLowerCase(Locale.ENGLISH), "ColumnMetadata name should match the column name"); + assertEquals(columnMetadata.getType(), IntegerType.INTEGER, "ColumnMetadata type should match the column type"); + } + + @Test + public void testToString() + { + String columnName = "testColumn"; + ArrowColumnHandle columnHandle = new ArrowColumnHandle(columnName, IntegerType.INTEGER); + String result = columnHandle.toString(); + String expected = columnName + ":" + IntegerType.INTEGER; + assertEquals(result, expected, "toString() should return the correct string representation"); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestArrowFlightIntegrationSmokeTest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestArrowFlightIntegrationSmokeTest.java new file mode 100644 index 000000000000..9383eaf1ee23 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestArrowFlightIntegrationSmokeTest.java @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.plugin.arrow.tests; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestIntegrationSmokeTest; +import com.facebook.presto.tests.DistributedQueryRunner; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.RootAllocator; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; + +import java.io.File; + +public class TestArrowFlightIntegrationSmokeTest + extends AbstractTestIntegrationSmokeTest +{ + private static final Logger logger = Logger.get(TestArrowFlightIntegrationSmokeTest.class); + private RootAllocator allocator; + private FlightServer server; + private Location serverLocation; + private DistributedQueryRunner arrowFlightQueryRunner; + + @BeforeClass + public void setup() + throws Exception + { + arrowFlightQueryRunner = getDistributedQueryRunner(); + File certChainFile = new File("src/test/resources/server.crt"); + File privateKeyFile = new File("src/test/resources/server.key"); + + allocator = new RootAllocator(Long.MAX_VALUE); + serverLocation = Location.forGrpcTls("127.0.0.1", 9442); + server = FlightServer.builder(allocator, serverLocation, new TestingArrowServer(allocator)) + .useTls(certChainFile, privateKeyFile) + .build(); + + server.start(); + logger.info("Server listening on port " + server.getPort()); + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return ArrowFlightQueryRunner.createQueryRunner(9442); + } + + @AfterClass(alwaysRun = true) + public void close() + throws InterruptedException + { + server.close(); + allocator.close(); + arrowFlightQueryRunner.close(); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestArrowFlightQueries.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestArrowFlightQueries.java new file mode 100644 index 000000000000..55284aee955b --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestArrowFlightQueries.java @@ -0,0 +1,175 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow.tests; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.Session; +import com.facebook.presto.common.type.TimeZoneKey; +import com.facebook.presto.testing.MaterializedResult; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestQueries; +import com.facebook.presto.tests.DistributedQueryRunner; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.RootAllocator; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.io.File; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.ZoneId; +import java.time.ZonedDateTime; +import java.time.format.DateTimeFormatter; + +import static com.facebook.presto.common.type.CharType.createCharType; +import static com.facebook.presto.common.type.DateType.DATE; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.TimeType.TIME; +import static com.facebook.presto.common.type.TimestampType.TIMESTAMP; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.common.type.VarcharType.createVarcharType; +import static com.facebook.presto.testing.MaterializedResult.resultBuilder; +import static java.lang.String.format; +import static org.testng.Assert.assertTrue; + +public class TestArrowFlightQueries + extends AbstractTestQueries +{ + private static final Logger logger = Logger.get(TestArrowFlightQueries.class); + private RootAllocator allocator; + private FlightServer server; + private Location serverLocation; + private DistributedQueryRunner arrowFlightQueryRunner; + + @BeforeClass + public void setup() + throws Exception + { + arrowFlightQueryRunner = getDistributedQueryRunner(); + File certChainFile = new File("src/test/resources/server.crt"); + File privateKeyFile = new File("src/test/resources/server.key"); + + allocator = new RootAllocator(Long.MAX_VALUE); + serverLocation = Location.forGrpcTls("localhost", 9443); + server = FlightServer.builder(allocator, serverLocation, new TestingArrowServer(allocator)) + .useTls(certChainFile, privateKeyFile) + .build(); + + server.start(); + logger.info("Server listening on port " + server.getPort()); + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return ArrowFlightQueryRunner.createQueryRunner(9443); + } + + @AfterClass(alwaysRun = true) + public void close() + throws InterruptedException + { + server.close(); + allocator.close(); + arrowFlightQueryRunner.close(); + } + + @Test + public void testShowCharColumns() + { + MaterializedResult actual = computeActual("SHOW COLUMNS FROM member"); + + MaterializedResult expectedUnparametrizedVarchar = resultBuilder(getSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR) + .row("id", "integer", "", "") + .row("name", "varchar", "", "") + .row("sex", "char", "", "") + .row("state", "char", "", "") + .build(); + + MaterializedResult expectedParametrizedVarchar = resultBuilder(getSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR) + .row("id", "integer", "", "") + .row("name", "varchar(50)", "", "") + .row("sex", "char(1)", "", "") + .row("state", "char(5)", "", "") + .build(); + + assertTrue(actual.equals(expectedParametrizedVarchar) || actual.equals(expectedUnparametrizedVarchar), + format("%s matches neither %s nor %s", actual, expectedParametrizedVarchar, expectedUnparametrizedVarchar)); + } + + @Test + public void testPredicateOnCharColumn() + { + MaterializedResult actualRow = computeActual("SELECT * from member WHERE state = 'CD'"); + MaterializedResult expectedRow = resultBuilder(getSession(), INTEGER, createVarcharType(50), createCharType(1), createCharType(5)) + .row(2, "MARY", "F", "CD ") + .build(); + assertTrue(actualRow.equals(expectedRow)); + } + + @Test + public void testSelectTime() + { + MaterializedResult actualRow = computeActual("SELECT * from event WHERE id = 1"); + Session session = getSession(); + MaterializedResult expectedRow = resultBuilder(session, INTEGER, DATE, TIME, TIMESTAMP) + .row(1, + getDate("2004-12-31"), + getTimeAtZone("23:59:59", session.getTimeZoneKey()), + getDateTimeAtZone("2005-12-31 23:59:59", session.getTimeZoneKey())) + .build(); + assertTrue(actualRow.equals(expectedRow)); + } + + private LocalDate getDate(String dateString) + { + DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd"); + LocalDate localDate = LocalDate.parse(dateString, formatter); + + return localDate; + } + + private LocalTime getTimeAtZone(String timeString, TimeZoneKey timeZoneKey) + { + DateTimeFormatter formatter = DateTimeFormatter.ofPattern("HH:mm:ss"); + LocalTime localTime = LocalTime.parse(timeString, formatter); + + LocalDateTime localDateTime = LocalDateTime.of(LocalDate.of(1970, 1, 1), localTime); + ZonedDateTime localZonedDateTime = localDateTime.atZone(ZoneId.systemDefault()); + + ZoneId zoneId = ZoneId.of(timeZoneKey.getId()); + ZonedDateTime zonedDateTime = localZonedDateTime.withZoneSameInstant(zoneId); + + LocalTime localTimeAtZone = zonedDateTime.toLocalTime(); + return localTimeAtZone; + } + + private LocalDateTime getDateTimeAtZone(String dateTimeString, TimeZoneKey timeZoneKey) + { + DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"); + LocalDateTime localDateTime = LocalDateTime.parse(dateTimeString, formatter); + + ZonedDateTime localZonedDateTime = localDateTime.atZone(ZoneId.systemDefault()); + + ZoneId zoneId = ZoneId.of(timeZoneKey.getId()); + ZonedDateTime zonedDateTime = localZonedDateTime.withZoneSameInstant(zoneId); + + LocalDateTime localDateTimeAtZone = zonedDateTime.toLocalDateTime(); + return localDateTimeAtZone; + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestArrowFlightQueriesWithDictionaryVector.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestArrowFlightQueriesWithDictionaryVector.java new file mode 100644 index 000000000000..3bff8663a99b --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestArrowFlightQueriesWithDictionaryVector.java @@ -0,0 +1,94 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow.tests; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.testing.MaterializedResult; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestQueryFramework; +import com.facebook.presto.tests.DistributedQueryRunner; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.RootAllocator; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.io.File; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.testing.MaterializedResult.resultBuilder; +import static com.facebook.presto.testing.assertions.Assert.assertEquals; + +public class TestArrowFlightQueriesWithDictionaryVector + extends AbstractTestQueryFramework +{ + private static final Logger logger = Logger.get(TestArrowFlightQueriesWithDictionaryVector.class); + private RootAllocator allocator; + private FlightServer server; + private Location serverLocation; + private DistributedQueryRunner arrowFlightQueryRunner; + + @BeforeClass + public void setup() + throws Exception + { + arrowFlightQueryRunner = getDistributedQueryRunner(); + File certChainFile = new File("src/test/resources/server.crt"); + File privateKeyFile = new File("src/test/resources/server.key"); + + allocator = new RootAllocator(Long.MAX_VALUE); + serverLocation = Location.forGrpcTls("localhost", 9444); + + server = FlightServer.builder(allocator, serverLocation, new TestingArrowServerUsingDictionaryVector(allocator)) + .useTls(certChainFile, privateKeyFile) + .build(); + + server.start(); + logger.info("Server listening on port " + server.getPort()); + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return ArrowFlightQueryRunner.createQueryRunner(9444); + } + + @AfterClass(alwaysRun = true) + public void close() + throws InterruptedException + { + allocator.close(); + server.close(); + arrowFlightQueryRunner.close(); + } + + @Test + public void testDictionaryBlock() + { + // Retrieve the actual result from the query + MaterializedResult actual = computeActual("SELECT shipmode, orderkey FROM lineitem"); + // Validate the row count first + assertEquals(actual.getRowCount(), 3); + // Now, validate each row + MaterializedResult expectedRow = resultBuilder(getSession(), VARCHAR, BIGINT) + .row("apple", 0L) + .row("banana", 1L) + .row("cherry", 2L) + .build(); + assertEquals(expectedRow, actual); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestArrowHandleResolver.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestArrowHandleResolver.java new file mode 100644 index 000000000000..34ba773d1026 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestArrowHandleResolver.java @@ -0,0 +1,73 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow.tests; + +import com.facebook.plugin.arrow.ArrowColumnHandle; +import com.facebook.plugin.arrow.ArrowHandleResolver; +import com.facebook.plugin.arrow.ArrowSplit; +import com.facebook.plugin.arrow.ArrowTableHandle; +import com.facebook.plugin.arrow.ArrowTableLayoutHandle; +import com.facebook.plugin.arrow.ArrowTransactionHandle; +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; + +public class TestArrowHandleResolver +{ + @Test + public void testGetTableHandleClass() + { + ArrowHandleResolver resolver = new ArrowHandleResolver(); + assertEquals( + resolver.getTableHandleClass(), + ArrowTableHandle.class, + "getTableHandleClass should return ArrowTableHandle class."); + } + @Test + public void testGetTableLayoutHandleClass() + { + ArrowHandleResolver resolver = new ArrowHandleResolver(); + assertEquals( + resolver.getTableLayoutHandleClass(), + ArrowTableLayoutHandle.class, + "getTableLayoutHandleClass should return ArrowTableLayoutHandle class."); + } + @Test + public void testGetColumnHandleClass() + { + ArrowHandleResolver resolver = new ArrowHandleResolver(); + assertEquals( + resolver.getColumnHandleClass(), + ArrowColumnHandle.class, + "getColumnHandleClass should return ArrowColumnHandle class."); + } + @Test + public void testGetSplitClass() + { + ArrowHandleResolver resolver = new ArrowHandleResolver(); + assertEquals( + resolver.getSplitClass(), + ArrowSplit.class, + "getSplitClass should return ArrowSplit class."); + } + @Test + public void testGetTransactionHandleClass() + { + ArrowHandleResolver resolver = new ArrowHandleResolver(); + assertEquals( + resolver.getTransactionHandleClass(), + ArrowTransactionHandle.class, + "getTransactionHandleClass should return ArrowTransactionHandle class."); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestArrowPlugin.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestArrowPlugin.java new file mode 100644 index 000000000000..05498c1c136f --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestArrowPlugin.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.plugin.arrow.tests; + +import com.facebook.plugin.arrow.ArrowConnectorFactory; +import com.facebook.plugin.arrow.ArrowPlugin; +import com.facebook.plugin.arrow.TestingArrowFlightPlugin; +import com.facebook.presto.spi.connector.ConnectorFactory; +import org.testng.annotations.Test; + +import static com.facebook.airlift.testing.Assertions.assertInstanceOf; +import static com.google.common.collect.Iterables.getOnlyElement; + +public class TestArrowPlugin +{ + @Test + public void testStartup() + { + ArrowPlugin plugin = new TestingArrowFlightPlugin(); + ConnectorFactory factory = getOnlyElement(plugin.getConnectorFactories()); + assertInstanceOf(factory, ArrowConnectorFactory.class); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestArrowSplit.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestArrowSplit.java new file mode 100644 index 000000000000..8c2293b08619 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestArrowSplit.java @@ -0,0 +1,86 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow.tests; + +import com.facebook.plugin.arrow.ArrowSplit; +import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.schedule.NodeSelectionStrategy; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.Ticket; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +@Test(singleThreaded = true) +public class TestArrowSplit +{ + private ArrowSplit arrowSplit; + private String schemaName; + private String tableName; + private byte[] ticketArray; + private FlightEndpoint flightEndpoint; + + @BeforeMethod + public void setUp() + throws URISyntaxException + { + schemaName = "testSchema"; + tableName = "testTable"; + ticketArray = new byte[] {1, 2, 3, 4}; + + List list = new ArrayList<>(); + Location location = new Location("http://localhost:8080"); + list.add(location); + Ticket ticket = new Ticket(ByteBuffer.wrap(ticketArray).array()); // Wrap the byte array in a Ticket + flightEndpoint = new FlightEndpoint(ticket, location); + // Instantiate ArrowSplit with mock data + arrowSplit = new ArrowSplit(schemaName, tableName, flightEndpoint.serialize().array()); + } + + @Test + public void testConstructorAndGetters() + throws IOException, URISyntaxException + { + // Test that the constructor correctly initializes fields + assertEquals(arrowSplit.getSchemaName(), schemaName, "Schema name should match."); + assertEquals(arrowSplit.getTableName(), tableName, "Table name should match."); + assertEquals(arrowSplit.getFlightEndpoint(), flightEndpoint.serialize().array(), "Byte array should match"); + } + + @Test + public void testNodeSelectionStrategy() + { + // Test that the node selection strategy is NO_PREFERENCE + assertEquals(arrowSplit.getNodeSelectionStrategy(), NodeSelectionStrategy.NO_PREFERENCE, "Node selection strategy should be NO_PREFERENCE."); + } + + @Test + public void testGetPreferredNodes() + { + // Test that the preferred nodes list is empty + List preferredNodes = arrowSplit.getPreferredNodes(null); + assertNotNull(preferredNodes, "Preferred nodes list should not be null."); + assertTrue(preferredNodes.isEmpty(), "Preferred nodes list should be empty."); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestArrowTableHandle.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestArrowTableHandle.java new file mode 100644 index 000000000000..29c8199b7540 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestArrowTableHandle.java @@ -0,0 +1,38 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow.tests; + +import com.facebook.airlift.testing.EquivalenceTester; +import com.facebook.plugin.arrow.ArrowTableHandle; +import org.testng.annotations.Test; + +import static com.facebook.plugin.arrow.tests.ArrowMetadataUtil.TABLE_CODEC; +import static com.facebook.plugin.arrow.tests.ArrowMetadataUtil.assertJsonRoundTrip; + +public class TestArrowTableHandle +{ + @Test + public void testJsonRoundTrip() + { + assertJsonRoundTrip(TABLE_CODEC, new ArrowTableHandle("schema", "table")); + } + + @Test + public void testEquivalence() + { + EquivalenceTester.equivalenceTester() + .addEquivalentGroup( + new ArrowTableHandle("tm_engine", "employees")).check(); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestArrowTableLayoutHandle.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestArrowTableLayoutHandle.java new file mode 100644 index 000000000000..d832eb5f4600 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestArrowTableLayoutHandle.java @@ -0,0 +1,119 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.plugin.arrow.tests; + +import com.facebook.plugin.arrow.ArrowColumnHandle; +import com.facebook.plugin.arrow.ArrowTableHandle; +import com.facebook.plugin.arrow.ArrowTableLayoutHandle; +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.spi.ColumnHandle; +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; + +public class TestArrowTableLayoutHandle +{ + @Test + public void testConstructorAndGetters() + { + ArrowTableHandle tableHandle = new ArrowTableHandle("schema", "table"); + List columnHandles = Arrays.asList( + new ArrowColumnHandle("column1", IntegerType.INTEGER), + new ArrowColumnHandle("column2", VarcharType.VARCHAR)); + TupleDomain tupleDomain = TupleDomain.all(); + + ArrowTableLayoutHandle layoutHandle = new ArrowTableLayoutHandle(tableHandle, columnHandles, tupleDomain); + + assertEquals(layoutHandle.getTableHandle(), tableHandle, "Table handle mismatch."); + assertEquals(layoutHandle.getColumnHandles(), columnHandles, "Column handles mismatch."); + assertEquals(layoutHandle.getTupleDomain(), tupleDomain, "Tuple domain mismatch."); + } + + @Test + public void testToString() + { + ArrowTableHandle tableHandle = new ArrowTableHandle("schema", "table"); + List columnHandles = Arrays.asList( + new ArrowColumnHandle("column1", IntegerType.INTEGER), + new ArrowColumnHandle("column2", BigintType.BIGINT)); + TupleDomain tupleDomain = TupleDomain.all(); + + ArrowTableLayoutHandle layoutHandle = new ArrowTableLayoutHandle(tableHandle, columnHandles, tupleDomain); + + String expectedString = "tableHandle:" + tableHandle + ", columnHandles:" + columnHandles + ", tupleDomain:" + tupleDomain; + assertEquals(layoutHandle.toString(), expectedString, "toString output mismatch."); + } + + @Test + public void testEqualsAndHashCode() + { + ArrowTableHandle tableHandle1 = new ArrowTableHandle("schema", "table"); + ArrowTableHandle tableHandle2 = new ArrowTableHandle("schema", "different_table"); + + List columnHandles1 = Arrays.asList( + new ArrowColumnHandle("column1", IntegerType.INTEGER), + new ArrowColumnHandle("column2", VarcharType.VARCHAR)); + List columnHandles2 = Collections.singletonList( + new ArrowColumnHandle("column1", IntegerType.INTEGER)); + + TupleDomain tupleDomain1 = TupleDomain.all(); + TupleDomain tupleDomain2 = TupleDomain.none(); + + ArrowTableLayoutHandle layoutHandle1 = new ArrowTableLayoutHandle(tableHandle1, columnHandles1, tupleDomain1); + ArrowTableLayoutHandle layoutHandle2 = new ArrowTableLayoutHandle(tableHandle1, columnHandles1, tupleDomain1); + ArrowTableLayoutHandle layoutHandle3 = new ArrowTableLayoutHandle(tableHandle2, columnHandles1, tupleDomain1); + ArrowTableLayoutHandle layoutHandle4 = new ArrowTableLayoutHandle(tableHandle1, columnHandles2, tupleDomain1); + ArrowTableLayoutHandle layoutHandle5 = new ArrowTableLayoutHandle(tableHandle1, columnHandles1, tupleDomain2); + + // Test equality + assertEquals(layoutHandle1, layoutHandle2, "Handles with same attributes should be equal."); + assertNotEquals(layoutHandle1, layoutHandle3, "Handles with different tableHandles should not be equal."); + assertNotEquals(layoutHandle1, layoutHandle4, "Handles with different columnHandles should not be equal."); + assertNotEquals(layoutHandle1, layoutHandle5, "Handles with different tupleDomains should not be equal."); + assertNotEquals(layoutHandle1, null, "Handle should not be equal to null."); + assertNotEquals(layoutHandle1, new Object(), "Handle should not be equal to an object of another class."); + + // Test hash codes + assertEquals(layoutHandle1.hashCode(), layoutHandle2.hashCode(), "Equal handles should have same hash code."); + assertNotEquals(layoutHandle1.hashCode(), layoutHandle3.hashCode(), "Handles with different tableHandles should have different hash codes."); + assertNotEquals(layoutHandle1.hashCode(), layoutHandle4.hashCode(), "Handles with different columnHandles should have different hash codes."); + } + + @Test(expectedExceptions = NullPointerException.class, expectedExceptionsMessageRegExp = "table is null") + public void testConstructorNullTableHandle() + { + new ArrowTableLayoutHandle(null, Collections.emptyList(), TupleDomain.all()); + } + + @Test(expectedExceptions = NullPointerException.class, expectedExceptionsMessageRegExp = "columns are null") + public void testConstructorNullColumnHandles() + { + new ArrowTableLayoutHandle(new ArrowTableHandle("schema", "table"), null, TupleDomain.all()); + } + + @Test(expectedExceptions = NullPointerException.class, expectedExceptionsMessageRegExp = "domain is null") + public void testConstructorNullTupleDomain() + { + new ArrowTableLayoutHandle(new ArrowTableHandle("schema", "table"), Collections.emptyList(), null); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestingArrowServer.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestingArrowServer.java new file mode 100644 index 000000000000..36261c250da7 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestingArrowServer.java @@ -0,0 +1,316 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.plugin.arrow.tests; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.log.Logger; +import com.facebook.plugin.arrow.TestingArrowFlightRequest; +import com.facebook.plugin.arrow.TestingArrowFlightResponse; +import com.google.common.collect.ImmutableList; +import org.apache.arrow.adapter.jdbc.ArrowVectorIterator; +import org.apache.arrow.adapter.jdbc.JdbcToArrow; +import org.apache.arrow.adapter.jdbc.JdbcToArrowConfigBuilder; +import org.apache.arrow.flight.Action; +import org.apache.arrow.flight.ActionType; +import org.apache.arrow.flight.Criteria; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.PutResult; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; + +import java.io.IOException; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.Calendar; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.TimeZone; +import java.util.concurrent.ThreadLocalRandom; + +import static com.facebook.airlift.json.JsonCodec.jsonCodec; +import static com.facebook.presto.common.Utils.checkArgument; + +public class TestingArrowServer + implements FlightProducer +{ + private final RootAllocator allocator; + private final Connection connection; + private static final Logger logger = Logger.get(TestingArrowServer.class); + private final JsonCodec requestCodec; + private final JsonCodec responseCodec; + + public TestingArrowServer(RootAllocator allocator) throws Exception + { + this.allocator = allocator; + String h2JdbcUrl = "jdbc:h2:mem:testdb" + System.nanoTime() + "_" + ThreadLocalRandom.current().nextInt() + ";DB_CLOSE_DELAY=-1"; + TestingH2DatabaseSetup.setup(h2JdbcUrl); + this.connection = DriverManager.getConnection(h2JdbcUrl, "sa", ""); + this.requestCodec = jsonCodec(TestingArrowFlightRequest.class); + this.responseCodec = jsonCodec(TestingArrowFlightResponse.class); + } + + @Override + public void getStream(CallContext callContext, Ticket ticket, ServerStreamListener serverStreamListener) + { + try (Statement stmt = connection.createStatement()) { + TestingArrowFlightRequest request = requestCodec.fromJson(ticket.getBytes()); + checkArgument(request != null, "Request is null"); + checkArgument(request.getQuery().isPresent(), "Query is missing"); + + // Extract and validate the SQL query + String query = request.getQuery().get(); + if (query.trim().isEmpty()) { + throw new IllegalArgumentException("Query cannot be empty."); + } + + logger.info("Executing query: " + query); + query = query.toUpperCase(); // Optionally, to maintain consistency + + try (ResultSet rs = stmt.executeQuery(query)) { + JdbcToArrowConfigBuilder config = new JdbcToArrowConfigBuilder().setAllocator(allocator).setTargetBatchSize(2048) + .setCalendar(Calendar.getInstance(TimeZone.getDefault())); + ArrowVectorIterator iterator = JdbcToArrow.sqlToArrowVectorIterator(rs, config.build()); + boolean firstBatch = true; + + VectorLoader vectorLoader = null; + VectorSchemaRoot newRoot = null; + while (iterator.hasNext()) { + try (VectorSchemaRoot root = iterator.next()) { + VectorUnloader vectorUnloader = new VectorUnloader(root); + try (ArrowRecordBatch batch = vectorUnloader.getRecordBatch()) { + if (firstBatch) { + firstBatch = false; + newRoot = VectorSchemaRoot.create(root.getSchema(), allocator); + vectorLoader = new VectorLoader(newRoot); + serverStreamListener.start(newRoot); + } + vectorLoader.load(batch); + serverStreamListener.putNext(); + } + } + } + if (newRoot != null) { + newRoot.close(); + } + serverStreamListener.completed(); + } + } + // Handle Arrow processing errors + catch (IOException e) { + logger.error("Arrow data processing failed", e); + serverStreamListener.error(e); + throw new RuntimeException("Failed to process Arrow data", e); + } + // Handle all other exceptions, including parsing errors + catch (Exception e) { + logger.error("Ticket processing failed", e); + serverStreamListener.error(e); + throw new RuntimeException("Failed to process the ticket", e); + } + } + + @Override + public void listFlights(CallContext callContext, Criteria criteria, StreamListener streamListener) + { + throw new UnsupportedOperationException("This operation is not supported"); + } + + @Override + public FlightInfo getFlightInfo(CallContext callContext, FlightDescriptor flightDescriptor) + { + try { + TestingArrowFlightRequest request = requestCodec.fromJson(flightDescriptor.getCommand()); + checkArgument(request != null, "Request is null"); + + checkArgument(request.getSchema().isPresent(), "Schema is missing"); + String schemaName = request.getSchema().get(); + Optional tableName = request.getTable(); + String selectStatement = request.getQuery().orElse(null); + + List fields = new ArrayList<>(); + if (tableName.isPresent()) { + String query = "SELECT * FROM INFORMATION_SCHEMA.COLUMNS " + + "WHERE TABLE_SCHEMA='" + schemaName.toUpperCase() + "' " + + "AND TABLE_NAME='" + tableName.get().toUpperCase() + "'"; + + try (ResultSet rs = connection.createStatement().executeQuery(query)) { + while (rs.next()) { + String columnName = rs.getString("COLUMN_NAME"); + String dataType = rs.getString("TYPE_NAME"); + String charMaxLength = rs.getString("CHARACTER_MAXIMUM_LENGTH"); + int precision = rs.getInt("NUMERIC_PRECISION"); + int scale = rs.getInt("NUMERIC_SCALE"); + + ArrowType arrowType = convertSqlTypeToArrowType(dataType, precision, scale); + Map metaDataMap = new HashMap<>(); + metaDataMap.put("columnNativeType", dataType); + if (charMaxLength != null) { + metaDataMap.put("columnLength", charMaxLength); + } + FieldType fieldType = new FieldType(true, arrowType, null, metaDataMap); + Field field = new Field(columnName, fieldType, null); + fields.add(field); + } + } + } + else if (selectStatement != null) { + selectStatement = selectStatement.toUpperCase(); + logger.info("Executing SELECT query: " + selectStatement); + try (ResultSet rs = connection.createStatement().executeQuery(selectStatement)) { + ResultSetMetaData metaData = rs.getMetaData(); + int columnCount = metaData.getColumnCount(); + + for (int i = 1; i <= columnCount; i++) { + String columnName = metaData.getColumnName(i); + String columnType = metaData.getColumnTypeName(i); + int precision = metaData.getPrecision(i); + int scale = metaData.getScale(i); + + ArrowType arrowType = convertSqlTypeToArrowType(columnType, precision, scale); + Field field = new Field(columnName, FieldType.nullable(arrowType), null); + fields.add(field); + } + } + } + else { + throw new IllegalArgumentException("Either schema_name/table_name or select_statement must be provided."); + } + + Schema schema = new Schema(fields); + FlightEndpoint endpoint = new FlightEndpoint(new Ticket(flightDescriptor.getCommand())); + return new FlightInfo(schema, flightDescriptor, Collections.singletonList(endpoint), -1, -1); + } + catch (Exception e) { + logger.error(e); + throw new RuntimeException("Failed to retrieve FlightInfo", e); + } + } + + @Override + public Runnable acceptPut(CallContext callContext, FlightStream flightStream, StreamListener streamListener) + { + throw new UnsupportedOperationException("This operation is not supported"); + } + + @Override + public void doAction(CallContext callContext, Action action, StreamListener streamListener) + { + try { + TestingArrowFlightRequest request = requestCodec.fromJson(action.getBody()); + Optional schemaName = request.getSchema(); + + String query; + if (!schemaName.isPresent()) { + query = "SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA"; + } + else { + query = "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA='" + schemaName.get().toUpperCase() + "'"; + } + ResultSet rs = connection.createStatement().executeQuery(query); + List names = new ArrayList<>(); + while (rs.next()) { + names.add(rs.getString(1)); + } + + TestingArrowFlightResponse response; + if (!schemaName.isPresent()) { + response = new TestingArrowFlightResponse(names, ImmutableList.of()); + } + else { + response = new TestingArrowFlightResponse(ImmutableList.of(), names); + } + + streamListener.onNext(new Result(responseCodec.toJsonBytes(response))); + streamListener.onCompleted(); + } + catch (Exception e) { + streamListener.onError(e); + } + } + + @Override + public void listActions(CallContext callContext, StreamListener streamListener) + { + throw new UnsupportedOperationException("This operation is not supported"); + } + + private ArrowType convertSqlTypeToArrowType(String sqlType, int precision, int scale) + { + switch (sqlType.toUpperCase()) { + case "VARCHAR": + case "CHAR": + case "CHARACTER VARYING": + case "CHARACTER": + case "CLOB": + return new ArrowType.Utf8(); + case "INTEGER": + case "INT": + return new ArrowType.Int(32, true); + case "BIGINT": + return new ArrowType.Int(64, true); + case "SMALLINT": + return new ArrowType.Int(16, true); + case "TINYINT": + return new ArrowType.Int(8, true); + case "DOUBLE": + case "DOUBLE PRECISION": + case "FLOAT": + return new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE); + case "REAL": + return new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE); + case "BOOLEAN": + return new ArrowType.Bool(); + case "DATE": + return new ArrowType.Date(DateUnit.DAY); + case "TIMESTAMP": + return new ArrowType.Timestamp(TimeUnit.MILLISECOND, null); + case "TIME": + return new ArrowType.Time(TimeUnit.MILLISECOND, 32); + case "DECIMAL": + case "NUMERIC": + return new ArrowType.Decimal(precision, scale); + case "BINARY": + case "VARBINARY": + return new ArrowType.Binary(); + case "NULL": + return new ArrowType.Null(); + default: + throw new IllegalArgumentException("Unsupported SQL type: " + sqlType); + } + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestingArrowServerUsingDictionaryVector.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestingArrowServerUsingDictionaryVector.java new file mode 100644 index 000000000000..435c33ae1860 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestingArrowServerUsingDictionaryVector.java @@ -0,0 +1,183 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow.tests; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.log.Logger; +import com.facebook.plugin.arrow.TestingArrowFlightRequest; +import com.facebook.plugin.arrow.TestingArrowFlightResponse; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import org.apache.arrow.flight.Action; +import org.apache.arrow.flight.ActionType; +import org.apache.arrow.flight.Criteria; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.PutResult; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Optional; + +import static com.facebook.airlift.json.JsonCodec.jsonCodec; + +public class TestingArrowServerUsingDictionaryVector + implements FlightProducer +{ + private final RootAllocator allocator; + private final ObjectMapper objectMapper = new ObjectMapper(); + private final JsonCodec requestCodec; + private final JsonCodec responseCodec; + + private static final Logger logger = Logger.get(TestingArrowServerUsingDictionaryVector.class); + + public TestingArrowServerUsingDictionaryVector(RootAllocator allocator) + throws Exception + { + this.allocator = allocator; + this.requestCodec = jsonCodec(TestingArrowFlightRequest.class); + this.responseCodec = jsonCodec(TestingArrowFlightResponse.class); + } + + @Override + public void getStream(CallContext callContext, Ticket ticket, ServerStreamListener serverStreamListener) + { + VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator); + dictionaryVector.allocateNew(3); // allocating 3 elements in dictionary + + // Fill dictionaryVector with some values + dictionaryVector.set(0, "apple".getBytes()); + dictionaryVector.set(1, "banana".getBytes()); + dictionaryVector.set(2, "cherry".getBytes()); + dictionaryVector.setValueCount(3); + + ArrowType.Int index = new ArrowType.Int(32, true); + FieldType fieldType = new FieldType(true, new ArrowType.Int(32, true), new DictionaryEncoding(1L, false, index)); + + Field shipmode = new Field("shipmode", fieldType, null); + Field orderkey = new Field("orderkey", FieldType.nullable(new ArrowType.Int(64, true)), null); + + Schema schema = new Schema(Arrays.asList(orderkey, shipmode)); + + // Create a VectorSchemaRoot from the schema and vectors + try (VectorSchemaRoot schemaRoot = VectorSchemaRoot.create(schema, allocator)) { + schemaRoot.allocateNew(); + + BigIntVector orderkeyVector = (BigIntVector) schemaRoot.getVector("orderkey"); + + orderkeyVector.allocateNew(3); + + for (int i = 0; i < 3; i++) { + orderkeyVector.setSafe(i, i); + } + orderkeyVector.setValueCount(3); + IntVector intFieldVector = (IntVector) schemaRoot.getVector("shipmode"); + intFieldVector.allocateNew(3); + + for (int i = 0; i < 3; i++) { + intFieldVector.setSafe(i, i); + } + intFieldVector.setValueCount(3); + + schemaRoot.setRowCount(3); + + Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, index)); + DictionaryProvider provider = new DictionaryProvider.MapDictionaryProvider(dictionary); + serverStreamListener.start(schemaRoot, provider); + + serverStreamListener.putNext(); + + serverStreamListener.completed(); + dictionaryVector.close(); + } + } + + @Override + public void listFlights(CallContext callContext, Criteria criteria, StreamListener streamListener) + { + throw new UnsupportedOperationException("This operation is not supported"); + } + + @Override + public FlightInfo getFlightInfo(CallContext callContext, FlightDescriptor flightDescriptor) + { + try { + Field shipmode = new Field("shipmode", FieldType.nullable(new ArrowType.Utf8()), null); + Field orderkey = new Field("orderkey", FieldType.nullable(new ArrowType.Int(64, true)), null); + + Schema schema = new Schema(Arrays.asList(orderkey, shipmode)); + + FlightEndpoint endpoint = new FlightEndpoint(new Ticket(flightDescriptor.getCommand())); + return new FlightInfo(schema, flightDescriptor, Collections.singletonList(endpoint), -1, -1); + } + catch (Exception e) { + logger.error(e); + throw new RuntimeException("Failed to retrieve FlightInfo", e); + } + } + + @Override + public Runnable acceptPut(CallContext callContext, FlightStream flightStream, StreamListener streamListener) + { + throw new UnsupportedOperationException("This operation is not supported"); + } + + @Override + public void doAction(CallContext callContext, Action action, StreamListener streamListener) + { + try { + TestingArrowFlightRequest request = requestCodec.fromJson(action.getBody()); + Optional schemaName = request.getSchema(); + + TestingArrowFlightResponse response; + if (!schemaName.isPresent()) { + // Return the list of schemas + response = new TestingArrowFlightResponse(ImmutableList.of("TPCH"), ImmutableList.of()); + } + else { + // Return the list of tables + response = new TestingArrowFlightResponse(ImmutableList.of(), ImmutableList.of("LINEITEM")); + } + + streamListener.onNext(new Result(responseCodec.toJsonBytes(response))); + streamListener.onCompleted(); + } + catch (Exception e) { + streamListener.onError(e); + } + } + + @Override + public void listActions(CallContext callContext, StreamListener streamListener) + { + throw new UnsupportedOperationException("This operation is not supported"); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestingH2DatabaseSetup.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestingH2DatabaseSetup.java new file mode 100644 index 000000000000..60c7f0c07a5a --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/tests/TestingH2DatabaseSetup.java @@ -0,0 +1,273 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow.tests; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.ConnectorTableMetadata; +import com.facebook.presto.spi.RecordCursor; +import com.facebook.presto.spi.RecordSet; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.tpch.TpchMetadata; +import com.facebook.presto.tpch.TpchTableHandle; +import com.google.common.base.Joiner; +import io.airlift.tpch.TpchTable; +import org.jdbi.v3.core.Handle; +import org.jdbi.v3.core.Jdbi; +import org.jdbi.v3.core.statement.PreparedBatch; +import org.joda.time.DateTimeZone; + +import java.sql.Connection; +import java.sql.Date; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.DateType.DATE; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.tpch.TpchMetadata.TINY_SCHEMA_NAME; +import static com.facebook.presto.tpch.TpchRecordSet.createTpchRecordSet; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.airlift.tpch.TpchTable.CUSTOMER; +import static io.airlift.tpch.TpchTable.LINE_ITEM; +import static io.airlift.tpch.TpchTable.NATION; +import static io.airlift.tpch.TpchTable.ORDERS; +import static io.airlift.tpch.TpchTable.PART; +import static io.airlift.tpch.TpchTable.PART_SUPPLIER; +import static io.airlift.tpch.TpchTable.REGION; +import static io.airlift.tpch.TpchTable.SUPPLIER; +import static java.lang.String.format; +import static java.util.Collections.nCopies; + +public class TestingH2DatabaseSetup +{ + private static final Logger logger = Logger.get(TestingH2DatabaseSetup.class); + private TestingH2DatabaseSetup() + { + throw new UnsupportedOperationException("This is a utility class and cannot be instantiated"); + } + + public static void setup(String h2JdbcUrl) throws Exception + { + Class.forName("org.h2.Driver"); + + Connection conn = DriverManager.getConnection(h2JdbcUrl, "sa", ""); + + Jdbi jdbi = Jdbi.create(h2JdbcUrl, "sa", ""); + Handle handle = jdbi.open(); // Get a handle for the database connection + + TpchMetadata tpchMetadata = new TpchMetadata(""); + + Statement stmt = conn.createStatement(); + + // Create schema + stmt.execute("CREATE SCHEMA IF NOT EXISTS tpch"); + + stmt.execute("CREATE TABLE tpch.member (" + + " id INTEGER PRIMARY KEY," + + " name VARCHAR(50)," + + " sex CHAR(1)," + + " state CHAR(5)" + + ")"); + stmt.execute("INSERT INTO tpch.member VALUES(1, 'TOM', 'M', 'TMX '),(2, 'MARY', 'F', 'CD ')"); + + stmt.execute("CREATE TABLE tpch.event (" + + " id INTEGER PRIMARY KEY," + + " startDate DATE," + + " startTime TIME," + + " startTimestamp TIMESTAMP" + + ")"); + stmt.execute("INSERT INTO tpch.event VALUES(1, DATE '2004-12-31', TIME '23:59:59'," + + " TIMESTAMP '2005-12-31 23:59:59')"); + + stmt.execute("CREATE TABLE tpch.orders (\n" + + " orderkey BIGINT PRIMARY KEY,\n" + + " custkey BIGINT NOT NULL,\n" + + " orderstatus VARCHAR(1) NOT NULL,\n" + + " totalprice DOUBLE NOT NULL,\n" + + " orderdate DATE NOT NULL,\n" + + " orderpriority VARCHAR(15) NOT NULL,\n" + + " clerk VARCHAR(15) NOT NULL,\n" + + " shippriority INTEGER NOT NULL,\n" + + " comment VARCHAR(79) NOT NULL\n" + + ")"); + stmt.execute("CREATE INDEX custkey_index ON tpch.orders (custkey)"); + insertRows(tpchMetadata, ORDERS, handle); + + handle.execute("CREATE TABLE tpch.lineitem (\n" + + " orderkey BIGINT,\n" + + " partkey BIGINT NOT NULL,\n" + + " suppkey BIGINT NOT NULL,\n" + + " linenumber INTEGER,\n" + + " quantity DOUBLE NOT NULL,\n" + + " extendedprice DOUBLE NOT NULL,\n" + + " discount DOUBLE NOT NULL,\n" + + " tax DOUBLE NOT NULL,\n" + + " returnflag CHAR(1) NOT NULL,\n" + + " linestatus CHAR(1) NOT NULL,\n" + + " shipdate DATE NOT NULL,\n" + + " commitdate DATE NOT NULL,\n" + + " receiptdate DATE NOT NULL,\n" + + " shipinstruct VARCHAR(25) NOT NULL,\n" + + " shipmode VARCHAR(10) NOT NULL,\n" + + " comment VARCHAR(44) NOT NULL,\n" + + " PRIMARY KEY (orderkey, linenumber)" + + ")"); + insertRows(tpchMetadata, LINE_ITEM, handle); + + handle.execute(" CREATE TABLE tpch.partsupp (\n" + + " partkey BIGINT NOT NULL,\n" + + " suppkey BIGINT NOT NULL,\n" + + " availqty INTEGER NOT NULL,\n" + + " supplycost DOUBLE NOT NULL,\n" + + " comment VARCHAR(199) NOT NULL,\n" + + " PRIMARY KEY(partkey, suppkey)" + + ")"); + insertRows(tpchMetadata, PART_SUPPLIER, handle); + + handle.execute("CREATE TABLE tpch.nation (\n" + + " nationkey BIGINT PRIMARY KEY,\n" + + " name VARCHAR(25) NOT NULL,\n" + + " regionkey BIGINT NOT NULL,\n" + + " comment VARCHAR(152) NOT NULL\n" + + ")"); + insertRows(tpchMetadata, NATION, handle); + + handle.execute("CREATE TABLE tpch.region(\n" + + " regionkey BIGINT PRIMARY KEY,\n" + + " name VARCHAR(25) NOT NULL,\n" + + " comment VARCHAR(115) NOT NULL\n" + + ")"); + insertRows(tpchMetadata, REGION, handle); + handle.execute("CREATE TABLE tpch.part(\n" + + " partkey BIGINT PRIMARY KEY,\n" + + " name VARCHAR(55) NOT NULL,\n" + + " mfgr VARCHAR(25) NOT NULL,\n" + + " brand VARCHAR(10) NOT NULL,\n" + + " type VARCHAR(25) NOT NULL,\n" + + " size INTEGER NOT NULL,\n" + + " container VARCHAR(10) NOT NULL,\n" + + " retailprice DOUBLE NOT NULL,\n" + + " comment VARCHAR(23) NOT NULL\n" + + ")"); + insertRows(tpchMetadata, PART, handle); + handle.execute(" CREATE TABLE tpch.customer ( \n" + + " custkey BIGINT NOT NULL, \n" + + " name VARCHAR(25) NOT NULL, \n" + + " address VARCHAR(40) NOT NULL, \n" + + " nationkey BIGINT NOT NULL, \n" + + " phone VARCHAR(15) NOT NULL, \n" + + " acctbal DOUBLE NOT NULL, \n" + + " mktsegment VARCHAR(10) NOT NULL, \n" + + " comment VARCHAR(117) NOT NULL \n" + + " ) "); + insertRows(tpchMetadata, CUSTOMER, handle); + handle.execute(" CREATE TABLE tpch.supplier ( \n" + + " suppkey bigint NOT NULL, \n" + + " name varchar(25) NOT NULL, \n" + + " address varchar(40) NOT NULL, \n" + + " nationkey bigint NOT NULL, \n" + + " phone varchar(15) NOT NULL, \n" + + " acctbal double NOT NULL, \n" + + " comment varchar(101) NOT NULL \n" + + " ) "); + insertRows(tpchMetadata, SUPPLIER, handle); + + ResultSet resultSet1 = stmt.executeQuery("SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'TPCH'"); + List tables = new ArrayList<>(); + while (resultSet1.next()) { + String tableName = resultSet1.getString("TABLE_NAME"); + tables.add(tableName); + } + logger.info("Tables in 'tpch' schema: %s", tables.stream().collect(Collectors.joining(", "))); + + ResultSet resultSet = stmt.executeQuery("SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA"); + List schemas = new ArrayList<>(); + while (resultSet.next()) { + String schemaName = resultSet.getString("SCHEMA_NAME"); + schemas.add(schemaName); + } + logger.info("Schemas: %s", schemas.stream().collect(Collectors.joining(", "))); + } + + private static void insertRows(TpchMetadata tpchMetadata, TpchTable tpchTable, Handle handle) + { + TpchTableHandle tableHandle = tpchMetadata.getTableHandle(null, new SchemaTableName(TINY_SCHEMA_NAME, tpchTable.getTableName())); + insertRows(tpchMetadata.getTableMetadata(null, tableHandle), handle, createTpchRecordSet(tpchTable, tableHandle.getScaleFactor())); + } + + private static void insertRows(ConnectorTableMetadata tableMetadata, Handle handle, RecordSet data) + { + List columns = tableMetadata.getColumns().stream() + .filter(columnMetadata -> !columnMetadata.isHidden()) + .collect(toImmutableList()); + + String schemaName = "tpch"; + String tableNameWithSchema = schemaName + "." + tableMetadata.getTable().getTableName(); + String vars = Joiner.on(',').join(nCopies(columns.size(), "?")); + String sql = format("INSERT INTO %s VALUES (%s)", tableNameWithSchema, vars); + + RecordCursor cursor = data.cursor(); + while (true) { + // insert 1000 rows at a time + PreparedBatch batch = handle.prepareBatch(sql); + for (int row = 0; row < 1000; row++) { + if (!cursor.advanceNextPosition()) { + if (batch.size() > 0) { + batch.execute(); + } + return; + } + for (int column = 0; column < columns.size(); column++) { + Type type = columns.get(column).getType(); + if (BOOLEAN.equals(type)) { + batch.bind(column, cursor.getBoolean(column)); + } + else if (BIGINT.equals(type)) { + batch.bind(column, cursor.getLong(column)); + } + else if (INTEGER.equals(type)) { + batch.bind(column, (int) cursor.getLong(column)); + } + else if (DOUBLE.equals(type)) { + batch.bind(column, cursor.getDouble(column)); + } + else if (type instanceof VarcharType) { + batch.bind(column, cursor.getSlice(column).toStringUtf8()); + } + else if (DATE.equals(type)) { + long millisUtc = TimeUnit.DAYS.toMillis(cursor.getLong(column)); + // H2 expects dates in to be millis at midnight in the JVM timezone + long localMillis = DateTimeZone.UTC.getMillisKeepLocal(DateTimeZone.getDefault(), millisUtc); + batch.bind(column, new Date(localMillis)); + } + else { + throw new IllegalArgumentException("Unsupported type " + type); + } + } + batch.add(); + } + batch.execute(); + } + } +} diff --git a/presto-base-arrow-flight/src/test/resources/server.crt b/presto-base-arrow-flight/src/test/resources/server.crt new file mode 100644 index 000000000000..f070253c9e11 --- /dev/null +++ b/presto-base-arrow-flight/src/test/resources/server.crt @@ -0,0 +1,32 @@ +-----BEGIN CERTIFICATE----- +MIIFlTCCA32gAwIBAgIUf2p3qdxzpsOofYDpXNXDJA1fjbgwDQYJKoZIhvcNAQEL +BQAwWTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MCAX +DTI0MTIyMTA3MzUwMFoYDzIxMjQxMTI3MDczNTAwWjBZMQswCQYDVQQGEwJBVTET +MBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0cyBQ +dHkgTHRkMRIwEAYDVQQDDAlsb2NhbGhvc3QwggIiMA0GCSqGSIb3DQEBAQUAA4IC +DwAwggIKAoICAQCmsVs+OHGQ/GlR9+KpgcKkBvXZQAUtCJzFG27GXBYoltJDAxsi +o47KbGkUWxf5E+4oFbbz7s8OHSuQnjP2XVHbKvmgicf3uP98fKkQtjvl+xn8vT18 +s6a8Kv8hj+f6MRWGwpHa/sQKU9uZmRYTRh+32vDZRtAvmpzkf6+2K8B1fFbwVmuC +j3ULb+0iYHnomC3aMWBFXkxjEmsamx4YK74NtQU98+EjQZwhWgWXhW5shS1kSs4r +3N6++tonBz+tDKAhCMueRRJAQXGjKqL7qDZn7wpk53L/fZT9mgRYyA+PN2ND9L0H +nMGMjJl71p42kkgIGOpllsmK6+g0Bj6aC/uCEnX2AtM0g57Th7U2aLwgOmRshC0s +uuBHxMWUgzJsB1dscXrFPPB+XVcUlgcwRbGsQG5VzK/rYRV4y1FmF5vSLY60mF43 +hFDAcnh4mnnBkobca5Dl6PSUpFmDkF56IUgYQCTrE6hPYnrIJKhPcfpmr/tm/Acr +ra1sPp/QPSFIxI9j7Nzm/QsOBF3Zy4AbbbOmhJOjNtwEi59r9za/FhxV5kSN/YM9 +HyyYYebxW/jXF/7hzQzWYfJBz1SgdD9prl4ml8VMVJdhZmBTzhVciPXizRUKkbD+ +LvQKw8q4a24/VruUvW15J39qalhdyWf3vqGuORWowpG7oYnGXik3kYVT5QIDAQAB +o1MwUTAdBgNVHQ4EFgQUkkKtG5568IDoFtn3AK0Yes0CqGgwHwYDVR0jBBgwFoAU +kkKtG5568IDoFtn3AK0Yes0CqGgwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0B +AQsFAAOCAgEAVjwxyO5N4EGQHf9euKiTjksUMeJrJkPMFYS/1UmQB5/yc0TkOxv3 +M8yvTzkzxWAQV2X+CPbOLgqO2eCoSev2GHVUyHLLUEFATzAgTpF79Ovzq6JkSgFo +GFxTL0TuA+d++430wYyts7iYY9uT5Xt8xp4+SwXrCvLWxHosP3xGUmNvY4sP0hdS +beIvdGE/J8izgy5DRt2fWZ03mmwfKqiz/qhKGj9DDsHkA/1jyKIivP/nufr9dzDr +41lhk1N7qFWkOjbMd06NYySIe0MaapIkenjT1IOgqGw2f98RfSEomoaXuxN0VoSM +6dZ4rN97cER25X7/zE0zCZurjCLHzPTyuTspYGEK+9U4plOeWh0keQEcdpcHTAr+ +NqU3VlhXVxz91nVREpRJmKk+r4c+xdrfY3YkDcSJ1dazbd3eS05Ggx51KOqes8Zb +hFQfhIDqvaqXDNlkBezLpr4v/MU69+cp7SOn5uPnOccS6sd7fzl/PUZhEjd5ZIws +8SX79OwhQjbYZYRHJSPPasb8B1amULtoo0pJ5izSkXxileiGuXhRO5stiBux7SL+ +oJAztjuRf0IvP6LWOMHgqquzc2JiEDCz0DPnTCqoXGZlT2HPGzXOoDSTOmRdFx+L +qi/DeY+MpIMVov/rplqjydXw6AuQDxcV1GvyjMvaHxJG5MEBC/mVeqQ= +-----END CERTIFICATE----- \ No newline at end of file diff --git a/presto-base-arrow-flight/src/test/resources/server.key b/presto-base-arrow-flight/src/test/resources/server.key new file mode 100644 index 000000000000..4524c26943f2 --- /dev/null +++ b/presto-base-arrow-flight/src/test/resources/server.key @@ -0,0 +1,52 @@ +-----BEGIN PRIVATE KEY----- +MIIJQwIBADANBgkqhkiG9w0BAQEFAASCCS0wggkpAgEAAoICAQCmsVs+OHGQ/GlR +9+KpgcKkBvXZQAUtCJzFG27GXBYoltJDAxsio47KbGkUWxf5E+4oFbbz7s8OHSuQ +njP2XVHbKvmgicf3uP98fKkQtjvl+xn8vT18s6a8Kv8hj+f6MRWGwpHa/sQKU9uZ +mRYTRh+32vDZRtAvmpzkf6+2K8B1fFbwVmuCj3ULb+0iYHnomC3aMWBFXkxjEmsa +mx4YK74NtQU98+EjQZwhWgWXhW5shS1kSs4r3N6++tonBz+tDKAhCMueRRJAQXGj +KqL7qDZn7wpk53L/fZT9mgRYyA+PN2ND9L0HnMGMjJl71p42kkgIGOpllsmK6+g0 +Bj6aC/uCEnX2AtM0g57Th7U2aLwgOmRshC0suuBHxMWUgzJsB1dscXrFPPB+XVcU +lgcwRbGsQG5VzK/rYRV4y1FmF5vSLY60mF43hFDAcnh4mnnBkobca5Dl6PSUpFmD +kF56IUgYQCTrE6hPYnrIJKhPcfpmr/tm/Acrra1sPp/QPSFIxI9j7Nzm/QsOBF3Z +y4AbbbOmhJOjNtwEi59r9za/FhxV5kSN/YM9HyyYYebxW/jXF/7hzQzWYfJBz1Sg +dD9prl4ml8VMVJdhZmBTzhVciPXizRUKkbD+LvQKw8q4a24/VruUvW15J39qalhd +yWf3vqGuORWowpG7oYnGXik3kYVT5QIDAQABAoICAE0MeIzTgSbPjQz+w82u9V1k +/DlNfrb4nqH7EqJsSS+8uvaPlnjV2fgV0SJAEt4mCLSNiPHKpfkzoYHopkMPknj4 +Lcc3OG94GtubMXhQi3I7tSDeBfBAh+a9Bw2n20WJb5ZJFCsCDHJrnXsrSAljpeCR +OjdsJGmEkVWK8ZiGM6D6dqMDhxEjpynAs/7qUh8hTDxpC0M1GaDHkCsNnQT2HxRt +4jznH97wgi7mUeReIBLYIgmUDCU5I9ppz/EvSA8AYXmze46uBYge19xgJlKlR3SW +CJtoYf7XOMlZ6f1xh8OeiesM0l0U51/EU2Nq6dl2lwXrIlkPsBve/AckBcalmDwh +tJgAOHM8VEMp0imJ8+SPGZupiIt+sTVSt4aq2uB5dtTBDcXGx4gVUYJoekI4yQtv +OWKAQAobq0Cutogx8hyXr2tSizp5pGpJDKx8/G/3YcNmRqCq+HQgwshlHQy/hPqs +QC/jcuOkr5GOZASF3wLCuLmAwFCELL96iVEJvPLe5Yb1u4HbCrsy9/N8rHV1s1qa +xMcGiIS4m/fwLbD+TjgM1foESeQciHH7GWVZ3Osn6Inpp6vLJKDhLhGBgpaYF72S +0fft6CkIy/CqbwXPe/3Wm0PL0SLTsSk0lKhw6Zeuvg0Z0jUrmDmlTrGo8j4zU7cA +Gc8uiZuFAU2ZoS9R7SZXAoIBAQDSfF8H4C8jn9IvTw/S3APkP7VyLKFSTitEfagt +qcAZr7sYWYGdYDL1hX+jBopAWdzbJ+N3MZBccH67f3XDZWVH4GDQWF6tyOpyseBe +42aU/yCr6ZSoxi91W9V+oyiTFh+t2cEkNyGey/pbXVBw45jncHpzIAaERm5gD8DD +MxOadIvDxEGquGs9MUgZg6ABFmqvQIfiZ7I3PvfoHJGrHsPyM9i6AqAkJe1rKiN3 +tt2wiDk6zDor4vFKn+5LaRQ5INbpgBuCJrvDsRh3MeKLb0ddCYcAStHeAjHJZbH4 +OiC5y9hpaIcSQN3Z4gEhdMd+Lhi1+z36uN5k7KCvbGmO3sG/AoIBAQDKvMiVqTgm +a36FFswwZJd5o+Hx6G8Nt5y7ZYLJgoq4wyTPYA5ZIHBp2Jtb//IGJMKpeLcRigxx +nKu4zrLZPHesgYzHbCH9i5fOBFl8yF9pZnbCU/zGOYhiVdn7MHZgJw8nlstWr1Jr +cMnxBBXEjgL65Lrse1noiPATNrvBsirFaJky6HhxRPbkijKKd0Xqc0FeHMxt6IHU +y6yZRzguI0M1A8RGR4CrqsOGrdv0vkMSiumkUdz8JW4w9R69n6ax/qNAZSMql0ss +PIPOqGtYRPhibOhXKPl0p30X32YXx/SKm8+L9Sr1ny76dab/bSnxStOdihGJLCs4 +l7vFkuJgTMtbAoIBAQCYA3CSfIsu3EbtGdlgvLsmxggh7C+aBJBlB6dFSzpMksi5 +njLo2MgU35Q9xgRk00GZGWbC94295RTyDuya8IjD7z2cWqYONnNz4BkeDndQli0f +WzOc7Hzr8iXvLqCoEatRYFmH8TUbvU8TWwI0dXtBcs9Mg82RDFi8kcPyddnri85A +1WVjiYsRh5z9qD0PbAQii6VXkvJ3ycc64B8oCbEUI/Oa6ziCws2DvswcsnnK+6bx +WvuMJHuFHJn55mrPk3MC8h1r0tN6UlVMCEAH2ZcdjzrrsB1/i/Au9n4gusJVzO1/ +uxkJysUujXWplvBYpav9CfVKNOeQ1gB6kP5vS1t7AoIBAB1DquCPkJ9bHOQxKkBC +BOt2EINOvdkJDAKw4HQd99A7uvCEOQ38dL2Smrpo85KXc9HqruJFPw6XQuJmU8Kv +y8aG3L9ciHuEzuDaF+C/O6aHN9VNMkuaumkXY2Oy1yOB/9oDFk7o98iyezPjFxFM +Pnng0mqYU54RRjY/zFJlWW8tbg+/JsOS5OCQYkNCfEEfaewf1BJ5YWRKEhv9/8oJ +JQZeCNLsN1KQT7D9H6bwX9YpXxhtCK0M6h7/AvT0OqeuzfnZn33iYON9yLjn7rbL +Hd93QQJz065XDuOHR8FfB5mKbCcTuKPD2pAks3pjU46U8n7nEyjtyz9cB6q5TRwB +eckCggEBAMd/3riFnoc2VF6LsA0VQzZiJleN0G9g3AF/5acqHHjMv6aJs/AjV7ZF +hFdiqTcQFEg87HSOzrtV2DVNWJc7DMcMIuwGhVec+D4wedJ2wrQv+MEgySltSBZq +wcPVn5IQiml38NnG/tPIrHETb0PIoa8+iu/Jg80o7j8M3+DKVKfCyfh334PjFK6W +B/mkgC9PfcfeA/Doby9pJsRnqmAJTeWjxbefksckI4PcRCLMEwggB2ReiIWef8Q+ +IooNIxypWtBWtpNEl5lhO7Y2f65Whp34TjooXmGBl1lj4szO8PNnf9QA865J86OS +kOoFda4Sn7LajkbSX0wTGMuXDpmx34M= +-----END PRIVATE KEY----- \ No newline at end of file diff --git a/presto-docs/src/main/sphinx/connector.rst b/presto-docs/src/main/sphinx/connector.rst index f07506b460ca..d337fe4ed12d 100644 --- a/presto-docs/src/main/sphinx/connector.rst +++ b/presto-docs/src/main/sphinx/connector.rst @@ -9,6 +9,7 @@ from different data sources. :maxdepth: 1 connector/accumulo + connector/base-arrow-flight connector/bigquery connector/blackhole connector/cassandra diff --git a/presto-docs/src/main/sphinx/connector/base-arrow-flight.rst b/presto-docs/src/main/sphinx/connector/base-arrow-flight.rst new file mode 100644 index 000000000000..e74e3eec9c63 --- /dev/null +++ b/presto-docs/src/main/sphinx/connector/base-arrow-flight.rst @@ -0,0 +1,93 @@ + +====================== +Arrow Flight Connector +====================== +This connector allows querying multiple data sources that are supported by an Arrow Flight server. Flight supports parallel transfers, allowing data to be streamed to or from a cluster of servers simultaneously. Official documentation for Arrow Flight can be found at https://arrow.apache.org/docs/format/Flight.html + +Getting Started with presto-base-arrow-flight module: Essential Abstract Methods for Developers +--------------------------------------------------------------------------------- +To create a plugin extending the presto-base-arrow-flight module, you need to implement certain abstract methods that are specific to your use case. Below are the required classes and their purposes: + +- ``BaseArrowFlightClient.java`` + This class contains the core functionality for the Arrow Flight module. A plugin extending base-arrow-module should extend this abstract class and implement the following methods. + + - ``getCallOptions`` This method should return an array of credential call options to authenticate to the Flight server. + - ``getFlightDescriptorForSchema`` This method should return the flight descriptor to fetch Arrow Schema for the table. + - ``listSchemaNames`` This method should return the list of schemas in the catalog. + - ``listTables`` This method should return the list of tables in the catalog. + - ``getFlightDescriptorForTableScan`` This method should return the flight descriptor for fetching data from the table. + +- ``ArrowPlugin.java`` + Register your connector name by extending the ArrowPlugin class. +- ``ArrowBlockBuilder.java`` + This class builds Presto blocks from Arrow vectors. Extend this class if needed and override ``getPrestoTypeFromArrowField`` method, if any customizations are needed for the conversion of Arrow vector to Presto type. A binding for this class should be created in the ``Module`` for the plugin. + +There is a reference implementation of the presto-base-arrow-flight module available in the test folder of this module which uses a locally started Flight server to fetch data from an H2 database. The classes ``TestingArrowBlockBuilder``, ``TestingArrowFlightClient``, ``TestingArrowFlightPlugin``, ``TestingArrowFlightRequest``, ``TestingArrowFlightResponse``, ``TestingArrowModule`` and ``TestingArrowQueryBuilder`` make up the reference implementation. + + +Configuration +------------- +Create a catalog file +in ``etc/catalog`` named, for example, ``arrowmariadb.properties``, to +mount the Flight connector as the ``arrowmariadb`` catalog. +Create the file with the following contents, replacing the +connection properties as appropriate for your setup: + + +.. code-block:: none + + + connector.name= + arrow-flight.server= + arrow-flight.server.port= + + + +Add other properties that are required for your Flight server to connect. + +========================================== ============================================================== +Property Name Description +========================================== ============================================================== +``arrow-flight.server`` Endpoint of Arrow Flight server +``arrow-flight.server.port`` Flight server port +``arrow-flight.server-ssl-certificate`` Pass ssl certificate +``arrow-flight.server.verify`` To verify server +``arrow-flight.server-ssl-enabled`` Port is ssl enabled +========================================== ============================================================== + +Querying Arrow-Flight +--------------------- + +The Flight connector provides schema for each supported *database*. +Example for MariaDB is shown below. +To see the available schemas, run ``SHOW SCHEMAS``:: + + SHOW SCHEMAS FROM arrowmariadb; + +To view the tables in the MariaDB database named ``user``, +run ``SHOW TABLES``:: + + SHOW TABLES FROM arrowmariadb.user; + +To see a list of the columns in the ``admin`` table in the ``user`` database, +use either of the following commands:: + + DESCRIBE arrowmariadb.user.admin; + SHOW COLUMNS FROM arrowmariadb.user.admin; + +Finally, you can access the ``admin`` table in the ``user`` database:: + + SELECT * FROM arrowmariadb.user.admin; + +If you used a different name for your catalog properties file, use +that catalog name instead of ``arrowmariadb`` in the above examples. + + +Flight Connector Limitations +---------------------------- + +* SELECT and DESCRIBE queries are supported. Implementing modules can add support for additional features. + +* The Flight connector can query against only those datasources which are supported by the Flight server. + +* The Flight server must be running for the Flight connector to work.