From 4dde5a74bb2e85d9005bfef21c02897291846261 Mon Sep 17 00:00:00 2001 From: Allison Portis Date: Fri, 16 Feb 2024 13:27:45 -0800 Subject: [PATCH 01/13] [Kernel] Support getting snapshot by version (#2607) --- .../src/main/java/io/delta/kernel/Table.java | 10 +++ .../io/delta/kernel/internal/DeltaErrors.java | 45 ++++++++++ .../io/delta/kernel/internal/TableImpl.java | 6 ++ .../internal/snapshot/SnapshotManager.java | 90 +++++++++++++------ .../internal/SnapshotManagerSuite.scala | 52 ++++++++--- .../defaults/DeltaTableReadsSuite.scala | 90 +++++++++++++++++++ .../kernel/defaults/utils/TestUtils.scala | 12 ++- 7 files changed, 262 insertions(+), 43 deletions(-) create mode 100644 kernel/kernel-api/src/main/java/io/delta/kernel/internal/DeltaErrors.java diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/Table.java b/kernel/kernel-api/src/main/java/io/delta/kernel/Table.java index ae70f6741eb..00a8db763e1 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/Table.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/Table.java @@ -56,4 +56,14 @@ Snapshot getLatestSnapshot(TableClient tableClient) * @return the table path */ String getPath(); + + /** + * Get the snapshot at the given {@code versionId}. + * + * @param tableClient {@link TableClient} instance to use in Delta Kernel. + * @param versionId snapshot version to retrieve + * @return an instance of {@link Snapshot} + */ + Snapshot getSnapshotAtVersion(TableClient tableClient, long versionId) + throws TableNotFoundException; } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/DeltaErrors.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/DeltaErrors.java new file mode 100644 index 00000000000..af8510d5c39 --- /dev/null +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/DeltaErrors.java @@ -0,0 +1,45 @@ +/* + * Copyright (2024) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.internal; + +public final class DeltaErrors { + private DeltaErrors() {} + + // TODO update to be user-facing exception with future exception framework + // (see delta-io/delta#2231) & document in method docs as needed (Table::getSnapshotAtVersion) + public static RuntimeException nonReconstructableStateException( + String tablePath, long version) { + String message = String.format( + "%s: Unable to reconstruct state at version %s as the transaction log has been " + + "truncated due to manual deletion or the log retention policy and checkpoint " + + "retention policy.", + tablePath, + version); + return new RuntimeException(message); + } + + // TODO update to be user-facing exception with future exception framework + // (see delta-io/delta#2231) & document in method docs as needed (Table::getSnapshotAtVersion) + public static RuntimeException nonExistentVersionException( + String tablePath, long versionToLoad, long latestVersion) { + String message = String.format( + "%s: Trying to load a non-existent version %s. The latest version available is %s", + tablePath, + versionToLoad, + latestVersion); + return new RuntimeException(message); + } +} diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TableImpl.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TableImpl.java index c9f7c157dcf..9bd37772f17 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TableImpl.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TableImpl.java @@ -60,4 +60,10 @@ public Snapshot getLatestSnapshot(TableClient tableClient) throws TableNotFoundE public String getPath() { return tablePath; } + + @Override + public Snapshot getSnapshotAtVersion(TableClient tableClient, long versionId) + throws TableNotFoundException { + return snapshotManager.getSnapshotAt(tableClient, versionId); + } } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/snapshot/SnapshotManager.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/snapshot/SnapshotManager.java index dcaba0ce529..ca820a3f6eb 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/snapshot/SnapshotManager.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/snapshot/SnapshotManager.java @@ -33,6 +33,7 @@ import io.delta.kernel.utils.CloseableIterator; import io.delta.kernel.utils.FileStatus; +import io.delta.kernel.internal.DeltaErrors; import io.delta.kernel.internal.SnapshotImpl; import io.delta.kernel.internal.checkpoints.CheckpointInstance; import io.delta.kernel.internal.checkpoints.CheckpointMetaData; @@ -110,6 +111,28 @@ public Snapshot buildLatestSnapshot(TableClient tableClient) return getSnapshotAtInit(tableClient); } + /** + * Construct the snapshot for the given table at the version provided. + * + * @param tableClient Instance of {@link TableClient} to use. + * @param version The snapshot version to construct + * @return a {@link Snapshot} of the table at version {@code version} + * @throws TableNotFoundException + */ + public Snapshot getSnapshotAt( + TableClient tableClient, + Long version) throws TableNotFoundException { + + Optional logSegmentOpt = getLogSegmentForVersion( + tableClient, + Optional.empty(), /* startCheckpointOpt */ + Optional.of(version) /* versionToLoadOpt */); + + return logSegmentOpt + .map(logSegment -> createSnapshot(logSegment, tableClient)) + .orElseThrow(() -> new TableNotFoundException(dataPath.toString())); + } + //////////////////// // Helper Methods // //////////////////// @@ -182,10 +205,12 @@ private Optional> listFromOrNone( * Returns the delta files and checkpoint files starting from the given `startVersion`. * `versionToLoad` is an optional parameter to set the max bound. It's usually used to load a * table snapshot for a specific version. + * If no delta or checkpoint files exist below the versionToLoad and at least one delta file + * exists, throws an exception that the state is not reconstructable. * * @param startVersion the version to start. Inclusive. * @param versionToLoad the optional parameter to set the max version we should return. - * Inclusive. + * Inclusive. Must be >= startVersion if provided. * @return Some array of files found (possibly empty, if no usable commit files are present), or * None if the listing returned no files at all. */ @@ -193,6 +218,14 @@ protected final Optional> listDeltaAndCheckpointFiles( TableClient tableClient, long startVersion, Optional versionToLoad) { + versionToLoad.ifPresent(v -> + checkArgument( + v >= startVersion, + String.format( + "versionToLoad=%s provided is less than startVersion=%s", + v, + startVersion) + )); logger.debug("startVersion: {}, versionToLoad: {}", startVersion, versionToLoad); return listFromOrNone( @@ -221,6 +254,13 @@ protected final Optional> listDeltaAndCheckpointFiles( .orElse(true); if (!versionWithinRange) { + // If we haven't taken any files yet and the first file we see is greater + // than the versionToLoad then the versionToLoad is not reconstructable + // from the existing logs + if (output.isEmpty()) { + throw DeltaErrors.nonReconstructableStateException( + dataPath.toString(), versionToLoad.get()); + } break; } @@ -314,18 +354,17 @@ public Optional getLogSegmentForVersion( TableClient tableClient, Optional startCheckpoint, Optional versionToLoad) { - // List from the starting checkpoint. If a checkpoint doesn't exist, this will still return - // deltaVersion=0. - // TODO when implementing time-travel don't list from startCheckpoint if - // startCheckpoint > versionToLoad + // Only use startCheckpoint if it is <= versionToLoad + Optional startCheckpointToUse = startCheckpoint + .filter(v -> !versionToLoad.isPresent() || v <= versionToLoad.get()); final Optional> newFiles = listDeltaAndCheckpointFiles( tableClient, - startCheckpoint.orElse(0L), + startCheckpointToUse.orElse(0L), // List from 0 if no starting checkpoint versionToLoad); return getLogSegmentForVersion( tableClient, - startCheckpoint, + startCheckpointToUse, versionToLoad, newFiles); } @@ -369,7 +408,7 @@ protected Optional getLogSegmentForVersion( // We can't construct a snapshot because the directory contained no usable commit // files... but we can't return Optional.empty either, because it was not truly empty. throw new RuntimeException( - String.format("Empty directory: %s", logPath) + String.format("No delta files found in the directory: %s", logPath) ); } else if (newFiles.isEmpty()) { // The directory may be deleted and recreated and we may have stale state in our @@ -475,6 +514,21 @@ protected Optional getLogSegmentForVersion( Arrays.toString(deltaVersionsAfterCheckpoint.toArray()) )); + final long newVersion = deltaVersionsAfterCheckpoint.isEmpty() ? + newCheckpointOpt.get().version : deltaVersionsAfterCheckpoint.getLast(); + + // In the case where `deltasAfterCheckpoint` is empty, `deltas` should still not be empty, + // they may just be before the checkpoint version unless we have a bug in log cleanup. + if (deltas.isEmpty()) { + throw new IllegalStateException( + String.format("Could not find any delta files for version %s", newVersion) + ); + } + + versionToLoadOpt.filter(v -> v != newVersion).ifPresent(v -> { + throw DeltaErrors.nonExistentVersionException(dataPath.toString(), v, newVersion); + }); + // We may just be getting a checkpoint file after the filtering if (!deltaVersionsAfterCheckpoint.isEmpty()) { if (deltaVersionsAfterCheckpoint.getFirst() != newCheckpointVersion + 1) { @@ -492,26 +546,6 @@ protected Optional getLogSegmentForVersion( versionToLoadOpt); } - // TODO: double check newCheckpointOpt.get() won't error out - - final long newVersion = deltaVersionsAfterCheckpoint.isEmpty() ? - newCheckpointOpt.get().version : deltaVersionsAfterCheckpoint.getLast(); - - // In the case where `deltasAfterCheckpoint` is empty, `deltas` should still not be empty, - // they may just be before the checkpoint version unless we have a bug in log cleanup. - if (deltas.isEmpty()) { - throw new IllegalStateException( - String.format("Could not find any delta files for version %s", newVersion) - ); - } - - if (versionToLoadOpt.map(v -> v != newVersion).orElse(false)) { - throw new IllegalStateException( - String.format("Trying to load a non-existent version %s", - versionToLoadOpt.get()) - ); - } - final long lastCommitTimestamp = deltas.get(deltas.size() - 1).getModificationTime(); final List newCheckpointFiles = newCheckpointOpt.map(newCheckpoint -> { diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/SnapshotManagerSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/SnapshotManagerSuite.scala index 2893c1ef11b..f87a06d7f59 100644 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/SnapshotManagerSuite.scala +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/SnapshotManagerSuite.scala @@ -338,6 +338,12 @@ class SnapshotManagerSuite extends AnyFunSuite { checkpointVersions = Seq(10, 20), versionToLoad = Optional.of(15) ) + testWithSingularAndMultipartCheckpoint( + deltaVersions = (0L to 20L), + checkpointVersions = Seq(10, 20), + startCheckpoint = Optional.of(10), + versionToLoad = Optional.of(15) + ) testWithSingularAndMultipartCheckpoint( deltaVersions = (0L to 20L), checkpointVersions = Seq(10, 20), @@ -393,24 +399,29 @@ class SnapshotManagerSuite extends AnyFunSuite { .map(FileStatus.of(_, 10, 10)) testExpectedError[RuntimeException]( files, - // TODO error message is misleading (delta-io/delta#2283) - expectedErrorMessageContains = "Empty directory: /fake/path/to/table/_delta_log" + expectedErrorMessageContains = + "No delta files found in the directory: /fake/path/to/table/_delta_log" + ) + testExpectedError[RuntimeException]( + files, + versionToLoad = Optional.of(5), + expectedErrorMessageContains = + "No delta files found in the directory: /fake/path/to/table/_delta_log" ) } test("getLogSegmentForVersion: versionToLoad higher than possible") { - // TODO throw more informative error message (delta-io/delta#2283) - testExpectedError[IllegalArgumentException]( + testExpectedError[RuntimeException]( files = deltaFileStatuses(Seq(0)), versionToLoad = Optional.of(15), expectedErrorMessageContains = - "Did not get the last delta file version 15 to compute Snapshot" + "Trying to load a non-existent version 15. The latest version available is 0" ) - testExpectedError[IllegalArgumentException]( + testExpectedError[RuntimeException]( files = deltaFileStatuses((10L until 13L)) ++ singularCheckpointFileStatuses(Seq(10)), versionToLoad = Optional.of(15), expectedErrorMessageContains = - "Did not get the last delta file version 15 to compute Snapshot" + "Trying to load a non-existent version 15. The latest version available is 12" ) } @@ -460,17 +471,16 @@ class SnapshotManagerSuite extends AnyFunSuite { test("getLogSegmentForVersion: versionToLoad not constructable from history") { val files = deltaFileStatuses(20L until 25L) ++ singularCheckpointFileStatuses(Seq(20)) - // TODO this error message is misleading (delta-io/delta#2283) testExpectedError[RuntimeException]( files, versionToLoad = Optional.of(15), - expectedErrorMessageContains = "Empty directory: /fake/path/to/table/_delta_log" + expectedErrorMessageContains = "Unable to reconstruct state at version 15" ) testExpectedError[RuntimeException]( files, startCheckpoint = Optional.of(20), versionToLoad = Optional.of(15), - expectedErrorMessageContains = "Empty directory: /fake/path/to/table/_delta_log" + expectedErrorMessageContains = "Unable to reconstruct state at version 15" ) } @@ -491,6 +501,26 @@ class SnapshotManagerSuite extends AnyFunSuite { } } + test("getLogSegmentForVersion: corrupt listing with missing log files") { + // checkpoint(10), 010.json, 011.json, 013.json + val fileList = deltaFileStatuses(Seq(10L, 11L)) ++ deltaFileStatuses(Seq(13L)) ++ + singularCheckpointFileStatuses(Seq(10L)) + testExpectedError[RuntimeException]( + fileList, + expectedErrorMessageContains = "Versions ([11, 13]) are not continuous" + ) + testExpectedError[RuntimeException]( + fileList, + startCheckpoint = Optional.of(10), + expectedErrorMessageContains = "Versions ([11, 13]) are not continuous" + ) + testExpectedError[RuntimeException]( + fileList, + versionToLoad = Optional.of(13), + expectedErrorMessageContains = "Versions ([11, 13]) are not continuous" + ) + } + // TODO address the inconsistent behaviors and throw better error messages for corrupt listings? // (delta-io/delta#2283) test("getLogSegmentForVersion: corrupt listing 000.json...009.json + checkpoint(10)") { @@ -498,7 +528,7 @@ class SnapshotManagerSuite extends AnyFunSuite { /* ---------- version to load is 15 (greater than latest checkpoint/delta file) ---------- */ // (?) different error messages - testExpectedError[IllegalStateException]( + testExpectedError[RuntimeException]( fileList, versionToLoad = Optional.of(15), expectedErrorMessageContains = "Trying to load a non-existent version 15" diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableReadsSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableReadsSuite.scala index f6409fbe99e..6eae58cf782 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableReadsSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableReadsSuite.scala @@ -460,4 +460,94 @@ class DeltaTableReadsSuite extends AnyFunSuite with TestUtils { } assert(e.getMessage.contains("Unsupported reader protocol version")) } + + ////////////////////////////////////////////////////////////////////////////////// + // getSnapshotAtVersion end-to-end tests (log segment tests in SnapshotManagerSuite) + ////////////////////////////////////////////////////////////////////////////////// + + test("getSnapshotAtVersion: basic end-to-end read") { + withTempDir { tempDir => + val path = tempDir.getCanonicalPath + (0 to 10).foreach { i => + spark.range(i*10, i*10 + 10).write + .format("delta") + .mode("append") + .save(path) + } + // Read a checkpoint version + checkTable( + path = path, + expectedAnswer = (0L to 99L).map(TestRow(_)), + version = Some(9), + expectedVersion = Some(9) + ) + // Read a JSON version + checkTable( + path = path, + expectedAnswer = (0L to 89L).map(TestRow(_)), + version = Some(8), + expectedVersion = Some(8) + ) + // Read the current version + checkTable( + path = path, + expectedAnswer = (0L to 109L).map(TestRow(_)), + version = Some(10), + expectedVersion = Some(10) + ) + // Cannot read a version that does not exist + val e = intercept[RuntimeException] { + Table.forPath(defaultTableClient, path) + .getSnapshotAtVersion(defaultTableClient, 11) + } + assert(e.getMessage.contains( + "Trying to load a non-existent version 11. The latest version available is 10")) + } + } + + test("getSnapshotAtVersion: end-to-end test with truncated delta log") { + withTempDir { tempDir => + val tablePath = tempDir.getCanonicalPath + // Write versions [0, 10] (inclusive) including a checkpoint + (0 to 10).foreach { i => + spark.range(i*10, i*10 + 10).write + .format("delta") + .mode("append") + .save(tablePath) + } + val log = org.apache.spark.sql.delta.DeltaLog.forTable( + spark, new org.apache.hadoop.fs.Path(tablePath)) + // Delete the log files for versions 0-9, truncating the table history to version 10 + (0 to 9).foreach { i => + val jsonFile = org.apache.spark.sql.delta.util.FileNames.deltaFile(log.logPath, i) + new File(new org.apache.hadoop.fs.Path(log.logPath, jsonFile).toUri).delete() + } + // Create version 11 that overwrites the whole table + spark.range(50).write + .format("delta") + .mode("overwrite") + .save(tablePath) + + // Cannot read a version that has been truncated + val e = intercept[RuntimeException] { + Table.forPath(defaultTableClient, tablePath) + .getSnapshotAtVersion(defaultTableClient, 9) + } + assert(e.getMessage.contains("Unable to reconstruct state at version 9")) + // Can read version 10 + checkTable( + path = tablePath, + expectedAnswer = (0L to 109L).map(TestRow(_)), + version = Some(10), + expectedVersion = Some(10) + ) + // Can read version 11 + checkTable( + path = tablePath, + expectedAnswer = (0L until 50L).map(TestRow(_)), + version = Some(11), + expectedVersion = Some(11) + ) + } + } } diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala index aa07f69f92a..0c63a0dfba1 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala @@ -107,9 +107,9 @@ trait TestUtils extends Assertions with SQLHelper { testFunc(tablePath) } - def latestSnapshot(path: String): Snapshot = { - Table.forPath(defaultTableClient, path) - .getLatestSnapshot(defaultTableClient) + def latestSnapshot(path: String, tableClient: TableClient = defaultTableClient): Snapshot = { + Table.forPath(tableClient, path) + .getLatestSnapshot(tableClient) } def collectScanFileRows(scan: Scan, tableClient: TableClient = defaultTableClient): Seq[Row] = { @@ -251,11 +251,15 @@ trait TestUtils extends Assertions with SQLHelper { tableClient: TableClient = defaultTableClient, expectedSchema: StructType = null, filter: Predicate = null, + version: Option[Long] = None, expectedRemainingFilter: Predicate = null, expectedVersion: Option[Long] = None ): Unit = { - val snapshot = latestSnapshot(path) + val snapshot = version.map { v => + Table.forPath(tableClient, path) + .getSnapshotAtVersion(tableClient, v) + }.getOrElse(latestSnapshot(path, tableClient)) val readSchema = if (readCols == null) { null From 8313f0270246aca9d5a77c2b80f3d60b53e5bae6 Mon Sep 17 00:00:00 2001 From: Tai Le <49281946+tlm365@users.noreply.github.com> Date: Tue, 20 Feb 2024 01:51:13 +0700 Subject: [PATCH 02/13] [Kernel][Java to Scala test conversion] Convert `TestDeltaTableReads` written in Java to Scala Resolves #2637 Signed-off-by: Tai Le Manh --- .../defaults/integration/BaseIntegration.java | 228 ------------------ .../integration/TestDeltaTableReads.java | 153 ------------ .../defaults/DeltaTableReadsSuite.scala | 74 ++++++ 3 files changed, 74 insertions(+), 381 deletions(-) delete mode 100644 kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/BaseIntegration.java delete mode 100644 kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/TestDeltaTableReads.java diff --git a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/BaseIntegration.java b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/BaseIntegration.java deleted file mode 100644 index 98990ba7506..00000000000 --- a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/BaseIntegration.java +++ /dev/null @@ -1,228 +0,0 @@ -/* - * Copyright (2023) The Delta Lake Project Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.delta.kernel.defaults.integration; - -import java.util.*; - -import org.apache.hadoop.conf.Configuration; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; - -import io.delta.kernel.Scan; -import io.delta.kernel.Snapshot; -import io.delta.kernel.Table; -import io.delta.kernel.client.TableClient; -import io.delta.kernel.data.ColumnVector; -import io.delta.kernel.data.ColumnarBatch; -import io.delta.kernel.data.FilteredColumnarBatch; -import io.delta.kernel.data.Row; -import io.delta.kernel.types.*; -import io.delta.kernel.utils.CloseableIterator; -import io.delta.kernel.utils.FileStatus; - -import io.delta.kernel.internal.InternalScanFileUtils; -import io.delta.kernel.internal.data.ScanStateRow; -import static io.delta.kernel.internal.util.Utils.singletonCloseableIterator; - -import io.delta.kernel.defaults.client.DefaultTableClient; -import io.delta.kernel.defaults.utils.DefaultKernelTestUtils; - -/** - * Base class containing utility method to write integration tests that read data from - * Delta tables using the Kernel APIs. - */ -public abstract class BaseIntegration { - protected TableClient tableClient = DefaultTableClient.create( - new Configuration() { - { - // Set the batch sizes to small so that we get to test the multiple batch scenarios. - set("delta.kernel.default.parquet.reader.batch-size", "2"); - set("delta.kernel.default.json.reader.batch-size", "2"); - } - }); - - protected Table table(String path) throws Exception { - return Table.forPath(tableClient, path); - } - - protected Snapshot snapshot(String path) throws Exception { - return table(path).getLatestSnapshot(tableClient); - } - - protected List readSnapshot(StructType readSchema, Snapshot snapshot) - throws Exception { - Scan scan = snapshot.getScanBuilder(tableClient) - .withReadSchema(tableClient, readSchema) - .build(); - - Row scanState = scan.getScanState(tableClient); - CloseableIterator scanFileIter = scan.getScanFiles(tableClient); - - return readScanFiles(scanState, scanFileIter); - } - - protected List readScanFiles( - Row scanState, - CloseableIterator scanFilesBatchIter) throws Exception { - List dataBatches = new ArrayList<>(); - try { - StructType physicalReadSchema = - ScanStateRow.getPhysicalDataReadSchema(tableClient, scanState); - while (scanFilesBatchIter.hasNext()) { - FilteredColumnarBatch scanFilesBatch = scanFilesBatchIter.next(); - try (CloseableIterator scanFileRows = scanFilesBatch.getRows()) { - while (scanFileRows.hasNext()) { - Row scanFileRow = scanFileRows.next(); - FileStatus fileStatus = InternalScanFileUtils.getAddFileStatus(scanFileRow); - CloseableIterator physicalDataIter = - tableClient.getParquetHandler() - .readParquetFiles( - singletonCloseableIterator(fileStatus), - physicalReadSchema, - Optional.empty()); - try (CloseableIterator transformedData = - Scan.transformPhysicalData( - tableClient, - scanState, - scanFileRow, - physicalDataIter)) { - while (transformedData.hasNext()) { - FilteredColumnarBatch filteredData = transformedData.next(); - assertFalse(filteredData.getSelectionVector().isPresent()); - dataBatches.add(filteredData.getData()); - } - } - } - } - } - } finally { - scanFilesBatchIter.close(); - } - - return dataBatches; - } - - protected void compareEqualUnorderd(ColumnarBatch expDataBatch, - List actDataBatches) { - Set expDataRowsMatched = new HashSet<>(); - for (int actDataBatchIdx = 0; actDataBatchIdx < actDataBatches.size(); actDataBatchIdx++) { - ColumnarBatch actDataBatch = actDataBatches.get(actDataBatchIdx); - - assertEquals(expDataBatch.getSchema(), actDataBatch.getSchema()); - - for (int actRowIdx = 0; actRowIdx < actDataBatch.getSize(); actRowIdx++) { - boolean matched = false; - for (int expRowIdx = 0; expRowIdx < expDataBatch.getSize(); expRowIdx++) { - // If the row is already matched by another record, don't match again - if (expDataRowsMatched.contains(expRowIdx)) { - continue; - } - - matched = compareRows(expDataBatch, expRowIdx, actDataBatch, actRowIdx); - if (matched) { - expDataRowsMatched.add(expRowIdx); - break; - } - } - assertTrue("Actual data contain a row that is not expected", matched); - } - } - - assertEquals( - "An expected row is not present in the actual data output", - expDataBatch.getSize(), - expDataRowsMatched.size()); - } - - protected boolean compareRows( - ColumnarBatch expDataBatch, - int expRowId, - ColumnarBatch actDataBatch, - int actRowId) { - StructType readSchema = expDataBatch.getSchema(); - - for (int fieldId = 0; fieldId < readSchema.length(); fieldId++) { - DataType fieldDataType = readSchema.at(fieldId).getDataType(); - - ColumnVector expDataVector = expDataBatch.getColumnVector(fieldId); - ColumnVector actDataVector = actDataBatch.getColumnVector(fieldId); - - Object expObject = DefaultKernelTestUtils.getValueAsObject(expDataVector, expRowId); - Object actObject = DefaultKernelTestUtils.getValueAsObject(actDataVector, actRowId); - boolean matched = compareObjects(fieldDataType, expObject, actObject); - if (!matched) { - return false; - } - } - - return true; - } - - protected boolean compareRows(Row exp, Row act) { - assertEquals(exp.getSchema(), act.getSchema()); - for (int fieldId = 0; fieldId < exp.getSchema().length(); fieldId++) { - DataType fileDataType = exp.getSchema().at(fieldId).getDataType(); - - Object expObject = DefaultKernelTestUtils.getValueAsObject(exp, fieldId); - Object actObject = DefaultKernelTestUtils.getValueAsObject(act, fieldId); - boolean matched = compareObjects(fileDataType, expObject, actObject); - if (!matched) { - return false; - } - } - return true; - } - - protected boolean compareArrays(ArrayType dataType, List exp, List act) { - assertEquals(exp.size(), act.size()); - for (int i = 0; i < exp.size(); i++) { - boolean matched = compareObjects(dataType.getElementType(), exp.get(i), act.get(i)); - if (!matched) { - return false; - } - } - return true; - } - - protected boolean compareMaps(MapType dataType, Map exp, Map act) { - assertEquals(exp.size(), act.size()); - Set> expEntrySet = exp.entrySet(); - for (Map.Entry expEntry : expEntrySet) { - // TODO: this doesn't work for key types that don't have equals/hashCode implemented. - K expKey = expEntry.getKey(); - V expValue = expEntry.getValue(); - V actValue = act.get(expKey); - boolean matched = compareObjects(dataType.getValueType(), expValue, actValue); - if (!matched) { - return false; - } - } - return true; - } - - protected boolean compareObjects(DataType dataType, Object exp, Object act) { - boolean matched = Objects.deepEquals(exp, act); - if (dataType instanceof StructType) { - matched = compareRows((Row) exp, (Row) act); - } else if (dataType instanceof ArrayType) { - matched = compareArrays((ArrayType) dataType, (List) exp, (List) act); - } else if (dataType instanceof MapType) { - matched = compareMaps((MapType) dataType, (Map) exp, (Map) act); - } - return matched; - } -} diff --git a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/TestDeltaTableReads.java b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/TestDeltaTableReads.java deleted file mode 100644 index 5f60843e9cd..00000000000 --- a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/TestDeltaTableReads.java +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Copyright (2023) The Delta Lake Project Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.delta.kernel.defaults.integration; - -import java.math.BigDecimal; -import java.util.List; - -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; -import static io.delta.golden.GoldenTableUtils.goldenTablePath; - -import io.delta.kernel.Snapshot; -import io.delta.kernel.client.TableClient; -import io.delta.kernel.data.ColumnarBatch; -import io.delta.kernel.types.*; - -import io.delta.kernel.defaults.client.DefaultTableClient; -import io.delta.kernel.defaults.integration.DataBuilderUtils.TestColumnBatchBuilder; -import static io.delta.kernel.defaults.utils.DefaultKernelTestUtils.getTestResourceFilePath; - -/** - * Test reading Delta lake tables end to end using the Kernel APIs and default {@link TableClient} - * implementation ({@link DefaultTableClient}) - *

- * It uses golden tables generated using the source code here: - * https://github.com/delta-io/delta/blob/master/connectors/golden-tables/src/test/scala/io/delta - * /golden/GoldenTables.scala - */ -public class TestDeltaTableReads - extends BaseIntegration { - @Rule - public ExpectedException expectedEx = ExpectedException.none(); - - @Test - public void tablePrimitives() - throws Exception { - String tablePath = goldenTablePath("data-reader-primitives"); - Snapshot snapshot = snapshot(tablePath); - StructType readSchema = snapshot.getSchema(tableClient); - - List actualData = readSnapshot(readSchema, snapshot); - - TestColumnBatchBuilder builder = DataBuilderUtils.builder(readSchema) - .addAllNullsRow(); - - for (int i = 0; i < 10; i++) { - builder = builder.addRow( - i, - (long) i, - (byte) i, - (short) i, - i % 2 == 0, - (float) i, - (double) i, - String.valueOf(i), - new byte[] {(byte) i, (byte) i,}, - new BigDecimal(i) - ); - } - - ColumnarBatch expData = builder.build(); - compareEqualUnorderd(expData, actualData); - } - - @Test - public void tableWithCheckpoint() - throws Exception { - String tablePath = getTestResourceFilePath("basic-with-checkpoint"); - Snapshot snapshot = snapshot(tablePath); - StructType readSchema = snapshot.getSchema(tableClient); - - List actualData = readSnapshot(readSchema, snapshot); - TestColumnBatchBuilder builder = DataBuilderUtils.builder(readSchema); - for (int i = 0; i < 150; i++) { - builder = builder.addRow((long) i); - } - - ColumnarBatch expData = builder.build(); - compareEqualUnorderd(expData, actualData); - } - - @Test - public void tableWithNameColumnMappingMode() - throws Exception { - String tablePath = getTestResourceFilePath("data-reader-primitives-column-mapping-name"); - Snapshot snapshot = snapshot(tablePath); - StructType readSchema = snapshot.getSchema(tableClient); - - List actualData = readSnapshot(readSchema, snapshot); - - TestColumnBatchBuilder builder = DataBuilderUtils.builder(readSchema) - .addAllNullsRow(); - - for (int i = 0; i < 10; i++) { - builder = builder.addRow( - i, - (long) i, - (byte) i, - (short) i, - i % 2 == 0, - (float) i, - (double) i, - String.valueOf(i), - new byte[] {(byte) i, (byte) i}, - new BigDecimal(i) - ); - } - - ColumnarBatch expData = builder.build(); - compareEqualUnorderd(expData, actualData); - } - - @Test - public void partitionedTableWithColumnMapping() - throws Exception { - String tablePath = - getTestResourceFilePath("data-reader-partition-values-column-mapping-name"); - Snapshot snapshot = snapshot(tablePath); - StructType readSchema = new StructType() - // partition fields - .add("as_int", IntegerType.INTEGER) - .add("as_double", DoubleType.DOUBLE) - // data fields - .add("value", StringType.STRING); - - List actualData = readSnapshot(readSchema, snapshot); - - TestColumnBatchBuilder builder = DataBuilderUtils.builder(readSchema); - - for (int i = 0; i < 2; i++) { - builder = builder.addRow(i, (double) i, String.valueOf(i)); - } - - builder = builder.addRow(null, null, "2"); - - ColumnarBatch expData = builder.build(); - compareEqualUnorderd(expData, actualData); - } -} diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableReadsSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableReadsSuite.scala index 6eae58cf782..edddb40bd3d 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableReadsSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableReadsSuite.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import io.delta.golden.GoldenTableUtils.goldenTablePath import io.delta.kernel.{Table, TableNotFoundException} import io.delta.kernel.defaults.internal.DefaultKernelUtils +import io.delta.kernel.defaults.utils.DefaultKernelTestUtils.getTestResourceFilePath import io.delta.kernel.defaults.utils.{TestRow, TestUtils} import io.delta.kernel.internal.util.InternalUtils.daysSinceEpoch import org.apache.hadoop.shaded.org.apache.commons.io.FileUtils @@ -550,4 +551,77 @@ class DeltaTableReadsSuite extends AnyFunSuite with TestUtils { ) } } + + test("table primitives") { + val expectedAnswer = (0 to 10).map { + case 10 => TestRow(null, null, null, null, null, null, null, null, null, null) + case i => TestRow( + i, + i.toLong, + i.toByte, + i.toShort, + i % 2 == 0, + i.toFloat, + i.toDouble, + i.toString, + Array[Byte](i.toByte, i.toByte), + new BigDecimal(i) + ) + } + + checkTable( + path = goldenTablePath("data-reader-primitives"), + expectedAnswer = expectedAnswer + ) + } + + test("table with checkpoint") { + checkTable( + path = getTestResourceFilePath("basic-with-checkpoint"), + expectedAnswer = (0 until 150).map(i => TestRow(i.toLong)) + ) + } + + test("table with name column mapping mode") { + val expectedAnswer = (0 to 10).map { + case 10 => TestRow(null, null, null, null, null, null, null, null, null, null) + case i => TestRow( + i, + i.toLong, + i.toByte, + i.toShort, + i % 2 == 0, + i.toFloat, + i.toDouble, + i.toString, + Array[Byte](i.toByte, i.toByte), + new BigDecimal(i) + ) + } + + checkTable( + path = getTestResourceFilePath("data-reader-primitives-column-mapping-name"), + expectedAnswer = expectedAnswer + ) + } + + test("partitioned table with column mapping") { + val expectedAnswer = (0 to 2).map { + case 2 => TestRow(null, null, "2") + case i => TestRow(i, i.toDouble, i.toString) + } + val readCols = Seq( + // partition fields + "as_int", + "as_double", + // data fields + "value" + ) + + checkTable( + path = getTestResourceFilePath("data-reader-partition-values-column-mapping-name"), + readCols = readCols, + expectedAnswer = expectedAnswer + ) + } } From efc0e34dd907e257f84eabb8d43f3e0346859cc9 Mon Sep 17 00:00:00 2001 From: Tai Le <49281946+tlm365@users.noreply.github.com> Date: Tue, 20 Feb 2024 02:34:15 +0700 Subject: [PATCH 03/13] [Kernel][Expressions] Adds the `IS_NULL` expression Resolve #2632. Adds the `IS NULL` expression. Signed-off-by: Tai Le Manh --- .../delta/kernel/expressions/Predicate.java | 6 ++++++ .../DefaultExpressionEvaluator.java | 20 +++++++++++++++++++ .../expressions/ExpressionVisitor.java | 4 ++++ .../DefaultExpressionEvaluatorSuite.scala | 12 +++++++++++ 4 files changed, 42 insertions(+) diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Predicate.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Predicate.java index 2eeb98ebeac..969eeaf7bc9 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Predicate.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Predicate.java @@ -97,6 +97,12 @@ *

  • Since version: 3.1.0
  • * * + *
  • Name: IS_NULL + *
      + *
    • SQL semantic: expr IS NULL
    • + *
    • Since version: 3.2.0
    • + *
    + *
  • * * * @since 3.0.0 diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java index dc52e1509f8..ee76c3f5be4 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java @@ -29,6 +29,7 @@ import static io.delta.kernel.internal.util.ExpressionUtils.getLeft; import static io.delta.kernel.internal.util.ExpressionUtils.getRight; +import static io.delta.kernel.internal.util.ExpressionUtils.getUnaryChild; import static io.delta.kernel.internal.util.Preconditions.checkArgument; import io.delta.kernel.defaults.internal.data.vector.DefaultBooleanVector; @@ -232,6 +233,15 @@ ExpressionTransformResult visitIsNotNull(Predicate predicate) { ); } + @Override + ExpressionTransformResult visitIsNull(Predicate predicate) { + Expression child = visit(getUnaryChild(predicate)).expression; + return new ExpressionTransformResult( + new Predicate(predicate.getName(), child), + BooleanType.BOOLEAN + ); + } + @Override ExpressionTransformResult visitCoalesce(ScalarExpression coalesce) { List children = coalesce.getChildren().stream() @@ -513,6 +523,16 @@ ColumnVector visitIsNotNull(Predicate predicate) { ); } + @Override + ColumnVector visitIsNull(Predicate predicate) { + ColumnVector childResult = visit(getUnaryChild(predicate)); + return booleanWrapperVector( + childResult, + rowId -> childResult.isNullAt(rowId), + rowId -> false + ); + } + @Override ColumnVector visitCoalesce(ScalarExpression coalesce) { List childResults = coalesce.getChildren() diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java index d9dc2037beb..bd219f55fda 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java @@ -55,6 +55,8 @@ abstract class ExpressionVisitor { abstract R visitIsNotNull(Predicate predicate); + abstract R visitIsNull(Predicate predicate); + abstract R visitCoalesce(ScalarExpression ifNull); final R visit(Expression expression) { @@ -99,6 +101,8 @@ private R visitScalarExpression(ScalarExpression expression) { return visitNot(new Predicate(name, children)); case "IS_NOT_NULL": return visitIsNotNull(new Predicate(name, children)); + case "IS_NULL": + return visitIsNull(new Predicate(name, children)); case "COALESCE": return visitCoalesce(expression); default: diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala index 8ef56db7a23..01eba16fa26 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala @@ -241,6 +241,18 @@ class DefaultExpressionEvaluatorSuite extends AnyFunSuite with ExpressionSuiteBa checkBooleanVectors(actOutputVector, expOutputVector) } + test("evaluate expression: is null") { + val childColumn = booleanVector(Seq[BooleanJ](true, false, null)) + + val schema = new StructType().add("child", BooleanType.BOOLEAN) + val batch = new DefaultColumnarBatch(childColumn.getSize, schema, Array(childColumn)) + + val isNullExpression = new Predicate("IS_NULL", new Column("child")) + val expOutputVector = booleanVector(Seq[BooleanJ](false, false, true)) + val actOutputVector = evaluator(schema, isNullExpression, BooleanType.BOOLEAN).eval(batch) + checkBooleanVectors(actOutputVector, expOutputVector) + } + test("evaluate expression: coalesce") { val col1 = booleanVector(Seq[BooleanJ](true, null, null, null)) val col2 = booleanVector(Seq[BooleanJ](false, false, null, null)) From 25c44838b4b3457bff6cc010860fe4f2412cf8cd Mon Sep 17 00:00:00 2001 From: Tai Le <49281946+tlm365@users.noreply.github.com> Date: Wed, 21 Feb 2024 01:14:05 +0700 Subject: [PATCH 04/13] [Kernel][Java to Scala test conversion] Convert TestDefaultFileSystemClient written in Java to Scala Resolves #2640 --- .../client/DefaultFileSystemClientSuite.scala | 65 +++++++++++++++++++ .../kernel/defaults/utils/TestUtils.scala | 19 +++++- 2 files changed, 83 insertions(+), 1 deletion(-) create mode 100644 kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/client/DefaultFileSystemClientSuite.scala diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/client/DefaultFileSystemClientSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/client/DefaultFileSystemClientSuite.scala new file mode 100644 index 00000000000..b69183ac0b3 --- /dev/null +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/client/DefaultFileSystemClientSuite.scala @@ -0,0 +1,65 @@ +/* + * Copyright (2024) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.defaults.client + +import java.io.FileNotFoundException + +import scala.collection.mutable.ArrayBuffer + +import io.delta.kernel.defaults.utils.TestUtils +import org.scalatest.funsuite.AnyFunSuite + +class DefaultFileSystemClientSuite extends AnyFunSuite with TestUtils { + + val fsClient = defaultTableClient.getFileSystemClient + + test("list from file") { + val basePath = fsClient.resolvePath(getTestResourceFilePath("json-files")) + val listFrom = fsClient.resolvePath(getTestResourceFilePath("json-files/2.json")) + + val actListOutput = new ArrayBuffer[String]() + val files = fsClient.listFrom(listFrom) + try { + fsClient.listFrom(listFrom).forEach(f => actListOutput += f.getPath) + } + finally if (files != null) { + files.close() + } + + val expListOutput = Seq(basePath + "/2.json", basePath + "/3.json") + + assert(expListOutput === actListOutput) + } + + test("list from non-existent file") { + intercept[FileNotFoundException] { + fsClient.listFrom("file:/non-existentfileTable/01.json") + } + } + + test("resolve path") { + val inputPath = getTestResourceFilePath("json-files") + val resolvedPath = fsClient.resolvePath(inputPath) + + assert("file:" + inputPath === resolvedPath) + } + + test("resolve path on non-existent file") { + intercept[FileNotFoundException] { + fsClient.resolvePath("/non-existentfileTable/01.json") + } + } +} diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala index 0c63a0dfba1..00b47d3625f 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala @@ -15,12 +15,14 @@ */ package io.delta.kernel.defaults.utils -import java.io.File +import java.io.{File, FileNotFoundException} import java.math.{BigDecimal => BigDecimalJ} import java.nio.file.Files import java.util.{Optional, TimeZone, UUID} + import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer + import io.delta.golden.GoldenTableUtils import io.delta.kernel.{Scan, Snapshot, Table} import io.delta.kernel.client.TableClient @@ -102,6 +104,10 @@ trait TestUtils extends Assertions with SQLHelper { } } + implicit object ResourceLoader { + lazy val classLoader: ClassLoader = ResourceLoader.getClass.getClassLoader + } + def withGoldenTable(tableName: String)(testFunc: String => Unit): Unit = { val tablePath = GoldenTableUtils.goldenTablePath(tableName) testFunc(tablePath) @@ -575,4 +581,15 @@ trait TestUtils extends Assertions with SQLHelper { }) } } + + /** + * Returns a URI encoded path of the resource. + */ + def getTestResourceFilePath(resourcePath: String): String = { + val resource = ResourceLoader.classLoader.getResource(resourcePath) + if (resource == null) { + throw new FileNotFoundException("resource not found") + } + resource.getFile + } } From 360e066a9653e5486544c57f2c728b62d5bc3bd1 Mon Sep 17 00:00:00 2001 From: Johan Lasperas Date: Wed, 14 Feb 2024 21:20:50 +0100 Subject: [PATCH 05/13] Factor logic to collect files to REORG out of OPTIMIZE #### Which Delta project/connector is this regarding? -Spark - [ ] Standalone - [ ] Flink - [ ] Kernel - [ ] Other (fill in here) ## Description This is a plain refactor of REORG TABLE / OPTIMIZE to allow for better extendability and adding new types of REORG TABLE operations in the future. The REORG operation currently supports: - PURGE: remove soft deleted rows (DVs) and dropped columns. - UPGRADE UNIFORM: rewrite files to be iceberg compatible. More operations can be used in the future to allow dropping table features, in particular for column mapping: rewrite files to have the physical column names match the logical column name. This a plain refactor without functional changes. Closes delta-io/delta#2616 GitOrigin-RevId: b8e8ad4d148201a33b1fb173ebcfe4ad8b8407ef --- .../commands/DeltaReorgTableCommand.scala | 73 ++++++++++++++----- .../delta/commands/OptimizeTableCommand.scala | 45 ++++-------- .../ReorgTableForUpgradeUniformHelper.scala | 11 +-- .../spark/sql/delta/hooks/AutoCompact.scala | 2 +- 4 files changed, 72 insertions(+), 59 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/commands/DeltaReorgTableCommand.scala b/spark/src/main/scala/org/apache/spark/sql/delta/commands/DeltaReorgTableCommand.scala index 59016b3ff77..c21efc5ddf0 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/commands/DeltaReorgTableCommand.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/commands/DeltaReorgTableCommand.scala @@ -16,13 +16,10 @@ package org.apache.spark.sql.delta.commands -import org.apache.spark.sql.delta.catalog.DeltaTableV2 -import org.apache.spark.sql.delta.sources.DeltaSourceUtils +import org.apache.spark.sql.delta.actions.AddFile import org.apache.spark.sql.{Row, SparkSession} -import org.apache.spark.sql.catalyst.analysis.ResolvedTable import org.apache.spark.sql.catalyst.plans.logical.{IgnoreCachedData, LeafCommand, LogicalPlan, UnaryCommand} -import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog} object DeltaReorgTableMode extends Enumeration { val PURGE, UNIFORM_ICEBERG = Value @@ -47,7 +44,7 @@ case class DeltaReorgTable( } /** - * The PURGE command. + * The REORG TABLE command. */ case class DeltaReorgTableCommand( target: LogicalPlan, @@ -60,30 +57,66 @@ case class DeltaReorgTableCommand( override val otherCopyArgs: Seq[AnyRef] = predicates :: Nil - override def optimizeByReorg( - sparkSession: SparkSession, - isPurge: Boolean, - icebergCompatVersion: Option[Int]): Seq[Row] = { + override def optimizeByReorg(sparkSession: SparkSession): Seq[Row] = { val command = OptimizeTableCommand( target, predicates, optimizeContext = DeltaOptimizeContext( - isPurge = isPurge, + reorg = Some(reorgOperation), minFileSize = Some(0L), - maxDeletedRowsRatio = Some(0d), - icebergCompatVersion = icebergCompatVersion - ) + maxDeletedRowsRatio = Some(0d)) )(zOrderBy = Nil) command.run(sparkSession) } - override def run(sparkSession: SparkSession): Seq[Row] = { - reorgTableSpec match { - case DeltaReorgTableSpec(DeltaReorgTableMode.PURGE, None) => - optimizeByReorg(sparkSession, isPurge = true, icebergCompatVersion = None) - case DeltaReorgTableSpec(DeltaReorgTableMode.UNIFORM_ICEBERG, Some(icebergCompatVersion)) => - val table = getDeltaTable(target, "REORG") - upgradeUniformIcebergCompatVersion(table, sparkSession, icebergCompatVersion) + override def run(sparkSession: SparkSession): Seq[Row] = reorgTableSpec match { + case DeltaReorgTableSpec(DeltaReorgTableMode.PURGE, None) => + optimizeByReorg(sparkSession) + case DeltaReorgTableSpec(DeltaReorgTableMode.UNIFORM_ICEBERG, Some(icebergCompatVersion)) => + val table = getDeltaTable(target, "REORG") + upgradeUniformIcebergCompatVersion(table, sparkSession, icebergCompatVersion) + } + + protected def reorgOperation: DeltaReorgOperation = reorgTableSpec match { + case DeltaReorgTableSpec(DeltaReorgTableMode.PURGE, None) => + new DeltaPurgeOperation() + case DeltaReorgTableSpec(DeltaReorgTableMode.UNIFORM_ICEBERG, Some(icebergCompatVersion)) => + new DeltaUpgradeUniformOperation(icebergCompatVersion) + } +} + +/** + * Defines a Reorg operation to be applied during optimize. + */ +sealed trait DeltaReorgOperation { + /** + * Collects files that need to be processed by the reorg operation from the list of candidate + * files. + */ + def filterFilesToReorg(files: Seq[AddFile]): Seq[AddFile] +} + +/** + * Reorg operation to purge files with soft deleted rows. + */ +class DeltaPurgeOperation extends DeltaReorgOperation { + override def filterFilesToReorg(files: Seq[AddFile]): Seq[AddFile] = + files.filter { file => + (file.deletionVector != null && file.numPhysicalRecords.isEmpty) || + file.numDeletedRecords > 0L + } +} + +/** + * Reorg operation to upgrade the iceberg compatibility version of a table. + */ +class DeltaUpgradeUniformOperation(icebergCompatVersion: Int) extends DeltaReorgOperation { + override def filterFilesToReorg(files: Seq[AddFile]): Seq[AddFile] = { + def shouldRewriteToBeIcebergCompatible(file: AddFile): Boolean = { + if (file.tags == null) return true + val icebergCompatVersion = file.tags.getOrElse(AddFile.Tags.ICEBERG_COMPAT_VERSION.name, "0") + !icebergCompatVersion.exists(_.toString == icebergCompatVersion) } + files.filter(shouldRewriteToBeIcebergCompatible) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/commands/OptimizeTableCommand.scala b/spark/src/main/scala/org/apache/spark/sql/delta/commands/OptimizeTableCommand.scala index 2e84ec52460..5e474423716 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/commands/OptimizeTableCommand.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/commands/OptimizeTableCommand.scala @@ -187,30 +187,24 @@ case class OptimizeTableCommand( /** * Stored all runtime context information that can control the execution of optimize. * - * @param isPurge Whether the rewriting task is only for purging soft-deleted data instead of - * for compaction. If [[isPurge]] is true, only files with DVs will be selected - * for compaction. + * @param reorg The REORG operation that triggered the rewriting task, if any. * @param minFileSize Files which are smaller than this threshold will be selected for compaction. * If not specified, [[DeltaSQLConf.DELTA_OPTIMIZE_MIN_FILE_SIZE]] will be used. - * This parameter must be set to `0` when [[isPurge]] is true. + * This parameter must be set to `0` when [[reorg]] is set. * @param maxDeletedRowsRatio Files with a ratio of soft-deleted rows to the total rows larger than * this threshold will be rewritten by the OPTIMIZE command. If not * specified, [[DeltaSQLConf.DELTA_OPTIMIZE_MAX_DELETED_ROWS_RATIO]] - * will be used. This parameter must be set to `0` when [[isPurge]] is - * true. - * @param icebergCompatVersion The iceberg compatibility version used to rewrite data for - * uniform tables. + * will be used. This parameter must be set to `0` when [[reorg]] is set. */ case class DeltaOptimizeContext( - isPurge: Boolean = false, + reorg: Option[DeltaReorgOperation] = None, minFileSize: Option[Long] = None, maxFileSize: Option[Long] = None, - maxDeletedRowsRatio: Option[Double] = None, - icebergCompatVersion: Option[Int] = None) { - if (isPurge || icebergCompatVersion.isDefined) { + maxDeletedRowsRatio: Option[Double] = None) { + if (reorg.nonEmpty) { require( minFileSize.contains(0L) && maxDeletedRowsRatio.contains(0d), - "minFileSize and maxDeletedRowsRatio must be 0 when running PURGE.") + "minFileSize and maxDeletedRowsRatio must be 0 when running REORG TABLE.") } } @@ -269,7 +263,10 @@ class OptimizeExecutor( val candidateFiles = txn.filterFiles(partitionPredicate, keepNumRecords = true) val partitionSchema = txn.metadata.partitionSchema - val filesToProcess = pruneCandidateFileList(minFileSize, maxDeletedRowsRatio, candidateFiles) + val filesToProcess = optimizeContext.reorg match { + case Some(reorgOperation) => reorgOperation.filterFilesToReorg(candidateFiles) + case None => filterCandidateFileList(minFileSize, maxDeletedRowsRatio, candidateFiles) + } val partitionsToCompact = filesToProcess.groupBy(_.partitionValues).toSeq val jobs = groupFilesIntoBins(partitionsToCompact, maxFileSize) @@ -344,7 +341,7 @@ class OptimizeExecutor( * Helper method to prune the list of selected files based on fileSize and ratio of * deleted rows according to the deletion vector in [[AddFile]]. */ - private def pruneCandidateFileList( + private def filterCandidateFileList( minFileSize: Long, maxDeletedRowsRatio: Double, files: Seq[AddFile]): Seq[AddFile] = { // Select all files in case of multi-dimensional clustering @@ -358,18 +355,9 @@ class OptimizeExecutor( file.deletedToPhysicalRecordsRatio.getOrElse(0d) > maxDeletedRowsRatio } - def shouldRewriteToBeIcebergCompatible(file: AddFile): Boolean = { - if (optimizeContext.icebergCompatVersion.isEmpty) return false - if (file.tags == null) return true - val icebergCompatVersion = file.tags.getOrElse(AddFile.Tags.ICEBERG_COMPAT_VERSION.name, "0") - !optimizeContext.icebergCompatVersion.exists(_.toString == icebergCompatVersion) - } - - // Select files that are small, have too many deleted rows, - // or need to be made iceberg compatible + // Select files that are small or have too many deleted rows files.filter( - addFile => addFile.size < minFileSize || shouldCompactBecauseOfDeletedRows(addFile) || - shouldRewriteToBeIcebergCompatible(addFile)) + addFile => addFile.size < minFileSize || shouldCompactBecauseOfDeletedRows(addFile)) } /** @@ -414,8 +402,7 @@ class OptimizeExecutor( bins.filter { bin => bin.size > 1 || // bin has more than one file or - (bin.size == 1 && bin(0).deletionVector != null) || // single file in the bin has a DV or - (bin.size == 1 && optimizeContext.icebergCompatVersion.isDefined) || // uniform reorg + bin.size == 1 && optimizeContext.reorg.nonEmpty || // always rewrite files during reorg isMultiDimClustering // multi-clustering }.map(b => (partition, b)) } @@ -511,7 +498,7 @@ class OptimizeExecutor( /** Create the appropriate [[Operation]] object for txn commit history */ private def getOperation(): Operation = { - if (optimizeContext.isPurge) { + if (optimizeContext.reorg.nonEmpty) { DeltaOperations.Reorg(partitionPredicate) } else { DeltaOperations.Optimize(partitionPredicate, clusteringColumns, auto = isAutoCompact) diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/commands/ReorgTableForUpgradeUniformHelper.scala b/spark/src/main/scala/org/apache/spark/sql/delta/commands/ReorgTableForUpgradeUniformHelper.scala index c15831b761a..79964431b39 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/commands/ReorgTableForUpgradeUniformHelper.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/commands/ReorgTableForUpgradeUniformHelper.scala @@ -49,10 +49,7 @@ trait ReorgTableForUpgradeUniformHelper extends DeltaLogging { /** * Helper function to rewrite the table. Implemented by Reorg Table Command. */ - def optimizeByReorg( - sparkSession: SparkSession, - isPurge: Boolean, - icebergCompatVersion: Option[Int]): Seq[Row] + def optimizeByReorg(sparkSession: SparkSession): Seq[Row] /** * Helper function to update the table icebergCompat properties. @@ -172,11 +169,7 @@ trait ReorgTableForUpgradeUniformHelper extends DeltaLogging { logInfo(s"Reorg Table ${target.tableIdentifier} to iceberg compat version = " + s"$targetIcebergCompatVersion need rewrite data files.") val metrics = try { - optimizeByReorg( - sparkSession, - isPurge = false, - icebergCompatVersion = Some(targetIcebergCompatVersion) - ) + optimizeByReorg(sparkSession) } catch { case NonFatal(e) => throw DeltaErrors.icebergCompatDataFileRewriteFailedException( diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/hooks/AutoCompact.scala b/spark/src/main/scala/org/apache/spark/sql/delta/hooks/AutoCompact.scala index 13fb836879a..72404d2ea77 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/hooks/AutoCompact.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/hooks/AutoCompact.scala @@ -197,7 +197,7 @@ trait AutoCompactBase extends PostCommitHook with DeltaLogging { recordDeltaOperation(deltaLog, s"$opType.execute") { val txn = deltaLog.startTransaction(catalogTable) val optimizeContext = DeltaOptimizeContext( - isPurge = false, + reorg = None, minFileSizeOpt, maxFileSizeOpt, maxDeletedRowsRatio = maxDeletedRowsRatio From 0ee57b79e54574cf6827553129ce4f248e309099 Mon Sep 17 00:00:00 2001 From: Ala Luszczak Date: Thu, 15 Feb 2024 12:29:04 +0100 Subject: [PATCH 06/13] [Spark] Handle NullType in normalizeColumnNames() The sanity check in normalizeColumnNamesInDataType() introduced by that change is a bit too restrictive, and fails to handle NullType correctly. Closes delta-io/delta#2634 GitOrigin-RevId: faaf3d981c57ef3ceb4081e0bc94d457359fc9d8 --- .../scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala b/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala index b028f97a977..c1fdf654da6 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala @@ -256,6 +256,10 @@ def normalizeColumnNamesInDataType( keyType = normalizedKeyType, valueType = normalizedValueType ) + case (_: NullType, _) => + // When schema evolution adds a new column during MERGE, it can be represented with + // a NullType in the schema of the data written by the MERGE. + sourceDataType case _ => if (Utils.isTesting) { assert(sourceDataType == tableDataType, From 66d0c54bde1bb77c37ff61f9edfd62c2fd381fd4 Mon Sep 17 00:00:00 2001 From: Johan Lasperas Date: Thu, 15 Feb 2024 20:12:12 +0100 Subject: [PATCH 07/13] Support map and arrays in ALTER TABLE CHANGE COLUMN #### Which Delta project/connector is this regarding? -Spark - [ ] Standalone - [ ] Flink - [ ] Kernel - [ ] Other (fill in here) ## Description This change addresses an issue where trying to change the key/value of a map or element of an array with ALTER TABLE CHANGE COLUMN would succeed while doing nothing. In addition, a proper error is now thrown when trying to add or drop the key or value of a map or element of an array. - Added tests to `DeltaAlterTableTests` to cover changing maps and arrays in ALTER TABLE CHANGE COLUMN. - Added tests to `SchemaUtilsSuite` and `DeltaDropColumnSuite` to cover the updated error when trying to add/drop map key/value or array element. ## This PR introduces the following *user-facing* changes Changing the type of the key or value of a map or of the elements of an array now fails if the type change isn't supported (= anything except setting the same type or moving between char, varchar, string): ``` CREATE TABLE table (m map) USING DELTA; ALTER TABLE table CHANGE COLUMN m.key key long; -- Fails with DELTA_UNSUPPORTED_ALTER_TABLE_CHANGE_COL, previously succeeded while applying no change. ``` Similarly, adding a comment now also fails. The error when trying to add or drop a map key/value or array element field is updated: ``` CREATE TABLE table (m map) USING DELTA; ALTER TABLE table ADD COLUMN m.key long; -- Now fails with DELTA_ADD_COLUMN_PARENT_NOT_STRUCT instead of IllegalArgumentException: Don't know where to add the column m.key" ``` Closes delta-io/delta#2615 GitOrigin-RevId: e9d4ba42cefaf7be7e70d948075312922059cde0 --- .../resources/error/delta-error-classes.json | 12 + .../apache/spark/sql/delta/DeltaErrors.scala | 45 +++- .../commands/alterDeltaTableCommands.scala | 89 +++++-- .../spark/sql/delta/schema/SchemaUtils.scala | 55 ++-- .../delta/stats/StatisticsCollection.scala | 6 +- .../sql/delta/DeltaAlterTableTests.scala | 239 +++++++++++++++++- .../sql/delta/DeltaDropColumnSuite.scala | 23 ++ .../spark/sql/delta/DeltaErrorsSuite.scala | 18 +- .../sql/delta/schema/SchemaUtilsSuite.scala | 46 ++++ 9 files changed, 466 insertions(+), 67 deletions(-) diff --git a/spark/src/main/resources/error/delta-error-classes.json b/spark/src/main/resources/error/delta-error-classes.json index db61c2a074f..4a1ed314db9 100644 --- a/spark/src/main/resources/error/delta-error-classes.json +++ b/spark/src/main/resources/error/delta-error-classes.json @@ -2254,6 +2254,12 @@ ], "sqlState" : "0AKDC" }, + "DELTA_UNSUPPORTED_ALTER_TABLE_CHANGE_COL_OP" : { + "message" : [ + "ALTER TABLE CHANGE COLUMN is not supported for changing column from to " + ], + "sqlState" : "0AKDC" + }, "DELTA_UNSUPPORTED_ALTER_TABLE_REPLACE_COL_OP" : { "message" : [ "Unsupported ALTER TABLE REPLACE COLUMNS operation. Reason:
    ", @@ -2323,6 +2329,12 @@ ], "sqlState" : "0AKDC" }, + "DELTA_UNSUPPORTED_COMMENT_MAP_ARRAY" : { + "message" : [ + "Can't add a comment to . Adding a comment to a map key/value or array element is not supported." + ], + "sqlState" : "0AKDC" + }, "DELTA_UNSUPPORTED_DATA_TYPES" : { "message" : [ "Found columns using unsupported data types: . You can set '' to 'false' to disable the type check. Disabling this type check may allow users to create unsupported Delta tables and should only be used when trying to read/write legacy tables." diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaErrors.scala b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaErrors.scala index 3042d4ee443..0105a77b977 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaErrors.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaErrors.scala @@ -610,10 +610,37 @@ trait DeltaErrorsBase ) } - def alterTableChangeColumnException(oldColumns: String, newColumns: String): Throwable = { - new AnalysisException( - "ALTER TABLE CHANGE COLUMN is not supported for changing column " + oldColumns + " to " - + newColumns) + def addCommentToMapArrayException(fieldPath: String): Throwable = { + new DeltaAnalysisException( + errorClass = "DELTA_UNSUPPORTED_COMMENT_MAP_ARRAY", + messageParameters = Array(fieldPath) + ) + } + + def alterTableChangeColumnException( + fieldPath: String, + oldField: StructField, + newField: StructField): Throwable = { + def fieldToString(field: StructField): String = + field.dataType.sql + (if (!field.nullable) " NOT NULL" else "") + + new DeltaAnalysisException( + errorClass = "DELTA_UNSUPPORTED_ALTER_TABLE_CHANGE_COL_OP", + messageParameters = Array( + fieldPath, + fieldToString(oldField), + fieldToString(newField)) + ) + } + + def alterTableReplaceColumnsException( + oldSchema: StructType, + newSchema: StructType, + reason: String): Throwable = { + new DeltaAnalysisException( + errorClass = "DELTA_UNSUPPORTED_ALTER_TABLE_REPLACE_COL_OP", + messageParameters = Array(reason, formatSchema(oldSchema), formatSchema(newSchema)) + ) } def cannotWriteIntoView(table: TableIdentifier): Throwable = { @@ -687,16 +714,6 @@ trait DeltaErrorsBase messageParameters = Array(source, targetType, target, tableName)) } - def alterTableReplaceColumnsException( - oldSchema: StructType, - newSchema: StructType, - reason: String): Throwable = { - new DeltaAnalysisException( - errorClass = "DELTA_UNSUPPORTED_ALTER_TABLE_REPLACE_COL_OP", - messageParameters = Array(reason, formatSchema(oldSchema), formatSchema(newSchema)) - ) - } - def ambiguousPartitionColumnException( columnName: String, colMatches: Seq[StructField]): Throwable = { new DeltaAnalysisException( diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/commands/alterDeltaTableCommands.scala b/spark/src/main/scala/org/apache/spark/sql/delta/commands/alterDeltaTableCommands.scala index 4b9221eb0cd..82d6539cb60 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/commands/alterDeltaTableCommands.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/commands/alterDeltaTableCommands.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.delta.actions.TableFeatureProtocolUtils import org.apache.spark.sql.delta.catalog.DeltaTableV2 import org.apache.spark.sql.delta.constraints.{CharVarcharConstraint, Constraints} import org.apache.spark.sql.delta.schema.{SchemaMergingUtils, SchemaUtils} -import org.apache.spark.sql.delta.schema.SchemaUtils.transformColumnsStructs +import org.apache.spark.sql.delta.schema.SchemaUtils.transformSchema import org.apache.spark.sql.delta.sources.DeltaSQLConf import org.apache.spark.sql.delta.stats.StatisticsCollection import org.apache.hadoop.fs.Path @@ -529,7 +529,7 @@ case class AlterTableChangeColumnDeltaCommand( // Verify that the columnName provided actually exists in the schema SchemaUtils.findColumnPosition(columnPath :+ columnName, oldSchema, resolver) - val newSchema = transformColumnsStructs(oldSchema, Some(columnName)) { + val newSchema = transformSchema(oldSchema, Some(columnName)) { case (`columnPath`, struct @ StructType(fields), _) => val oldColumn = struct(columnName) verifyColumnChange(sparkSession, struct(columnName), resolver, txn) @@ -561,11 +561,27 @@ case class AlterTableChangeColumnDeltaCommand( } // Reorder new field to correct position if necessary - colPosition.map { position => + StructType(colPosition.map { position => reorderFieldList(struct, newFieldList, newField, position, resolver) - }.getOrElse(newFieldList.toSeq) + }.getOrElse(newFieldList.toSeq)) - case (_, _ @ StructType(fields), _) => fields + case (`columnPath`, m: MapType, _) if columnName == "key" => + val originalField = StructField(columnName, m.keyType, nullable = false) + verifyMapArrayChange(sparkSession, originalField, resolver, txn) + m.copy(keyType = SchemaUtils.changeDataType(m.keyType, newColumn.dataType, resolver)) + + case (`columnPath`, m: MapType, _) if columnName == "value" => + val originalField = StructField(columnName, m.valueType, nullable = m.valueContainsNull) + verifyMapArrayChange(sparkSession, originalField, resolver, txn) + m.copy(valueType = SchemaUtils.changeDataType(m.valueType, newColumn.dataType, resolver)) + + case (`columnPath`, a: ArrayType, _) if columnName == "element" => + val originalField = StructField(columnName, a.elementType, nullable = a.containsNull) + verifyMapArrayChange(sparkSession, originalField, resolver, txn) + a.copy(elementType = + SchemaUtils.changeDataType(a.elementType, newColumn.dataType, resolver)) + + case (_, other @ (_: StructType | _: ArrayType | _: MapType), _) => other } // update `partitionColumns` if the changed column is a partition column @@ -685,15 +701,18 @@ case class AlterTableChangeColumnDeltaCommand( // first (original data type is already normalized as we store char/varchar as string type with // special metadata in the Delta log), then apply Delta-specific checks. val newType = CharVarcharUtils.replaceCharVarcharWithString(newColumn.dataType) - if (SchemaUtils.canChangeDataType(originalField.dataType, newType, resolver, - txn.metadata.columnMappingMode, columnPath :+ originalField.name).nonEmpty) { + if (SchemaUtils.canChangeDataType( + originalField.dataType, + newType, + resolver, + txn.metadata.columnMappingMode, + columnPath :+ originalField.name + ).nonEmpty) { throw DeltaErrors.alterTableChangeColumnException( - s"'${UnresolvedAttribute(columnPath :+ originalField.name).name}' with type " + - s"'${originalField.dataType}" + - s" (nullable = ${originalField.nullable})'", - s"'${UnresolvedAttribute(Seq(newColumn.name)).name}' with type " + - s"'$newType" + - s" (nullable = ${newColumn.nullable})'") + fieldPath = UnresolvedAttribute(columnPath :+ originalField.name).name, + oldField = originalField, + newField = newColumn + ) } if (columnName != newColumn.name) { @@ -704,13 +723,36 @@ case class AlterTableChangeColumnDeltaCommand( if (originalField.nullable && !newColumn.nullable) { throw DeltaErrors.alterTableChangeColumnException( - s"'${UnresolvedAttribute(columnPath :+ originalField.name).name}' with type " + - s"'${originalField.dataType}" + - s" (nullable = ${originalField.nullable})'", - s"'${UnresolvedAttribute(Seq(newColumn.name)).name}' with type " + - s"'${newColumn.dataType}" + - s" (nullable = ${newColumn.nullable})'") + fieldPath = UnresolvedAttribute(columnPath :+ originalField.name).name, + oldField = originalField, + newField = newColumn + ) + } + } + + /** + * Verify whether replacing the original map key/value or array element with a new data type is a + * valid operation. + * + * @param originalField the original map key/value or array element to update. + */ + private def verifyMapArrayChange(spark: SparkSession, originalField: StructField, + resolver: Resolver, txn: OptimisticTransaction): Unit = { + // Map key/value and array element can't have comments. + if (newColumn.getComment().nonEmpty) { + throw DeltaErrors.addCommentToMapArrayException( + fieldPath = UnresolvedAttribute(columnPath :+ columnName).name + ) + } + // Changing the nullability of map key/value or array element isn't supported. + if (originalField.nullable != newColumn.nullable) { + throw DeltaErrors.alterTableChangeColumnException( + fieldPath = UnresolvedAttribute(columnPath :+ originalField.name).name, + oldField = originalField, + newField = newColumn + ) } + verifyColumnChange(spark, originalField, resolver, txn) } } @@ -738,8 +780,13 @@ case class AlterTableReplaceColumnsDeltaCommand( val resolver = sparkSession.sessionState.conf.resolver val changingSchema = StructType(columns) - SchemaUtils.canChangeDataType(existingSchema, changingSchema, resolver, - txn.metadata.columnMappingMode, failOnAmbiguousChanges = true).foreach { operation => + SchemaUtils.canChangeDataType( + existingSchema, + changingSchema, + resolver, + txn.metadata.columnMappingMode, + failOnAmbiguousChanges = true + ).foreach { operation => throw DeltaErrors.alterTableReplaceColumnsException( existingSchema, changingSchema, operation) } diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala b/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala index c1fdf654da6..426a2f78d9a 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala @@ -789,7 +789,9 @@ def normalizeColumnNamesInDataType( */ def addColumn(schema: StructType, column: StructField, position: Seq[Int]): StructType = { def addColumnInChild(parent: DataType, column: StructField, position: Seq[Int]): DataType = { - require(position.nonEmpty, s"Don't know where to add the column $column") + if (position.isEmpty) { + throw DeltaErrors.addColumnParentNotStructException(column, parent) + } parent match { case struct: StructType => addColumn(struct, column, position) @@ -857,7 +859,9 @@ def normalizeColumnNamesInDataType( */ def dropColumn(schema: StructType, position: Seq[Int]): (StructType, StructField) = { def dropColumnInChild(parent: DataType, position: Seq[Int]): (DataType, StructField) = { - require(position.nonEmpty, s"Don't know where to drop the column") + if (position.isEmpty) { + throw DeltaErrors.dropNestedColumnsFromNonStructTypeException(parent) + } parent match { case struct: StructType => dropColumn(struct, position) @@ -1014,38 +1018,51 @@ def normalizeColumnNamesInDataType( } /** - * Transform (nested) columns in a schema. Runs the transform function on all nested StructTypes - * - * If `colName` is defined, we also check if the struct to process contains the column name. - * + * Runs the transform function `tf` on all nested StructTypes, MapTypes and ArrayTypes in the + * schema. + * If `colName` is defined, the transform function is only applied to all the fields with the + * given name. There may be multiple matches if nested fields with the same name exist in the + * schema, it is the responsibility of the caller to check the full field path before transforming + * a field. * @param schema to transform. * @param colName Optional name to match for * @param tf function to apply on the StructType. * @return the transformed schema. */ - def transformColumnsStructs( + def transformSchema( schema: StructType, colName: Option[String] = None)( - tf: (Seq[String], StructType, Resolver) => Seq[StructField]): StructType = { + tf: (Seq[String], DataType, Resolver) => DataType): StructType = { def transform[E <: DataType](path: Seq[String], dt: E): E = { val newDt = dt match { case struct @ StructType(fields) => - val newFields = if (colName.isEmpty || fields.exists(f => colName.contains(f.name))) { - tf(path, struct, DELTA_COL_RESOLVER) + val newStruct = if (colName.isEmpty || fields.exists(f => colName.contains(f.name))) { + tf(path, struct, DELTA_COL_RESOLVER).asInstanceOf[StructType] } else { - fields.toSeq + struct } - StructType(newFields.map { field => + StructType(newStruct.fields.map { field => field.copy(dataType = transform(path :+ field.name, field.dataType)) }) - case ArrayType(elementType, containsNull) => - ArrayType(transform(path :+ "element", elementType), containsNull) - case MapType(keyType, valueType, valueContainsNull) => - MapType( - transform(path :+ "key", keyType), - transform(path :+ "value", valueType), - valueContainsNull) + case array: ArrayType => + val newArray = + if (colName.isEmpty || colName.contains("element")) { + tf(path, array, DELTA_COL_RESOLVER).asInstanceOf[ArrayType] + } else { + array + } + newArray.copy(elementType = transform(path :+ "element", newArray.elementType)) + case map: MapType => + val newMap = + if (colName.isEmpty || colName.contains("key") || colName.contains("value")) { + tf(path, map, DELTA_COL_RESOLVER).asInstanceOf[MapType] + } else { + map + } + newMap.copy( + keyType = transform(path :+ "key", newMap.keyType), + valueType = transform(path :+ "value", newMap.valueType)) case other => other } newDt.asInstanceOf[E] diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/stats/StatisticsCollection.scala b/spark/src/main/scala/org/apache/spark/sql/delta/stats/StatisticsCollection.scala index d07f2fbb50f..6105c1b446d 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/stats/StatisticsCollection.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/stats/StatisticsCollection.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.delta.commands.DeletionVectorUtils import org.apache.spark.sql.delta.commands.DeltaCommand import org.apache.spark.sql.delta.metering.DeltaLogging import org.apache.spark.sql.delta.schema.{SchemaMergingUtils, SchemaUtils} -import org.apache.spark.sql.delta.schema.SchemaUtils.transformColumnsStructs +import org.apache.spark.sql.delta.schema.SchemaUtils.transformSchema import org.apache.spark.sql.delta.sources.DeltaSQLConf import org.apache.spark.sql.delta.stats.DeltaStatistics._ import org.apache.spark.sql.delta.stats.StatisticsCollection.getIndexedColumns @@ -490,12 +490,12 @@ object StatisticsCollection extends DeltaCommand { SchemaUtils.findColumnPosition(columnFullPath, schema) // Delta statistics columns must be data skipping type. val (prefixPath, columnName) = columnFullPath.splitAt(columnFullPath.size - 1) - transformColumnsStructs(schema, Some(columnName.head)) { + transformSchema(schema, Some(columnName.head)) { case (`prefixPath`, struct @ StructType(_), _) => val columnField = struct(columnName.head) validateDataSkippingType(columnAttribute.name, columnField.dataType, visitedColumns) struct - case (_, s: StructType, _) => s + case (_, other, _) => other } } } diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaAlterTableTests.scala b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaAlterTableTests.scala index 37ebabcf75e..e9a1886cde6 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaAlterTableTests.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaAlterTableTests.scala @@ -918,6 +918,85 @@ trait DeltaAlterTableTests extends DeltaAlterTableTestBase { } } + ddlTest("CHANGE COLUMN - (unsupported) add a comment to key/value of a MapType") { + val df = Seq((1, 1), (2, 2)).toDF("v1", "v2") + .withColumn("a", map('v1, 'v2)) + withDeltaTable(df) { tableName => + checkError( + exception = intercept[DeltaAnalysisException] { + sql(s"ALTER TABLE $tableName CHANGE COLUMN a.key COMMENT 'a comment'") + }, + errorClass = "DELTA_UNSUPPORTED_COMMENT_MAP_ARRAY", + parameters = Map("fieldPath" -> "a.key") + ) + checkError( + exception = intercept[DeltaAnalysisException] { + sql(s"ALTER TABLE $tableName CHANGE COLUMN a.value COMMENT 'a comment'") + }, + errorClass = "DELTA_UNSUPPORTED_COMMENT_MAP_ARRAY", + parameters = Map("fieldPath" -> "a.value") + ) + } + } + + ddlTest("CHANGE COLUMN - (unsupported) add a comment to element of an array") { + val df = Seq(1, 2).toDF("v1") + .withColumn("a", array('v1)) + withDeltaTable(df) { tableName => + checkError( + exception = intercept[DeltaAnalysisException] { + sql(s"ALTER TABLE $tableName CHANGE COLUMN a.element COMMENT 'a comment'") + }, + errorClass = "DELTA_UNSUPPORTED_COMMENT_MAP_ARRAY", + parameters = Map("fieldPath" -> "a.element") + ) + } + } + + ddlTest("RENAME COLUMN - (unsupported) rename key/value of a MapType") { + val df = Seq((1, 1), (2, 2)).toDF("v1", "v2") + .withColumn("a", map('v1, 'v2)) + withDeltaTable(df) { tableName => + checkError( + exception = intercept[AnalysisException] { + sql(s"ALTER TABLE $tableName RENAME COLUMN a.key TO key2") + }, + errorClass = "INVALID_FIELD_NAME", + parameters = Map( + "fieldName" -> "`a`.`key2`", + "path" -> "`a`" + ) + ) + checkError( + exception = intercept[AnalysisException] { + sql(s"ALTER TABLE $tableName RENAME COLUMN a.value TO value2") + }, + errorClass = "INVALID_FIELD_NAME", + parameters = Map( + "fieldName" -> "`a`.`value2`", + "path" -> "`a`" + ) + ) + } + } + + ddlTest("RENAME COLUMN - (unsupported) rename element of an array") { + val df = Seq(1, 2).toDF("v1") + .withColumn("a", array('v1)) + withDeltaTable(df) { tableName => + checkError( + exception = intercept[AnalysisException] { + sql(s"ALTER TABLE $tableName RENAME COLUMN a.element TO element2") + }, + errorClass = "INVALID_FIELD_NAME", + parameters = Map( + "fieldName" -> "`a`.`element2`", + "path" -> "`a`" + ) + ) + } + } + ddlTest("CHANGE COLUMN - change name") { withDeltaTable(Seq((1, "a"), (2, "b")).toDF("v1", "v2")) { tableName => @@ -927,11 +1006,17 @@ trait DeltaAlterTableTests extends DeltaAlterTableTestBase { ddlTest("CHANGE COLUMN - incompatible") { withDeltaTable(Seq((1, "a"), (2, "b")).toDF("v1", "v2")) { tableName => - - assertNotSupported( - s"ALTER TABLE $tableName CHANGE COLUMN v1 v1 long", - "'v1' with type 'IntegerType (nullable = true)'", - "'v1' with type 'LongType (nullable = true)'") + checkError( + exception = intercept[DeltaAnalysisException] { + sql(s"ALTER TABLE $tableName CHANGE COLUMN v1 v1 long") + }, + errorClass = "DELTA_UNSUPPORTED_ALTER_TABLE_CHANGE_COL_OP", + parameters = Map( + "fieldPath" -> "v1", + "oldField" -> "INT", + "newField" -> "BIGINT" + ) + ) } } @@ -939,11 +1024,71 @@ trait DeltaAlterTableTests extends DeltaAlterTableTestBase { val df = Seq((1, "a"), (2, "b")).toDF("v1", "v2") .withColumn("struct", struct("v1", "v2")) withDeltaTable(df) { tableName => + checkError( + exception = intercept[DeltaAnalysisException] { + sql(s"ALTER TABLE $tableName CHANGE COLUMN struct.v1 v1 long") + }, + errorClass = "DELTA_UNSUPPORTED_ALTER_TABLE_CHANGE_COL_OP", + parameters = Map( + "fieldPath" -> "struct.v1", + "oldField" -> "INT", + "newField" -> "BIGINT" + ) + ) + } + } - assertNotSupported( - s"ALTER TABLE $tableName CHANGE COLUMN struct.v1 v1 long", - "'struct.v1' with type 'IntegerType (nullable = true)'", - "'v1' with type 'LongType (nullable = true)'") + ddlTest("CHANGE COLUMN - (unsupported) change type of key of a MapType") { + val df = Seq((1, 1), (2, 2)).toDF("v1", "v2") + .withColumn("a", map('v1, 'v2)) + withDeltaTable(df) { tableName => + checkError( + exception = intercept[DeltaAnalysisException] { + sql(s"ALTER TABLE $tableName CHANGE COLUMN a.key key long") + }, + errorClass = "DELTA_UNSUPPORTED_ALTER_TABLE_CHANGE_COL_OP", + parameters = Map( + "fieldPath" -> "a.key", + "oldField" -> "INT NOT NULL", + "newField" -> "BIGINT NOT NULL" + ) + ) + } + } + + ddlTest("CHANGE COLUMN - (unsupported) change type of value of a MapType") { + val df = Seq((1, 1), (2, 2)).toDF("v1", "v2") + .withColumn("a", map('v1, 'v2)) + withDeltaTable(df) { tableName => + checkError( + exception = intercept[DeltaAnalysisException] { + sql(s"ALTER TABLE $tableName CHANGE COLUMN a.value value long") + }, + errorClass = "DELTA_UNSUPPORTED_ALTER_TABLE_CHANGE_COL_OP", + parameters = Map( + "fieldPath" -> "a.value", + "oldField" -> "INT", + "newField" -> "BIGINT" + ) + ) + } + } + + ddlTest("CHANGE COLUMN - (unsupported) change type of element of an ArrayType") { + val df = Seq(1).toDF("v1") + .withColumn("a", array('v1)) + withDeltaTable(df) { tableName => + checkError( + exception = intercept[DeltaAnalysisException] { + sql(s"ALTER TABLE $tableName CHANGE COLUMN a.element element long") + }, + errorClass = "DELTA_UNSUPPORTED_ALTER_TABLE_CHANGE_COL_OP", + parameters = Map( + "fieldPath" -> "a.element", + "oldField" -> "INT", + "newField" -> "BIGINT" + ) + ) } } @@ -1216,6 +1361,58 @@ trait DeltaAlterTableTests extends DeltaAlterTableTestBase { } } + test("CHANGE COLUMN - (unsupported) change nullability of map key/value and array element") { + val df = Seq((1, 1), (2, 2)) + .toDF("key", "value") + .withColumn("m", map(col("key"), col("value"))) + .withColumn("a", array(col("value"))) + + withDeltaTable(df) { tableName => + val schema = spark.read.table(tableName).schema + assert(schema("m").dataType === + MapType(IntegerType, IntegerType, valueContainsNull = true)) + assert(schema("a").dataType === + ArrayType(IntegerType, containsNull = true)) + + // No-op actions are allowed - map keys are always non-nullable. + sql(s"ALTER TABLE $tableName CHANGE COLUMN m.key SET NOT NULL") + sql(s"ALTER TABLE $tableName CHANGE COLUMN m.value DROP NOT NULL") + sql(s"ALTER TABLE $tableName CHANGE COLUMN a.element DROP NOT NULL") + + // Changing the nullability of map/array fields is not allowed. + var statement = s"ALTER TABLE $tableName CHANGE COLUMN m.key DROP NOT NULL" + checkError( + exception = intercept[AnalysisException] { sql(statement) }, + errorClass = "DELTA_UNSUPPORTED_ALTER_TABLE_CHANGE_COL_OP", + parameters = Map( + "fieldPath" -> "m.key", + "oldField" -> "INT NOT NULL", + "newField" -> "INT" + ) + ) + + statement = s"ALTER TABLE $tableName CHANGE COLUMN m.value SET NOT NULL" + checkError( + exception = intercept[AnalysisException] { sql(statement) }, + errorClass = "_LEGACY_ERROR_TEMP_2330", + parameters = Map( + "fieldName" -> "m.value" + ), + context = ExpectedContext(statement, 0, statement.length - 1) + ) + + statement = s"ALTER TABLE $tableName CHANGE COLUMN a.element SET NOT NULL" + checkError( + exception = intercept[AnalysisException] { sql(statement) }, + errorClass = "_LEGACY_ERROR_TEMP_2330", + parameters = Map( + "fieldName" -> "a.element" + ), + context = ExpectedContext(statement, 0, statement.length - 1) + ) + } + } + ddlTest("CHANGE COLUMN - change name (nested)") { val df = Seq((1, "a"), (2, "b")).toDF("v1", "v2") .withColumn("struct", struct("v1", "v2")) @@ -1380,6 +1577,30 @@ trait DeltaAlterTableTests extends DeltaAlterTableTestBase { } } + test("CHANGE COLUMN: allow to change map key from char to string type") { + withTable("t") { + sql("CREATE TABLE t(i STRING, m map) USING delta") + sql("ALTER TABLE t CHANGE COLUMN m.key TYPE STRING") + assert(spark.table("t").schema(1).dataType === MapType(StringType, IntegerType)) + } + } + + test("CHANGE COLUMN: allow to change map value from char to string type") { + withTable("t") { + sql("CREATE TABLE t(i STRING, m map) USING delta") + sql("ALTER TABLE t CHANGE COLUMN m.value TYPE STRING") + assert(spark.table("t").schema(1).dataType === MapType(IntegerType, StringType)) + } + } + + test("CHANGE COLUMN: allow to change array element from char to string type") { + withTable("t") { + sql("CREATE TABLE t(i STRING, a array) USING delta") + sql("ALTER TABLE t CHANGE COLUMN a.element TYPE STRING") + assert(spark.table("t").schema(1).dataType === ArrayType(StringType)) + } + } + private def checkColType(f: StructField, dt: DataType): Unit = { assert(f.dataType == CharVarcharUtils.replaceCharVarcharWithString(dt)) assert(CharVarcharUtils.getRawType(f.metadata).contains(dt)) diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaDropColumnSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaDropColumnSuite.scala index fb0d07b692a..f5c47893df6 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaDropColumnSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaDropColumnSuite.scala @@ -437,4 +437,27 @@ class DeltaDropColumnSuite extends QueryTest initialColumnType = "map>>", fieldToDrop = "value.element.b", updatedColumnType = "map>>") + + test("can't drop map key/value or array element") { + withTable("delta_test") { + sql( + s""" + |CREATE TABLE delta_test (m map, a array) + |USING delta + |TBLPROPERTIES (${DeltaConfigs.COLUMN_MAPPING_MODE.key} = 'name') + """.stripMargin) + for { + field <- Seq("m.key", "m.value", "a.element") + } + checkError( + exception = intercept[AnalysisException] { + sql(s"ALTER TABLE delta_test DROP COLUMN $field") + }, + errorClass = "DELTA_UNSUPPORTED_DROP_NESTED_COLUMN_FROM_NON_STRUCT_TYPE", + parameters = Map( + "struct" -> "IntegerType" + ) + ) + } + } } diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaErrorsSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaErrorsSuite.scala index 87d76ac0fd2..35ac48daf41 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaErrorsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaErrorsSuite.scala @@ -56,7 +56,7 @@ import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.errors.QueryErrorsBase import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} -import org.apache.spark.sql.types.{CalendarIntervalType, DataTypes, DateType, IntegerType, StringType, StructField, StructType, TimestampNTZType} +import org.apache.spark.sql.types._ trait DeltaErrorsSuiteBase extends QueryTest @@ -1006,6 +1006,22 @@ trait DeltaErrorsSuiteBase |""".stripMargin )) } + { + checkError( + exception = intercept[DeltaAnalysisException] { + throw DeltaErrors.alterTableChangeColumnException( + fieldPath = "a.b.c", + oldField = StructField("c", IntegerType), + newField = StructField("c", LongType)) + }, + errorClass = "DELTA_UNSUPPORTED_ALTER_TABLE_CHANGE_COL_OP", + parameters = Map( + "fieldPath" -> "a.b.c", + "oldField" -> "INT", + "newField" -> "BIGINT" + ) + ) + } { val s1 = StructType(Seq(StructField("c0", IntegerType))) val s2 = StructType(Seq(StructField("c0", StringType))) diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/schema/SchemaUtilsSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/schema/SchemaUtilsSuite.scala index 8ac22cce922..1bf52857391 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/schema/SchemaUtilsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/schema/SchemaUtilsSuite.scala @@ -1095,6 +1095,14 @@ class SchemaUtilsSuite extends QueryTest keyType = new StructType().add(k), valueType = new StructType().add(v).add(x)))) + // Adding to map key/value. + expectFailure("parent is not a structtype") { + SchemaUtils.addColumn(schema, x, Seq(0, MAP_KEY_INDEX)) + } + expectFailure("parent is not a structtype") { + SchemaUtils.addColumn(schema, x, Seq(0, MAP_VALUE_INDEX)) + } + // Invalid map access. expectFailure("parent is not a structtype") { SchemaUtils.addColumn(schema, x, Seq(0, MAP_KEY_INDEX - 1, 0)) } @@ -1143,6 +1151,13 @@ class SchemaUtilsSuite extends QueryTest assert(SchemaUtils.addColumn(schema(), x, Seq(0, MAP_VALUE_INDEX, MAP_VALUE_INDEX, 1)) === schema(vv = new StructType().add("vv", IntegerType).add(x))) + // Adding to map key/value. + expectFailure("parent is not a structtype") { + SchemaUtils.addColumn(schema(), x, Seq(0, MAP_KEY_INDEX, MAP_KEY_INDEX)) + } + expectFailure("parent is not a structtype") { + SchemaUtils.addColumn(schema(), x, Seq(0, MAP_KEY_INDEX, MAP_VALUE_INDEX)) + } // Invalid map access. expectFailure("parent is not a structtype") { SchemaUtils.addColumn(schema(), x, Seq(0, MAP_KEY_INDEX, MAP_KEY_INDEX - 1, 0)) @@ -1172,6 +1187,10 @@ class SchemaUtilsSuite extends QueryTest assert(SchemaUtils.addColumn(schema, x, Seq(0, ARRAY_ELEMENT_INDEX, 1)) === new StructType().add("a", ArrayType(new StructType().add(e).add(x)))) + // Adding to array element. + expectFailure("parent is not a structtype") { + SchemaUtils.addColumn(schema, x, Seq(0, ARRAY_ELEMENT_INDEX)) + } // Invalid array access. expectFailure("Incorrectly accessing an ArrayType") { SchemaUtils.addColumn(schema, x, Seq(0, ARRAY_ELEMENT_INDEX - 1, 0)) @@ -1195,6 +1214,10 @@ class SchemaUtilsSuite extends QueryTest assert(SchemaUtils.addColumn(schema, x, Seq(0, ARRAY_ELEMENT_INDEX, ARRAY_ELEMENT_INDEX, 1)) === new StructType().add("a", ArrayType(ArrayType(new StructType().add(e).add(x))))) + // Adding to array element. + expectFailure("parent is not a structtype") { + SchemaUtils.addColumn(schema, x, Seq(0, ARRAY_ELEMENT_INDEX, ARRAY_ELEMENT_INDEX)) + } // Invalid array access. expectFailure("Incorrectly accessing an ArrayType") { SchemaUtils.addColumn(schema, x, Seq(0, ARRAY_ELEMENT_INDEX, ARRAY_ELEMENT_INDEX - 1, 0)) @@ -1307,6 +1330,14 @@ class SchemaUtilsSuite extends QueryTest valueType = new StructType().add(c))), d)) + // Dropping map key/value. + expectFailure("can only drop nested columns from structtype") { + SchemaUtils.dropColumn(schema, Seq(0, MAP_KEY_INDEX)) + } + expectFailure("can only drop nested columns from structtype") { + SchemaUtils.dropColumn(schema, Seq(0, MAP_VALUE_INDEX)) + } + // Invalid map access. expectFailure("can only drop nested columns from structtype") { SchemaUtils.dropColumn(schema, Seq(0, MAP_KEY_INDEX - 1, 0)) } @@ -1373,6 +1404,13 @@ class SchemaUtilsSuite extends QueryTest initialSchema = schema(vv = new StructType().add("vv", IntegerType).add(a)), position = Seq(0, MAP_VALUE_INDEX, MAP_VALUE_INDEX, 1)) + // Dropping map key/value. + expectFailure("can only drop nested columns from structtype") { + SchemaUtils.dropColumn(schema(), Seq(0, MAP_KEY_INDEX, MAP_KEY_INDEX)) + } + expectFailure("can only drop nested columns from structtype") { + SchemaUtils.dropColumn(schema(), Seq(0, MAP_KEY_INDEX, MAP_VALUE_INDEX)) + } // Invalid map access. expectFailure("can only drop nested columns from structtype") { SchemaUtils.dropColumn(schema(), Seq(0, MAP_KEY_INDEX, MAP_KEY_INDEX - 1, 0)) @@ -1402,6 +1440,10 @@ class SchemaUtilsSuite extends QueryTest assert(SchemaUtils.dropColumn(schema, Seq(0, ARRAY_ELEMENT_INDEX, 1)) === (new StructType().add("a", ArrayType(new StructType().add(e))), f)) + // Dropping array element. + expectFailure("can only drop nested columns from structtype") { + SchemaUtils.dropColumn(schema, Seq(0, ARRAY_ELEMENT_INDEX)) + } // Invalid array access. expectFailure("Incorrectly accessing an ArrayType") { SchemaUtils.dropColumn(schema, Seq(0, ARRAY_ELEMENT_INDEX - 1, 0)) @@ -1425,6 +1467,10 @@ class SchemaUtilsSuite extends QueryTest assert(SchemaUtils.dropColumn(schema, Seq(0, ARRAY_ELEMENT_INDEX, ARRAY_ELEMENT_INDEX, 1)) === (new StructType().add("a", ArrayType(ArrayType(new StructType().add(e)))), f)) + // Dropping array element. + expectFailure("can only drop nested columns from structtype") { + SchemaUtils.dropColumn(schema, Seq(0, ARRAY_ELEMENT_INDEX, ARRAY_ELEMENT_INDEX)) + } // Invalid array access. expectFailure("Incorrectly accessing an ArrayType") { SchemaUtils.dropColumn(schema, Seq(0, ARRAY_ELEMENT_INDEX, ARRAY_ELEMENT_INDEX - 1, 0)) From d6482c4440903e22a3855b973aa38f24b7ca284e Mon Sep 17 00:00:00 2001 From: Hao Jiang Date: Thu, 15 Feb 2024 14:12:29 -0800 Subject: [PATCH 08/13] Add example for IcebergCompatV2 and REORG This PR add an example explaining how to use REORG APPLY UniForm command to enable IcebergCompatV2 and UniForm Closes delta-io/delta#2500 GitOrigin-RevId: 23ccd5bac7d95977530646fcf5ba5d53a25d2734 --- .../main/scala/example/IcebergCompatV2.scala | 87 +++++++++++++++++++ .../src/main/scala/example/UniForm.scala | 1 + 2 files changed, 88 insertions(+) create mode 100644 examples/scala/src/main/scala/example/IcebergCompatV2.scala diff --git a/examples/scala/src/main/scala/example/IcebergCompatV2.scala b/examples/scala/src/main/scala/example/IcebergCompatV2.scala new file mode 100644 index 00000000000..9609f3ee7c6 --- /dev/null +++ b/examples/scala/src/main/scala/example/IcebergCompatV2.scala @@ -0,0 +1,87 @@ +/* + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package example + +import java.io.{File, IOException} +import java.net.ServerSocket + +import org.apache.commons.io.FileUtils + +import org.apache.spark.sql.SparkSession +/** + * This example relies on an external Hive metastore (HMS) instance to run. + * + * A standalone HMS can be created using the following docker command. + * ************************************************************ + * docker run -d -p 9083:9083 --env SERVICE_NAME=metastore \ + * --name metastore-standalone apache/hive:4.0.0-beta-1 + * ************************************************************ + * The URL of this standalone HMS is thrift://localhost:9083 + * + * By default this hms will use `/opt/hive/data/warehouse` as warehouse path. + * Please make sure this path exists or change it prior to running the example. + */ +object IcebergCompatV2 { + + def main(args: Array[String]): Unit = { + // Update this according to the metastore config + val port = 9083 + val warehousePath = "/opt/hive/data/warehouse/" + + if (!UniForm.hmsReady(port)) { + print("HMS not available. Exit.") + return + } + + val testTableName = "uniform_table3" + FileUtils.deleteDirectory(new File(s"${warehousePath}${testTableName}")) + + val deltaSpark = SparkSession + .builder() + .appName("UniForm-Delta") + .master("local[*]") + .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") + .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") + .config("hive.metastore.uris", s"thrift://localhost:$port") + .config("spark.sql.catalogImplementation", "hive") + .getOrCreate() + + deltaSpark.sql(s"DROP TABLE IF EXISTS ${testTableName}") + deltaSpark.sql( + s"""CREATE TABLE `${testTableName}` + | (id INT, ts TIMESTAMP, array_data array, map_data map) + | using DELTA""".stripMargin) + deltaSpark.sql( + s""" + |INSERT INTO `$testTableName` (id, ts, array_data, map_data) + | VALUES (123, '2024-01-01 00:00:00', array(2, 3, 4, 5), map(3, 6, 8, 7))""".stripMargin) + deltaSpark.sql( + s"""REORG TABLE `$testTableName` APPLY (UPGRADE UNIFORM + | (ICEBERG_COMPAT_VERSION = 2))""".stripMargin) + + val icebergSpark = SparkSession.builder() + .master("local[*]") + .appName("UniForm-Iceberg") + .config("spark.sql.extensions", + "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions") + .config("spark.sql.catalog.spark_catalog", "org.apache.iceberg.spark.SparkSessionCatalog") + .config("hive.metastore.uris", s"thrift://localhost:$port") + .config("spark.sql.catalogImplementation", "hive") + .getOrCreate() + + icebergSpark.sql(s"SELECT * FROM ${testTableName}").show() + } +} diff --git a/examples/scala/src/main/scala/example/UniForm.scala b/examples/scala/src/main/scala/example/UniForm.scala index 3b7d0b01b8b..291f9101bcf 100644 --- a/examples/scala/src/main/scala/example/UniForm.scala +++ b/examples/scala/src/main/scala/example/UniForm.scala @@ -66,6 +66,7 @@ object UniForm { s"""CREATE TABLE `${testTableName}` (col1 INT) using DELTA |TBLPROPERTIES ( | 'delta.columnMapping.mode' = 'name', + | 'delta.enableIcebergCompatV1' = 'true', | 'delta.universalFormat.enabledFormats' = 'iceberg' |)""".stripMargin) deltaSpark.sql(s"INSERT INTO `$testTableName` VALUES (123)") From 70e527bd29eb640f9829cb16c0cda9b4c138c4a6 Mon Sep 17 00:00:00 2001 From: Jing Zhan Date: Fri, 16 Feb 2024 09:37:40 -0800 Subject: [PATCH 09/13] Add a config flag for partition change check in DeltaSource Add a config for partition change check in Delta Source. Users can turn on or turn off the check by changing the config. Closes delta-io/delta#2618 GitOrigin-RevId: cbfc621d5f07e01b2ee60a048fb15e1fa80a9322 --- .../spark/DeltaFormatSharingSourceSuite.scala | 107 +++++++++++++++++- .../spark/sql/delta/DeltaColumnMapping.scala | 7 +- .../spark/sql/delta/schema/SchemaUtils.scala | 3 +- .../sql/delta/sources/DeltaSQLConf.scala | 11 ++ .../spark/sql/delta/sources/DeltaSource.scala | 15 ++- 5 files changed, 135 insertions(+), 8 deletions(-) diff --git a/sharing/src/test/scala/io/delta/sharing/spark/DeltaFormatSharingSourceSuite.scala b/sharing/src/test/scala/io/delta/sharing/spark/DeltaFormatSharingSourceSuite.scala index 16f37f9ae89..4f7f3d45c4a 100644 --- a/sharing/src/test/scala/io/delta/sharing/spark/DeltaFormatSharingSourceSuite.scala +++ b/sharing/src/test/scala/io/delta/sharing/spark/DeltaFormatSharingSourceSuite.scala @@ -16,12 +16,14 @@ package io.delta.sharing.spark +import org.apache.spark.sql.delta.DeltaIllegalStateException import org.apache.spark.sql.delta.DeltaLog import org.apache.spark.sql.delta.DeltaOptions.{ IGNORE_CHANGES_OPTION, IGNORE_DELETES_OPTION, SKIP_CHANGE_COMMITS_OPTION } +import org.apache.spark.sql.delta.sources.DeltaSQLConf import org.apache.spark.sql.delta.test.DeltaSQLCommandTest import io.delta.sharing.client.DeltaSharingRestClient import io.delta.sharing.client.model.{Table => DeltaSharingTable} @@ -30,7 +32,8 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkEnv import org.apache.spark.sql.Row import org.apache.spark.sql.delta.sharing.DeltaSharingTestSparkUtils -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.streaming.StreamingQueryException import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.types.{ DateType, @@ -898,4 +901,106 @@ class DeltaFormatSharingSourceSuite } } } + + Seq( + ("add a partition column", Seq("part"), Seq("is_even", "part")), + ("change partition order", Seq("part", "is_even"), Seq("is_even", "part")), + ("different partition column", Seq("part"), Seq("is_even")) + ).foreach { case (repartitionTestCase, initPartitionCols, overwritePartitionCols) => + test("deltaSharing - repartition delta source should fail by default " + + s"unless unsafe flag is set - $repartitionTestCase") { + withTempDirs { (inputDir, outputDir, checkpointDir) => + + val deltaTableName = "basic_delta_table_partition_check" + withTable(deltaTableName) { + spark.sql( + s"""CREATE TABLE $deltaTableName (id LONG, part INT, is_even BOOLEAN) + |USING DELTA PARTITIONED BY (${initPartitionCols.mkString(", ")}) + |""".stripMargin + ) + val sharedTableName = "shared_streaming_table_partition_check_" + + s"${repartitionTestCase.replace(' ', '_')}" + prepareMockedClientMetadata(deltaTableName, sharedTableName) + val profileFile = prepareProfileFile(inputDir) + val tablePath = profileFile.getCanonicalPath + s"#share1.default.$sharedTableName" + + withSQLConf(getDeltaSharingClassesSQLConf.toSeq: _*) { + + def processAllAvailableInStream(startingVersion: Int): Unit = { + val q = spark.readStream + .format("deltaSharing") + .option("responseFormat", "delta") + .option("skipChangeCommits", "true") + .option("startingVersion", startingVersion) + .load(tablePath) + .writeStream + .format("delta") + .option("checkpointLocation", checkpointDir.toString) + .start(outputDir.toString) + + try { + q.processAllAvailable() + } finally { + q.stop() + } + } + + spark.range(10).withColumn("part", lit(1)) + .withColumn("is_even", $"id" % 2 === 0).write + .format("delta").partitionBy(initPartitionCols: _*) + .mode("append") + .saveAsTable(deltaTableName) + spark.range(2).withColumn("part", lit(2)) + .withColumn("is_even", $"id" % 2 === 0).write + .format("delta").partitionBy(initPartitionCols: _*) + .mode("append").saveAsTable(deltaTableName) + spark.range(10).withColumn("part", lit(1)) + .withColumn("is_even", $"id" % 2 === 0).write + .format("delta").partitionBy(overwritePartitionCols: _*) + .option("overwriteSchema", "true").mode("overwrite") + .saveAsTable(deltaTableName) + spark.range(2).withColumn("part", lit(2)) + .withColumn("is_even", $"id" % 2 === 0).write + .format("delta").partitionBy(overwritePartitionCols: _*) + .mode("append").saveAsTable(deltaTableName) + + prepareMockedClientAndFileSystemResultForStreaming( + deltaTable = deltaTableName, + sharedTable = sharedTableName, + startingVersion = 0L, + endingVersion = 4L + ) + prepareMockedClientGetTableVersion(deltaTableName, sharedTableName) + var e = intercept[StreamingQueryException] { + processAllAvailableInStream(0) + } + assert(e.getCause.asInstanceOf[DeltaIllegalStateException].getErrorClass + == "DELTA_SCHEMA_CHANGED_WITH_STARTING_OPTIONS") + assert(e.getMessage.contains("Detected schema change in version 3")) + + // delta table created using sql with specified partition col + // will construct their initial snapshot on the initial definition + prepareMockedClientAndFileSystemResultForStreaming( + deltaTable = deltaTableName, + sharedTable = sharedTableName, + startingVersion = 4L, + endingVersion = 4L + ) + prepareMockedClientGetTableVersion(deltaTableName, sharedTableName) + e = intercept[StreamingQueryException] { + processAllAvailableInStream(4) + } + assert(e.getMessage.contains("Detected schema change in version 4")) + + // Streaming query made progress without throwing error when + // unsafe flag is set to true + withSQLConf( + DeltaSQLConf.DELTA_STREAMING_UNSAFE_READ_ON_PARTITION_COLUMN_CHANGE.key -> "true") { + processAllAvailableInStream(0) + } + } + } + } + } + } } diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaColumnMapping.scala b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaColumnMapping.scala index 30850d9f94e..3fc69bb578c 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaColumnMapping.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaColumnMapping.scala @@ -628,14 +628,17 @@ trait DeltaColumnMappingBase extends DeltaLogging { * As of now, `newMetadata` is column mapping read compatible with `oldMetadata` if * no rename column or drop column has happened in-between. */ - def hasNoColumnMappingSchemaChanges(newMetadata: Metadata, oldMetadata: Metadata): Boolean = { + def hasNoColumnMappingSchemaChanges(newMetadata: Metadata, oldMetadata: Metadata, + allowUnsafeReadOnPartitionChanges: Boolean = false): Boolean = { // Helper function to check no column mapping schema change and no repartition def hasNoColMappingAndRepartitionSchemaChange( newMetadata: Metadata, oldMetadata: Metadata): Boolean = { isRenameColumnOperation(newMetadata, oldMetadata) || isDropColumnOperation(newMetadata, oldMetadata) || !SchemaUtils.isPartitionCompatible( - newMetadata.partitionColumns, oldMetadata.partitionColumns) + // if allow unsafe row read for partition change, ignore the check + if (allowUnsafeReadOnPartitionChanges) Seq.empty else newMetadata.partitionColumns, + if (allowUnsafeReadOnPartitionChanges) Seq.empty else oldMetadata.partitionColumns) } val (oldMode, newMode) = (oldMetadata.columnMappingMode, newMetadata.columnMappingMode) diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala b/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala index 426a2f78d9a..eaaaab87d07 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala @@ -341,8 +341,7 @@ def normalizeColumnNamesInDataType( def isPartitionCompatible( newPartitionColumns: Seq[String] = Seq.empty, oldPartitionColumns: Seq[String] = Seq.empty): Boolean = { - (newPartitionColumns.isEmpty && oldPartitionColumns.isEmpty) || - (newPartitionColumns == oldPartitionColumns) + newPartitionColumns == oldPartitionColumns } /** diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSQLConf.scala b/spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSQLConf.scala index bba5463b3bd..ae08e203d59 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSQLConf.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSQLConf.scala @@ -1179,6 +1179,17 @@ trait DeltaSQLConfBase { .booleanConf .createWithDefault(false) + val DELTA_STREAMING_UNSAFE_READ_ON_PARTITION_COLUMN_CHANGE = + buildConf("streaming.unsafeReadOnPartitionColumnChanges.enabled") + .doc( + "Streaming read on Delta table with partition column overwrite " + + "(e.g. changing partition column) is currently blocked due to potential data loss. " + + "However, existing users may use this flag to force unblock " + + "if they'd like to take the risk.") + .internal() + .booleanConf + .createWithDefault(false) + val DELTA_STREAMING_ENABLE_SCHEMA_TRACKING = buildConf("streaming.schemaTracking.enabled") .doc( diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSource.scala b/spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSource.scala index 2b138fc6947..160cf1440e0 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSource.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSource.scala @@ -124,6 +124,11 @@ trait DeltaSourceBase extends Source unsafeFlagEnabled } + protected lazy val allowUnsafeStreamingReadOnPartitionColumnChanges: Boolean = + spark.sessionState.conf.getConf( + DeltaSQLConf.DELTA_STREAMING_UNSAFE_READ_ON_PARTITION_COLUMN_CHANGE + ) + /** * Flag that allows user to disable the read-compatibility check during stream start which * protects against an corner case in which verifyStreamHygiene could not detect. @@ -601,7 +606,8 @@ trait DeltaSourceBase extends Source if (!allowUnsafeStreamingReadOnColumnMappingSchemaChanges) { assert(!trackingMetadataChange, "should not check schema change while tracking it") - if (!DeltaColumnMapping.hasNoColumnMappingSchemaChanges(newMetadata, oldMetadata)) { + if (!DeltaColumnMapping.hasNoColumnMappingSchemaChanges(newMetadata, oldMetadata, + allowUnsafeStreamingReadOnPartitionColumnChanges)) { throw DeltaErrors.blockStreamingReadsWithIncompatibleColumnMappingSchemaChanges( spark, oldMetadata.schema, @@ -642,8 +648,11 @@ trait DeltaSourceBase extends Source isStreamingFromColumnMappingTable && allowUnsafeStreamingReadOnColumnMappingSchemaChanges && backfilling, - newPartitionColumns = newMetadata.partitionColumns, - oldPartitionColumns = oldMetadata.partitionColumns + // Partition column change will be ignored if user enable the unsafe flag + newPartitionColumns = if (allowUnsafeStreamingReadOnPartitionColumnChanges) Seq.empty + else newMetadata.partitionColumns, + oldPartitionColumns = if (allowUnsafeStreamingReadOnPartitionColumnChanges) Seq.empty + else oldMetadata.partitionColumns )) { // Only schema change later than the current read snapshot/schema can be retried, in other // words, backfills could never be retryable, because we have no way to refresh From 622bd3257182332b5ed3ca2d2146301b347c7c58 Mon Sep 17 00:00:00 2001 From: Thang Long Vu Date: Fri, 16 Feb 2024 19:05:31 +0100 Subject: [PATCH 10/13] Add tests to check the behaviour of the different combinations of CREATE and REPLACE with row IDs. Add tests to check the behaviour of the different combinations of `CREATE` and `REPLACE` with row IDs. Closes https://github.com/delta-io/delta/pull/2642 GitOrigin-RevId: b8bddb63afa616416be01211a5167639e3f044d9 --- .../rowid/RowIdCreateReplaceTableSuite.scala | 212 ++++++++++++++++++ .../sql/delta/rowid/RowIdTestUtils.scala | 7 + 2 files changed, 219 insertions(+) create mode 100644 spark/src/test/scala/org/apache/spark/sql/delta/rowid/RowIdCreateReplaceTableSuite.scala diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/rowid/RowIdCreateReplaceTableSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/rowid/RowIdCreateReplaceTableSuite.scala new file mode 100644 index 00000000000..3cfe6ac404a --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/delta/rowid/RowIdCreateReplaceTableSuite.scala @@ -0,0 +1,212 @@ +/* + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.rowid + +import org.apache.spark.sql.delta.{DeltaConfigs, DeltaLog} +import org.apache.spark.sql.delta.RowId.extractHighWatermark +import org.apache.spark.sql.delta.actions.TableFeatureProtocolUtils.TABLE_FEATURES_MIN_WRITER_VERSION + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.test.SharedSparkSession + +class RowIdCreateReplaceTableSuite extends QueryTest + with SharedSparkSession with RowIdTestUtils { + + private val numSourceRows = 50 + + test("Create or replace table with values list") { + withRowTrackingEnabled(enabled = true) { + withTable("target") { + writeTargetTestData(withRowIds = true) + val (log, snapshot) = DeltaLog.forTableWithSnapshot(spark, TableIdentifier("target")) + + val highWaterMarkBefore = extractHighWatermark(snapshot).get + createReplaceTargetTable( + commandName = "CREATE OR REPLACE", + query = "SELECT * FROM VALUES (0, 0), (1, 1)") + + assertHighWatermarkIsCorrectAfterUpdate( + log, highWaterMarkBefore, expectedNumRecordsWritten = 2) + assertRowIdsAreLargerThanValue(log, highWaterMarkBefore) + } + } + } + + test("Create or replace table with other delta table") { + withRowTrackingEnabled(enabled = true) { + withTable("source", "target") { + writeTargetTestData(withRowIds = true) + + writeSourceTestData(withRowIds = true) + val (log, snapshot) = DeltaLog.forTableWithSnapshot(spark, TableIdentifier("target")) + + val highWaterMarkBefore = extractHighWatermark(snapshot).get + createReplaceTargetTable(commandName = "CREATE OR REPLACE", query = "SELECT * FROM source") + + assertHighWatermarkIsCorrectAfterUpdate( + log, highWaterMarkBefore, expectedNumRecordsWritten = numSourceRows) + assertRowIdsAreLargerThanValue(log, highWaterMarkBefore) + } + } + } + + test("Replace table with values list") { + withRowTrackingEnabled(enabled = true) { + withTable("target") { + writeTargetTestData(withRowIds = true) + val (log, snapshot) = DeltaLog.forTableWithSnapshot(spark, TableIdentifier("target")) + + val highWaterMarkBefore = extractHighWatermark(snapshot).get + createReplaceTargetTable(commandName = "REPLACE", query = "SELECT * FROM VALUES (0), (1)") + + assertHighWatermarkIsCorrectAfterUpdate( + log, highWaterMarkBefore, expectedNumRecordsWritten = 2) + assertRowIdsAreLargerThanValue(log, highWaterMarkBefore) + } + } + } + + test("Replace table with another delta table") { + withRowTrackingEnabled(enabled = true) { + withTable("source", "target") { + writeTargetTestData(withRowIds = true) + val log = DeltaLog.forTable(spark, TableIdentifier("target")) + + writeSourceTestData(withRowIds = true) + + val highWaterMarkBefore = extractHighWatermark(log.update()).get + createReplaceTargetTable(commandName = "REPLACE", query = "SELECT * FROM source") + + assertHighWatermarkIsCorrectAfterUpdate( + log, highWaterMarkBefore, expectedNumRecordsWritten = numSourceRows) + assertRowIdsAreLargerThanValue(log, highWaterMarkBefore) + } + } + } + + test("Replace table with row IDs with table without row IDs assigns new row IDs") { + withTable("source", "target") { + writeTargetTestData(withRowIds = true) + val log = DeltaLog.forTable(spark, TableIdentifier("target")) + + writeSourceTestData(withRowIds = false) + + val highWaterMarkBefore = extractHighWatermark(log.update()).get + withRowTrackingEnabled(enabled = false) { + createReplaceTargetTable(commandName = "REPLACE", query = "SELECT * FROM source") + } + + assertHighWatermarkIsCorrectAfterUpdate( + log, highWaterMarkBefore, expectedNumRecordsWritten = numSourceRows) + } + } + + test("Replacing a table without row IDs with row IDs enabled assigns new row IDs") { + withTable("source", "target") { + writeTargetTestData(withRowIds = false) + writeSourceTestData(withRowIds = true) + + val log = DeltaLog.forTable(spark, TableIdentifier("target")) + assertRowIdsAreNotSet(log) + + withRowTrackingEnabled(enabled = true) { + createReplaceTargetTable( + commandName = "REPLACE", + query = "SELECT * FROM source", + tblProperties = s"'$rowTrackingFeatureName' = 'supported'" :: + s"'delta.minWriterVersion' = $TABLE_FEATURES_MIN_WRITER_VERSION" :: Nil) + } + + assertRowIdsAreValid(log) + } + } + + test("CREATE OR REPLACE on existing table without row IDs assigns new row IDs when enabling " + + "row IDs") { + withTable("target") { + writeTargetTestData(withRowIds = false) + + val log = DeltaLog.forTable(spark, TableIdentifier("target")) + assertRowIdsAreNotSet(log) + + withRowTrackingEnabled(enabled = true) { + createReplaceTargetTable( + commandName = "CREATE OR REPLACE", + query = "SELECT * FROM VALUES (0), (1)", + tblProperties = s"${DeltaConfigs.ROW_TRACKING_ENABLED.key} = 'true'" :: Nil) + } + + assertRowIdsAreValid(log) + } + } + + test("CTAS assigns new row IDs when immediately enabling row IDs") { + withTable("target") { + createReplaceTargetTable( + commandName = "CREATE", + query = "SELECT * FROM VALUES (0), (1)", + tblProperties = s"${DeltaConfigs.ROW_TRACKING_ENABLED.key} = 'true'" :: Nil) + + val log = DeltaLog.forTable(spark, TableIdentifier("target")) + assertRowIdsAreValid(log) + } + } + + test("CTAS assigns new row IDs when row IDs are by default enabled") { + withTable("target") { + withSQLConf(DeltaConfigs.ROW_TRACKING_ENABLED.defaultTablePropertyKey -> "true") { + createReplaceTargetTable( + commandName = "CREATE", + query = "SELECT * FROM VALUES (0), (1)") + + val log = DeltaLog.forTable(spark, TableIdentifier("target")) + assertRowIdsAreValid(log) + } + } + } + + def createReplaceTargetTable( + commandName: String, query: String, tblProperties: Seq[String] = Seq.empty): Unit = { + val tblPropertiesStr = if (tblProperties.nonEmpty) { + s"TBLPROPERTIES ${tblProperties.mkString("(", ",", ")")}" + } else { + "" + } + sql( + s""" + |$commandName TABLE target + |USING delta + |$tblPropertiesStr + |AS $query + |""".stripMargin) + } + + def writeTargetTestData(withRowIds: Boolean): Unit = { + withRowTrackingEnabled(enabled = withRowIds) { + spark.range(start = 0, end = 100, step = 1, numPartitions = 10) + .write.format("delta").saveAsTable("target") + } + } + + def writeSourceTestData(withRowIds: Boolean): Unit = { + withRowTrackingEnabled(enabled = withRowIds) { + spark.range(start = 0, end = numSourceRows, step = 1, numPartitions = 10) + .write.format("delta").saveAsTable("source") + } + } +} diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/rowid/RowIdTestUtils.scala b/spark/src/test/scala/org/apache/spark/sql/delta/rowid/RowIdTestUtils.scala index 2852d608f0b..a625e57f598 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/rowid/RowIdTestUtils.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/rowid/RowIdTestUtils.scala @@ -78,4 +78,11 @@ trait RowIdTestUtils extends RowTrackingTestUtils with DeltaSQLCommandTest { val files = snapshot.allFiles.collect() assert(files.forall(_.baseRowId.isEmpty)) } + + def assertRowIdsAreLargerThanValue(log: DeltaLog, value: Long): Unit = { + log.update().allFiles.collect().foreach { f => + val minRowId = getRowIdRangeInclusive(f)._1 + assert(minRowId > value, s"${f.toString} has a row id smaller or equal than $value") + } + } } From 19f3a4fc95860feee9b4d5508bbdc42c99417459 Mon Sep 17 00:00:00 2001 From: Carmen Kwan Date: Tue, 20 Feb 2024 14:18:24 +0100 Subject: [PATCH 11/13] [Spark][TEST-ONLY] Add more test coverage for TRUNCATE HISTORY #### Which Delta project/connector is this regarding? -Spark - [ ] Standalone - [ ] Flink - [ ] Kernel - [ ] Other (fill in here) ## Description We are currently missing SQL tests for ALTER TABLE DROP FEATURE TRUNCATE HISTORY with non-path based table. We have tests for writer feature, but not for readerwriter feature that require the TRUNCATE HISTORY syntax. This PR addresses that gap. This is a test-only PR. ## Does this PR introduce _any_ user-facing changes? No. Closes delta-io/delta#2635 GitOrigin-RevId: f64a7226defe240145dca756b51c5f325a030841 --- .../PreDowngradeTableFeatureCommand.scala | 13 +++- .../sql/delta/DeltaProtocolVersionSuite.scala | 72 +++++++++++++++++++ 2 files changed, 83 insertions(+), 2 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/PreDowngradeTableFeatureCommand.scala b/spark/src/main/scala/org/apache/spark/sql/delta/PreDowngradeTableFeatureCommand.scala index 8e48f5b7954..9e43ac3e6d1 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/PreDowngradeTableFeatureCommand.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/PreDowngradeTableFeatureCommand.scala @@ -20,6 +20,7 @@ import java.util.concurrent.TimeUnit import org.apache.spark.sql.delta.catalog.DeltaTableV2 import org.apache.spark.sql.delta.commands.{AlterTableSetPropertiesDeltaCommand, AlterTableUnsetPropertiesDeltaCommand} import org.apache.spark.sql.delta.metering.DeltaLogging +import org.apache.spark.sql.delta.util.{Utils => DeltaUtils} /** * A base class for implementing a preparation command for removing table features. @@ -45,7 +46,10 @@ case class TestWriterFeaturePreDowngradeCommand(table: DeltaTableV2) // Make sure feature data/metadata exist before proceeding. if (TestRemovableWriterFeature.validateRemoval(table.initialSnapshot)) return false - recordDeltaEvent(table.deltaLog, "delta.test.TestWriterFeaturePreDowngradeCommand") + if (DeltaUtils.isTesting) { + recordDeltaEvent(table.deltaLog, "delta.test.TestWriterFeaturePreDowngradeCommand") + } + val properties = Seq(TestRemovableWriterFeature.TABLE_PROP_KEY) AlterTableUnsetPropertiesDeltaCommand(table, properties, ifExists = true).run(table.spark) true @@ -53,12 +57,17 @@ case class TestWriterFeaturePreDowngradeCommand(table: DeltaTableV2) } case class TestReaderWriterFeaturePreDowngradeCommand(table: DeltaTableV2) - extends PreDowngradeTableFeatureCommand { + extends PreDowngradeTableFeatureCommand + with DeltaLogging { // To remove the feature we only need to remove the table property. override def removeFeatureTracesIfNeeded(): Boolean = { // Make sure feature data/metadata exist before proceeding. if (TestRemovableReaderWriterFeature.validateRemoval(table.initialSnapshot)) return false + if (DeltaUtils.isTesting) { + recordDeltaEvent(table.deltaLog, "delta.test.TestReaderWriterFeaturePreDowngradeCommand") + } + val properties = Seq(TestRemovableReaderWriterFeature.TABLE_PROP_KEY) AlterTableUnsetPropertiesDeltaCommand(table, properties, ifExists = true).run(table.spark) true diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaProtocolVersionSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaProtocolVersionSuite.scala index 14b7ca78672..d55276989a3 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaProtocolVersionSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaProtocolVersionSuite.scala @@ -3193,6 +3193,78 @@ trait DeltaProtocolVersionSuiteBase extends QueryTest } } + for { + withCatalog <- BOOLEAN_DOMAIN + quoteWith <- if (withCatalog) Seq("none", "single", "backtick") else Seq("none") + } test(s"Drop feature DDL TRUNCATE HISTORY - withCatalog=$withCatalog, quoteWith=$quoteWith") { + withTempDir { dir => + val table: String = if (withCatalog) { + s"${spark.sessionState.catalog.getCurrentDatabase}.table" + } else { + s"delta.`${dir.getCanonicalPath}`" + } + if (withCatalog) sql(s"DROP TABLE IF EXISTS $table") + sql( + s"""CREATE TABLE $table (id bigint) USING delta + |TBLPROPERTIES ( + |delta.feature.${TestRemovableReaderWriterFeature.name} = 'supported', + |${TestRemovableReaderWriterFeature.TABLE_PROP_KEY} = "true", + |${DeltaConfigs.TABLE_FEATURE_DROP_TRUNCATE_HISTORY_LOG_RETENTION.key} = "0 hours" + |)""".stripMargin) + + // We need to use a Delta log object with the ManualClock created in this test instead of + // the default SystemClock. However, we can't pass the Delta log to use directly in the SQL + // command. Instead, we will + // 1. Clear the Delta log cache to remove the log associated with table creation. + // 2. Populate the Delta log cache with the Delta log object that has the ManualClock we + // want to use + // TODO(c27kwan): Refactor this and provide a better way to control clocks in Delta tests. + val clock = new ManualClock(System.currentTimeMillis()) + val deltaLog = if (withCatalog) { + val tableIdentifier = + TableIdentifier("table", Some(spark.sessionState.catalog.getCurrentDatabase)) + // We need to hack the Delta log cache with path based access to setup the right key. + val path = DeltaLog.forTable(spark, tableIdentifier, clock).dataPath + DeltaLog.clearCache() + DeltaLog.forTable(spark, path, clock) + } else { + DeltaLog.clearCache() + DeltaLog.forTable(spark, dir, clock) + } + + val protocol = deltaLog.update().protocol + assert(protocol === protocolWithReaderFeature(TestRemovableReaderWriterFeature)) + + val logs = Log4jUsageLogger.track { + val featureName = quoteWith match { + case "none" => s"${TestRemovableReaderWriterFeature.name}" + case "single" => s"'${TestRemovableReaderWriterFeature.name}'" + case "backtick" => s"`${TestRemovableReaderWriterFeature.name}`" + } + + // Expect an exception when dropping a reader writer feature on a table that + // still has traces of the feature. + intercept[DeltaTableFeatureException] { + sql(s"ALTER TABLE $table DROP FEATURE $featureName") + } + + // Move past retention period. + clock.advance(TimeUnit.HOURS.toMillis(1)) + + sql(s"ALTER TABLE $table DROP FEATURE $featureName TRUNCATE HISTORY") + assert(deltaLog.update().protocol === Protocol(1, 1)) + } + + // Validate the correct downgrade command was invoked. + val expectedOpType = "delta.test.TestReaderWriterFeaturePreDowngradeCommand" + val blob = logs.collectFirst { + case r if r.metric == MetricDefinitions.EVENT_TAHOE.name && + r.tags.get("opType").contains(expectedOpType) => r.blob + } + assert(blob.nonEmpty, s"Expecting an '$expectedOpType' event but didn't see any.") + } + } + protected def testProtocolVersionDowngrade( initialMinReaderVersion: Int, initialMinWriterVersion: Int, From 3a972508db4531e0b8f6475799c3df4c566db47f Mon Sep 17 00:00:00 2001 From: Sumeet Varma Date: Tue, 20 Feb 2024 11:20:18 -0800 Subject: [PATCH 12/13] Add InMemoryCommitStore to test Managed Commit Backend An in-memory-commit-store that tracks per-table commits, backfills and validates various edge cases and unexpected scenarios. Closes delta-io/delta#2649 GitOrigin-RevId: 4da495caa6259501f16723ced0ca236ab8420044 --- .../AbstractBatchBackfillingCommitStore.scala | 16 +- .../sql/delta/managedcommit/CommitStore.scala | 2 +- .../managedcommit/InMemoryCommitStore.scala | 132 +++++++ .../InMemoryCommitStoreSuite.scala | 326 ++++++++++++++++++ 4 files changed, 473 insertions(+), 3 deletions(-) create mode 100644 spark/src/main/scala/org/apache/spark/sql/delta/managedcommit/InMemoryCommitStore.scala create mode 100644 spark/src/test/scala/org/apache/spark/sql/delta/managedcommit/InMemoryCommitStoreSuite.scala diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/managedcommit/AbstractBatchBackfillingCommitStore.scala b/spark/src/main/scala/org/apache/spark/sql/delta/managedcommit/AbstractBatchBackfillingCommitStore.scala index 4c4e8265176..6cfbec736b5 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/managedcommit/AbstractBatchBackfillingCommitStore.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/managedcommit/AbstractBatchBackfillingCommitStore.scala @@ -16,11 +16,13 @@ package org.apache.spark.sql.delta.managedcommit +import java.nio.file.FileAlreadyExistsException + import org.apache.spark.sql.delta.{DeltaLog, SerializableFileStatus} import org.apache.spark.sql.delta.storage.LogStore import org.apache.spark.sql.delta.util.FileNames import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileAlreadyExistsException, FileStatus, FileSystem, Path} +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.spark.internal.Logging @@ -119,17 +121,27 @@ trait AbstractBatchBackfillingCommitStore extends CommitStore with Logging { fileStatus: FileStatus): Unit = { val targetFile = FileNames.deltaFile(logPath(tablePath), version) logInfo(s"Backfilling commit ${fileStatus.getPath} to ${targetFile.toString}") + val commitContentIterator = logStore.readAsIterator(fileStatus, hadoopConf) try { logStore.write( targetFile, - logStore.readAsIterator(fileStatus, hadoopConf), + commitContentIterator, overwrite = false, hadoopConf) + registerBackfill(tablePath, version, targetFile) } catch { case _: FileAlreadyExistsException => logInfo(s"The backfilled file $targetFile already exists.") + } finally { + commitContentIterator.close() } } + /** Callback to tell the CommitStore that all commits <= `untilVersion` are backfilled. */ + protected[delta] def registerBackfill( + tablePath: Path, + untilVersion: Long, + deltaFile: Path): Unit + protected def logPath(tablePath: Path): Path = new Path(tablePath, DeltaLog.LOG_DIR_NAME) } diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/managedcommit/CommitStore.scala b/spark/src/main/scala/org/apache/spark/sql/delta/managedcommit/CommitStore.scala index f090b2f54de..5b0310bfccb 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/managedcommit/CommitStore.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/managedcommit/CommitStore.scala @@ -43,7 +43,7 @@ case class Commit( * | yes | yes | physical conflict (allowed to rebase and retry) | */ class CommitFailedException( - retryable: Boolean, conflict: Boolean, message: String) extends Exception(message) + val retryable: Boolean, val conflict: Boolean, message: String) extends Exception(message) /** Response container for [[CommitStore.commit]] API */ case class CommitResponse(commit: Commit) diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/managedcommit/InMemoryCommitStore.scala b/spark/src/main/scala/org/apache/spark/sql/delta/managedcommit/InMemoryCommitStore.scala new file mode 100644 index 00000000000..053febea0e9 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/delta/managedcommit/InMemoryCommitStore.scala @@ -0,0 +1,132 @@ +/* + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.managedcommit + +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.locks.ReentrantReadWriteLock + +import scala.collection.mutable + +import org.apache.spark.sql.delta.SerializableFileStatus +import org.apache.spark.sql.delta.storage.LogStore +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} + +class InMemoryCommitStore(val batchSize: Long) extends AbstractBatchBackfillingCommitStore { + + private[managedcommit] class PerTableData { + // Map from version to Commit data + val commitsMap: mutable.SortedMap[Long, Commit] = mutable.SortedMap.empty + // We maintain maxCommitVersion explicitly since commitsMap might be empty + // if all commits for a table have been backfilled. + var maxCommitVersion: Long = -1 + val lock: ReentrantReadWriteLock = new ReentrantReadWriteLock() + } + + private[managedcommit] val perTableMap = new ConcurrentHashMap[Path, PerTableData]() + + private[managedcommit] def withWriteLock[T](tablePath: Path)(operation: => T): T = { + val lock = perTableMap + .computeIfAbsent(tablePath, _ => new PerTableData()) // computeIfAbsent is atomic + .lock + .writeLock() + lock.lock() + try { + operation + } finally { + lock.unlock() + } + } + + private[managedcommit] def withReadLock[T](tablePath: Path)(operation: => T): T = { + val lock = perTableMap + .computeIfAbsent(tablePath, _ => new PerTableData()) // computeIfAbsent is atomic + .lock + .readLock() + lock.lock() + try { + operation + } finally { + lock.unlock() + } + } + + /** + * This method acquires a write lock, validates the commit version is next in line, + * updates commit maps, and releases the lock. + * + * @throws CommitFailedException if the commit version is not the expected next version, + * indicating a version conflict. + */ + protected def commitImpl( + logStore: LogStore, + hadoopConf: Configuration, + tablePath: Path, + commitVersion: Long, + commitFile: FileStatus, + commitTimestamp: Long): CommitResponse = { + withWriteLock[CommitResponse](tablePath) { + val tableData = perTableMap.get(tablePath) + val expectedVersion = tableData.maxCommitVersion + 1 + if (commitVersion != expectedVersion) { + throw new CommitFailedException( + retryable = commitVersion < expectedVersion, + conflict = commitVersion < expectedVersion, + s"Commit version $commitVersion is not valid. Expected version: $expectedVersion.") + } + + val commit = + Commit(commitVersion, SerializableFileStatus.fromStatus(commitFile), commitTimestamp) + tableData.commitsMap(commitVersion) = commit + tableData.maxCommitVersion = commitVersion + + logInfo(s"Added commit file ${commitFile.getPath} to commit-store.") + CommitResponse(commit) + } + } + + override def getCommits( + tablePath: Path, + startVersion: Long, + endVersion: Option[Long]): Seq[Commit] = { + withReadLock[Seq[Commit]](tablePath) { + val tableData = perTableMap.get(tablePath) + // Calculate the end version for the range, or use the last key if endVersion is not provided + val effectiveEndVersion = + endVersion.getOrElse(tableData.commitsMap.lastOption.map(_._1).getOrElse(startVersion)) + val commitsInRange = tableData.commitsMap.range(startVersion, effectiveEndVersion + 1) + commitsInRange.values.toSeq + } + } + + override protected[delta] def registerBackfill( + tablePath: Path, + untilVersion: Long, + deltaFile: Path): Unit = { + withWriteLock(tablePath) { + val tableData = perTableMap.get(tablePath) + if (untilVersion > tableData.maxCommitVersion) { + throw new IllegalArgumentException( + s"Unexpected backfill version: $untilVersion. " + + s"Max backfill version: ${tableData.maxCommitVersion}") + } + // Remove keys with versions less than or equal to 'untilVersion' + val versionsToRemove = tableData.commitsMap.keys.takeWhile(_ <= untilVersion).toList + versionsToRemove.foreach(tableData.commitsMap.remove) + } + } +} diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/managedcommit/InMemoryCommitStoreSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/managedcommit/InMemoryCommitStoreSuite.scala new file mode 100644 index 00000000000..4cb0cc00246 --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/delta/managedcommit/InMemoryCommitStoreSuite.scala @@ -0,0 +1,326 @@ +/* + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.managedcommit + +import java.io.File +import java.util.concurrent.{Executors, TimeUnit} +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.spark.sql.delta.DeltaLog +import org.apache.spark.sql.delta.actions.CommitInfo +import org.apache.spark.sql.delta.storage.{LogStore, LogStoreProvider} +import org.apache.spark.sql.delta.util.FileNames +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.util.Utils + +object InMemoryCommitStoreBuilder { + private val defaultBatchSize = 5L +} + +/** + * The InMemoryCommitStoreBuilder class is responsible for creating singleton instances of + * InMemoryCommitStore with the specified batchSize. + */ +case class InMemoryCommitStoreBuilder( + batchSize: Long = InMemoryCommitStoreBuilder.defaultBatchSize) extends CommitStoreBuilder { + private lazy val inMemoryStore = new InMemoryCommitStore(batchSize) + + /** Name of the commit-store */ + def name: String = "InMemoryCommitStore" + + /** Returns a commit store based on the given conf */ + def build(conf: Map[String, String]): CommitStore = { + inMemoryStore + } +} + +class InMemoryCommitStoreSuite extends QueryTest + with SharedSparkSession + with LogStoreProvider { + + // scalastyle:off deltahadoopconfiguration + def sessionHadoopConf: Configuration = spark.sessionState.newHadoopConf() + // scalastyle:on deltahadoopconfiguration + + def store: LogStore = createLogStore(spark) + + private def withTempTableDir(f: File => Unit): Unit = { + val dir = Utils.createTempDir() + val deltaLogDir = new File(dir, DeltaLog.LOG_DIR_NAME) + deltaLogDir.mkdir() + val commitLogDir = new File(deltaLogDir, FileNames.COMMIT_SUBDIR) + commitLogDir.mkdir() + try f(dir) + finally { + Utils.deleteRecursively(dir) + } + } + + protected def commit( + version: Long, + timestamp: Long, + cs: CommitStore, + tablePath: Path): Commit = { + val commitInfo = CommitInfo.empty(version = Some(version)).withTimestamp(timestamp) + cs.commit( + store, + sessionHadoopConf, + tablePath, + version, + Iterator(s"$version", s"$timestamp"), + UpdatedActions(commitInfo, None, None)).commit + } + + private def assertBackfilled( + version: Long, + tablePath: Path, + timestampOpt: Option[Long] = None): Unit = { + val logPath = new Path(tablePath, DeltaLog.LOG_DIR_NAME) + val delta = FileNames.deltaFile(logPath, version) + if (timestampOpt.isDefined) { + assert(store.read(delta, sessionHadoopConf) == Seq(s"$version", s"${timestampOpt.get}")) + } else { + assert(store.read(delta, sessionHadoopConf).take(1) == Seq(s"$version")) + } + } + + private def assertCommitFail( + currentVersion: Long, + expectedVersion: Long, + retryable: Boolean, + commitFunc: => Commit): Unit = { + val e = intercept[CommitFailedException] { + commitFunc + } + assert(e.retryable == retryable) + assert(e.conflict == retryable) + assert( + e.getMessage == + s"Commit version $currentVersion is not valid. Expected version: $expectedVersion.") + } + + private def assertInvariants( + tablePath: Path, + cs: InMemoryCommitStore, + commitTimestampsOpt: Option[Array[Long]] = None): Unit = { + val maxUntrackedVersion: Int = cs.withReadLock[Int](tablePath) { + val tableData = cs.perTableMap.get(tablePath) + if (tableData.commitsMap.isEmpty) { + tableData.maxCommitVersion.toInt + } else { + assert( + tableData.commitsMap.last._1 == tableData.maxCommitVersion, + s"Max version in commitMap ${tableData.commitsMap.last._1} must match max version in " + + s"maxCommitVersionMap $tableData.maxCommitVersion.") + val minVersion = tableData.commitsMap.head._1 + assert( + tableData.maxCommitVersion - minVersion + 1 == tableData.commitsMap.size, + "Commit map should have a contiguous range of unbackfilled commits.") + minVersion.toInt - 1 + } + } + (0 to maxUntrackedVersion).foreach { version => + assertBackfilled(version, tablePath, commitTimestampsOpt.map(_(version)))} + } + + test("in-memory-commit-store-builder works as expected") { + val builder1 = InMemoryCommitStoreBuilder(5) + val cs1 = builder1.build(Map.empty) + assert(cs1.isInstanceOf[InMemoryCommitStore]) + assert(cs1.asInstanceOf[InMemoryCommitStore].batchSize == 5) + + val cs1_again = builder1.build(Map.empty) + assert(cs1_again.isInstanceOf[InMemoryCommitStore]) + assert(cs1 == cs1_again) + + val builder2 = InMemoryCommitStoreBuilder(10) + val cs2 = builder2.build(Map.empty) + assert(cs2.isInstanceOf[InMemoryCommitStore]) + assert(cs2.asInstanceOf[InMemoryCommitStore].batchSize == 10) + assert(cs2 ne cs1) + + val builder3 = InMemoryCommitStoreBuilder(10) + val cs3 = builder3.build(Map.empty) + assert(cs3.isInstanceOf[InMemoryCommitStore]) + assert(cs3.asInstanceOf[InMemoryCommitStore].batchSize == 10) + assert(cs3 ne cs2) + } + + test("test basic commit and backfill functionality") { + withTempTableDir { tempDir => + val tablePath = new Path(tempDir.getCanonicalPath) + val cs = InMemoryCommitStoreBuilder(batchSize = 3).build(Map.empty) + + // Commit 0 is always immediately backfilled + val c0 = commit(0, 0, cs, tablePath) + assert(cs.getCommits(tablePath, 0) == Seq.empty) + assertBackfilled(0, tablePath, Some(0)) + + val c1 = commit(1, 1, cs, tablePath) + val c2 = commit(2, 2, cs, tablePath) + assert(cs.getCommits(tablePath, 0).takeRight(2) == Seq(c1, c2)) + + // All 3 commits are backfilled since batchSize == 3 + val c3 = commit(3, 3, cs, tablePath) + assert(cs.getCommits(tablePath, 0) == Seq.empty) + (1 to 3).foreach(i => assertBackfilled(i, tablePath, Some(i))) + + // Test that startVersion and endVersion are respected in getCommits + val c4 = commit(4, 4, cs, tablePath) + val c5 = commit(5, 5, cs, tablePath) + assert(cs.getCommits(tablePath, 4) == Seq(c4, c5)) + assert(cs.getCommits(tablePath, 4, Some(4)) == Seq(c4)) + assert(cs.getCommits(tablePath, 5) == Seq(c5)) + + // Commit [4, 6] are backfilled since batchSize == 3 + val c6 = commit(6, 6, cs, tablePath) + assert(cs.getCommits(tablePath, 0) == Seq.empty) + (4 to 6).foreach(i => assertBackfilled(i, tablePath, Some(i))) + assertInvariants(tablePath, cs.asInstanceOf[InMemoryCommitStore]) + } + } + + test("test basic commit and backfill functionality with 1 batch size") { + withTempTableDir { tempDir => + val tablePath = new Path(tempDir.getCanonicalPath) + val cs = InMemoryCommitStoreBuilder(batchSize = 1).build(Map.empty) + + // Test that all commits are immediately backfilled + (0 to 3).foreach { version => + commit(version, version, cs, tablePath) + assert(cs.getCommits(tablePath, 0) == Seq.empty) + assertBackfilled(version, tablePath, Some(version)) + } + + // Test that out-of-order backfill is rejected + intercept[IllegalArgumentException] { + cs.asInstanceOf[InMemoryCommitStore] + .registerBackfill(tablePath, 5, new Path("delta5.json")) + } + assertInvariants(tablePath, cs.asInstanceOf[InMemoryCommitStore]) + } + } + + test("test out-of-order commits are rejected") { + withTempTableDir { tempDir => + val tablePath = new Path(tempDir.getCanonicalPath) + val cs = InMemoryCommitStoreBuilder(batchSize = 5).build(Map.empty) + + // Anything other than version-0 should be rejected as the first commit + assertCommitFail(1, 0, retryable = false, commit(1, 0, cs, tablePath)) + + // Verify that conflict-checker rejects out-of-order commits. + (0 to 4).foreach(i => commit(i, i, cs, tablePath)) + assertCommitFail(0, 5, retryable = true, commit(0, 5, cs, tablePath)) + assertCommitFail(4, 5, retryable = true, commit(4, 6, cs, tablePath)) + + // Verify that the conflict-checker still works even when everything has been backfilled + commit(5, 5, cs, tablePath) + assert(cs.getCommits(tablePath, 0) == Seq.empty) + assertCommitFail(5, 6, retryable = true, commit(5, 5, cs, tablePath)) + assertCommitFail(7, 6, retryable = false, commit(7, 7, cs, tablePath)) + + assertInvariants(tablePath, cs.asInstanceOf[InMemoryCommitStore]) + } + } + + test("test out-of-order backfills are rejected") { + withTempTableDir { tempDir => + val tablePath = new Path(tempDir.getCanonicalPath) + val cs = InMemoryCommitStoreBuilder(batchSize = 5).build(Map.empty) + intercept[IllegalArgumentException] { + cs.asInstanceOf[InMemoryCommitStore].registerBackfill(tablePath, 0, new Path("delta0.json")) + } + (0 to 3).foreach(i => commit(i, i, cs, tablePath)) + + // Test that backfilling is idempotent for already-backfilled commits. + cs.asInstanceOf[InMemoryCommitStore].registerBackfill(tablePath, 2, new Path("delta2.json")) + cs.asInstanceOf[InMemoryCommitStore].registerBackfill(tablePath, 2, new Path("delta2.json")) + + // Test that backfilling uncommited commits fail. + intercept[IllegalArgumentException] { + cs.asInstanceOf[InMemoryCommitStore].registerBackfill(tablePath, 4, new Path("delta4.json")) + } + } + } + + test("should handle concurrent readers and writers") { + withTempTableDir { tempDir => + val tablePath = new Path(tempDir.getCanonicalPath) + val batchSize = 6 + val cs = InMemoryCommitStoreBuilder(batchSize).build(Map.empty) + + val numberOfWriters = 10 + val numberOfCommitsPerWriter = 10 + // scalastyle:off sparkThreadPools + val executor = Executors.newFixedThreadPool(numberOfWriters) + // scalastyle:on sparkThreadPools + val runningTimestamp = new AtomicInteger(0) + val commitFailedExceptions = new AtomicInteger(0) + val totalCommits = numberOfWriters * numberOfCommitsPerWriter + // actualCommits is used to determine version of next commit when getCommits is empty. + val actualCommits = new AtomicInteger(0) + val commitTimestamp: Array[Long] = new Array[Long](totalCommits) + + try { + (0 until numberOfWriters).foreach { i => + executor.submit(new Runnable { + override def run(): Unit = { + var currentWriterCommits = 0 + while (currentWriterCommits < numberOfCommitsPerWriter) { + val nextVersion = + cs + .getCommits(tablePath, 0) + .lastOption.map(_.version + 1) + .getOrElse(actualCommits.get().toLong) + try { + val currentTimestamp = runningTimestamp.getAndIncrement() + val commitResponse = commit(nextVersion, currentTimestamp, cs, tablePath) + currentWriterCommits += 1 + actualCommits.getAndIncrement() + assert(commitResponse.commitTimestamp == currentTimestamp) + assert(commitResponse.version == nextVersion) + commitTimestamp(commitResponse.version.toInt) = commitResponse.commitTimestamp + } catch { + case e: CommitFailedException => + assert(e.conflict) + assert(e.retryable) + commitFailedExceptions.getAndIncrement() + } finally { + assertInvariants( + tablePath, + cs.asInstanceOf[InMemoryCommitStore], + Some(commitTimestamp)) + } + } + } + }) + } + + executor.shutdown() + executor.awaitTermination(15, TimeUnit.SECONDS) + } catch { + case e: InterruptedException => + fail("Test interrupted: " + e.getMessage) + } + } + } +} From 98fac5728f8efb297d3e68c1c6301871bfc93095 Mon Sep 17 00:00:00 2001 From: Venki Korukanti Date: Wed, 21 Feb 2024 13:29:59 -0800 Subject: [PATCH 13/13] [Kernel] Collect file statistics as part of writing Parquet files ## Description Add support for collecting statistics for columns as part of the Parquet file writing. ## How was this patch tested? Refactored existing tests to make them concise. Added tests for stats collection and verifying the stats using the Spark reader. Also added a few special cases around collecting stats when the input contains NaN, -0.0 or 0.0. --- ...a657-3ba905ccae36-c000.snappy.parquet.crc} | Bin 232 -> 232 bytes .../_delta_log/.00000000000000000000.json.crc | Bin 0 -> 44 bytes .../_delta_log/00000000000000000000.json | 4 + ...d77-a657-3ba905ccae36-c000.snappy.parquet} | Bin 28420 -> 28420 bytes .../scala/io/delta/golden/GoldenTables.scala | 2 +- .../defaults/internal/DefaultKernelUtils.java | 37 +- .../internal/parquet/ParquetFileWriter.java | 21 +- .../internal/parquet/ParquetStatsReader.java | 242 ++++++++ .../parquet/ParquetFileWriterSuite.scala | 538 ++++++++++++------ .../defaults/utils/ExpressionTestUtils.scala | 10 +- .../kernel/defaults/utils/TestUtils.scala | 20 +- .../defaults/utils/VectorTestUtils.scala | 50 +- 12 files changed, 724 insertions(+), 200 deletions(-) rename connectors/golden-tables/src/main/resources/golden/parquet-all-types/{.part-00000-b47994d3-8795-4126-b3a4-2b56a6bdee04-c000.snappy.parquet.crc => .part-00000-4b3cf091-231f-4d77-a657-3ba905ccae36-c000.snappy.parquet.crc} (82%) create mode 100644 connectors/golden-tables/src/main/resources/golden/parquet-all-types/_delta_log/.00000000000000000000.json.crc create mode 100644 connectors/golden-tables/src/main/resources/golden/parquet-all-types/_delta_log/00000000000000000000.json rename connectors/golden-tables/src/main/resources/golden/parquet-all-types/{part-00000-b47994d3-8795-4126-b3a4-2b56a6bdee04-c000.snappy.parquet => part-00000-4b3cf091-231f-4d77-a657-3ba905ccae36-c000.snappy.parquet} (96%) create mode 100644 kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetStatsReader.java diff --git a/connectors/golden-tables/src/main/resources/golden/parquet-all-types/.part-00000-b47994d3-8795-4126-b3a4-2b56a6bdee04-c000.snappy.parquet.crc b/connectors/golden-tables/src/main/resources/golden/parquet-all-types/.part-00000-4b3cf091-231f-4d77-a657-3ba905ccae36-c000.snappy.parquet.crc similarity index 82% rename from connectors/golden-tables/src/main/resources/golden/parquet-all-types/.part-00000-b47994d3-8795-4126-b3a4-2b56a6bdee04-c000.snappy.parquet.crc rename to connectors/golden-tables/src/main/resources/golden/parquet-all-types/.part-00000-4b3cf091-231f-4d77-a657-3ba905ccae36-c000.snappy.parquet.crc index c21175853f30e702074ba491e36ef26cd2d70baf..654003416aff8e6f4bc3cf2a060c88cf7c2767e8 100644 GIT binary patch delta 47 zcmV+~0MP&F0q6mcz$nm4`maJS)aAaAkA~*s-CByhdl#ia(QQG#IDQ00(j`6=@vFq@ F+Rkc18PNa$ delta 47 zcmV+~0MP&F0q6mcz$ieW{EA7GjITmc5Rh-w*5pHZVfCd!(QQG#IDQ00(j`6=@vFz* F18+3A7c2k( diff --git a/connectors/golden-tables/src/main/resources/golden/parquet-all-types/_delta_log/.00000000000000000000.json.crc b/connectors/golden-tables/src/main/resources/golden/parquet-all-types/_delta_log/.00000000000000000000.json.crc new file mode 100644 index 0000000000000000000000000000000000000000..252a7d95a416025571c24162b40cf8bf1d52f40e GIT binary patch literal 44 zcmYc;N@ieSU}CtXXdFFfshXGEK8*_vYCkVcm+K4J^4GvLuJ&j=pXm*)=Jtks0AoiH AqyPW_ literal 0 HcmV?d00001 diff --git a/connectors/golden-tables/src/main/resources/golden/parquet-all-types/_delta_log/00000000000000000000.json b/connectors/golden-tables/src/main/resources/golden/parquet-all-types/_delta_log/00000000000000000000.json new file mode 100644 index 00000000000..2b752d2c10e --- /dev/null +++ b/connectors/golden-tables/src/main/resources/golden/parquet-all-types/_delta_log/00000000000000000000.json @@ -0,0 +1,4 @@ +{"commitInfo":{"timestamp":1708108025792,"operation":"WRITE","operationParameters":{"mode":"Append","partitionBy":"[]"},"isolationLevel":"Serializable","isBlindAppend":true,"operationMetrics":{"numFiles":"1","numOutputRows":"200","numOutputBytes":"28420"},"engineInfo":"Apache-Spark/3.5.0 Delta-Lake/3.2.0-SNAPSHOT","txnId":"59b6eac4-ecda-4440-8268-a18cac973ba1"}} +{"metaData":{"id":"testId","format":{"provider":"parquet","options":{}},"schemaString":"{\"type\":\"struct\",\"fields\":[{\"name\":\"ByteType\",\"type\":\"byte\",\"nullable\":true,\"metadata\":{}},{\"name\":\"ShortType\",\"type\":\"short\",\"nullable\":true,\"metadata\":{}},{\"name\":\"IntegerType\",\"type\":\"integer\",\"nullable\":true,\"metadata\":{}},{\"name\":\"LongType\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"FloatType\",\"type\":\"float\",\"nullable\":true,\"metadata\":{}},{\"name\":\"DoubleType\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"decimal\",\"type\":\"decimal(10,2)\",\"nullable\":true,\"metadata\":{}},{\"name\":\"BooleanType\",\"type\":\"boolean\",\"nullable\":true,\"metadata\":{}},{\"name\":\"StringType\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}},{\"name\":\"BinaryType\",\"type\":\"binary\",\"nullable\":true,\"metadata\":{}},{\"name\":\"DateType\",\"type\":\"date\",\"nullable\":true,\"metadata\":{}},{\"name\":\"TimestampType\",\"type\":\"timestamp\",\"nullable\":true,\"metadata\":{}},{\"name\":\"nested_struct\",\"type\":{\"type\":\"struct\",\"fields\":[{\"name\":\"aa\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}},{\"name\":\"ac\",\"type\":{\"type\":\"struct\",\"fields\":[{\"name\":\"aca\",\"type\":\"integer\",\"nullable\":true,\"metadata\":{}}]},\"nullable\":true,\"metadata\":{}}]},\"nullable\":true,\"metadata\":{}},{\"name\":\"array_of_prims\",\"type\":{\"type\":\"array\",\"elementType\":\"integer\",\"containsNull\":true},\"nullable\":true,\"metadata\":{}},{\"name\":\"array_of_arrays\",\"type\":{\"type\":\"array\",\"elementType\":{\"type\":\"array\",\"elementType\":\"integer\",\"containsNull\":true},\"containsNull\":true},\"nullable\":true,\"metadata\":{}},{\"name\":\"array_of_structs\",\"type\":{\"type\":\"array\",\"elementType\":{\"type\":\"struct\",\"fields\":[{\"name\":\"ab\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}}]},\"containsNull\":true},\"nullable\":true,\"metadata\":{}},{\"name\":\"map_of_prims\",\"type\":{\"type\":\"map\",\"keyType\":\"integer\",\"valueType\":\"long\",\"valueContainsNull\":true},\"nullable\":true,\"metadata\":{}},{\"name\":\"map_of_rows\",\"type\":{\"type\":\"map\",\"keyType\":\"integer\",\"valueType\":{\"type\":\"struct\",\"fields\":[{\"name\":\"ab\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}}]},\"valueContainsNull\":true},\"nullable\":true,\"metadata\":{}},{\"name\":\"map_of_arrays\",\"type\":{\"type\":\"map\",\"keyType\":\"long\",\"valueType\":{\"type\":\"array\",\"elementType\":\"integer\",\"containsNull\":true},\"valueContainsNull\":true},\"nullable\":true,\"metadata\":{}}]}","partitionColumns":[],"configuration":{},"createdTime":1708108023726}} +{"protocol":{"minReaderVersion":1,"minWriterVersion":2}} +{"add":{"path":"part-00000-4b3cf091-231f-4d77-a657-3ba905ccae36-c000.snappy.parquet","partitionValues":{},"size":28420,"modificationTime":1708108025717,"dataChange":true,"stats":"{\"numRecords\":200,\"minValues\":{\"ByteType\":-128,\"ShortType\":1,\"IntegerType\":1,\"LongType\":2,\"FloatType\":0.234,\"DoubleType\":234234.23,\"decimal\":123.52,\"StringType\":\"1\",\"DateType\":\"1970-01-01\",\"TimestampType\":\"1970-01-01T06:30:23.523Z\",\"nested_struct\":{\"aa\":\"1\",\"ac\":{\"aca\":1}}},\"maxValues\":{\"ByteType\":127,\"ShortType\":199,\"IntegerType\":199,\"LongType\":200,\"FloatType\":46.566,\"DoubleType\":4.661261177E7,\"decimal\":24580.48,\"StringType\":\"99\",\"DateType\":\"1970-02-16\",\"TimestampType\":\"1970-02-23T22:48:01.077Z\",\"nested_struct\":{\"aa\":\"99\",\"ac\":{\"aca\":199}}},\"nullCount\":{\"ByteType\":3,\"ShortType\":4,\"IntegerType\":9,\"LongType\":8,\"FloatType\":8,\"DoubleType\":4,\"decimal\":3,\"BooleanType\":3,\"StringType\":4,\"BinaryType\":4,\"DateType\":4,\"TimestampType\":4,\"nested_struct\":{\"aa\":14,\"ac\":{\"aca\":22}},\"array_of_prims\":8,\"array_of_arrays\":25,\"array_of_structs\":0,\"map_of_prims\":8,\"map_of_rows\":0,\"map_of_arrays\":7}}"}} diff --git a/connectors/golden-tables/src/main/resources/golden/parquet-all-types/part-00000-b47994d3-8795-4126-b3a4-2b56a6bdee04-c000.snappy.parquet b/connectors/golden-tables/src/main/resources/golden/parquet-all-types/part-00000-4b3cf091-231f-4d77-a657-3ba905ccae36-c000.snappy.parquet similarity index 96% rename from connectors/golden-tables/src/main/resources/golden/parquet-all-types/part-00000-b47994d3-8795-4126-b3a4-2b56a6bdee04-c000.snappy.parquet rename to connectors/golden-tables/src/main/resources/golden/parquet-all-types/part-00000-4b3cf091-231f-4d77-a657-3ba905ccae36-c000.snappy.parquet index c45ef0884d3c21ce37ddb59a9a30f44397be8650..3a4ace80ea920b488006f26078b4d4e5d9f6e7fd 100644 GIT binary patch delta 221 zcmZp<$Jla@aYIxLGY1>PHBeB zU|J{M6-<}JyEAdHP2L-y1fmoZJiwwk3BDkD^PYq_MkXLFlT;0&C(ley2XUEFn8953 z)CG*Dn^&Y>=P)D$UGEQAkNLF|;r?G_XuDO)@t&NliAjG&V@FFgHj}NlY~~ aGfXuxut-TsO4OA6BR=_Jp4Q|Cc`5(}7dnLi delta 221 zcmZp<$Jla@aYIxLGXopPHBeB zU|J{M6-<}JyE6fe+Z&$*q7)N6z@j+`z94$@o`g6?CI+_6GD+1Sdh*QVbP$&*g&E9c zPhG%hvUx@7bq;eQJ!1upq|(fs6ooX4l;qSDLkpv1lVroRq?FWTv*hF?V?zs5Q)BZa b(?kQqRI}6+vm{N)KjM=w=4nlSkf#CwGXy(& diff --git a/connectors/golden-tables/src/test/scala/io/delta/golden/GoldenTables.scala b/connectors/golden-tables/src/test/scala/io/delta/golden/GoldenTables.scala index 69522ea5081..cd734de340d 100644 --- a/connectors/golden-tables/src/test/scala/io/delta/golden/GoldenTables.scala +++ b/connectors/golden-tables/src/test/scala/io/delta/golden/GoldenTables.scala @@ -1171,7 +1171,7 @@ class GoldenTables extends QueryTest with SharedSparkSession { val df = spark.createDataFrame(spark.sparkContext.parallelize(rows), schema) df.repartition(1) .write - .format("parquet") + .format("delta") .mode("append") .save(tablePath) } diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/DefaultKernelUtils.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/DefaultKernelUtils.java index a9c98faefde..6a3f67300a3 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/DefaultKernelUtils.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/DefaultKernelUtils.java @@ -18,12 +18,17 @@ import java.time.LocalDate; import java.util.concurrent.TimeUnit; +import io.delta.kernel.expressions.Column; +import io.delta.kernel.types.DataType; +import io.delta.kernel.types.StructType; + import io.delta.kernel.internal.util.Tuple2; public class DefaultKernelUtils { private static final LocalDate EPOCH = LocalDate.ofEpochDay(0); - private DefaultKernelUtils() {} + private DefaultKernelUtils() { + } ////////////////////////////////////////////////////////////////////////////////// // Below utils are adapted from org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -39,12 +44,12 @@ private DefaultKernelUtils() {} public static long fromJulianDay(int days, long nanos) { // use Long to avoid rounding errors return ((long) (days - JULIAN_DAY_OF_EPOCH)) * DateTimeConstants.MICROS_PER_DAY + - nanos / DateTimeConstants.NANOS_PER_MICROS; + nanos / DateTimeConstants.NANOS_PER_MICROS; } /** * Returns Julian day and remaining nanoseconds from the number of microseconds - * + *

    * Note: support timestamp since 4717 BC (without negative nanoseconds, compatible with Hive). */ public static Tuple2 toJulianDay(long micros) { @@ -87,4 +92,30 @@ public static class DateTimeConstants { public static final long NANOS_PER_MILLIS = MICROS_PER_MILLIS * NANOS_PER_MICROS; public static final long NANOS_PER_SECOND = MILLIS_PER_SECOND * NANOS_PER_MILLIS; } + + /** + * Search for the data type of the given column in the schema. + * + * @param schema the schema to search + * @param column the column whose data type is to be found + * @return the data type of the column + * @throws IllegalArgumentException if the column is not found in the schema + */ + public static DataType getDataType(StructType schema, Column column) { + DataType dataType = schema; + for (String part : column.getNames()) { + if (!(dataType instanceof StructType)) { + throw new IllegalArgumentException( + String.format("Cannot resolve column (%s) in schema: %s", column, schema)); + } + StructType structType = (StructType) dataType; + if (structType.fieldNames().contains(part)) { + dataType = structType.get(part).getDataType(); + } else { + throw new IllegalArgumentException( + String.format("Cannot resolve column (%s) in schema: %s", column, schema)); + } + } + return dataType; + } } diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetFileWriter.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetFileWriter.java index 16843a0bfd1..f2da42f6702 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetFileWriter.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetFileWriter.java @@ -22,7 +22,6 @@ import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileStatus; -import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.parquet.hadoop.ParquetOutputFormat; import org.apache.parquet.hadoop.ParquetWriter; @@ -41,6 +40,7 @@ import static io.delta.kernel.internal.util.Preconditions.checkArgument; import io.delta.kernel.defaults.internal.parquet.ParquetColumnWriters.ColumnWriter; +import static io.delta.kernel.defaults.internal.parquet.ParquetStatsReader.readDataFileStatistics; /** * Implements writing data given as {@link FilteredColumnarBatch} to Parquet files. @@ -351,13 +351,18 @@ private DataFileStatus constructDataFileStatus(String path, StructType dataSchem try { // Get the FileStatus to figure out the file size and modification time Path hadoopPath = new Path(path); - FileSystem hadoopFs = hadoopPath.getFileSystem(configuration); - FileStatus fileStatus = hadoopFs.getFileStatus(hadoopPath); - long fileSize = fileStatus.getLen(); - long modTime = fileStatus.getModificationTime(); - - // TODO: Stats computation is coming next. - return new DataFileStatus(path, fileSize, modTime, Optional.empty()); + FileStatus fileStatus = hadoopPath.getFileSystem(configuration) + .getFileStatus(hadoopPath); + Path resolvedPath = fileStatus.getPath(); + + DataFileStatistics stats = (statsColumns.isEmpty()) ? null : + readDataFileStatistics(resolvedPath, configuration, dataSchema, statsColumns); + + return new DataFileStatus( + resolvedPath.toString(), + fileStatus.getLen(), + fileStatus.getModificationTime(), + Optional.ofNullable(stats)); } catch (IOException ioe) { throw new UncheckedIOException("Failed to read the stats for: " + path, ioe); } diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetStatsReader.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetStatsReader.java new file mode 100644 index 00000000000..1e982ad25c2 --- /dev/null +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetStatsReader.java @@ -0,0 +1,242 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.defaults.internal.parquet; + +import java.io.IOException; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.util.*; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.function.UnaryOperator.identity; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.shaded.com.google.common.collect.ImmutableMultimap; +import org.apache.hadoop.shaded.com.google.common.collect.Multimap; +import org.apache.parquet.column.statistics.*; +import org.apache.parquet.hadoop.ParquetFileReader; +import org.apache.parquet.hadoop.metadata.*; +import org.apache.parquet.schema.LogicalTypeAnnotation; +import org.apache.parquet.schema.LogicalTypeAnnotation.DecimalLogicalTypeAnnotation; +import static org.apache.hadoop.shaded.com.google.common.collect.ImmutableMap.toImmutableMap; + +import io.delta.kernel.expressions.Column; +import io.delta.kernel.expressions.Literal; +import io.delta.kernel.types.*; +import io.delta.kernel.utils.DataFileStatistics; + +import static io.delta.kernel.internal.util.Preconditions.checkArgument; + +import static io.delta.kernel.defaults.internal.DefaultKernelUtils.getDataType; + +/** + * Helper class to read statistics from Parquet files. + */ +public class ParquetStatsReader { + /** + * Read the statistics for the given Parquet file. + * + * @param parquetFilePath The path to the Parquet file. + * @param hadoopConf The Hadoop configuration to use for reading the file. + * @param dataSchema The schema of the Parquet file. Type info is used to decode + * statistics. + * @param statsColumns The columns for which statistics should be collected and returned. + * @return File/column level statistics as {@link DataFileStatistics} instance. + */ + public static DataFileStatistics readDataFileStatistics( + Path parquetFilePath, + Configuration hadoopConf, + StructType dataSchema, + List statsColumns) throws IOException { + // Read the Parquet footer to compute the statistics + ParquetMetadata footer = ParquetFileReader.readFooter(hadoopConf, parquetFilePath); + ImmutableMultimap.Builder metadataForColumn = + ImmutableMultimap.builder(); + + long rowCount = 0; + for (BlockMetaData blockMetaData : footer.getBlocks()) { + rowCount += blockMetaData.getRowCount(); + for (ColumnChunkMetaData columnChunkMetaData : blockMetaData.getColumns()) { + Column column = new Column(columnChunkMetaData.getPath().toArray()); + metadataForColumn.put(column, columnChunkMetaData); + } + } + + return constructFileStats(metadataForColumn.build(), dataSchema, statsColumns, rowCount); + } + + /** + * Merge statistics from multiple rowgroups into a single set of statistics for each column. + * @return Stats for each column in the file as {@link DataFileStatistics}. + */ + private static DataFileStatistics constructFileStats( + Multimap metadataForColumn, + StructType dataSchema, + List statsColumns, + long rowCount) { + Map>> statsForColumn = + metadataForColumn.keySet().stream() + .collect(toImmutableMap( + identity(), + key -> mergeMetadataList(metadataForColumn.get(key)))); + + Map minValues = new HashMap<>(); + Map maxValues = new HashMap<>(); + Map nullCounts = new HashMap<>(); + for (Column statsColumn : statsColumns) { + Optional> stats = statsForColumn.get(statsColumn); + DataType columnType = getDataType(dataSchema, statsColumn); + if (stats == null || !stats.isPresent() || !isStatsSupportedDataType(columnType)) { + continue; + } + Statistics statistics = stats.get(); + + Long numNulls = statistics.isNumNullsSet() ? statistics.getNumNulls() : null; + nullCounts.put(statsColumn, numNulls); + + if (numNulls != null && rowCount == numNulls) { + // If all values are null, then min and max are also null + minValues.put(statsColumn, Literal.ofNull(columnType)); + maxValues.put(statsColumn, Literal.ofNull(columnType)); + continue; + } + + Literal minValue = decodeMinMaxStat(columnType, statistics, true /* decodeMin */); + minValues.put(statsColumn, minValue); + + Literal maxValue = decodeMinMaxStat(columnType, statistics, false /* decodeMin */); + maxValues.put(statsColumn, maxValue); + } + + return new DataFileStatistics(rowCount, minValues, maxValues, nullCounts); + } + + private static Literal decodeMinMaxStat( + DataType dataType, + Statistics statistics, + boolean decodeMin) { + Object statValue = decodeMin ? statistics.genericGetMin() : statistics.genericGetMax(); + if (statValue == null) { + return null; + } + + if (dataType instanceof BooleanType) { + return Literal.ofBoolean((Boolean) statValue); + } else if (dataType instanceof ByteType) { + return Literal.ofByte(((Number) statValue).byteValue()); + } else if (dataType instanceof ShortType) { + return Literal.ofShort(((Number) statValue).shortValue()); + } else if (dataType instanceof IntegerType) { + return Literal.ofInt(((Number) statValue).intValue()); + } else if (dataType instanceof LongType) { + return Literal.ofLong(((Number) statValue).longValue()); + } else if (dataType instanceof FloatType) { + return Literal.ofFloat(((Number) statValue).floatValue()); + } else if (dataType instanceof DoubleType) { + return Literal.ofDouble(((Number) statValue).doubleValue()); + } else if (dataType instanceof DecimalType) { + LogicalTypeAnnotation logicalType = statistics.type().getLogicalTypeAnnotation(); + checkArgument( + logicalType instanceof DecimalLogicalTypeAnnotation, + "Physical decimal column has invalid Parquet Logical Type: %s", logicalType); + int scale = ((DecimalLogicalTypeAnnotation) logicalType).getScale(); + + DecimalType decimalType = (DecimalType) dataType; + + // Check the scale is same in both the Delta data type and the Parquet Logical Type + checkArgument( + scale == decimalType.getScale(), + "Physical decimal type has different scale than the logical type: %s", scale); + + // Decimal is stored either as int, long or binary. Decode the stats accordingly. + BigDecimal decimalStatValue; + if (statistics instanceof IntStatistics) { + decimalStatValue = BigDecimal.valueOf((Integer) statValue).movePointLeft(scale); + } else if (statistics instanceof LongStatistics) { + decimalStatValue = BigDecimal.valueOf((Long) statValue).movePointLeft(scale); + } else if (statistics instanceof BinaryStatistics) { + BigInteger base = new BigInteger(getBinaryStat(statistics, decodeMin)); + decimalStatValue = new BigDecimal(base, scale); + } else { + throw new UnsupportedOperationException( + "Unsupported stats type for Decimal: " + statistics.getClass()); + } + return Literal.ofDecimal(decimalStatValue, decimalType.getPrecision(), + decimalType.getScale()); + } else if (dataType instanceof DateType) { + checkArgument( + statistics instanceof IntStatistics, + "Column with DATE type contained invalid statistics: %s", statistics); + return Literal.ofDate((Integer) statValue); // stats are stored as epoch days in Parquet + } else if (dataType instanceof StringType) { + byte[] binaryStat = getBinaryStat(statistics, decodeMin); + return Literal.ofString(new String(binaryStat, UTF_8)); + } else if (dataType instanceof BinaryType) { + return Literal.ofBinary(getBinaryStat(statistics, decodeMin)); + } + + throw new IllegalArgumentException("Unsupported stats data type: " + statValue); + } + + private static Optional> mergeMetadataList( + Collection metadataList) { + if (hasInvalidStatistics(metadataList)) { + return Optional.empty(); + } + + return metadataList.stream() + .>map(ColumnChunkMetaData::getStatistics) + .reduce((statsA, statsB) -> { + statsA.mergeStatistics(statsB); + return statsA; + }); + } + + private static boolean hasInvalidStatistics(Collection metadataList) { + // If any row group does not have stats collected, stats for the file will not be valid + return metadataList.stream().anyMatch(metadata -> { + Statistics stats = metadata.getStatistics(); + if (stats == null || stats.isEmpty() || !stats.isNumNullsSet()) { + return true; + } + + // Columns with NaN values are marked by `hasNonNullValue` = false by the Parquet reader + // See issue: https://issues.apache.org/jira/browse/PARQUET-1246 + return !stats.hasNonNullValue() && + stats.getNumNulls() != metadata.getValueCount(); + }); + } + + private static boolean isStatsSupportedDataType(DataType dataType) { + return dataType instanceof BooleanType || + dataType instanceof ByteType || + dataType instanceof ShortType || + dataType instanceof IntegerType || + dataType instanceof LongType || + dataType instanceof FloatType || + dataType instanceof DoubleType || + dataType instanceof DecimalType || + dataType instanceof DateType || + dataType instanceof StringType || + dataType instanceof BinaryType; + // TODO: timestamp is complicated to handle because of the storage format (INT96 or INT64). + // Add support later. + } + + private static byte[] getBinaryStat(Statistics statistics, boolean decodeMin) { + return decodeMin ? statistics.getMinBytes() : statistics.getMaxBytes(); + } +} diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetFileWriterSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetFileWriterSuite.scala index 7114bab2dbe..08ac809c368 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetFileWriterSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetFileWriterSuite.scala @@ -15,11 +15,11 @@ */ package io.delta.kernel.defaults.internal.parquet; -import io.delta.golden.GoldenTableUtils.goldenTableFile +import io.delta.golden.GoldenTableUtils.{goldenTableFile, goldenTablePath} import io.delta.kernel.Table import io.delta.kernel.data.{ColumnarBatch, FilteredColumnarBatch} -import io.delta.kernel.defaults.internal.data.DefaultColumnarBatch -import io.delta.kernel.defaults.utils.{TestRow, TestUtils} +import io.delta.kernel.defaults.internal.DefaultKernelUtils +import io.delta.kernel.defaults.utils.{ExpressionTestUtils, TestRow, TestUtils, VectorTestUtils} import io.delta.kernel.expressions.{Column, Literal, Predicate} import io.delta.kernel.internal.util.ColumnMapping import io.delta.kernel.internal.util.ColumnMapping.convertToPhysicalSchema @@ -29,8 +29,10 @@ import io.delta.kernel.utils.{DataFileStatus, FileStatus} import org.apache.hadoop.fs.Path import org.apache.parquet.hadoop.ParquetFileReader import org.apache.parquet.hadoop.metadata.ParquetMetadata +import org.apache.spark.sql.{functions => sparkfn} import org.scalatest.funsuite.AnyFunSuite +import java.lang.{Double => DoubleJ, Float => FloatJ} import java.nio.file.{Files, Paths} import java.util.Optional import scala.collection.JavaConverters._ @@ -60,83 +62,140 @@ import scala.util.control.NonFatal * 4.2) read the new Parquet file(s) using the Spark Parquet reader and compare with (2) * 4.3) verify the stats returned in (3) are correct using the Spark Parquet reader */ -class ParquetFileWriterSuite extends AnyFunSuite with TestUtils { - import ParquetFileWriterSuite._ - - Seq(200, 1000, 1048576).foreach { targetFileSize => - test(s"write all types - no stats - targetFileSize: $targetFileSize") { - withTempDir { tempPath => - val targetDir = tempPath.getAbsolutePath - - val dataToWrite = - readParquetUsingKernelAsColumnarBatches(ALL_TYPES_DATA, ALL_TYPES_FILE_SCHEMA) - .map(_.toFiltered) - - writeToParquetUsingKernel(dataToWrite, targetDir, targetFileSize) - - val expectedNumParquetFiles = targetFileSize match { - case 200 => 100 - case 1000 => 29 - case 1048576 => 1 - case _ => throw new IllegalArgumentException(s"Invalid targetFileSize: $targetFileSize") - } - assert(parquetFileCount(targetDir) === expectedNumParquetFiles) - - verify(targetDir, dataToWrite) - } +class ParquetFileWriterSuite extends AnyFunSuite + with TestUtils with VectorTestUtils with ExpressionTestUtils { + + Seq( + // Test cases reading and writing all types of data with or without stats collection + Seq((200, 100), (1024, 28), (1048576, 1)).map { + case (targetFileSize, expParquetFileCount) => + ( + "write all types (no stats)", // test name + "parquet-all-types", // input table where the data is read and written + targetFileSize, + expParquetFileCount, + 200, /* expected number of rows written to Parquet files */ + Option.empty[Predicate], // predicate for filtering what rows to write to parquet files + Seq.empty[Column], // list of columns to collect stats as part of the Parquet file write + 0 // how many columns have the stats collected from given list above + ) + }, + // Test cases reading and writing decimal types data with different precisions + // They trigger different paths in the Parquet writer as how decimal types are stored in Parquet + // based on the precision and scale. + Seq((1048576, 3), (2048576, 2)).map { + case (targetFileSize, expParquetFileCount) => + ( + "write decimal all types (with stats)", // test name + "parquet-decimal-type", + targetFileSize, + expParquetFileCount, + 99998, /* expected number of rows written to Parquet files */ + Option.empty[Predicate], // predicate for filtering what rows to write to parquet files + leafLevelPrimitiveColumns( + Seq.empty, tableSchema(goldenTablePath("parquet-decimal-type"))), + 4 // how many columns have the stats collected from given list above + ) + }, + // Test cases reading and writing data with field ids. This is for column mapping mode ID. + Seq((200, 3), (1024, 1)).map { + case (targetFileSize, expParquetFileCount) => + ( + "write data with field ids (no stats)", // test name + "table-with-columnmapping-mode-id", + targetFileSize, + expParquetFileCount, + 6, /* expected number of rows written to Parquet files */ + Option.empty[Predicate], // predicate for filtering what rows to write to parquet files + Seq.empty[Column], // list of columns to collect stats as part of the Parquet file write + 0 // how many columns have the stats collected from given list above + ) + }, + // Test cases reading and writing only a subset of data passing a predicate. + Seq((200, 39), (1024, 11), (1048576, 1)).map { + case (targetFileSize, expParquetFileCount) => + ( + "write filtered all types (no stats)", // test name + "parquet-all-types", // input table where the data is read and written + targetFileSize, + expParquetFileCount, + 77, /* expected number of rows written to Parquet files */ + // predicate for filtering what input rows to write to parquet files + Some(greaterThanOrEqual(col("ByteType"), Literal.ofInt(50))), + Seq.empty[Column], // list of columns to collect stats as part of the Parquet file write + 0 // how many columns have the stats collected from given list above + ) + }, + // Test cases reading and writing all types of data WITH stats collection + Seq((200, 100), (1024, 28), (1048576, 1)).map { + case (targetFileSize, expParquetFileCount) => + ( + "write all types (with stats for all leaf-level columns)", // test name + "parquet-all-types", // input table where the data is read and written + targetFileSize, + expParquetFileCount, + 200, /* expected number of rows written to Parquet files */ + Option.empty[Predicate], // predicate for filtering what rows to write to parquet files + leafLevelPrimitiveColumns(Seq.empty, tableSchema(goldenTablePath("parquet-all-types"))), + 13 // how many columns have the stats collected from given list above + ) + }, + // Test cases reading and writing all types of data with a partial column set stats collection + Seq((200, 100), (1024, 28), (1048576, 1)).map { + case (targetFileSize, expParquetFileCount) => + ( + "write all types (with stats for a subset of leaf-level columns)", // test name + "parquet-all-types", // input table where the data is read and written + targetFileSize, + expParquetFileCount, + 200, /* expected number of rows written to Parquet files */ + Option.empty[Predicate], // predicate for filtering what rows to write to parquet files + Seq( + new Column("ByteType"), + new Column("DateType"), + new Column(Array("nested_struct", "aa")), + new Column(Array("nested_struct", "ac", "aca")), + new Column("TimestampType"), // stats are not collected for timestamp type YET. + new Column(Array("nested_struct", "ac")), // stats are not collected for struct types + new Column("nested_struct"), // stats are not collected for struct types + new Column("array_of_prims"), // stats are not collected for array types + new Column("map_of_prims") // stats are not collected for map types + ), + 4 // how many columns have the stats collected from given list above + ) } - } + ).flatten.foreach { + case (name, input, fileSize, expFileCount, expRowCount, predicate, statsCols, expStatsColCnt) => + test(s"$name: targetFileSize=$fileSize, predicate=$predicate") { + withTempDir { tempPath => + val targetDir = tempPath.getAbsolutePath + + val inputLocation = goldenTablePath(input) + val schema = tableSchema(inputLocation) + + val physicalSchema = if (hasColumnMappingId(inputLocation)) { + convertToPhysicalSchema(schema, schema, ColumnMapping.COLUMN_MAPPING_MODE_ID) + } else { + schema + } - Seq(1048576, 2048576).foreach { targetFileSize => - test(s"decimal all types - no stats - targetFileSize: $targetFileSize") { - withTempDir { tempPath => - val targetDir = tempPath.getAbsolutePath + val dataToWrite = + readParquetUsingKernelAsColumnarBatches(inputLocation, physicalSchema) // read data + // Convert the schema of the data to the physical schema with field ids + .map(_.withNewSchema(physicalSchema)) + // convert the data to filtered columnar batches + .map(_.toFiltered(predicate)) - val dataToWrite = - readParquetUsingKernelAsColumnarBatches(DECIMAL_TYPES_DATA, DECIMAL_TYPES_FILE_SCHEMA) - .map(_.toFiltered) + val writeOutput = + writeToParquetUsingKernel(dataToWrite, targetDir, fileSize, statsCols) - writeToParquetUsingKernel(dataToWrite, targetDir, targetFileSize) + assert(parquetFileCount(targetDir) === expFileCount) + assert(parquetFileRowCount(targetDir) == expRowCount) - val expectedNumParquetFiles = targetFileSize match { - case 1048576 => 3 - case 2048576 => 2 - case _ => throw new IllegalArgumentException(s"Invalid targetFileSize: $targetFileSize") + verifyContent(targetDir, dataToWrite) + verifyStatsUsingSpark(targetDir, writeOutput, schema, statsCols, expStatsColCnt) } - assert(parquetFileCount(targetDir) === expectedNumParquetFiles) - - verify(targetDir, dataToWrite) } - } - } - - Seq(200, 1000, 1048576).foreach { targetFileSize => - test(s"write all types - filtered dataset, targetFileSize: $targetFileSize") { - withTempDir { tempPath => - val targetDir = tempPath.getAbsolutePath - - // byteValue is in the range [-72, 127] with null at every (value % 72 == 0) row - // File has total of 200 rows. - val predicate = new Predicate(">=", new Column("byteType"), Literal.ofInt(50)) - val expectedRowCount = 128 /* no. of positive values */ - 1 /* one */ - 50 /* val < 50 */ - val dataToWrite = - readParquetUsingKernelAsColumnarBatches(ALL_TYPES_DATA, ALL_TYPES_FILE_SCHEMA) - .map(_.toFiltered(predicate)) - - writeToParquetUsingKernel(dataToWrite, targetDir, targetFileSize = targetFileSize) - - val expectedNumParquetFiles = targetFileSize match { - case 200 => 39 - case 1000 => 11 - case 1048576 => 1 - case _ => throw new IllegalArgumentException(s"Invalid targetFileSize: $targetFileSize") - } - assert(parquetFileCount(targetDir) === expectedNumParquetFiles) - assert(parquetFileRowCount(targetDir) === expectedRowCount) - - verify(targetDir, dataToWrite) - } - } } test("columnar batches containing different schema") { @@ -144,65 +203,113 @@ class ParquetFileWriterSuite extends AnyFunSuite with TestUtils { val targetDir = tempPath.getAbsolutePath // First batch with one column - val batch1 = new DefaultColumnarBatch( - /* size */ 10, - new StructType().add("col1", IntegerType.INTEGER), - Array(testColumnVector(10, IntegerType.INTEGER))) + val batch1 = columnarBatch(testColumnVector(10, IntegerType.INTEGER)) // Batch with two columns - val batch2 = new DefaultColumnarBatch( - /* size */ 10, - new StructType() - .add("col1", IntegerType.INTEGER) - .add("col2", LongType.LONG), - Array(testColumnVector(10, IntegerType.INTEGER), testColumnVector(10, LongType.LONG))) + val batch2 = columnarBatch( + testColumnVector(10, IntegerType.INTEGER), + testColumnVector(10, LongType.LONG)) // Batch with one column as first batch but different data type - val batch3 = new DefaultColumnarBatch( - /* size */ 10, - new StructType().add("col1", LongType.LONG), - Array(testColumnVector(10, LongType.LONG))) - - Seq(Seq(batch1, batch2), Seq(batch1, batch3)).foreach { - dataToWrite => - val e = intercept[IllegalArgumentException] { - writeToParquetUsingKernel(dataToWrite.map(_.toFiltered), targetDir) - } - assert(e.getMessage.contains("Input data has columnar batches with different schemas:")) + val batch3 = columnarBatch(testColumnVector(10, LongType.LONG)) + + Seq(Seq(batch1, batch2), Seq(batch1, batch3)).foreach { dataToWrite => + val e = intercept[IllegalArgumentException] { + writeToParquetUsingKernel(dataToWrite.map(_.toFiltered), targetDir) + } + assert(e.getMessage.contains("Input data has columnar batches with different schemas:")) } } } - test("write data with field ids") { - withTempDir { tempPath => - val targetDir = tempPath.getAbsolutePath - - val cmGoldenTable = goldenTableFile("table-with-columnmapping-mode-id").toString - val schema = tableSchema(cmGoldenTable) - - val dataToWrite = - readParquetUsingKernelAsColumnarBatches(cmGoldenTable, schema) - .map(_.toFiltered) - - // From the Delta schema, generate the physical schema that has field ids. - val physicalSchema = - convertToPhysicalSchema(schema, schema, ColumnMapping.COLUMN_MAPPING_MODE_ID) - - writeToParquetUsingKernel( - // Convert the schema of the data to the physical schema with field ids - dataToWrite.map(_.getData).map(_.withNewSchema(physicalSchema)).map(_.toFiltered), - targetDir) - - verifyFieldIds(targetDir, schema) - verify(targetDir, dataToWrite) + /** + * Tests to cover floating point comparison special cases in Parquet. + * - https://issues.apache.org/jira/browse/PARQUET-1222 + * - Parquet doesn't collect stats if NaN is present in the column values + * - Min is written as -0.0 instead of 0.0 and max is written as 0.0 instead of -0.0 + */ + test("float/double type column stats collection") { + // Try writing different set of floating point values and verify the stats are correct + // (float values, double values, exp rowCount in files, exp stats (min, max, nullCount) + Seq( + ( // no stats collection as NaN is present + Seq(Float.NegativeInfinity, Float.MinValue, -1.0f, + -0.0f, 0.0f, 1.0f, null, Float.MaxValue, Float.PositiveInfinity, Float.NaN), + Seq(Double.NegativeInfinity, Double.MinValue, -1.0d, + -0.0d, 0.0d, 1.0d, null, Double.MaxValue, Double.PositiveInfinity, Double.NaN), + 10, + (null, null, null), + (null, null, null) + ), + ( // Min and max are infinities + Seq(Float.NegativeInfinity, Float.MinValue, -1.0f, + -0.0f, 0.0f, 1.0f, null, Float.MaxValue, Float.PositiveInfinity), + Seq(Double.NegativeInfinity, Double.MinValue, -1.0d, + -0.0d, 0.0d, 1.0d, null, Double.MaxValue, Double.PositiveInfinity), + 9, + (Float.NegativeInfinity, Float.PositiveInfinity, 1L), + (Double.NegativeInfinity, Double.PositiveInfinity, 1L) + ), + ( // no infinities or NaN - expect stats collected + Seq(Float.MinValue, -1.0f, -0.0f, 0.0f, 1.0f, null, Float.MaxValue), + Seq(Double.MinValue, -1.0d, -0.0d, 0.0d, 1.0d, null, Double.MaxValue), + 7, + (Float.MinValue, Float.MaxValue, 1L), + (Double.MinValue, Double.MaxValue, 1L) + ), + ( // Only negative numbers. Max is 0.0 instead of -0.0 to avoid PARQUET-1222 + Seq(Float.NegativeInfinity, Float.MinValue, -1.0f, -0.0f, null), + Seq(Double.NegativeInfinity, Double.MinValue, -1.0d, -0.0d, null), + 5, + (Float.NegativeInfinity, 0.0f, 1L), + (Double.NegativeInfinity, 0.0d, 1L) + ), + ( // Only positive numbers. Min is -0.0 instead of 0.0 to avoid PARQUET-1222 + Seq(0.0f, 1.0f, null, Float.MaxValue, Float.PositiveInfinity), + Seq(0.0d, 1.0d, null, Double.MaxValue, Double.PositiveInfinity), + 5, + (-0.0f, Float.PositiveInfinity, 1L), + (-0.0d, Double.PositiveInfinity, 1L) + ) + ).foreach { + case (floats: Seq[FloatJ], doubles: Seq[DoubleJ], expRowCount, expFltStats, expDblStats) => + withTempDir { tempPath => + val targetDir = tempPath.getAbsolutePath + val testBatch = columnarBatch(floatVector(floats), doubleVector(doubles)) + val dataToWrite = Seq(testBatch.toFiltered) + + val writeOutput = + writeToParquetUsingKernel( + dataToWrite, + targetDir, + statsColumns = Seq(col("col_0"), col("col_1"))) + + assert(parquetFileRowCount(targetDir) == expRowCount) + verifyContent(targetDir, dataToWrite) + + val stats = writeOutput.head.getStatistics.get() + + def getStats(column: String): (Object, Object, Object) = + ( + Option(stats.getMinValues.get(col(column))).map(_.getValue).orNull, + Option(stats.getMaxValues.get(col(column))).map(_.getValue).orNull, + Option(stats.getNullCounts.get(col(column))).orNull + ) + + assert(getStats("col_0") === expFltStats) + assert(getStats("col_1") === expDblStats) + } } } test(s"invalid target file size") { withTempDir { tempPath => val targetDir = tempPath.getAbsolutePath + val inputLocation = goldenTableFile("parquet-all-types").toString + val schema = tableSchema(inputLocation) + val dataToWrite = - readParquetUsingKernelAsColumnarBatches(DECIMAL_TYPES_DATA, DECIMAL_TYPES_FILE_SCHEMA) + readParquetUsingKernelAsColumnarBatches(inputLocation, schema) .map(_.toFiltered) Seq(-1, 0).foreach { targetFileSize => @@ -214,17 +321,23 @@ class ParquetFileWriterSuite extends AnyFunSuite with TestUtils { } } - def verify(actualFileDir: String, expected: Seq[FilteredColumnarBatch]): Unit = { + /** + * Verify the contents of the Parquet files located in `actualFileDir` matches the + * `expected` data. Does two types of verifications. + * 1) Verify the data using the Kernel Parquet reader + * 2) Verify the data using the Spark Parquet reader + */ + def verifyContent(actualFileDir: String, expected: Seq[FilteredColumnarBatch]): Unit = { verifyFileMetadata(actualFileDir) - verifyUsingKernelReader(actualFileDir, expected) - verifyUsingSparkReader(actualFileDir, expected) + verifyContentUsingKernelReader(actualFileDir, expected) + verifyContentUsingSparkReader(actualFileDir, expected) } /** * Verify the data in the Parquet files located in `actualFileDir` matches the expected data. * Use Kernel Parquet reader to read the data from the Parquet files. */ - def verifyUsingKernelReader( + def verifyContentUsingKernelReader( actualFileDir: String, expected: Seq[FilteredColumnarBatch]): Unit = { @@ -244,7 +357,7 @@ class ParquetFileWriterSuite extends AnyFunSuite with TestUtils { * Verify the data in the Parquet files located in `actualFileDir` matches the expected data. * Use Spark Parquet reader to read the data from the Parquet files. */ - def verifyUsingSparkReader( + def verifyContentUsingSparkReader( actualFileDir: String, expected: Seq[FilteredColumnarBatch]): Unit = { @@ -260,6 +373,73 @@ class ParquetFileWriterSuite extends AnyFunSuite with TestUtils { checkAnswer(actualTestRows, expectedTestRows) } + def verifyStatsUsingSpark( + actualFileDir: String, + actualFileStatuses: Seq[DataFileStatus], + fileDataSchema: StructType, + statsColumns: Seq[Column], + expStatsColCount: Int): Unit = { + + if (statsColumns.isEmpty) return + + val actualStatsOutput = actualFileStatuses + .map { fileStatus => + // validate there are no more the expected number of stats columns + assert(fileStatus.getStatistics.isPresent) + assert(fileStatus.getStatistics.get().getMinValues.size() === expStatsColCount) + assert(fileStatus.getStatistics.get().getMaxValues.size() === expStatsColCount) + assert(fileStatus.getStatistics.get().getNullCounts.size() === expStatsColCount) + + // Convert to TestRow for comparison with the actual values computing using Spark. + fileStatus.toTestRow(statsColumns) + } + + if (expStatsColCount == 0) return + + // Use spark to fetch the stats from the parquet files use them as the expected statistics + // Compare them with the actual stats returned by the Kernel's Parquet writer. + val df = spark.read + .format("parquet") + .parquet(actualFileDir) + .to(fileDataSchema.toSpark) + .select( + sparkfn.col("*"), // select all columns from the parquet files + sparkfn.col("_metadata.file_path").as("path"), // select file path + sparkfn.col("_metadata.file_size").as("size"), // select file size + // select mod time and convert to millis + sparkfn.unix_timestamp( + sparkfn.col("_metadata.file_modification_time")).as("modificationTime") + ) + .groupBy("path", "size", "modificationTime") + + val nullStats = Seq(sparkfn.lit(null), sparkfn.lit(null), sparkfn.lit(null)) + + // Add the row count aggregation + val aggs = Seq(sparkfn.count(sparkfn.col("*")).as("rowCount")) ++ + // add agg for each stats column to get min, max and null count + statsColumns + .flatMap { statColumn => + val dataType = DefaultKernelUtils.getDataType(fileDataSchema, statColumn) + dataType match { + case _: TimestampType => nullStats // not yet supported + case _: StructType => nullStats // no concept of stats for struct types + case _: ArrayType => nullStats // no concept of stats for array types + case _: MapType => nullStats // no concept of stats for map types + case _ => // for all other types + val colName = statColumn.toPath + Seq( + sparkfn.min(colName).as("min_" + colName), + sparkfn.max(colName).as("max_" + colName), + sparkfn.sum(sparkfn.when( + sparkfn.col(colName).isNull, 1).otherwise(0)).as("nullCount_" + colName)) + } + } + + val expectedStatsOutput = df.agg(aggs.head, aggs.tail: _*).collect().map(TestRow(_)) + + checkAnswer(actualStatsOutput, expectedStatsOutput) + } + /** * Verify the metadata of the Parquet files in `targetDir` matches says it is written by Kernel. */ @@ -328,10 +508,10 @@ class ParquetFileWriterSuite extends AnyFunSuite with TestUtils { * verify the data using the Kernel Parquet reader and Spark Parquet reader. */ def writeToParquetUsingKernel( - filteredData: Seq[FilteredColumnarBatch], - location: String, - targetFileSize: Long = 1024 * 1024, - statsColumns: Seq[Column] = Seq.empty): Seq[DataFileStatus] = { + filteredData: Seq[FilteredColumnarBatch], + location: String, + targetFileSize: Long = 1024 * 1024, + statsColumns: Seq[Column] = Seq.empty): Seq[DataFileStatus] = { val parquetWriter = new ParquetFileWriter( configuration, new Path(location), targetFileSize, statsColumns.asJava) @@ -400,60 +580,64 @@ class ParquetFileWriterSuite extends AnyFunSuite with TestUtils { } } - def tableSchema(path: String): StructType = { + def tableSchema(path: String): StructType = Table.forPath(defaultTableClient, path) .getLatestSnapshot(defaultTableClient) .getSchema(defaultTableClient) + + def hasColumnMappingId(str: String): Boolean = { + val table = Table.forPath(defaultTableClient, str) + val schema = table.getLatestSnapshot(defaultTableClient).getSchema(defaultTableClient) + schema.fields().asScala.exists { field => + field.getMetadata.contains(ColumnMapping.COLUMN_MAPPING_ID_KEY) + } } -} -object ParquetFileWriterSuite { - // Parquet file containing data of all supported types and variations - val ALL_TYPES_DATA = goldenTableFile("parquet-all-types").toString - // Schema of the data in `ALL_TYPES_DATA` - val ALL_TYPES_FILE_SCHEMA = new StructType() - .add("byteType", ByteType.BYTE) - .add("shortType", ShortType.SHORT) - .add("integerType", IntegerType.INTEGER) - .add("longType", LongType.LONG) - .add("floatType", FloatType.FLOAT) - .add("doubleType", DoubleType.DOUBLE) - .add("decimal", new DecimalType(10, 2)) - .add("booleanType", BooleanType.BOOLEAN) - .add("stringType", StringType.STRING) - .add("binaryType", BinaryType.BINARY) - .add("dateType", DateType.DATE) - .add("timestampType", TimestampType.TIMESTAMP) - .add("nested_struct", - new StructType() - .add("aa", StringType.STRING) - .add("ac", - new StructType() - .add("aca", IntegerType.INTEGER))) - .add("array_of_prims", new ArrayType(IntegerType.INTEGER, true)) - .add("array_of_arrays", new ArrayType(new ArrayType(IntegerType.INTEGER, true), true)) - .add("array_of_structs", - new ArrayType( - new StructType() - .add("ab", LongType.LONG), true)) - .add("map_of_prims", new MapType(IntegerType.INTEGER, LongType.LONG, true)) - .add("map_of_rows", - new MapType( - IntegerType.INTEGER, - new StructType().add("ab", LongType.LONG), - true)) - .add("map_of_arrays", - new MapType( - LongType.LONG, - new ArrayType(IntegerType.INTEGER, true), - true)) - - // Parquet file containing all variations (int, long and fixed binary) Decimal type data - val DECIMAL_TYPES_DATA = goldenTableFile("parquet-decimal-type").toString - // Schema of the data in `DECIMAL_TYPES_DATA` - val DECIMAL_TYPES_FILE_SCHEMA = new StructType() - .add("id", IntegerType.INTEGER) - .add("col1", new DecimalType(5, 1)) // stored as int - .add("col2", new DecimalType(10, 5)) // stored as long - .add("col3", new DecimalType(20, 5)) // stored as fixed binary + /** Get the list of all leaf-level primitive column references in the given `structType` */ + def leafLevelPrimitiveColumns(basePath: Seq[String], structType: StructType): Seq[Column] = { + structType.fields.asScala.flatMap { + case field if field.getDataType.isInstanceOf[StructType] => + leafLevelPrimitiveColumns( + basePath :+ field.getName, + field.getDataType.asInstanceOf[StructType]) + case field if !field.getDataType.isInstanceOf[ArrayType] && + !field.getDataType.isInstanceOf[MapType] => + // for all primitive types + Seq(new Column((basePath :+ field.getName).asJava.toArray(new Array[String](0)))); + case _ => Seq.empty + } + } + + implicit class DataFileStatusOps(dataFileStatus: DataFileStatus) { + /** + * Convert the [[DataFileStatus]] to a [[TestRow]]. + * (path, size, modification time, numRecords, + * min_col1, max_col1, nullCount_col1 (..repeated for every stats column) + * ) + */ + def toTestRow(statsColumns: Seq[Column]): TestRow = { + val statsOpt = dataFileStatus.getStatistics + val record: Seq[Any] = { + dataFileStatus.getPath +: + dataFileStatus.getSize +: + // convert to seconds, Spark returns in seconds and we can compare at second level + (dataFileStatus.getModificationTime / 1000) +: + // Add the row count to the stats literals + (if (statsOpt.isPresent) statsOpt.get().getNumRecords else null) +: + statsColumns.flatMap { column => + if (statsOpt.isPresent) { + val stats = statsOpt.get() + Seq( + Option(stats.getMinValues.get(column)).map(_.getValue).orNull, + Option(stats.getMaxValues.get(column)).map(_.getValue).orNull, + Option(stats.getNullCounts.get(column)).orNull + ) + } else { + Seq(null, null, null) + } + } + } + TestRow(record: _*) + } + } } diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/ExpressionTestUtils.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/ExpressionTestUtils.scala index d0b594c92c6..82acece375b 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/ExpressionTestUtils.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/ExpressionTestUtils.scala @@ -15,7 +15,7 @@ */ package io.delta.kernel.defaults.utils -import io.delta.kernel.expressions.{Column, Expression, Predicate} +import io.delta.kernel.expressions._ /** Useful helper functions for creating expressions in tests */ trait ExpressionTestUtils { @@ -54,6 +54,14 @@ trait ExpressionTestUtils { new Column(name.split("\\.")) } + protected def and(left: Predicate, right: Predicate): And = { + new And(left, right) + } + + protected def or(left: Predicate, right: Predicate): Or = { + new Or(left, right) + } + /* ---------- NOT-YET SUPPORTED EXPRESSIONS ----------- */ /* diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala index 00b47d3625f..2f7e7f75951 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala @@ -29,7 +29,7 @@ import io.delta.kernel.client.TableClient import io.delta.kernel.data.{ColumnVector, ColumnarBatch, FilteredColumnarBatch, MapValue, Row} import io.delta.kernel.defaults.client.DefaultTableClient import io.delta.kernel.defaults.internal.data.vector.DefaultGenericVector -import io.delta.kernel.expressions.Predicate +import io.delta.kernel.expressions.{Column, Predicate} import io.delta.kernel.internal.InternalScanFileUtils import io.delta.kernel.internal.data.ScanStateRow import io.delta.kernel.internal.util.Utils.singletonCloseableIterator @@ -96,14 +96,22 @@ trait TestUtils extends Assertions with SQLHelper { new FilteredColumnarBatch(batch, Optional.empty()) } - def toFiltered(predicate: Predicate): FilteredColumnarBatch = { - val predicateEvaluator = defaultTableClient.getExpressionHandler - .getPredicateEvaluator(batch.getSchema, predicate) - val selVector = predicateEvaluator.eval(batch, Optional.empty()) - new FilteredColumnarBatch(batch, Optional.of(selVector)) + def toFiltered(predicate: Option[Predicate]): FilteredColumnarBatch = { + if (predicate.isEmpty) { + new FilteredColumnarBatch(batch, Optional.empty()) + } else { + val predicateEvaluator = defaultTableClient.getExpressionHandler + .getPredicateEvaluator(batch.getSchema, predicate.get) + val selVector = predicateEvaluator.eval(batch, Optional.empty()) + new FilteredColumnarBatch(batch, Optional.of(selVector)) + } } } + implicit class ColumnOps(column: Column) { + def toPath: String = column.getNames.mkString(".") + } + implicit object ResourceLoader { lazy val classLoader: ClassLoader = ResourceLoader.getClass.getClassLoader } diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/VectorTestUtils.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/VectorTestUtils.scala index 772d69e426e..ee0de2bcad2 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/VectorTestUtils.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/VectorTestUtils.scala @@ -15,10 +15,10 @@ */ package io.delta.kernel.defaults.utils -import java.lang.{Boolean => BooleanJ} - -import io.delta.kernel.data.ColumnVector -import io.delta.kernel.types.{BooleanType, DataType, StringType} +import java.lang.{Boolean => BooleanJ, Double => DoubleJ, Float => FloatJ} +import io.delta.kernel.data.{ColumnVector, ColumnarBatch} +import io.delta.kernel.defaults.internal.data.DefaultColumnarBatch +import io.delta.kernel.types._ trait VectorTestUtils { @@ -50,4 +50,46 @@ trait VectorTestUtils { } } + protected def floatVector(values: Seq[FloatJ]): ColumnVector = { + new ColumnVector { + override def getDataType: DataType = FloatType.FLOAT + + override def getSize: Int = values.length + + override def close(): Unit = {} + + override def isNullAt(rowId: Int): Boolean = (values(rowId) == null) + + override def getFloat(rowId: Int): Float = values(rowId) + } + } + + protected def doubleVector(values: Seq[DoubleJ]): ColumnVector = { + new ColumnVector { + override def getDataType: DataType = DoubleType.DOUBLE + + override def getSize: Int = values.length + + override def close(): Unit = {} + + override def isNullAt(rowId: Int): Boolean = (values(rowId) == null) + + override def getDouble(rowId: Int): Double = values(rowId) + } + } + + /** + * Returns a [[ColumnarBatch]] with each given vector is a top-level column col_i where i is + * the index of the vector in the input list. + */ + protected def columnarBatch(vectors: ColumnVector*): ColumnarBatch = { + val numRows = vectors.head.getSize + vectors.tail.foreach( + v => require(v.getSize == numRows, "All vectors should have the same size")) + + val schema = (0 until vectors.length) + .foldLeft(new StructType())((s, i) => s.add(s"col_$i", vectors(i).getDataType)) + + new DefaultColumnarBatch(numRows, schema, vectors.toArray) + } }