Skip to content

Commit

Permalink
Create CommandVisitorBase
Browse files Browse the repository at this point in the history
  • Loading branch information
devinrsmith committed Nov 6, 2024
1 parent bb77291 commit a902aef
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import io.grpc.Status;
Expand Down Expand Up @@ -141,4 +142,83 @@ private static <T extends Message> T unpack(Any command, Class<T> clazz, String
.format("Invalid command, provided message cannot be unpacked as %s, %s", clazz.getName(), logId));
}
}

public static abstract class CommandVisitorBase<T> implements CommandVisitor<T> {
public abstract T visitDefault(Descriptor descriptor, Object command);

@Override
public T visit(CommandGetCatalogs command) {
return visitDefault(CommandGetCatalogs.getDescriptor(), command);
}

@Override
public T visit(CommandGetDbSchemas command) {
return visitDefault(CommandGetDbSchemas.getDescriptor(), command);
}

@Override
public T visit(CommandGetTableTypes command) {
return visitDefault(CommandGetTableTypes.getDescriptor(), command);
}

@Override
public T visit(CommandGetImportedKeys command) {
return visitDefault(CommandGetImportedKeys.getDescriptor(), command);
}

@Override
public T visit(CommandGetExportedKeys command) {
return visitDefault(CommandGetExportedKeys.getDescriptor(), command);
}

@Override
public T visit(CommandGetPrimaryKeys command) {
return visitDefault(CommandGetPrimaryKeys.getDescriptor(), command);
}

@Override
public T visit(CommandGetTables command) {
return visitDefault(CommandGetTables.getDescriptor(), command);
}

@Override
public T visit(CommandStatementQuery command) {
return visitDefault(CommandStatementQuery.getDescriptor(), command);
}

@Override
public T visit(CommandPreparedStatementQuery command) {
return visitDefault(CommandPreparedStatementQuery.getDescriptor(), command);
}

@Override
public T visit(CommandGetSqlInfo command) {
return visitDefault(CommandGetSqlInfo.getDescriptor(), command);
}

@Override
public T visit(CommandStatementUpdate command) {
return visitDefault(CommandStatementUpdate.getDescriptor(), command);
}

@Override
public T visit(CommandGetCrossReference command) {
return visitDefault(CommandGetCrossReference.getDescriptor(), command);
}

@Override
public T visit(CommandStatementSubstraitPlan command) {
return visitDefault(CommandStatementSubstraitPlan.getDescriptor(), command);
}

@Override
public T visit(CommandPreparedStatementUpdate command) {
return visitDefault(CommandPreparedStatementUpdate.getDescriptor(), command);
}

@Override
public T visit(CommandGetXdbcTypeInfo command) {
return visitDefault(CommandGetXdbcTypeInfo.getDescriptor(), command);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,14 @@
import org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementRequest;
import org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementResult;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCatalogs;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCrossReference;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetDbSchemas;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetExportedKeys;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetImportedKeys;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetPrimaryKeys;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetSqlInfo;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetTableTypes;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetTables;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetXdbcTypeInfo;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementQuery;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementUpdate;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementQuery;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementSubstraitPlan;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementUpdate;
import org.apache.arrow.flight.sql.impl.FlightSql.TicketStatementQuery;
import org.apache.arrow.vector.types.pojo.ArrowType.Utf8;
import org.apache.arrow.vector.types.pojo.Field;
Expand Down Expand Up @@ -164,22 +158,6 @@ public final class FlightSqlResolver implements ActionResolver, CommandResolver
@VisibleForTesting
static final Schema DATASET_SCHEMA_SENTINEL = new Schema(List.of(Field.nullable("DO_NOT_USE", Utf8.INSTANCE)));

// Need dense_union support to implement this.
private static final CommandHandler<CommandGetSqlInfo> GET_SQL_INFO_HANDLER =
new UnsupportedCommand<>(CommandGetSqlInfo.getDescriptor(), CommandGetSqlInfo.class);
private static final CommandHandler<CommandStatementUpdate> STATEMENT_UPDATE_HANDLER =
new UnsupportedCommand<>(CommandStatementUpdate.getDescriptor(), CommandStatementUpdate.class);
private static final CommandHandler<CommandGetCrossReference> GET_CROSS_REFERENCE_HANDLER =
new UnsupportedCommand<>(CommandGetCrossReference.getDescriptor(), CommandGetCrossReference.class);
private static final CommandHandler<CommandStatementSubstraitPlan> STATEMENT_SUBSTRAIT_PLAN_HANDLER =
new UnsupportedCommand<>(CommandStatementSubstraitPlan.getDescriptor(),
CommandStatementSubstraitPlan.class);
private static final CommandHandler<CommandPreparedStatementUpdate> PREPARED_STATEMENT_UPDATE_HANDLER =
new UnsupportedCommand<>(CommandPreparedStatementUpdate.getDescriptor(),
CommandPreparedStatementUpdate.class);
private static final CommandHandler<CommandGetXdbcTypeInfo> GET_XDBC_TYPE_INFO_HANDLER =
new UnsupportedCommand<>(CommandGetXdbcTypeInfo.getDescriptor(), CommandGetXdbcTypeInfo.class);

// Unable to depends on TicketRouter, would be a circular dependency atm (since TicketRouter depends on all the
// TicketResolvers).
// private final TicketRouter router;
Expand Down Expand Up @@ -300,7 +278,7 @@ public ExportObject<FlightInfo> flightInfoFor(
return FlightSqlCommandHelper.visit(descriptor, new GetFlightInfoImpl(session, descriptor), logId);
}

private class GetFlightInfoImpl implements FlightSqlCommandHelper.CommandVisitor<ExportObject<FlightInfo>> {
private class GetFlightInfoImpl extends FlightSqlCommandHelper.CommandVisitorBase<ExportObject<FlightInfo>> {
private final SessionState session;
private final FlightDescriptor descriptor;

Expand All @@ -309,6 +287,11 @@ public GetFlightInfoImpl(SessionState session, FlightDescriptor descriptor) {
this.descriptor = Objects.requireNonNull(descriptor);
}

@Override
public ExportObject<FlightInfo> visitDefault(Descriptor descriptor, Object command) {
return submit(new UnsupportedCommand<>(descriptor), command);
}

@Override
public ExportObject<FlightInfo> visit(CommandGetCatalogs command) {
return submit(CommandGetCatalogsConstants.HANDLER, command);
Expand Down Expand Up @@ -354,36 +337,6 @@ public ExportObject<FlightInfo> visit(CommandPreparedStatementQuery command) {
return submit(new CommandPreparedStatementQueryImpl(session), command);
}

@Override
public ExportObject<FlightInfo> visit(CommandGetSqlInfo command) {
return submit(GET_SQL_INFO_HANDLER, command);
}

@Override
public ExportObject<FlightInfo> visit(CommandStatementUpdate command) {
return submit(STATEMENT_UPDATE_HANDLER, command);
}

@Override
public ExportObject<FlightInfo> visit(CommandGetCrossReference command) {
return submit(GET_CROSS_REFERENCE_HANDLER, command);
}

@Override
public ExportObject<FlightInfo> visit(CommandStatementSubstraitPlan command) {
return submit(STATEMENT_SUBSTRAIT_PLAN_HANDLER, command);
}

@Override
public ExportObject<FlightInfo> visit(CommandPreparedStatementUpdate command) {
return submit(PREPARED_STATEMENT_UPDATE_HANDLER, command);
}

@Override
public ExportObject<FlightInfo> visit(CommandGetXdbcTypeInfo command) {
return submit(GET_XDBC_TYPE_INFO_HANDLER, command);
}

private <T> ExportObject<FlightInfo> submit(CommandHandler<T> handler, T command) {
return session.<FlightInfo>nonExport().submit(() -> getInfo(handler, command));
}
Expand Down Expand Up @@ -820,7 +773,7 @@ public Table resolve() {
private static final class UnsupportedCommand<T> implements CommandHandler<T>, TicketHandler {
private final Descriptor descriptor;

UnsupportedCommand(Descriptor descriptor, Class<? extends Message> clazz) {
UnsupportedCommand(Descriptor descriptor) {
this.descriptor = Objects.requireNonNull(descriptor);
}

Expand Down

0 comments on commit a902aef

Please sign in to comment.