diff --git a/pom.xml b/pom.xml index 8487bfedfa0bf..0259042c7380c 100644 --- a/pom.xml +++ b/pom.xml @@ -200,6 +200,7 @@ presto-singlestore presto-hana presto-openapi + presto-base-arrow-flight diff --git a/presto-base-arrow-flight/pom.xml b/presto-base-arrow-flight/pom.xml new file mode 100644 index 0000000000000..fd9871e6b2490 --- /dev/null +++ b/presto-base-arrow-flight/pom.xml @@ -0,0 +1,400 @@ + + + 4.0.0 + + + com.facebook.presto + presto-root + 0.290-SNAPSHOT + + presto-base-arrow-flight + presto-base-arrow-flight + Presto - Base Arrow Flight Connector + + + ${project.parent.basedir} + 4.10.0 + 17.0.0 + 4.1.110.Final + + + + + com.facebook.airlift + bootstrap + + + ch.qos.logback + logback-core + + + + + + com.facebook.airlift + log + + + + com.google.guava + guava + + + org.checkerframework + checker-qual + + + com.google.errorprone + error_prone_annotations + + + com.google.j2objc + j2objc-annotations + + + + + + javax.inject + javax.inject + + + + com.facebook.presto + presto-spi + provided + + + + io.airlift + slice + provided + + + + com.fasterxml.jackson.core + jackson-annotations + provided + + + + com.facebook.presto + presto-common + provided + + + + io.airlift + units + provided + + + + com.google.code.findbugs + jsr305 + true + + + + org.apache.arrow + arrow-memory-core + ${arrow.version} + + + + com.google.inject + guice + + + + com.facebook.airlift + configuration + + + + org.apache.arrow + arrow-jdbc + ${arrow.version} + + + + io.netty + netty-codec-http2 + ${netty.version} + + + + io.netty + netty-handler-proxy + ${netty.version} + + + io.netty + netty-codec-http + + + + + + io.netty + netty-tcnative-boringssl-static + 2.0.65.Final + + + + org.apache.arrow + arrow-vector + ${arrow.version} + + + com.fasterxml.jackson.datatype + jackson-datatype-jsr310 + + + + + + + org.testng + testng + test + + + + io.airlift.tpch + tpch + test + + + + joda-time + joda-time + + + + org.jdbi + jdbi3-core + + + + com.facebook.presto + presto-tpch + test + + + + com.facebook.airlift + json + test + + + + org.checkerframework + checker-qual + 3.4.1 + test + + + + com.facebook.presto + presto-testng-services + test + + + + com.facebook.airlift + testing + test + ${dep.airlift.version} + + + + com.fasterxml.jackson.core + jackson-databind + provided + + + + com.fasterxml.jackson.core + jackson-core + provided + + + + org.apache.arrow + flight-core + ${arrow.version} + + + com.google.j2objc + j2objc-annotations + + + + + + com.facebook.presto + presto-main + test + + + + com.facebook.presto + presto-tests + test + + + + com.h2database + h2 + test + + + + + + + + org.codehaus.mojo + animal-sniffer-annotations + 1.23 + + + + com.google.j2objc + j2objc-annotations + 1.3 + + + + com.google.errorprone + error_prone_annotations + 2.14.0 + + + + com.google.protobuf + protobuf-java + 3.25.5 + + + + commons-codec + commons-codec + 1.17.0 + + + + io.netty + netty-transport-native-unix-common + ${netty.version} + + + + io.netty + netty-common + ${netty.version} + + + + io.netty + netty-buffer + ${netty.version} + + + + io.netty + netty-handler + ${netty.version} + + + + io.netty + netty-transport + ${netty.version} + + + + io.netty + netty-codec + ${netty.version} + + + + org.objenesis + objenesis + 3.3 + + + + org.jetbrains.kotlin + kotlin-stdlib-common + 1.6.20 + + + + org.slf4j + slf4j-api + 2.0.13 + + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + -Xss10M + + + + org.apache.maven.plugins + maven-enforcer-plugin + + + + + com.google.errorprone:error_prone_annotations + + + + + + + org.apache.maven.plugins + maven-dependency-plugin + + + io.netty:netty-codec-http2 + io.netty:netty-handler-proxy + io.netty:netty-tcnative-boringssl-static + + + + + org.basepom.maven + duplicate-finder-maven-plugin + 1.2.1 + + + module-info + META-INF.versions.9.module-info + + + arrow-git.properties + about.html + + + + + + check + + + + + + + diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java new file mode 100644 index 0000000000000..0507a76749977 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java @@ -0,0 +1,353 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.BooleanType; +import com.facebook.presto.common.type.DateType; +import com.facebook.presto.common.type.DecimalType; +import com.facebook.presto.common.type.DoubleType; +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.common.type.RealType; +import com.facebook.presto.common.type.SmallintType; +import com.facebook.presto.common.type.TimeType; +import com.facebook.presto.common.type.TimestampType; +import com.facebook.presto.common.type.TinyintType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarbinaryType; +import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.ConnectorTableLayout; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.ConnectorTableLayoutResult; +import com.facebook.presto.spi.ConnectorTableMetadata; +import com.facebook.presto.spi.Constraint; +import com.facebook.presto.spi.NotFoundException; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.SchemaTablePrefix; +import com.facebook.presto.spi.connector.ConnectorMetadata; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_ERROR; +import static java.util.Objects.requireNonNull; + +public abstract class AbstractArrowMetadata + implements ConnectorMetadata +{ + private static final Logger logger = Logger.get(AbstractArrowMetadata.class); + private final ArrowFlightConfig config; + private final ArrowFlightClientHandler clientHandler; + + public AbstractArrowMetadata(ArrowFlightConfig config, ArrowFlightClientHandler clientHandler) + { + this.config = config; + this.clientHandler = requireNonNull(clientHandler); + } + + private ArrowColumnHandle createArrowColumnHandleForFloatingPointType(String columnName, ArrowType.FloatingPoint floatingPoint) + { + switch (floatingPoint.getPrecision()) { + case SINGLE: + return new ArrowColumnHandle(columnName, RealType.REAL); + case DOUBLE: + return new ArrowColumnHandle(columnName, DoubleType.DOUBLE); + default: + throw new ArrowException(ARROW_FLIGHT_ERROR, "Invalid floating point precision " + floatingPoint.getPrecision()); + } + } + + private ArrowColumnHandle createArrowColumnHandleForIntType(String columnName, ArrowType.Int intType) + { + switch (intType.getBitWidth()) { + case 64: + return new ArrowColumnHandle(columnName, BigintType.BIGINT); + case 32: + return new ArrowColumnHandle(columnName, IntegerType.INTEGER); + case 16: + return new ArrowColumnHandle(columnName, SmallintType.SMALLINT); + case 8: + return new ArrowColumnHandle(columnName, TinyintType.TINYINT); + default: + throw new ArrowException(ARROW_FLIGHT_ERROR, "Invalid bit width " + intType.getBitWidth()); + } + } + + private ColumnMetadata createIntColumnMetadata(String columnName, ArrowType.Int intType) + { + switch (intType.getBitWidth()) { + case 64: + return new ColumnMetadata(columnName, BigintType.BIGINT); + case 32: + return new ColumnMetadata(columnName, IntegerType.INTEGER); + case 16: + return new ColumnMetadata(columnName, SmallintType.SMALLINT); + case 8: + return new ColumnMetadata(columnName, TinyintType.TINYINT); + default: + throw new ArrowException(ARROW_FLIGHT_ERROR, "Invalid bit width " + intType.getBitWidth()); + } + } + + private ColumnMetadata createFloatingPointColumnMetadata(String columnName, ArrowType.FloatingPoint floatingPointType) + { + switch (floatingPointType.getPrecision()) { + case SINGLE: + return new ColumnMetadata(columnName, RealType.REAL); + case DOUBLE: + return new ColumnMetadata(columnName, DoubleType.DOUBLE); + default: + throw new ArrowException(ARROW_FLIGHT_ERROR, "Invalid floating point precision " + floatingPointType.getPrecision()); + } + } + + /** + * Provides the field type, which can be overridden by concrete implementations + * with their own custom type. + * + * @return the field type + */ + protected Type overrideFieldType(Field field, Type type) + { + return type; + } + + protected abstract ArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, Optional query, String schema, String table); + + protected abstract ArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, String schema); + + protected abstract String getDataSourceSpecificSchemaName(ArrowFlightConfig config, String schemaName); + + protected abstract String getDataSourceSpecificTableName(ArrowFlightConfig config, String tableName); + + @Override + public ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTableName tableName) + { + if (!listSchemaNames(session).contains(tableName.getSchemaName())) { + return null; + } + + if (!listTables(session, Optional.ofNullable(tableName.getSchemaName())).contains(tableName)) { + return null; + } + return new ArrowTableHandle(tableName.getSchemaName(), tableName.getTableName()); + } + + public List getColumnsList(String schema, String table, ConnectorSession connectorSession) + { + try { + String dataSourceSpecificSchemaName = getDataSourceSpecificSchemaName(config, schema); + String dataSourceSpecificTableName = getDataSourceSpecificTableName(config, table); + ArrowFlightRequest request = getArrowFlightRequest(clientHandler.getConfig(), Optional.empty(), + dataSourceSpecificSchemaName, dataSourceSpecificTableName); + + FlightInfo flightInfo = clientHandler.getFlightInfo(request, connectorSession); + List fields = flightInfo.getSchema().getFields(); + return fields; + } + catch (Exception e) { + throw new ArrowException(ARROW_FLIGHT_ERROR, "The table columns could not be listed for the table " + table, e); + } + } + + @Override + public Map getColumnHandles(ConnectorSession session, ConnectorTableHandle tableHandle) + { + Map column = new HashMap<>(); + + String schemaValue = ((ArrowTableHandle) tableHandle).getSchema(); + String tableValue = ((ArrowTableHandle) tableHandle).getTable(); + String dbSpecificSchemaValue = getDataSourceSpecificSchemaName(config, schemaValue); + String dBSpecificTableName = getDataSourceSpecificTableName(config, tableValue); + List columnList = getColumnsList(dbSpecificSchemaValue, dBSpecificTableName, session); + + for (Field field : columnList) { + String columnName = field.getName(); + logger.debug("The value of the flight columnName is:- %s", columnName); + + ArrowColumnHandle handle; + switch (field.getType().getTypeID()) { + case Int: + ArrowType.Int intType = (ArrowType.Int) field.getType(); + handle = createArrowColumnHandleForIntType(columnName, intType); + break; + case Binary: + case LargeBinary: + case FixedSizeBinary: + handle = new ArrowColumnHandle(columnName, VarbinaryType.VARBINARY); + break; + case Date: + handle = new ArrowColumnHandle(columnName, DateType.DATE); + break; + case Timestamp: + handle = new ArrowColumnHandle(columnName, TimestampType.TIMESTAMP); + break; + case Utf8: + case LargeUtf8: + handle = new ArrowColumnHandle(columnName, VarcharType.VARCHAR); + break; + case FloatingPoint: + ArrowType.FloatingPoint floatingPoint = (ArrowType.FloatingPoint) field.getType(); + handle = createArrowColumnHandleForFloatingPointType(columnName, floatingPoint); + break; + case Decimal: + ArrowType.Decimal decimalType = (ArrowType.Decimal) field.getType(); + handle = new ArrowColumnHandle(columnName, DecimalType.createDecimalType(decimalType.getPrecision(), decimalType.getScale())); + break; + case Bool: + handle = new ArrowColumnHandle(columnName, BooleanType.BOOLEAN); + break; + case Time: + handle = new ArrowColumnHandle(columnName, TimeType.TIME); + break; + default: + throw new UnsupportedOperationException("The data type " + field.getType().getTypeID() + " is not supported."); + } + Type type = overrideFieldType(field, handle.getColumnType()); + if (!type.equals(handle.getColumnType())) { + handle = new ArrowColumnHandle(columnName, type); + } + column.put(columnName, handle); + } + return column; + } + + @Override + public List getTableLayouts(ConnectorSession session, ConnectorTableHandle table, Constraint constraint, Optional> desiredColumns) + { + ArrowTableHandle tableHandle = (ArrowTableHandle) table; + + List columns = new ArrayList<>(); + if (desiredColumns.isPresent()) { + List arrowColumns = new ArrayList<>(desiredColumns.get()); + columns = (List) (List) arrowColumns; + } + + ConnectorTableLayout layout = new ConnectorTableLayout(new ArrowTableLayoutHandle(tableHandle, columns, constraint.getSummary())); + return ImmutableList.of(new ConnectorTableLayoutResult(layout, constraint.getSummary())); + } + + @Override + public ConnectorTableLayout getTableLayout(ConnectorSession session, ConnectorTableLayoutHandle handle) + { + return new ConnectorTableLayout(handle); + } + + @Override + public ConnectorTableMetadata getTableMetadata(ConnectorSession session, ConnectorTableHandle table) + { + List meta = new ArrayList<>(); + List columnList = getColumnsList(((ArrowTableHandle) table).getSchema(), ((ArrowTableHandle) table).getTable(), session); + + for (Field field : columnList) { + String columnName = field.getName(); + ArrowType type = field.getType(); + + ColumnMetadata columnMetadata; + + switch (type.getTypeID()) { + case Int: + ArrowType.Int intType = (ArrowType.Int) type; + columnMetadata = createIntColumnMetadata(columnName, intType); + break; + case Binary: + case LargeBinary: + case FixedSizeBinary: + columnMetadata = new ColumnMetadata(columnName, VarbinaryType.VARBINARY); + break; + case Date: + columnMetadata = new ColumnMetadata(columnName, DateType.DATE); + break; + case Timestamp: + columnMetadata = new ColumnMetadata(columnName, TimestampType.TIMESTAMP); + break; + case Utf8: + case LargeUtf8: + columnMetadata = new ColumnMetadata(columnName, VarcharType.VARCHAR); + break; + case FloatingPoint: + ArrowType.FloatingPoint floatingPointType = (ArrowType.FloatingPoint) type; + columnMetadata = createFloatingPointColumnMetadata(columnName, floatingPointType); + break; + case Decimal: + ArrowType.Decimal decimalType = (ArrowType.Decimal) type; + columnMetadata = new ColumnMetadata(columnName, DecimalType.createDecimalType(decimalType.getPrecision(), decimalType.getScale())); + break; + case Time: + columnMetadata = new ColumnMetadata(columnName, TimeType.TIME); + break; + case Bool: + columnMetadata = new ColumnMetadata(columnName, BooleanType.BOOLEAN); + break; + default: + throw new UnsupportedOperationException("The data type " + type.getTypeID() + " is not supported."); + } + + Type fieldType = overrideFieldType(field, columnMetadata.getType()); + if (!fieldType.equals(columnMetadata.getType())) { + columnMetadata = new ColumnMetadata(columnName, fieldType); + } + meta.add(columnMetadata); + } + return new ConnectorTableMetadata(new SchemaTableName(((ArrowTableHandle) table).getSchema(), ((ArrowTableHandle) table).getTable()), meta); + } + + @Override + public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTableHandle tableHandle, ColumnHandle columnHandle) + { + return ((ArrowColumnHandle) columnHandle).getColumnMetadata(); + } + + @Override + public Map> listTableColumns(ConnectorSession session, SchemaTablePrefix prefix) + { + requireNonNull(prefix, "prefix is null"); + ImmutableMap.Builder> columns = ImmutableMap.builder(); + List tables; + if (prefix.getSchemaName() != null && prefix.getTableName() != null) { + tables = ImmutableList.of(new SchemaTableName(prefix.getSchemaName(), prefix.getTableName())); + } + else { + tables = listTables(session, Optional.of(prefix.getSchemaName())); + } + + for (SchemaTableName tableName : tables) { + try { + ConnectorTableHandle tableHandle = getTableHandle(session, tableName); + columns.put(tableName, getTableMetadata(session, tableHandle).getColumns()); + } + catch (ClassCastException | NotFoundException e) { + throw new ArrowException(ARROW_FLIGHT_ERROR, "The table columns could not be listed for the table " + tableName, e); + } + catch (Exception e) { + throw new ArrowException(ARROW_FLIGHT_ERROR, e.getMessage(), e); + } + } + return columns.build(); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowSplitManager.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowSplitManager.java new file mode 100644 index 0000000000000..253158c5d8884 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowSplitManager.java @@ -0,0 +1,61 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorSplitSource; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.FixedSplitSource; +import com.facebook.presto.spi.connector.ConnectorSplitManager; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import org.apache.arrow.flight.FlightInfo; + +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +public abstract class AbstractArrowSplitManager + implements ConnectorSplitManager +{ + private static final Logger logger = Logger.get(AbstractArrowSplitManager.class); + private final ArrowFlightClientHandler clientHandler; + + public AbstractArrowSplitManager(ArrowFlightClientHandler client) + { + this.clientHandler = client; + } + + protected abstract ArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, ArrowTableLayoutHandle tableLayoutHandle); + + @Override + public ConnectorSplitSource getSplits(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorTableLayoutHandle layout, SplitSchedulingContext splitSchedulingContext) + { + ArrowTableLayoutHandle tableLayoutHandle = (ArrowTableLayoutHandle) layout; + ArrowTableHandle tableHandle = tableLayoutHandle.getTableHandle(); + ArrowFlightRequest request = getArrowFlightRequest(clientHandler.getConfig(), + tableLayoutHandle); + + FlightInfo flightInfo = clientHandler.getFlightInfo(request, session); + List splits = flightInfo.getEndpoints() + .stream() + .map(info -> new ArrowSplit( + tableHandle.getSchema(), + tableHandle.getTable(), + info.getTicket().getBytes(), + info.getLocations().stream().map(location -> location.getUri().toString()).collect(Collectors.toList()))) + .collect(Collectors.collectingAndThen(Collectors.toList(), Collections::unmodifiableList)); + return new FixedSplitSource(splits); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowColumnHandle.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowColumnHandle.java new file mode 100644 index 0000000000000..1ee3791e37457 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowColumnHandle.java @@ -0,0 +1,61 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ColumnMetadata; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import static java.util.Objects.requireNonNull; + +public class ArrowColumnHandle + implements ColumnHandle +{ + private final String columnName; + private final Type columnType; + + @JsonCreator + public ArrowColumnHandle( + @JsonProperty("columnName") String columnName, + @JsonProperty("columnType") Type columnType) + { + this.columnName = requireNonNull(columnName, "columnName is null"); + this.columnType = requireNonNull(columnType, "type is null"); + } + + @JsonProperty + public String getColumnName() + { + return columnName; + } + + @JsonProperty + public Type getColumnType() + { + return columnType; + } + + public ColumnMetadata getColumnMetadata() + { + return new ColumnMetadata(columnName, columnType); + } + + @Override + public String toString() + { + return columnName + ":" + columnType; + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java new file mode 100644 index 0000000000000..9523bd210faa5 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java @@ -0,0 +1,77 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.ConnectorHandleResolver; +import com.facebook.presto.spi.connector.Connector; +import com.facebook.presto.spi.connector.ConnectorMetadata; +import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; +import com.facebook.presto.spi.connector.ConnectorSplitManager; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.transaction.IsolationLevel; +import com.google.inject.Inject; + +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public class ArrowConnector + implements Connector +{ + private final ConnectorMetadata metadata; + private final ConnectorSplitManager splitManager; + private final ConnectorPageSourceProvider pageSourceProvider; + private final ConnectorHandleResolver handleResolver; + + @Inject + public ArrowConnector(ConnectorMetadata metadata, + ConnectorHandleResolver handleResolver, + ConnectorSplitManager splitManager, + ConnectorPageSourceProvider pageSourceProvider) + { + this.metadata = requireNonNull(metadata, "Metadata is null"); + this.handleResolver = requireNonNull(handleResolver, "Metadata is null"); + this.splitManager = requireNonNull(splitManager, "SplitManager is null"); + this.pageSourceProvider = requireNonNull(pageSourceProvider, "PageSinkProvider is null"); + } + + public Optional getHandleResolver() + { + return Optional.of(handleResolver); + } + + @Override + public ConnectorTransactionHandle beginTransaction(IsolationLevel isolationLevel, boolean readOnly) + { + return ArrowTransactionHandle.INSTANCE; + } + + @Override + public ConnectorMetadata getMetadata(ConnectorTransactionHandle transactionHandle) + { + return metadata; + } + + @Override + public ConnectorSplitManager getSplitManager() + { + return splitManager; + } + + @Override + public ConnectorPageSourceProvider getPageSourceProvider() + { + return pageSourceProvider; + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorFactory.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorFactory.java new file mode 100644 index 0000000000000..f17317dd42b3a --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorFactory.java @@ -0,0 +1,98 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.bootstrap.Bootstrap; +import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.spi.ConnectorHandleResolver; +import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.classloader.ThreadContextClassLoader; +import com.facebook.presto.spi.connector.Connector; +import com.facebook.presto.spi.connector.ConnectorContext; +import com.facebook.presto.spi.connector.ConnectorFactory; +import com.facebook.presto.spi.function.FunctionMetadataManager; +import com.facebook.presto.spi.function.StandardFunctionResolution; +import com.facebook.presto.spi.relation.RowExpressionService; +import com.google.inject.ConfigurationException; +import com.google.inject.Injector; +import com.google.inject.Module; + +import java.util.Map; + +import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_ERROR; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Strings.isNullOrEmpty; +import static com.google.common.base.Throwables.throwIfUnchecked; +import static java.util.Objects.requireNonNull; + +public class ArrowConnectorFactory + implements ConnectorFactory +{ + private final String name; + private final Module module; + private final ClassLoader classLoader; + + public ArrowConnectorFactory(String name, Module module, ClassLoader classLoader) + { + checkArgument(!isNullOrEmpty(name), "name is null or empty"); + this.name = name; + this.module = requireNonNull(module, "module is null"); + this.classLoader = requireNonNull(classLoader, "classLoader is null"); + } + + @Override + public String getName() + { + return name; + } + + @Override + public ConnectorHandleResolver getHandleResolver() + { + return new ArrowHandleResolver(); + } + + @Override + public Connector create(String catalogName, Map requiredConfig, ConnectorContext context) + { + requireNonNull(requiredConfig, "requiredConfig is null"); + + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + Bootstrap app = new Bootstrap( + binder -> { + binder.bind(TypeManager.class).toInstance(context.getTypeManager()); + binder.bind(FunctionMetadataManager.class).toInstance(context.getFunctionMetadataManager()); + binder.bind(StandardFunctionResolution.class).toInstance(context.getStandardFunctionResolution()); + binder.bind(RowExpressionService.class).toInstance(context.getRowExpressionService()); + binder.bind(NodeManager.class).toInstance(context.getNodeManager()); + }, + new ArrowModule(catalogName), + module); + + Injector injector = app + .doNotInitializeLogging() + .setRequiredConfigurationProperties(requiredConfig) + .initialize(); + + return injector.getInstance(ArrowConnector.class); + } + catch (ConfigurationException ex) { + throw new ArrowException(ARROW_FLIGHT_ERROR, "The connector instance could not be created.", ex); + } + catch (Exception e) { + throwIfUnchecked(e); + throw new RuntimeException(e); + } + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorId.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorId.java new file mode 100644 index 0000000000000..dce08bac4ac24 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorId.java @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +public class ArrowConnectorId +{ + private final String id; + + public ArrowConnectorId(String id) + { + this.id = requireNonNull(id, "id is null"); + } + + @Override + public String toString() + { + return id; + } + + @Override + public int hashCode() + { + return Objects.hash(id); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if ((obj == null) || (getClass() != obj.getClass())) { + return false; + } + ArrowConnectorId other = (ArrowConnectorId) obj; + return Objects.equals(this.id, other.id); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowErrorCode.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowErrorCode.java new file mode 100644 index 0000000000000..2e33f736a62c5 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowErrorCode.java @@ -0,0 +1,43 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.ErrorCode; +import com.facebook.presto.common.ErrorType; +import com.facebook.presto.spi.ErrorCodeSupplier; + +import static com.facebook.presto.common.ErrorType.EXTERNAL; +import static com.facebook.presto.common.ErrorType.INTERNAL_ERROR; + +public enum ArrowErrorCode + implements ErrorCodeSupplier +{ + ARROW_INVALID_TABLE(0, EXTERNAL), + ARROW_INVALID_CREDENTAILS(1, EXTERNAL), + ARROW_FLIGHT_ERROR(2, EXTERNAL), + ARROW_INTERNAL_ERROR(3, INTERNAL_ERROR); + + private final ErrorCode errorCode; + + ArrowErrorCode(int code, ErrorType type) + { + errorCode = new ErrorCode(code + 0x0509_0000, name(), type); + } + + @Override + public ErrorCode toErrorCode() + { + return errorCode; + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowException.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowException.java new file mode 100644 index 0000000000000..ba2c6edba589c --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowException.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.ErrorCodeSupplier; +import com.facebook.presto.spi.PrestoException; + +public class ArrowException + extends PrestoException +{ + public ArrowException(ErrorCodeSupplier errorCode, String message) + { + super(errorCode, message); + } + + public ArrowException(ErrorCodeSupplier errorCode, String message, Throwable cause) + { + super(errorCode, message, cause); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClient.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClient.java new file mode 100644 index 0000000000000..1a9d964d7458d --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClient.java @@ -0,0 +1,61 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.memory.RootAllocator; + +import java.io.IOException; +import java.io.InputStream; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public class ArrowFlightClient + implements AutoCloseable +{ + private final FlightClient flightClient; + private final Optional trustedCertificate; + private RootAllocator allocator; + + public ArrowFlightClient(FlightClient flightClient, Optional trustedCertificate, RootAllocator allocator) + { + this.flightClient = requireNonNull(flightClient, "flightClient cannot be null"); + this.trustedCertificate = trustedCertificate; + this.allocator = allocator; + } + + public FlightClient getFlightClient() + { + return flightClient; + } + + public Optional getTrustedCertificate() + { + return trustedCertificate; + } + + @Override + public void close() throws InterruptedException, IOException + { + flightClient.close(); + if (trustedCertificate.isPresent()) { + trustedCertificate.get().close(); + } + if (allocator != null) { + allocator.close(); + allocator = null; + } + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java new file mode 100644 index 0000000000000..ba75cbffc088c --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java @@ -0,0 +1,114 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.spi.ConnectorSession; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.grpc.CredentialCallOption; +import org.apache.arrow.memory.RootAllocator; + +import java.io.FileInputStream; +import java.io.InputStream; +import java.util.Optional; + +import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_ERROR; + +public abstract class ArrowFlightClientHandler +{ + private static final Logger logger = Logger.get(ArrowFlightClientHandler.class); + private final ArrowFlightConfig config; + + public ArrowFlightClientHandler(ArrowFlightConfig config) + { + this.config = config; + } + + private ArrowFlightClient initializeClient(Optional uri) + { + try { + RootAllocator allocator = new RootAllocator(Long.MAX_VALUE); + Optional trustedCertificate = Optional.empty(); + + Location location; + if (uri.isPresent()) { + location = new Location(uri.get()); + } + else { + if (config.getArrowFlightServerSslEnabled() != null && !config.getArrowFlightServerSslEnabled()) { + location = Location.forGrpcInsecure(config.getFlightServerName(), config.getArrowFlightPort()); + } + else { + location = Location.forGrpcTls(config.getFlightServerName(), config.getArrowFlightPort()); + } + } + + FlightClient.Builder flightClientBuilder = FlightClient.builder(allocator, location); + if (config.getVerifyServer() != null && !config.getVerifyServer()) { + flightClientBuilder.verifyServer(false); + } + else if (config.getFlightServerSSLCertificate() != null) { + trustedCertificate = Optional.of(new FileInputStream(config.getFlightServerSSLCertificate())); + flightClientBuilder.trustedCertificates(trustedCertificate.get()).useTls(); + } + + FlightClient flightClient = flightClientBuilder.build(); + return new ArrowFlightClient(flightClient, trustedCertificate, allocator); + } + catch (Exception ex) { + throw new ArrowException(ARROW_FLIGHT_ERROR, "The flight client could not be obtained." + ex.getMessage(), ex); + } + } + + protected abstract CredentialCallOption getCallOptions(ConnectorSession connectorSession); + + /** + * Connector implementations can override this method to get a FlightDescriptor + * from command or path. + * @param flightRequest + * @return + */ + protected FlightDescriptor getFlightDescriptor(ArrowFlightRequest flightRequest) + { + return FlightDescriptor.command(flightRequest.getCommand()); + } + + public ArrowFlightConfig getConfig() + { + return config; + } + + public ArrowFlightClient getClient(Optional uri) + { + return initializeClient(uri); + } + + public FlightInfo getFlightInfo(ArrowFlightRequest request, ConnectorSession connectorSession) + { + try (ArrowFlightClient client = getClient(Optional.empty())) { + CredentialCallOption auth = this.getCallOptions(connectorSession); + FlightDescriptor descriptor = getFlightDescriptor(request); + logger.debug("Fetching flight info"); + FlightInfo flightInfo = client.getFlightClient().getInfo(descriptor, auth); + logger.debug("got flight info"); + return flightInfo; + } + catch (Exception e) { + throw new ArrowException(ARROW_FLIGHT_ERROR, "The flight information could not be obtained from the flight server." + e.getMessage(), e); + } + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightConfig.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightConfig.java new file mode 100644 index 0000000000000..a30301dcdc361 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightConfig.java @@ -0,0 +1,84 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.configuration.Config; + +public class ArrowFlightConfig +{ + private String server; + private Boolean verifyServer; + private String flightServerSSLCertificate; + private Boolean arrowFlightServerSslEnabled; + private Integer arrowFlightPort; + public String getFlightServerName() + { + return server; + } + + public Boolean getVerifyServer() + { + return verifyServer; + } + + public Boolean getArrowFlightServerSslEnabled() + { + return arrowFlightServerSslEnabled; + } + + public String getFlightServerSSLCertificate() + { + return flightServerSSLCertificate; + } + + public Integer getArrowFlightPort() + { + return arrowFlightPort; + } + + @Config("arrow-flight.server") + public ArrowFlightConfig setFlightServerName(String server) + { + this.server = server; + return this; + } + + @Config("arrow-flight.server.verify") + public ArrowFlightConfig setVerifyServer(Boolean verifyServer) + { + this.verifyServer = verifyServer; + return this; + } + + @Config("arrow-flight.server.port") + public ArrowFlightConfig setArrowFlightPort(Integer arrowFlightPort) + { + this.arrowFlightPort = arrowFlightPort; + return this; + } + + @Config("arrow-flight.server-ssl-certificate") + public ArrowFlightConfig setFlightServerSSLCertificate(String flightServerSSLCertificate) + { + this.flightServerSSLCertificate = flightServerSSLCertificate; + return this; + } + + @Config("arrow-flight.server-ssl-enabled") + public ArrowFlightConfig setArrowFlightServerSslEnabled(Boolean arrowFlightServerSslEnabled) + { + this.arrowFlightServerSslEnabled = arrowFlightServerSslEnabled; + return this; + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightRequest.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightRequest.java new file mode 100644 index 0000000000000..7e04e0a6066e3 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightRequest.java @@ -0,0 +1,19 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +public interface ArrowFlightRequest +{ + byte[] getCommand(); +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowHandleResolver.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowHandleResolver.java new file mode 100644 index 0000000000000..8b231b98a6ee6 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowHandleResolver.java @@ -0,0 +1,55 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorHandleResolver; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; + +public class ArrowHandleResolver + implements ConnectorHandleResolver +{ + @Override + public Class getTableHandleClass() + { + return ArrowTableHandle.class; + } + + @Override + public Class getTableLayoutHandleClass() + { + return ArrowTableLayoutHandle.class; + } + + @Override + public Class getColumnHandleClass() + { + return ArrowColumnHandle.class; + } + + @Override + public Class getSplitClass() + { + return ArrowSplit.class; + } + + @Override + public Class getTransactionHandleClass() + { + return ArrowTransactionHandle.class; + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowModule.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowModule.java new file mode 100644 index 0000000000000..b20762f8ad497 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowModule.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.ConnectorHandleResolver; +import com.facebook.presto.spi.connector.Connector; +import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; +import com.google.inject.Binder; +import com.google.inject.Module; +import com.google.inject.Scopes; + +import static com.facebook.airlift.configuration.ConfigBinder.configBinder; +import static java.util.Objects.requireNonNull; + +public class ArrowModule + implements Module +{ + protected final String connectorId; + + public ArrowModule(String connectorId) + { + this.connectorId = requireNonNull(connectorId, "connector id is null"); + } + + public void configure(Binder binder) + { + configBinder(binder).bindConfig(ArrowFlightConfig.class); + binder.bind(ArrowConnector.class).in(Scopes.SINGLETON); + binder.bind(ArrowConnectorId.class).toInstance(new ArrowConnectorId(connectorId)); + binder.bind(ConnectorHandleResolver.class).to(ArrowHandleResolver.class).in(Scopes.SINGLETON); + binder.bind(ArrowPageSourceProvider.class).in(Scopes.SINGLETON); + binder.bind(ConnectorPageSourceProvider.class).to(ArrowPageSourceProvider.class).in(Scopes.SINGLETON); + binder.bind(Connector.class).to(ArrowConnector.class).in(Scopes.SINGLETON); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java new file mode 100644 index 0000000000000..b83e791f2163e --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java @@ -0,0 +1,640 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.CharType; +import com.facebook.presto.common.type.DateType; +import com.facebook.presto.common.type.DecimalType; +import com.facebook.presto.common.type.Decimals; +import com.facebook.presto.common.type.TimeType; +import com.facebook.presto.common.type.TimestampType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.spi.ConnectorPageSource; +import com.facebook.presto.spi.ConnectorSession; +import com.google.common.base.CharMatcher; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import org.apache.arrow.flight.FlightRuntimeException; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.NullVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeMicroVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeSecVector; +import org.apache.arrow.vector.TimeStampMicroVector; +import org.apache.arrow.vector.TimeStampMilliTZVector; +import org.apache.arrow.vector.TimeStampMilliVector; +import org.apache.arrow.vector.TimeStampSecVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; + +import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.LocalTime; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.TimeUnit; + +import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_ERROR; + +public class ArrowPageSource + implements ConnectorPageSource +{ + private static final Logger logger = Logger.get(ArrowPageSource.class); + private final ArrowSplit split; + private final List columnHandles; + private boolean completed; + private int currentPosition; + private Optional vectorSchemaRoot = Optional.empty(); + private ArrowFlightClient flightClient; + private FlightStream flightStream; + + public ArrowPageSource(ArrowSplit split, List columnHandles, ArrowFlightClientHandler clientHandler, + ConnectorSession connectorSession) + { + this.columnHandles = columnHandles; + this.split = split; + getFlightStream(clientHandler, split.getTicket(), connectorSession); + } + + private void getFlightStream(ArrowFlightClientHandler clientHandler, byte[] ticket, ConnectorSession connectorSession) + { + try { + Optional uri = (split.getLocationUrls().isEmpty()) ? + Optional.empty() : Optional.of(split.getLocationUrls().get(0)); + flightClient = clientHandler.getClient(uri); + flightStream = flightClient.getFlightClient().getStream(new Ticket(ticket), clientHandler.getCallOptions(connectorSession)); + } + catch (FlightRuntimeException e) { + throw new ArrowException(ARROW_FLIGHT_ERROR, e.getMessage(), e); + } + } + + private Block buildBlockFromVector(FieldVector vector, Type type) + { + if (vector instanceof BitVector) { + return buildBlockFromBitVector((BitVector) vector, type); + } + else if (vector instanceof TinyIntVector) { + return buildBlockFromTinyIntVector((TinyIntVector) vector, type); + } + else if (vector instanceof IntVector) { + return buildBlockFromIntVector((IntVector) vector, type); + } + else if (vector instanceof SmallIntVector) { + return buildBlockFromSmallIntVector((SmallIntVector) vector, type); + } + else if (vector instanceof BigIntVector) { + return buildBlockFromBigIntVector((BigIntVector) vector, type); + } + else if (vector instanceof DecimalVector) { + return buildBlockFromDecimalVector((DecimalVector) vector, type); + } + else if (vector instanceof NullVector) { + return buildBlockFromNullVector((NullVector) vector, type); + } + else if (vector instanceof TimeStampMicroVector) { + return buildBlockFromTimeStampMicroVector((TimeStampMicroVector) vector, type); + } + else if (vector instanceof TimeStampMilliVector) { + return buildBlockFromTimeStampMilliVector((TimeStampMilliVector) vector, type); + } + else if (vector instanceof Float4Vector) { + return buildBlockFromFloat4Vector((Float4Vector) vector, type); + } + else if (vector instanceof Float8Vector) { + return buildBlockFromFloat8Vector((Float8Vector) vector, type); + } + else if (vector instanceof VarCharVector) { + if (type instanceof CharType) { + return buildCharTypeBlockFromVarcharVector((VarCharVector) vector, type); + } + else if (type instanceof TimeType) { + return buildTimeTypeBlockFromVarcharVector((VarCharVector) vector, type); + } + else { + return buildBlockFromVarCharVector((VarCharVector) vector, type); + } + } + else if (vector instanceof VarBinaryVector) { + return buildBlockFromVarBinaryVector((VarBinaryVector) vector, type); + } + else if (vector instanceof DateDayVector) { + return buildBlockFromDateDayVector((DateDayVector) vector, type); + } + else if (vector instanceof DateMilliVector) { + return buildBlockFromDateMilliVector((DateMilliVector) vector, type); + } + else if (vector instanceof TimeMilliVector) { + return buildBlockFromTimeMilliVector((TimeMilliVector) vector, type); + } + else if (vector instanceof TimeSecVector) { + return buildBlockFromTimeSecVector((TimeSecVector) vector, type); + } + else if (vector instanceof TimeStampSecVector) { + return buildBlockFromTimeStampSecVector((TimeStampSecVector) vector, type); + } + else if (vector instanceof TimeMicroVector) { + return buildBlockFromTimeMicroVector((TimeMicroVector) vector, type); + } + else if (vector instanceof TimeStampMilliTZVector) { + return buildBlockFromTimeMilliTZVector((TimeStampMilliTZVector) vector, type); + } + throw new UnsupportedOperationException("Unsupported vector type: " + vector.getClass().getSimpleName()); + } + + private Block buildBlockFromTimeMilliTZVector(TimeStampMilliTZVector vector, Type type) + { + if (!(type instanceof TimestampType)) { + throw new IllegalArgumentException("Type must be a TimestampType for TimeStampMilliTZVector"); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long millis = vector.get(i); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + private Block buildBlockFromBitVector(BitVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeBoolean(builder, vector.get(i) == 1); + } + } + return builder.build(); + } + + private Block buildBlockFromIntVector(IntVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + return builder.build(); + } + + private Block buildBlockFromSmallIntVector(SmallIntVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + return builder.build(); + } + + private Block buildBlockFromTinyIntVector(TinyIntVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + return builder.build(); + } + + private Block buildBlockFromBigIntVector(BigIntVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + return builder.build(); + } + + private Block buildBlockFromDecimalVector(DecimalVector vector, Type type) + { + if (!(type instanceof DecimalType)) { + throw new IllegalArgumentException("Type must be a DecimalType for DecimalVector"); + } + + DecimalType decimalType = (DecimalType) type; + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + BigDecimal decimal = vector.getObject(i); // Get the BigDecimal value + if (decimalType.isShort()) { + builder.writeLong(decimal.unscaledValue().longValue()); + } + else { + Slice slice = Decimals.encodeScaledValue(decimal); + decimalType.writeSlice(builder, slice, 0, slice.length()); + } + } + } + return builder.build(); + } + + private Block buildBlockFromNullVector(NullVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + builder.appendNull(); + } + return builder.build(); + } + + private Block buildBlockFromTimeStampMicroVector(TimeStampMicroVector vector, Type type) + { + if (!(type instanceof TimestampType)) { + throw new IllegalArgumentException("Expected TimestampType but got " + type.getClass().getName()); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long micros = vector.get(i); + long millis = TimeUnit.MICROSECONDS.toMillis(micros); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + private Block buildBlockFromTimeStampMilliVector(TimeStampMilliVector vector, Type type) + { + if (!(type instanceof TimestampType)) { + throw new IllegalArgumentException("Expected TimestampType but got " + type.getClass().getName()); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long millis = vector.get(i); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + private Block buildBlockFromFloat8Vector(Float8Vector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeDouble(builder, vector.get(i)); + } + } + return builder.build(); + } + + private Block buildBlockFromFloat4Vector(Float4Vector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + int intBits = Float.floatToIntBits(vector.get(i)); + type.writeLong(builder, intBits); + } + } + return builder.build(); + } + + private Block buildBlockFromVarBinaryVector(VarBinaryVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + byte[] value = vector.get(i); + type.writeSlice(builder, Slices.wrappedBuffer(value)); + } + } + return builder.build(); + } + + private Block buildBlockFromVarCharVector(VarCharVector vector, Type type) + { + if (!(type instanceof VarcharType)) { + throw new IllegalArgumentException("Expected VarcharType but got " + type.getClass().getName()); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + String value = new String(vector.get(i), StandardCharsets.UTF_8); + type.writeSlice(builder, Slices.utf8Slice(value)); + } + } + return builder.build(); + } + + private Block buildBlockFromDateDayVector(DateDayVector vector, Type type) + { + if (!(type instanceof DateType)) { + throw new IllegalArgumentException("Expected DateType but got " + type.getClass().getName()); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + return builder.build(); + } + + private Block buildBlockFromDateMilliVector(DateMilliVector vector, Type type) + { + if (!(type instanceof DateType)) { + throw new IllegalArgumentException("Expected DateType but got " + type.getClass().getName()); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + DateType dateType = (DateType) type; + long days = TimeUnit.MILLISECONDS.toDays(vector.get(i)); + dateType.writeLong(builder, days); + } + } + return builder.build(); + } + + private Block buildBlockFromTimeSecVector(TimeSecVector vector, Type type) + { + if (!(type instanceof TimeType)) { + throw new IllegalArgumentException("Type must be a TimeType for TimeSecVector"); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + int value = vector.get(i); + long millis = TimeUnit.SECONDS.toMillis(value); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + private Block buildBlockFromTimeMilliVector(TimeMilliVector vector, Type type) + { + if (!(type instanceof TimeType)) { + throw new IllegalArgumentException("Type must be a TimeType for TimeSecVector"); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long millis = vector.get(i); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + private Block buildBlockFromTimeMicroVector(TimeMicroVector vector, Type type) + { + if (!(type instanceof TimeType)) { + throw new IllegalArgumentException("Type must be a TimeType for TimemicroVector"); + } + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long value = vector.get(i); + long micro = TimeUnit.MICROSECONDS.toMillis(value); + type.writeLong(builder, micro); + } + } + return builder.build(); + } + + private Block buildBlockFromTimeStampSecVector(TimeStampSecVector vector, Type type) + { + if (!(type instanceof TimestampType)) { + throw new IllegalArgumentException("Type must be a TimestampType for TimeStampSecVector"); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long value = vector.get(i); + long millis = TimeUnit.SECONDS.toMillis(value); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + private Block buildCharTypeBlockFromVarcharVector(VarCharVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + String value = new String(vector.get(i), StandardCharsets.UTF_8); + type.writeSlice(builder, Slices.utf8Slice(CharMatcher.is(' ').trimTrailingFrom(value))); + } + } + return builder.build(); + } + + private Block buildTimeTypeBlockFromVarcharVector(VarCharVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + String timeString = new String(vector.get(i), StandardCharsets.UTF_8); + LocalTime time = LocalTime.parse(timeString); + long millis = Duration.between(LocalTime.MIN, time).toMillis(); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + @Override + public long getCompletedBytes() + { + return 0; + } + + @Override + public long getCompletedPositions() + { + return currentPosition; + } + + @Override + public long getReadTimeNanos() + { + return 0; + } + + @Override + public boolean isFinished() + { + return completed; + } + + @Override + public long getSystemMemoryUsage() + { + return 0; + } + + @Override + public Page getNextPage() + { + if (vectorSchemaRoot.isPresent()) { + vectorSchemaRoot.get().close(); + vectorSchemaRoot = Optional.empty(); + } + + if (flightStream.next()) { + vectorSchemaRoot = Optional.ofNullable(flightStream.getRoot()); + } + + if (!vectorSchemaRoot.isPresent()) { + completed = true; + } + + if (isFinished()) { + return null; + } + + currentPosition++; + + List blocks = new ArrayList<>(); + for (int columnIndex = 0; columnIndex < columnHandles.size(); columnIndex++) { + FieldVector vector = vectorSchemaRoot.get().getVector(columnIndex); + Type type = columnHandles.get(columnIndex).getColumnType(); + + Block block = buildBlockFromVector(vector, type); + blocks.add(block); + } + + return new Page(vectorSchemaRoot.get().getRowCount(), blocks.toArray(new Block[0])); + } + + @Override + public void close() + { + if (vectorSchemaRoot.isPresent()) { + vectorSchemaRoot.get().close(); + vectorSchemaRoot = Optional.empty(); + } + if (flightStream != null) { + try { + flightStream.close(); + } + catch (Exception e) { + logger.error(e); + } + } + try { + if (flightClient != null) { + flightClient.close(); + flightClient = null; + } + } + catch (Exception ex) { + logger.error("Failed to close the flight client: %s", ex.getMessage(), ex); + } + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSourceProvider.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSourceProvider.java new file mode 100644 index 0000000000000..f3bb41c3e35d4 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSourceProvider.java @@ -0,0 +1,52 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorPageSource; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.SplitContext; +import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.google.common.collect.ImmutableList; + +import javax.inject.Inject; + +import java.util.List; + +public class ArrowPageSourceProvider + implements ConnectorPageSourceProvider +{ + private static final Logger logger = Logger.get(ArrowPageSourceProvider.class); + private ArrowFlightClientHandler clientHandler; + @Inject + public ArrowPageSourceProvider(ArrowFlightClientHandler clientHandler) + { + this.clientHandler = clientHandler; + } + + @Override + public ConnectorPageSource createPageSource(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorSplit split, List columns, SplitContext splitContext) + { + ImmutableList.Builder columnHandles = ImmutableList.builder(); + for (ColumnHandle handle : columns) { + columnHandles.add((ArrowColumnHandle) handle); + } + ArrowSplit arrowSplit = (ArrowSplit) split; + logger.debug("Processing split with flight ticket"); + return new ArrowPageSource(arrowSplit, columnHandles.build(), clientHandler, session); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPlugin.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPlugin.java new file mode 100644 index 0000000000000..bb5599b00b625 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPlugin.java @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.Plugin; +import com.facebook.presto.spi.connector.ConnectorFactory; +import com.google.common.collect.ImmutableList; +import com.google.inject.Module; + +import static com.google.common.base.MoreObjects.firstNonNull; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Strings.isNullOrEmpty; +import static java.util.Objects.requireNonNull; + +public class ArrowPlugin + implements Plugin +{ + protected final String name; + protected final Module module; + + public ArrowPlugin(String name, Module module) + { + checkArgument(!isNullOrEmpty(name), "name is null or empty"); + this.name = name; + this.module = requireNonNull(module, "module is null"); + } + + private static ClassLoader getClassLoader() + { + return firstNonNull(Thread.currentThread().getContextClassLoader(), ArrowPlugin.class.getClassLoader()); + } + + @Override + public Iterable getConnectorFactories() + { + return ImmutableList.of(new ArrowConnectorFactory(name, module, getClassLoader())); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplit.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplit.java new file mode 100644 index 0000000000000..db65912de8c58 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplit.java @@ -0,0 +1,90 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.NodeProvider; +import com.facebook.presto.spi.schedule.NodeSelectionStrategy; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import javax.annotation.Nullable; + +import java.util.Collections; +import java.util.List; + +public class ArrowSplit + implements ConnectorSplit +{ + private final String schemaName; + private final String tableName; + private final byte[] ticket; + private final List locationUrls; + + @JsonCreator + public ArrowSplit( + @JsonProperty("schemaName") @Nullable String schemaName, + @JsonProperty("tableName") String tableName, + @JsonProperty("ticket") byte[] ticket, + @JsonProperty("locationUrls") List locationUrls) + { + this.schemaName = schemaName; + this.tableName = tableName; + this.ticket = ticket; + this.locationUrls = locationUrls; + } + + @Override + public NodeSelectionStrategy getNodeSelectionStrategy() + { + return NodeSelectionStrategy.NO_PREFERENCE; + } + + @Override + public List getPreferredNodes(NodeProvider nodeProvider) + { + return Collections.emptyList(); + } + + @Override + public Object getInfo() + { + return this.getInfoMap(); + } + + @JsonProperty + public String getSchemaName() + { + return schemaName; + } + + @JsonProperty + public String getTableName() + { + return tableName; + } + + @JsonProperty + public byte[] getTicket() + { + return ticket; + } + + @JsonProperty + public List getLocationUrls() + { + return locationUrls; + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java new file mode 100644 index 0000000000000..e0f8c6586791d --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java @@ -0,0 +1,52 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.ConnectorTableHandle; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +public class ArrowTableHandle + implements ConnectorTableHandle +{ + private final String schema; + private final String table; + + @JsonCreator + public ArrowTableHandle( + @JsonProperty("schema") String schema, + @JsonProperty("table") String table) + { + this.schema = schema; + this.table = table; + } + + @JsonProperty("schema") + public String getSchema() + { + return schema; + } + + @JsonProperty("table") + public String getTable() + { + return table; + } + + @Override + public String toString() + { + return schema + ":" + table; + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java new file mode 100644 index 0000000000000..512246683562f --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java @@ -0,0 +1,66 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; + +import static java.util.Objects.requireNonNull; + +public class ArrowTableLayoutHandle + implements ConnectorTableLayoutHandle +{ + private final ArrowTableHandle tableHandle; + private final List columnHandles; + private final TupleDomain tupleDomain; + + @JsonCreator + public ArrowTableLayoutHandle(@JsonProperty("table") ArrowTableHandle table, + @JsonProperty("columnHandles") List columnHandles, + @JsonProperty("tupleDomain") TupleDomain domain) + { + this.tableHandle = requireNonNull(table, "table is null"); + this.columnHandles = requireNonNull(columnHandles, "columns are null"); + this.tupleDomain = requireNonNull(domain, "domain is null"); + } + + @JsonProperty("table") + public ArrowTableHandle getTableHandle() + { + return tableHandle; + } + + @JsonProperty("tupleDomain") + public TupleDomain getTupleDomain() + { + return tupleDomain; + } + + @JsonProperty("columnHandles") + public List getColumnHandles() + { + return columnHandles; + } + + @Override + public String toString() + { + return "tableHandle:" + tableHandle + ", columnHandles:" + columnHandles + ", tupleDomain:" + tupleDomain; + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTransactionHandle.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTransactionHandle.java new file mode 100644 index 0000000000000..07eb7385cfbcf --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTransactionHandle.java @@ -0,0 +1,22 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; + +public enum ArrowTransactionHandle + implements ConnectorTransactionHandle +{ + INSTANCE +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightSmoke.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightSmoke.java new file mode 100644 index 0000000000000..fd869748096a6 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightSmoke.java @@ -0,0 +1,171 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.Session; +import com.facebook.presto.common.type.TimeZoneKey; +import com.facebook.presto.testing.MaterializedResult; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestQueries; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.RootAllocator; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.io.File; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.ZoneId; +import java.time.ZonedDateTime; +import java.time.format.DateTimeFormatter; + +import static com.facebook.presto.common.type.CharType.createCharType; +import static com.facebook.presto.common.type.DateType.DATE; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.TimeType.TIME; +import static com.facebook.presto.common.type.TimestampType.TIMESTAMP; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.common.type.VarcharType.createVarcharType; +import static com.facebook.presto.testing.MaterializedResult.resultBuilder; +import static java.lang.String.format; +import static org.testng.Assert.assertTrue; + +public class TestArrowFlightSmoke + extends AbstractTestQueries +{ + private static final Logger logger = Logger.get(TestArrowFlightSmoke.class); + private static RootAllocator allocator; + private static FlightServer server; + private static Location serverLocation; + + @BeforeClass + public void setup() + throws Exception + { + File certChainFile = new File("src/test/resources/server.crt"); + File privateKeyFile = new File("src/test/resources/server.key"); + + allocator = new RootAllocator(Long.MAX_VALUE); + serverLocation = Location.forGrpcTls("127.0.0.1", 9443); + server = FlightServer.builder(allocator, serverLocation, new TestingArrowServer(allocator)) + .useTls(certChainFile, privateKeyFile) + .build(); + + server.start(); + logger.info("Server listening on port " + server.getPort()); + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return TestingArrowQueryRunner.createQueryRunner(); + } + + @AfterClass(alwaysRun = true) + public void close() + throws InterruptedException + { + server.close(); + allocator.close(); + } + + @Test + public void testShowCharColumns() + { + MaterializedResult actual = computeActual("SHOW COLUMNS FROM member"); + + MaterializedResult expectedUnparametrizedVarchar = resultBuilder(getSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR) + .row("id", "integer", "", "") + .row("name", "varchar", "", "") + .row("sex", "char", "", "") + .row("state", "char", "", "") + .build(); + + MaterializedResult expectedParametrizedVarchar = resultBuilder(getSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR) + .row("id", "integer", "", "") + .row("name", "varchar(50)", "", "") + .row("sex", "char(1)", "", "") + .row("state", "char(5)", "", "") + .build(); + + assertTrue(actual.equals(expectedParametrizedVarchar) || actual.equals(expectedUnparametrizedVarchar), + format("%s matches neither %s nor %s", actual, expectedParametrizedVarchar, expectedUnparametrizedVarchar)); + } + + @Test + public void testPredicateOnCharColumn() + { + MaterializedResult actualRow = computeActual("SELECT * from member WHERE state = 'CD'"); + MaterializedResult expectedRow = resultBuilder(getSession(), INTEGER, createVarcharType(50), createCharType(1), createCharType(5)) + .row(2, "MARY", "F", "CD ") + .build(); + assertTrue(actualRow.equals(expectedRow)); + } + + @Test + public void testSelectTime() + { + MaterializedResult actualRow = computeActual("SELECT * from event WHERE id = 1"); + Session session = getSession(); + MaterializedResult expectedRow = resultBuilder(session, INTEGER, DATE, TIME, TIMESTAMP) + .row(1, + getDate("2004-12-31"), + getTimeAtZone("23:59:59", session.getTimeZoneKey()), + getDateTimeAtZone("2005-12-31 23:59:59", session.getTimeZoneKey())) + .build(); + assertTrue(actualRow.equals(expectedRow)); + } + + private LocalDate getDate(String dateString) + { + DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd"); + LocalDate localDate = LocalDate.parse(dateString, formatter); + + return localDate; + } + + private LocalTime getTimeAtZone(String timeString, TimeZoneKey timeZoneKey) + { + DateTimeFormatter formatter = DateTimeFormatter.ofPattern("HH:mm:ss"); + LocalTime localTime = LocalTime.parse(timeString, formatter); + + LocalDateTime localDateTime = LocalDateTime.of(LocalDate.of(1970, 1, 1), localTime); + ZonedDateTime localZonedDateTime = localDateTime.atZone(ZoneId.systemDefault()); + + ZoneId zoneId = ZoneId.of(timeZoneKey.getId()); + ZonedDateTime zonedDateTime = localZonedDateTime.withZoneSameInstant(zoneId); + + LocalTime localTimeAtZone = zonedDateTime.toLocalTime(); + return localTimeAtZone; + } + + private LocalDateTime getDateTimeAtZone(String dateTimeString, TimeZoneKey timeZoneKey) + { + DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"); + LocalDateTime localDateTime = LocalDateTime.parse(dateTimeString, formatter); + + ZonedDateTime localZonedDateTime = localDateTime.atZone(ZoneId.systemDefault()); + + ZoneId zoneId = ZoneId.of(timeZoneKey.getId()); + ZonedDateTime zonedDateTime = localZonedDateTime.withZoneSameInstant(zoneId); + + LocalDateTime localDateTimeAtZone = zonedDateTime.toLocalDateTime(); + return localDateTimeAtZone; + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFactory.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFactory.java new file mode 100644 index 0000000000000..99e2d6fb96f46 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFactory.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +public class TestingArrowFactory + extends ArrowConnectorFactory +{ + public TestingArrowFactory() + { + super("arrow", new TestingArrowModule(), TestingArrowFactory.class.getClassLoader()); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightClientHandler.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightClientHandler.java new file mode 100644 index 0000000000000..7eb2bc88a7907 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightClientHandler.java @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.ConnectorSession; +import org.apache.arrow.flight.auth2.BearerCredentialWriter; +import org.apache.arrow.flight.grpc.CredentialCallOption; + +import javax.inject.Inject; + +public class TestingArrowFlightClientHandler + extends ArrowFlightClientHandler +{ + @Inject + public TestingArrowFlightClientHandler(ArrowFlightConfig config) + { + super(config); + } + + @Override + protected CredentialCallOption getCallOptions(ConnectorSession connectorSession) + { + return new CredentialCallOption(new BearerCredentialWriter(null)); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightConfig.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightConfig.java new file mode 100644 index 0000000000000..c55eab7ae97ef --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightConfig.java @@ -0,0 +1,113 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.configuration.Config; +import com.facebook.airlift.configuration.ConfigSecuritySensitive; + +public class TestingArrowFlightConfig +{ + private String host; // non-static field + private String database; // non-static field + private String username; // non-static field + private String password; // non-static field + private String name; // non-static field + private Integer port; // non-static field + private Boolean ssl; + + public String getDataSourceHost() + { // non-static getter + return host; + } + + public String getDataSourceDatabase() + { // non-static getter + return database; + } + + public String getDataSourceUsername() + { // non-static getter + return username; + } + + public String getDataSourcePassword() + { // non-static getter + return password; + } + + public String getDataSourceName() + { // non-static getter + return name; + } + + public Integer getDataSourcePort() + { // non-static getter + return port; + } + + public Boolean getDataSourceSSL() + { // non-static getter + return ssl; + } + + @Config("data-source.host") + public TestingArrowFlightConfig setDataSourceHost(String host) + { // non-static setter + this.host = host; + return this; + } + + @Config("data-source.database") + public TestingArrowFlightConfig setDataSourceDatabase(String database) + { // non-static setter + this.database = database; + return this; + } + + @Config("data-source.username") + public TestingArrowFlightConfig setDataSourceUsername(String username) + { // non-static setter + this.username = username; + return this; + } + + @Config("data-source.password") + @ConfigSecuritySensitive + public TestingArrowFlightConfig setDataSourcePassword(String password) + { // non-static setter + this.password = password; + return this; + } + + @Config("data-source.name") + public TestingArrowFlightConfig setDataSourceName(String name) + { // non-static setter + this.name = name; + return this; + } + + @Config("data-source.port") + public TestingArrowFlightConfig setDataSourcePort(Integer port) + { // non-static setter + this.port = port; + return this; + } + + @Config("data-source.ssl") + public TestingArrowFlightConfig setDataSourceSSL(Boolean ssl) + { // non-static setter + this.ssl = ssl; + return this; + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightRequest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightRequest.java new file mode 100644 index 0000000000000..2497a7c68ae3f --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightRequest.java @@ -0,0 +1,114 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializationFeature; + +import java.nio.charset.StandardCharsets; +import java.util.Optional; + +public class TestingArrowFlightRequest + implements ArrowFlightRequest +{ + private final String schema; + private final String table; + private final Optional query; + private final ArrowFlightConfig config; + private final int noOfPartitions; + + private final TestingArrowFlightConfig testconfig; + + public TestingArrowFlightRequest(ArrowFlightConfig config, TestingArrowFlightConfig testconfig, String schema, String table, Optional query, int noOfPartitions) + { + this.config = config; + this.schema = schema; + this.table = table; + this.query = query; + this.testconfig = testconfig; + this.noOfPartitions = noOfPartitions; + } + + public TestingArrowFlightRequest(ArrowFlightConfig config, String schema, int noOfPartitions, TestingArrowFlightConfig testconfig) + { + this.schema = schema; + this.table = null; + this.query = Optional.empty(); + this.config = config; + this.testconfig = testconfig; + this.noOfPartitions = noOfPartitions; + } + + public String getSchema() + { + return schema; + } + + public String getTable() + { + return table; + } + + public Optional getQuery() + { + return query; + } + + public TestingRequestData build() + { + TestingRequestData requestData = new TestingRequestData(); + requestData.setConnectionProperties(getConnectionProperties()); + requestData.setInteractionProperties(createInteractionProperties()); + return requestData; + } + + @Override + public byte[] getCommand() + { + ObjectMapper objectMapper = new ObjectMapper(); + objectMapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); + try { + String jsonString = objectMapper.writeValueAsString(build()); + return jsonString.getBytes(StandardCharsets.UTF_8); + } + catch (JsonProcessingException e) { + throw new ArrowException(ArrowErrorCode.ARROW_FLIGHT_ERROR, "JSON request cannot be created.", e); + } + } + + private TestingConnectionProperties getConnectionProperties() + { + TestingConnectionProperties properties = new TestingConnectionProperties(); + properties.database = testconfig.getDataSourceDatabase(); + properties.host = testconfig.getDataSourceHost(); + properties.port = testconfig.getDataSourcePort(); + properties.username = testconfig.getDataSourceUsername(); + properties.password = testconfig.getDataSourcePassword(); + return properties; + } + + private TestingInteractionProperties createInteractionProperties() + { + TestingInteractionProperties interactionProperties = new TestingInteractionProperties(); + if (getQuery().isPresent()) { + interactionProperties.setSelectStatement(getQuery().get()); + } + else { + interactionProperties.setSchema(getSchema()); + interactionProperties.setTable(getTable()); + } + return interactionProperties; + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowMetadata.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowMetadata.java new file mode 100644 index 0000000000000..d7c101945fa34 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowMetadata.java @@ -0,0 +1,157 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.common.type.CharType; +import com.facebook.presto.common.type.TimeType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.SchemaTableName; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.common.collect.ImmutableList; +import org.apache.arrow.flight.Action; +import org.apache.arrow.flight.Result; +import org.apache.arrow.vector.types.pojo.Field; + +import javax.inject.Inject; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; + +import static java.util.Locale.ENGLISH; + +public class TestingArrowMetadata + extends AbstractArrowMetadata +{ + private static final Logger logger = Logger.get(TestingArrowMetadata.class); + private static final ObjectMapper objectMapper = new ObjectMapper(); + private final NodeManager nodeManager; + private final TestingArrowFlightConfig testconfig; + private final ArrowFlightClientHandler clientHandler; + private final ArrowFlightConfig config; + + @Inject + public TestingArrowMetadata(ArrowFlightConfig config, ArrowFlightClientHandler clientHandler, NodeManager nodeManager, TestingArrowFlightConfig testconfig) + { + super(config, clientHandler); + this.nodeManager = nodeManager; + this.testconfig = testconfig; + this.clientHandler = clientHandler; + this.config = config; + } + + @Override + protected ArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, String schema) + { + return new TestingArrowFlightRequest(config, schema, nodeManager.getWorkerNodes().size(), testconfig); + } + + @Override + public List listSchemaNames(ConnectorSession session) + { + List listSchemas = extractSchemaAndTableData(Optional.empty(), session); + List names = new ArrayList<>(); + for (String value : listSchemas) { + names.add(value.toLowerCase(ENGLISH)); + } + return ImmutableList.copyOf(names); + } + + @Override + public List listTables(ConnectorSession session, Optional schemaName) + { + String schemaValue = schemaName.orElse(""); + String dataSourceSpecificSchemaName = getDataSourceSpecificSchemaName(config, schemaValue); + List listTables = extractSchemaAndTableData(Optional.ofNullable(dataSourceSpecificSchemaName), session); + List tables = new ArrayList<>(); + for (String value : listTables) { + tables.add(new SchemaTableName(dataSourceSpecificSchemaName.toLowerCase(ENGLISH), value.toLowerCase(ENGLISH))); + } + + return tables; + } + + public List extractSchemaAndTableData(Optional schema, ConnectorSession connectorSession) + { + try (ArrowFlightClient client = clientHandler.getClient(Optional.empty())) { + List names = new ArrayList<>(); + ArrowFlightRequest request = getArrowFlightRequest(config, schema.orElse(null)); + ObjectNode rootNode = (ObjectNode) objectMapper.readTree(request.getCommand()); + + String modifiedQueryJson = objectMapper.writeValueAsString(rootNode); + byte[] queryJsonBytes = modifiedQueryJson.getBytes(StandardCharsets.UTF_8); + Iterator iterator = client.getFlightClient().doAction(new Action("discovery", queryJsonBytes), clientHandler.getCallOptions(connectorSession)); + while (iterator.hasNext()) { + Result result = iterator.next(); + String jsonResult = new String(result.getBody(), StandardCharsets.UTF_8); + List tableNames = objectMapper.readValue(jsonResult, new TypeReference>() { + }); + names.addAll(tableNames); + } + return names; + } + catch (IOException | InterruptedException e) { + throw new RuntimeException(e); + } + } + + @Override + protected Type overrideFieldType(Field field, Type type) + { + String columnLength = field.getMetadata().get("columnLength"); + int length = columnLength != null ? Integer.parseInt(columnLength) : 0; + + String nativeType = field.getMetadata().get("columnNativeType"); + + if ("CHAR".equals(nativeType) || "CHARACTER".equals(nativeType)) { + return CharType.createCharType(length); + } + else if ("VARCHAR".equals(nativeType)) { + return VarcharType.createVarcharType(length); + } + else if ("TIME".equals(nativeType)) { + return TimeType.TIME; + } + else { + return type; + } + } + + @Override + protected String getDataSourceSpecificSchemaName(ArrowFlightConfig config, String schemaName) + { + return schemaName; + } + + @Override + protected String getDataSourceSpecificTableName(ArrowFlightConfig config, String tableName) + { + return tableName; + } + + @Override + protected ArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, Optional query, String schema, String table) + { + return new TestingArrowFlightRequest(config, testconfig, schema, table, query, nodeManager.getWorkerNodes().size()); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowModule.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowModule.java new file mode 100644 index 0000000000000..cab5872a5507a --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowModule.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.connector.ConnectorMetadata; +import com.facebook.presto.spi.connector.ConnectorSplitManager; +import com.google.inject.Binder; +import com.google.inject.Module; +import com.google.inject.Scopes; + +import static com.facebook.airlift.configuration.ConfigBinder.configBinder; + +public class TestingArrowModule + implements Module +{ + @Override + public void configure(Binder binder) + { + configBinder(binder).bindConfig(TestingArrowFlightConfig.class); + binder.bind(ConnectorSplitManager.class).to(TestingArrowSplitManager.class).in(Scopes.SINGLETON); + binder.bind(ArrowFlightClientHandler.class).to(TestingArrowFlightClientHandler.class).in(Scopes.SINGLETON); + binder.bind(ConnectorMetadata.class).to(TestingArrowMetadata.class).in(Scopes.SINGLETON); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowPlugin.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowPlugin.java new file mode 100644 index 0000000000000..cf67aaa983121 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowPlugin.java @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.connector.ConnectorFactory; +import org.testng.annotations.Test; + +import static com.facebook.airlift.testing.Assertions.assertInstanceOf; +import static com.google.common.collect.Iterables.getOnlyElement; + +public class TestingArrowPlugin +{ + @Test + public void testStartup() + { + ArrowModule testModule = new ArrowModule("arrow-flight"); + ArrowPlugin plugin = new ArrowPlugin("arrow-flight", testModule); + ConnectorFactory factory = getOnlyElement(plugin.getConnectorFactories()); + assertInstanceOf(factory, ArrowConnectorFactory.class); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowQueryBuilder.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowQueryBuilder.java new file mode 100644 index 0000000000000..600f100a7e6ec --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowQueryBuilder.java @@ -0,0 +1,305 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.predicate.Domain; +import com.facebook.presto.common.predicate.Range; +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.BooleanType; +import com.facebook.presto.common.type.CharType; +import com.facebook.presto.common.type.DateType; +import com.facebook.presto.common.type.DoubleType; +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.common.type.RealType; +import com.facebook.presto.common.type.SmallintType; +import com.facebook.presto.common.type.TimeType; +import com.facebook.presto.common.type.TimestampType; +import com.facebook.presto.common.type.TinyintType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.spi.ColumnHandle; +import com.google.common.base.Joiner; +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; + +import java.sql.Time; +import java.sql.Timestamp; +import java.text.SimpleDateFormat; +import java.time.ZoneId; +import java.util.ArrayList; +import java.util.Date; +import java.util.List; +import java.util.Map; +import java.util.TimeZone; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Strings.isNullOrEmpty; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.lang.Float.intBitsToFloat; +import static java.lang.Math.toIntExact; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.joining; + +public class TestingArrowQueryBuilder +{ + // not all databases support booleans, so use 1=1 and 1=0 instead + private static final String ALWAYS_TRUE = "1=1"; + private static final String ALWAYS_FALSE = "1=0"; + public static final String DATE_FORMAT = "yyyy-MM-dd"; + public static final String TIMESTAMP_FORMAT = "yyyy-MM-dd HH:mm:ss.SSS"; + public static final String TIME_FORMAT = "HH:mm:ss"; + public static final TimeZone UTC_TIME_ZONE = TimeZone.getTimeZone(ZoneId.of("UTC")); + + public String buildSql( + String schema, + String table, + List columns, + Map columnExpressions, + TupleDomain tupleDomain) + { + StringBuilder sql = new StringBuilder(); + + sql.append("SELECT "); + sql.append(addColumnExpression(columns, columnExpressions)); + + sql.append(" FROM "); + if (!isNullOrEmpty(schema)) { + sql.append(quote(schema)).append('.'); + } + sql.append(quote(table)); + + List accumulator = new ArrayList<>(); + + if (tupleDomain != null && !tupleDomain.isAll()) { + List clauses = toConjuncts(columns, tupleDomain, accumulator); + if (!clauses.isEmpty()) { + sql.append(" WHERE ") + .append(Joiner.on(" AND ").join(clauses)); + } + } + + return sql.toString(); + } + + public static String convertEpochToString(long epochValue, Type type) + { + if (type instanceof DateType) { + long millis = TimeUnit.DAYS.toMillis(epochValue); + Date date = new Date(millis); + SimpleDateFormat dateFormat = new SimpleDateFormat(DATE_FORMAT); + dateFormat.setTimeZone(UTC_TIME_ZONE); + return dateFormat.format(date); + } + else if (type instanceof TimestampType) { + Timestamp timestamp = new Timestamp(epochValue); + SimpleDateFormat timestampFormat = new SimpleDateFormat(TIMESTAMP_FORMAT); + timestampFormat.setTimeZone(UTC_TIME_ZONE); + return timestampFormat.format(timestamp); + } + else if (type instanceof TimeType) { + long millis = TimeUnit.SECONDS.toMillis(epochValue / 1000); + Time time = new Time(millis); + SimpleDateFormat timeFormat = new SimpleDateFormat(TIME_FORMAT); + timeFormat.setTimeZone(UTC_TIME_ZONE); + return timeFormat.format(time); + } + else { + throw new UnsupportedOperationException(type + " is not supported."); + } + } + + protected static class TypeAndValue + { + private final Type type; + private final Object value; + + public TypeAndValue(Type type, Object value) + { + this.type = requireNonNull(type, "type is null"); + this.value = requireNonNull(value, "value is null"); + } + + public Type getType() + { + return type; + } + + public Object getValue() + { + return value; + } + } + + private String addColumnExpression(List columns, Map columnExpressions) + { + if (columns.isEmpty()) { + return "null"; + } + + return columns.stream() + .map(ArrowColumnHandle -> { + String columnAlias = quote(ArrowColumnHandle.getColumnName()); + String expression = columnExpressions.get(ArrowColumnHandle.getColumnName()); + if (expression == null) { + return columnAlias; + } + return format("%s AS %s", expression, columnAlias); + }) + .collect(joining(", ")); + } + + private static boolean isAcceptedType(Type type) + { + Type validType = requireNonNull(type, "type is null"); + return validType.equals(BigintType.BIGINT) || + validType.equals(TinyintType.TINYINT) || + validType.equals(SmallintType.SMALLINT) || + validType.equals(IntegerType.INTEGER) || + validType.equals(DoubleType.DOUBLE) || + validType.equals(RealType.REAL) || + validType.equals(BooleanType.BOOLEAN) || + validType.equals(DateType.DATE) || + validType.equals(TimeType.TIME) || + validType.equals(TimestampType.TIMESTAMP) || + validType instanceof VarcharType || + validType instanceof CharType; + } + private List toConjuncts(List columns, TupleDomain tupleDomain, List accumulator) + { + ImmutableList.Builder builder = ImmutableList.builder(); + for (ArrowColumnHandle column : columns) { + Type type = column.getColumnType(); + if (isAcceptedType(type)) { + Domain domain = tupleDomain.getDomains().get().get(column); + if (domain != null) { + builder.add(toPredicate(column.getColumnName(), domain, column, accumulator)); + } + } + } + return builder.build(); + } + + private String toPredicate(String columnName, Domain domain, ArrowColumnHandle columnHandle, List accumulator) + { + checkArgument(domain.getType().isOrderable(), "Domain type must be orderable"); + + if (domain.getValues().isNone()) { + return domain.isNullAllowed() ? quote(columnName) + " IS NULL" : ALWAYS_FALSE; + } + + if (domain.getValues().isAll()) { + return domain.isNullAllowed() ? ALWAYS_TRUE : quote(columnName) + " IS NOT NULL"; + } + + List disjuncts = new ArrayList<>(); + List singleValues = new ArrayList<>(); + for (Range range : domain.getValues().getRanges().getOrderedRanges()) { + checkState(!range.isAll()); // Already checked + if (range.isSingleValue()) { + singleValues.add(range.getSingleValue()); + } + else { + List rangeConjuncts = new ArrayList<>(); + if (!range.isLowUnbounded()) { + rangeConjuncts.add(toPredicate(columnName, range.isLowInclusive() ? ">=" : ">", range.getLowBoundedValue(), columnHandle, accumulator)); + } + if (!range.isHighUnbounded()) { + rangeConjuncts.add(toPredicate(columnName, range.isHighInclusive() ? "<=" : "<", range.getHighBoundedValue(), columnHandle, accumulator)); + } + // If rangeConjuncts is null, then the range was ALL, which should already have been checked for + checkState(!rangeConjuncts.isEmpty()); + disjuncts.add("(" + Joiner.on(" AND ").join(rangeConjuncts) + ")"); + } + } + + // Add back all of the possible single values either as an equality or an IN predicate + if (singleValues.size() == 1) { + disjuncts.add(toPredicate(columnName, "=", getOnlyElement(singleValues), columnHandle, accumulator)); + } + else if (singleValues.size() > 1) { + for (Object value : singleValues) { + bindValue(value, columnHandle, accumulator); + } + String values = Joiner.on(",").join(singleValues.stream().map(v -> + parameterValueToString(columnHandle.getColumnType(), v)).collect(Collectors.toList())); + disjuncts.add(quote(columnName) + " IN (" + values + ")"); + } + + // Add nullability disjuncts + checkState(!disjuncts.isEmpty()); + if (domain.isNullAllowed()) { + disjuncts.add(quote(columnName) + " IS NULL"); + } + + return "(" + Joiner.on(" OR ").join(disjuncts) + ")"; + } + + private String toPredicate(String columnName, String operator, Object value, ArrowColumnHandle columnHandle, List accumulator) + { + bindValue(value, columnHandle, accumulator); + return quote(columnName) + " " + operator + " " + parameterValueToString(columnHandle.getColumnType(), value); + } + private String quote(String name) + { + return "\"" + name + "\""; + } + + private String quoteValue(String name) + { + return "'" + name + "'"; + } + + private void bindValue(Object value, ArrowColumnHandle columnHandle, List accumulator) + { + Type type = columnHandle.getColumnType(); + accumulator.add(new TypeAndValue(type, value)); + } + + public static String convertLongToFloatString(Long value) + { + float floatFromIntBits = intBitsToFloat(toIntExact(value)); + return String.valueOf(floatFromIntBits); + } + + private String parameterValueToString(Type type, Object value) + { + Class javaType = type.getJavaType(); + if (type instanceof DateType && javaType == long.class) { + return quoteValue(convertEpochToString((Long) value, type)); + } + else if (type instanceof TimeType && javaType == long.class) { + return quoteValue(convertEpochToString((Long) value, type)); + } + else if (type instanceof TimestampType && javaType == long.class) { + return quoteValue(convertEpochToString((Long) value, type)); + } + else if (type instanceof RealType && javaType == long.class) { + return convertLongToFloatString((Long) value); + } + else if (javaType == boolean.class || javaType == double.class || javaType == long.class) { + return value.toString(); + } + else if (javaType == Slice.class) { + return quoteValue(((Slice) value).toStringUtf8()); + } + else { + return quoteValue(value.toString()); + } + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowQueryRunner.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowQueryRunner.java new file mode 100644 index 0000000000000..aeed01e6ca473 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowQueryRunner.java @@ -0,0 +1,74 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.Session; +import com.facebook.presto.tests.DistributedQueryRunner; +import com.google.common.collect.ImmutableMap; + +import java.util.Map; + +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; + +public class TestingArrowQueryRunner +{ + private static DistributedQueryRunner queryRunner; + private static final Logger logger = Logger.get(TestingArrowQueryRunner.class); + private TestingArrowQueryRunner() + { + throw new UnsupportedOperationException("This is a utility class and cannot be instantiated"); + } + + public static DistributedQueryRunner createQueryRunner() throws Exception + { + if (queryRunner == null) { + queryRunner = createQueryRunner(ImmutableMap.of(), TestingArrowFactory.class); + } + return queryRunner; + } + + private static DistributedQueryRunner createQueryRunner(Map catalogProperties, Class factoryClass) throws Exception + { + Session session = testSessionBuilder() + .setCatalog("arrow") + .setSchema("testdb") + .build(); + + if (queryRunner == null) { + queryRunner = DistributedQueryRunner.builder(session).build(); + } + + try { + String connectorName = "arrow"; + queryRunner.installPlugin(new ArrowPlugin(connectorName, new TestingArrowModule())); + + ImmutableMap.Builder properties = ImmutableMap.builder() + .putAll(catalogProperties) + .put("arrow-flight.server", "127.0.0.1") + .put("arrow-flight.server-ssl-enabled", "true") + .put("arrow-flight.server.port", "9443") + .put("arrow-flight.server.verify", "false"); + + queryRunner.createCatalog(connectorName, connectorName, properties.build()); + + return queryRunner; + } + catch (Exception e) { + logger.error(e); + throw new RuntimeException("Failed to create ArrowQueryRunner", e); + } + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowServer.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowServer.java new file mode 100644 index 0000000000000..b3124d0425913 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowServer.java @@ -0,0 +1,308 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.arrow.adapter.jdbc.ArrowVectorIterator; +import org.apache.arrow.adapter.jdbc.JdbcToArrow; +import org.apache.arrow.adapter.jdbc.JdbcToArrowConfigBuilder; +import org.apache.arrow.flight.Action; +import org.apache.arrow.flight.ActionType; +import org.apache.arrow.flight.Criteria; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.PutResult; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.Calendar; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.TimeZone; +import java.util.concurrent.ThreadLocalRandom; + +public class TestingArrowServer + implements FlightProducer +{ + private final RootAllocator allocator; + private final ObjectMapper objectMapper = new ObjectMapper(); + private static Connection connection; + private static final Logger logger = Logger.get(TestingArrowServer.class); + + public TestingArrowServer(RootAllocator allocator) throws Exception + { + this.allocator = allocator; + String h2JdbcUrl = "jdbc:h2:mem:testdb" + System.nanoTime() + "_" + ThreadLocalRandom.current().nextInt() + ";DB_CLOSE_DELAY=-1"; + TestingH2DatabaseSetup.setup(h2JdbcUrl); + this.connection = DriverManager.getConnection(h2JdbcUrl, "sa", ""); + } + + @Override + public void getStream(CallContext callContext, Ticket ticket, ServerStreamListener serverStreamListener) + { + try (Statement stmt = connection.createStatement()) { + // Convert ticket bytes to String and parse into JSON + String ticketString = new String(ticket.getBytes(), StandardCharsets.UTF_8); + JsonNode ticketJson = objectMapper.readTree(ticketString); + + // Extract interaction properties and validate + JsonNode interactionProperties = ticketJson.get("interactionProperties"); + if (interactionProperties == null || !interactionProperties.has("select_statement")) { + throw new IllegalArgumentException("Invalid ticket format: missing select_statement."); + } + + // Extract and validate the SQL query + String query = interactionProperties.get("select_statement").asText(); + if (query == null || query.trim().isEmpty()) { + throw new IllegalArgumentException("Query cannot be null or empty."); + } + + logger.info("Executing query: " + query); + query = query.toUpperCase(); // Optionally, to maintain consistency + + try (ResultSet rs = stmt.executeQuery(query)) { + JdbcToArrowConfigBuilder config = new JdbcToArrowConfigBuilder().setAllocator(allocator).setTargetBatchSize(2048) + .setCalendar(Calendar.getInstance(TimeZone.getDefault())); + ArrowVectorIterator iterator = JdbcToArrow.sqlToArrowVectorIterator(rs, config.build()); + boolean firstBatch = true; + + VectorLoader vectorLoader = null; + VectorSchemaRoot newRoot = null; + while (iterator.hasNext()) { + try (VectorSchemaRoot root = iterator.next()) { + VectorUnloader vectorUnloader = new VectorUnloader(root); + try (ArrowRecordBatch batch = vectorUnloader.getRecordBatch()) { + if (firstBatch) { + firstBatch = false; + newRoot = VectorSchemaRoot.create(root.getSchema(), allocator); + vectorLoader = new VectorLoader(newRoot); + serverStreamListener.start(newRoot); + } + if (vectorLoader != null) { + vectorLoader.load(batch); + } + serverStreamListener.putNext(); + } + } + } + if (newRoot != null) { + newRoot.close(); + } + serverStreamListener.completed(); + } + } + // Handle Arrow processing errors + catch (IOException e) { + logger.error("Arrow data processing failed", e); + serverStreamListener.error(e); + throw new RuntimeException("Failed to process Arrow data", e); + } + // Handle all other exceptions, including parsing errors + catch (Exception e) { + logger.error("Ticket processing failed", e); + serverStreamListener.error(e); + throw new RuntimeException("Failed to process the ticket", e); + } + } + + @Override + public void listFlights(CallContext callContext, Criteria criteria, StreamListener streamListener) + { + } + + @Override + public FlightInfo getFlightInfo(CallContext callContext, FlightDescriptor flightDescriptor) + { + try { + String jsonRequest = new String(flightDescriptor.getCommand(), StandardCharsets.UTF_8); + JsonNode rootNode = objectMapper.readTree(jsonRequest); + + String schemaName = rootNode.get("interactionProperties").get("schema_name").asText(null); + String tableName = rootNode.get("interactionProperties").get("table_name").asText(null); + String selectStatement = rootNode.get("interactionProperties").get("select_statement").asText(null); + + List fields = new ArrayList<>(); + if (schemaName != null && tableName != null) { + String query = "SELECT * FROM INFORMATION_SCHEMA.COLUMNS " + + "WHERE TABLE_SCHEMA='" + schemaName.toUpperCase() + "' " + + "AND TABLE_NAME='" + tableName.toUpperCase() + "'"; + + try (ResultSet rs = connection.createStatement().executeQuery(query)) { + while (rs.next()) { + String columnName = rs.getString("COLUMN_NAME"); + String dataType = rs.getString("TYPE_NAME"); + String charMaxLength = rs.getString("CHARACTER_MAXIMUM_LENGTH"); + int precision = rs.getInt("NUMERIC_PRECISION"); + int scale = rs.getInt("NUMERIC_SCALE"); + + ArrowType arrowType = convertSqlTypeToArrowType(dataType, precision, scale); + Map metaDataMap = new HashMap<>(); + metaDataMap.put("columnNativeType", dataType); + if (charMaxLength != null) { + metaDataMap.put("columnLength", charMaxLength); + } + FieldType fieldType = new FieldType(true, arrowType, null, metaDataMap); + Field field = new Field(columnName, fieldType, null); + fields.add(field); + } + } + } + else if (selectStatement != null) { + selectStatement = selectStatement.toUpperCase(); + logger.info("Executing SELECT query: " + selectStatement); + try (ResultSet rs = connection.createStatement().executeQuery(selectStatement)) { + ResultSetMetaData metaData = rs.getMetaData(); + int columnCount = metaData.getColumnCount(); + + for (int i = 1; i <= columnCount; i++) { + String columnName = metaData.getColumnName(i); + String columnType = metaData.getColumnTypeName(i); + int precision = metaData.getPrecision(i); + int scale = metaData.getScale(i); + + ArrowType arrowType = convertSqlTypeToArrowType(columnType, precision, scale); + Field field = new Field(columnName, FieldType.nullable(arrowType), null); + fields.add(field); + } + } + } + else { + throw new IllegalArgumentException("Either schema_name/table_name or select_statement must be provided."); + } + + Schema schema = new Schema(fields); + FlightEndpoint endpoint = new FlightEndpoint(new Ticket(flightDescriptor.getCommand())); + return new FlightInfo(schema, flightDescriptor, Collections.singletonList(endpoint), -1, -1); + } + catch (Exception e) { + logger.error(e); + throw new RuntimeException("Failed to retrieve FlightInfo", e); + } + } + + @Override + public Runnable acceptPut(CallContext callContext, FlightStream flightStream, StreamListener streamListener) + { + return null; + } + + @Override + public void doAction(CallContext callContext, Action action, StreamListener streamListener) + { + try { + String jsonRequest = new String(action.getBody(), StandardCharsets.UTF_8); + JsonNode rootNode = objectMapper.readTree(jsonRequest); + String schemaName = rootNode.get("interactionProperties").get("schema_name").asText(null); + + String query; + if (schemaName == null) { + query = "SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA"; + } + else { + schemaName = schemaName.toUpperCase(); + query = "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA='" + schemaName + "'"; + } + ResultSet rs = connection.createStatement().executeQuery(query); + List names = new ArrayList<>(); + while (rs.next()) { + names.add(rs.getString(1)); + } + + String jsonResponse = objectMapper.writeValueAsString(names); + streamListener.onNext(new Result(jsonResponse.getBytes(StandardCharsets.UTF_8))); + streamListener.onCompleted(); + } + catch (Exception e) { + streamListener.onError(e); + } + } + + @Override + public void listActions(CallContext callContext, StreamListener streamListener) + { + } + + private ArrowType convertSqlTypeToArrowType(String sqlType, int precision, int scale) + { + switch (sqlType.toUpperCase()) { + case "VARCHAR": + case "CHAR": + case "CHARACTER VARYING": + case "CHARACTER": + case "CLOB": + return new ArrowType.Utf8(); + case "INTEGER": + case "INT": + return new ArrowType.Int(32, true); + case "BIGINT": + return new ArrowType.Int(64, true); + case "SMALLINT": + return new ArrowType.Int(16, true); + case "TINYINT": + return new ArrowType.Int(8, true); + case "DOUBLE": + case "DOUBLE PRECISION": + case "FLOAT": + return new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE); + case "REAL": + return new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE); + case "BOOLEAN": + return new ArrowType.Bool(); + case "DATE": + return new ArrowType.Date(DateUnit.DAY); + case "TIMESTAMP": + return new ArrowType.Timestamp(TimeUnit.MILLISECOND, null); + case "TIME": + return new ArrowType.Time(TimeUnit.MILLISECOND, 32); + case "DECIMAL": + case "NUMERIC": + return new ArrowType.Decimal(precision, scale); + case "BINARY": + case "VARBINARY": + return new ArrowType.Binary(); + case "NULL": + return new ArrowType.Null(); + default: + throw new IllegalArgumentException("Unsupported SQL type: " + sqlType); + } + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowSplitManager.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowSplitManager.java new file mode 100644 index 0000000000000..5206c65717933 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowSplitManager.java @@ -0,0 +1,48 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.NodeManager; +import com.google.common.collect.ImmutableMap; + +import javax.inject.Inject; + +import java.util.Optional; + +public class TestingArrowSplitManager + extends AbstractArrowSplitManager +{ + private TestingArrowFlightConfig testconfig; + + private final NodeManager nodeManager; + + @Inject + public TestingArrowSplitManager(ArrowFlightClientHandler client, TestingArrowFlightConfig testconfig, NodeManager nodeManager) + { + super(client); + this.testconfig = testconfig; + this.nodeManager = nodeManager; + } + + @Override + protected ArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, ArrowTableLayoutHandle tableLayoutHandle) + { + ArrowTableHandle tableHandle = tableLayoutHandle.getTableHandle(); + Optional query = Optional.of(new TestingArrowQueryBuilder().buildSql(tableHandle.getSchema(), + tableHandle.getTable(), + tableLayoutHandle.getColumnHandles(), ImmutableMap.of(), + tableLayoutHandle.getTupleDomain())); + return new TestingArrowFlightRequest(config, testconfig, tableHandle.getSchema(), tableHandle.getTable(), query, nodeManager.getWorkerNodes().size()); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingConnectionProperties.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingConnectionProperties.java new file mode 100644 index 0000000000000..7f265036eeac0 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingConnectionProperties.java @@ -0,0 +1,24 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +public class TestingConnectionProperties +{ + public String database; + public String password; + public Integer port; + public String host; + public Boolean ssl; + public String username; +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingH2DatabaseSetup.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingH2DatabaseSetup.java new file mode 100644 index 0000000000000..7f21afc8c69fa --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingH2DatabaseSetup.java @@ -0,0 +1,273 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.ConnectorTableMetadata; +import com.facebook.presto.spi.RecordCursor; +import com.facebook.presto.spi.RecordSet; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.tpch.TpchMetadata; +import com.facebook.presto.tpch.TpchTableHandle; +import com.google.common.base.Joiner; +import io.airlift.tpch.TpchTable; +import org.jdbi.v3.core.Handle; +import org.jdbi.v3.core.Jdbi; +import org.jdbi.v3.core.statement.PreparedBatch; +import org.joda.time.DateTimeZone; + +import java.sql.Connection; +import java.sql.Date; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.DateType.DATE; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.tpch.TpchMetadata.TINY_SCHEMA_NAME; +import static com.facebook.presto.tpch.TpchRecordSet.createTpchRecordSet; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.airlift.tpch.TpchTable.CUSTOMER; +import static io.airlift.tpch.TpchTable.LINE_ITEM; +import static io.airlift.tpch.TpchTable.NATION; +import static io.airlift.tpch.TpchTable.ORDERS; +import static io.airlift.tpch.TpchTable.PART; +import static io.airlift.tpch.TpchTable.PART_SUPPLIER; +import static io.airlift.tpch.TpchTable.REGION; +import static io.airlift.tpch.TpchTable.SUPPLIER; +import static java.lang.String.format; +import static java.util.Collections.nCopies; + +public class TestingH2DatabaseSetup +{ + private static final Logger logger = Logger.get(TestingH2DatabaseSetup.class); + private TestingH2DatabaseSetup() + { + throw new UnsupportedOperationException("This is a utility class and cannot be instantiated"); + } + + public static void setup(String h2JdbcUrl) throws Exception + { + Class.forName("org.h2.Driver"); + + Connection conn = DriverManager.getConnection(h2JdbcUrl, "sa", ""); + + Jdbi jdbi = Jdbi.create(h2JdbcUrl, "sa", ""); + Handle handle = jdbi.open(); // Get a handle for the database connection + + TpchMetadata tpchMetadata = new TpchMetadata(""); + + Statement stmt = conn.createStatement(); + + // Create schema + stmt.execute("CREATE SCHEMA IF NOT EXISTS testdb"); + + stmt.execute("CREATE TABLE testdb.member (" + + " id INTEGER PRIMARY KEY," + + " name VARCHAR(50)," + + " sex CHAR(1)," + + " state CHAR(5)" + + ")"); + stmt.execute("INSERT INTO testdb.member VALUES(1, 'TOM', 'M', 'TMX '),(2, 'MARY', 'F', 'CD ')"); + + stmt.execute("CREATE TABLE testdb.event (" + + " id INTEGER PRIMARY KEY," + + " startDate DATE," + + " startTime TIME," + + " startTimestamp TIMESTAMP" + + ")"); + stmt.execute("INSERT INTO testdb.event VALUES(1, DATE '2004-12-31', TIME '23:59:59'," + + " TIMESTAMP '2005-12-31 23:59:59')"); + + stmt.execute("CREATE TABLE testdb.orders (\n" + + " orderkey BIGINT PRIMARY KEY,\n" + + " custkey BIGINT NOT NULL,\n" + + " orderstatus VARCHAR(1) NOT NULL,\n" + + " totalprice DOUBLE NOT NULL,\n" + + " orderdate DATE NOT NULL,\n" + + " orderpriority VARCHAR(15) NOT NULL,\n" + + " clerk VARCHAR(15) NOT NULL,\n" + + " shippriority INTEGER NOT NULL,\n" + + " comment VARCHAR(79) NOT NULL\n" + + ")"); + stmt.execute("CREATE INDEX custkey_index ON testdb.orders (custkey)"); + insertRows(tpchMetadata, ORDERS, handle); + + handle.execute("CREATE TABLE testdb.lineitem (\n" + + " orderkey BIGINT,\n" + + " partkey BIGINT NOT NULL,\n" + + " suppkey BIGINT NOT NULL,\n" + + " linenumber INTEGER,\n" + + " quantity DOUBLE NOT NULL,\n" + + " extendedprice DOUBLE NOT NULL,\n" + + " discount DOUBLE NOT NULL,\n" + + " tax DOUBLE NOT NULL,\n" + + " returnflag CHAR(1) NOT NULL,\n" + + " linestatus CHAR(1) NOT NULL,\n" + + " shipdate DATE NOT NULL,\n" + + " commitdate DATE NOT NULL,\n" + + " receiptdate DATE NOT NULL,\n" + + " shipinstruct VARCHAR(25) NOT NULL,\n" + + " shipmode VARCHAR(10) NOT NULL,\n" + + " comment VARCHAR(44) NOT NULL,\n" + + " PRIMARY KEY (orderkey, linenumber)" + + ")"); + insertRows(tpchMetadata, LINE_ITEM, handle); + + handle.execute(" CREATE TABLE testdb.partsupp (\n" + + " partkey BIGINT NOT NULL,\n" + + " suppkey BIGINT NOT NULL,\n" + + " availqty INTEGER NOT NULL,\n" + + " supplycost DOUBLE NOT NULL,\n" + + " comment VARCHAR(199) NOT NULL,\n" + + " PRIMARY KEY(partkey, suppkey)" + + ")"); + insertRows(tpchMetadata, PART_SUPPLIER, handle); + + handle.execute("CREATE TABLE testdb.nation (\n" + + " nationkey BIGINT PRIMARY KEY,\n" + + " name VARCHAR(25) NOT NULL,\n" + + " regionkey BIGINT NOT NULL,\n" + + " comment VARCHAR(152) NOT NULL\n" + + ")"); + insertRows(tpchMetadata, NATION, handle); + + handle.execute("CREATE TABLE testdb.region(\n" + + " regionkey BIGINT PRIMARY KEY,\n" + + " name VARCHAR(25) NOT NULL,\n" + + " comment VARCHAR(115) NOT NULL\n" + + ")"); + insertRows(tpchMetadata, REGION, handle); + handle.execute("CREATE TABLE testdb.part(\n" + + " partkey BIGINT PRIMARY KEY,\n" + + " name VARCHAR(55) NOT NULL,\n" + + " mfgr VARCHAR(25) NOT NULL,\n" + + " brand VARCHAR(10) NOT NULL,\n" + + " type VARCHAR(25) NOT NULL,\n" + + " size INTEGER NOT NULL,\n" + + " container VARCHAR(10) NOT NULL,\n" + + " retailprice DOUBLE NOT NULL,\n" + + " comment VARCHAR(23) NOT NULL\n" + + ")"); + insertRows(tpchMetadata, PART, handle); + handle.execute(" CREATE TABLE testdb.customer ( \n" + + " custkey BIGINT NOT NULL, \n" + + " name VARCHAR(25) NOT NULL, \n" + + " address VARCHAR(40) NOT NULL, \n" + + " nationkey BIGINT NOT NULL, \n" + + " phone VARCHAR(15) NOT NULL, \n" + + " acctbal DOUBLE NOT NULL, \n" + + " mktsegment VARCHAR(10) NOT NULL, \n" + + " comment VARCHAR(117) NOT NULL \n" + + " ) "); + insertRows(tpchMetadata, CUSTOMER, handle); + handle.execute(" CREATE TABLE testdb.supplier ( \n" + + " suppkey bigint NOT NULL, \n" + + " name varchar(25) NOT NULL, \n" + + " address varchar(40) NOT NULL, \n" + + " nationkey bigint NOT NULL, \n" + + " phone varchar(15) NOT NULL, \n" + + " acctbal double NOT NULL, \n" + + " comment varchar(101) NOT NULL \n" + + " ) "); + insertRows(tpchMetadata, SUPPLIER, handle); + + ResultSet resultSet1 = stmt.executeQuery("SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'TESTDB'"); + List tables = new ArrayList<>(); + while (resultSet1.next()) { + String tableName = resultSet1.getString("TABLE_NAME"); + tables.add(tableName); + } + logger.info("Tables in 'testdb' schema: %s", tables.stream().collect(Collectors.joining(", "))); + + ResultSet resultSet = stmt.executeQuery("SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA"); + List schemas = new ArrayList<>(); + while (resultSet.next()) { + String schemaName = resultSet.getString("SCHEMA_NAME"); + schemas.add(schemaName); + } + logger.info("Schemas: %s", schemas.stream().collect(Collectors.joining(", "))); + } + + private static void insertRows(TpchMetadata tpchMetadata, TpchTable tpchTable, Handle handle) + { + TpchTableHandle tableHandle = tpchMetadata.getTableHandle(null, new SchemaTableName(TINY_SCHEMA_NAME, tpchTable.getTableName())); + insertRows(tpchMetadata.getTableMetadata(null, tableHandle), handle, createTpchRecordSet(tpchTable, tableHandle.getScaleFactor())); + } + + private static void insertRows(ConnectorTableMetadata tableMetadata, Handle handle, RecordSet data) + { + List columns = tableMetadata.getColumns().stream() + .filter(columnMetadata -> !columnMetadata.isHidden()) + .collect(toImmutableList()); + + String schemaName = "testdb"; + String tableNameWithSchema = schemaName + "." + tableMetadata.getTable().getTableName(); + String vars = Joiner.on(',').join(nCopies(columns.size(), "?")); + String sql = format("INSERT INTO %s VALUES (%s)", tableNameWithSchema, vars); + + RecordCursor cursor = data.cursor(); + while (true) { + // insert 1000 rows at a time + PreparedBatch batch = handle.prepareBatch(sql); + for (int row = 0; row < 1000; row++) { + if (!cursor.advanceNextPosition()) { + if (batch.size() > 0) { + batch.execute(); + } + return; + } + for (int column = 0; column < columns.size(); column++) { + Type type = columns.get(column).getType(); + if (BOOLEAN.equals(type)) { + batch.bind(column, cursor.getBoolean(column)); + } + else if (BIGINT.equals(type)) { + batch.bind(column, cursor.getLong(column)); + } + else if (INTEGER.equals(type)) { + batch.bind(column, (int) cursor.getLong(column)); + } + else if (DOUBLE.equals(type)) { + batch.bind(column, cursor.getDouble(column)); + } + else if (type instanceof VarcharType) { + batch.bind(column, cursor.getSlice(column).toStringUtf8()); + } + else if (DATE.equals(type)) { + long millisUtc = TimeUnit.DAYS.toMillis(cursor.getLong(column)); + // H2 expects dates in to be millis at midnight in the JVM timezone + long localMillis = DateTimeZone.UTC.getMillisKeepLocal(DateTimeZone.getDefault(), millisUtc); + batch.bind(column, new Date(localMillis)); + } + else { + throw new IllegalArgumentException("Unsupported type " + type); + } + } + batch.add(); + } + batch.execute(); + } + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingInteractionProperties.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingInteractionProperties.java new file mode 100644 index 0000000000000..81b4ee5f7e811 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingInteractionProperties.java @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.fasterxml.jackson.annotation.JsonProperty; + +public class TestingInteractionProperties +{ + @JsonProperty("select_statement") + private String selectStatement; + + @JsonProperty("schema_name") + private String schema; + @JsonProperty("table_name") + private String table; + public String getSelectStatement() + { + return selectStatement; + } + + public String getSchema() + { + return schema; + } + + public String getTable() + { + return table; + } + + public void setSchema(String schema) + { + this.schema = schema; + } + + public void setSelectStatement(String selectStatement) + { + this.selectStatement = selectStatement; + } + + public void setTable(String table) + { + this.table = table; + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingRequestData.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingRequestData.java new file mode 100644 index 0000000000000..aee379dc04311 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingRequestData.java @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +public class TestingRequestData +{ + private TestingConnectionProperties connectionProperties; + private TestingInteractionProperties interactionProperties; + + public TestingConnectionProperties getConnectionProperties() + { + return connectionProperties; + } + + public TestingInteractionProperties getInteractionProperties() + { + return interactionProperties; + } + + public void setConnectionProperties(TestingConnectionProperties connectionProperties) + { + this.connectionProperties = connectionProperties; + } + + public void setInteractionProperties(TestingInteractionProperties interactionProperties) + { + this.interactionProperties = interactionProperties; + } +} diff --git a/presto-base-arrow-flight/src/test/resources/server.crt b/presto-base-arrow-flight/src/test/resources/server.crt new file mode 100644 index 0000000000000..9af01752e2959 --- /dev/null +++ b/presto-base-arrow-flight/src/test/resources/server.crt @@ -0,0 +1,21 @@ +-----BEGIN CERTIFICATE----- +MIIDazCCAlOgAwIBAgIUBYcCyz9qphcpgV9wIx2caf4ikoEwDQYJKoZIhvcNAQEL +BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yNDA4MzAxMzM0MjJaFw0yNTA4 +MzAxMzM0MjJaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw +HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEB +AQUAA4IBDwAwggEKAoIBAQCTu4vGRFav9/aNuGorx2v+xAevMRFfSi6x2onRCpMp +u+9L3bK3em+vSJcC/+QwvNkvTmf0ac8/G79BXQiUtPSV3tDSFnpO7LIICWpO2gJX +UKjTN+SZ3LfLEXSYAijVUCkWf8RswJLvJ12qNQRiv4IDrpBh/X+ICr+ALMa7y6cn +hBfpFtVyxlLkCp6Q71JgMGDjocv2195pc4uiMEDsn6tZfEKxmw1phX+CLke42srX ++AHoBQvlfcOBVIuMkxU7pZvvDDOuw0tSm0qomjRV+Azjn4oWbPDfkIEZIbBvgLwN +nxtKzSAgW+7xWQ0eVMRq1f5JPk6vLAV/yit8lFk8TCtdAgMBAAGjUzBRMB0GA1Ud +DgQWBBS4K4nHKbp4qw77KxR+tqUWXKCADjAfBgNVHSMEGDAWgBS4K4nHKbp4qw77 +KxR+tqUWXKCADjAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQBI +yraxUUgEb53q0aF2ID1ml3/9uMZALK4xQHKYP0dtkqdaS0CvuMsqSe0AT7PMDJOD +Q+s1lxe0gJ/lb7dLSOQpFGAU7nnUsBaL5alCGmYgeF6ayLzNWobsn4amnjFpSlOH +P/nOQvOrRgVbphozhaYBSjMtwy2Uzj8G7ZHn+Vg/fdEUf/mGb7/4M9sL7iQWJmIg +slzNnA/LssAqGvyJ65cSkjGrb82VvDJv6JZ7fam6nrMEpq5D4H4GzOwupNvdItLH +Zf9IBsedkDEu8P+Rzp9kGIMbGTLfL3u/pP1st+pD4PfHYRsbl8aAUqYA7P+d29mf +aVcp98Z1h1s5e005BT1T +-----END CERTIFICATE----- diff --git a/presto-base-arrow-flight/src/test/resources/server.key b/presto-base-arrow-flight/src/test/resources/server.key new file mode 100644 index 0000000000000..894dc1fad8e61 --- /dev/null +++ b/presto-base-arrow-flight/src/test/resources/server.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCTu4vGRFav9/aN +uGorx2v+xAevMRFfSi6x2onRCpMpu+9L3bK3em+vSJcC/+QwvNkvTmf0ac8/G79B +XQiUtPSV3tDSFnpO7LIICWpO2gJXUKjTN+SZ3LfLEXSYAijVUCkWf8RswJLvJ12q +NQRiv4IDrpBh/X+ICr+ALMa7y6cnhBfpFtVyxlLkCp6Q71JgMGDjocv2195pc4ui +MEDsn6tZfEKxmw1phX+CLke42srX+AHoBQvlfcOBVIuMkxU7pZvvDDOuw0tSm0qo +mjRV+Azjn4oWbPDfkIEZIbBvgLwNnxtKzSAgW+7xWQ0eVMRq1f5JPk6vLAV/yit8 +lFk8TCtdAgMBAAECggEAAKgSdrLajMUmFhql9CRafUMbQqLN8DW47+bn+mMY5NRW +O6jUUL7tTKLesu92sOXB9FUdnqcyudXSe4Shk2GbfagEFw7tA7lHEESUcZ3D6WXt +HiUvMaTatz8QXNWTn2EQEa7HLXGMpZ3v61/5cUPnHMOTli/ld3IOyE/KoU6GI2WP +5+8eKJUdJiZDLdywLTF/OefTuHhnbkNxsoPtR1DwO1UZKsX1OFo0b3x6ibqPZSu3 +x01QkYfTlUoVjUBNFQTcyfXKWX5rrAK9C/2h7TVJxbktj7V4DOGLb8ee6oN+yr4e +lIg6Vfz/KAakXKaRVappoxSxvFTZXqGhL8kBnLPFoQKBgQDIfr3ir65pi2Gzc+6e +0DWCkznfDqVvfgzoAxCrIqKp/6MqVUIJmProPjBztSAhiaNhK8sDbLFBssw0AOeZ +YEqMcY69ClyVrtDoGYBNeuCTTz3S8yXf0mzZe+CgWA6vNnbzEgKe7NqB+EbGjLnH +BcNoSE1GuBjt7Juvuor8P9XW/QKBgQC8oXtd0gQmpIyApgTXtE4AdbckPZE6sX+5 +aWWk9RY7yL9ivx86ej+gG2oj0YqmBtCbcabrYdF2Dv1idSpr+UoW5cxMa+jfAwSQ +8WNeUOeWk1OfEMrnbp8XLZMTDMGMTTyHdPMNpXq0zw7bvIReiRzeFb7fEgdE2+q+ +Me+GHU5D4QKBgDK+eTrFciQ+ZbTwk6VYVyK8NnpxD4f/ZC7Yj8BwnLDgBaDyQSuC +r4ZWLxcp8X7rghFW7yPnv5k8Mpi63eMgzt1q5FCOLc6olzEXOzTg87P061XXum9C +p9AHnVuXzeekpkhw937XvZoFh4w7E83+dG2RVxWeBJk7OFAqq4Cae3nVAoGAYXOD +yqqvnk8wj1417kKmcbJfFYgBObNt6xo6ewhrniNOTPO0bH+v00WWhj7BRJkMuOH0 +fHKixj1kRrOFYRb/YekCrRCq1Fw4xbEPxzBBFRe0Ad+pE/ugkVboPtU+QP++H7UZ +xJkTVcoLQRaZxEVN9qaBX7luq/J5yhz+Q+lr/8ECgYEAsDUXgtuRyMvKr1PFzrfQ +n/aVb4wmY39Zyr1sERHwHYLUvX2gsiNMaUq7lFWUNBGy0x0gcaMCWOlxXT3SMUP9 +aE3LCjrkUWP69/RuFjKdcustLGOtJmnr9ioSpZkkrR49tmiu1SWFsK2IunTd4GUg +/p/ITUEsA2oiyDZcpTkWx0Q= +-----END PRIVATE KEY----- diff --git a/presto-docs/src/main/sphinx/connector.rst b/presto-docs/src/main/sphinx/connector.rst index f07506b460cab..d337fe4ed12d1 100644 --- a/presto-docs/src/main/sphinx/connector.rst +++ b/presto-docs/src/main/sphinx/connector.rst @@ -9,6 +9,7 @@ from different data sources. :maxdepth: 1 connector/accumulo + connector/base-arrow-flight connector/bigquery connector/blackhole connector/cassandra diff --git a/presto-docs/src/main/sphinx/connector/base-arrow-flight.rst b/presto-docs/src/main/sphinx/connector/base-arrow-flight.rst new file mode 100644 index 0000000000000..80c00311cee7d --- /dev/null +++ b/presto-docs/src/main/sphinx/connector/base-arrow-flight.rst @@ -0,0 +1,96 @@ + +====================== +Arrow Flight Connector +====================== +This connector allows querying multiple data sources that are supported by an Arrow Flight server. +Apache Arrow enhances performance and efficiency in data-intensive applications through its columnar memory layout, zero-copy reads, vectorized execution, cross-language interoperability, rich data type support, and optimization for modern hardware. These features collectively reduce overhead, improve data processing speeds, and facilitate seamless data exchange between different systems and languages. + +Getting Started with base-arrow-module: Essential Abstract Methods for Developers +--------------------------------------------------------------------------------- +To use the base-arrow-module, you need to implement certain abstract methods that are specific to your use case. Below are the required classes and their purposes: + +* ``ArrowFlightClientHandler.java`` + This class is responsible for initializing the Flight client and retrieving Flight information from the Flight server. To authenticate the Flight server, you must implement the abstract method ``getCallOptions`` in ArrowFlightClientHandler, which returns the ``CredentialCallOption`` specific to your Flight server. + +* ``AbstractArrowFlightRequest.java`` + Implement this class to define the request data, including the data source type, connection properties, the number of partitions and other data required to interact with database. + +* ``AbstractArrowMetadata.java`` + To retrieve metadata (schema and table information), implement the abstract methods in the ArrowAbstractMetadata class. + +* ``AbstractArrowSplitManager.java`` + Extend the ArrowAbstractSplitManager class to implement the Arrow Flight request, defining the Arrow split. + +* ``ArrowPlugin.java`` + Register your connector name by extending the ArrowPlugin class. + +* ``ArrowFlightRequest`` + The ``getCommand`` method in the ``ArrowFlightRequest`` interface should return a byte array for the Flight request. + + +Configuration +------------- +Create a catalog file +in ``etc/catalog`` named, for example, ``arrowmariadb.properties``, to +mount the Flight connector as the ``arrowmariadb`` catalog. +Create the file with the following contents, replacing the +connection properties as appropriate for your setup: + + +.. code-block:: none + + + connector.name= + arrow-flight.server= + arrow-flight.server.port= + + + +Add other properties that are required for your Flight server to connect. + +========================================== ============================================================== +Property Name Description +========================================== ============================================================== +``arrow-flight.server`` Endpoint of Arrow Flight server +``arrow-flight.server.port`` Flight server port +``arrow-flight.server-ssl-certificate`` Pass ssl certificate +``arrow-flight.server.verify`` To verify server +``arrow-flight.server-ssl-enabled`` Port is ssl enabled +========================================== ============================================================== + +Querying Arrow-Flight +--------------------- + +The Flight connector provides schema for each supported *database*. +Example for MariaDB is shown below. +To see the available schemas, run ``SHOW SCHEMAS``:: + + SHOW SCHEMAS FROM arrowmariadb; + +To view the tables in the MariaDB database named ``user``, +run ``SHOW TABLES``:: + + SHOW TABLES FROM arrowmariadb.user; + +To see a list of the columns in the ``admin`` table in the ``user`` database, +use either of the following commands:: + + DESCRIBE arrowmariadb.user.admin; + SHOW COLUMNS FROM arrowmariadb.user.admin; + +Finally, you can access the ``admin`` table in the ``user`` database:: + + SELECT * FROM arrowmariadb.user.admin; + +If you used a different name for your catalog properties file, use +that catalog name instead of ``arrowmariadb`` in the above examples. + + +Flight Connector Limitations +---------------------------------- + +* SELECT and DESCRIBE queries are supported by this connector template. Implementing modules can add support for additional features. + +* Flight connector can query against only those datasources which are supported by the Flight server. + +* The user should have the Flight server running for the Flight connector to work.