From b7ddabbc2f8266af4a8b743a478b9c62eab354c6 Mon Sep 17 00:00:00 2001 From: Jack Klamer Date: Mon, 6 May 2024 19:02:08 -0500 Subject: [PATCH] Refactor block building decoder creation to connector specific class --- .../hive/formats/avro/AvroFileReader.java | 12 +- .../hive/formats}/avro/AvroHiveConstants.java | 2 +- .../hive/formats/avro/AvroPageDataReader.java | 865 +----------------- .../avro/AvroPagePositionDataWriter.java | 2 + .../formats/avro/AvroTypeBlockHandler.java | 48 + .../hive/formats/avro/AvroTypeManager.java | 22 - .../hive/formats/avro/AvroTypeUtils.java | 68 +- .../avro/BaseAvroTypeBlockHandlerImpls.java | 412 +++++++++ .../formats/avro/BlockBuildingDecoder.java | 25 + .../avro/HiveAvroTypeBlockHandler.java | 429 +++++++++ .../formats/avro/HiveAvroTypeManager.java | 125 +++ ...ativeLogicalTypesAvroTypeBlockHandler.java | 293 ++++++ .../NativeLogicalTypesAvroTypeManager.java | 280 +++--- .../formats/avro/RowBlockBuildingDecoder.java | 170 ++++ .../formats/avro/model/AvroLogicalType.java | 50 +- .../formats/avro/model/AvroReadAction.java | 84 +- .../hive/formats/avro/model/DoubleRead.java | 7 +- .../hive/formats/avro/model/FloatRead.java | 7 +- .../hive/formats/avro/model/LongRead.java | 7 +- .../model/SkipFieldRecordFieldReadAction.java | 125 ++- .../avro/BaseAvroTypeBlockHandler.java | 44 + .../formats/avro/NoOpAvroTypeManager.java | 20 - .../trino/hive/formats/avro/TestAvroBase.java | 4 +- ...ataReaderWithAvroNativeTypeManagement.java | 8 +- ...tAvroPageDataReaderWithoutTypeManager.java | 154 +--- ...tAvroPageDataWriterWithoutTypeManager.java | 10 +- .../avro/TestHiveAvroTypeBlockHandler.java | 171 ++++ .../hive/avro/AvroFileWriterFactory.java | 6 +- .../plugin/hive/avro/AvroHiveFileUtils.java | 20 +- .../plugin/hive/avro/AvroHiveFileWriter.java | 5 +- .../plugin/hive/avro/AvroPageSource.java | 4 +- .../hive/avro/AvroPageSourceFactory.java | 6 +- .../plugin/hive/avro/HiveAvroTypeManager.java | 267 ------ .../hive/avro/TestAvroSchemaGeneration.java | 2 +- 34 files changed, 2131 insertions(+), 1623 deletions(-) rename {plugin/trino-hive/src/main/java/io/trino/plugin/hive => lib/trino-hive-formats/src/main/java/io/trino/hive/formats}/avro/AvroHiveConstants.java (97%) create mode 100644 lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroTypeBlockHandler.java create mode 100644 lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/BaseAvroTypeBlockHandlerImpls.java create mode 100644 lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/BlockBuildingDecoder.java create mode 100644 lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/HiveAvroTypeBlockHandler.java create mode 100644 lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/HiveAvroTypeManager.java create mode 100644 lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/NativeLogicalTypesAvroTypeBlockHandler.java create mode 100644 lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/RowBlockBuildingDecoder.java create mode 100644 lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/BaseAvroTypeBlockHandler.java create mode 100644 lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/TestHiveAvroTypeBlockHandler.java delete mode 100644 plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/HiveAvroTypeManager.java diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroFileReader.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroFileReader.java index 480caf476258..406cfd67a337 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroFileReader.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroFileReader.java @@ -44,23 +44,23 @@ public class AvroFileReader public AvroFileReader( TrinoInputFile inputFile, Schema schema, - AvroTypeManager avroTypeManager) + AvroTypeBlockHandler avroTypeBlockHandler) throws IOException, AvroTypeException { - this(inputFile, schema, avroTypeManager, 0, OptionalLong.empty()); + this(inputFile, schema, avroTypeBlockHandler, 0, OptionalLong.empty()); } public AvroFileReader( TrinoInputFile inputFile, Schema schema, - AvroTypeManager avroTypeManager, + AvroTypeBlockHandler avroTypeBlockHandler, long offset, OptionalLong length) throws IOException, AvroTypeException { requireNonNull(inputFile, "inputFile is null"); requireNonNull(schema, "schema is null"); - requireNonNull(avroTypeManager, "avroTypeManager is null"); + requireNonNull(avroTypeBlockHandler, "avroTypeBlockHandler is null"); long fileSize = inputFile.length(); verify(offset >= 0, "offset is negative"); @@ -69,7 +69,7 @@ public AvroFileReader( end = length.stream().map(l -> l + offset).findFirst(); end.ifPresent(endLong -> verify(endLong <= fileSize, "offset plus length is greater than data size")); input = new TrinoDataInputStream(inputFile.newStream()); - dataReader = new AvroPageDataReader(schema, avroTypeManager); + dataReader = new AvroPageDataReader(schema, avroTypeBlockHandler); try { fileReader = new DataFileReader<>(new TrinoDataInputStreamAsAvroSeekableInput(input, fileSize), dataReader); fileReader.sync(offset); @@ -79,7 +79,7 @@ public AvroFileReader( // so the exception is wrapped in a runtime exception that must be unwrapped throw runtimeWrapper.getAvroTypeException(); } - avroTypeManager.configure(fileReader.getMetaKeys().stream().collect(toImmutableMap(Function.identity(), fileReader::getMeta))); + avroTypeBlockHandler.configure(fileReader.getMetaKeys().stream().collect(toImmutableMap(Function.identity(), fileReader::getMeta))); } public long getCompletedBytes() diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroHiveConstants.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroHiveConstants.java similarity index 97% rename from plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroHiveConstants.java rename to lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroHiveConstants.java index 8bbf18ad7c33..57d75f3e8be0 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroHiveConstants.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroHiveConstants.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.trino.plugin.hive.avro; +package io.trino.hive.formats.avro; public final class AvroHiveConstants { diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroPageDataReader.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroPageDataReader.java index 62e5e6a40d9c..418f6e1505b1 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroPageDataReader.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroPageDataReader.java @@ -13,93 +13,39 @@ */ package io.trino.hive.formats.avro; -import io.airlift.slice.Slice; -import io.airlift.slice.Slices; -import io.trino.hive.formats.avro.model.ArrayReadAction; -import io.trino.hive.formats.avro.model.AvroReadAction; -import io.trino.hive.formats.avro.model.BooleanRead; -import io.trino.hive.formats.avro.model.BytesRead; -import io.trino.hive.formats.avro.model.DefaultValueFieldRecordFieldReadAction; -import io.trino.hive.formats.avro.model.DoubleRead; -import io.trino.hive.formats.avro.model.EnumReadAction; -import io.trino.hive.formats.avro.model.FixedRead; -import io.trino.hive.formats.avro.model.FloatRead; -import io.trino.hive.formats.avro.model.IntRead; -import io.trino.hive.formats.avro.model.LongRead; -import io.trino.hive.formats.avro.model.MapReadAction; -import io.trino.hive.formats.avro.model.NullRead; -import io.trino.hive.formats.avro.model.ReadErrorReadAction; -import io.trino.hive.formats.avro.model.ReadFieldAction; -import io.trino.hive.formats.avro.model.ReadingUnionReadAction; -import io.trino.hive.formats.avro.model.RecordFieldReadAction; -import io.trino.hive.formats.avro.model.RecordReadAction; -import io.trino.hive.formats.avro.model.SkipFieldRecordFieldReadAction; -import io.trino.hive.formats.avro.model.StringRead; -import io.trino.hive.formats.avro.model.WrittenUnionReadAction; import io.trino.spi.Page; import io.trino.spi.PageBuilder; -import io.trino.spi.block.ArrayBlockBuilder; -import io.trino.spi.block.BlockBuilder; -import io.trino.spi.block.MapBlockBuilder; -import io.trino.spi.block.RowBlockBuilder; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; -import org.apache.avro.Resolver; import org.apache.avro.Schema; -import org.apache.avro.generic.GenericData; -import org.apache.avro.io.BinaryDecoder; import org.apache.avro.io.DatumReader; import org.apache.avro.io.Decoder; -import org.apache.avro.io.DecoderFactory; -import org.apache.avro.io.Encoder; -import org.apache.avro.io.EncoderFactory; -import org.apache.avro.io.FastReaderBuilder; -import org.apache.avro.io.parsing.ResolvingGrammarGenerator; -import org.apache.avro.util.internal.Accessor; -import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.util.List; import java.util.Optional; -import java.util.function.BiConsumer; -import java.util.function.IntFunction; -import java.util.stream.IntStream; -import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Verify.verify; -import static io.trino.hive.formats.UnionToRowCoercionUtils.UNION_FIELD_TAG_TYPE; -import static io.trino.hive.formats.avro.AvroTypeUtils.isSimpleNullableUnion; -import static io.trino.hive.formats.avro.AvroTypeUtils.typeFromAvro; -import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.spi.type.BooleanType.BOOLEAN; -import static io.trino.spi.type.DoubleType.DOUBLE; -import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.spi.type.RealType.REAL; -import static io.trino.spi.type.VarbinaryType.VARBINARY; -import static io.trino.spi.type.VarcharType.VARCHAR; -import static java.lang.Float.floatToRawIntBits; +import static io.trino.hive.formats.avro.AvroTypeUtils.verifyNoCircularReferences; import static java.util.Objects.requireNonNull; public class AvroPageDataReader implements DatumReader> { - // same limit as org.apache.avro.io.BinaryDecoder - private static final long MAX_ARRAY_SIZE = (long) Integer.MAX_VALUE - 8L; - private final Schema readerSchema; private Schema writerSchema; private final PageBuilder pageBuilder; private RowBlockBuildingDecoder rowBlockBuildingDecoder; - private final AvroTypeManager typeManager; + private final AvroTypeBlockHandler typeManager; - public AvroPageDataReader(Schema readerSchema, AvroTypeManager typeManager) + public AvroPageDataReader(Schema readerSchema, AvroTypeBlockHandler typeManager) throws AvroTypeException { this.readerSchema = requireNonNull(readerSchema, "readerSchema is null"); writerSchema = this.readerSchema; this.typeManager = requireNonNull(typeManager, "typeManager is null"); + verifyNoCircularReferences(readerSchema); try { - Type readerSchemaType = typeFromAvro(this.readerSchema, typeManager); + Type readerSchemaType = typeManager.typeFor(readerSchema); verify(readerSchemaType instanceof RowType, "Root Avro type must be a row"); pageBuilder = new PageBuilder(readerSchemaType.getTypeParameters()); initialize(); @@ -158,807 +104,6 @@ public Optional flush() return Optional.empty(); } - private abstract static class BlockBuildingDecoder - { - protected abstract void decodeIntoBlock(Decoder decoder, BlockBuilder builder) - throws IOException; - } - - private static BlockBuildingDecoder createBlockBuildingDecoderForAction(AvroReadAction action, AvroTypeManager typeManager) - throws AvroTypeException - { - Optional> consumer = typeManager.overrideBuildingFunctionForSchema(action.readSchema()); - if (consumer.isPresent()) { - return new UserDefinedBlockBuildingDecoder(action.readSchema(), action.writeSchema(), consumer.get()); - } - return switch (action) { - case NullRead __ -> NullBlockBuildingDecoder.INSTANCE; - case BooleanRead __ -> BooleanBlockBuildingDecoder.INSTANCE; - case IntRead __ -> IntBlockBuildingDecoder.INSTANCE; - case LongRead longRead -> new LongBlockBuildingDecoder(longRead.getLongDecoder()); - case FloatRead floatRead -> new FloatBlockBuildingDecoder(floatRead.getFloatDecoder()); - case DoubleRead doubleRead -> new DoubleBlockBuildingDecoder(doubleRead.getDoubleDecoder()); - case BytesRead __ -> BytesBlockBuildingDecoder.INSTANCE; - case FixedRead __ -> new FixedBlockBuildingDecoder(action.readSchema().getFixedSize()); - case StringRead __ -> StringBlockBuildingDecoder.INSTANCE; - case ArrayReadAction arrayReadAction -> new ArrayBlockBuildingDecoder(arrayReadAction, typeManager); - case MapReadAction mapReadAction -> new MapBlockBuildingDecoder(mapReadAction, typeManager); - case EnumReadAction enumReadAction -> new EnumBlockBuildingDecoder(enumReadAction); - case RecordReadAction recordReadAction -> new RowBlockBuildingDecoder(recordReadAction, typeManager); - case WrittenUnionReadAction writtenUnionReadAction -> { - if (writtenUnionReadAction.readSchema().getType() == Schema.Type.UNION && !isSimpleNullableUnion(writtenUnionReadAction.readSchema())) { - yield new WriterUnionCoercedIntoRowBlockBuildingDecoder(writtenUnionReadAction, typeManager); - } - else { - // reading a union with non-union or nullable union, optimistically try to create the reader, will fail at read time with any underlying issues - yield new WriterUnionBlockBuildingDecoder(writtenUnionReadAction, typeManager); - } - } - case ReadingUnionReadAction readingUnionReadAction -> { - if (isSimpleNullableUnion(readingUnionReadAction.readSchema())) { - yield createBlockBuildingDecoderForAction(readingUnionReadAction.actualAction(), typeManager); - } - else { - yield new ReaderUnionCoercedIntoRowBlockBuildingDecoder(readingUnionReadAction, typeManager); - } - } - case ReadErrorReadAction readErrorReadAction -> new TypeErrorThrower(readErrorReadAction); - }; - } - - private static class TypeErrorThrower - extends BlockBuildingDecoder - { - private final ReadErrorReadAction action; - - public TypeErrorThrower(ReadErrorReadAction action) - { - this.action = requireNonNull(action, "action is null"); - } - - @Override - protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) - throws IOException - { - throw new IOException(new AvroTypeException("Resolution action returned with error " + action)); - } - } - - // Different plugins may have different Avro Schema to Type mappings - // that are currently transforming GenericDatumReader returned objects into their target type during the record reading process - // This block building decoder allows plugin writers to port that code directly and use within this reader - // This mechanism is used to enhance Avro longs into timestamp types according to schema metadata - private static class UserDefinedBlockBuildingDecoder - extends BlockBuildingDecoder - { - private final BiConsumer userBuilderFunction; - private final DatumReader datumReader; - - public UserDefinedBlockBuildingDecoder(Schema readerSchema, Schema writerSchema, BiConsumer userBuilderFunction) - throws AvroTypeException - { - requireNonNull(readerSchema, "readerSchema is null"); - requireNonNull(writerSchema, "writerSchema is null"); - try { - FastReaderBuilder fastReaderBuilder = new FastReaderBuilder(new GenericData()); - datumReader = fastReaderBuilder.createDatumReader(writerSchema, readerSchema); - } - catch (IOException ioException) { - // IOException only thrown when default encoded in schema is unable to be re-serialized into bytes with proper typing - // translate into type exception - throw new AvroTypeException("Unable to decode default value in schema " + readerSchema, ioException); - } - this.userBuilderFunction = requireNonNull(userBuilderFunction, "userBuilderFunction is null"); - } - - @Override - protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) - throws IOException - { - userBuilderFunction.accept(builder, datumReader.read(null, decoder)); - } - } - - private static class NullBlockBuildingDecoder - extends BlockBuildingDecoder - { - private static final NullBlockBuildingDecoder INSTANCE = new NullBlockBuildingDecoder(); - - @Override - protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) - throws IOException - { - decoder.readNull(); - builder.appendNull(); - } - } - - private static class BooleanBlockBuildingDecoder - extends BlockBuildingDecoder - { - private static final BooleanBlockBuildingDecoder INSTANCE = new BooleanBlockBuildingDecoder(); - - @Override - protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) - throws IOException - { - BOOLEAN.writeBoolean(builder, decoder.readBoolean()); - } - } - - private static class IntBlockBuildingDecoder - extends BlockBuildingDecoder - { - private static final IntBlockBuildingDecoder INSTANCE = new IntBlockBuildingDecoder(); - - @Override - protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) - throws IOException - { - INTEGER.writeLong(builder, decoder.readInt()); - } - } - - private static class LongBlockBuildingDecoder - extends BlockBuildingDecoder - { - private final LongIoFunction extractLong; - - public LongBlockBuildingDecoder(LongIoFunction extractLong) - { - this.extractLong = requireNonNull(extractLong, "extractLong is null"); - } - - @Override - protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) - throws IOException - { - BIGINT.writeLong(builder, extractLong.apply(decoder)); - } - } - - private static class FloatBlockBuildingDecoder - extends BlockBuildingDecoder - { - private final FloatIoFunction extractFloat; - - public FloatBlockBuildingDecoder(FloatIoFunction extractFloat) - { - this.extractFloat = requireNonNull(extractFloat, "extractFloat is null"); - } - - @Override - protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) - throws IOException - { - REAL.writeLong(builder, floatToRawIntBits(extractFloat.apply(decoder))); - } - } - - private static class DoubleBlockBuildingDecoder - extends BlockBuildingDecoder - { - private final DoubleIoFunction extractDouble; - - public DoubleBlockBuildingDecoder(DoubleIoFunction extractDouble) - { - this.extractDouble = requireNonNull(extractDouble, "extractDouble is null"); - } - - @Override - protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) - throws IOException - { - DOUBLE.writeDouble(builder, extractDouble.apply(decoder)); - } - } - - private static class StringBlockBuildingDecoder - extends BlockBuildingDecoder - { - private static final StringBlockBuildingDecoder INSTANCE = new StringBlockBuildingDecoder(); - - @Override - protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) - throws IOException - { - // it is only possible to read String type when underlying write type is String or Bytes - // both have the same encoding, so coercion is a no-op - long size = decoder.readLong(); - if (size > MAX_ARRAY_SIZE) { - throw new IOException("Unable to read avro String with size greater than %s. Found String size: %s".formatted(MAX_ARRAY_SIZE, size)); - } - byte[] bytes = new byte[(int) size]; - decoder.readFixed(bytes); - VARCHAR.writeSlice(builder, Slices.wrappedBuffer(bytes)); - } - } - - private static class BytesBlockBuildingDecoder - extends BlockBuildingDecoder - { - private static final BytesBlockBuildingDecoder INSTANCE = new BytesBlockBuildingDecoder(); - - @Override - protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) - throws IOException - { - // it is only possible to read Bytes type when underlying write type is String or Bytes - // both have the same encoding, so coercion is a no-op - long size = decoder.readLong(); - if (size > MAX_ARRAY_SIZE) { - throw new IOException("Unable to read avro Bytes with size greater than %s. Found Bytes size: %s".formatted(MAX_ARRAY_SIZE, size)); - } - byte[] bytes = new byte[(int) size]; - decoder.readFixed(bytes); - VARBINARY.writeSlice(builder, Slices.wrappedBuffer(bytes)); - } - } - - private static class FixedBlockBuildingDecoder - extends BlockBuildingDecoder - { - private final int expectedSize; - - public FixedBlockBuildingDecoder(int expectedSize) - { - verify(expectedSize >= 0, "expected size must be greater than or equal to 0"); - this.expectedSize = expectedSize; - } - - @Override - protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) - throws IOException - { - byte[] slice = new byte[expectedSize]; - decoder.readFixed(slice); - VARBINARY.writeSlice(builder, Slices.wrappedBuffer(slice)); - } - } - - private static class EnumBlockBuildingDecoder - extends BlockBuildingDecoder - { - private final List symbols; - - public EnumBlockBuildingDecoder(EnumReadAction enumReadAction) - { - requireNonNull(enumReadAction, "action is null"); - symbols = enumReadAction.getSymbolIndex(); - } - - @Override - protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) - throws IOException - { - VARCHAR.writeSlice(builder, symbols.get(decoder.readEnum())); - } - } - - private static class ArrayBlockBuildingDecoder - extends BlockBuildingDecoder - { - private final BlockBuildingDecoder elementBlockBuildingDecoder; - - public ArrayBlockBuildingDecoder(ArrayReadAction arrayReadAction, AvroTypeManager typeManager) - throws AvroTypeException - { - requireNonNull(arrayReadAction, "arrayReadAction is null"); - elementBlockBuildingDecoder = createBlockBuildingDecoderForAction(arrayReadAction.elementReadAction(), typeManager); - } - - @Override - protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) - throws IOException - { - ((ArrayBlockBuilder) builder).buildEntry(elementBuilder -> { - long elementsInBlock = decoder.readArrayStart(); - if (elementsInBlock > 0) { - do { - for (int i = 0; i < elementsInBlock; i++) { - elementBlockBuildingDecoder.decodeIntoBlock(decoder, elementBuilder); - } - } - while ((elementsInBlock = decoder.arrayNext()) > 0); - } - }); - } - } - - private static class MapBlockBuildingDecoder - extends BlockBuildingDecoder - { - private final BlockBuildingDecoder keyBlockBuildingDecoder = new StringBlockBuildingDecoder(); - private final BlockBuildingDecoder valueBlockBuildingDecoder; - - public MapBlockBuildingDecoder(MapReadAction mapReadAction, AvroTypeManager typeManager) - throws AvroTypeException - { - requireNonNull(mapReadAction, "mapReadAction is null"); - valueBlockBuildingDecoder = createBlockBuildingDecoderForAction(mapReadAction.valueReadAction(), typeManager); - } - - @Override - protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) - throws IOException - { - ((MapBlockBuilder) builder).buildEntry((keyBuilder, valueBuilder) -> { - long entriesInBlock = decoder.readMapStart(); - // TODO need to filter out all but last value for key? - if (entriesInBlock > 0) { - do { - for (int i = 0; i < entriesInBlock; i++) { - keyBlockBuildingDecoder.decodeIntoBlock(decoder, keyBuilder); - valueBlockBuildingDecoder.decodeIntoBlock(decoder, valueBuilder); - } - } - while ((entriesInBlock = decoder.mapNext()) > 0); - } - }); - } - } - - public static class RowBlockBuildingDecoder - extends BlockBuildingDecoder - { - private final RowBuildingAction[] buildSteps; - - private RowBlockBuildingDecoder(Schema writeSchema, Schema readSchema, AvroTypeManager typeManager) - throws AvroTypeException - { - this(AvroReadAction.fromAction(Resolver.resolve(writeSchema, readSchema, new GenericData())), typeManager); - } - - private RowBlockBuildingDecoder(AvroReadAction action, AvroTypeManager typeManager) - throws AvroTypeException - - { - if (!(action instanceof RecordReadAction recordReadAction)) { - throw new AvroTypeException("Write and Read Schemas must be records when building a row block building decoder. Illegal action: " + action); - } - buildSteps = new RowBuildingAction[recordReadAction.fieldReadActions().size()]; - int i = 0; - for (RecordFieldReadAction fieldAction : recordReadAction.fieldReadActions()) { - buildSteps[i] = switch (fieldAction) { - case DefaultValueFieldRecordFieldReadAction defaultValueFieldRecordFieldReadAction -> new ConstantBlockAction( - getDefaultBlockBuilder(defaultValueFieldRecordFieldReadAction.fieldSchema(), - defaultValueFieldRecordFieldReadAction.defaultBytes(), typeManager), - defaultValueFieldRecordFieldReadAction.outputChannel()); - case ReadFieldAction readFieldAction -> new BuildIntoBlockAction(createBlockBuildingDecoderForAction(readFieldAction.readAction(), typeManager), - readFieldAction.outputChannel()); - case SkipFieldRecordFieldReadAction skipFieldRecordFieldReadAction -> new SkipSchemaBuildingAction(skipFieldRecordFieldReadAction.skipAction()); - }; - i++; - } - } - - @Override - protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) - throws IOException - { - ((RowBlockBuilder) builder).buildEntry(fieldBuilders -> decodeIntoBlockProvided(decoder, fieldBuilders::get)); - } - - protected void decodeIntoPageBuilder(Decoder decoder, PageBuilder builder) - throws IOException - { - builder.declarePosition(); - decodeIntoBlockProvided(decoder, builder::getBlockBuilder); - } - - protected void decodeIntoBlockProvided(Decoder decoder, IntFunction fieldBlockBuilder) - throws IOException - { - for (RowBuildingAction buildStep : buildSteps) { - switch (buildStep) { - case SkipSchemaBuildingAction skipSchemaBuildingAction -> skipSchemaBuildingAction.skip(decoder); - case BuildIntoBlockAction buildIntoBlockAction -> buildIntoBlockAction.decode(decoder, fieldBlockBuilder); - case ConstantBlockAction constantBlockAction -> constantBlockAction.addConstant(fieldBlockBuilder); - } - } - } - - sealed interface RowBuildingAction - permits BuildIntoBlockAction, ConstantBlockAction, SkipSchemaBuildingAction - {} - - private static final class BuildIntoBlockAction - implements RowBuildingAction - { - private final BlockBuildingDecoder delegate; - private final int outputChannel; - - public BuildIntoBlockAction(BlockBuildingDecoder delegate, int outputChannel) - { - this.delegate = requireNonNull(delegate, "delegate is null"); - checkArgument(outputChannel >= 0, "outputChannel must be positive"); - this.outputChannel = outputChannel; - } - - public void decode(Decoder decoder, IntFunction channelSelector) - throws IOException - { - delegate.decodeIntoBlock(decoder, channelSelector.apply(outputChannel)); - } - } - - protected static final class ConstantBlockAction - implements RowBuildingAction - { - private final IoConsumer addConstantFunction; - private final int outputChannel; - - public ConstantBlockAction(IoConsumer addConstantFunction, int outputChannel) - { - this.addConstantFunction = requireNonNull(addConstantFunction, "addConstantFunction is null"); - checkArgument(outputChannel >= 0, "outputChannel must be positive"); - this.outputChannel = outputChannel; - } - - public void addConstant(IntFunction channelSelector) - throws IOException - { - addConstantFunction.accept(channelSelector.apply(outputChannel)); - } - } - - public static final class SkipSchemaBuildingAction - implements RowBuildingAction - { - private final SkipAction skipAction; - - SkipSchemaBuildingAction(SkipAction skipAction) - { - this.skipAction = requireNonNull(skipAction, "skipAction is null"); - } - - public void skip(Decoder decoder) - throws IOException - { - skipAction.skip(decoder); - } - - @FunctionalInterface - public interface SkipAction - { - void skip(Decoder decoder) - throws IOException; - } - - public static SkipAction createSkipActionForSchema(Schema schema) - { - return switch (schema.getType()) { - case NULL -> Decoder::readNull; - case BOOLEAN -> Decoder::readBoolean; - case INT -> Decoder::readInt; - case LONG -> Decoder::readLong; - case FLOAT -> Decoder::readFloat; - case DOUBLE -> Decoder::readDouble; - case STRING -> Decoder::skipString; - case BYTES -> Decoder::skipBytes; - case ENUM -> Decoder::readEnum; - case FIXED -> { - int size = schema.getFixedSize(); - yield decoder -> decoder.skipFixed(size); - } - case ARRAY -> new ArraySkipAction(schema.getElementType()); - case MAP -> new MapSkipAction(schema.getValueType()); - case RECORD -> new RecordSkipAction(schema.getFields()); - case UNION -> new UnionSkipAction(schema.getTypes()); - }; - } - - private static class ArraySkipAction - implements SkipAction - { - private final SkipAction elementSkipAction; - - public ArraySkipAction(Schema elementSchema) - { - elementSkipAction = createSkipActionForSchema(requireNonNull(elementSchema, "elementSchema is null")); - } - - @Override - public void skip(Decoder decoder) - throws IOException - { - for (long i = decoder.skipArray(); i != 0; i = decoder.skipArray()) { - for (long j = 0; j < i; j++) { - elementSkipAction.skip(decoder); - } - } - } - } - - private static class MapSkipAction - implements SkipAction - { - private final SkipAction valueSkipAction; - - public MapSkipAction(Schema valueSchema) - { - valueSkipAction = createSkipActionForSchema(requireNonNull(valueSchema, "valueSchema is null")); - } - - @Override - public void skip(Decoder decoder) - throws IOException - { - for (long i = decoder.skipMap(); i != 0; i = decoder.skipMap()) { - for (long j = 0; j < i; j++) { - decoder.skipString(); // key - valueSkipAction.skip(decoder); // value - } - } - } - } - - private static class RecordSkipAction - implements SkipAction - { - private final SkipAction[] fieldSkips; - - public RecordSkipAction(List fields) - { - fieldSkips = new SkipAction[requireNonNull(fields, "fields is null").size()]; - for (int i = 0; i < fields.size(); i++) { - fieldSkips[i] = createSkipActionForSchema(fields.get(i).schema()); - } - } - - @Override - public void skip(Decoder decoder) - throws IOException - { - for (SkipAction fieldSkipAction : fieldSkips) { - fieldSkipAction.skip(decoder); - } - } - } - - private static class UnionSkipAction - implements SkipAction - { - private final SkipAction[] skipActions; - - private UnionSkipAction(List types) - { - skipActions = new SkipAction[requireNonNull(types, "types is null").size()]; - for (int i = 0; i < types.size(); i++) { - skipActions[i] = createSkipActionForSchema(types.get(i)); - } - } - - @Override - public void skip(Decoder decoder) - throws IOException - { - skipActions[decoder.readIndex()].skip(decoder); - } - } - } - } - - private static class WriterUnionBlockBuildingDecoder - extends BlockBuildingDecoder - { - protected final BlockBuildingDecoder[] blockBuildingDecoders; - - public WriterUnionBlockBuildingDecoder(WrittenUnionReadAction writtenUnionReadAction, AvroTypeManager typeManager) - throws AvroTypeException - { - requireNonNull(writtenUnionReadAction, "writerUnion is null"); - blockBuildingDecoders = new BlockBuildingDecoder[writtenUnionReadAction.writeOptionReadActions().size()]; - for (int i = 0; i < writtenUnionReadAction.writeOptionReadActions().size(); i++) { - blockBuildingDecoders[i] = createBlockBuildingDecoderForAction(writtenUnionReadAction.writeOptionReadActions().get(i), typeManager); - } - } - - @Override - protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) - throws IOException - { - decodeIntoBlock(decoder.readIndex(), decoder, builder); - } - - protected void decodeIntoBlock(int blockBuilderIndex, Decoder decoder, BlockBuilder builder) - throws IOException - { - blockBuildingDecoders[blockBuilderIndex].decodeIntoBlock(decoder, builder); - } - } - - private static class WriterUnionCoercedIntoRowBlockBuildingDecoder - extends WriterUnionBlockBuildingDecoder - { - private final boolean readUnionEquiv; - private final int[] indexToChannel; - private final int totalChannels; - - public WriterUnionCoercedIntoRowBlockBuildingDecoder(WrittenUnionReadAction writtenUnionReadAction, AvroTypeManager avroTypeManager) - throws AvroTypeException - { - super(writtenUnionReadAction, avroTypeManager); - readUnionEquiv = writtenUnionReadAction.unionEqiv(); - List readSchemas = writtenUnionReadAction.readSchema().getTypes(); - checkArgument(readSchemas.size() == writtenUnionReadAction.writeOptionReadActions().size(), "each read schema must have resolvedAction For it"); - indexToChannel = getIndexToChannel(readSchemas); - totalChannels = (int) IntStream.of(indexToChannel).filter(i -> i >= 0).count(); - } - - @Override - protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) - throws IOException - { - int index = decoder.readIndex(); - if (readUnionEquiv) { - // if no output channel then the schema is null and the whole record can be null; - if (indexToChannel[index] < 0) { - NullBlockBuildingDecoder.INSTANCE.decodeIntoBlock(decoder, builder); - } - else { - // the index for the reader and writer are the same, so the channel for the index is used to select the field to populate - makeSingleRowWithTagAndAllFieldsNullButOne(indexToChannel[index], totalChannels, blockBuildingDecoders[index], decoder, builder); - } - } - else { - // delegate to ReaderUnionCoercedIntoRowBlockBuildingDecoder to get the output channel from the resolved action - decodeIntoBlock(index, decoder, builder); - } - } - - protected static void makeSingleRowWithTagAndAllFieldsNullButOne(int outputChannel, int totalChannels, BlockBuildingDecoder blockBuildingDecoder, Decoder decoder, BlockBuilder builder) - throws IOException - { - ((RowBlockBuilder) builder).buildEntry(fieldBuilders -> { - //add tag with channel - UNION_FIELD_TAG_TYPE.writeLong(fieldBuilders.getFirst(), outputChannel); - //add in null fields except one - for (int channel = 1; channel <= totalChannels; channel++) { - if (channel == outputChannel + 1) { - blockBuildingDecoder.decodeIntoBlock(decoder, fieldBuilders.get(channel)); - } - else { - fieldBuilders.get(channel).appendNull(); - } - } - }); - } - - protected static int[] getIndexToChannel(List schemas) - { - int[] indexToChannel = new int[schemas.size()]; - int outputChannel = 0; - for (int i = 0; i < indexToChannel.length; i++) { - if (schemas.get(i).getType() == Schema.Type.NULL) { - indexToChannel[i] = -1; - } - else { - indexToChannel[i] = outputChannel++; - } - } - return indexToChannel; - } - } - - private static class ReaderUnionCoercedIntoRowBlockBuildingDecoder - extends - BlockBuildingDecoder - { - private final BlockBuildingDecoder delegateBuilder; - private final int outputChannel; - private final int totalChannels; - - public ReaderUnionCoercedIntoRowBlockBuildingDecoder(ReadingUnionReadAction readingUnionReadAction, AvroTypeManager avroTypeManager) - throws AvroTypeException - { - requireNonNull(readingUnionReadAction, "readerUnion is null"); - requireNonNull(avroTypeManager, "avroTypeManger is null"); - int[] indexToChannel = WriterUnionCoercedIntoRowBlockBuildingDecoder.getIndexToChannel(readingUnionReadAction.readSchema().getTypes()); - outputChannel = indexToChannel[readingUnionReadAction.firstMatch()]; - delegateBuilder = createBlockBuildingDecoderForAction(readingUnionReadAction.actualAction(), avroTypeManager); - totalChannels = (int) IntStream.of(indexToChannel).filter(i -> i >= 0).count(); - } - - @Override - protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) - throws IOException - { - if (outputChannel < 0) { - // No outputChannel for Null schema in union, null out coerces struct - NullBlockBuildingDecoder.INSTANCE.decodeIntoBlock(decoder, builder); - } - else { - WriterUnionCoercedIntoRowBlockBuildingDecoder - .makeSingleRowWithTagAndAllFieldsNullButOne(outputChannel, totalChannels, delegateBuilder, decoder, builder); - } - } - } - - public static LongIoFunction getLongDecoderFunction(Schema writerSchema) - { - return switch (writerSchema.getType()) { - case INT -> Decoder::readInt; - case LONG -> Decoder::readLong; - default -> throw new IllegalArgumentException("Cannot promote type %s to long".formatted(writerSchema.getType())); - }; - } - - public static FloatIoFunction getFloatDecoderFunction(Schema writerSchema) - { - return switch (writerSchema.getType()) { - case INT -> Decoder::readInt; - case LONG -> Decoder::readLong; - case FLOAT -> Decoder::readFloat; - default -> throw new IllegalArgumentException("Cannot promote type %s to float".formatted(writerSchema.getType())); - }; - } - - public static DoubleIoFunction getDoubleDecoderFunction(Schema writerSchema) - { - return switch (writerSchema.getType()) { - case INT -> Decoder::readInt; - case LONG -> Decoder::readLong; - case FLOAT -> Decoder::readFloat; - case DOUBLE -> Decoder::readDouble; - default -> throw new IllegalArgumentException("Cannot promote type %s to double".formatted(writerSchema.getType())); - }; - } - - @FunctionalInterface - public interface LongIoFunction - { - long apply(A a) - throws IOException; - } - - @FunctionalInterface - public interface FloatIoFunction - { - float apply(A a) - throws IOException; - } - - @FunctionalInterface - public interface DoubleIoFunction - { - double apply(A a) - throws IOException; - } - - @FunctionalInterface - public interface IoConsumer - { - void accept(A a) - throws IOException; - } - - // Avro supports default values for reader record fields that are missing in the writer schema - // the bytes representing the default field value are passed to a block building decoder - // so that it can pack the block appropriately for the default type. - private static IoConsumer getDefaultBlockBuilder(Schema fieldSchema, byte[] defaultBytes, AvroTypeManager typeManager) - throws AvroTypeException - { - BlockBuildingDecoder buildingDecoder = createBlockBuildingDecoderForAction(AvroReadAction.fromAction(Resolver.resolve(fieldSchema, fieldSchema)), typeManager); - BinaryDecoder reuse = DecoderFactory.get().binaryDecoder(defaultBytes, null); - return blockBuilder -> buildingDecoder.decodeIntoBlock(DecoderFactory.get().binaryDecoder(defaultBytes, reuse), blockBuilder); - } - - public static byte[] getDefaultByes(Schema.Field field) - throws AvroTypeException - { - try { - ByteArrayOutputStream out = new ByteArrayOutputStream(); - Encoder e = EncoderFactory.get().binaryEncoder(out, null); - ResolvingGrammarGenerator.encode(e, field.schema(), Accessor.defaultValue(field)); - e.flush(); - return out.toByteArray(); - } - catch (IOException exception) { - throw new AvroTypeException("Unable to encode to bytes for default value in field " + field, exception); - } - } - /** * Used for throwing {@link AvroTypeException} through interfaces that can not throw checked exceptions like DatumReader */ diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroPagePositionDataWriter.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroPagePositionDataWriter.java index 366b7b6b0d95..185396c0ecb5 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroPagePositionDataWriter.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroPagePositionDataWriter.java @@ -47,6 +47,7 @@ import static io.trino.hive.formats.avro.AvroTypeUtils.isSimpleNullableUnion; import static io.trino.hive.formats.avro.AvroTypeUtils.lowerCaseAllFieldsForWriter; import static io.trino.hive.formats.avro.AvroTypeUtils.unwrapNullableUnion; +import static io.trino.hive.formats.avro.AvroTypeUtils.verifyNoCircularReferences; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.DoubleType.DOUBLE; @@ -69,6 +70,7 @@ public AvroPagePositionDataWriter(Schema schema, AvroTypeManager avroTypeManager throws AvroTypeException { this.schema = requireNonNull(schema, "schema is null"); + verifyNoCircularReferences(schema); pageBlockPositionEncoder = new RecordBlockPositionEncoder(schema, avroTypeManager, channelNames, channelTypes); checkInvariants(); } diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroTypeBlockHandler.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroTypeBlockHandler.java new file mode 100644 index 000000000000..d872475fe67e --- /dev/null +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroTypeBlockHandler.java @@ -0,0 +1,48 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.hive.formats.avro; + +import io.trino.hive.formats.avro.model.AvroReadAction; +import io.trino.spi.type.Type; +import org.apache.avro.Schema; + +import java.util.Map; + +public interface AvroTypeBlockHandler +{ + /** + * Called when the type manager is reading out data from a data file such as in {@link AvroFileReader} + * + * This usage implies interior immutability of the block handler. + * The avro library (and our usage of it) requires the block handler to be defined when opening up a file. + * However, Hive uses the metadata from the file to define behavior of the block handler. + * Instead of writing our own version of the Avro file reader, interior mutability is allowed. + * This will be called once before any data of the file is read, but after the Trino types are defined for the read schema. + * It will be called once per file. + * + * @param fileMetadata metadata from the file header + */ + void configure(Map fileMetadata); + + /** + * Returns the block type to build for this Avro schema. + * + * @throws AvroTypeException in case there is no Type for or misconfiguration with the given schema + */ + Type typeFor(Schema schema) + throws AvroTypeException; + + BlockBuildingDecoder blockBuildingDecoderFor(AvroReadAction readAction) + throws AvroTypeException; +} diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroTypeManager.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroTypeManager.java index 0bf5d8f90b76..f8d577217d4d 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroTypeManager.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroTypeManager.java @@ -14,36 +14,14 @@ package io.trino.hive.formats.avro; import io.trino.spi.block.Block; -import io.trino.spi.block.BlockBuilder; import io.trino.spi.type.Type; import org.apache.avro.Schema; -import java.util.Map; import java.util.Optional; -import java.util.function.BiConsumer; import java.util.function.BiFunction; public interface AvroTypeManager { - /** - * Called when the type manager is reading out data from a data file such as in {@link AvroFileReader} - * - * @param fileMetadata metadata from the file header - */ - void configure(Map fileMetadata); - - Optional overrideTypeForSchema(Schema schema) - throws AvroTypeException; - - /** - * Object provided by FasterReader's deserialization with no conversions. - * Object class determined by Avro's standard generic data process - * BlockBuilder provided by Type returned above for the schema - * Possible to override for each primitive type as well. - */ - Optional> overrideBuildingFunctionForSchema(Schema schema) - throws AvroTypeException; - /** * Extract and convert the object from the given block at the given position and return the Avro Generic Data forum. * Type is either provided explicitly to the writer or derived from the schema using this interface. diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroTypeUtils.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroTypeUtils.java index 4671e2d330d9..b389476e8b1e 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroTypeUtils.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroTypeUtils.java @@ -13,79 +13,48 @@ */ package io.trino.hive.formats.avro; -import com.google.common.collect.ImmutableList; -import io.trino.spi.type.ArrayType; -import io.trino.spi.type.BigintType; -import io.trino.spi.type.BooleanType; -import io.trino.spi.type.DoubleType; -import io.trino.spi.type.IntegerType; -import io.trino.spi.type.MapType; -import io.trino.spi.type.RealType; -import io.trino.spi.type.RowType; -import io.trino.spi.type.Type; -import io.trino.spi.type.TypeOperators; -import io.trino.spi.type.VarbinaryType; -import io.trino.spi.type.VarcharType; import org.apache.avro.Schema; import java.util.HashSet; import java.util.Locale; -import java.util.Optional; import java.util.Set; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.MoreCollectors.onlyElement; -import static io.trino.hive.formats.UnionToRowCoercionUtils.rowTypeForUnionOfTypes; import static java.util.function.Predicate.not; public final class AvroTypeUtils { private AvroTypeUtils() {} - public static Type typeFromAvro(Schema schema, AvroTypeManager avroTypeManager) + public static void verifyNoCircularReferences(Schema schema) throws AvroTypeException { - return typeFromAvro(schema, avroTypeManager, new HashSet<>()); + verifyNoCircularReferences(schema, new HashSet<>()); } - private static Type typeFromAvro(final Schema schema, AvroTypeManager avroTypeManager, Set enclosingRecords) + private static void verifyNoCircularReferences(Schema schema, Set enclosingRecords) throws AvroTypeException { - Optional customType = avroTypeManager.overrideTypeForSchema(schema); - if (customType.isPresent()) { - return customType.get(); - } - return switch (schema.getType()) { - case NULL -> throw new UnsupportedOperationException("No null column type support"); - case BOOLEAN -> BooleanType.BOOLEAN; - case INT -> IntegerType.INTEGER; - case LONG -> BigintType.BIGINT; - case FLOAT -> RealType.REAL; - case DOUBLE -> DoubleType.DOUBLE; - case ENUM, STRING -> VarcharType.VARCHAR; - case FIXED, BYTES -> VarbinaryType.VARBINARY; - case ARRAY -> new ArrayType(typeFromAvro(schema.getElementType(), avroTypeManager, enclosingRecords)); - case MAP -> new MapType(VarcharType.VARCHAR, typeFromAvro(schema.getValueType(), avroTypeManager, enclosingRecords), new TypeOperators()); + switch (schema.getType()) { + case NULL, BOOLEAN, INT, LONG, FLOAT, DOUBLE, ENUM, STRING, FIXED, BYTES -> {} //no-op + case ARRAY -> verifyNoCircularReferences(schema.getElementType(), enclosingRecords); + case MAP -> verifyNoCircularReferences(schema.getValueType(), enclosingRecords); case RECORD -> { if (!enclosingRecords.add(schema)) { - throw new UnsupportedOperationException("Unable to represent recursive avro schemas in Trino Type form"); + throw new AvroTypeException("Recursive Avro Schema not supported: " + schema.getFullName()); } - ImmutableList.Builder rowFieldTypes = ImmutableList.builder(); for (Schema.Field field : schema.getFields()) { - rowFieldTypes.add(new RowType.Field(Optional.of(field.name()), typeFromAvro(field.schema(), avroTypeManager, new HashSet<>(enclosingRecords)))); + verifyNoCircularReferences(field.schema(), new HashSet<>(enclosingRecords)); } - yield RowType.from(rowFieldTypes.build()); } case UNION -> { - if (isSimpleNullableUnion(schema)) { - yield typeFromAvro(unwrapNullableUnion(schema), avroTypeManager, enclosingRecords); - } - else { - yield rowTypeForUnion(schema, avroTypeManager, enclosingRecords); + for (Schema schemaOption : schema.getTypes()) { + verifyNoCircularReferences(schemaOption, enclosingRecords); } } - }; + } } public static boolean isSimpleNullableUnion(Schema schema) @@ -101,19 +70,6 @@ static Schema unwrapNullableUnion(Schema schema) return schema.getTypes().stream().filter(not(Schema::isNullable)).collect(onlyElement()); } - private static RowType rowTypeForUnion(Schema schema, AvroTypeManager avroTypeManager, Set enclosingRecords) - throws AvroTypeException - { - verify(schema.isUnion()); - ImmutableList.Builder unionTypes = ImmutableList.builder(); - for (Schema variant : schema.getTypes()) { - if (!variant.isNullable()) { - unionTypes.add(typeFromAvro(variant, avroTypeManager, enclosingRecords)); - } - } - return rowTypeForUnionOfTypes(unionTypes.build()); - } - public static SimpleUnionNullIndex getSimpleNullableUnionNullIndex(Schema schema) { verify(schema.isUnion(), "Schema must be union"); diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/BaseAvroTypeBlockHandlerImpls.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/BaseAvroTypeBlockHandlerImpls.java new file mode 100644 index 000000000000..90c0c8557d97 --- /dev/null +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/BaseAvroTypeBlockHandlerImpls.java @@ -0,0 +1,412 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.hive.formats.avro; + +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import io.trino.annotation.NotThreadSafe; +import io.trino.hive.formats.avro.model.ArrayReadAction; +import io.trino.hive.formats.avro.model.AvroReadAction; +import io.trino.hive.formats.avro.model.BooleanRead; +import io.trino.hive.formats.avro.model.BytesRead; +import io.trino.hive.formats.avro.model.DoubleRead; +import io.trino.hive.formats.avro.model.EnumReadAction; +import io.trino.hive.formats.avro.model.FixedRead; +import io.trino.hive.formats.avro.model.FloatRead; +import io.trino.hive.formats.avro.model.IntRead; +import io.trino.hive.formats.avro.model.LongRead; +import io.trino.hive.formats.avro.model.MapReadAction; +import io.trino.hive.formats.avro.model.NullRead; +import io.trino.hive.formats.avro.model.ReadErrorReadAction; +import io.trino.hive.formats.avro.model.ReadingUnionReadAction; +import io.trino.hive.formats.avro.model.RecordReadAction; +import io.trino.hive.formats.avro.model.StringRead; +import io.trino.hive.formats.avro.model.WrittenUnionReadAction; +import io.trino.spi.block.ArrayBlockBuilder; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.MapBlockBuilder; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.BigintType; +import io.trino.spi.type.BooleanType; +import io.trino.spi.type.DoubleType; +import io.trino.spi.type.IntegerType; +import io.trino.spi.type.MapType; +import io.trino.spi.type.RealType; +import io.trino.spi.type.RowType; +import io.trino.spi.type.Type; +import io.trino.spi.type.TypeOperators; +import io.trino.spi.type.VarbinaryType; +import io.trino.spi.type.VarcharType; +import org.apache.avro.Schema; +import org.apache.avro.io.Decoder; + +import java.io.IOException; +import java.util.List; +import java.util.Optional; + +import static com.google.common.base.Verify.verify; +import static io.trino.hive.formats.avro.AvroTypeUtils.isSimpleNullableUnion; +import static io.trino.hive.formats.avro.AvroTypeUtils.unwrapNullableUnion; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.RealType.REAL; +import static io.trino.spi.type.VarbinaryType.VARBINARY; +import static io.trino.spi.type.VarcharType.VARCHAR; +import static java.lang.Float.floatToRawIntBits; +import static java.util.Objects.requireNonNull; + +public class BaseAvroTypeBlockHandlerImpls +{ + // same limit as org.apache.avro.io.BinaryDecoder + static final long MAX_ARRAY_SIZE = (long) Integer.MAX_VALUE - 8L; + + private BaseAvroTypeBlockHandlerImpls() {} + + public static Type baseTypeFor(Schema schema, AvroTypeBlockHandler delegate) + throws AvroTypeException + { + return switch (schema.getType()) { + case NULL -> throw new UnsupportedOperationException("No null column type support"); + case BOOLEAN -> BooleanType.BOOLEAN; + case INT -> IntegerType.INTEGER; + case LONG -> BigintType.BIGINT; + case FLOAT -> RealType.REAL; + case DOUBLE -> DoubleType.DOUBLE; + case ENUM, STRING -> VarcharType.VARCHAR; + case FIXED, BYTES -> VarbinaryType.VARBINARY; + case ARRAY -> new ArrayType(delegate.typeFor(schema.getElementType())); + case MAP -> new MapType(VarcharType.VARCHAR, delegate.typeFor(schema.getValueType()), new TypeOperators()); + case RECORD -> { + ImmutableList.Builder rowFieldTypes = ImmutableList.builder(); + for (Schema.Field field : schema.getFields()) { + rowFieldTypes.add(new RowType.Field(Optional.of(field.name()), delegate.typeFor(field.schema()))); + } + yield RowType.from(rowFieldTypes.build()); + } + case UNION -> { + if (isSimpleNullableUnion(schema)) { + yield delegate.typeFor(unwrapNullableUnion(schema)); + } + throw new AvroTypeException("Unable to read union with multiple non null types: " + schema); + } + }; + } + + public static BlockBuildingDecoder baseBlockBuildingDecoderFor(AvroReadAction action, AvroTypeBlockHandler delegate) + throws AvroTypeException + { + return switch (action) { + case NullRead __ -> NullBlockBuildingDecoder.INSTANCE; + case BooleanRead __ -> BooleanBlockBuildingDecoder.INSTANCE; + case IntRead __ -> IntBlockBuildingDecoder.INSTANCE; + case LongRead longRead -> new LongBlockBuildingDecoder(longRead.getLongDecoder()); + case FloatRead floatRead -> new FloatBlockBuildingDecoder(floatRead.getFloatDecoder()); + case DoubleRead doubleRead -> new DoubleBlockBuildingDecoder(doubleRead.getDoubleDecoder()); + case BytesRead __ -> BytesBlockBuildingDecoder.INSTANCE; + case FixedRead __ -> new FixedBlockBuildingDecoder(action.readSchema().getFixedSize()); + case StringRead __ -> StringBlockBuildingDecoder.INSTANCE; + case ArrayReadAction arrayReadAction -> new ArrayBlockBuildingDecoder(arrayReadAction, delegate); + case MapReadAction mapReadAction -> new MapBlockBuildingDecoder(mapReadAction, delegate); + case EnumReadAction enumReadAction -> new EnumBlockBuildingDecoder(enumReadAction); + case RecordReadAction recordReadAction -> new RowBlockBuildingDecoder(recordReadAction, delegate); + case ReadingUnionReadAction readingUnionReadAction -> { + if (!isSimpleNullableUnion(readingUnionReadAction.readSchema())) { + throw new AvroTypeException("Unable to natively read into a non nullable union: " + readingUnionReadAction.readSchema()); + } + yield delegate.blockBuildingDecoderFor(readingUnionReadAction.actualAction()); + } + case WrittenUnionReadAction writtenUnionReadAction -> { + // if unionEqiv, then need to check union is + if (writtenUnionReadAction.unionEqiv() && !isSimpleNullableUnion(writtenUnionReadAction.readSchema())) { + throw new AvroTypeException("Unable to natively read into a non nullable union: " + writtenUnionReadAction.readSchema()); + } + yield new WriterUnionBlockBuildingDecoder(writtenUnionReadAction, delegate); + } + case ReadErrorReadAction readErrorReadAction -> throw new AvroTypeException("Incompatible read and write schema returned with error:\n " + readErrorReadAction); + }; + } + + public static class NullBlockBuildingDecoder + implements BlockBuildingDecoder + { + public static final NullBlockBuildingDecoder INSTANCE = new NullBlockBuildingDecoder(); + + private NullBlockBuildingDecoder() {} + + @Override + public void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + decoder.readNull(); + builder.appendNull(); + } + } + + static class BooleanBlockBuildingDecoder + implements BlockBuildingDecoder + { + private static final BooleanBlockBuildingDecoder INSTANCE = new BooleanBlockBuildingDecoder(); + + @Override + public void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + BOOLEAN.writeBoolean(builder, decoder.readBoolean()); + } + } + + public static class IntBlockBuildingDecoder + implements BlockBuildingDecoder + { + private static final IntBlockBuildingDecoder INSTANCE = new IntBlockBuildingDecoder(); + + @Override + public void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + INTEGER.writeLong(builder, decoder.readInt()); + } + } + + public static class LongBlockBuildingDecoder + implements BlockBuildingDecoder + { + private final AvroReadAction.LongIoFunction extractLong; + + public LongBlockBuildingDecoder(AvroReadAction.LongIoFunction extractLong) + { + this.extractLong = requireNonNull(extractLong, "extractLong is null"); + } + + @Override + public void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + BIGINT.writeLong(builder, extractLong.apply(decoder)); + } + } + + static class FloatBlockBuildingDecoder + implements BlockBuildingDecoder + { + private final AvroReadAction.FloatIoFunction extractFloat; + + public FloatBlockBuildingDecoder(AvroReadAction.FloatIoFunction extractFloat) + { + this.extractFloat = requireNonNull(extractFloat, "extractFloat is null"); + } + + @Override + public void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + REAL.writeLong(builder, floatToRawIntBits(extractFloat.apply(decoder))); + } + } + + static class DoubleBlockBuildingDecoder + implements BlockBuildingDecoder + { + private final AvroReadAction.DoubleIoFunction extractDouble; + + public DoubleBlockBuildingDecoder(AvroReadAction.DoubleIoFunction extractDouble) + { + this.extractDouble = requireNonNull(extractDouble, "extractDouble is null"); + } + + @Override + public void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + DOUBLE.writeDouble(builder, extractDouble.apply(decoder)); + } + } + + static class StringBlockBuildingDecoder + implements BlockBuildingDecoder + { + private static final StringBlockBuildingDecoder INSTANCE = new StringBlockBuildingDecoder(); + + @Override + public void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + // it is only possible to read String type when underlying write type is String or Bytes + // both have the same encoding, so coercion is a no-op + long size = decoder.readLong(); + if (size > MAX_ARRAY_SIZE) { + throw new IOException("Unable to read avro String with size greater than %s. Found String size: %s".formatted(MAX_ARRAY_SIZE, size)); + } + byte[] bytes = new byte[(int) size]; + decoder.readFixed(bytes); + VARCHAR.writeSlice(builder, Slices.wrappedBuffer(bytes)); + } + } + + static class BytesBlockBuildingDecoder + implements BlockBuildingDecoder + { + private static final BytesBlockBuildingDecoder INSTANCE = new BytesBlockBuildingDecoder(); + + @Override + public void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + // it is only possible to read Bytes type when underlying write type is String or Bytes + // both have the same encoding, so coercion is a no-op + long size = decoder.readLong(); + if (size > MAX_ARRAY_SIZE) { + throw new IOException("Unable to read avro Bytes with size greater than %s. Found Bytes size: %s".formatted(MAX_ARRAY_SIZE, size)); + } + byte[] bytes = new byte[(int) size]; + decoder.readFixed(bytes); + VARBINARY.writeSlice(builder, Slices.wrappedBuffer(bytes)); + } + } + + @NotThreadSafe + static class FixedBlockBuildingDecoder + implements BlockBuildingDecoder + { + private final byte[] bytes; + + public FixedBlockBuildingDecoder(int expectedSize) + { + verify(expectedSize >= 0, "expected size must be greater than or equal to 0"); + bytes = new byte[expectedSize]; + } + + @Override + public void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + decoder.readFixed(bytes); + VARBINARY.writeSlice(builder, Slices.wrappedBuffer(bytes)); + } + } + + static class EnumBlockBuildingDecoder + implements BlockBuildingDecoder + { + private final List symbols; + + public EnumBlockBuildingDecoder(EnumReadAction enumReadAction) + { + requireNonNull(enumReadAction, "action is null"); + symbols = enumReadAction.getSymbolIndex(); + } + + @Override + public void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + VARCHAR.writeSlice(builder, symbols.get(decoder.readEnum())); + } + } + + static class ArrayBlockBuildingDecoder + implements BlockBuildingDecoder + { + private final BlockBuildingDecoder elementBlockBuildingDecoder; + + public ArrayBlockBuildingDecoder(ArrayReadAction arrayReadAction, AvroTypeBlockHandler typeManager) + throws AvroTypeException + { + requireNonNull(arrayReadAction, "arrayReadAction is null"); + elementBlockBuildingDecoder = typeManager.blockBuildingDecoderFor(arrayReadAction.elementReadAction()); + } + + @Override + public void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + ((ArrayBlockBuilder) builder).buildEntry(elementBuilder -> { + long elementsInBlock = decoder.readArrayStart(); + if (elementsInBlock > 0) { + do { + for (int i = 0; i < elementsInBlock; i++) { + elementBlockBuildingDecoder.decodeIntoBlock(decoder, elementBuilder); + } + } + while ((elementsInBlock = decoder.arrayNext()) > 0); + } + }); + } + } + + static class MapBlockBuildingDecoder + implements BlockBuildingDecoder + { + private final BlockBuildingDecoder keyBlockBuildingDecoder = new StringBlockBuildingDecoder(); + private final BlockBuildingDecoder valueBlockBuildingDecoder; + + public MapBlockBuildingDecoder(MapReadAction mapReadAction, AvroTypeBlockHandler typeManager) + throws AvroTypeException + { + requireNonNull(mapReadAction, "mapReadAction is null"); + valueBlockBuildingDecoder = typeManager.blockBuildingDecoderFor(mapReadAction.valueReadAction()); + } + + @Override + public void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + ((MapBlockBuilder) builder).buildEntry((keyBuilder, valueBuilder) -> { + long entriesInBlock = decoder.readMapStart(); + // TODO need to filter out all but last value for key? + if (entriesInBlock > 0) { + do { + for (int i = 0; i < entriesInBlock; i++) { + keyBlockBuildingDecoder.decodeIntoBlock(decoder, keyBuilder); + valueBlockBuildingDecoder.decodeIntoBlock(decoder, valueBuilder); + } + } + while ((entriesInBlock = decoder.mapNext()) > 0); + } + }); + } + } + + static class WriterUnionBlockBuildingDecoder + implements BlockBuildingDecoder + { + protected final BlockBuildingDecoder[] blockBuildingDecoders; + + public WriterUnionBlockBuildingDecoder(WrittenUnionReadAction writtenUnionReadAction, AvroTypeBlockHandler typeManager) + throws AvroTypeException + { + requireNonNull(writtenUnionReadAction, "writerUnion is null"); + blockBuildingDecoders = new BlockBuildingDecoder[writtenUnionReadAction.writeOptionReadActions().size()]; + for (int i = 0; i < writtenUnionReadAction.writeOptionReadActions().size(); i++) { + blockBuildingDecoders[i] = typeManager.blockBuildingDecoderFor(writtenUnionReadAction.writeOptionReadActions().get(i)); + } + } + + @Override + public void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + decodeIntoBlock(decoder.readIndex(), decoder, builder); + } + + protected void decodeIntoBlock(int blockBuilderIndex, Decoder decoder, BlockBuilder builder) + throws IOException + { + blockBuildingDecoders[blockBuilderIndex].decodeIntoBlock(decoder, builder); + } + } +} diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/BlockBuildingDecoder.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/BlockBuildingDecoder.java new file mode 100644 index 000000000000..1e7d153febda --- /dev/null +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/BlockBuildingDecoder.java @@ -0,0 +1,25 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.hive.formats.avro; + +import io.trino.spi.block.BlockBuilder; +import org.apache.avro.io.Decoder; + +import java.io.IOException; + +public interface BlockBuildingDecoder +{ + void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException; +} diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/HiveAvroTypeBlockHandler.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/HiveAvroTypeBlockHandler.java new file mode 100644 index 000000000000..64d0bbb9c5d2 --- /dev/null +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/HiveAvroTypeBlockHandler.java @@ -0,0 +1,429 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.hive.formats.avro; + +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slices; +import io.trino.hive.formats.avro.BaseAvroTypeBlockHandlerImpls.WriterUnionBlockBuildingDecoder; +import io.trino.hive.formats.avro.NativeLogicalTypesAvroTypeManager.InvalidNativeAvroLogicalType; +import io.trino.hive.formats.avro.NativeLogicalTypesAvroTypeManager.NoLogicalType; +import io.trino.hive.formats.avro.NativeLogicalTypesAvroTypeManager.NonNativeAvroLogicalType; +import io.trino.hive.formats.avro.NativeLogicalTypesAvroTypeManager.ValidNativeAvroLogicalType; +import io.trino.hive.formats.avro.NativeLogicalTypesAvroTypeManager.ValidateLogicalTypeResult; +import io.trino.hive.formats.avro.model.ArrayReadAction; +import io.trino.hive.formats.avro.model.AvroLogicalType.BytesDecimalLogicalType; +import io.trino.hive.formats.avro.model.AvroLogicalType.DateLogicalType; +import io.trino.hive.formats.avro.model.AvroLogicalType.FixedDecimalLogicalType; +import io.trino.hive.formats.avro.model.AvroLogicalType.StringUUIDLogicalType; +import io.trino.hive.formats.avro.model.AvroLogicalType.TimeMicrosLogicalType; +import io.trino.hive.formats.avro.model.AvroLogicalType.TimeMillisLogicalType; +import io.trino.hive.formats.avro.model.AvroLogicalType.TimestampMicrosLogicalType; +import io.trino.hive.formats.avro.model.AvroLogicalType.TimestampMillisLogicalType; +import io.trino.hive.formats.avro.model.AvroReadAction; +import io.trino.hive.formats.avro.model.AvroReadAction.LongIoFunction; +import io.trino.hive.formats.avro.model.BooleanRead; +import io.trino.hive.formats.avro.model.BytesRead; +import io.trino.hive.formats.avro.model.DoubleRead; +import io.trino.hive.formats.avro.model.EnumReadAction; +import io.trino.hive.formats.avro.model.FixedRead; +import io.trino.hive.formats.avro.model.FloatRead; +import io.trino.hive.formats.avro.model.IntRead; +import io.trino.hive.formats.avro.model.LongRead; +import io.trino.hive.formats.avro.model.MapReadAction; +import io.trino.hive.formats.avro.model.NullRead; +import io.trino.hive.formats.avro.model.ReadErrorReadAction; +import io.trino.hive.formats.avro.model.ReadingUnionReadAction; +import io.trino.hive.formats.avro.model.RecordReadAction; +import io.trino.hive.formats.avro.model.StringRead; +import io.trino.hive.formats.avro.model.WrittenUnionReadAction; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.RowBlockBuilder; +import io.trino.spi.type.BooleanType; +import io.trino.spi.type.Chars; +import io.trino.spi.type.LongTimestamp; +import io.trino.spi.type.RowType; +import io.trino.spi.type.TimestampType; +import io.trino.spi.type.Timestamps; +import io.trino.spi.type.Type; +import io.trino.spi.type.Varchars; +import org.apache.avro.Schema; +import org.apache.avro.io.Decoder; +import org.joda.time.DateTimeZone; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.time.ZoneId; +import java.util.List; +import java.util.Map; +import java.util.TimeZone; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.IntStream; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Verify.verify; +import static io.trino.hive.formats.UnionToRowCoercionUtils.UNION_FIELD_TAG_TYPE; +import static io.trino.hive.formats.UnionToRowCoercionUtils.rowTypeForUnionOfTypes; +import static io.trino.hive.formats.avro.AvroHiveConstants.CHAR_TYPE_LOGICAL_NAME; +import static io.trino.hive.formats.avro.AvroHiveConstants.VARCHAR_TYPE_LOGICAL_NAME; +import static io.trino.hive.formats.avro.AvroTypeUtils.isSimpleNullableUnion; +import static io.trino.hive.formats.avro.AvroTypeUtils.unwrapNullableUnion; +import static io.trino.hive.formats.avro.AvroTypeUtils.verifyNoCircularReferences; +import static io.trino.hive.formats.avro.BaseAvroTypeBlockHandlerImpls.MAX_ARRAY_SIZE; +import static io.trino.hive.formats.avro.BaseAvroTypeBlockHandlerImpls.NullBlockBuildingDecoder.INSTANCE; +import static io.trino.hive.formats.avro.BaseAvroTypeBlockHandlerImpls.baseBlockBuildingDecoderFor; +import static io.trino.hive.formats.avro.BaseAvroTypeBlockHandlerImpls.baseTypeFor; +import static io.trino.hive.formats.avro.HiveAvroTypeManager.getHiveLogicalVarCharOrCharType; +import static io.trino.hive.formats.avro.NativeLogicalTypesAvroTypeBlockHandler.getLogicalTypeBuildingFunction; +import static io.trino.hive.formats.avro.NativeLogicalTypesAvroTypeManager.getAvroLogicalTypeSpiType; +import static io.trino.hive.formats.avro.NativeLogicalTypesAvroTypeManager.validateLogicalType; +import static java.time.ZoneOffset.UTC; +import static java.util.Objects.requireNonNull; + +public class HiveAvroTypeBlockHandler + implements AvroTypeBlockHandler +{ + private final TimestampType hiveSessionTimestamp; + private final AtomicReference convertToTimezone = new AtomicReference<>(UTC); + + public HiveAvroTypeBlockHandler(TimestampType hiveSessionTimestamp) + { + this.hiveSessionTimestamp = requireNonNull(hiveSessionTimestamp, "hiveSessionTimestamp is null"); + } + + @Override + public void configure(Map fileMetadata) + { + if (fileMetadata.containsKey(AvroHiveConstants.WRITER_TIME_ZONE)) { + convertToTimezone.set(ZoneId.of(new String(fileMetadata.get(AvroHiveConstants.WRITER_TIME_ZONE), StandardCharsets.UTF_8))); + } + else { + // legacy path allows this conversion to be skipped with {@link org.apache.hadoop.conf.Configuration} param + // currently no way to set that configuration in Trino + convertToTimezone.set(TimeZone.getDefault().toZoneId()); + } + } + + @Override + public Type typeFor(Schema schema) + throws AvroTypeException + { + verifyNoCircularReferences(schema); + if (schema.getType() == Schema.Type.NULL) { + // allows of dereference when no base columns from file used + // BooleanType chosen rather arbitrarily to be stuffed with null + // in response to behavior defined by io.trino.tests.product.hive.TestAvroSchemaStrictness.testInvalidUnionDefaults + return BooleanType.BOOLEAN; + } + else if (schema.getType() == Schema.Type.UNION) { + if (isSimpleNullableUnion(schema)) { + return typeFor(unwrapNullableUnion(schema)); + } + else { + // coerce complex unions to structs + return rowTypeForUnion(schema); + } + } + ValidateLogicalTypeResult logicalTypeResult = validateLogicalType(schema); + return switch (logicalTypeResult) { + case NoLogicalType _ -> baseTypeFor(schema, this); + case InvalidNativeAvroLogicalType _ -> //Hive Ignores these issues + baseTypeFor(schema, this); + case NonNativeAvroLogicalType nonNativeAvroLogicalType -> switch (nonNativeAvroLogicalType.getLogicalTypeName()) { + case VARCHAR_TYPE_LOGICAL_NAME, CHAR_TYPE_LOGICAL_NAME -> getHiveLogicalVarCharOrCharType(schema, nonNativeAvroLogicalType); + // logical type we don't recognize, ignore + default -> baseTypeFor(schema, this); + }; + case ValidNativeAvroLogicalType validNativeAvroLogicalType -> switch (validNativeAvroLogicalType.getLogicalType()) { + case DateLogicalType dateLogicalType -> getAvroLogicalTypeSpiType(dateLogicalType); + case BytesDecimalLogicalType bytesDecimalLogicalType -> getAvroLogicalTypeSpiType(bytesDecimalLogicalType); + case TimestampMillisLogicalType _ -> hiveSessionTimestamp; + // Other logical types ignored by hive/don't map to hive types + case FixedDecimalLogicalType _, TimeMicrosLogicalType _, TimeMillisLogicalType _, TimestampMicrosLogicalType _, StringUUIDLogicalType _ -> baseTypeFor(schema, this); + }; + }; + } + + private RowType rowTypeForUnion(Schema schema) + throws AvroTypeException + { + verify(schema.isUnion()); + ImmutableList.Builder unionTypes = ImmutableList.builder(); + for (Schema variant : schema.getTypes()) { + if (!variant.isNullable()) { + unionTypes.add(typeFor(variant)); + } + } + return rowTypeForUnionOfTypes(unionTypes.build()); + } + + @Override + public BlockBuildingDecoder blockBuildingDecoderFor(AvroReadAction readAction) + throws AvroTypeException + { + ValidateLogicalTypeResult logicalTypeResult = validateLogicalType(readAction.readSchema()); + return switch (logicalTypeResult) { + case NoLogicalType _ -> baseBlockBuildingDecoderWithUnionCoerceAndErrorDelays(readAction); + case InvalidNativeAvroLogicalType _ -> //Hive Ignores these issues + baseBlockBuildingDecoderWithUnionCoerceAndErrorDelays(readAction); + case NonNativeAvroLogicalType nonNativeAvroLogicalType -> switch (nonNativeAvroLogicalType.getLogicalTypeName()) { + case VARCHAR_TYPE_LOGICAL_NAME -> + new HiveVarcharTypeBlockBuildingDecoder(getHiveLogicalVarCharOrCharType(readAction.readSchema(), nonNativeAvroLogicalType)); + case CHAR_TYPE_LOGICAL_NAME -> + new HiveCharTypeBlockBuildingDecoder(getHiveLogicalVarCharOrCharType(readAction.readSchema(), nonNativeAvroLogicalType)); + // logical type we don't recognize, ignore + default -> baseBlockBuildingDecoderWithUnionCoerceAndErrorDelays(readAction); + }; + case ValidNativeAvroLogicalType validNativeAvroLogicalType -> switch (validNativeAvroLogicalType.getLogicalType()) { + case DateLogicalType dateLogicalType -> getLogicalTypeBuildingFunction(dateLogicalType, readAction); + case BytesDecimalLogicalType bytesDecimalLogicalType -> getLogicalTypeBuildingFunction(bytesDecimalLogicalType, readAction); + case TimestampMillisLogicalType _ -> new HiveCoercesedTimestampBlockBuildingDecoder(hiveSessionTimestamp, readAction, convertToTimezone); + // Hive only supports Bytes Decimal Type + case FixedDecimalLogicalType _ -> baseBlockBuildingDecoderWithUnionCoerceAndErrorDelays(readAction); + // Other logical types ignored by hive/don't map to hive types + case TimeMicrosLogicalType _, TimeMillisLogicalType _, TimestampMicrosLogicalType _, StringUUIDLogicalType _ -> baseBlockBuildingDecoderWithUnionCoerceAndErrorDelays(readAction); + }; + }; + } + + private BlockBuildingDecoder baseBlockBuildingDecoderWithUnionCoerceAndErrorDelays(AvroReadAction readAction) + throws AvroTypeException + { + return switch (readAction) { + case NullRead _, BooleanRead _, IntRead _, LongRead _, FloatRead _, DoubleRead _, StringRead _, BytesRead _, FixedRead _, ArrayReadAction _, EnumReadAction _, + MapReadAction _, RecordReadAction _ -> baseBlockBuildingDecoderFor(readAction, this); + case WrittenUnionReadAction writtenUnionReadAction -> { + if (writtenUnionReadAction.readSchema().getType() == Schema.Type.UNION && !isSimpleNullableUnion(writtenUnionReadAction.readSchema())) { + yield new WriterUnionCoercedIntoRowBlockBuildingDecoder(writtenUnionReadAction, this); + } + else { + // reading a union with non-union or nullable union, optimistically try to create the reader, will fail at read time with any underlying issues + yield new WriterUnionBlockBuildingDecoder(writtenUnionReadAction, this); + } + } + case ReadingUnionReadAction readingUnionReadAction -> { + if (isSimpleNullableUnion(readingUnionReadAction.readSchema())) { + yield blockBuildingDecoderFor(readingUnionReadAction.actualAction()); + } + else { + yield new ReaderUnionCoercedIntoRowBlockBuildingDecoder(readingUnionReadAction, this); + } + } + case ReadErrorReadAction readErrorReadAction -> new TypeErrorThrower(readErrorReadAction); + }; + } + + public static class WriterUnionCoercedIntoRowBlockBuildingDecoder + extends WriterUnionBlockBuildingDecoder + { + private final boolean readUnionEquiv; + private final int[] indexToChannel; + private final int totalChannels; + + public WriterUnionCoercedIntoRowBlockBuildingDecoder(WrittenUnionReadAction writtenUnionReadAction, AvroTypeBlockHandler avroTypeManager) + throws AvroTypeException + { + super(writtenUnionReadAction, avroTypeManager); + readUnionEquiv = writtenUnionReadAction.unionEqiv(); + List readSchemas = writtenUnionReadAction.readSchema().getTypes(); + checkArgument(readSchemas.size() == writtenUnionReadAction.writeOptionReadActions().size(), "each read schema must have resolvedAction For it"); + indexToChannel = getIndexToChannel(readSchemas); + totalChannels = (int) IntStream.of(indexToChannel).filter(i -> i >= 0).count(); + } + + @Override + public void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + int index = decoder.readIndex(); + if (readUnionEquiv) { + // if no output channel then the schema is null and the whole record can be null; + if (indexToChannel[index] < 0) { + INSTANCE.decodeIntoBlock(decoder, builder); + } + else { + // the index for the reader and writer are the same, so the channel for the index is used to select the field to populate + makeSingleRowWithTagAndAllFieldsNullButOne(indexToChannel[index], totalChannels, blockBuildingDecoders[index], decoder, builder); + } + } + else { + // delegate to ReaderUnionCoercedIntoRowBlockBuildingDecoder to get the output channel from the resolved action + decodeIntoBlock(index, decoder, builder); + } + } + + protected static void makeSingleRowWithTagAndAllFieldsNullButOne(int outputChannel, int totalChannels, BlockBuildingDecoder blockBuildingDecoder, Decoder decoder, BlockBuilder builder) + throws IOException + { + ((RowBlockBuilder) builder).buildEntry(fieldBuilders -> { + //add tag with channel + UNION_FIELD_TAG_TYPE.writeLong(fieldBuilders.getFirst(), outputChannel); + //add in null fields except one + for (int channel = 1; channel <= totalChannels; channel++) { + if (channel == outputChannel + 1) { + blockBuildingDecoder.decodeIntoBlock(decoder, fieldBuilders.get(channel)); + } + else { + fieldBuilders.get(channel).appendNull(); + } + } + }); + } + + protected static int[] getIndexToChannel(List schemas) + { + int[] indexToChannel = new int[schemas.size()]; + int outputChannel = 0; + for (int i = 0; i < indexToChannel.length; i++) { + if (schemas.get(i).getType() == Schema.Type.NULL) { + indexToChannel[i] = -1; + } + else { + indexToChannel[i] = outputChannel++; + } + } + return indexToChannel; + } + } + + public static class ReaderUnionCoercedIntoRowBlockBuildingDecoder + implements BlockBuildingDecoder + { + private final BlockBuildingDecoder delegateBuilder; + private final int outputChannel; + private final int totalChannels; + + public ReaderUnionCoercedIntoRowBlockBuildingDecoder(ReadingUnionReadAction readingUnionReadAction, AvroTypeBlockHandler avroTypeManager) + throws AvroTypeException + { + requireNonNull(readingUnionReadAction, "readerUnion is null"); + requireNonNull(avroTypeManager, "avroTypeManger is null"); + int[] indexToChannel = WriterUnionCoercedIntoRowBlockBuildingDecoder.getIndexToChannel(readingUnionReadAction.readSchema().getTypes()); + outputChannel = indexToChannel[readingUnionReadAction.firstMatch()]; + delegateBuilder = avroTypeManager.blockBuildingDecoderFor(readingUnionReadAction.actualAction()); + totalChannels = (int) IntStream.of(indexToChannel).filter(i -> i >= 0).count(); + } + + @Override + public void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + if (outputChannel < 0) { + // No outputChannel for Null schema in union, null out coerces struct + INSTANCE.decodeIntoBlock(decoder, builder); + } + else { + WriterUnionCoercedIntoRowBlockBuildingDecoder + .makeSingleRowWithTagAndAllFieldsNullButOne(outputChannel, totalChannels, delegateBuilder, decoder, builder); + } + } + } + + // A block building decoder that can delay errors till read time to allow the shape of the data to circumvent schema compatibility issues + // for example reading a ["null", "string"] with "string" is possible if the underlying union data is all strings + private static class TypeErrorThrower + implements BlockBuildingDecoder + { + private final ReadErrorReadAction action; + + public TypeErrorThrower(ReadErrorReadAction action) + { + this.action = requireNonNull(action, "action is null"); + } + + @Override + public void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + throw new IOException(new AvroTypeException("Resolution action returned with error " + action)); + } + } + + private static class HiveVarcharTypeBlockBuildingDecoder + implements BlockBuildingDecoder + { + private final Type type; + + public HiveVarcharTypeBlockBuildingDecoder(Type type) + { + this.type = requireNonNull(type, "type is null"); + } + + @Override + public void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + long size = decoder.readLong(); + if (size > MAX_ARRAY_SIZE) { + throw new IOException("Unable to read avro String with size greater than %s. Found String size: %s".formatted(MAX_ARRAY_SIZE, size)); + } + byte[] bytes = new byte[(int) size]; + decoder.readFixed(bytes); + type.writeSlice(builder, Varchars.truncateToLength(Slices.wrappedBuffer(bytes), type)); + } + } + + private static class HiveCharTypeBlockBuildingDecoder + implements BlockBuildingDecoder + { + private final Type type; + + public HiveCharTypeBlockBuildingDecoder(Type type) + { + this.type = requireNonNull(type, "type is null"); + } + + @Override + public void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + long size = decoder.readLong(); + if (size > MAX_ARRAY_SIZE) { + throw new IOException("Unable to read avro String with size greater than %s. Found String size: %s".formatted(MAX_ARRAY_SIZE, size)); + } + byte[] bytes = new byte[(int) size]; + decoder.readFixed(bytes); + type.writeSlice(builder, Chars.truncateToLengthAndTrimSpaces(Slices.wrappedBuffer(bytes), type)); + } + } + + private static class HiveCoercesedTimestampBlockBuildingDecoder + implements BlockBuildingDecoder + { + private final TimestampType hiveSessionTimestamp; + private final LongIoFunction longDecoder; + private final AtomicReference convertToTimezone; + + public HiveCoercesedTimestampBlockBuildingDecoder(TimestampType hiveSessionTimestamp, AvroReadAction avroReadAction, AtomicReference convertToTimezone) + { + checkArgument(avroReadAction instanceof LongRead, "TimestampMillis type can only be read from Long Avro type"); + this.hiveSessionTimestamp = requireNonNull(hiveSessionTimestamp, "hiveSessionTimestamp is null"); + longDecoder = ((LongRead) avroReadAction).getLongDecoder(); + this.convertToTimezone = requireNonNull(convertToTimezone, "convertToTimezone is null"); + } + + @Override + public void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + long millisSinceEpochUTC = longDecoder.apply(decoder); + if (hiveSessionTimestamp.isShort()) { + hiveSessionTimestamp.writeLong(builder, DateTimeZone.forTimeZone(TimeZone.getTimeZone(convertToTimezone.get())).convertUTCToLocal(millisSinceEpochUTC) * Timestamps.MICROSECONDS_PER_MILLISECOND); + } + else { + LongTimestamp longTimestamp = new LongTimestamp(DateTimeZone.forTimeZone(TimeZone.getTimeZone(convertToTimezone.get())).convertUTCToLocal(millisSinceEpochUTC) * Timestamps.MICROSECONDS_PER_MILLISECOND, 0); + hiveSessionTimestamp.writeObject(builder, longTimestamp); + } + } + } +} diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/HiveAvroTypeManager.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/HiveAvroTypeManager.java new file mode 100644 index 000000000000..ea2f8aa2860b --- /dev/null +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/HiveAvroTypeManager.java @@ -0,0 +1,125 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.hive.formats.avro; + +import io.airlift.slice.Slice; +import io.trino.hive.formats.avro.model.AvroLogicalType.BytesDecimalLogicalType; +import io.trino.hive.formats.avro.model.AvroLogicalType.DateLogicalType; +import io.trino.hive.formats.avro.model.AvroLogicalType.TimestampMillisLogicalType; +import io.trino.spi.block.Block; +import io.trino.spi.type.SqlTimestamp; +import io.trino.spi.type.TimestampType; +import io.trino.spi.type.Timestamps; +import io.trino.spi.type.Type; +import org.apache.avro.Schema; +import org.joda.time.DateTimeZone; + +import java.util.Optional; +import java.util.TimeZone; +import java.util.function.BiFunction; + +import static io.trino.hive.formats.avro.AvroHiveConstants.CHAR_TYPE_LOGICAL_NAME; +import static io.trino.hive.formats.avro.AvroHiveConstants.VARCHAR_TYPE_LOGICAL_NAME; +import static io.trino.hive.formats.avro.model.AvroLogicalType.DATE; +import static io.trino.hive.formats.avro.model.AvroLogicalType.DECIMAL; +import static io.trino.hive.formats.avro.model.AvroLogicalType.TIMESTAMP_MILLIS; +import static io.trino.spi.type.CharType.createCharType; +import static io.trino.spi.type.Timestamps.roundDiv; +import static io.trino.spi.type.VarcharType.createVarcharType; + +public class HiveAvroTypeManager + extends NativeLogicalTypesAvroTypeManager +{ + @Override + public Optional> overrideBlockToAvroObject(Schema schema, Type type) + throws AvroTypeException + { + ValidateLogicalTypeResult result = validateLogicalType(schema); + return switch (result) { + case NoLogicalType ignored -> Optional.empty(); + case NonNativeAvroLogicalType nonNativeAvroLogicalType -> switch (nonNativeAvroLogicalType.getLogicalTypeName()) { + case VARCHAR_TYPE_LOGICAL_NAME, CHAR_TYPE_LOGICAL_NAME -> { + Type expectedType = getHiveLogicalVarCharOrCharType(schema, nonNativeAvroLogicalType); + if (!expectedType.equals(type)) { + throw new AvroTypeException("Type provided for column [%s] is incompatible with type for schema: %s".formatted(type, expectedType)); + } + yield Optional.of((block, pos) -> ((Slice) expectedType.getObject(block, pos)).toStringUtf8()); + } + default -> Optional.empty(); + }; + case InvalidNativeAvroLogicalType invalidNativeAvroLogicalType -> switch (invalidNativeAvroLogicalType.getLogicalTypeName()) { + case TIMESTAMP_MILLIS, DATE, DECIMAL -> throw invalidNativeAvroLogicalType.getCause(); + default -> Optional.empty(); // Other logical types ignored by hive/don't map to hive types + }; + case ValidNativeAvroLogicalType validNativeAvroLogicalType -> switch (validNativeAvroLogicalType.getLogicalType()) { + case TimestampMillisLogicalType __ -> { + if (!(type instanceof TimestampType timestampType)) { + throw new AvroTypeException("Can't represent avro logical type %s with Trino Type %s".formatted(validNativeAvroLogicalType.getLogicalType(), type)); + } + if (timestampType.isShort()) { + yield Optional.of((block, pos) -> { + long millis = roundDiv(timestampType.getLong(block, pos), Timestamps.MICROSECONDS_PER_MILLISECOND); + // see org.apache.hadoop.hive.serde2.avro.AvroSerializer.serializePrimitive + return DateTimeZone.forTimeZone(TimeZone.getDefault()).convertLocalToUTC(millis, false); + }); + } + else { + yield Optional.of((block, pos) -> + { + SqlTimestamp timestamp = (SqlTimestamp) timestampType.getObject(block, pos); + // see org.apache.hadoop.hive.serde2.avro.AvroSerializer.serializePrimitive + return DateTimeZone.forTimeZone(TimeZone.getDefault()).convertLocalToUTC(timestamp.getMillis(), false); + }); + } + } + case DateLogicalType __ -> super.overrideBlockToAvroObject(schema, type); + case BytesDecimalLogicalType __ -> super.overrideBlockToAvroObject(schema, type); + default -> Optional.empty(); // Other logical types ignored by hive/don't map to hive types + }; + }; + } + + static Type getHiveLogicalVarCharOrCharType(Schema schema, NonNativeAvroLogicalType nonNativeAvroLogicalType) + throws AvroTypeException + { + if (schema.getType() != Schema.Type.STRING) { + throw new AvroTypeException("Unsupported Avro type for Hive Logical Type in schema " + schema); + } + Object maxLengthObject = schema.getObjectProp(AvroHiveConstants.VARCHAR_AND_CHAR_LOGICAL_TYPE_LENGTH_PROP); + if (maxLengthObject == null) { + throw new AvroTypeException("Missing property maxLength in schema for Hive Type " + nonNativeAvroLogicalType.getLogicalTypeName()); + } + try { + int maxLength; + if (maxLengthObject instanceof String maxLengthString) { + maxLength = Integer.parseInt(maxLengthString); + } + else if (maxLengthObject instanceof Number maxLengthNumber) { + maxLength = maxLengthNumber.intValue(); + } + else { + throw new AvroTypeException("Unrecognized property type for " + AvroHiveConstants.VARCHAR_AND_CHAR_LOGICAL_TYPE_LENGTH_PROP + " in schema " + schema); + } + if (nonNativeAvroLogicalType.getLogicalTypeName().equals(VARCHAR_TYPE_LOGICAL_NAME)) { + return createVarcharType(maxLength); + } + else { + return createCharType(maxLength); + } + } + catch (NumberFormatException numberFormatException) { + throw new AvroTypeException("Property maxLength not convertible to Integer in Hive Logical type schema " + schema); + } + } +} diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/NativeLogicalTypesAvroTypeBlockHandler.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/NativeLogicalTypesAvroTypeBlockHandler.java new file mode 100644 index 000000000000..0dabc8036e9f --- /dev/null +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/NativeLogicalTypesAvroTypeBlockHandler.java @@ -0,0 +1,293 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.hive.formats.avro; + +import io.trino.annotation.NotThreadSafe; +import io.trino.hive.formats.avro.model.AvroLogicalType; +import io.trino.hive.formats.avro.model.AvroLogicalType.BytesDecimalLogicalType; +import io.trino.hive.formats.avro.model.AvroLogicalType.DateLogicalType; +import io.trino.hive.formats.avro.model.AvroLogicalType.FixedDecimalLogicalType; +import io.trino.hive.formats.avro.model.AvroLogicalType.StringUUIDLogicalType; +import io.trino.hive.formats.avro.model.AvroLogicalType.TimeMicrosLogicalType; +import io.trino.hive.formats.avro.model.AvroLogicalType.TimeMillisLogicalType; +import io.trino.hive.formats.avro.model.AvroLogicalType.TimestampMicrosLogicalType; +import io.trino.hive.formats.avro.model.AvroLogicalType.TimestampMillisLogicalType; +import io.trino.hive.formats.avro.model.AvroReadAction; +import io.trino.hive.formats.avro.model.AvroReadAction.LongIoFunction; +import io.trino.hive.formats.avro.model.BytesRead; +import io.trino.hive.formats.avro.model.FixedRead; +import io.trino.hive.formats.avro.model.IntRead; +import io.trino.hive.formats.avro.model.LongRead; +import io.trino.hive.formats.avro.model.StringRead; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.type.DecimalType; +import io.trino.spi.type.Int128; +import io.trino.spi.type.Timestamps; +import io.trino.spi.type.Type; +import org.apache.avro.Schema; +import org.apache.avro.io.Decoder; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.Optional; + +import static com.google.common.base.Verify.verify; +import static io.trino.hive.formats.avro.BaseAvroTypeBlockHandlerImpls.MAX_ARRAY_SIZE; +import static io.trino.hive.formats.avro.BaseAvroTypeBlockHandlerImpls.baseBlockBuildingDecoderFor; +import static io.trino.hive.formats.avro.BaseAvroTypeBlockHandlerImpls.baseTypeFor; +import static io.trino.hive.formats.avro.NativeLogicalTypesAvroTypeManager.fromBigEndian; +import static io.trino.hive.formats.avro.NativeLogicalTypesAvroTypeManager.validateAndLogIssues; +import static io.trino.spi.type.DateType.DATE; +import static io.trino.spi.type.TimeType.TIME_MICROS; +import static io.trino.spi.type.TimeType.TIME_MILLIS; +import static io.trino.spi.type.TimestampType.TIMESTAMP_MICROS; +import static io.trino.spi.type.TimestampType.TIMESTAMP_MILLIS; +import static io.trino.spi.type.UuidType.UUID; +import static io.trino.spi.type.UuidType.javaUuidToTrinoUuid; +import static java.util.Objects.requireNonNull; + +/** + * This block handler class is designed to handle all Avro documented logical types + * and their reading as the appropriate Trino type. + */ +public class NativeLogicalTypesAvroTypeBlockHandler + implements AvroTypeBlockHandler +{ + @Override + public void configure(Map fileMetadata) {} + + @Override + public Type typeFor(Schema schema) + throws AvroTypeException + { + Optional avroLogicalTypeTrinoType = validateAndLogIssues(schema) + .map(NativeLogicalTypesAvroTypeManager::getAvroLogicalTypeSpiType); + if (avroLogicalTypeTrinoType.isPresent()) { + return avroLogicalTypeTrinoType.get(); + } + return baseTypeFor(schema, this); + } + + @Override + public BlockBuildingDecoder blockBuildingDecoderFor(AvroReadAction readAction) + throws AvroTypeException + { + Optional avroLogicalType = validateAndLogIssues(readAction.readSchema()); + if (avroLogicalType.isPresent()) { + return getLogicalTypeBuildingFunction(avroLogicalType.get(), readAction); + } + return baseBlockBuildingDecoderFor(readAction, this); + } + + static BlockBuildingDecoder getLogicalTypeBuildingFunction(AvroLogicalType logicalType, AvroReadAction readAction) + { + return switch (logicalType) { + case DateLogicalType _ -> { + if (readAction instanceof IntRead) { + yield new DateBlockBuildingDecoder(); + } + throw new IllegalStateException("Unreachable unfiltered logical type"); + } + case BytesDecimalLogicalType bytesDecimalLogicalType -> { + if (readAction instanceof BytesRead) { + yield new BytesDecimalBlockBuildingDecoder(DecimalType.createDecimalType(bytesDecimalLogicalType.precision(), bytesDecimalLogicalType.scale())); + } + throw new IllegalStateException("Unreachable unfiltered logical type"); + } + case FixedDecimalLogicalType fixedDecimalLogicalType -> { + if (readAction instanceof FixedRead) { + yield new FixedDecimalBlockBuildingDecoder(DecimalType.createDecimalType(fixedDecimalLogicalType.precision(), fixedDecimalLogicalType.scale()), fixedDecimalLogicalType.fixedSize()); + } + throw new IllegalStateException("Unreachable unfiltered logical type"); + } + case TimeMillisLogicalType _ -> { + if (readAction instanceof IntRead) { + yield new TimeMillisBlockBuildingDecoder(); + } + throw new IllegalStateException("Unreachable unfiltered logical type"); + } + case TimeMicrosLogicalType _ -> { + if (readAction instanceof LongRead) { + yield new TimeMicrosBlockBuildingDecoder(); + } + throw new IllegalStateException("Unreachable unfiltered logical type"); + } + case TimestampMillisLogicalType _ -> { + if (readAction instanceof LongRead longRead) { + yield new TimestampMillisBlockBuildingDecoder(longRead); + } + throw new IllegalStateException("Unreachable unfiltered logical type"); + } + case TimestampMicrosLogicalType _ -> { + if (readAction instanceof LongRead longRead) { + yield new TimestampMicrosBlockBuildingDecoder(longRead); + } + throw new IllegalStateException("Unreachable unfiltered logical type"); + } + case StringUUIDLogicalType _ -> { + if (readAction instanceof StringRead) { + yield new StringUUIDBlockBuildingDecoder(); + } + throw new IllegalStateException("Unreachable unfiltered logical type"); + } + }; + } + + public record DateBlockBuildingDecoder() + implements BlockBuildingDecoder + { + @Override + public void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + DATE.writeLong(builder, decoder.readInt()); + } + } + + public record BytesDecimalBlockBuildingDecoder(DecimalType decimalType) + implements BlockBuildingDecoder + { + public BytesDecimalBlockBuildingDecoder + { + requireNonNull(decimalType, "decimalType is null"); + } + + @Override + public void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + // it is only possible to read Bytes type when underlying write type is String or Bytes + // both have the same encoding, so coercion is a no-op + long size = decoder.readLong(); + if (size > MAX_ARRAY_SIZE) { + throw new IOException("Unable to read avro Bytes with size greater than %s. Found Bytes size: %s".formatted(MAX_ARRAY_SIZE, size)); + } + byte[] bytes = new byte[(int) size]; + decoder.readFixed(bytes); + if (decimalType.isShort()) { + decimalType.writeLong(builder, fromBigEndian(bytes)); + } + else { + decimalType.writeObject(builder, Int128.fromBigEndian(bytes)); + } + } + } + + @NotThreadSafe + public static class FixedDecimalBlockBuildingDecoder + implements BlockBuildingDecoder + { + private final DecimalType decimalType; + private final byte[] bytes; + + public FixedDecimalBlockBuildingDecoder(DecimalType decimalType, int fixedSize) + { + this.decimalType = requireNonNull(decimalType, "decimalType is null"); + verify(fixedSize > 0, "fixedSize must be over 0"); + bytes = new byte[fixedSize]; + } + + @Override + public void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + decoder.readFixed(bytes); + if (decimalType.isShort()) { + decimalType.writeLong(builder, fromBigEndian(bytes)); + } + else { + decimalType.writeObject(builder, Int128.fromBigEndian(bytes)); + } + } + } + + public record TimeMillisBlockBuildingDecoder() + implements BlockBuildingDecoder + { + @Override + public void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + TIME_MILLIS.writeLong(builder, ((long) decoder.readInt()) * Timestamps.PICOSECONDS_PER_MILLISECOND); + } + } + + public record TimeMicrosBlockBuildingDecoder() + implements BlockBuildingDecoder + { + @Override + public void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + TIME_MICROS.writeLong(builder, decoder.readLong() * Timestamps.PICOSECONDS_PER_MICROSECOND); + } + } + + public static class TimestampMillisBlockBuildingDecoder + implements BlockBuildingDecoder + { + private final LongIoFunction longDecoder; + + public TimestampMillisBlockBuildingDecoder(LongRead longRead) + { + longDecoder = requireNonNull(longRead, "longRead is null").getLongDecoder(); + } + + @Override + public void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + TIMESTAMP_MILLIS.writeLong(builder, longDecoder.apply(decoder) * Timestamps.MICROSECONDS_PER_MILLISECOND); + } + } + + public static class TimestampMicrosBlockBuildingDecoder + implements BlockBuildingDecoder + { + private final LongIoFunction longDecoder; + + public TimestampMicrosBlockBuildingDecoder(LongRead longRead) + { + longDecoder = requireNonNull(longRead, "longRead is null").getLongDecoder(); + } + + @Override + public void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + TIMESTAMP_MICROS.writeLong(builder, longDecoder.apply(decoder)); + } + } + + public record StringUUIDBlockBuildingDecoder() + implements BlockBuildingDecoder + { + @Override + public void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + // it is only possible to read String type when underlying write type is String or Bytes + // both have the same encoding, so coercion is a no-op + long size = decoder.readLong(); + if (size > MAX_ARRAY_SIZE) { + throw new IOException("Unable to read avro String with size greater than %s. Found String size: %s".formatted(MAX_ARRAY_SIZE, size)); + } + // these bytes represent the string representation of the UUID, not the UUID bytes themselves + byte[] bytes = new byte[(int) size]; + decoder.readFixed(bytes); + UUID.writeSlice(builder, javaUuidToTrinoUuid(java.util.UUID.fromString(new String(bytes, StandardCharsets.UTF_8)))); + } + } +} diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/NativeLogicalTypesAvroTypeManager.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/NativeLogicalTypesAvroTypeManager.java index 4f5f4f4aa627..44efb2e7af1c 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/NativeLogicalTypesAvroTypeManager.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/NativeLogicalTypesAvroTypeManager.java @@ -13,12 +13,19 @@ */ package io.trino.hive.formats.avro; -import com.google.common.base.VerifyException; import com.google.common.primitives.Longs; import io.airlift.log.Logger; import io.airlift.slice.Slice; +import io.trino.hive.formats.avro.model.AvroLogicalType; +import io.trino.hive.formats.avro.model.AvroLogicalType.BytesDecimalLogicalType; +import io.trino.hive.formats.avro.model.AvroLogicalType.DateLogicalType; +import io.trino.hive.formats.avro.model.AvroLogicalType.FixedDecimalLogicalType; +import io.trino.hive.formats.avro.model.AvroLogicalType.StringUUIDLogicalType; +import io.trino.hive.formats.avro.model.AvroLogicalType.TimeMicrosLogicalType; +import io.trino.hive.formats.avro.model.AvroLogicalType.TimeMillisLogicalType; +import io.trino.hive.formats.avro.model.AvroLogicalType.TimestampMicrosLogicalType; +import io.trino.hive.formats.avro.model.AvroLogicalType.TimestampMillisLogicalType; import io.trino.spi.block.Block; -import io.trino.spi.block.BlockBuilder; import io.trino.spi.type.DateType; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Int128; @@ -33,22 +40,28 @@ import org.apache.avro.Schema; import org.apache.avro.SchemaBuilder; import org.apache.avro.generic.GenericData; -import org.apache.avro.generic.GenericFixed; import java.lang.invoke.MethodHandles; import java.lang.invoke.VarHandle; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.Arrays; -import java.util.Map; import java.util.Optional; -import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.Function; import static com.google.common.base.Verify.verify; +import static io.trino.hive.formats.avro.model.AvroLogicalType.DATE; +import static io.trino.hive.formats.avro.model.AvroLogicalType.DECIMAL; +import static io.trino.hive.formats.avro.model.AvroLogicalType.LOCAL_TIMESTAMP_MICROS; +import static io.trino.hive.formats.avro.model.AvroLogicalType.LOCAL_TIMESTAMP_MILLIS; +import static io.trino.hive.formats.avro.model.AvroLogicalType.TIMESTAMP_MICROS; +import static io.trino.hive.formats.avro.model.AvroLogicalType.TIMESTAMP_MILLIS; +import static io.trino.hive.formats.avro.model.AvroLogicalType.TIME_MICROS; +import static io.trino.hive.formats.avro.model.AvroLogicalType.TIME_MILLIS; +import static io.trino.hive.formats.avro.model.AvroLogicalType.UUID; +import static io.trino.hive.formats.avro.model.AvroLogicalType.fromAvroLogicalType; import static io.trino.spi.type.Timestamps.roundDiv; -import static io.trino.spi.type.UuidType.javaUuidToTrinoUuid; import static io.trino.spi.type.UuidType.trinoUuidToJavaUuid; import static java.util.Objects.requireNonNull; import static org.apache.avro.LogicalTypes.fromSchemaIgnoreInvalid; @@ -61,165 +74,120 @@ public class NativeLogicalTypesAvroTypeManager { private static final Logger log = Logger.get(NativeLogicalTypesAvroTypeManager.class); - public static final Schema TIMESTAMP_MILLIS_SCHEMA; - public static final Schema TIMESTAMP_MICROS_SCHEMA; public static final Schema DATE_SCHEMA; public static final Schema TIME_MILLIS_SCHEMA; public static final Schema TIME_MICROS_SCHEMA; + public static final Schema TIMESTAMP_MILLIS_SCHEMA; + public static final Schema TIMESTAMP_MICROS_SCHEMA; public static final Schema UUID_SCHEMA; - // Copied from org.apache.avro.LogicalTypes - protected static final String DECIMAL = "decimal"; - protected static final String UUID = "uuid"; - protected static final String DATE = "date"; - protected static final String TIME_MILLIS = "time-millis"; - protected static final String TIME_MICROS = "time-micros"; - protected static final String TIMESTAMP_MILLIS = "timestamp-millis"; - protected static final String TIMESTAMP_MICROS = "timestamp-micros"; - protected static final String LOCAL_TIMESTAMP_MILLIS = "local-timestamp-millis"; - protected static final String LOCAL_TIMESTAMP_MICROS = "local-timestamp-micros"; - static { - TIMESTAMP_MILLIS_SCHEMA = SchemaBuilder.builder().longType(); - LogicalTypes.timestampMillis().addToSchema(TIMESTAMP_MILLIS_SCHEMA); - TIMESTAMP_MICROS_SCHEMA = SchemaBuilder.builder().longType(); - LogicalTypes.timestampMicros().addToSchema(TIMESTAMP_MICROS_SCHEMA); DATE_SCHEMA = Schema.create(Schema.Type.INT); LogicalTypes.date().addToSchema(DATE_SCHEMA); TIME_MILLIS_SCHEMA = Schema.create(Schema.Type.INT); LogicalTypes.timeMillis().addToSchema(TIME_MILLIS_SCHEMA); TIME_MICROS_SCHEMA = Schema.create(Schema.Type.LONG); LogicalTypes.timeMicros().addToSchema(TIME_MICROS_SCHEMA); + TIMESTAMP_MILLIS_SCHEMA = SchemaBuilder.builder().longType(); + LogicalTypes.timestampMillis().addToSchema(TIMESTAMP_MILLIS_SCHEMA); + TIMESTAMP_MICROS_SCHEMA = SchemaBuilder.builder().longType(); + LogicalTypes.timestampMicros().addToSchema(TIMESTAMP_MICROS_SCHEMA); UUID_SCHEMA = Schema.create(Schema.Type.STRING); LogicalTypes.uuid().addToSchema(UUID_SCHEMA); } - @Override - public void configure(Map fileMetadata) {} - - @Override - public Optional overrideTypeForSchema(Schema schema) - throws AvroTypeException - { - return validateAndLogIssues(schema).map(NativeLogicalTypesAvroTypeManager::getAvroLogicalTypeSpiType); - } - - @Override - public Optional> overrideBuildingFunctionForSchema(Schema schema) - throws AvroTypeException - { - return validateAndLogIssues(schema).map(logicalType -> getLogicalTypeBuildingFunction(logicalType, schema)); - } - @Override public Optional> overrideBlockToAvroObject(Schema schema, Type type) throws AvroTypeException { - Optional logicalType = validateAndLogIssues(schema); + Optional logicalType = validateAndLogIssues(schema); if (logicalType.isEmpty()) { return Optional.empty(); } return Optional.of(getAvroFunction(logicalType.get(), schema, type)); } - private static Type getAvroLogicalTypeSpiType(LogicalType logicalType) + static Type getAvroLogicalTypeSpiType(AvroLogicalType avroLogicalType) + { + return switch (avroLogicalType) { + case DateLogicalType __ -> DateType.DATE; + case BytesDecimalLogicalType bytesDecimalLogicalType -> DecimalType.createDecimalType(bytesDecimalLogicalType.precision(), bytesDecimalLogicalType.scale()); + case FixedDecimalLogicalType fixedDecimalLogicalType -> DecimalType.createDecimalType(fixedDecimalLogicalType.precision(), fixedDecimalLogicalType.scale()); + case TimeMillisLogicalType __ -> TimeType.TIME_MILLIS; + case TimeMicrosLogicalType __ -> TimeType.TIME_MICROS; + case TimestampMillisLogicalType __ -> TimestampType.TIMESTAMP_MILLIS; + case TimestampMicrosLogicalType __ -> TimestampType.TIMESTAMP_MICROS; + case StringUUIDLogicalType __ -> UuidType.UUID; + }; + } + + static Type getAvroLogicalTypeSpiType(LogicalType logicalType) { return switch (logicalType.getName()) { - case TIMESTAMP_MILLIS -> TimestampType.TIMESTAMP_MILLIS; - case TIMESTAMP_MICROS -> TimestampType.TIMESTAMP_MICROS; + case DATE -> DateType.DATE; case DECIMAL -> { LogicalTypes.Decimal decimal = (LogicalTypes.Decimal) logicalType; yield DecimalType.createDecimalType(decimal.getPrecision(), decimal.getScale()); } - case DATE -> DateType.DATE; case TIME_MILLIS -> TimeType.TIME_MILLIS; case TIME_MICROS -> TimeType.TIME_MICROS; + case TIMESTAMP_MILLIS -> TimestampType.TIMESTAMP_MILLIS; + case TIMESTAMP_MICROS -> TimestampType.TIMESTAMP_MICROS; case UUID -> UuidType.UUID; default -> throw new IllegalStateException("Unreachable unfiltered logical type"); }; } - private static BiConsumer getLogicalTypeBuildingFunction(LogicalType logicalType, Schema schema) + private static BiFunction getAvroFunction(AvroLogicalType logicalType, Schema schema, Type type) + throws AvroTypeException { - return switch (logicalType.getName()) { - case TIMESTAMP_MILLIS -> { - if (schema.getType() == Schema.Type.LONG) { - yield (builder, obj) -> { - Long l = (Long) obj; - TimestampType.TIMESTAMP_MILLIS.writeLong(builder, l * Timestamps.MICROSECONDS_PER_MILLISECOND); - }; + return switch (logicalType) { + case DateLogicalType _ -> { + if (type != DateType.DATE) { + throw new AvroTypeException("Can't represent Avro logical type %s with Trino Type %s".formatted(logicalType, type)); } - throw new IllegalStateException("Unreachable unfiltered logical type"); + yield DateType.DATE::getLong; } - case TIMESTAMP_MICROS -> { - if (schema.getType() == Schema.Type.LONG) { - yield (builder, obj) -> { - Long l = (Long) obj; - TimestampType.TIMESTAMP_MICROS.writeLong(builder, l); - }; + case BytesDecimalLogicalType _ -> { + DecimalType decimalType = (DecimalType) getAvroLogicalTypeSpiType(logicalType); + if (decimalType.isShort()) { + yield (block, pos) -> ByteBuffer.wrap(Longs.toByteArray(decimalType.getLong(block, pos))); + } + else { + yield (block, pos) -> ByteBuffer.wrap(((Int128) decimalType.getObject(block, pos)).toBigEndianBytes()); } - throw new IllegalStateException("Unreachable unfiltered logical type"); } - case DECIMAL -> { - LogicalTypes.Decimal decimal = (LogicalTypes.Decimal) logicalType; - DecimalType decimalType = DecimalType.createDecimalType(decimal.getPrecision(), decimal.getScale()); - Function byteExtract = switch (schema.getType()) { - case BYTES -> // This is only safe because we don't reuse byte buffer objects which means each gets sized exactly for the bytes contained - (obj) -> ((ByteBuffer) obj).array(); - case FIXED -> (obj) -> ((GenericFixed) obj).bytes(); - default -> throw new IllegalStateException("Unreachable unfiltered logical type"); - }; + case FixedDecimalLogicalType _ -> { + DecimalType decimalType = (DecimalType) getAvroLogicalTypeSpiType(logicalType); + Function wrapBytes = bytes -> new GenericData.Fixed(schema, fitBigEndianValueToByteArraySize(bytes, schema.getFixedSize())); if (decimalType.isShort()) { - yield (builder, obj) -> decimalType.writeLong(builder, fromBigEndian(byteExtract.apply(obj))); + yield (block, pos) -> wrapBytes.apply(Longs.toByteArray(decimalType.getLong(block, pos))); } else { - yield (builder, obj) -> decimalType.writeObject(builder, Int128.fromBigEndian(byteExtract.apply(obj))); + yield (block, pos) -> wrapBytes.apply(((Int128) decimalType.getObject(block, pos)).toBigEndianBytes()); } } - case DATE -> { - if (schema.getType() == Schema.Type.INT) { - yield (builder, obj) -> { - Integer i = (Integer) obj; - DateType.DATE.writeLong(builder, i.longValue()); - }; + case TimeMillisLogicalType _ -> { + if (!(type instanceof TimeType timeType)) { + throw new AvroTypeException("Can't represent Avro logical type %s with Trino Type %s".formatted(logicalType, type)); } - throw new IllegalStateException("Unreachable unfiltered logical type"); - } - case TIME_MILLIS -> { - if (schema.getType() == Schema.Type.INT) { - yield (builder, obj) -> { - Integer i = (Integer) obj; - TimeType.TIME_MILLIS.writeLong(builder, i.longValue() * Timestamps.PICOSECONDS_PER_MILLISECOND); - }; + if (timeType.getPrecision() > 3) { + throw new AvroTypeException("Can't write out Avro logical time-millis from Trino Time Type with precision %s".formatted(timeType.getPrecision())); } - throw new IllegalStateException("Unreachable unfiltered logical type"); + yield (block, pos) -> roundDiv(timeType.getLong(block, pos), Timestamps.PICOSECONDS_PER_MILLISECOND); } - case TIME_MICROS -> { - if (schema.getType() == Schema.Type.LONG) { - yield (builder, obj) -> { - Long i = (Long) obj; - TimeType.TIME_MICROS.writeLong(builder, i * Timestamps.PICOSECONDS_PER_MICROSECOND); - }; + case TimeMicrosLogicalType _ -> { + if (!(type instanceof TimeType timeType)) { + throw new AvroTypeException("Can't represent Avro logical type %s with Trino Type %s".formatted(logicalType, type)); } - throw new IllegalStateException("Unreachable unfiltered logical type"); - } - case UUID -> { - if (schema.getType() == Schema.Type.STRING) { - yield (builder, obj) -> UuidType.UUID.writeSlice(builder, javaUuidToTrinoUuid(java.util.UUID.fromString(obj.toString()))); + if (timeType.getPrecision() > 6) { + throw new AvroTypeException("Can't write out Avro logical time-millis from Trino Time Type with precision %s".formatted(timeType.getPrecision())); } - throw new IllegalStateException("Unreachable unfiltered logical type"); + yield (block, pos) -> roundDiv(timeType.getLong(block, pos), Timestamps.PICOSECONDS_PER_MICROSECOND); } - default -> throw new IllegalStateException("Unreachable unfiltered logical type"); - }; - } - - private static BiFunction getAvroFunction(LogicalType logicalType, Schema schema, Type type) - throws AvroTypeException - { - return switch (logicalType.getName()) { - case TIMESTAMP_MILLIS -> { + case TimestampMillisLogicalType _ -> { if (!(type instanceof TimestampType timestampType)) { - throw new AvroTypeException("Can't represent Avro logical type %s with Trino Type %s".formatted(logicalType.getName(), type)); + throw new AvroTypeException("Can't represent Avro logical type %s with Trino Type %s".formatted(logicalType, type)); } if (timestampType.isShort()) { yield (block, integer) -> timestampType.getLong(block, integer) / Timestamps.MICROSECONDS_PER_MILLISECOND; @@ -232,9 +200,9 @@ private static BiFunction getAvroFunction(LogicalType lo }; } } - case TIMESTAMP_MICROS -> { + case TimestampMicrosLogicalType _ -> { if (!(type instanceof TimestampType timestampType)) { - throw new AvroTypeException("Can't represent Avro logical type %s with Trino Type %s".formatted(logicalType.getName(), type)); + throw new AvroTypeException("Can't represent Avro logical type %s with Trino Type %s".formatted(logicalType, type)); } if (timestampType.isShort()) { // Don't use method reference because it causes an NPE in errorprone @@ -248,76 +216,36 @@ private static BiFunction getAvroFunction(LogicalType lo }; } } - case DECIMAL -> { - DecimalType decimalType = (DecimalType) getAvroLogicalTypeSpiType(logicalType); - Function wrapBytes = switch (schema.getType()) { - case BYTES -> ByteBuffer::wrap; - case FIXED -> bytes -> new GenericData.Fixed(schema, fitBigEndianValueToByteArraySize(bytes, schema.getFixedSize())); - default -> throw new VerifyException("Unreachable unfiltered logical type"); - }; - if (decimalType.isShort()) { - yield (block, pos) -> wrapBytes.apply(Longs.toByteArray(decimalType.getLong(block, pos))); - } - else { - yield (block, pos) -> wrapBytes.apply(((Int128) decimalType.getObject(block, pos)).toBigEndianBytes()); - } - } - case DATE -> { - if (type != DateType.DATE) { - throw new AvroTypeException("Can't represent Avro logical type %s with Trino Type %s".formatted(logicalType.getName(), type)); - } - yield DateType.DATE::getLong; - } - case TIME_MILLIS -> { - if (!(type instanceof TimeType timeType)) { - throw new AvroTypeException("Can't represent Avro logical type %s with Trino Type %s".formatted(logicalType.getName(), type)); - } - if (timeType.getPrecision() > 3) { - throw new AvroTypeException("Can't write out Avro logical time-millis from Trino Time Type with precision %s".formatted(timeType.getPrecision())); - } - yield (block, pos) -> roundDiv(timeType.getLong(block, pos), Timestamps.PICOSECONDS_PER_MILLISECOND); - } - case TIME_MICROS -> { - if (!(type instanceof TimeType timeType)) { - throw new AvroTypeException("Can't represent Avro logical type %s with Trino Type %s".formatted(logicalType.getName(), type)); - } - if (timeType.getPrecision() > 6) { - throw new AvroTypeException("Can't write out Avro logical time-millis from Trino Time Type with precision %s".formatted(timeType.getPrecision())); - } - yield (block, pos) -> roundDiv(timeType.getLong(block, pos), Timestamps.PICOSECONDS_PER_MICROSECOND); - } - case UUID -> { + case StringUUIDLogicalType _ -> { if (!(type instanceof UuidType uuidType)) { - throw new AvroTypeException("Can't represent Avro logical type %s with Trino Type %s".formatted(logicalType.getName(), type)); + throw new AvroTypeException("Can't represent Avro logical type %s with Trino Type %s".formatted(logicalType, type)); } yield (block, pos) -> trinoUuidToJavaUuid((Slice) uuidType.getObject(block, pos)).toString(); } - default -> throw new VerifyException("Unreachable unfiltered logical type"); }; } - private Optional validateAndLogIssues(Schema schema) + static Optional validateAndLogIssues(Schema schema) { - // TODO replace with switch sealed class syntax when stable - ValidateLogicalTypeResult logicalTypeResult = validateLogicalType(schema); - if (logicalTypeResult instanceof NoLogicalType ignored) { - return Optional.empty(); - } - if (logicalTypeResult instanceof NonNativeAvroLogicalType ignored) { - log.debug("Unrecognized logical type " + schema); - return Optional.empty(); - } - if (logicalTypeResult instanceof InvalidNativeAvroLogicalType invalidNativeAvroLogicalType) { - log.debug(invalidNativeAvroLogicalType.getCause(), "Invalidly configured native Avro logical type"); - return Optional.empty(); - } - if (logicalTypeResult instanceof ValidNativeAvroLogicalType validNativeAvroLogicalType) { - return Optional.of(validNativeAvroLogicalType.getLogicalType()); + switch (validateLogicalType(schema)) { + case NoLogicalType ignored -> { + return Optional.empty(); + } + case NonNativeAvroLogicalType ignored -> { + log.debug("Unrecognized logical type " + schema); + return Optional.empty(); + } + case InvalidNativeAvroLogicalType invalidNativeAvroLogicalType -> { + log.debug(invalidNativeAvroLogicalType.getCause(), "Invalidly configured native Avro logical type"); + return Optional.empty(); + } + case ValidNativeAvroLogicalType validNativeAvroLogicalType -> { + return Optional.of(validNativeAvroLogicalType.getLogicalType()); + } } - throw new IllegalStateException("Unhandled validate logical type result"); } - protected static ValidateLogicalTypeResult validateLogicalType(Schema schema) + public static ValidateLogicalTypeResult validateLogicalType(Schema schema) { final String typeName = schema.getProp(LogicalType.LOGICAL_TYPE_PROP); if (typeName == null) { @@ -325,7 +253,7 @@ protected static ValidateLogicalTypeResult validateLogicalType(Schema schema) } LogicalType logicalType; switch (typeName) { - case TIMESTAMP_MILLIS, TIMESTAMP_MICROS, DECIMAL, DATE, TIME_MILLIS, TIME_MICROS, UUID: + case DATE, DECIMAL, TIME_MILLIS, TIME_MICROS, TIMESTAMP_MILLIS, TIMESTAMP_MICROS, UUID: logicalType = fromSchemaIgnoreInvalid(schema); break; case LOCAL_TIMESTAMP_MICROS + LOCAL_TIMESTAMP_MILLIS: @@ -342,14 +270,14 @@ protected static ValidateLogicalTypeResult validateLogicalType(Schema schema) catch (RuntimeException e) { return new InvalidNativeAvroLogicalType(typeName, e); } - return new ValidNativeAvroLogicalType(logicalType); + return new ValidNativeAvroLogicalType(fromAvroLogicalType(logicalType, schema)); } else { return new NonNativeAvroLogicalType(typeName); } } - protected abstract static sealed class ValidateLogicalTypeResult + public abstract static sealed class ValidateLogicalTypeResult permits NoLogicalType, NonNativeAvroLogicalType, InvalidNativeAvroLogicalType, ValidNativeAvroLogicalType {} protected static final class NoLogicalType @@ -397,14 +325,14 @@ public RuntimeException getCause() protected static final class ValidNativeAvroLogicalType extends ValidateLogicalTypeResult { - private final LogicalType logicalType; + private final AvroLogicalType logicalType; - public ValidNativeAvroLogicalType(LogicalType logicalType) + public ValidNativeAvroLogicalType(AvroLogicalType logicalType) { this.logicalType = requireNonNull(logicalType, "logicalType is null"); } - public LogicalType getLogicalType() + public AvroLogicalType getLogicalType() { return logicalType; } diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/RowBlockBuildingDecoder.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/RowBlockBuildingDecoder.java new file mode 100644 index 000000000000..2e068b3bdcb0 --- /dev/null +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/RowBlockBuildingDecoder.java @@ -0,0 +1,170 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.hive.formats.avro; + +import io.trino.hive.formats.avro.model.AvroReadAction; +import io.trino.hive.formats.avro.model.DefaultValueFieldRecordFieldReadAction; +import io.trino.hive.formats.avro.model.ReadFieldAction; +import io.trino.hive.formats.avro.model.RecordFieldReadAction; +import io.trino.hive.formats.avro.model.RecordReadAction; +import io.trino.hive.formats.avro.model.SkipFieldRecordFieldReadAction; +import io.trino.hive.formats.avro.model.SkipFieldRecordFieldReadAction.SkipAction; +import io.trino.spi.PageBuilder; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.RowBlockBuilder; +import org.apache.avro.Resolver; +import org.apache.avro.Schema; +import org.apache.avro.generic.GenericData; +import org.apache.avro.io.BinaryDecoder; +import org.apache.avro.io.Decoder; +import org.apache.avro.io.DecoderFactory; + +import java.io.IOException; +import java.util.function.IntFunction; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public class RowBlockBuildingDecoder + implements BlockBuildingDecoder +{ + private final RowBuildingAction[] buildSteps; + + public RowBlockBuildingDecoder(Schema writeSchema, Schema readSchema, AvroTypeBlockHandler typeManager) + throws AvroTypeException + { + this(AvroReadAction.fromAction(Resolver.resolve(writeSchema, readSchema, new GenericData())), typeManager); + } + + public RowBlockBuildingDecoder(AvroReadAction action, AvroTypeBlockHandler typeManager) + throws AvroTypeException + + { + if (!(action instanceof RecordReadAction recordReadAction)) { + throw new AvroTypeException("Write and Read Schemas must be records when building a row block building decoder. Illegal action: " + action); + } + buildSteps = new RowBuildingAction[recordReadAction.fieldReadActions().size()]; + int i = 0; + for (RecordFieldReadAction fieldAction : recordReadAction.fieldReadActions()) { + buildSteps[i] = switch (fieldAction) { + case DefaultValueFieldRecordFieldReadAction defaultValueFieldRecordFieldReadAction -> new ConstantBlockAction( + getDefaultBlockBuilder(defaultValueFieldRecordFieldReadAction.fieldSchema(), + defaultValueFieldRecordFieldReadAction.defaultBytes(), typeManager), + defaultValueFieldRecordFieldReadAction.outputChannel()); + case ReadFieldAction readFieldAction -> new BuildIntoBlockAction(typeManager.blockBuildingDecoderFor(readFieldAction.readAction()), + readFieldAction.outputChannel()); + case SkipFieldRecordFieldReadAction skipFieldRecordFieldReadAction -> new SkipSchemaBuildingAction(skipFieldRecordFieldReadAction.skipAction()); + }; + i++; + } + } + + @Override + public void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + ((RowBlockBuilder) builder).buildEntry(fieldBuilders -> decodeIntoBlockProvided(decoder, fieldBuilders::get)); + } + + protected void decodeIntoPageBuilder(Decoder decoder, PageBuilder builder) + throws IOException + { + builder.declarePosition(); + decodeIntoBlockProvided(decoder, builder::getBlockBuilder); + } + + protected void decodeIntoBlockProvided(Decoder decoder, IntFunction fieldBlockBuilder) + throws IOException + { + for (RowBuildingAction buildStep : buildSteps) { + switch (buildStep) { + case SkipSchemaBuildingAction skipSchemaBuildingAction -> skipSchemaBuildingAction.skip(decoder); + case BuildIntoBlockAction buildIntoBlockAction -> buildIntoBlockAction.decode(decoder, fieldBlockBuilder); + case ConstantBlockAction constantBlockAction -> constantBlockAction.addConstant(fieldBlockBuilder); + } + } + } + + sealed interface RowBuildingAction + permits BuildIntoBlockAction, ConstantBlockAction, SkipSchemaBuildingAction + {} + + private static final class BuildIntoBlockAction + implements RowBuildingAction + { + private final BlockBuildingDecoder delegate; + private final int outputChannel; + + public BuildIntoBlockAction(BlockBuildingDecoder delegate, int outputChannel) + { + this.delegate = requireNonNull(delegate, "delegate is null"); + checkArgument(outputChannel >= 0, "outputChannel must be positive"); + this.outputChannel = outputChannel; + } + + public void decode(Decoder decoder, IntFunction channelSelector) + throws IOException + { + delegate.decodeIntoBlock(decoder, channelSelector.apply(outputChannel)); + } + } + + protected static final class ConstantBlockAction + implements RowBuildingAction + { + private final AvroReadAction.IoConsumer addConstantFunction; + private final int outputChannel; + + public ConstantBlockAction(AvroReadAction.IoConsumer addConstantFunction, int outputChannel) + { + this.addConstantFunction = requireNonNull(addConstantFunction, "addConstantFunction is null"); + checkArgument(outputChannel >= 0, "outputChannel must be positive"); + this.outputChannel = outputChannel; + } + + public void addConstant(IntFunction channelSelector) + throws IOException + { + addConstantFunction.accept(channelSelector.apply(outputChannel)); + } + } + + public static final class SkipSchemaBuildingAction + implements RowBuildingAction + { + private final SkipAction skipAction; + + SkipSchemaBuildingAction(SkipAction skipAction) + { + this.skipAction = requireNonNull(skipAction, "skipAction is null"); + } + + public void skip(Decoder decoder) + throws IOException + { + skipAction.skip(decoder); + } + } + + // Avro supports default values for reader record fields that are missing in the writer schema + // the bytes representing the default field value are passed to a block building decoder + // so that it can pack the block appropriately for the default type. + private static AvroReadAction.IoConsumer getDefaultBlockBuilder(Schema fieldSchema, byte[] defaultBytes, AvroTypeBlockHandler typeManager) + throws AvroTypeException + { + BlockBuildingDecoder buildingDecoder = typeManager.blockBuildingDecoderFor(AvroReadAction.fromAction(Resolver.resolve(fieldSchema, fieldSchema))); + BinaryDecoder reuse = DecoderFactory.get().binaryDecoder(defaultBytes, null); + return blockBuilder -> buildingDecoder.decodeIntoBlock(DecoderFactory.get().binaryDecoder(defaultBytes, reuse), blockBuilder); + } +} diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/model/AvroLogicalType.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/model/AvroLogicalType.java index 8ddc7d1f5a46..11c9327f0e1e 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/model/AvroLogicalType.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/model/AvroLogicalType.java @@ -20,37 +20,32 @@ // Avro logical types supported in Trino natively cross product with supported Avro schema encoding public sealed interface AvroLogicalType permits - AvroLogicalType.TimestampMillisLogicalType, - AvroLogicalType.TimestampMicrosLogicalType, + AvroLogicalType.DateLogicalType, AvroLogicalType.BytesDecimalLogicalType, AvroLogicalType.FixedDecimalLogicalType, - AvroLogicalType.DateLogicalType, AvroLogicalType.TimeMillisLogicalType, AvroLogicalType.TimeMicrosLogicalType, + AvroLogicalType.TimestampMillisLogicalType, + AvroLogicalType.TimestampMicrosLogicalType, AvroLogicalType.StringUUIDLogicalType { // Copied from org.apache.avro.LogicalTypes - String DECIMAL = "decimal"; - String UUID = "uuid"; String DATE = "date"; + String DECIMAL = "decimal"; String TIME_MILLIS = "time-millis"; String TIME_MICROS = "time-micros"; String TIMESTAMP_MILLIS = "timestamp-millis"; String TIMESTAMP_MICROS = "timestamp-micros"; String LOCAL_TIMESTAMP_MILLIS = "local-timestamp-millis"; String LOCAL_TIMESTAMP_MICROS = "local-timestamp-micros"; + String UUID = "uuid"; static AvroLogicalType fromAvroLogicalType(LogicalType logicalType, Schema schema) { switch (logicalType.getName()) { - case TIMESTAMP_MILLIS -> { - if (schema.getType() == Schema.Type.LONG) { - return new TimestampMillisLogicalType(); - } - } - case TIMESTAMP_MICROS -> { - if (schema.getType() == Schema.Type.LONG) { - return new TimestampMicrosLogicalType(); + case DATE -> { + if (schema.getType() == Schema.Type.INT) { + return new DateLogicalType(); } } case DECIMAL -> { @@ -65,11 +60,6 @@ static AvroLogicalType fromAvroLogicalType(LogicalType logicalType, Schema schem default -> throw new IllegalArgumentException("Unsupported Logical Type %s and schema pairing %s".formatted(logicalType.getName(), schema)); } } - case DATE -> { - if (schema.getType() == Schema.Type.INT) { - return new DateLogicalType(); - } - } case TIME_MILLIS -> { if (schema.getType() == Schema.Type.INT) { return new TimeMillisLogicalType(); @@ -80,6 +70,16 @@ static AvroLogicalType fromAvroLogicalType(LogicalType logicalType, Schema schem return new TimeMicrosLogicalType(); } } + case TIMESTAMP_MILLIS -> { + if (schema.getType() == Schema.Type.LONG) { + return new TimestampMillisLogicalType(); + } + } + case TIMESTAMP_MICROS -> { + if (schema.getType() == Schema.Type.LONG) { + return new TimestampMicrosLogicalType(); + } + } case UUID -> { if (schema.getType() == Schema.Type.STRING) { return new StringUUIDLogicalType(); @@ -89,10 +89,7 @@ static AvroLogicalType fromAvroLogicalType(LogicalType logicalType, Schema schem throw new IllegalArgumentException("Unsupported Logical Type %s and schema pairing %s".formatted(logicalType.getName(), schema)); } - record TimestampMillisLogicalType() - implements AvroLogicalType {} - - record TimestampMicrosLogicalType() + record DateLogicalType() implements AvroLogicalType {} record BytesDecimalLogicalType(int precision, int scale) @@ -101,15 +98,18 @@ record BytesDecimalLogicalType(int precision, int scale) record FixedDecimalLogicalType(int precision, int scale, int fixedSize) implements AvroLogicalType {} - record DateLogicalType() - implements AvroLogicalType {} - record TimeMillisLogicalType() implements AvroLogicalType {} record TimeMicrosLogicalType() implements AvroLogicalType {} + record TimestampMillisLogicalType() + implements AvroLogicalType {} + + record TimestampMicrosLogicalType() + implements AvroLogicalType {} + record StringUUIDLogicalType() implements AvroLogicalType {} } diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/model/AvroReadAction.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/model/AvroReadAction.java index 3424e926c3a0..917725ca54f9 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/model/AvroReadAction.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/model/AvroReadAction.java @@ -19,13 +19,20 @@ import io.trino.hive.formats.avro.AvroTypeException; import org.apache.avro.Resolver; import org.apache.avro.Schema; +import org.apache.avro.io.Decoder; +import org.apache.avro.io.Encoder; +import org.apache.avro.io.EncoderFactory; +import org.apache.avro.io.parsing.ResolvingGrammarGenerator; +import org.apache.avro.util.internal.Accessor; +import java.io.ByteArrayOutputStream; +import java.io.IOException; import java.util.Arrays; import java.util.List; import static com.google.common.base.Verify.verify; -import static io.trino.hive.formats.avro.AvroPageDataReader.RowBlockBuildingDecoder.SkipSchemaBuildingAction.createSkipActionForSchema; -import static io.trino.hive.formats.avro.AvroPageDataReader.getDefaultByes; +import static io.trino.hive.formats.avro.model.AvroReadAction.getDefaultByes; +import static io.trino.hive.formats.avro.model.SkipFieldRecordFieldReadAction.createSkipActionForSchema; import static java.util.Objects.requireNonNull; public sealed interface AvroReadAction @@ -47,6 +54,51 @@ public sealed interface AvroReadAction WrittenUnionReadAction, ReadErrorReadAction { + static byte[] getDefaultByes(Schema.Field field) + throws AvroTypeException + { + try { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + Encoder e = EncoderFactory.get().binaryEncoder(out, null); + ResolvingGrammarGenerator.encode(e, field.schema(), Accessor.defaultValue(field)); + e.flush(); + return out.toByteArray(); + } + catch (IOException exception) { + throw new AvroTypeException("Unable to encode to bytes for default value in field " + field, exception); + } + } + + static LongIoFunction getLongDecoderFunction(Schema writerSchema) + { + return switch (writerSchema.getType()) { + case INT -> Decoder::readInt; + case LONG -> Decoder::readLong; + default -> throw new IllegalArgumentException("Cannot promote type %s to long".formatted(writerSchema.getType())); + }; + } + + static FloatIoFunction getFloatDecoderFunction(Schema writerSchema) + { + return switch (writerSchema.getType()) { + case INT -> Decoder::readInt; + case LONG -> Decoder::readLong; + case FLOAT -> Decoder::readFloat; + default -> throw new IllegalArgumentException("Cannot promote type %s to float".formatted(writerSchema.getType())); + }; + } + + static DoubleIoFunction getDoubleDecoderFunction(Schema writerSchema) + { + return switch (writerSchema.getType()) { + case INT -> Decoder::readInt; + case LONG -> Decoder::readLong; + case FLOAT -> Decoder::readFloat; + case DOUBLE -> Decoder::readDouble; + default -> throw new IllegalArgumentException("Cannot promote type %s to double".formatted(writerSchema.getType())); + }; + } + Schema readSchema(); Schema writeSchema(); @@ -180,4 +232,32 @@ private static ReadingUnionReadAction fromReaderUnionAction(Resolver.ReaderUnion { return new ReadingUnionReadAction(readerUnion.reader, readerUnion.writer, readerUnion.firstMatch, fromAction(readerUnion.actualAction)); } + + @FunctionalInterface + interface LongIoFunction + { + long apply(A a) + throws IOException; + } + + @FunctionalInterface + interface FloatIoFunction + { + float apply(A a) + throws IOException; + } + + @FunctionalInterface + interface DoubleIoFunction + { + double apply(A a) + throws IOException; + } + + @FunctionalInterface + interface IoConsumer + { + void accept(A a) + throws IOException; + } } diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/model/DoubleRead.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/model/DoubleRead.java index e4a972b9cefc..c060ec8a635b 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/model/DoubleRead.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/model/DoubleRead.java @@ -13,12 +13,11 @@ */ package io.trino.hive.formats.avro.model; -import io.trino.hive.formats.avro.AvroPageDataReader; import org.apache.avro.Schema; import org.apache.avro.io.Decoder; import static com.google.common.base.Verify.verify; -import static io.trino.hive.formats.avro.AvroPageDataReader.getDoubleDecoderFunction; +import static io.trino.hive.formats.avro.model.AvroReadAction.getDoubleDecoderFunction; import static java.util.Objects.requireNonNull; public final class DoubleRead @@ -26,7 +25,7 @@ public final class DoubleRead { private final Schema readSchema; private final Schema writeSchema; - private final AvroPageDataReader.DoubleIoFunction doubleDecoder; + private final DoubleIoFunction doubleDecoder; DoubleRead(Schema readSchema, Schema writeSchema) { @@ -48,7 +47,7 @@ public Schema writeSchema() return writeSchema; } - public AvroPageDataReader.DoubleIoFunction getDoubleDecoder() + public DoubleIoFunction getDoubleDecoder() { return doubleDecoder; } diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/model/FloatRead.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/model/FloatRead.java index 66f415e55d16..ccef23879efc 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/model/FloatRead.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/model/FloatRead.java @@ -13,12 +13,11 @@ */ package io.trino.hive.formats.avro.model; -import io.trino.hive.formats.avro.AvroPageDataReader; import org.apache.avro.Schema; import org.apache.avro.io.Decoder; import static com.google.common.base.Verify.verify; -import static io.trino.hive.formats.avro.AvroPageDataReader.getFloatDecoderFunction; +import static io.trino.hive.formats.avro.model.AvroReadAction.getFloatDecoderFunction; import static java.util.Objects.requireNonNull; public final class FloatRead @@ -26,7 +25,7 @@ public final class FloatRead { private final Schema readSchema; private final Schema writeSchema; - private final AvroPageDataReader.FloatIoFunction floatDecoder; + private final FloatIoFunction floatDecoder; FloatRead(Schema readSchema, Schema writeSchema) { @@ -48,7 +47,7 @@ public Schema writeSchema() return writeSchema; } - public AvroPageDataReader.FloatIoFunction getFloatDecoder() + public FloatIoFunction getFloatDecoder() { return floatDecoder; } diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/model/LongRead.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/model/LongRead.java index 5ed0d6ff12e3..9767ecfc512d 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/model/LongRead.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/model/LongRead.java @@ -13,12 +13,11 @@ */ package io.trino.hive.formats.avro.model; -import io.trino.hive.formats.avro.AvroPageDataReader; import org.apache.avro.Schema; import org.apache.avro.io.Decoder; import static com.google.common.base.Verify.verify; -import static io.trino.hive.formats.avro.AvroPageDataReader.getLongDecoderFunction; +import static io.trino.hive.formats.avro.model.AvroReadAction.getLongDecoderFunction; import static java.util.Objects.requireNonNull; public final class LongRead @@ -26,7 +25,7 @@ public final class LongRead { private final Schema readSchema; private final Schema writeSchema; - private final AvroPageDataReader.LongIoFunction longDecoder; + private final LongIoFunction longDecoder; LongRead(Schema readSchema, Schema writeSchema) { @@ -48,7 +47,7 @@ public Schema writeSchema() return writeSchema; } - public AvroPageDataReader.LongIoFunction getLongDecoder() + public LongIoFunction getLongDecoder() { return longDecoder; } diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/model/SkipFieldRecordFieldReadAction.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/model/SkipFieldRecordFieldReadAction.java index 621b22a78eaa..cc1029d77365 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/model/SkipFieldRecordFieldReadAction.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/model/SkipFieldRecordFieldReadAction.java @@ -13,7 +13,11 @@ */ package io.trino.hive.formats.avro.model; -import io.trino.hive.formats.avro.AvroPageDataReader.RowBlockBuildingDecoder.SkipSchemaBuildingAction.SkipAction; +import org.apache.avro.Schema; +import org.apache.avro.io.Decoder; + +import java.io.IOException; +import java.util.List; import static java.util.Objects.requireNonNull; @@ -24,4 +28,123 @@ public record SkipFieldRecordFieldReadAction(SkipAction skipAction) { requireNonNull(skipAction, "skipAction is null"); } + + @FunctionalInterface + public interface SkipAction + { + void skip(Decoder decoder) + throws IOException; + } + + public static SkipAction createSkipActionForSchema(Schema schema) + { + return switch (schema.getType()) { + case NULL -> Decoder::readNull; + case BOOLEAN -> Decoder::readBoolean; + case INT -> Decoder::readInt; + case LONG -> Decoder::readLong; + case FLOAT -> Decoder::readFloat; + case DOUBLE -> Decoder::readDouble; + case STRING -> Decoder::skipString; + case BYTES -> Decoder::skipBytes; + case ENUM -> Decoder::readEnum; + case FIXED -> { + int size = schema.getFixedSize(); + yield decoder -> decoder.skipFixed(size); + } + case ARRAY -> new ArraySkipAction(schema.getElementType()); + case MAP -> new MapSkipAction(schema.getValueType()); + case RECORD -> new RecordSkipAction(schema.getFields()); + case UNION -> new UnionSkipAction(schema.getTypes()); + }; + } + + private static class ArraySkipAction + implements SkipAction + { + private final SkipAction elementSkipAction; + + public ArraySkipAction(Schema elementSchema) + { + elementSkipAction = createSkipActionForSchema(requireNonNull(elementSchema, "elementSchema is null")); + } + + @Override + public void skip(Decoder decoder) + throws IOException + { + for (long i = decoder.skipArray(); i != 0; i = decoder.skipArray()) { + for (long j = 0; j < i; j++) { + elementSkipAction.skip(decoder); + } + } + } + } + + private static class MapSkipAction + implements SkipFieldRecordFieldReadAction.SkipAction + { + private final SkipFieldRecordFieldReadAction.SkipAction valueSkipAction; + + public MapSkipAction(Schema valueSchema) + { + valueSkipAction = createSkipActionForSchema(requireNonNull(valueSchema, "valueSchema is null")); + } + + @Override + public void skip(Decoder decoder) + throws IOException + { + for (long i = decoder.skipMap(); i != 0; i = decoder.skipMap()) { + for (long j = 0; j < i; j++) { + decoder.skipString(); // key + valueSkipAction.skip(decoder); // value + } + } + } + } + + private static class RecordSkipAction + implements SkipFieldRecordFieldReadAction.SkipAction + { + private final SkipFieldRecordFieldReadAction.SkipAction[] fieldSkips; + + public RecordSkipAction(List fields) + { + fieldSkips = new SkipFieldRecordFieldReadAction.SkipAction[requireNonNull(fields, "fields is null").size()]; + for (int i = 0; i < fields.size(); i++) { + fieldSkips[i] = createSkipActionForSchema(fields.get(i).schema()); + } + } + + @Override + public void skip(Decoder decoder) + throws IOException + { + for (SkipFieldRecordFieldReadAction.SkipAction fieldSkipAction : fieldSkips) { + fieldSkipAction.skip(decoder); + } + } + } + + private static class UnionSkipAction + implements SkipFieldRecordFieldReadAction.SkipAction + { + private final SkipFieldRecordFieldReadAction.SkipAction[] skipActions; + + private UnionSkipAction(List types) + { + skipActions = new SkipFieldRecordFieldReadAction.SkipAction[requireNonNull(types, "types is null").size()]; + for (int i = 0; i < types.size(); i++) { + skipActions[i] = createSkipActionForSchema(types.get(i)); + } + } + + @Override + public void skip(Decoder decoder) + throws IOException + { + skipActions[decoder.readIndex()].skip(decoder); + } + } } diff --git a/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/BaseAvroTypeBlockHandler.java b/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/BaseAvroTypeBlockHandler.java new file mode 100644 index 000000000000..a1aab978e4c7 --- /dev/null +++ b/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/BaseAvroTypeBlockHandler.java @@ -0,0 +1,44 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.hive.formats.avro; + +import io.trino.hive.formats.avro.model.AvroReadAction; +import io.trino.spi.type.Type; +import org.apache.avro.Schema; + +import java.util.Map; + +import static io.trino.hive.formats.avro.BaseAvroTypeBlockHandlerImpls.baseBlockBuildingDecoderFor; +import static io.trino.hive.formats.avro.BaseAvroTypeBlockHandlerImpls.baseTypeFor; + +public class BaseAvroTypeBlockHandler + implements AvroTypeBlockHandler +{ + @Override + public void configure(Map fileMetadata) {} + + @Override + public Type typeFor(Schema schema) + throws AvroTypeException + { + return baseTypeFor(schema, this); + } + + @Override + public BlockBuildingDecoder blockBuildingDecoderFor(AvroReadAction readAction) + throws AvroTypeException + { + return baseBlockBuildingDecoderFor(readAction, this); + } +} diff --git a/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/NoOpAvroTypeManager.java b/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/NoOpAvroTypeManager.java index cd7ccc22082e..445ca0df6e70 100644 --- a/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/NoOpAvroTypeManager.java +++ b/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/NoOpAvroTypeManager.java @@ -14,13 +14,10 @@ package io.trino.hive.formats.avro; import io.trino.spi.block.Block; -import io.trino.spi.block.BlockBuilder; import io.trino.spi.type.Type; import org.apache.avro.Schema; -import java.util.Map; import java.util.Optional; -import java.util.function.BiConsumer; import java.util.function.BiFunction; public class NoOpAvroTypeManager @@ -30,23 +27,6 @@ public class NoOpAvroTypeManager private NoOpAvroTypeManager() {} - @Override - public void configure(Map fileMetadata) {} - - @Override - public Optional overrideTypeForSchema(Schema schema) - throws AvroTypeException - { - return Optional.empty(); - } - - @Override - public Optional> overrideBuildingFunctionForSchema(Schema schema) - throws AvroTypeException - { - return Optional.empty(); - } - @Override public Optional> overrideBlockToAvroObject(Schema schema, Type type) throws AvroTypeException diff --git a/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/TestAvroBase.java b/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/TestAvroBase.java index 77f35c66966c..02473d28ae4e 100644 --- a/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/TestAvroBase.java +++ b/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/TestAvroBase.java @@ -303,7 +303,7 @@ private void testSerdeCycles(Schema schema, AvroCompressionKind compressionKind) try (AvroFileReader fileReader = new AvroFileReader( createWrittenFileWithData(schema, testRecordsExpected.build(), temp1), schema, - NoOpAvroTypeManager.INSTANCE)) { + new BaseAvroTypeBlockHandler())) { while (fileReader.hasNext()) { pages.add(fileReader.next()); } @@ -316,7 +316,7 @@ private void testSerdeCycles(Schema schema, AvroCompressionKind compressionKind) compressionKind, ImmutableMap.of(), schema.getFields().stream().map(Schema.Field::name).collect(toImmutableList()), - AvroTypeUtils.typeFromAvro(schema, NoOpAvroTypeManager.INSTANCE).getTypeParameters(), false)) { + new BaseAvroTypeBlockHandler().typeFor(schema).getTypeParameters(), false)) { for (Page p : pages.build()) { fileWriter.write(p); } diff --git a/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/TestAvroPageDataReaderWithAvroNativeTypeManagement.java b/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/TestAvroPageDataReaderWithAvroNativeTypeManagement.java index b3131a2cd4d1..70439c083f79 100644 --- a/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/TestAvroPageDataReaderWithAvroNativeTypeManagement.java +++ b/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/TestAvroPageDataReaderWithAvroNativeTypeManagement.java @@ -181,7 +181,7 @@ public void testTypesSimple() throws IOException, AvroTypeException { TrinoInputFile input = createWrittenFileWithData(ALL_SUPPORTED_TYPES_SCHEMA, ImmutableList.of(ALL_SUPPORTED_TYPES_GENERIC_RECORD)); - try (AvroFileReader pageIterator = new AvroFileReader(input, ALL_SUPPORTED_TYPES_SCHEMA, new NativeLogicalTypesAvroTypeManager())) { + try (AvroFileReader pageIterator = new AvroFileReader(input, ALL_SUPPORTED_TYPES_SCHEMA, new NativeLogicalTypesAvroTypeBlockHandler())) { while (pageIterator.hasNext()) { Page p = pageIterator.next(); assertIsAllSupportedTypePage(p); @@ -213,7 +213,7 @@ public void testWithDefaults() .endRecord(); TrinoInputFile input = createWrittenFileWithSchema(10, writeSchema); - try (AvroFileReader avroFileReader = new AvroFileReader(input, schema, new NativeLogicalTypesAvroTypeManager())) { + try (AvroFileReader avroFileReader = new AvroFileReader(input, schema, new NativeLogicalTypesAvroTypeBlockHandler())) { int totalRecords = 0; while (avroFileReader.hasNext()) { Page p = avroFileReader.next(); @@ -251,14 +251,14 @@ public void testWriting() AvroCompressionKind.NULL, ImmutableMap.of(), ALL_SUPPORTED_TYPES_SCHEMA.getFields().stream().map(Schema.Field::name).collect(toImmutableList()), - AvroTypeUtils.typeFromAvro(ALL_SUPPORTED_TYPES_SCHEMA, new NativeLogicalTypesAvroTypeManager()).getTypeParameters(), false)) { + new NativeLogicalTypesAvroTypeBlockHandler().typeFor(ALL_SUPPORTED_TYPES_SCHEMA).getTypeParameters(), false)) { fileWriter.write(ALL_SUPPORTED_PAGE); } try (AvroFileReader fileReader = new AvroFileReader( trinoLocalFilesystem.newInputFile(testLocation), ALL_SUPPORTED_TYPES_SCHEMA, - new NativeLogicalTypesAvroTypeManager())) { + new NativeLogicalTypesAvroTypeBlockHandler())) { assertThat(fileReader.hasNext()).isTrue(); assertIsAllSupportedTypePage(fileReader.next()); assertThat(fileReader.hasNext()).isFalse(); diff --git a/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/TestAvroPageDataReaderWithoutTypeManager.java b/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/TestAvroPageDataReaderWithoutTypeManager.java index 98caf74ace74..d10a7cc80aaa 100644 --- a/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/TestAvroPageDataReaderWithoutTypeManager.java +++ b/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/TestAvroPageDataReaderWithoutTypeManager.java @@ -15,7 +15,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Iterables; import io.airlift.slice.Slice; import io.trino.filesystem.TrinoInputFile; import io.trino.spi.Page; @@ -23,12 +22,10 @@ import io.trino.spi.block.IntArrayBlock; import io.trino.spi.block.LongArrayBlock; import io.trino.spi.block.MapBlock; -import io.trino.spi.block.RowBlock; import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.type.VarcharType; import org.apache.avro.Schema; import org.apache.avro.SchemaBuilder; -import org.apache.avro.generic.GenericData; import org.apache.avro.generic.GenericRecord; import org.apache.avro.util.RandomData; import org.junit.jupiter.api.Test; @@ -44,7 +41,6 @@ import static io.trino.block.BlockAssertions.assertBlockEquals; import static io.trino.block.BlockAssertions.createStringsBlock; -import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; import static org.assertj.core.api.Assertions.assertThat; @@ -60,7 +56,7 @@ public void testAllTypesSimple() throws IOException, AvroTypeException { TrinoInputFile input = createWrittenFileWithData(ALL_TYPES_RECORD_SCHEMA, ImmutableList.of(ALL_TYPES_GENERIC_RECORD)); - try (AvroFileReader avroFileReader = new AvroFileReader(input, ALL_TYPES_RECORD_SCHEMA, NoOpAvroTypeManager.INSTANCE)) { + try (AvroFileReader avroFileReader = new AvroFileReader(input, ALL_TYPES_RECORD_SCHEMA, new BaseAvroTypeBlockHandler())) { int totalRecords = 0; while (avroFileReader.hasNext()) { Page p = avroFileReader.next(); @@ -81,7 +77,7 @@ public void testSchemaWithSkips() int count = ThreadLocalRandom.current().nextInt(10000, 100000); TrinoInputFile input = createWrittenFileWithSchema(count, ALL_TYPES_RECORD_SCHEMA); - try (AvroFileReader avroFileReader = new AvroFileReader(input, readSchema, NoOpAvroTypeManager.INSTANCE)) { + try (AvroFileReader avroFileReader = new AvroFileReader(input, readSchema, new BaseAvroTypeBlockHandler())) { int totalRecords = 0; while (avroFileReader.hasNext()) { Page p = avroFileReader.next(); @@ -108,7 +104,7 @@ public void testSchemaWithDefaults() int count = ThreadLocalRandom.current().nextInt(10000, 100000); TrinoInputFile input = createWrittenFileWithSchema(count, SIMPLE_RECORD_SCHEMA); - try (AvroFileReader avroFileReader = new AvroFileReader(input, readerSchema, NoOpAvroTypeManager.INSTANCE)) { + try (AvroFileReader avroFileReader = new AvroFileReader(input, readerSchema, new BaseAvroTypeBlockHandler())) { int totalRecords = 0; while (avroFileReader.hasNext()) { Page p = avroFileReader.next(); @@ -134,7 +130,7 @@ public void testSchemaWithReorders() { Schema writerSchema = reorderSchema(ALL_TYPES_RECORD_SCHEMA); TrinoInputFile input = createWrittenFileWithData(writerSchema, ImmutableList.of(reorderGenericRecord(writerSchema, ALL_TYPES_GENERIC_RECORD))); - try (AvroFileReader avroFileReader = new AvroFileReader(input, ALL_TYPES_RECORD_SCHEMA, NoOpAvroTypeManager.INSTANCE)) { + try (AvroFileReader avroFileReader = new AvroFileReader(input, ALL_TYPES_RECORD_SCHEMA, new BaseAvroTypeBlockHandler())) { int totalRecords = 0; while (avroFileReader.hasNext()) { Page p = avroFileReader.next(); @@ -178,7 +174,7 @@ public void testPromotions() int count = ThreadLocalRandom.current().nextInt(10000, 100000); TrinoInputFile input = createWrittenFileWithSchema(count, writeSchemaBuilder.endRecord()); - try (AvroFileReader avroFileReader = new AvroFileReader(input, readSchemaBuilder.endRecord(), NoOpAvroTypeManager.INSTANCE)) { + try (AvroFileReader avroFileReader = new AvroFileReader(input, readSchemaBuilder.endRecord(), new BaseAvroTypeBlockHandler())) { int totalRecords = 0; while (avroFileReader.hasNext()) { Page p = avroFileReader.next(); @@ -212,7 +208,7 @@ public void testEnum() //test superset TrinoInputFile input = createWrittenFileWithData(base, ImmutableList.of(expected)); - try (AvroFileReader avroFileReader = new AvroFileReader(input, superSchema, NoOpAvroTypeManager.INSTANCE)) { + try (AvroFileReader avroFileReader = new AvroFileReader(input, superSchema, new BaseAvroTypeBlockHandler())) { int totalRecords = 0; while (avroFileReader.hasNext()) { Page p = avroFileReader.next(); @@ -225,7 +221,7 @@ public void testEnum() //test reordered input = createWrittenFileWithData(base, ImmutableList.of(expected)); - try (AvroFileReader avroFileReader = new AvroFileReader(input, reorderdSchema, NoOpAvroTypeManager.INSTANCE)) { + try (AvroFileReader avroFileReader = new AvroFileReader(input, reorderdSchema, new BaseAvroTypeBlockHandler())) { int totalRecords = 0; while (avroFileReader.hasNext()) { Page p = avroFileReader.next(); @@ -237,140 +233,6 @@ public void testEnum() } } - @Test - public void testCoercionOfUnionToStruct() - throws IOException, AvroTypeException - { - Schema complexUnion = Schema.createUnion(Schema.create(Schema.Type.INT), Schema.create(Schema.Type.STRING), Schema.create(Schema.Type.NULL)); - - Schema readSchema = SchemaBuilder.builder() - .record("testComplexUnions") - .fields() - .name("readStraightUp") - .type(complexUnion) - .noDefault() - .name("readFromReverse") - .type(complexUnion) - .noDefault() - .name("readFromDefault") - .type(complexUnion) - .withDefault(42) - .endRecord(); - - Schema writeSchema = SchemaBuilder.builder() - .record("testComplexUnions") - .fields() - .name("readStraightUp") - .type(complexUnion) - .noDefault() - .name("readFromReverse") - .type(Schema.createUnion(Schema.create(Schema.Type.STRING), Schema.create(Schema.Type.NULL), Schema.create(Schema.Type.INT))) - .noDefault() - .endRecord(); - - GenericRecord stringsOnly = new GenericData.Record(writeSchema); - stringsOnly.put("readStraightUp", "I am in column 0 field 1"); - stringsOnly.put("readFromReverse", "I am in column 1 field 1"); - - GenericRecord ints = new GenericData.Record(writeSchema); - ints.put("readStraightUp", 5); - ints.put("readFromReverse", 21); - - GenericRecord nulls = new GenericData.Record(writeSchema); - nulls.put("readStraightUp", null); - nulls.put("readFromReverse", null); - - TrinoInputFile input = createWrittenFileWithData(writeSchema, ImmutableList.of(stringsOnly, ints, nulls)); - try (AvroFileReader avroFileReader = new AvroFileReader(input, readSchema, NoOpAvroTypeManager.INSTANCE)) { - int totalRecords = 0; - while (avroFileReader.hasNext()) { - Page p = avroFileReader.next(); - assertThat(p.getPositionCount()).withFailMessage("Page Batch should be at least 3").isEqualTo(3); - //check first column - //check first column first row coerced struct - RowBlock readStraightUpStringsOnly = (RowBlock) p.getBlock(0).getSingleValueBlock(0); - assertThat(readStraightUpStringsOnly.getFieldBlocks().size()).isEqualTo(3); // tag, int and string block fields - assertThat(readStraightUpStringsOnly.getFieldBlocks().get(1).isNull(0)).isTrue(); // int field null - assertThat(VARCHAR.getObjectValue(null, readStraightUpStringsOnly.getFieldBlocks().get(2), 0)).isEqualTo("I am in column 0 field 1"); //string field expected value - // check first column second row coerced struct - RowBlock readStraightUpInts = (RowBlock) p.getBlock(0).getSingleValueBlock(1); - assertThat(readStraightUpInts.getFieldBlocks().size()).isEqualTo(3); // tag, int and string block fields - assertThat(readStraightUpInts.getFieldBlocks().get(2).isNull(0)).isTrue(); // string field null - assertThat(INTEGER.getObjectValue(null, readStraightUpInts.getFieldBlocks().get(1), 0)).isEqualTo(5); - - //check first column third row is null - assertThat(p.getBlock(0).isNull(2)).isTrue(); - //check second column - //check second column first row coerced struct - RowBlock readFromReverseStringsOnly = (RowBlock) p.getBlock(1).getSingleValueBlock(0); - assertThat(readFromReverseStringsOnly.getFieldBlocks().size()).isEqualTo(3); // tag, int and string block fields - assertThat(readFromReverseStringsOnly.getFieldBlocks().get(1).isNull(0)).isTrue(); // int field null - assertThat(VARCHAR.getObjectValue(null, readFromReverseStringsOnly.getFieldBlocks().get(2), 0)).isEqualTo("I am in column 1 field 1"); - //check second column second row coerced struct - RowBlock readFromReverseUpInts = (RowBlock) p.getBlock(1).getSingleValueBlock(1); - assertThat(readFromReverseUpInts.getFieldBlocks().size()).isEqualTo(3); // tag, int and string block fields - assertThat(readFromReverseUpInts.getFieldBlocks().get(2).isNull(0)).isTrue(); // string field null - assertThat(INTEGER.getObjectValue(null, readFromReverseUpInts.getFieldBlocks().get(1), 0)).isEqualTo(21); - //check second column third row is null - assertThat(p.getBlock(1).isNull(2)).isTrue(); - - //check third column (default of 42 always) - //check third column first row coerced struct - RowBlock readFromDefaultStringsOnly = (RowBlock) p.getBlock(2).getSingleValueBlock(0); - assertThat(readFromDefaultStringsOnly.getFieldBlocks().size()).isEqualTo(3); // tag, int and string block fields - assertThat(readFromDefaultStringsOnly.getFieldBlocks().get(2).isNull(0)).isTrue(); // string field null - assertThat(INTEGER.getObjectValue(null, readFromDefaultStringsOnly.getFieldBlocks().get(1), 0)).isEqualTo(42); - //check third column second row coerced struct - RowBlock readFromDefaultInts = (RowBlock) p.getBlock(2).getSingleValueBlock(1); - assertThat(readFromDefaultInts.getFieldBlocks().size()).isEqualTo(3); // tag, int and string block fields - assertThat(readFromDefaultInts.getFieldBlocks().get(2).isNull(0)).isTrue(); // string field null - assertThat(INTEGER.getObjectValue(null, readFromDefaultInts.getFieldBlocks().get(1), 0)).isEqualTo(42); - //check third column third row coerced struct - RowBlock readFromDefaultNulls = (RowBlock) p.getBlock(2).getSingleValueBlock(2); - assertThat(readFromDefaultNulls.getFieldBlocks().size()).isEqualTo(3); // int and string block fields - assertThat(readFromDefaultNulls.getFieldBlocks().get(2).isNull(0)).isTrue(); // string field null - assertThat(INTEGER.getObjectValue(null, readFromDefaultNulls.getFieldBlocks().get(1), 0)).isEqualTo(42); - - totalRecords += p.getPositionCount(); - } - assertThat(totalRecords).isEqualTo(3); - } - } - - @Test - public void testRead3UnionWith2UnionDataWith2Union() - throws IOException, AvroTypeException - { - Schema twoUnion = Schema.createUnion(Schema.create(Schema.Type.NULL), Schema.create(Schema.Type.INT)); - Schema threeUnion = Schema.createUnion(Schema.create(Schema.Type.NULL), Schema.create(Schema.Type.INT), Schema.create(Schema.Type.STRING)); - - Schema twoUnionRecord = SchemaBuilder.builder() - .record("aRecord") - .fields() - .name("aField") - .type(twoUnion) - .noDefault() - .endRecord(); - - Schema threeUnionRecord = SchemaBuilder.builder() - .record("aRecord") - .fields() - .name("aField") - .type(threeUnion) - .noDefault() - .endRecord(); - - // write a file with the 3 union schema, using 2 union data - TrinoInputFile inputFile = createWrittenFileWithData(threeUnionRecord, ImmutableList.copyOf(Iterables.transform(new RandomData(twoUnionRecord, 1000), object -> (GenericRecord) object))); - - //read the file with the 2 union schema and ensure that no error thrown - try (AvroFileReader avroFileReader = new AvroFileReader(inputFile, twoUnionRecord, NoOpAvroTypeManager.INSTANCE)) { - while (avroFileReader.hasNext()) { - assertThat(avroFileReader.next()).isNotNull(); - } - } - } - @Test public void testReadUnionWithNonUnionAllCoercions() throws IOException, AvroTypeException @@ -397,7 +259,7 @@ public void testReadUnionWithNonUnionAllCoercions() TrinoInputFile inputFile = createWrittenFileWithSchema(1000, unionRecord); //read the file with the non-union schema and ensure that no error thrown - try (AvroFileReader avroFileReader = new AvroFileReader(inputFile, nonUnionRecord, NoOpAvroTypeManager.INSTANCE)) { + try (AvroFileReader avroFileReader = new AvroFileReader(inputFile, nonUnionRecord, new BaseAvroTypeBlockHandler())) { while (avroFileReader.hasNext()) { assertThat(avroFileReader.next()).isNotNull(); } diff --git a/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/TestAvroPageDataWriterWithoutTypeManager.java b/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/TestAvroPageDataWriterWithoutTypeManager.java index f8c098dd6361..c2ebc1ccf936 100644 --- a/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/TestAvroPageDataWriterWithoutTypeManager.java +++ b/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/TestAvroPageDataWriterWithoutTypeManager.java @@ -75,11 +75,11 @@ private void testAllTypesWriting(Schema writeSchema) AvroCompressionKind.NULL, ImmutableMap.of(), ALL_TYPES_RECORD_SCHEMA.getFields().stream().map(Schema.Field::name).collect(toImmutableList()), - AvroTypeUtils.typeFromAvro(ALL_TYPES_RECORD_SCHEMA, NoOpAvroTypeManager.INSTANCE).getTypeParameters(), false)) { + new BaseAvroTypeBlockHandler().typeFor(ALL_TYPES_RECORD_SCHEMA).getTypeParameters(), false)) { fileWriter.write(ALL_TYPES_PAGE); } - try (AvroFileReader fileReader = new AvroFileReader(trinoLocalFilesystem.newInputFile(tempTestLocation), ALL_TYPES_RECORD_SCHEMA, NoOpAvroTypeManager.INSTANCE)) { + try (AvroFileReader fileReader = new AvroFileReader(trinoLocalFilesystem.newInputFile(tempTestLocation), ALL_TYPES_RECORD_SCHEMA, new BaseAvroTypeBlockHandler())) { assertThat(fileReader.hasNext()).isTrue(); assertIsAllTypesPage(fileReader.next()); assertThat(fileReader.hasNext()).isFalse(); @@ -140,14 +140,14 @@ public void testRLEAndDictionaryBlocks() AvroCompressionKind.NULL, ImmutableMap.of(), testBlocksSchema.getFields().stream().map(Schema.Field::name).collect(toImmutableList()), - AvroTypeUtils.typeFromAvro(testBlocksSchema, NoOpAvroTypeManager.INSTANCE).getTypeParameters(), false)) { + new BaseAvroTypeBlockHandler().typeFor(testBlocksSchema).getTypeParameters(), false)) { avroFileWriter.write(toWrite); } try (AvroFileReader avroFileReader = new AvroFileReader( trinoLocalFilesystem.newInputFile(testLocation), testBlocksSchema, - NoOpAvroTypeManager.INSTANCE)) { + new BaseAvroTypeBlockHandler())) { assertThat(avroFileReader.hasNext()).isTrue(); Page readPage = avroFileReader.next(); assertThat(INTEGER.getInt(readPage.getBlock(0), 0)).isEqualTo(2); @@ -211,7 +211,7 @@ public void testBlockUpcasting() try (AvroFileReader avroFileReader = new AvroFileReader( trinoLocalFilesystem.newInputFile(testLocation), testCastingSchema, - NoOpAvroTypeManager.INSTANCE)) { + new BaseAvroTypeBlockHandler())) { assertThat(avroFileReader.hasNext()).isTrue(); Page readPage = avroFileReader.next(); assertThat(INTEGER.getInt(readPage.getBlock(0), 0)).isEqualTo(1); diff --git a/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/TestHiveAvroTypeBlockHandler.java b/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/TestHiveAvroTypeBlockHandler.java new file mode 100644 index 000000000000..d31f5b2867e5 --- /dev/null +++ b/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/TestHiveAvroTypeBlockHandler.java @@ -0,0 +1,171 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.hive.formats.avro; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import io.trino.filesystem.TrinoInputFile; +import io.trino.spi.Page; +import io.trino.spi.block.RowBlock; +import io.trino.spi.type.TimestampType; +import org.apache.avro.Schema; +import org.apache.avro.SchemaBuilder; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.util.RandomData; +import org.junit.jupiter.api.Test; + +import java.io.IOException; + +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.VarcharType.VARCHAR; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestHiveAvroTypeBlockHandler + extends TestAvroBase +{ + @Test + public void testRead3UnionWith2UnionDataWith2Union() + throws IOException, AvroTypeException + { + Schema twoUnion = Schema.createUnion(Schema.create(Schema.Type.NULL), Schema.create(Schema.Type.INT)); + Schema threeUnion = Schema.createUnion(Schema.create(Schema.Type.NULL), Schema.create(Schema.Type.INT), Schema.create(Schema.Type.STRING)); + + Schema twoUnionRecord = SchemaBuilder.builder() + .record("aRecord") + .fields() + .name("aField") + .type(twoUnion) + .noDefault() + .endRecord(); + + Schema threeUnionRecord = SchemaBuilder.builder() + .record("aRecord") + .fields() + .name("aField") + .type(threeUnion) + .noDefault() + .endRecord(); + + // write a file with the 3 union schema, using 2 union data + TrinoInputFile inputFile = createWrittenFileWithData(threeUnionRecord, ImmutableList.copyOf(Iterables.transform(new RandomData(twoUnionRecord, 1000), object -> (GenericRecord) object))); + + //read the file with the 2 union schema and ensure that no error thrown + try (AvroFileReader avroFileReader = new AvroFileReader(inputFile, twoUnionRecord, new HiveAvroTypeBlockHandler(TimestampType.TIMESTAMP_MILLIS))) { + while (avroFileReader.hasNext()) { + assertThat(avroFileReader.next()).isNotNull(); + } + } + } + + @Test + public void testCoercionOfUnionToStruct() + throws IOException, AvroTypeException + { + Schema complexUnion = Schema.createUnion(Schema.create(Schema.Type.INT), Schema.create(Schema.Type.STRING), Schema.create(Schema.Type.NULL)); + + Schema readSchema = SchemaBuilder.builder() + .record("testComplexUnions") + .fields() + .name("readStraightUp") + .type(complexUnion) + .noDefault() + .name("readFromReverse") + .type(complexUnion) + .noDefault() + .name("readFromDefault") + .type(complexUnion) + .withDefault(42) + .endRecord(); + + Schema writeSchema = SchemaBuilder.builder() + .record("testComplexUnions") + .fields() + .name("readStraightUp") + .type(complexUnion) + .noDefault() + .name("readFromReverse") + .type(Schema.createUnion(Schema.create(Schema.Type.STRING), Schema.create(Schema.Type.NULL), Schema.create(Schema.Type.INT))) + .noDefault() + .endRecord(); + + GenericRecord stringsOnly = new GenericData.Record(writeSchema); + stringsOnly.put("readStraightUp", "I am in column 0 field 1"); + stringsOnly.put("readFromReverse", "I am in column 1 field 1"); + + GenericRecord ints = new GenericData.Record(writeSchema); + ints.put("readStraightUp", 5); + ints.put("readFromReverse", 21); + + GenericRecord nulls = new GenericData.Record(writeSchema); + nulls.put("readStraightUp", null); + nulls.put("readFromReverse", null); + + TrinoInputFile input = createWrittenFileWithData(writeSchema, ImmutableList.of(stringsOnly, ints, nulls)); + try (AvroFileReader avroFileReader = new AvroFileReader(input, readSchema, new HiveAvroTypeBlockHandler(TimestampType.TIMESTAMP_MILLIS))) { + int totalRecords = 0; + while (avroFileReader.hasNext()) { + Page p = avroFileReader.next(); + assertThat(p.getPositionCount()).withFailMessage("Page Batch should be at least 3").isEqualTo(3); + //check first column + //check first column first row coerced struct + RowBlock readStraightUpStringsOnly = (RowBlock) p.getBlock(0).getSingleValueBlock(0); + assertThat(readStraightUpStringsOnly.getFieldBlocks().size()).isEqualTo(3); // tag, int and string block fields + assertThat(readStraightUpStringsOnly.getFieldBlocks().get(1).isNull(0)).isTrue(); // int field null + assertThat(VARCHAR.getObjectValue(null, readStraightUpStringsOnly.getFieldBlocks().get(2), 0)).isEqualTo("I am in column 0 field 1"); //string field expected value + // check first column second row coerced struct + RowBlock readStraightUpInts = (RowBlock) p.getBlock(0).getSingleValueBlock(1); + assertThat(readStraightUpInts.getFieldBlocks().size()).isEqualTo(3); // tag, int and string block fields + assertThat(readStraightUpInts.getFieldBlocks().get(2).isNull(0)).isTrue(); // string field null + assertThat(INTEGER.getObjectValue(null, readStraightUpInts.getFieldBlocks().get(1), 0)).isEqualTo(5); + + //check first column third row is null + assertThat(p.getBlock(0).isNull(2)).isTrue(); + //check second column + //check second column first row coerced struct + RowBlock readFromReverseStringsOnly = (RowBlock) p.getBlock(1).getSingleValueBlock(0); + assertThat(readFromReverseStringsOnly.getFieldBlocks().size()).isEqualTo(3); // tag, int and string block fields + assertThat(readFromReverseStringsOnly.getFieldBlocks().get(1).isNull(0)).isTrue(); // int field null + assertThat(VARCHAR.getObjectValue(null, readFromReverseStringsOnly.getFieldBlocks().get(2), 0)).isEqualTo("I am in column 1 field 1"); + //check second column second row coerced struct + RowBlock readFromReverseUpInts = (RowBlock) p.getBlock(1).getSingleValueBlock(1); + assertThat(readFromReverseUpInts.getFieldBlocks().size()).isEqualTo(3); // tag, int and string block fields + assertThat(readFromReverseUpInts.getFieldBlocks().get(2).isNull(0)).isTrue(); // string field null + assertThat(INTEGER.getObjectValue(null, readFromReverseUpInts.getFieldBlocks().get(1), 0)).isEqualTo(21); + //check second column third row is null + assertThat(p.getBlock(1).isNull(2)).isTrue(); + + //check third column (default of 42 always) + //check third column first row coerced struct + RowBlock readFromDefaultStringsOnly = (RowBlock) p.getBlock(2).getSingleValueBlock(0); + assertThat(readFromDefaultStringsOnly.getFieldBlocks().size()).isEqualTo(3); // tag, int and string block fields + assertThat(readFromDefaultStringsOnly.getFieldBlocks().get(2).isNull(0)).isTrue(); // string field null + assertThat(INTEGER.getObjectValue(null, readFromDefaultStringsOnly.getFieldBlocks().get(1), 0)).isEqualTo(42); + //check third column second row coerced struct + RowBlock readFromDefaultInts = (RowBlock) p.getBlock(2).getSingleValueBlock(1); + assertThat(readFromDefaultInts.getFieldBlocks().size()).isEqualTo(3); // tag, int and string block fields + assertThat(readFromDefaultInts.getFieldBlocks().get(2).isNull(0)).isTrue(); // string field null + assertThat(INTEGER.getObjectValue(null, readFromDefaultInts.getFieldBlocks().get(1), 0)).isEqualTo(42); + //check third column third row coerced struct + RowBlock readFromDefaultNulls = (RowBlock) p.getBlock(2).getSingleValueBlock(2); + assertThat(readFromDefaultNulls.getFieldBlocks().size()).isEqualTo(3); // int and string block fields + assertThat(readFromDefaultNulls.getFieldBlocks().get(2).isNull(0)).isTrue(); // string field null + assertThat(INTEGER.getObjectValue(null, readFromDefaultNulls.getFieldBlocks().get(1), 0)).isEqualTo(42); + + totalRecords += p.getPositionCount(); + } + assertThat(totalRecords).isEqualTo(3); + } + } +} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroFileWriterFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroFileWriterFactory.java index ddca493a7b2d..07c986d363a6 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroFileWriterFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroFileWriterFactory.java @@ -21,6 +21,8 @@ import io.trino.filesystem.TrinoFileSystemFactory; import io.trino.filesystem.TrinoOutputFile; import io.trino.hive.formats.avro.AvroCompressionKind; +import io.trino.hive.formats.avro.HiveAvroTypeBlockHandler; +import io.trino.hive.formats.avro.HiveAvroTypeManager; import io.trino.memory.context.AggregatedMemoryContext; import io.trino.plugin.hive.FileWriter; import io.trino.plugin.hive.HiveCompressionCodec; @@ -52,6 +54,7 @@ import static io.trino.plugin.hive.HiveSessionProperties.getTimestampPrecision; import static io.trino.plugin.hive.util.HiveUtil.getColumnNames; import static io.trino.plugin.hive.util.HiveUtil.getColumnTypes; +import static io.trino.spi.type.TimestampType.createTimestampType; import static java.util.Objects.requireNonNull; public class AvroFileWriterFactory @@ -120,7 +123,8 @@ public Optional createFileWriter( outputFile.create(outputStreamMemoryContext), outputStreamMemoryContext, fileSchema, - new HiveAvroTypeManager(hiveTimestampPrecision), + new HiveAvroTypeManager(), + new HiveAvroTypeBlockHandler(createTimestampType(hiveTimestampPrecision.getPrecision())), rollbackAction, inputColumnNames, inputColumnTypes, diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroHiveFileUtils.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroHiveFileUtils.java index b5ca4bfc242f..0dd25d22f51f 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroHiveFileUtils.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroHiveFileUtils.java @@ -45,17 +45,17 @@ import java.util.concurrent.atomic.AtomicInteger; import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.trino.hive.formats.avro.AvroHiveConstants.CHAR_TYPE_LOGICAL_NAME; +import static io.trino.hive.formats.avro.AvroHiveConstants.SCHEMA_DOC; +import static io.trino.hive.formats.avro.AvroHiveConstants.SCHEMA_LITERAL; +import static io.trino.hive.formats.avro.AvroHiveConstants.SCHEMA_NAME; +import static io.trino.hive.formats.avro.AvroHiveConstants.SCHEMA_NAMESPACE; +import static io.trino.hive.formats.avro.AvroHiveConstants.SCHEMA_NONE; +import static io.trino.hive.formats.avro.AvroHiveConstants.SCHEMA_URL; +import static io.trino.hive.formats.avro.AvroHiveConstants.TABLE_NAME; +import static io.trino.hive.formats.avro.AvroHiveConstants.VARCHAR_AND_CHAR_LOGICAL_TYPE_LENGTH_PROP; +import static io.trino.hive.formats.avro.AvroHiveConstants.VARCHAR_TYPE_LOGICAL_NAME; import static io.trino.plugin.hive.HiveMetadata.TABLE_COMMENT; -import static io.trino.plugin.hive.avro.AvroHiveConstants.CHAR_TYPE_LOGICAL_NAME; -import static io.trino.plugin.hive.avro.AvroHiveConstants.SCHEMA_DOC; -import static io.trino.plugin.hive.avro.AvroHiveConstants.SCHEMA_LITERAL; -import static io.trino.plugin.hive.avro.AvroHiveConstants.SCHEMA_NAME; -import static io.trino.plugin.hive.avro.AvroHiveConstants.SCHEMA_NAMESPACE; -import static io.trino.plugin.hive.avro.AvroHiveConstants.SCHEMA_NONE; -import static io.trino.plugin.hive.avro.AvroHiveConstants.SCHEMA_URL; -import static io.trino.plugin.hive.avro.AvroHiveConstants.TABLE_NAME; -import static io.trino.plugin.hive.avro.AvroHiveConstants.VARCHAR_AND_CHAR_LOGICAL_TYPE_LENGTH_PROP; -import static io.trino.plugin.hive.avro.AvroHiveConstants.VARCHAR_TYPE_LOGICAL_NAME; import static io.trino.plugin.hive.util.HiveUtil.getColumnNames; import static io.trino.plugin.hive.util.HiveUtil.getColumnTypes; import static io.trino.plugin.hive.util.SerdeConstants.LIST_COLUMN_COMMENTS; diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroHiveFileWriter.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroHiveFileWriter.java index f520c368e787..5b8c25d47f86 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroHiveFileWriter.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroHiveFileWriter.java @@ -17,9 +17,9 @@ import com.google.common.io.CountingOutputStream; import io.trino.hive.formats.avro.AvroCompressionKind; import io.trino.hive.formats.avro.AvroFileWriter; +import io.trino.hive.formats.avro.AvroTypeBlockHandler; import io.trino.hive.formats.avro.AvroTypeException; import io.trino.hive.formats.avro.AvroTypeManager; -import io.trino.hive.formats.avro.AvroTypeUtils; import io.trino.memory.context.AggregatedMemoryContext; import io.trino.plugin.hive.FileWriter; import io.trino.spi.Page; @@ -63,6 +63,7 @@ public AvroHiveFileWriter( AggregatedMemoryContext outputStreamMemoryContext, Schema fileSchema, AvroTypeManager typeManager, + AvroTypeBlockHandler avroTypeBlockHandler, Closeable rollbackAction, List inputColumnNames, List inputColumnTypes, @@ -89,7 +90,7 @@ public AvroHiveFileWriter( ImmutableList.Builder blocks = ImmutableList.builder(); for (Map.Entry entry : fields.entrySet()) { outputColumnNames.add(entry.getKey().toLowerCase(Locale.ENGLISH)); - Type type = AvroTypeUtils.typeFromAvro(entry.getValue().schema(), typeManager); + Type type = avroTypeBlockHandler.typeFor(entry.getValue().schema()); outputColumnTypes.add(type); blocks.add(type.createBlockBuilder(null, 1).appendNull().build()); } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroPageSource.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroPageSource.java index 40398b749374..54503fc61779 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroPageSource.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroPageSource.java @@ -16,8 +16,8 @@ import io.airlift.units.DataSize; import io.trino.filesystem.TrinoInputFile; import io.trino.hive.formats.avro.AvroFileReader; +import io.trino.hive.formats.avro.AvroTypeBlockHandler; import io.trino.hive.formats.avro.AvroTypeException; -import io.trino.hive.formats.avro.AvroTypeManager; import io.trino.spi.Page; import io.trino.spi.TrinoException; import io.trino.spi.connector.ConnectorPageSource; @@ -41,7 +41,7 @@ public class AvroPageSource public AvroPageSource( TrinoInputFile inputFile, Schema schema, - AvroTypeManager avroTypeManager, + AvroTypeBlockHandler avroTypeManager, long offset, long length) throws IOException, AvroTypeException diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroPageSourceFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroPageSourceFactory.java index 492517707313..525654d81ef4 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroPageSourceFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroPageSourceFactory.java @@ -23,6 +23,7 @@ import io.trino.filesystem.TrinoInputStream; import io.trino.filesystem.memory.MemoryInputFile; import io.trino.hive.formats.avro.AvroTypeException; +import io.trino.hive.formats.avro.HiveAvroTypeBlockHandler; import io.trino.plugin.hive.AcidInfo; import io.trino.plugin.hive.HiveColumnHandle; import io.trino.plugin.hive.HivePageSourceFactory; @@ -62,6 +63,7 @@ import static io.trino.plugin.hive.avro.AvroHiveFileUtils.wrapInUnionWithNull; import static io.trino.plugin.hive.util.HiveUtil.getDeserializerClassName; import static io.trino.plugin.hive.util.HiveUtil.splitError; +import static io.trino.spi.type.TimestampType.createTimestampType; import static java.lang.Math.min; import static java.util.Objects.requireNonNull; @@ -162,7 +164,7 @@ public Optional createPageSource( nullSchema = nullSchema.name(notAColumnName).type(Schema.create(Schema.Type.NULL)).withDefault(null); } try { - return Optional.of(noProjectionAdaptation(new AvroPageSource(inputFile, nullSchema.endRecord(), new HiveAvroTypeManager(hiveTimestampPrecision), start, length))); + return Optional.of(noProjectionAdaptation(new AvroPageSource(inputFile, nullSchema.endRecord(), new HiveAvroTypeBlockHandler(createTimestampType(hiveTimestampPrecision.getPrecision())), start, length))); } catch (IOException e) { throw new TrinoException(HIVE_CANNOT_OPEN_SPLIT, e); @@ -173,7 +175,7 @@ public Optional createPageSource( } try { - return Optional.of(new ReaderPageSource(new AvroPageSource(inputFile, maskedSchema, new HiveAvroTypeManager(hiveTimestampPrecision), start, length), readerProjections)); + return Optional.of(new ReaderPageSource(new AvroPageSource(inputFile, maskedSchema, new HiveAvroTypeBlockHandler(createTimestampType(hiveTimestampPrecision.getPrecision())), start, length), readerProjections)); } catch (IOException e) { throw new TrinoException(HIVE_CANNOT_OPEN_SPLIT, e); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/HiveAvroTypeManager.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/HiveAvroTypeManager.java deleted file mode 100644 index 9c1bf3af943d..000000000000 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/HiveAvroTypeManager.java +++ /dev/null @@ -1,267 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.hive.avro; - -import io.airlift.slice.Slice; -import io.airlift.slice.Slices; -import io.trino.hive.formats.avro.AvroTypeException; -import io.trino.hive.formats.avro.NativeLogicalTypesAvroTypeManager; -import io.trino.plugin.hive.HiveTimestampPrecision; -import io.trino.spi.block.Block; -import io.trino.spi.block.BlockBuilder; -import io.trino.spi.type.BooleanType; -import io.trino.spi.type.Chars; -import io.trino.spi.type.LongTimestamp; -import io.trino.spi.type.SqlTimestamp; -import io.trino.spi.type.TimestampType; -import io.trino.spi.type.Timestamps; -import io.trino.spi.type.Type; -import io.trino.spi.type.Varchars; -import org.apache.avro.Schema; -import org.joda.time.DateTimeZone; - -import java.nio.charset.StandardCharsets; -import java.time.ZoneId; -import java.util.Map; -import java.util.Optional; -import java.util.TimeZone; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.BiConsumer; -import java.util.function.BiFunction; - -import static io.trino.plugin.hive.avro.AvroHiveConstants.CHAR_TYPE_LOGICAL_NAME; -import static io.trino.plugin.hive.avro.AvroHiveConstants.VARCHAR_AND_CHAR_LOGICAL_TYPE_LENGTH_PROP; -import static io.trino.plugin.hive.avro.AvroHiveConstants.VARCHAR_TYPE_LOGICAL_NAME; -import static io.trino.spi.type.CharType.createCharType; -import static io.trino.spi.type.TimestampType.createTimestampType; -import static io.trino.spi.type.Timestamps.roundDiv; -import static io.trino.spi.type.VarcharType.createVarcharType; -import static java.time.ZoneOffset.UTC; -import static java.util.Objects.requireNonNull; - -public class HiveAvroTypeManager - extends NativeLogicalTypesAvroTypeManager -{ - private final AtomicReference convertToTimezone = new AtomicReference<>(UTC); - private final TimestampType hiveSessionTimestamp; - - public HiveAvroTypeManager(HiveTimestampPrecision hiveTimestampPrecision) - { - hiveSessionTimestamp = createTimestampType(requireNonNull(hiveTimestampPrecision, "hiveTimestampPrecision is null").getPrecision()); - } - - @Override - public void configure(Map fileMetadata) - { - if (fileMetadata.containsKey(AvroHiveConstants.WRITER_TIME_ZONE)) { - convertToTimezone.set(ZoneId.of(new String(fileMetadata.get(AvroHiveConstants.WRITER_TIME_ZONE), StandardCharsets.UTF_8))); - } - else { - // legacy path allows this conversion to be skipped with {@link org.apache.hadoop.conf.Configuration} param - // currently no way to set that configuration in Trino - convertToTimezone.set(TimeZone.getDefault().toZoneId()); - } - } - - @Override - public Optional overrideTypeForSchema(Schema schema) - throws AvroTypeException - { - if (schema.getType() == Schema.Type.NULL) { - // allows of dereference when no base columns from file used - // BooleanType chosen rather arbitrarily to be stuffed with null - // in response to behavior defined by io.trino.tests.product.hive.TestAvroSchemaStrictness.testInvalidUnionDefaults - return Optional.of(BooleanType.BOOLEAN); - } - ValidateLogicalTypeResult result = validateLogicalType(schema); - // mapped in from HiveType translator - // TODO replace with sealed class case match syntax when stable - if (result instanceof NativeLogicalTypesAvroTypeManager.NoLogicalType ignored) { - return Optional.empty(); - } - if (result instanceof NonNativeAvroLogicalType nonNativeAvroLogicalType) { - return switch (nonNativeAvroLogicalType.getLogicalTypeName()) { - case VARCHAR_TYPE_LOGICAL_NAME, CHAR_TYPE_LOGICAL_NAME -> Optional.of(getHiveLogicalVarCharOrCharType(schema, nonNativeAvroLogicalType)); - default -> Optional.empty(); - }; - } - if (result instanceof NativeLogicalTypesAvroTypeManager.InvalidNativeAvroLogicalType invalidNativeAvroLogicalType) { - return switch (invalidNativeAvroLogicalType.getLogicalTypeName()) { - case TIMESTAMP_MILLIS, DATE, DECIMAL -> throw invalidNativeAvroLogicalType.getCause(); - default -> Optional.empty(); // Other logical types ignored by hive/don't map to hive types - }; - } - if (result instanceof NativeLogicalTypesAvroTypeManager.ValidNativeAvroLogicalType validNativeAvroLogicalType) { - return switch (validNativeAvroLogicalType.getLogicalType().getName()) { - case DATE -> super.overrideTypeForSchema(schema); - case TIMESTAMP_MILLIS -> Optional.of(hiveSessionTimestamp); - case DECIMAL -> { - if (schema.getType() == Schema.Type.FIXED) { - // for backwards compatibility - throw new AvroTypeException("Hive does not support fixed decimal types"); - } - yield super.overrideTypeForSchema(schema); - } - default -> Optional.empty(); // Other logical types ignored by hive/don't map to hive types - }; - } - throw new IllegalStateException("Unhandled validate logical type result"); - } - - @Override - public Optional> overrideBuildingFunctionForSchema(Schema schema) - throws AvroTypeException - { - ValidateLogicalTypeResult result = validateLogicalType(schema); - // TODO replace with sealed class case match syntax when stable - if (result instanceof NativeLogicalTypesAvroTypeManager.NoLogicalType ignored) { - return Optional.empty(); - } - if (result instanceof NonNativeAvroLogicalType nonNativeAvroLogicalType) { - return switch (nonNativeAvroLogicalType.getLogicalTypeName()) { - case VARCHAR_TYPE_LOGICAL_NAME, CHAR_TYPE_LOGICAL_NAME -> { - Type type = getHiveLogicalVarCharOrCharType(schema, nonNativeAvroLogicalType); - if (nonNativeAvroLogicalType.getLogicalTypeName().equals(VARCHAR_TYPE_LOGICAL_NAME)) { - yield Optional.of((blockBuilder, obj) -> { - type.writeSlice(blockBuilder, Varchars.truncateToLength(Slices.utf8Slice(obj.toString()), type)); - }); - } - else { - yield Optional.of((blockBuilder, obj) -> { - type.writeSlice(blockBuilder, Chars.truncateToLengthAndTrimSpaces(Slices.utf8Slice(obj.toString()), type)); - }); - } - } - default -> Optional.empty(); - }; - } - if (result instanceof NativeLogicalTypesAvroTypeManager.InvalidNativeAvroLogicalType invalidNativeAvroLogicalType) { - return switch (invalidNativeAvroLogicalType.getLogicalTypeName()) { - case TIMESTAMP_MILLIS, DATE, DECIMAL -> throw invalidNativeAvroLogicalType.getCause(); - default -> Optional.empty(); // Other logical types ignored by hive/don't map to hive types - }; - } - if (result instanceof NativeLogicalTypesAvroTypeManager.ValidNativeAvroLogicalType validNativeAvroLogicalType) { - return switch (validNativeAvroLogicalType.getLogicalType().getName()) { - case TIMESTAMP_MILLIS -> { - if (hiveSessionTimestamp.isShort()) { - yield Optional.of((blockBuilder, obj) -> { - Long millisSinceEpochUTC = (Long) obj; - hiveSessionTimestamp.writeLong(blockBuilder, DateTimeZone.forTimeZone(TimeZone.getTimeZone(convertToTimezone.get())).convertUTCToLocal(millisSinceEpochUTC) * Timestamps.MICROSECONDS_PER_MILLISECOND); - }); - } - else { - yield Optional.of((blockBuilder, obj) -> { - Long millisSinceEpochUTC = (Long) obj; - LongTimestamp longTimestamp = new LongTimestamp(DateTimeZone.forTimeZone(TimeZone.getTimeZone(convertToTimezone.get())).convertUTCToLocal(millisSinceEpochUTC) * Timestamps.MICROSECONDS_PER_MILLISECOND, 0); - hiveSessionTimestamp.writeObject(blockBuilder, longTimestamp); - }); - } - } - case DATE, DECIMAL -> super.overrideBuildingFunctionForSchema(schema); - default -> Optional.empty(); // Other logical types ignored by hive/don't map to hive types - }; - } - throw new IllegalStateException("Unhandled validate logical type result"); - } - - @Override - public Optional> overrideBlockToAvroObject(Schema schema, Type type) - throws AvroTypeException - { - ValidateLogicalTypeResult result = validateLogicalType(schema); - // TODO replace with sealed class case match syntax when stable - if (result instanceof NativeLogicalTypesAvroTypeManager.NoLogicalType ignored) { - return Optional.empty(); - } - if (result instanceof NonNativeAvroLogicalType nonNativeAvroLogicalType) { - return switch (nonNativeAvroLogicalType.getLogicalTypeName()) { - case VARCHAR_TYPE_LOGICAL_NAME, CHAR_TYPE_LOGICAL_NAME -> { - Type expectedType = getHiveLogicalVarCharOrCharType(schema, nonNativeAvroLogicalType); - if (!expectedType.equals(type)) { - throw new AvroTypeException("Type provided for column [%s] is incompatible with type for schema: %s".formatted(type, expectedType)); - } - yield Optional.of((block, pos) -> ((Slice) expectedType.getObject(block, pos)).toStringUtf8()); - } - default -> Optional.empty(); - }; - } - if (result instanceof NativeLogicalTypesAvroTypeManager.InvalidNativeAvroLogicalType invalidNativeAvroLogicalType) { - return switch (invalidNativeAvroLogicalType.getLogicalTypeName()) { - case TIMESTAMP_MILLIS, DATE, DECIMAL -> throw invalidNativeAvroLogicalType.getCause(); - default -> Optional.empty(); // Other logical types ignored by hive/don't map to hive types - }; - } - if (result instanceof NativeLogicalTypesAvroTypeManager.ValidNativeAvroLogicalType validNativeAvroLogicalType) { - return switch (validNativeAvroLogicalType.getLogicalType().getName()) { - case TIMESTAMP_MILLIS -> { - if (!(type instanceof TimestampType timestampType)) { - throw new AvroTypeException("Can't represent avro logical type %s with Trino Type %s".formatted(validNativeAvroLogicalType.getLogicalType().getName(), type)); - } - if (timestampType.isShort()) { - yield Optional.of((block, pos) -> { - long millis = roundDiv(timestampType.getLong(block, pos), Timestamps.MICROSECONDS_PER_MILLISECOND); - // see org.apache.hadoop.hive.serde2.avro.AvroSerializer.serializePrimitive - return DateTimeZone.forTimeZone(TimeZone.getDefault()).convertLocalToUTC(millis, false); - }); - } - else { - yield Optional.of((block, pos) -> - { - SqlTimestamp timestamp = (SqlTimestamp) timestampType.getObject(block, pos); - // see org.apache.hadoop.hive.serde2.avro.AvroSerializer.serializePrimitive - return DateTimeZone.forTimeZone(TimeZone.getDefault()).convertLocalToUTC(timestamp.getMillis(), false); - }); - } - } - case DATE, DECIMAL -> super.overrideBlockToAvroObject(schema, type); - default -> Optional.empty(); // Other logical types ignored by hive/don't map to hive types - }; - } - throw new IllegalStateException("Unhandled validate logical type result"); - } - - private static Type getHiveLogicalVarCharOrCharType(Schema schema, NonNativeAvroLogicalType nonNativeAvroLogicalType) - throws AvroTypeException - { - if (schema.getType() != Schema.Type.STRING) { - throw new AvroTypeException("Unsupported Avro type for Hive Logical Type in schema " + schema); - } - Object maxLengthObject = schema.getObjectProp(VARCHAR_AND_CHAR_LOGICAL_TYPE_LENGTH_PROP); - if (maxLengthObject == null) { - throw new AvroTypeException("Missing property maxLength in schema for Hive Type " + nonNativeAvroLogicalType.getLogicalTypeName()); - } - try { - int maxLength; - if (maxLengthObject instanceof String maxLengthString) { - maxLength = Integer.parseInt(maxLengthString); - } - else if (maxLengthObject instanceof Number maxLengthNumber) { - maxLength = maxLengthNumber.intValue(); - } - else { - throw new AvroTypeException("Unrecognized property type for " + VARCHAR_AND_CHAR_LOGICAL_TYPE_LENGTH_PROP + " in schema " + schema); - } - if (nonNativeAvroLogicalType.getLogicalTypeName().equals(VARCHAR_TYPE_LOGICAL_NAME)) { - return createVarcharType(maxLength); - } - else { - return createCharType(maxLength); - } - } - catch (NumberFormatException numberFormatException) { - throw new AvroTypeException("Property maxLength not convertible to Integer in Hive Logical type schema " + schema); - } - } -} diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/avro/TestAvroSchemaGeneration.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/avro/TestAvroSchemaGeneration.java index cdc380572c74..bd1ab558bb0f 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/avro/TestAvroSchemaGeneration.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/avro/TestAvroSchemaGeneration.java @@ -29,9 +29,9 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static io.trino.hive.formats.avro.AvroHiveConstants.TABLE_NAME; import static io.trino.plugin.hive.HiveType.HIVE_STRING; import static io.trino.plugin.hive.HiveType.toHiveType; -import static io.trino.plugin.hive.avro.AvroHiveConstants.TABLE_NAME; import static io.trino.plugin.hive.util.SerdeConstants.LIST_COLUMNS; import static io.trino.plugin.hive.util.SerdeConstants.LIST_COLUMN_TYPES; import static io.trino.spi.type.VarcharType.VARCHAR;