From b7f6770b06a22e2c0b59a2a58e16688d08918c14 Mon Sep 17 00:00:00 2001 From: Will Raschkowski Date: Fri, 3 Jul 2020 18:05:19 +0200 Subject: [PATCH] Revert "[SPARK-25299] Put SPARK-25299 changes into master (#606)" This reverts commit e153f7363426bed666e88e9fe1461d91d016046b. --- .circleci/config.yml | 5 +- .../api/MapOutputWriterCommitMessage.java | 53 --- .../spark/shuffle/api/ShuffleBlockInfo.java | 91 ----- .../spark/shuffle/api/ShuffleDataIO.java | 53 --- .../shuffle/api/ShuffleDriverComponents.java | 39 -- .../api/ShuffleExecutorComponents.java | 91 ----- .../shuffle/api/ShuffleMapOutputWriter.java | 80 ---- .../shuffle/api/ShufflePartitionWriter.java | 98 ----- .../SingleSpillShuffleMapOutputWriter.java | 37 -- .../api/WritableByteChannelWrapper.java | 42 -- .../sort/BypassMergeSortShuffleWriter.java | 207 ++++------ .../shuffle/sort/UnsafeShuffleWriter.java | 364 ++++++++---------- .../sort/io/LocalDiskShuffleDataIO.java | 48 --- .../LocalDiskShuffleExecutorComponents.java | 126 ------ .../io/LocalDiskShuffleMapOutputWriter.java | 281 -------------- .../LocalDiskSingleSpillMapOutputWriter.java | 61 --- .../LocalDiskShuffleDriverComponents.java | 60 --- .../org/apache/spark/ContextCleaner.scala | 8 +- .../scala/org/apache/spark/Dependency.scala | 1 - .../org/apache/spark/MapOutputTracker.scala | 49 +-- .../scala/org/apache/spark/SparkContext.scala | 15 +- .../apache/spark/executor/TaskMetrics.scala | 12 +- .../spark/internal/config/package.scala | 9 +- .../apache/spark/scheduler/DAGScheduler.scala | 72 ++-- .../apache/spark/scheduler/MapStatus.scala | 63 +-- .../shuffle/BlockStoreShuffleReader.scala | 67 +--- .../spark/shuffle/FetchFailedException.scala | 13 +- .../shuffle/IndexShuffleBlockResolver.scala | 2 +- .../shuffle/ShufflePartitionPairsWriter.scala | 135 ------- .../io/LocalDiskShuffleReadSupport.scala | 111 ------ .../shuffle/sort/SortShuffleManager.scala | 43 +-- .../shuffle/sort/SortShuffleWriter.scala | 26 +- .../org/apache/spark/storage/BlockId.scala | 9 +- .../apache/spark/storage/BlockManagerId.scala | 23 +- .../spark/storage/DiskBlockObjectWriter.scala | 6 +- .../storage/ShuffleBlockFetcherIterator.scala | 16 +- .../scala/org/apache/spark/util/Utils.scala | 30 +- .../util/collection/ExternalSorter.scala | 83 +--- .../spark/util/collection/PairsWriter.scala | 28 -- .../WritablePartitionedPairCollection.scala | 4 +- .../sort/UnsafeShuffleWriterSuite.java | 68 +--- .../spark/InternalAccumulatorSuite.scala | 3 +- .../apache/spark/MapOutputTrackerSuite.scala | 167 ++------ .../scala/org/apache/spark/ShuffleSuite.scala | 22 +- .../DAGSchedulerShufflePluginSuite.scala | 169 -------- .../spark/scheduler/DAGSchedulerSuite.scala | 48 +-- .../spark/scheduler/MapStatusSuite.scala | 18 +- .../scheduler/TaskSchedulerImplSuite.scala | 2 +- .../serializer/KryoSerializerSuite.scala | 3 +- .../BlockStoreShuffleReaderSuite.scala | 55 +-- .../ShuffleDriverComponentsSuite.scala | 84 ---- .../BypassMergeSortShuffleWriterSuite.scala | 149 +++---- .../shuffle/sort/SortShuffleWriterSuite.scala | 117 ------ ...LocalDiskShuffleMapOutputWriterSuite.scala | 161 -------- .../ShuffleBlockFetcherIteratorSuite.scala | 32 +- .../spark/sql/execution/ShuffledRowRDD.scala | 7 +- 56 files changed, 569 insertions(+), 3097 deletions(-) delete mode 100644 core/src/main/java/org/apache/spark/shuffle/api/MapOutputWriterCommitMessage.java delete mode 100644 core/src/main/java/org/apache/spark/shuffle/api/ShuffleBlockInfo.java delete mode 100644 core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java delete mode 100644 core/src/main/java/org/apache/spark/shuffle/api/ShuffleDriverComponents.java delete mode 100644 core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java delete mode 100644 core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java delete mode 100644 core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java delete mode 100644 core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java delete mode 100644 core/src/main/java/org/apache/spark/shuffle/api/WritableByteChannelWrapper.java delete mode 100644 core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java delete mode 100644 core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java delete mode 100644 core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java delete mode 100644 core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java delete mode 100644 core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/LocalDiskShuffleDriverComponents.java delete mode 100644 core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala delete mode 100644 core/src/main/scala/org/apache/spark/shuffle/io/LocalDiskShuffleReadSupport.scala delete mode 100644 core/src/main/scala/org/apache/spark/util/collection/PairsWriter.scala delete mode 100644 core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerShufflePluginSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala diff --git a/.circleci/config.yml b/.circleci/config.yml index b0c027d6a4faf..1bf1e1f67f2a8 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -8,6 +8,7 @@ defaults: &defaults TERM: dumb BUILD_SBT_CACHE: "/home/circleci/build-sbt-cache" + test-defaults: &test-defaults <<: *defaults environment: @@ -27,8 +28,8 @@ deployable-branches-and-tags: &deployable-branches-and-tags tags: only: /[0-9]+(?:\.[0-9]+){2,}-palantir\.[0-9]+(?:\.[0-9]+)*/ branches: - only: - - master + only: master + # Step templates diff --git a/core/src/main/java/org/apache/spark/shuffle/api/MapOutputWriterCommitMessage.java b/core/src/main/java/org/apache/spark/shuffle/api/MapOutputWriterCommitMessage.java deleted file mode 100644 index 5a1c82499b715..0000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/api/MapOutputWriterCommitMessage.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.api; - -import java.util.Optional; - -import org.apache.spark.annotation.Private; -import org.apache.spark.storage.BlockManagerId; - -@Private -public final class MapOutputWriterCommitMessage { - - private final long[] partitionLengths; - private final Optional location; - - private MapOutputWriterCommitMessage( - long[] partitionLengths, Optional location) { - this.partitionLengths = partitionLengths; - this.location = location; - } - - public static MapOutputWriterCommitMessage of(long[] partitionLengths) { - return new MapOutputWriterCommitMessage(partitionLengths, Optional.empty()); - } - - public static MapOutputWriterCommitMessage of( - long[] partitionLengths, BlockManagerId location) { - return new MapOutputWriterCommitMessage(partitionLengths, Optional.of(location)); - } - - public long[] getPartitionLengths() { - return partitionLengths; - } - - public Optional getLocation() { - return location; - } -} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleBlockInfo.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleBlockInfo.java deleted file mode 100644 index 72a67c76f28b5..0000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleBlockInfo.java +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.api; - -import java.util.Objects; - -import org.apache.spark.api.java.Optional; -import org.apache.spark.storage.BlockManagerId; - -/** - * :: Experimental :: - * An object defining the shuffle block and length metadata associated with the block. - * @since 3.0.0 - */ -public class ShuffleBlockInfo { - private final int shuffleId; - private final int mapId; - private final int reduceId; - private final long length; - private final long mapTaskAttemptId; - private final Optional shuffleLocation; - - public ShuffleBlockInfo( - int shuffleId, - int mapId, - int reduceId, - long length, - long mapTaskAttemptId, - Optional shuffleLocation) { - this.shuffleId = shuffleId; - this.mapId = mapId; - this.reduceId = reduceId; - this.length = length; - this.mapTaskAttemptId = mapTaskAttemptId; - this.shuffleLocation = shuffleLocation; - } - - public int getShuffleId() { - return shuffleId; - } - - public int getMapId() { - return mapId; - } - - public int getReduceId() { - return reduceId; - } - - public long getLength() { - return length; - } - - public long getMapTaskAttemptId() { - return mapTaskAttemptId; - } - - public Optional getShuffleLocation() { - return shuffleLocation; - } - - @Override - public boolean equals(Object other) { - return other instanceof ShuffleBlockInfo - && shuffleId == ((ShuffleBlockInfo) other).shuffleId - && mapId == ((ShuffleBlockInfo) other).mapId - && reduceId == ((ShuffleBlockInfo) other).reduceId - && length == ((ShuffleBlockInfo) other).length - && Objects.equals(shuffleLocation, ((ShuffleBlockInfo) other).shuffleLocation); - } - - @Override - public int hashCode() { - return Objects.hash(shuffleId, mapId, reduceId, length, shuffleLocation); - } -} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java deleted file mode 100644 index 5126f0c3577f8..0000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.api; - -import org.apache.spark.annotation.Private; - -/** - * :: Private :: - * An interface for plugging in modules for storing and reading temporary shuffle data. - *

- * This is the root of a plugin system for storing shuffle bytes to arbitrary storage - * backends in the sort-based shuffle algorithm implemented by the - * {@link org.apache.spark.shuffle.sort.SortShuffleManager}. If another shuffle algorithm is - * needed instead of sort-based shuffle, one should implement - * {@link org.apache.spark.shuffle.ShuffleManager} instead. - *

- * A single instance of this module is loaded per process in the Spark application. - * The default implementation reads and writes shuffle data from the local disks of - * the executor, and is the implementation of shuffle file storage that has remained - * consistent throughout most of Spark's history. - *

- * Alternative implementations of shuffle data storage can be loaded via setting - * spark.shuffle.sort.io.plugin.class. - * @since 3.0.0 - */ -@Private -public interface ShuffleDataIO { - - String SHUFFLE_SPARK_CONF_PREFIX = "spark.shuffle.plugin."; - - ShuffleDriverComponents driver(); - - /** - * Called once on executor processes to bootstrap the shuffle data storage modules that - * are only invoked on the executors. - */ - ShuffleExecutorComponents executor(); -} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDriverComponents.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDriverComponents.java deleted file mode 100644 index cbc59bc7b6a05..0000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDriverComponents.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.api; - -import java.io.IOException; -import java.util.Map; - -public interface ShuffleDriverComponents { - - /** - * @return additional SparkConf values necessary for the executors. - */ - Map initializeApplication(); - - default void cleanupApplication() throws IOException {} - - default void registerShuffle(int shuffleId) throws IOException {} - - default void removeShuffle(int shuffleId, boolean blocking) throws IOException {} - - default boolean shouldUnregisterOutputOnHostOnFetchFailure() { - return false; - } -} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java deleted file mode 100644 index 94c07009f3180..0000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.api; - -import java.io.IOException; -import java.io.InputStream; -import java.util.Map; -import java.util.Optional; - -import org.apache.spark.annotation.Private; - -/** - * :: Private :: - * An interface for building shuffle support for Executors. - * - * @since 3.0.0 - */ -@Private -public interface ShuffleExecutorComponents { - - /** - * Called once per executor to bootstrap this module with state that is specific to - * that executor, specifically the application ID and executor ID. - */ - void initializeExecutor(String appId, String execId, Map extraConfigs); - - /** - * Called once per map task to create a writer that will be responsible for persisting all the - * partitioned bytes written by that map task. - * - * @param shuffleId Unique identifier for the shuffle the map task is a part of - * @param mapId Within the shuffle, the identifier of the map task - * @param mapTaskAttemptId Identifier of the task attempt. Multiple attempts of the same map task - * with the same (shuffleId, mapId) pair can be distinguished by the - * different values of mapTaskAttemptId. - * @param numPartitions The number of partitions that will be written by the map task. Some of - * these partitions may be empty. - */ - ShuffleMapOutputWriter createMapOutputWriter( - int shuffleId, - int mapId, - long mapTaskAttemptId, - int numPartitions) throws IOException; - - /** - * Returns an underlying {@link Iterable} that will iterate - * through shuffle data, given an iterable for the shuffle blocks to fetch. - */ - Iterable getPartitionReaders(Iterable blockMetadata) - throws IOException; - - default boolean shouldWrapPartitionReaderStream() { - return true; - } - - /** - * An optional extension for creating a map output writer that can optimize the transfer of a - * single partition file, as the entire result of a map task, to the backing store. - *

- * Most implementations should return the default {@link Optional#empty()} to indicate that - * they do not support this optimization. This primarily is for backwards-compatibility in - * preserving an optimization in the local disk shuffle storage implementation. - * - * @param shuffleId Unique identifier for the shuffle the map task is a part of - * @param mapId Within the shuffle, the identifier of the map task - * @param mapTaskAttemptId Identifier of the task attempt. Multiple attempts of the same map task - * with the same (shuffleId, mapId) pair can be distinguished by the - * different values of mapTaskAttemptId. - */ - default Optional createSingleFileMapOutputWriter( - int shuffleId, - int mapId, - long mapTaskAttemptId) throws IOException { - return Optional.empty(); - } -} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java deleted file mode 100644 index 8fcc73ba3c9b2..0000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.api; - -import java.io.IOException; - -import org.apache.spark.annotation.Private; - -/** - * :: Private :: - * A top-level writer that returns child writers for persisting the output of a map task, - * and then commits all of the writes as one atomic operation. - * - * @since 3.0.0 - */ -@Private -public interface ShuffleMapOutputWriter { - - /** - * Creates a writer that can open an output stream to persist bytes targeted for a given reduce - * partition id. - *

- * The chunk corresponds to bytes in the given reduce partition. This will not be called twice - * for the same partition within any given map task. The partition identifier will be in the - * range of precisely 0 (inclusive) to numPartitions (exclusive), where numPartitions was - * provided upon the creation of this map output writer via - * {@link ShuffleExecutorComponents#createMapOutputWriter(int, int, long, int)}. - *

- * Calls to this method will be invoked with monotonically increasing reducePartitionIds; each - * call to this method will be called with a reducePartitionId that is strictly greater than - * the reducePartitionIds given to any previous call to this method. This method is not - * guaranteed to be called for every partition id in the above described range. In particular, - * no guarantees are made as to whether or not this method will be called for empty partitions. - */ - ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws IOException; - - /** - * Commits the writes done by all partition writers returned by all calls to this object's - * {@link #getPartitionWriter(int)}, and returns a bundle of metadata associated with the - * behavior of the write. - *

- * This should ensure that the writes conducted by this module's partition writers are - * available to downstream reduce tasks. If this method throws any exception, this module's - * {@link #abort(Throwable)} method will be invoked before propagating the exception. - *

- * This can also close any resources and clean up temporary state if necessary. - *

- * The returned array should contain two sets of metadata: - * - * 1. For each partition from (0) to (numPartitions - 1), the number of bytes written by - * the partition writer for that partition id. - * - * 2. If the partition data was stored on the local disk of this executor, also provide - * the block manager id where these bytes can be fetched from. - */ - MapOutputWriterCommitMessage commitAllPartitions() throws IOException; - - /** - * Abort all of the writes done by any writers returned by {@link #getPartitionWriter(int)}. - *

- * This should invalidate the results of writing bytes. This can also close any resources and - * clean up temporary state if necessary. - */ - void abort(Throwable error) throws IOException; -} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java deleted file mode 100644 index 928875156a70f..0000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.api; - -import java.io.IOException; -import java.util.Optional; -import java.io.OutputStream; - -import org.apache.spark.annotation.Private; - -/** - * :: Private :: - * An interface for opening streams to persist partition bytes to a backing data store. - *

- * This writer stores bytes for one (mapper, reducer) pair, corresponding to one shuffle - * block. - * - * @since 3.0.0 - */ -@Private -public interface ShufflePartitionWriter { - - /** - * Open and return an {@link OutputStream} that can write bytes to the underlying - * data store. - *

- * This method will only be called once on this partition writer in the map task, to write the - * bytes to the partition. The output stream will only be used to write the bytes for this - * partition. The map task closes this output stream upon writing all the bytes for this - * block, or if the write fails for any reason. - *

- * Implementations that intend on combining the bytes for all the partitions written by this - * map task should reuse the same OutputStream instance across all the partition writers provided - * by the parent {@link ShuffleMapOutputWriter}. If one does so, ensure that - * {@link OutputStream#close()} does not close the resource, since it will be reused across - * partition writes. The underlying resources should be cleaned up in - * {@link ShuffleMapOutputWriter#commitAllPartitions()} and - * {@link ShuffleMapOutputWriter#abort(Throwable)}. - */ - OutputStream openStream() throws IOException; - - /** - * Opens and returns a {@link WritableByteChannelWrapper} for transferring bytes from - * input byte channels to the underlying shuffle data store. - *

- * This method will only be called once on this partition writer in the map task, to write the - * bytes to the partition. The channel will only be used to write the bytes for this - * partition. The map task closes this channel upon writing all the bytes for this - * block, or if the write fails for any reason. - *

- * Implementations that intend on combining the bytes for all the partitions written by this - * map task should reuse the same channel instance across all the partition writers provided - * by the parent {@link ShuffleMapOutputWriter}. If one does so, ensure that - * {@link WritableByteChannelWrapper#close()} does not close the resource, since the channel - * will be reused across partition writes. The underlying resources should be cleaned up in - * {@link ShuffleMapOutputWriter#commitAllPartitions()} and - * {@link ShuffleMapOutputWriter#abort(Throwable)}. - *

- * This method is primarily for advanced optimizations where bytes can be copied from the input - * spill files to the output channel without copying data into memory. If such optimizations are - * not supported, the implementation should return {@link Optional#empty()}. By default, the - * implementation returns {@link Optional#empty()}. - *

- * Note that the returned {@link WritableByteChannelWrapper} itself is closed, but not the - * underlying channel that is returned by {@link WritableByteChannelWrapper#channel()}. Ensure - * that the underlying channel is cleaned up in {@link WritableByteChannelWrapper#close()}, - * {@link ShuffleMapOutputWriter#commitAllPartitions()}, or - * {@link ShuffleMapOutputWriter#abort(Throwable)}. - */ - default Optional openChannelWrapper() throws IOException { - return Optional.empty(); - } - - /** - * Returns the number of bytes written either by this writer's output stream opened by - * {@link #openStream()} or the byte channel opened by {@link #openChannelWrapper()}. - *

- * This can be different from the number of bytes given by the caller. For example, the - * stream might compress or encrypt the bytes before persisting the data to the backing - * data store. - */ - long getNumBytesWritten(); -} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java deleted file mode 100644 index bddb97bdf0d7e..0000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.api; - -import java.io.File; -import java.io.IOException; - -import org.apache.spark.annotation.Private; - -/** - * Optional extension for partition writing that is optimized for transferring a single - * file to the backing store. - */ -@Private -public interface SingleSpillShuffleMapOutputWriter { - - /** - * Transfer a file that contains the bytes of all the partitions written by this map task. - */ - MapOutputWriterCommitMessage transferMapSpillFile( - File mapOutputFile, long[] partitionLengths) throws IOException; -} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/WritableByteChannelWrapper.java b/core/src/main/java/org/apache/spark/shuffle/api/WritableByteChannelWrapper.java deleted file mode 100644 index a204903008a51..0000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/api/WritableByteChannelWrapper.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.api; - -import java.io.Closeable; -import java.nio.channels.WritableByteChannel; - -import org.apache.spark.annotation.Private; - -/** - * :: Private :: - * - * A thin wrapper around a {@link WritableByteChannel}. - *

- * This is primarily provided for the local disk shuffle implementation to provide a - * {@link java.nio.channels.FileChannel} that keeps the channel open across partition writes. - * - * @since 3.0.0 - */ -@Private -public interface WritableByteChannelWrapper extends Closeable { - - /** - * The underlying channel to write bytes into. - */ - WritableByteChannel channel(); -} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 94ad5fc66185b..32b446785a9f0 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -19,10 +19,8 @@ import java.io.File; import java.io.FileInputStream; +import java.io.FileOutputStream; import java.io.IOException; -import java.io.OutputStream; -import java.nio.channels.FileChannel; -import java.util.Optional; import javax.annotation.Nullable; import scala.None$; @@ -36,20 +34,16 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.internal.config.package$; import org.apache.spark.Partitioner; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; -import org.apache.spark.shuffle.api.MapOutputWriterCommitMessage; -import org.apache.spark.shuffle.api.ShuffleExecutorComponents; -import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; -import org.apache.spark.shuffle.api.ShufflePartitionWriter; -import org.apache.spark.shuffle.api.WritableByteChannelWrapper; -import org.apache.spark.internal.config.package$; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; import org.apache.spark.serializer.Serializer; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.*; import org.apache.spark.util.Utils; @@ -87,15 +81,14 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { private final ShuffleWriteMetricsReporter writeMetrics; private final int shuffleId; private final int mapId; - private final long mapTaskAttemptId; private final Serializer serializer; - private final ShuffleExecutorComponents shuffleExecutorComponents; + private final IndexShuffleBlockResolver shuffleBlockResolver; /** Array of file writers, one for each partition */ private DiskBlockObjectWriter[] partitionWriters; private FileSegment[] partitionWriterSegments; @Nullable private MapStatus mapStatus; - private MapOutputWriterCommitMessage commitMessage; + private long[] partitionLengths; /** * Are we in the process of stopping? Because map tasks can call stop() with success = true @@ -106,89 +99,79 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { BypassMergeSortShuffleWriter( BlockManager blockManager, + IndexShuffleBlockResolver shuffleBlockResolver, BypassMergeSortShuffleHandle handle, int mapId, - long mapTaskAttemptId, SparkConf conf, - ShuffleWriteMetricsReporter writeMetrics, - ShuffleExecutorComponents shuffleExecutorComponents) { + ShuffleWriteMetricsReporter writeMetrics) { // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSize = (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true); this.blockManager = blockManager; final ShuffleDependency dep = handle.dependency(); this.mapId = mapId; - this.mapTaskAttemptId = mapTaskAttemptId; this.shuffleId = dep.shuffleId(); this.partitioner = dep.partitioner(); this.numPartitions = partitioner.numPartitions(); this.writeMetrics = writeMetrics; this.serializer = dep.serializer(); - this.shuffleExecutorComponents = shuffleExecutorComponents; + this.shuffleBlockResolver = shuffleBlockResolver; } @Override public void write(Iterator> records) throws IOException { assert (partitionWriters == null); - ShuffleMapOutputWriter mapOutputWriter = shuffleExecutorComponents - .createMapOutputWriter(shuffleId, mapId, mapTaskAttemptId, numPartitions); - try { - if (!records.hasNext()) { - commitMessage = mapOutputWriter.commitAllPartitions(); - mapStatus = MapStatus$.MODULE$.apply( - commitMessage.getLocation().orElse(null), - commitMessage.getPartitionLengths(), - mapTaskAttemptId); - return; - } - final SerializerInstance serInstance = serializer.newInstance(); - final long openStartTime = System.nanoTime(); - partitionWriters = new DiskBlockObjectWriter[numPartitions]; - partitionWriterSegments = new FileSegment[numPartitions]; - for (int i = 0; i < numPartitions; i++) { - final Tuple2 tempShuffleBlockIdPlusFile = - blockManager.diskBlockManager().createTempShuffleBlock(); - final File file = tempShuffleBlockIdPlusFile._2(); - final BlockId blockId = tempShuffleBlockIdPlusFile._1(); - partitionWriters[i] = - blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics); - } - // Creating the file to write to and creating a disk writer both involve interacting with - // the disk, and can take a long time in aggregate when we open many files, so should be - // included in the shuffle write time. - writeMetrics.incWriteTime(System.nanoTime() - openStartTime); - - while (records.hasNext()) { - final Product2 record = records.next(); - final K key = record._1(); - partitionWriters[partitioner.getPartition(key)].write(key, record._2()); - } + if (!records.hasNext()) { + partitionLengths = new long[numPartitions]; + shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + return; + } + final SerializerInstance serInstance = serializer.newInstance(); + final long openStartTime = System.nanoTime(); + partitionWriters = new DiskBlockObjectWriter[numPartitions]; + partitionWriterSegments = new FileSegment[numPartitions]; + for (int i = 0; i < numPartitions; i++) { + final Tuple2 tempShuffleBlockIdPlusFile = + blockManager.diskBlockManager().createTempShuffleBlock(); + final File file = tempShuffleBlockIdPlusFile._2(); + final BlockId blockId = tempShuffleBlockIdPlusFile._1(); + partitionWriters[i] = + blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics); + } + // Creating the file to write to and creating a disk writer both involve interacting with + // the disk, and can take a long time in aggregate when we open many files, so should be + // included in the shuffle write time. + writeMetrics.incWriteTime(System.nanoTime() - openStartTime); + + while (records.hasNext()) { + final Product2 record = records.next(); + final K key = record._1(); + partitionWriters[partitioner.getPartition(key)].write(key, record._2()); + } - for (int i = 0; i < numPartitions; i++) { - try (DiskBlockObjectWriter writer = partitionWriters[i]) { - partitionWriterSegments[i] = writer.commitAndGet(); - } + for (int i = 0; i < numPartitions; i++) { + try (DiskBlockObjectWriter writer = partitionWriters[i]) { + partitionWriterSegments[i] = writer.commitAndGet(); } + } - commitMessage = writePartitionedData(mapOutputWriter); - mapStatus = MapStatus$.MODULE$.apply( - commitMessage.getLocation().orElse(null), - commitMessage.getPartitionLengths(), - mapTaskAttemptId); - } catch (Exception e) { - try { - mapOutputWriter.abort(e); - } catch (Exception e2) { - logger.error("Failed to abort the writer after failing to write map output.", e2); - e.addSuppressed(e2); + File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); + File tmp = Utils.tempFileWith(output); + try { + partitionLengths = writePartitionedFile(tmp); + shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); + } finally { + if (tmp.exists() && !tmp.delete()) { + logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); } - throw e; } + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } @VisibleForTesting long[] getPartitionLengths() { - return commitMessage.getPartitionLengths(); + return partitionLengths; } /** @@ -196,75 +179,41 @@ long[] getPartitionLengths() { * * @return array of lengths, in bytes, of each partition of the file (used by map output tracker). */ - private MapOutputWriterCommitMessage writePartitionedData( - ShuffleMapOutputWriter mapOutputWriter) throws IOException { + private long[] writePartitionedFile(File outputFile) throws IOException { // Track location of the partition starts in the output file - if (partitionWriters != null) { - final long writeStartTime = System.nanoTime(); - try { - for (int i = 0; i < numPartitions; i++) { - final File file = partitionWriterSegments[i].file(); - ShufflePartitionWriter writer = mapOutputWriter.getPartitionWriter(i); - if (file.exists()) { - if (transferToEnabled) { - // Using WritableByteChannelWrapper to make resource closing consistent between - // this implementation and UnsafeShuffleWriter. - Optional maybeOutputChannel = writer.openChannelWrapper(); - if (maybeOutputChannel.isPresent()) { - writePartitionedDataWithChannel(file, maybeOutputChannel.get()); - } else { - writePartitionedDataWithStream(file, writer); - } - } else { - writePartitionedDataWithStream(file, writer); - } - if (!file.delete()) { - logger.error("Unable to delete file for partition {}", i); - } - } - } - } finally { - writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); - } - partitionWriters = null; + final long[] lengths = new long[numPartitions]; + if (partitionWriters == null) { + // We were passed an empty iterator + return lengths; } - return mapOutputWriter.commitAllPartitions(); - } - private void writePartitionedDataWithChannel( - File file, - WritableByteChannelWrapper outputChannel) throws IOException { - boolean copyThrewException = true; + final FileOutputStream out = new FileOutputStream(outputFile, true); + final long writeStartTime = System.nanoTime(); + boolean threwException = true; try { - FileInputStream in = new FileInputStream(file); - try (FileChannel inputChannel = in.getChannel()) { - Utils.copyFileStreamNIO( - inputChannel, outputChannel.channel(), 0L, inputChannel.size()); - copyThrewException = false; - } finally { - Closeables.close(in, copyThrewException); - } - } finally { - Closeables.close(outputChannel, copyThrewException); - } - } - - private void writePartitionedDataWithStream(File file, ShufflePartitionWriter writer) - throws IOException { - boolean copyThrewException = true; - FileInputStream in = new FileInputStream(file); - OutputStream outputStream; - try { - outputStream = writer.openStream(); - try { - Utils.copyStream(in, outputStream, false, false); - copyThrewException = false; - } finally { - Closeables.close(outputStream, copyThrewException); + for (int i = 0; i < numPartitions; i++) { + final File file = partitionWriterSegments[i].file(); + if (file.exists()) { + final FileInputStream in = new FileInputStream(file); + boolean copyThrewException = true; + try { + lengths[i] = Utils.copyStream(in, out, false, transferToEnabled); + copyThrewException = false; + } finally { + Closeables.close(in, copyThrewException); + } + if (!file.delete()) { + logger.error("Unable to delete file for partition {}", i); + } + } } + threwException = false; } finally { - Closeables.close(in, copyThrewException); + Closeables.close(out, threwException); + writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); } + partitionWriters = null; + return lengths; } @Override diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index bc7c401e420c6..36081069b0e75 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -17,12 +17,9 @@ package org.apache.spark.shuffle.sort; -import java.nio.channels.Channels; -import java.util.Optional; import javax.annotation.Nullable; import java.io.*; import java.nio.channels.FileChannel; -import java.nio.channels.WritableByteChannel; import java.util.Iterator; import scala.Option; @@ -34,6 +31,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.io.ByteStreams; import com.google.common.io.Closeables; +import com.google.common.io.Files; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -43,6 +41,8 @@ import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; import org.apache.spark.io.NioBufferedFileInputStream; +import org.apache.commons.io.output.CloseShieldOutputStream; +import org.apache.commons.io.output.CountingOutputStream; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; @@ -50,13 +50,8 @@ import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; import org.apache.spark.serializer.SerializationStream; import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.ShuffleWriter; -import org.apache.spark.shuffle.api.MapOutputWriterCommitMessage; -import org.apache.spark.shuffle.api.ShuffleExecutorComponents; -import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; -import org.apache.spark.shuffle.api.ShufflePartitionWriter; -import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter; -import org.apache.spark.shuffle.api.WritableByteChannelWrapper; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.TimeTrackingOutputStream; import org.apache.spark.unsafe.Platform; @@ -70,14 +65,15 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); @VisibleForTesting + static final int DEFAULT_INITIAL_SORT_BUFFER_SIZE = 4096; static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024; private final BlockManager blockManager; + private final IndexShuffleBlockResolver shuffleBlockResolver; private final TaskMemoryManager memoryManager; private final SerializerInstance serializer; private final Partitioner partitioner; private final ShuffleWriteMetricsReporter writeMetrics; - private final ShuffleExecutorComponents shuffleExecutorComponents; private final int shuffleId; private final int mapId; private final TaskContext taskContext; @@ -85,6 +81,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final boolean transferToEnabled; private final int initialSortBufferSize; private final int inputBufferSizeInBytes; + private final int outputBufferSizeInBytes; @Nullable private MapStatus mapStatus; @Nullable private ShuffleExternalSorter sorter; @@ -106,15 +103,27 @@ private static final class MyByteArrayOutputStream extends ByteArrayOutputStream */ private boolean stopping = false; + private class CloseAndFlushShieldOutputStream extends CloseShieldOutputStream { + + CloseAndFlushShieldOutputStream(OutputStream outputStream) { + super(outputStream); + } + + @Override + public void flush() { + // do nothing + } + } + public UnsafeShuffleWriter( BlockManager blockManager, + IndexShuffleBlockResolver shuffleBlockResolver, TaskMemoryManager memoryManager, SerializedShuffleHandle handle, int mapId, TaskContext taskContext, SparkConf sparkConf, - ShuffleWriteMetricsReporter writeMetrics, - ShuffleExecutorComponents shuffleExecutorComponents) { + ShuffleWriteMetricsReporter writeMetrics) throws IOException { final int numPartitions = handle.dependency().partitioner().numPartitions(); if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) { throw new IllegalArgumentException( @@ -123,6 +132,7 @@ public UnsafeShuffleWriter( " reduce partitions"); } this.blockManager = blockManager; + this.shuffleBlockResolver = shuffleBlockResolver; this.memoryManager = memoryManager; this.mapId = mapId; final ShuffleDependency dep = handle.dependency(); @@ -130,7 +140,6 @@ public UnsafeShuffleWriter( this.serializer = dep.serializer().newInstance(); this.partitioner = dep.partitioner(); this.writeMetrics = writeMetrics; - this.shuffleExecutorComponents = shuffleExecutorComponents; this.taskContext = taskContext; this.sparkConf = sparkConf; this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); @@ -138,6 +147,8 @@ public UnsafeShuffleWriter( (int) sparkConf.get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()); this.inputBufferSizeInBytes = (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; + this.outputBufferSizeInBytes = + (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024; open(); } @@ -219,20 +230,26 @@ void closeAndWriteOutput() throws IOException { serOutputStream = null; final SpillInfo[] spills = sorter.closeAndGetSpills(); sorter = null; - final MapOutputWriterCommitMessage commitMessage; + final long[] partitionLengths; + final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); + final File tmp = Utils.tempFileWith(output); try { - commitMessage = mergeSpills(spills); - } finally { - for (SpillInfo spill : spills) { - if (spill.file.exists() && !spill.file.delete()) { - logger.error("Error while deleting spill file {}", spill.file.getPath()); + try { + partitionLengths = mergeSpills(spills, tmp); + } finally { + for (SpillInfo spill : spills) { + if (spill.file.exists() && ! spill.file.delete()) { + logger.error("Error while deleting spill file {}", spill.file.getPath()); + } } } + shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); + } finally { + if (tmp.exists() && !tmp.delete()) { + logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); + } } - mapStatus = MapStatus$.MODULE$.apply( - commitMessage.getLocation().orElse(null), - commitMessage.getPartitionLengths(), - taskContext.attemptNumber()); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } @VisibleForTesting @@ -264,166 +281,137 @@ void forceSorterToSpill() throws IOException { * * @return the partition lengths in the merged file. */ - private MapOutputWriterCommitMessage mergeSpills(SpillInfo[] spills) throws IOException { - MapOutputWriterCommitMessage commitMessage; - if (spills.length == 0) { - final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents - .createMapOutputWriter( - shuffleId, - mapId, - taskContext.taskAttemptId(), - partitioner.numPartitions()); - return mapWriter.commitAllPartitions(); - } else if (spills.length == 1) { - Optional maybeSingleFileWriter = - shuffleExecutorComponents.createSingleFileMapOutputWriter( - shuffleId, mapId, taskContext.taskAttemptId()); - if (maybeSingleFileWriter.isPresent()) { - // Here, we don't need to perform any metrics updates because the bytes written to this - // output file would have already been counted as shuffle bytes written. - long[] partitionLengths = spills[0].partitionLengths; - return maybeSingleFileWriter.get().transferMapSpillFile( - spills[0].file, partitionLengths); - } else { - commitMessage = mergeSpillsUsingStandardWriter(spills); - } - } else { - commitMessage = mergeSpillsUsingStandardWriter(spills); - } - return commitMessage; - } - - private MapOutputWriterCommitMessage mergeSpillsUsingStandardWriter( - SpillInfo[] spills) throws IOException { - assert (spills.length >= 2); - MapOutputWriterCommitMessage commitMessage; + private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException { final boolean compressionEnabled = (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_COMPRESS()); final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf); final boolean fastMergeEnabled = - (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_UNSAFE_FAST_MERGE_ENABLE()); + (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_UNDAFE_FAST_MERGE_ENABLE()); final boolean fastMergeIsSupported = !compressionEnabled || - CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec); + CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec); final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled(); - final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents - .createMapOutputWriter( - shuffleId, - mapId, - taskContext.taskAttemptId(), - partitioner.numPartitions()); try { - // There are multiple spills to merge, so none of these spill files' lengths were counted - // towards our shuffle write count or shuffle write time. If we use the slow merge path, - // then the final output file's size won't necessarily be equal to the sum of the spill - // files' sizes. To guard against this case, we look at the output file's actual size when - // computing shuffle bytes written. - // - // We allow the individual merge methods to report their own IO times since different merge - // strategies use different IO techniques. We count IO during merge towards the shuffle - // write time, which appears to be consistent with the "not bypassing merge-sort" branch in - // ExternalSorter. - if (fastMergeEnabled && fastMergeIsSupported) { - // Compression is disabled or we are using an IO compression codec that supports - // decompression of concatenated compressed streams, so we can perform a fast spill merge - // that doesn't need to interpret the spilled bytes. - if (transferToEnabled && !encryptionEnabled) { - logger.debug("Using transferTo-based fast merge"); - mergeSpillsWithTransferTo(spills, mapWriter); + if (spills.length == 0) { + new FileOutputStream(outputFile).close(); // Create an empty file + return new long[partitioner.numPartitions()]; + } else if (spills.length == 1) { + // Here, we don't need to perform any metrics updates because the bytes written to this + // output file would have already been counted as shuffle bytes written. + Files.move(spills[0].file, outputFile); + return spills[0].partitionLengths; + } else { + final long[] partitionLengths; + // There are multiple spills to merge, so none of these spill files' lengths were counted + // towards our shuffle write count or shuffle write time. If we use the slow merge path, + // then the final output file's size won't necessarily be equal to the sum of the spill + // files' sizes. To guard against this case, we look at the output file's actual size when + // computing shuffle bytes written. + // + // We allow the individual merge methods to report their own IO times since different merge + // strategies use different IO techniques. We count IO during merge towards the shuffle + // shuffle write time, which appears to be consistent with the "not bypassing merge-sort" + // branch in ExternalSorter. + if (fastMergeEnabled && fastMergeIsSupported) { + // Compression is disabled or we are using an IO compression codec that supports + // decompression of concatenated compressed streams, so we can perform a fast spill merge + // that doesn't need to interpret the spilled bytes. + if (transferToEnabled && !encryptionEnabled) { + logger.debug("Using transferTo-based fast merge"); + partitionLengths = mergeSpillsWithTransferTo(spills, outputFile); + } else { + logger.debug("Using fileStream-based fast merge"); + partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null); + } } else { - logger.debug("Using fileStream-based fast merge"); - mergeSpillsWithFileStream(spills, mapWriter, null); + logger.debug("Using slow merge"); + partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec); } - } else { - logger.debug("Using slow merge"); - mergeSpillsWithFileStream(spills, mapWriter, compressionCodec); + // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has + // in-memory records, we write out the in-memory records to a file but do not count that + // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs + // to be counted as shuffle write, but this will lead to double-counting of the final + // SpillInfo's bytes. + writeMetrics.decBytesWritten(spills[spills.length - 1].file.length()); + writeMetrics.incBytesWritten(outputFile.length()); + return partitionLengths; } - // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has - // in-memory records, we write out the in-memory records to a file but do not count that - // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs - // to be counted as shuffle write, but this will lead to double-counting of the final - // SpillInfo's bytes. - writeMetrics.decBytesWritten(spills[spills.length - 1].file.length()); - commitMessage = mapWriter.commitAllPartitions(); - } catch (Exception e) { - try { - mapWriter.abort(e); - } catch (Exception e2) { - logger.warn("Failed to abort writing the map output.", e2); - e.addSuppressed(e2); + } catch (IOException e) { + if (outputFile.exists() && !outputFile.delete()) { + logger.error("Unable to delete output file {}", outputFile.getPath()); } throw e; } - return commitMessage; } /** * Merges spill files using Java FileStreams. This code path is typically slower than * the NIO-based merge, {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], - * ShuffleMapOutputWriter)}, and it's mostly used in cases where the IO compression codec - * does not support concatenation of compressed data, when encryption is enabled, or when - * users have explicitly disabled use of {@code transferTo} in order to work around kernel bugs. + * File)}, and it's mostly used in cases where the IO compression codec does not support + * concatenation of compressed data, when encryption is enabled, or when users have + * explicitly disabled use of {@code transferTo} in order to work around kernel bugs. * This code path might also be faster in cases where individual partition size in a spill * is small and UnsafeShuffleWriter#mergeSpillsWithTransferTo method performs many small * disk ios which is inefficient. In those case, Using large buffers for input and output * files helps reducing the number of disk ios, making the file merging faster. * * @param spills the spills to merge. - * @param mapWriter the map output writer to use for output. + * @param outputFile the file to write the merged data to. * @param compressionCodec the IO compression codec, or null if shuffle compression is disabled. * @return the partition lengths in the merged file. */ - private void mergeSpillsWithFileStream( + private long[] mergeSpillsWithFileStream( SpillInfo[] spills, - ShuffleMapOutputWriter mapWriter, + File outputFile, @Nullable CompressionCodec compressionCodec) throws IOException { + assert (spills.length >= 2); final int numPartitions = partitioner.numPartitions(); + final long[] partitionLengths = new long[numPartitions]; final InputStream[] spillInputStreams = new InputStream[spills.length]; + final OutputStream bos = new BufferedOutputStream( + new FileOutputStream(outputFile), + outputBufferSizeInBytes); + // Use a counting output stream to avoid having to close the underlying file and ask + // the file system for its size after each partition is written. + final CountingOutputStream mergedFileOutputStream = new CountingOutputStream(bos); + boolean threwException = true; try { for (int i = 0; i < spills.length; i++) { spillInputStreams[i] = new NioBufferedFileInputStream( - spills[i].file, - inputBufferSizeInBytes); + spills[i].file, + inputBufferSizeInBytes); } for (int partition = 0; partition < numPartitions; partition++) { - boolean copyThrewException = true; - ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition); - OutputStream partitionOutput = writer.openStream(); - try { - partitionOutput = new TimeTrackingOutputStream(writeMetrics, partitionOutput); - partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); - if (compressionCodec != null) { - partitionOutput = compressionCodec.compressedOutputStream(partitionOutput); - } - for (int i = 0; i < spills.length; i++) { - final long partitionLengthInSpill = spills[i].partitionLengths[partition]; - - if (partitionLengthInSpill > 0) { - InputStream partitionInputStream = null; - boolean copySpillThrewException = true; - try { - partitionInputStream = new LimitedInputStream(spillInputStreams[i], - partitionLengthInSpill, false); - partitionInputStream = blockManager.serializerManager().wrapForEncryption( - partitionInputStream); - if (compressionCodec != null) { - partitionInputStream = compressionCodec.compressedInputStream( - partitionInputStream); - } - ByteStreams.copy(partitionInputStream, partitionOutput); - copySpillThrewException = false; - } finally { - Closeables.close(partitionInputStream, copySpillThrewException); + final long initialFileLength = mergedFileOutputStream.getByteCount(); + // Shield the underlying output stream from close() and flush() calls, so that we can close + // the higher level streams to make sure all data is really flushed and internal state is + // cleaned. + OutputStream partitionOutput = new CloseAndFlushShieldOutputStream( + new TimeTrackingOutputStream(writeMetrics, mergedFileOutputStream)); + partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); + if (compressionCodec != null) { + partitionOutput = compressionCodec.compressedOutputStream(partitionOutput); + } + for (int i = 0; i < spills.length; i++) { + final long partitionLengthInSpill = spills[i].partitionLengths[partition]; + if (partitionLengthInSpill > 0) { + InputStream partitionInputStream = new LimitedInputStream(spillInputStreams[i], + partitionLengthInSpill, false); + try { + partitionInputStream = blockManager.serializerManager().wrapForEncryption( + partitionInputStream); + if (compressionCodec != null) { + partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); } + ByteStreams.copy(partitionInputStream, partitionOutput); + } finally { + partitionInputStream.close(); } - copyThrewException = false; } - copyThrewException = false; - } finally { - Closeables.close(partitionOutput, copyThrewException); } - long numBytesWritten = writer.getNumBytesWritten(); - writeMetrics.incBytesWritten(numBytesWritten); + partitionOutput.flush(); + partitionOutput.close(); + partitionLengths[partition] = (mergedFileOutputStream.getByteCount() - initialFileLength); } threwException = false; } finally { @@ -432,7 +420,9 @@ private void mergeSpillsWithFileStream( for (InputStream stream : spillInputStreams) { Closeables.close(stream, threwException); } + Closeables.close(mergedFileOutputStream, threwException); } + return partitionLengths; } /** @@ -440,46 +430,54 @@ private void mergeSpillsWithFileStream( * This is only safe when the IO compression codec and serializer support concatenation of * serialized streams. * - * @param spills the spills to merge. - * @param mapWriter the map output writer to use for output. * @return the partition lengths in the merged file. */ - private void mergeSpillsWithTransferTo( - SpillInfo[] spills, - ShuffleMapOutputWriter mapWriter) throws IOException { + private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException { + assert (spills.length >= 2); final int numPartitions = partitioner.numPartitions(); + final long[] partitionLengths = new long[numPartitions]; final FileChannel[] spillInputChannels = new FileChannel[spills.length]; final long[] spillInputChannelPositions = new long[spills.length]; + FileChannel mergedFileOutputChannel = null; boolean threwException = true; try { for (int i = 0; i < spills.length; i++) { spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel(); } + // This file needs to opened in append mode in order to work around a Linux kernel bug that + // affects transferTo; see SPARK-3948 for more details. + mergedFileOutputChannel = new FileOutputStream(outputFile, true).getChannel(); + + long bytesWrittenToMergedFile = 0; for (int partition = 0; partition < numPartitions; partition++) { - boolean copyThrewException = true; - ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition); - WritableByteChannelWrapper resolvedChannel = writer.openChannelWrapper() - .orElseGet(() -> new StreamFallbackChannelWrapper(openStreamUnchecked(writer))); - try { - for (int i = 0; i < spills.length; i++) { - long partitionLengthInSpill = spills[i].partitionLengths[partition]; - final FileChannel spillInputChannel = spillInputChannels[i]; - final long writeStartTime = System.nanoTime(); - Utils.copyFileStreamNIO( - spillInputChannel, - resolvedChannel.channel(), - spillInputChannelPositions[i], - partitionLengthInSpill); - copyThrewException = false; - spillInputChannelPositions[i] += partitionLengthInSpill; - writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); - } - } finally { - Closeables.close(resolvedChannel, copyThrewException); + for (int i = 0; i < spills.length; i++) { + final long partitionLengthInSpill = spills[i].partitionLengths[partition]; + final FileChannel spillInputChannel = spillInputChannels[i]; + final long writeStartTime = System.nanoTime(); + Utils.copyFileStreamNIO( + spillInputChannel, + mergedFileOutputChannel, + spillInputChannelPositions[i], + partitionLengthInSpill); + spillInputChannelPositions[i] += partitionLengthInSpill; + writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); + bytesWrittenToMergedFile += partitionLengthInSpill; + partitionLengths[partition] += partitionLengthInSpill; } - long numBytes = writer.getNumBytesWritten(); - writeMetrics.incBytesWritten(numBytes); + } + // Check the position after transferTo loop to see if it is in the right position and raise an + // exception if it is incorrect. The position will not be increased to the expected length + // after calling transferTo in kernel version 2.6.32. This issue is described at + // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948. + if (mergedFileOutputChannel.position() != bytesWrittenToMergedFile) { + throw new IOException( + "Current position " + mergedFileOutputChannel.position() + " does not equal expected " + + "position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" + + " version to see if it is 2.6.32, as there is a kernel bug which will lead to " + + "unexpected behavior when using transferTo. You can set spark.file.transferTo=false " + + "to disable this NIO feature." + ); } threwException = false; } finally { @@ -489,7 +487,9 @@ private void mergeSpillsWithTransferTo( assert(spillInputChannelPositions[i] == spills[i].file.length()); Closeables.close(spillInputChannels[i], threwException); } + Closeables.close(mergedFileOutputChannel, threwException); } + return partitionLengths; } @Override @@ -518,30 +518,4 @@ public Option stop(boolean success) { } } } - - private static OutputStream openStreamUnchecked(ShufflePartitionWriter writer) { - try { - return writer.openStream(); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - private static final class StreamFallbackChannelWrapper implements WritableByteChannelWrapper { - private final WritableByteChannel channel; - - StreamFallbackChannelWrapper(OutputStream fallbackStream) { - this.channel = Channels.newChannel(fallbackStream); - } - - @Override - public WritableByteChannel channel() { - return channel; - } - - @Override - public void close() throws IOException { - channel.close(); - } - } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java deleted file mode 100644 index 77fcd34f962bf..0000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.sort.io; - -import org.apache.spark.SparkConf; -import org.apache.spark.shuffle.api.ShuffleDriverComponents; -import org.apache.spark.shuffle.api.ShuffleExecutorComponents; -import org.apache.spark.shuffle.api.ShuffleDataIO; -import org.apache.spark.shuffle.sort.lifecycle.LocalDiskShuffleDriverComponents; - -/** - * Implementation of the {@link ShuffleDataIO} plugin system that replicates the local shuffle - * storage and index file functionality that has historically been used from Spark 2.4 and earlier. - */ -public class LocalDiskShuffleDataIO implements ShuffleDataIO { - - private final SparkConf sparkConf; - - public LocalDiskShuffleDataIO(SparkConf sparkConf) { - this.sparkConf = sparkConf; - } - - @Override - public ShuffleDriverComponents driver() { - return new LocalDiskShuffleDriverComponents(); - } - - @Override - public ShuffleExecutorComponents executor() { - return new LocalDiskShuffleExecutorComponents(sparkConf); - } - -} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java deleted file mode 100644 index c8d70d72eb02e..0000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.sort.io; - -import java.io.InputStream; -import java.util.Map; -import java.util.Optional; - -import com.google.common.annotations.VisibleForTesting; - -import org.apache.spark.MapOutputTracker; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkEnv; -import org.apache.spark.serializer.SerializerManager; -import org.apache.spark.shuffle.api.ShuffleBlockInfo; -import org.apache.spark.shuffle.api.ShuffleExecutorComponents; -import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; -import org.apache.spark.shuffle.IndexShuffleBlockResolver; -import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter; -import org.apache.spark.shuffle.io.LocalDiskShuffleReadSupport; -import org.apache.spark.storage.BlockManager; -import org.apache.spark.storage.BlockManagerId; - -public class LocalDiskShuffleExecutorComponents implements ShuffleExecutorComponents { - - private final SparkConf sparkConf; - private LocalDiskShuffleReadSupport shuffleReadSupport; - private BlockManagerId shuffleServerId; - private BlockManager blockManager; - private IndexShuffleBlockResolver blockResolver; - - public LocalDiskShuffleExecutorComponents(SparkConf sparkConf) { - this.sparkConf = sparkConf; - } - - @VisibleForTesting - public LocalDiskShuffleExecutorComponents( - SparkConf sparkConf, - BlockManager blockManager, - MapOutputTracker mapOutputTracker, - SerializerManager serializerManager, - IndexShuffleBlockResolver blockResolver, - BlockManagerId shuffleServerId) { - this.sparkConf = sparkConf; - this.blockManager = blockManager; - this.blockResolver = blockResolver; - this.shuffleServerId = shuffleServerId; - this.shuffleReadSupport = new LocalDiskShuffleReadSupport( - blockManager, mapOutputTracker, serializerManager, sparkConf); - } - - @Override - public void initializeExecutor(String appId, String execId, Map extraConfigs) { - blockManager = SparkEnv.get().blockManager(); - if (blockManager == null) { - throw new IllegalStateException("No blockManager available from the SparkEnv."); - } - shuffleServerId = blockManager.shuffleServerId(); - blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager); - MapOutputTracker mapOutputTracker = SparkEnv.get().mapOutputTracker(); - SerializerManager serializerManager = SparkEnv.get().serializerManager(); - shuffleReadSupport = new LocalDiskShuffleReadSupport( - blockManager, mapOutputTracker, serializerManager, sparkConf); - } - - @Override - public ShuffleMapOutputWriter createMapOutputWriter( - int shuffleId, - int mapId, - long mapTaskAttemptId, - int numPartitions) { - if (blockResolver == null) { - throw new IllegalStateException( - "Executor components must be initialized before getting writers."); - } - return new LocalDiskShuffleMapOutputWriter( - shuffleId, - mapId, - numPartitions, - blockResolver, - shuffleServerId, - sparkConf); - } - - @Override - public Optional createSingleFileMapOutputWriter( - int shuffleId, - int mapId, - long mapTaskAttemptId) { - if (blockResolver == null) { - throw new IllegalStateException( - "Executor components must be initialized before getting writers."); - } - return Optional.of(new LocalDiskSingleSpillMapOutputWriter( - shuffleId, mapId, blockResolver, shuffleServerId)); - } - - @Override - public Iterable getPartitionReaders(Iterable blockMetadata) { - if (blockResolver == null) { - throw new IllegalStateException( - "Executor components must be initialized before getting readers."); - } - return shuffleReadSupport.getPartitionReaders(blockMetadata); - } - - @Override - public boolean shouldWrapPartitionReaderStream() { - return false; - } -} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java deleted file mode 100644 index 064875420c473..0000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java +++ /dev/null @@ -1,281 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.sort.io; - -import java.io.BufferedOutputStream; -import java.io.File; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.OutputStream; -import java.nio.channels.FileChannel; -import java.nio.channels.WritableByteChannel; -import java.util.Optional; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.spark.SparkConf; -import org.apache.spark.shuffle.IndexShuffleBlockResolver; -import org.apache.spark.shuffle.api.MapOutputWriterCommitMessage; -import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; -import org.apache.spark.shuffle.api.ShufflePartitionWriter; -import org.apache.spark.shuffle.api.WritableByteChannelWrapper; -import org.apache.spark.storage.BlockManagerId; -import org.apache.spark.internal.config.package$; -import org.apache.spark.util.Utils; - -/** - * Implementation of {@link ShuffleMapOutputWriter} that replicates the functionality of shuffle - * persisting shuffle data to local disk alongside index files, identical to Spark's historic - * canonical shuffle storage mechanism. - */ -public class LocalDiskShuffleMapOutputWriter implements ShuffleMapOutputWriter { - - private static final Logger log = - LoggerFactory.getLogger(LocalDiskShuffleMapOutputWriter.class); - - private final int shuffleId; - private final int mapId; - private final IndexShuffleBlockResolver blockResolver; - private final long[] partitionLengths; - private final int bufferSize; - private final BlockManagerId shuffleServerId; - private int lastPartitionId = -1; - private long currChannelPosition; - private long bytesWrittenToMergedFile = 0L; - - private final File outputFile; - private File outputTempFile; - private FileOutputStream outputFileStream; - private FileChannel outputFileChannel; - private BufferedOutputStream outputBufferedFileStream; - - public LocalDiskShuffleMapOutputWriter( - int shuffleId, - int mapId, - int numPartitions, - IndexShuffleBlockResolver blockResolver, - BlockManagerId shuffleServerId, - SparkConf sparkConf) { - this.shuffleId = shuffleId; - this.mapId = mapId; - this.blockResolver = blockResolver; - this.bufferSize = - (int) (long) sparkConf.get( - package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024; - this.shuffleServerId = shuffleServerId; - this.partitionLengths = new long[numPartitions]; - this.outputFile = blockResolver.getDataFile(shuffleId, mapId); - this.outputTempFile = null; - } - - @Override - public ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws IOException { - if (reducePartitionId <= lastPartitionId) { - throw new IllegalArgumentException("Partitions should be requested in increasing order."); - } - lastPartitionId = reducePartitionId; - if (outputTempFile == null) { - outputTempFile = Utils.tempFileWith(outputFile); - } - if (outputFileChannel != null) { - currChannelPosition = outputFileChannel.position(); - } else { - currChannelPosition = 0L; - } - return new LocalDiskShufflePartitionWriter(reducePartitionId); - } - - @Override - public MapOutputWriterCommitMessage commitAllPartitions() throws IOException { - // Check the position after transferTo loop to see if it is in the right position and raise a - // exception if it is incorrect. The position will not be increased to the expected length - // after calling transferTo in kernel version 2.6.32. This issue is described at - // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948. - if (outputFileChannel != null && outputFileChannel.position() != bytesWrittenToMergedFile) { - throw new IOException( - "Current position " + outputFileChannel.position() + " does not equal expected " + - "position " + bytesWrittenToMergedFile + " after transferTo. Please check your " + - " kernel version to see if it is 2.6.32, as there is a kernel bug which will lead " + - "to unexpected behavior when using transferTo. You can set " + - "spark.file.transferTo=false to disable this NIO feature."); - } - cleanUp(); - File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null; - blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp); - return MapOutputWriterCommitMessage.of(partitionLengths, shuffleServerId); - } - - @Override - public void abort(Throwable error) throws IOException { - cleanUp(); - if (outputTempFile != null && outputTempFile.exists() && !outputTempFile.delete()) { - log.warn("Failed to delete temporary shuffle file at {}", outputTempFile.getAbsolutePath()); - } - } - - private void cleanUp() throws IOException { - if (outputBufferedFileStream != null) { - outputBufferedFileStream.close(); - } - if (outputFileChannel != null) { - outputFileChannel.close(); - } - if (outputFileStream != null) { - outputFileStream.close(); - } - } - - private void initStream() throws IOException { - if (outputFileStream == null) { - outputFileStream = new FileOutputStream(outputTempFile, true); - } - if (outputBufferedFileStream == null) { - outputBufferedFileStream = new BufferedOutputStream(outputFileStream, bufferSize); - } - } - - private void initChannel() throws IOException { - // This file needs to opened in append mode in order to work around a Linux kernel bug that - // affects transferTo; see SPARK-3948 for more details. - if (outputFileChannel == null) { - outputFileChannel = new FileOutputStream(outputTempFile, true).getChannel(); - } - } - - private class LocalDiskShufflePartitionWriter implements ShufflePartitionWriter { - - private final int partitionId; - private PartitionWriterStream partStream = null; - private PartitionWriterChannel partChannel = null; - - private LocalDiskShufflePartitionWriter(int partitionId) { - this.partitionId = partitionId; - } - - @Override - public OutputStream openStream() throws IOException { - if (partStream == null) { - if (outputFileChannel != null) { - throw new IllegalStateException("Requested an output channel for a previous write but" + - " now an output stream has been requested. Should not be using both channels" + - " and streams to write."); - } - initStream(); - partStream = new PartitionWriterStream(partitionId); - } - return partStream; - } - - @Override - public Optional openChannelWrapper() throws IOException { - if (partChannel == null) { - if (partStream != null) { - throw new IllegalStateException("Requested an output stream for a previous write but" + - " now an output channel has been requested. Should not be using both channels" + - " and streams to write."); - } - initChannel(); - partChannel = new PartitionWriterChannel(partitionId); - } - return Optional.of(partChannel); - } - - @Override - public long getNumBytesWritten() { - if (partChannel != null) { - try { - return partChannel.getCount(); - } catch (IOException e) { - throw new RuntimeException(e); - } - } else if (partStream != null) { - return partStream.getCount(); - } else { - // Assume an empty partition if stream and channel are never created - return 0; - } - } - } - - private class PartitionWriterStream extends OutputStream { - private final int partitionId; - private int count = 0; - private boolean isClosed = false; - - PartitionWriterStream(int partitionId) { - this.partitionId = partitionId; - } - - public int getCount() { - return count; - } - - @Override - public void write(int b) throws IOException { - verifyNotClosed(); - outputBufferedFileStream.write(b); - count++; - } - - @Override - public void write(byte[] buf, int pos, int length) throws IOException { - verifyNotClosed(); - outputBufferedFileStream.write(buf, pos, length); - count += length; - } - - @Override - public void close() { - isClosed = true; - partitionLengths[partitionId] = count; - bytesWrittenToMergedFile += count; - } - - private void verifyNotClosed() { - if (isClosed) { - throw new IllegalStateException("Attempting to write to a closed block output stream."); - } - } - } - - private class PartitionWriterChannel implements WritableByteChannelWrapper { - - private final int partitionId; - - PartitionWriterChannel(int partitionId) { - this.partitionId = partitionId; - } - - public long getCount() throws IOException { - long writtenPosition = outputFileChannel.position(); - return writtenPosition - currChannelPosition; - } - - @Override - public WritableByteChannel channel() { - return outputFileChannel; - } - - @Override - public void close() throws IOException { - partitionLengths[partitionId] = getCount(); - bytesWrittenToMergedFile += partitionLengths[partitionId]; - } - } -} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java deleted file mode 100644 index 219f9ee1296dd..0000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.sort.io; - -import java.io.File; -import java.io.IOException; -import java.nio.file.Files; - -import org.apache.spark.shuffle.IndexShuffleBlockResolver; -import org.apache.spark.shuffle.api.MapOutputWriterCommitMessage; -import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter; -import org.apache.spark.storage.BlockManagerId; -import org.apache.spark.util.Utils; - -public class LocalDiskSingleSpillMapOutputWriter - implements SingleSpillShuffleMapOutputWriter { - - private final int shuffleId; - private final int mapId; - private final IndexShuffleBlockResolver blockResolver; - private final BlockManagerId shuffleServerId; - - public LocalDiskSingleSpillMapOutputWriter( - int shuffleId, - int mapId, - IndexShuffleBlockResolver blockResolver, - BlockManagerId shuffleServerId) { - this.shuffleId = shuffleId; - this.mapId = mapId; - this.blockResolver = blockResolver; - this.shuffleServerId = shuffleServerId; - } - - @Override - public MapOutputWriterCommitMessage transferMapSpillFile( - File mapSpillFile, - long[] partitionLengths) throws IOException { - // The map spill file already has the proper format, and it contains all of the partition data. - // So just transfer it directly to the destination without any merging. - File outputFile = blockResolver.getDataFile(shuffleId, mapId); - File tempFile = Utils.tempFileWith(outputFile); - Files.move(mapSpillFile.toPath(), tempFile.toPath()); - blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tempFile); - return MapOutputWriterCommitMessage.of(partitionLengths, shuffleServerId); - } -} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/LocalDiskShuffleDriverComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/LocalDiskShuffleDriverComponents.java deleted file mode 100644 index 183769274841c..0000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/LocalDiskShuffleDriverComponents.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.sort.lifecycle; - -import java.util.Map; - -import com.google.common.collect.ImmutableMap; - -import org.apache.spark.SparkEnv; -import org.apache.spark.shuffle.api.ShuffleDriverComponents; -import org.apache.spark.internal.config.package$; -import org.apache.spark.storage.BlockManagerMaster; - -public class LocalDiskShuffleDriverComponents implements ShuffleDriverComponents { - - private BlockManagerMaster blockManagerMaster; - private boolean shouldUnregisterOutputOnHostOnFetchFailure; - - @Override - public Map initializeApplication() { - blockManagerMaster = SparkEnv.get().blockManager().master(); - this.shouldUnregisterOutputOnHostOnFetchFailure = - SparkEnv.get().blockManager().externalShuffleServiceEnabled() - && (boolean) SparkEnv.get().conf() - .get(package$.MODULE$.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE()); - return ImmutableMap.of(); - } - - @Override - public void removeShuffle(int shuffleId, boolean blocking) { - checkInitialized(); - blockManagerMaster.removeShuffle(shuffleId, blocking); - } - - @Override - public boolean shouldUnregisterOutputOnHostOnFetchFailure() { - return shouldUnregisterOutputOnHostOnFetchFailure; - } - - private void checkInitialized() { - if (blockManagerMaster == null) { - throw new IllegalStateException("Driver components must be initialized before using"); - } - } -} diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 98232380cc266..305ec46a364a2 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -27,7 +27,6 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData} -import org.apache.spark.shuffle.api.ShuffleDriverComponents import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, ThreadUtils, Utils} /** @@ -59,9 +58,7 @@ private class CleanupTaskWeakReference( * to be processed when the associated object goes out of scope of the application. Actual * cleanup is performed in a separate daemon thread. */ -private[spark] class ContextCleaner( - sc: SparkContext, - shuffleDriverComponents: ShuffleDriverComponents) extends Logging { +private[spark] class ContextCleaner(sc: SparkContext) extends Logging { /** * A buffer to ensure that `CleanupTaskWeakReference`s are not garbage collected as long as they @@ -225,7 +222,7 @@ private[spark] class ContextCleaner( try { logDebug("Cleaning shuffle " + shuffleId) mapOutputTrackerMaster.unregisterShuffle(shuffleId) - shuffleDriverComponents.removeShuffle(shuffleId, blocking) + blockManagerMaster.removeShuffle(shuffleId, blocking) listeners.asScala.foreach(_.shuffleCleaned(shuffleId)) logInfo("Cleaned shuffle " + shuffleId) } catch { @@ -273,6 +270,7 @@ private[spark] class ContextCleaner( } } + private def blockManagerMaster = sc.env.blockManager.master private def broadcastManager = sc.env.broadcastManager private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] } diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 8e4846e4b1ffd..fb051a8c0db8e 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -96,7 +96,6 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( shuffleId, _rdd.partitions.length, this) _rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this)) - _rdd.sparkContext.shuffleDriverComponents.registerShuffle(shuffleId) } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index a169192225d7e..a17a507a0cf04 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -36,7 +36,7 @@ import org.apache.spark.internal.config._ import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.MetadataFetchFailedException -import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockAttemptId, ShuffleBlockId} +import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util._ /** @@ -148,8 +148,7 @@ private class ShuffleStatus(numPartitions: Int) { */ def removeOutputsByFilter(f: (BlockManagerId) => Boolean): Unit = synchronized { for (mapId <- 0 until mapStatuses.length) { - if (mapStatuses(mapId) != null && mapStatuses(mapId).location != null - && f(mapStatuses(mapId).location)) { + if (mapStatuses(mapId) != null && f(mapStatuses(mapId).location)) { decrementNumAvailableOutputs(mapStatuses(mapId).location) mapStatuses(mapId) = null invalidateSerializedMapOutputStatusCache() @@ -218,21 +217,17 @@ private class ShuffleStatus(numPartitions: Int) { } private[this] def incrementNumAvailableOutputs(bmAddress: BlockManagerId): Unit = synchronized { - if (bmAddress != null) { - _numOutputsPerExecutorId(bmAddress.executorId) += 1 - } + _numOutputsPerExecutorId(bmAddress.executorId) += 1 _numAvailableOutputs += 1 } private[this] def decrementNumAvailableOutputs(bmAddress: BlockManagerId): Unit = synchronized { - if (bmAddress != null) { - assert(_numOutputsPerExecutorId(bmAddress.executorId) >= 1, - s"Tried to remove non-existent output from ${bmAddress.executorId}") - if (_numOutputsPerExecutorId(bmAddress.executorId) == 1) { - _numOutputsPerExecutorId.remove(bmAddress.executorId) - } else { - _numOutputsPerExecutorId(bmAddress.executorId) -= 1 - } + assert(_numOutputsPerExecutorId(bmAddress.executorId) >= 1, + s"Tried to remove non-existent output from ${bmAddress.executorId}") + if (_numOutputsPerExecutorId(bmAddress.executorId) == 1) { + _numOutputsPerExecutorId.remove(bmAddress.executorId) + } else { + _numOutputsPerExecutorId(bmAddress.executorId) -= 1 } _numAvailableOutputs -= 1 } @@ -327,7 +322,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging // For testing def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) - : Iterator[(Option[BlockManagerId], Seq[(BlockId, Long)])] = { + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1) } @@ -341,7 +336,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * describing the shuffle blocks that are stored at that block manager. */ def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(Option[BlockManagerId], Seq[(BlockId, Long)])] + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] /** * Deletes map output status information for the specified shuffle stage. @@ -668,7 +663,7 @@ private[spark] class MapOutputTrackerMaster( /** * Return a list of locations that each have fraction of map output greater than the specified - * threshold. Ignores shuffle blocks without location or executor id. + * threshold. * * @param shuffleId id of the shuffle * @param reducerId id of the reduce task @@ -697,8 +692,7 @@ private[spark] class MapOutputTrackerMaster( // array with null entries for each output, and registerMapOutputs, which populates it // with valid status entries. This is possible if one thread schedules a job which // depends on an RDD which is currently being computed by another thread. - // This also ignores locations that are not on executors. - if (status != null && status.location != null && status.location.executorId != null) { + if (status != null) { val blockSize = status.getSizeForBlock(reducerId) if (blockSize > 0) { locs(status.location) = locs.getOrElse(status.location, 0L) + blockSize @@ -737,7 +731,7 @@ private[spark] class MapOutputTrackerMaster( // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. // This method is only called in local-mode. def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(Option[BlockManagerId], Seq[(BlockId, Long)])] = { + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") shuffleStatuses.get(shuffleId) match { case Some (shuffleStatus) => @@ -774,7 +768,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(Option[BlockManagerId], Seq[(BlockId, Long)])] = { + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") val statuses = getStatuses(shuffleId) try { @@ -965,9 +959,9 @@ private[spark] object MapOutputTracker extends SafeLogging { shuffleId: Int, startPartition: Int, endPartition: Int, - statuses: Array[MapStatus]): Iterator[(Option[BlockManagerId], Seq[(BlockId, Long)])] = { + statuses: Array[MapStatus]): Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { assert (statuses != null) - val splitsByAddress = new HashMap[Option[BlockManagerId], ListBuffer[(BlockId, Long)]] + val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long)]] for ((status, mapId) <- statuses.iterator.zipWithIndex) { if (status == null) { val errorMessage = "Missing an output location for shuffle" @@ -979,13 +973,8 @@ private[spark] object MapOutputTracker extends SafeLogging { for (part <- startPartition until endPartition) { val size = status.getSizeForBlock(part) if (size != 0) { - if (status.location != null) { - splitsByAddress.getOrElseUpdate(Option.apply(status.location), ListBuffer()) += - ((ShuffleBlockAttemptId(shuffleId, mapId, part, status.mapTaskAttemptId), size)) - } else { - splitsByAddress.getOrElseUpdate(Option.empty, ListBuffer()) += - ((ShuffleBlockAttemptId(shuffleId, mapId, part, status.mapTaskAttemptId), size)) - } + splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) += + ((ShuffleBlockId(shuffleId, mapId, part), size)) } } } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index e6d48666f8dd5..eeefe4ada79d1 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -57,7 +57,6 @@ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.StandaloneSchedulerBackend import org.apache.spark.scheduler.local.LocalSchedulerBackend -import org.apache.spark.shuffle.api.{ShuffleDataIO, ShuffleDriverComponents} import org.apache.spark.status.{AppStatusSource, AppStatusStore} import org.apache.spark.status.api.v1.ThreadStackTrace import org.apache.spark.storage._ @@ -216,7 +215,6 @@ class SparkContext(config: SparkConf) extends SafeLogging { private var _shutdownHookRef: AnyRef = _ private var _statusStore: AppStatusStore = _ private var _heartbeater: Heartbeater = _ - private var _shuffleDriverComponents: ShuffleDriverComponents = _ /* ------------------------------------------------------------------------------------- * | Accessors and public fields. These provide access to the internal state of the | @@ -308,8 +306,6 @@ class SparkContext(config: SparkConf) extends SafeLogging { _dagScheduler = ds } - private[spark] def shuffleDriverComponents: ShuffleDriverComponents = _shuffleDriverComponents - /** * A unique identifier for the Spark application. * Its format depends on the scheduler implementation. @@ -495,14 +491,6 @@ class SparkContext(config: SparkConf) extends SafeLogging { executorEnvs ++= _conf.getExecutorEnv executorEnvs("SPARK_USER") = sparkUser - val configuredPluginClasses = conf.get(SHUFFLE_IO_PLUGIN_CLASS) - val maybeIO = Utils.loadExtensions( - classOf[ShuffleDataIO], Seq(configuredPluginClasses), conf) - require(maybeIO.size == 1, s"Failed to load plugins of type $configuredPluginClasses") - _shuffleDriverComponents = maybeIO.head.driver() - _shuffleDriverComponents.initializeApplication().asScala.foreach { - case (k, v) => _conf.set(ShuffleDataIO.SHUFFLE_SPARK_CONF_PREFIX + k, v) } - // We need to register "HeartbeatReceiver" before "createTaskScheduler" because Executor will // retrieve "HeartbeatReceiver" in the constructor. (SPARK-6640) _heartbeatReceiver = env.rpcEnv.setupEndpoint( @@ -573,7 +561,7 @@ class SparkContext(config: SparkConf) extends SafeLogging { _cleaner = if (_conf.get(CLEANER_REFERENCE_TRACKING)) { - Some(new ContextCleaner(this, _shuffleDriverComponents)) + Some(new ContextCleaner(this)) } else { None } @@ -1971,7 +1959,6 @@ class SparkContext(config: SparkConf) extends SafeLogging { } _heartbeater = null } - _shuffleDriverComponents.cleanupApplication() if (env != null && _heartbeatReceiver != null) { Utils.tryLogNonFatalError { env.rpcEnv.stop(_heartbeatReceiver) diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index df30fd5c7f679..ea79c7310349d 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -56,8 +56,6 @@ class TaskMetrics private[spark] () extends Serializable { private val _diskBytesSpilled = new LongAccumulator private val _peakExecutionMemory = new LongAccumulator private val _updatedBlockStatuses = new CollectionAccumulator[(BlockId, BlockStatus)] - private var _decorFunc: TempShuffleReadMetrics => TempShuffleReadMetrics = - Predef.identity[TempShuffleReadMetrics] /** * Time taken on the executor to deserialize this task. @@ -189,17 +187,11 @@ class TaskMetrics private[spark] () extends Serializable { * be lost. */ private[spark] def createTempShuffleReadMetrics(): TempShuffleReadMetrics = synchronized { - val tempShuffleMetrics = new TempShuffleReadMetrics - val readMetrics = _decorFunc(tempShuffleMetrics) - tempShuffleReadMetrics += tempShuffleMetrics + val readMetrics = new TempShuffleReadMetrics + tempShuffleReadMetrics += readMetrics readMetrics } - private[spark] def decorateTempShuffleReadMetrics( - decorFunc: TempShuffleReadMetrics => TempShuffleReadMetrics): Unit = synchronized { - _decorFunc = decorFunc - } - /** * Merge values across all temporary [[ShuffleReadMetrics]] into `_shuffleReadMetrics`. * This is expected to be called on executor heartbeat and at the end of a task. diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 3dbe28cfb1776..6c5f36149e5f4 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -23,7 +23,6 @@ import org.apache.spark.api.conda.CondaBootstrapMode import org.apache.spark.launcher.SparkLauncher import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{EventLoggingListener, SchedulingMode} -import org.apache.spark.shuffle.sort.io.LocalDiskShuffleDataIO import org.apache.spark.storage.{DefaultTopologyMapper, RandomBlockReplicationPolicy} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.util.Utils @@ -798,12 +797,6 @@ package object config { .booleanConf .createWithDefault(false) - private[spark] val SHUFFLE_IO_PLUGIN_CLASS = - ConfigBuilder("spark.shuffle.sort.io.plugin.class") - .doc("Name of the class to use for shuffle IO.") - .stringConf - .createWithDefault(classOf[LocalDiskShuffleDataIO].getName) - private[spark] val SHUFFLE_FILE_BUFFER_SIZE = ConfigBuilder("spark.shuffle.file.buffer") .doc("Size of the in-memory buffer for each shuffle file output stream, in KiB unless " + @@ -980,7 +973,7 @@ package object config { .booleanConf .createWithDefault(false) - private[spark] val SHUFFLE_UNSAFE_FAST_MERGE_ENABLE = + private[spark] val SHUFFLE_UNDAFE_FAST_MERGE_ENABLE = ConfigBuilder("spark.shuffle.unsafe.fastMergeEnabled") .doc("Whether to perform a fast spill merge.") .booleanConf diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index e8f0e419c8d4b..ef7429db84878 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -170,8 +170,6 @@ private[spark] class DAGScheduler( private[scheduler] val activeJobs = new HashSet[ActiveJob] - private[scheduler] val shuffleDriverComponents = sc.shuffleDriverComponents - /** * Contains the locations that each RDD's partitions are cached on. This map's keys are RDD ids * and its values are arrays indexed by partition numbers. Each array value is the set of @@ -724,7 +722,7 @@ private[spark] class DAGScheduler( assert(partitions.size > 0) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] - val waiter = new JobWaiter[U](this, jobId, partitions.size, resultHandler) + val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler) eventProcessLoop.post(JobSubmitted( jobId, rdd, func2, partitions.toArray, callSite, waiter, SerializationUtils.clone(properties))) @@ -825,8 +823,7 @@ private[spark] class DAGScheduler( // This makes it easier to avoid race conditions between the user code and the map output // tracker that might result if we told the user the stage had finished, but then they queries // the map output tracker and some node failures had caused the output statistics to be lost. - val waiter = new JobWaiter[MapOutputStatistics]( - this, jobId, 1, (i: Int, r: MapOutputStatistics) => callback(r)) + val waiter = new JobWaiter(this, jobId, 1, (i: Int, r: MapOutputStatistics) => callback(r)) eventProcessLoop.post(MapStageSubmitted( jobId, dependency, callSite, waiter, SerializationUtils.clone(properties))) waiter @@ -1409,24 +1406,14 @@ private[spark] class DAGScheduler( case smt: ShuffleMapTask => val shuffleStage = stage.asInstanceOf[ShuffleMapStage] val status = event.result.asInstanceOf[MapStatus] - if (status.location != null) { - val execId = status.location.executorId - if (execId != null) { - logDebug("ShuffleMapTask finished on " + execId) - if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) { - logInfo(s"Ignoring possibly bogus $smt completion from executor $execId") - } else { - // The epoch of the task is acceptable (i.e., the task was launched after the most - // recent failure we're aware of for the executor), so mark the task's output as - // available. - mapOutputTracker.registerMapOutput( - shuffleStage.shuffleDep.shuffleId, smt.partitionId, status) - } - } else { - mapOutputTracker.registerMapOutput( - shuffleStage.shuffleDep.shuffleId, smt.partitionId, status) - } + val execId = status.location.executorId + logDebug("Registering shuffle output on executor " + execId) + if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) { + logInfo(s"Ignoring possibly bogus $smt completion from executor $execId") } else { + // The epoch of the task is acceptable (i.e., the task was launched after the most + // recent failure we're aware of for the executor), so mark the task's output as + // available. mapOutputTracker.registerMapOutput( shuffleStage.shuffleDep.shuffleId, smt.partitionId, status) } @@ -1486,10 +1473,13 @@ private[spark] class DAGScheduler( logInfo("Ignoring result from " + rt + " because its job has finished") } - case smt: ShuffleMapTask => + case _: ShuffleMapTask => val shuffleStage = stage.asInstanceOf[ShuffleMapStage] shuffleStage.pendingPartitions -= task.partitionId val status = event.result.asInstanceOf[MapStatus] + val execId = status.location.executorId + logDebug("ShuffleMapTask finished on " + execId) + if (runningStages.contains(shuffleStage) && shuffleStage.pendingPartitions.isEmpty) { markStageAsFinished(shuffleStage) logInfo("looking for newly runnable stages") @@ -1671,31 +1661,21 @@ private[spark] class DAGScheduler( // TODO: mark the executor as failed only if there were lots of fetch failures on it if (bmAddress != null) { - if (bmAddress.executorId == null) { - if (shuffleDriverComponents.shouldUnregisterOutputOnHostOnFetchFailure()) { - val currentEpoch = task.epoch - val host = bmAddress.host - logInfo("Shuffle files lost for host: %s (epoch %d)".format(host, currentEpoch)) - mapOutputTracker.removeOutputsOnHost(host) - clearCacheLocs() - } + val hostToUnregisterOutputs = if (env.blockManager.externalShuffleServiceEnabled && + unRegisterOutputOnHostOnFetchFailure) { + // We had a fetch failure with the external shuffle service, so we + // assume all shuffle data on the node is bad. + Some(bmAddress.host) } else { - val hostToUnregisterOutputs = - if (shuffleDriverComponents.shouldUnregisterOutputOnHostOnFetchFailure()) { - // We had a fetch failure with the external shuffle service, so we - // assume all shuffle data on the node is bad. - Some(bmAddress.host) - } else { - // Unregister shuffle data just for one executor (we don't have any - // reason to believe shuffle data has been lost for the entire host). - None - } - removeExecutorAndUnregisterOutputs( - execId = bmAddress.executorId, - fileLost = true, - hostToUnregisterOutputs = hostToUnregisterOutputs, - maybeEpoch = Some(task.epoch)) + // Unregister shuffle data just for one executor (we don't have any + // reason to believe shuffle data has been lost for the entire host). + None } + removeExecutorAndUnregisterOutputs( + execId = bmAddress.executorId, + fileLost = true, + hostToUnregisterOutputs = hostToUnregisterOutputs, + maybeEpoch = Some(task.epoch)) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 0d8bf57ab9162..64f0a060a247c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -36,8 +36,6 @@ private[spark] sealed trait MapStatus { /** Location where this task was run. */ def location: BlockManagerId - def mapTaskAttemptId: Long - /** * Estimated size for the reduce block, in bytes. * @@ -58,12 +56,11 @@ private[spark] object MapStatus { .map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS)) .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get) - def apply(loc: BlockManagerId, uncompressedSizes: Array[Long], mapTaskAttemptId: Long) - : MapStatus = { + def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = { if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) { - HighlyCompressedMapStatus(loc, uncompressedSizes, mapTaskAttemptId) + HighlyCompressedMapStatus(loc, uncompressedSizes) } else { - new CompressedMapStatus(loc, uncompressedSizes, mapTaskAttemptId) + new CompressedMapStatus(loc, uncompressedSizes) } } @@ -106,15 +103,13 @@ private[spark] object MapStatus { */ private[spark] class CompressedMapStatus( private[this] var loc: BlockManagerId, - private[this] var compressedSizes: Array[Byte], - private[this] var attemptNum: Long) + private[this] var compressedSizes: Array[Byte]) extends MapStatus with Externalizable { - // For deserialization only - protected def this() = this(null, null.asInstanceOf[Array[Byte]], -1) + protected def this() = this(null, null.asInstanceOf[Array[Byte]]) // For deserialization only - def this(loc: BlockManagerId, uncompressedSizes: Array[Long], mapTaskAttemptId: Long) { - this(loc, uncompressedSizes.map(MapStatus.compressSize), mapTaskAttemptId) + def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) { + this(loc, uncompressedSizes.map(MapStatus.compressSize)) } override def location: BlockManagerId = loc @@ -124,30 +119,17 @@ private[spark] class CompressedMapStatus( } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { - if (loc != null) { - out.writeBoolean(true) - loc.writeExternal(out) - } else { - out.writeBoolean(false) - } + loc.writeExternal(out) out.writeInt(compressedSizes.length) out.write(compressedSizes) - out.writeLong(attemptNum) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { - if (in.readBoolean()) { - loc = BlockManagerId(in) - } else { - loc = null - } + loc = BlockManagerId(in) val len = in.readInt() compressedSizes = new Array[Byte](len) in.readFully(compressedSizes) - attemptNum = in.readLong() } - - override def mapTaskAttemptId: Long = attemptNum } /** @@ -166,15 +148,14 @@ private[spark] class HighlyCompressedMapStatus private ( private[this] var numNonEmptyBlocks: Int, private[this] var emptyBlocks: RoaringBitmap, private[this] var avgSize: Long, - private[this] var hugeBlockSizes: scala.collection.Map[Int, Byte], - private[this] var attemptNum: Long) + private[this] var hugeBlockSizes: scala.collection.Map[Int, Byte]) extends MapStatus with Externalizable { // loc could be null when the default constructor is called during deserialization require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0, "Average size can only be zero for map stages that produced no output") - protected def this() = this(null, -1, null, -1, null, -1) // For deserialization only + protected def this() = this(null, -1, null, -1, null) // For deserialization only override def location: BlockManagerId = loc @@ -191,12 +172,7 @@ private[spark] class HighlyCompressedMapStatus private ( } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { - if (loc != null) { - out.writeBoolean(true) - loc.writeExternal(out) - } else { - out.writeBoolean(false) - } + loc.writeExternal(out) emptyBlocks.writeExternal(out) out.writeLong(avgSize) out.writeInt(hugeBlockSizes.size) @@ -204,15 +180,10 @@ private[spark] class HighlyCompressedMapStatus private ( out.writeInt(kv._1) out.writeByte(kv._2) } - out.writeLong(attemptNum) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { - if (in.readBoolean()) { - loc = BlockManagerId(in) - } else { - loc = null - } + loc = BlockManagerId(in) emptyBlocks = new RoaringBitmap() emptyBlocks.readExternal(in) avgSize = in.readLong() @@ -224,15 +195,11 @@ private[spark] class HighlyCompressedMapStatus private ( hugeBlockSizesImpl(block) = size } hugeBlockSizes = hugeBlockSizesImpl - attemptNum = in.readLong() } - - override def mapTaskAttemptId: Long = attemptNum } private[spark] object HighlyCompressedMapStatus { - def apply(loc: BlockManagerId, uncompressedSizes: Array[Long], mapTaskAttemptId: Long) - : HighlyCompressedMapStatus = { + def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = { // We must keep track of which blocks are empty so that we don't report a zero-sized // block as being non-empty (or vice-versa) when using the average block size. var i = 0 @@ -273,6 +240,6 @@ private[spark] object HighlyCompressedMapStatus { emptyBlocks.trim() emptyBlocks.runOptimize() new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize, - hugeBlockSizes, mapTaskAttemptId) + hugeBlockSizes) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index e614dbc8c9542..c5eefc7c5c049 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -17,15 +17,10 @@ package org.apache.spark.shuffle -import scala.collection.JavaConverters._ - import org.apache.spark._ -import org.apache.spark.api.java.Optional import org.apache.spark.internal.{config, Logging} -import org.apache.spark.io.CompressionCodec import org.apache.spark.serializer.SerializerManager -import org.apache.spark.shuffle.api.{ShuffleBlockInfo, ShuffleExecutorComponents} -import org.apache.spark.storage.ShuffleBlockAttemptId +import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter @@ -39,57 +34,33 @@ private[spark] class BlockStoreShuffleReader[K, C]( endPartition: Int, context: TaskContext, readMetrics: ShuffleReadMetricsReporter, - shuffleExecutorComponents: ShuffleExecutorComponents, serializerManager: SerializerManager = SparkEnv.get.serializerManager, - mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker, - sparkConf: SparkConf = SparkEnv.get.conf) + blockManager: BlockManager = SparkEnv.get.blockManager, + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) extends ShuffleReader[K, C] with Logging { private val dep = handle.dependency - private val compressionCodec = CompressionCodec.createCodec(sparkConf) - - private val compressShuffle = sparkConf.get(config.SHUFFLE_COMPRESS) - /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - val streamsIterator = - shuffleExecutorComponents.getPartitionReaders(new Iterable[ShuffleBlockInfo] { - override def iterator: Iterator[ShuffleBlockInfo] = { - mapOutputTracker - .getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition) - .flatMap { shuffleLocationInfo => - shuffleLocationInfo._2.map { blockInfo => - val block = blockInfo._1.asInstanceOf[ShuffleBlockAttemptId] - new ShuffleBlockInfo( - block.shuffleId, - block.mapId, - block.reduceId, - blockInfo._2, - block.mapTaskAttemptId, - Optional.ofNullable(shuffleLocationInfo._1.orNull)) - } - } - } - }.asJava).iterator() - - val retryingWrappedStreams = streamsIterator.asScala.map(rawReaderStream => { - if (shuffleExecutorComponents.shouldWrapPartitionReaderStream()) { - if (compressShuffle) { - compressionCodec.compressedInputStream( - serializerManager.wrapForEncryption(rawReaderStream)) - } else { - serializerManager.wrapForEncryption(rawReaderStream) - } - } else { - // The default implementation checks for corrupt streams, so it will already have - // decompressed/decrypted the bytes - rawReaderStream - } - }) + val wrappedStreams = new ShuffleBlockFetcherIterator( + context, + blockManager.shuffleClient, + blockManager, + mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition), + serializerManager.wrapStream, + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024, + SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT), + SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), + SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), + SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), + readMetrics).toCompletionIterator val serializerInstance = dep.serializer.newInstance() - val recordIter = retryingWrappedStreams.flatMap { wrappedStream => + + // Create a key/value iterator for each stream + val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) => // Note: the asKeyValueIterator below wraps a key/value iterator inside of a // NextIterator. The NextIterator makes sure that close() is called on the // underlying InputStream when all records have been read. diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala index a7ccb35d9c64c..265a8acfa8d61 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala @@ -18,7 +18,7 @@ package org.apache.spark.shuffle import org.apache.spark.{FetchFailed, TaskContext, TaskFailedReason} -import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} +import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils /** @@ -58,8 +58,6 @@ private[spark] class FetchFailedException( def toTaskFailedReason: TaskFailedReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId, Utils.exceptionString(this)) - - def getShuffleBlockId(): ShuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) } /** @@ -70,12 +68,3 @@ private[spark] class MetadataFetchFailedException( reduceId: Int, message: String) extends FetchFailedException(null, shuffleId, -1, reduceId, message) - -private[spark] class RemoteFetchFailedException( - shuffleId: Int, - mapId: Int, - reduceId: Int, - message: String, - host: String, - port: Int) - extends FetchFailedException(BlockManagerId(host, port), shuffleId, mapId, reduceId, message) diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index e5aad891541f6..d3f1c7ec1bbee 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -55,7 +55,7 @@ private[spark] class IndexShuffleBlockResolver( blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) } - def getIndexFile(shuffleId: Int, mapId: Int): File = { + private def getIndexFile(shuffleId: Int, mapId: Int): File = { blockManager.diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala deleted file mode 100644 index e0affb858c359..0000000000000 --- a/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala +++ /dev/null @@ -1,135 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle - -import java.io.{Closeable, IOException, OutputStream} - -import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} -import org.apache.spark.shuffle.api.ShufflePartitionWriter -import org.apache.spark.storage.{BlockId, TimeTrackingOutputStream} -import org.apache.spark.util.Utils -import org.apache.spark.util.collection.PairsWriter - -/** - * A key-value writer inspired by {@link DiskBlockObjectWriter} that pushes the bytes to an - * arbitrary partition writer instead of writing to local disk through the block manager. - */ -private[spark] class ShufflePartitionPairsWriter( - partitionWriter: ShufflePartitionWriter, - serializerManager: SerializerManager, - serializerInstance: SerializerInstance, - blockId: BlockId, - writeMetrics: ShuffleWriteMetricsReporter) - extends PairsWriter with Closeable { - - private var isClosed = false - private var partitionStream: OutputStream = _ - private var timeTrackingStream: OutputStream = _ - private var wrappedStream: OutputStream = _ - private var objOut: SerializationStream = _ - private var numRecordsWritten = 0 - private var curNumBytesWritten = 0L - - override def write(key: Any, value: Any): Unit = { - if (isClosed) { - throw new IOException("Partition pairs writer is already closed.") - } - if (objOut == null) { - open() - } - objOut.writeKey(key) - objOut.writeValue(value) - recordWritten() - } - - private def open(): Unit = { - try { - partitionStream = partitionWriter.openStream - timeTrackingStream = new TimeTrackingOutputStream(writeMetrics, partitionStream) - wrappedStream = serializerManager.wrapStream(blockId, timeTrackingStream) - objOut = serializerInstance.serializeStream(wrappedStream) - } catch { - case e: Exception => - Utils.tryLogNonFatalError { - close() - } - throw e - } - } - - override def close(): Unit = { - if (!isClosed) { - Utils.tryWithSafeFinally { - Utils.tryWithSafeFinally { - objOut = closeIfNonNull(objOut) - // Setting these to null will prevent the underlying streams from being closed twice - // just in case any stream's close() implementation is not idempotent. - wrappedStream = null - timeTrackingStream = null - partitionStream = null - } { - // Normally closing objOut would close the inner streams as well, but just in case there - // was an error in initialization etc. we make sure we clean the other streams up too. - Utils.tryWithSafeFinally { - wrappedStream = closeIfNonNull(wrappedStream) - // Same as above - if wrappedStream closes then assume it closes underlying - // partitionStream and don't close again in the finally - timeTrackingStream = null - partitionStream = null - } { - Utils.tryWithSafeFinally { - timeTrackingStream = closeIfNonNull(timeTrackingStream) - partitionStream = null - } { - partitionStream = closeIfNonNull(partitionStream) - } - } - } - updateBytesWritten() - } { - isClosed = true - } - } - } - - private def closeIfNonNull[T <: Closeable](closeable: T): T = { - if (closeable != null) { - closeable.close() - } - null.asInstanceOf[T] - } - - /** - * Notify the writer that a record worth of bytes has been written with OutputStream#write. - */ - private def recordWritten(): Unit = { - numRecordsWritten += 1 - writeMetrics.incRecordsWritten(1) - - if (numRecordsWritten % 16384 == 0) { - updateBytesWritten() - } - } - - private def updateBytesWritten(): Unit = { - val numBytesWritten = partitionWriter.getNumBytesWritten - val bytesWrittenDiff = numBytesWritten - curNumBytesWritten - writeMetrics.incBytesWritten(bytesWrittenDiff) - curNumBytesWritten = numBytesWritten - } -} diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/LocalDiskShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/LocalDiskShuffleReadSupport.scala deleted file mode 100644 index 9e1c1816d306c..0000000000000 --- a/core/src/main/scala/org/apache/spark/shuffle/io/LocalDiskShuffleReadSupport.scala +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.io - -import java.io.InputStream - -import scala.collection.JavaConverters._ - -import org.apache.spark.{MapOutputTracker, SparkConf, TaskContext} -import org.apache.spark.internal.config -import org.apache.spark.serializer.SerializerManager -import org.apache.spark.shuffle.ShuffleReadMetricsReporter -import org.apache.spark.shuffle.api.ShuffleBlockInfo -import org.apache.spark.storage.{BlockManager, ShuffleBlockAttemptId, ShuffleBlockFetcherIterator, ShuffleBlockId} - -class LocalDiskShuffleReadSupport( - blockManager: BlockManager, - mapOutputTracker: MapOutputTracker, - serializerManager: SerializerManager, - conf: SparkConf) { - - private val maxBytesInFlight = conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024 - private val maxReqsInFlight = conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT) - private val maxBlocksInFlightPerAddress = - conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS) - private val maxReqSizeShuffleToMem = conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) - private val detectCorrupt = conf.get(config.SHUFFLE_DETECT_CORRUPT) - - def getPartitionReaders(blockMetadata: java.lang.Iterable[ShuffleBlockInfo]): - java.lang.Iterable[InputStream] = { - - val iterableToReturn = if (blockMetadata.asScala.isEmpty) { - Iterable.empty - } else { - val (minReduceId, maxReduceId) = blockMetadata.asScala.map(block => block.getReduceId) - .foldLeft(Int.MaxValue, 0) { - case ((min, max), elem) => (math.min(min, elem), math.max(max, elem)) - } - val shuffleId = blockMetadata.asScala.head.getShuffleId - new ShuffleBlockFetcherIterable( - TaskContext.get(), - blockManager, - serializerManager, - maxBytesInFlight, - maxReqsInFlight, - maxBlocksInFlightPerAddress, - maxReqSizeShuffleToMem, - detectCorrupt, - shuffleMetrics = TaskContext.get().taskMetrics().createTempShuffleReadMetrics(), - minReduceId, - maxReduceId, - shuffleId, - mapOutputTracker - ) - } - iterableToReturn.asJava - } -} - -private class ShuffleBlockFetcherIterable( - context: TaskContext, - blockManager: BlockManager, - serializerManager: SerializerManager, - maxBytesInFlight: Long, - maxReqsInFlight: Int, - maxBlocksInFlightPerAddress: Int, - maxReqSizeShuffleToMem: Long, - detectCorruption: Boolean, - shuffleMetrics: ShuffleReadMetricsReporter, - minReduceId: Int, - maxReduceId: Int, - shuffleId: Int, - mapOutputTracker: MapOutputTracker) extends Iterable[InputStream] { - - override def iterator: Iterator[InputStream] = { - new ShuffleBlockFetcherIterator( - context, - blockManager.shuffleClient, - blockManager, - mapOutputTracker.getMapSizesByExecutorId(shuffleId, minReduceId, maxReduceId + 1) - .map(loc => ( - loc._1.get, - loc._2.map { case(shuffleBlockAttemptId, size) => - val block = shuffleBlockAttemptId.asInstanceOf[ShuffleBlockAttemptId] - (ShuffleBlockId(block.shuffleId, block.mapId, block.reduceId), size) - })), - serializerManager.wrapStream, - maxBytesInFlight, - maxReqsInFlight, - maxBlocksInFlightPerAddress, - maxReqSizeShuffleToMem, - detectCorruption, - shuffleMetrics).toCompletionIterator - } - -} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 610c04ace3b6f..b59fa8e8a3ccd 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -19,13 +19,9 @@ package org.apache.spark.shuffle.sort import java.util.concurrent.ConcurrentHashMap -import scala.collection.JavaConverters._ - import org.apache.spark._ -import org.apache.spark.internal.{config, Logging} +import org.apache.spark.internal.Logging import org.apache.spark.shuffle._ -import org.apache.spark.shuffle.api.{ShuffleDataIO, ShuffleExecutorComponents} -import org.apache.spark.util.Utils /** * In sort-based shuffle, incoming records are sorted according to their target partition ids, then @@ -72,8 +68,6 @@ import org.apache.spark.util.Utils */ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { - import SortShuffleManager._ - if (!conf.getBoolean("spark.shuffle.spill", true)) { logWarning( "spark.shuffle.spill was set to false, but this configuration is ignored as of Spark 1.6+." + @@ -85,8 +79,6 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager */ private[this] val numMapsForShuffle = new ConcurrentHashMap[Int, Int]() - private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf) - override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) /** @@ -126,11 +118,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], - startPartition, - endPartition, - context, - metrics, - shuffleExecutorComponents) + startPartition, endPartition, context, metrics) } /** Get a writer for a given partition. Called on executors by map tasks. */ @@ -146,25 +134,23 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => new UnsafeShuffleWriter( env.blockManager, + shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], context.taskMemoryManager(), unsafeShuffleHandle, mapId, context, env.conf, - metrics, - shuffleExecutorComponents) + metrics) case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => new BypassMergeSortShuffleWriter( env.blockManager, + shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], bypassMergeSortHandle, mapId, - context.taskAttemptId(), env.conf, - metrics, - shuffleExecutorComponents) + metrics) case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => - new SortShuffleWriter( - shuffleBlockResolver, other, mapId, context, shuffleExecutorComponents) + new SortShuffleWriter(shuffleBlockResolver, other, mapId, context) } } @@ -219,21 +205,6 @@ private[spark] object SortShuffleManager extends Logging { true } } - - private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = { - val configuredPluginClasses = conf.get(config.SHUFFLE_IO_PLUGIN_CLASS) - val maybeIO = Utils.loadExtensions( - classOf[ShuffleDataIO], Seq(configuredPluginClasses), conf) - require(maybeIO.size == 1, s"Failed to load plugins of type $configuredPluginClasses") - val executorComponents = maybeIO.head.executor() - val extraConfigs = conf.getAllWithPrefix(ShuffleDataIO.SHUFFLE_SPARK_CONF_PREFIX) - .toMap - executorComponents.initializeExecutor( - conf.getAppId, - SparkEnv.get.executorId, - extraConfigs.asJava) - executorComponents - } } /** diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 0082b4c9c6b24..16058de8bf3ff 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -21,15 +21,15 @@ import org.apache.spark._ import org.apache.spark.internal.{config, Logging} import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleWriter} -import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.util.Utils import org.apache.spark.util.collection.ExternalSorter private[spark] class SortShuffleWriter[K, V, C]( shuffleBlockResolver: IndexShuffleBlockResolver, handle: BaseShuffleHandle[K, V, C], mapId: Int, - context: TaskContext, - shuffleExecutorComponents: ShuffleExecutorComponents) + context: TaskContext) extends ShuffleWriter[K, V] with Logging { private val dep = handle.dependency @@ -64,14 +64,18 @@ private[spark] class SortShuffleWriter[K, V, C]( // Don't bother including the time to open the merged output file in the shuffle write time, // because it just opens a single file, so is typically too fast to measure accurately // (see SPARK-3570). - val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter( - dep.shuffleId, mapId, context.taskAttemptId(), dep.partitioner.numPartitions) - sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter) - val commitMessage = mapOutputWriter.commitAllPartitions() - mapStatus = MapStatus( - commitMessage.getLocation.orElse(null), - commitMessage.getPartitionLengths, - context.taskAttemptId()) + val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId) + val tmp = Utils.tempFileWith(output) + try { + val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) + val partitionLengths = sorter.writePartitionedFile(blockId, tmp) + shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) + } finally { + if (tmp.exists() && !tmp.delete()) { + logError(s"Error while deleting temp file ${tmp.getAbsolutePath}") + } + } } /** Close this writer, passing along whether the map completed */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index ac67ca284f926..7ac2c71c18eb3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -20,7 +20,7 @@ package org.apache.spark.storage import java.util.UUID import org.apache.spark.SparkException -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.DeveloperApi /** * :: DeveloperApi :: @@ -56,13 +56,6 @@ case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends Blo override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId } -@Experimental -case class ShuffleBlockAttemptId(shuffleId: Int, mapId: Int, reduceId: Int, mapTaskAttemptId: Long) - extends BlockId { - override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + - reduceId + "_" + mapTaskAttemptId -} - @DeveloperApi case class ShuffleDataBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".data" diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index 8d66cbbfb7562..d4a59c33b974c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -70,12 +70,7 @@ class BlockManagerId private ( } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { - if (executorId_ != null) { - out.writeBoolean(true) - out.writeUTF(executorId_) - } else { - out.writeBoolean(false) - } + out.writeUTF(executorId_) out.writeUTF(host_) out.writeInt(port_) out.writeBoolean(topologyInfo_.isDefined) @@ -84,9 +79,7 @@ class BlockManagerId private ( } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { - if (in.readBoolean()) { - executorId_ = in.readUTF() - } + executorId_ = in.readUTF() host_ = in.readUTF() port_ = in.readInt() val isTopologyInfoAvailable = in.readBoolean() @@ -98,13 +91,8 @@ class BlockManagerId private ( override def toString: String = s"BlockManagerId($executorId, $host, $port, $topologyInfo)" - override def hashCode: Int = { - if (executorId != null) { - ((executorId.hashCode * 41 + host.hashCode) * 41 + port) * 41 + topologyInfo.hashCode - } else { - (host.hashCode * 41 + port) * 41 + topologyInfo.hashCode - } - } + override def hashCode: Int = + ((executorId.hashCode * 41 + host.hashCode) * 41 + port) * 41 + topologyInfo.hashCode override def equals(that: Any): Boolean = that match { case id: BlockManagerId => @@ -139,9 +127,6 @@ private[spark] object BlockManagerId { topologyInfo: Option[String] = None): BlockManagerId = getCachedBlockManagerId(new BlockManagerId(execId, host, port, topologyInfo)) - def apply(host: String, port: Int): BlockManagerId = - getCachedBlockManagerId(new BlockManagerId(null, host, port, None)) - def apply(in: ObjectInput): BlockManagerId = { val obj = new BlockManagerId() obj.readExternal(in) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index 758621c52495b..17390f9c60e79 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -24,7 +24,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} import org.apache.spark.shuffle.ShuffleWriteMetricsReporter import org.apache.spark.util.Utils -import org.apache.spark.util.collection.PairsWriter /** * A class for writing JVM objects directly to a file on disk. This class allows data to be appended @@ -47,8 +46,7 @@ private[spark] class DiskBlockObjectWriter( writeMetrics: ShuffleWriteMetricsReporter, val blockId: BlockId = null) extends OutputStream - with Logging - with PairsWriter { + with Logging { /** * Guards against close calls, e.g. from a wrapping stream. @@ -234,7 +232,7 @@ private[spark] class DiskBlockObjectWriter( /** * Writes a key-value pair. */ - override def write(key: Any, value: Any) { + def write(key: Any, value: Any) { if (!streamOpen) { open() } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 9ff21ae00060f..3966980a11ed0 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -74,7 +74,7 @@ final class ShuffleBlockFetcherIterator( maxReqSizeShuffleToMem: Long, detectCorrupt: Boolean, shuffleMetrics: ShuffleReadMetricsReporter) - extends Iterator[InputStream] with DownloadFileManager with Logging { + extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging { import ShuffleBlockFetcherIterator._ @@ -397,7 +397,7 @@ final class ShuffleBlockFetcherIterator( * * Throws a FetchFailedException if the next block could not be fetched. */ - override def next(): InputStream = { + override def next(): (BlockId, InputStream) = { if (!hasNext) { throw new NoSuchElementException() } @@ -495,6 +495,7 @@ final class ShuffleBlockFetcherIterator( in.close() } } + case FailureFetchResult(blockId, address, e) => throwFetchFailedException(blockId, address, e) } @@ -507,16 +508,11 @@ final class ShuffleBlockFetcherIterator( throw new NoSuchElementException() } currentResult = result.asInstanceOf[SuccessFetchResult] - new BufferReleasingInputStream(input, this) - } - - // for testing only - def getCurrentBlock(): ShuffleBlockId = { - currentResult.blockId.asInstanceOf[ShuffleBlockId] + (currentResult.blockId, new BufferReleasingInputStream(input, this)) } - def toCompletionIterator: Iterator[InputStream] = { - CompletionIterator[InputStream, this.type](this, + def toCompletionIterator: Iterator[(BlockId, InputStream)] = { + CompletionIterator[(BlockId, InputStream), this.type](this, onCompleteCallback.onComplete(context)) } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index c466842aec35e..73f700f7b2001 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -24,7 +24,7 @@ import java.lang.reflect.InvocationTargetException import java.math.{MathContext, RoundingMode} import java.net._ import java.nio.ByteBuffer -import java.nio.channels.{Channels, FileChannel, WritableByteChannel} +import java.nio.channels.{Channels, FileChannel} import java.nio.charset.StandardCharsets import java.nio.file.Files import java.security.SecureRandom @@ -337,14 +337,10 @@ private[spark] object Utils extends Logging { def copyFileStreamNIO( input: FileChannel, - output: WritableByteChannel, + output: FileChannel, startPosition: Long, bytesToCopy: Long): Unit = { - val outputInitialState = output match { - case outputFileChannel: FileChannel => - Some((outputFileChannel.position(), outputFileChannel)) - case _ => None - } + val initialPos = output.position() var count = 0L // In case transferTo method transferred less data than we have required. while (count < bytesToCopy) { @@ -359,17 +355,15 @@ private[spark] object Utils extends Logging { // kernel version 2.6.32, this issue can be seen in // https://bugs.openjdk.java.net/browse/JDK-7052359 // This will lead to stream corruption issue when using sort-based shuffle (SPARK-3948). - outputInitialState.foreach { case (initialPos, outputFileChannel) => - val finalPos = outputFileChannel.position() - val expectedPos = initialPos + bytesToCopy - assert(finalPos == expectedPos, - s""" - |Current position $finalPos do not equal to expected position $expectedPos - |after transferTo, please check your kernel version to see if it is 2.6.32, - |this is a kernel bug which will lead to unexpected behavior when using transferTo. - |You can set spark.file.transferTo = false to disable this NIO feature. - """.stripMargin) - } + val finalPos = output.position() + val expectedPos = initialPos + bytesToCopy + assert(finalPos == expectedPos, + s""" + |Current position $finalPos do not equal to expected position $expectedPos + |after transferTo, please check your kernel version to see if it is 2.6.32, + |this is a kernel bug which will lead to unexpected behavior when using transferTo. + |You can set spark.file.transferTo = false to disable this NIO feature. + """.stripMargin) } /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 1216a45415a74..4806c13967253 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -29,10 +29,7 @@ import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.{config, Logging} import org.apache.spark.serializer._ -import org.apache.spark.shuffle.ShufflePartitionPairsWriter -import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter} -import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId} -import org.apache.spark.util.{Utils => TryUtils} +import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} /** * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner @@ -677,9 +674,11 @@ private[spark] class ExternalSorter[K, V, C]( } /** - * TODO(SPARK-28764): remove this, as this is only used by UnsafeRowSerializerSuite in the SQL - * project. We should figure out an alternative way to test that so that we can remove this - * otherwise unused code path. + * Write all the data added into this ExternalSorter into a file in the disk store. This is + * called by the SortShuffleWriter. + * + * @param blockId block ID to write to. The index file will be blockId.name + ".index". + * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) */ def writePartitionedFile( blockId: BlockId, @@ -723,74 +722,6 @@ private[spark] class ExternalSorter[K, V, C]( lengths } - /** - * Write all the data added into this ExternalSorter into a map output writer that pushes bytes - * to some arbitrary backing store. This is called by the SortShuffleWriter. - * - * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) - */ - def writePartitionedMapOutput( - shuffleId: Int, - mapId: Int, - mapOutputWriter: ShuffleMapOutputWriter): Unit = { - if (spills.isEmpty) { - // Case where we only have in-memory data - val collection = if (aggregator.isDefined) map else buffer - val it = collection.destructiveSortedWritablePartitionedIterator(comparator) - while (it.hasNext()) { - val partitionId = it.nextPartition() - var partitionWriter: ShufflePartitionWriter = null - var partitionPairsWriter: ShufflePartitionPairsWriter = null - TryUtils.tryWithSafeFinally { - partitionWriter = mapOutputWriter.getPartitionWriter(partitionId) - val blockId = ShuffleBlockId(shuffleId, mapId, partitionId) - partitionPairsWriter = new ShufflePartitionPairsWriter( - partitionWriter, - serializerManager, - serInstance, - blockId, - context.taskMetrics().shuffleWriteMetrics) - while (it.hasNext && it.nextPartition() == partitionId) { - it.writeNext(partitionPairsWriter) - } - } { - if (partitionPairsWriter != null) { - partitionPairsWriter.close() - } - } - } - } else { - // We must perform merge-sort; get an iterator by partition and write everything directly. - for ((id, elements) <- this.partitionedIterator) { - val blockId = ShuffleBlockId(shuffleId, mapId, id) - var partitionWriter: ShufflePartitionWriter = null - var partitionPairsWriter: ShufflePartitionPairsWriter = null - TryUtils.tryWithSafeFinally { - partitionWriter = mapOutputWriter.getPartitionWriter(id) - partitionPairsWriter = new ShufflePartitionPairsWriter( - partitionWriter, - serializerManager, - serInstance, - blockId, - context.taskMetrics().shuffleWriteMetrics) - if (elements.hasNext) { - for (elem <- elements) { - partitionPairsWriter.write(elem._1, elem._2) - } - } - } { - if (partitionPairsWriter != null) { - partitionPairsWriter.close() - } - } - } - } - - context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled) - context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled) - context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes) - } - def stop(): Unit = { spills.foreach(s => s.file.delete()) spills.clear() @@ -854,7 +785,7 @@ private[spark] class ExternalSorter[K, V, C]( val inMemoryIterator = new WritablePartitionedIterator { private[this] var cur = if (upstream.hasNext) upstream.next() else null - def writeNext(writer: PairsWriter): Unit = { + def writeNext(writer: DiskBlockObjectWriter): Unit = { writer.write(cur._1._2, cur._2) cur = if (upstream.hasNext) upstream.next() else null } diff --git a/core/src/main/scala/org/apache/spark/util/collection/PairsWriter.scala b/core/src/main/scala/org/apache/spark/util/collection/PairsWriter.scala deleted file mode 100644 index 05ed72c3e3778..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/collection/PairsWriter.scala +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.collection - -/** - * An abstraction of a consumer of key-value pairs, primarily used when - * persisting partitioned data, either through the shuffle writer plugins - * or via DiskBlockObjectWriter. - */ -private[spark] trait PairsWriter { - - def write(key: Any, value: Any): Unit -} diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala index 337b0673b4031..5232c2bd8d6f6 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala @@ -52,7 +52,7 @@ private[spark] trait WritablePartitionedPairCollection[K, V] { new WritablePartitionedIterator { private[this] var cur = if (it.hasNext) it.next() else null - def writeNext(writer: PairsWriter): Unit = { + def writeNext(writer: DiskBlockObjectWriter): Unit = { writer.write(cur._1._2, cur._2) cur = if (it.hasNext) it.next() else null } @@ -96,7 +96,7 @@ private[spark] object WritablePartitionedPairCollection { * has an associated partition. */ private[spark] trait WritablePartitionedIterator { - def writeNext(writer: PairsWriter): Unit + def writeNext(writer: DiskBlockObjectWriter): Unit def hasNext(): Boolean diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 18f3a339e246c..9bf707f783d44 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -19,10 +19,8 @@ import java.io.*; import java.nio.ByteBuffer; -import java.nio.file.Files; import java.util.*; -import org.mockito.stubbing.Answer; import scala.Option; import scala.Product2; import scala.Tuple2; @@ -38,11 +36,9 @@ import org.mockito.MockitoAnnotations; import org.apache.spark.HashPartitioner; -import org.apache.spark.MapOutputTracker; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; -import org.apache.spark.TaskContext$; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.executor.TaskMetrics; import org.apache.spark.io.CompressionCodec$; @@ -57,7 +53,6 @@ import org.apache.spark.security.CryptoStreamUtils; import org.apache.spark.serializer.*; import org.apache.spark.shuffle.IndexShuffleBlockResolver; -import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents; import org.apache.spark.storage.*; import org.apache.spark.util.Utils; @@ -70,7 +65,6 @@ public class UnsafeShuffleWriterSuite { - static final int DEFAULT_INITIAL_SORT_BUFFER_SIZE = 4096; static final int NUM_PARTITITONS = 4; TestMemoryManager memoryManager; TaskMemoryManager taskMemoryManager; @@ -88,12 +82,9 @@ public class UnsafeShuffleWriterSuite { @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext; @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency shuffleDep; - @Mock(answer = RETURNS_SMART_NULLS) MapOutputTracker mapOutputTracker; - @Mock(answer = RETURNS_SMART_NULLS) SerializerManager serializerManager; @After public void tearDown() { - TaskContext$.MODULE$.unset(); Utils.deleteRecursively(tempDir); final long leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory(); if (leakedMemory != 0) { @@ -141,27 +132,14 @@ public void setUp() throws IOException { }); when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile); - Answer renameTempAnswer = invocationOnMock -> { + doAnswer(invocationOnMock -> { partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2]; File tmp = (File) invocationOnMock.getArguments()[3]; - if (!mergedOutputFile.delete()) { - throw new RuntimeException("Failed to delete old merged output file."); - } - if (tmp != null) { - Files.move(tmp.toPath(), mergedOutputFile.toPath()); - } else if (!mergedOutputFile.createNewFile()) { - throw new RuntimeException("Failed to create empty merged output file."); - } + mergedOutputFile.delete(); + tmp.renameTo(mergedOutputFile); return null; - }; - - doAnswer(renameTempAnswer) - .when(shuffleBlockResolver) - .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), any(File.class)); - - doAnswer(renameTempAnswer) - .when(shuffleBlockResolver) - .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), eq(null)); + }).when(shuffleBlockResolver) + .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), any(File.class)); when(diskBlockManager.createTempShuffleBlock()).thenAnswer(invocationOnMock -> { TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID()); @@ -173,26 +151,21 @@ public void setUp() throws IOException { when(taskContext.taskMetrics()).thenReturn(taskMetrics); when(shuffleDep.serializer()).thenReturn(serializer); when(shuffleDep.partitioner()).thenReturn(hashPartitioner); - when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager); } - private UnsafeShuffleWriter createWriter(boolean transferToEnabled) { + private UnsafeShuffleWriter createWriter( + boolean transferToEnabled) throws IOException { conf.set("spark.file.transferTo", String.valueOf(transferToEnabled)); - return new UnsafeShuffleWriter( + return new UnsafeShuffleWriter<>( blockManager, + shuffleBlockResolver, taskMemoryManager, new SerializedShuffleHandle<>(0, 1, shuffleDep), 0, // map id taskContext, conf, - taskContext.taskMetrics().shuffleWriteMetrics(), - new LocalDiskShuffleExecutorComponents( - conf, - blockManager, - mapOutputTracker, - serializerManager, - shuffleBlockResolver, - BlockManagerId.apply("localhost", 7077))); + taskContext.taskMetrics().shuffleWriteMetrics() + ); } private void assertSpillFilesWereCleanedUp() { @@ -418,7 +391,7 @@ public void mergeSpillsWithFileStreamAndCompressionAndEncryption() throws Except @Test public void mergeSpillsWithCompressionAndEncryptionSlowPath() throws Exception { - conf.set(package$.MODULE$.SHUFFLE_UNSAFE_FAST_MERGE_ENABLE(), false); + conf.set(package$.MODULE$.SHUFFLE_UNDAFE_FAST_MERGE_ENABLE(), false); testMergingSpills(false, LZ4CompressionCodec.class.getName(), true); } @@ -471,10 +444,10 @@ public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpillRadixOn() thro } private void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception { - memoryManager.limit(DEFAULT_INITIAL_SORT_BUFFER_SIZE * 16); + memoryManager.limit(UnsafeShuffleWriter.DEFAULT_INITIAL_SORT_BUFFER_SIZE * 16); final UnsafeShuffleWriter writer = createWriter(false); final ArrayList> dataToWrite = new ArrayList<>(); - for (int i = 0; i < DEFAULT_INITIAL_SORT_BUFFER_SIZE + 1; i++) { + for (int i = 0; i < UnsafeShuffleWriter.DEFAULT_INITIAL_SORT_BUFFER_SIZE + 1; i++) { dataToWrite.add(new Tuple2<>(i, i)); } writer.write(dataToWrite.iterator()); @@ -543,21 +516,16 @@ public void testPeakMemoryUsed() throws Exception { final long numRecordsPerPage = pageSizeBytes / recordLengthBytes; taskMemoryManager = spy(taskMemoryManager); when(taskMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes); - final UnsafeShuffleWriter writer = new UnsafeShuffleWriter( + final UnsafeShuffleWriter writer = + new UnsafeShuffleWriter<>( blockManager, + shuffleBlockResolver, taskMemoryManager, new SerializedShuffleHandle<>(0, 1, shuffleDep), 0, // map id taskContext, conf, - taskContext.taskMetrics().shuffleWriteMetrics(), - new LocalDiskShuffleExecutorComponents( - conf, - blockManager, - mapOutputTracker, - serializerManager, - shuffleBlockResolver, - BlockManagerId.apply("localhost", 7077))); + taskContext.taskMetrics().shuffleWriteMetrics()); // Peak memory should be monotonically increasing. More specifically, every time // we allocate a new page it should increase by exactly the size of the page. diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala index 28cbeeda7a88d..62824a5bec9d1 100644 --- a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala @@ -210,8 +210,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { /** * A special [[ContextCleaner]] that saves the IDs of the accumulators registered for cleanup. */ - private class SaveAccumContextCleaner(sc: SparkContext) extends - ContextCleaner(sc, null) { + private class SaveAccumContextCleaner(sc: SparkContext) extends ContextCleaner(sc) { private val accumsRegistered = new ArrayBuffer[Long] override def registerAccumulatorForCleanup(a: AccumulatorV2[_, _]): Unit = { diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 2b1258d7c923a..26a0fb0657af2 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.internal.config.Network.{RPC_ASK_TIMEOUT, RPC_MESSAGE_MA import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.{BlockManagerId, ShuffleBlockAttemptId, ShuffleBlockId} +import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} class MapOutputTrackerSuite extends SparkFunSuite { private val conf = new SparkConf @@ -64,15 +64,13 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(1000L, 10000L), 0)) + Array(1000L, 10000L))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(10000L, 1000L), 0)) + Array(10000L, 1000L))) val statuses = tracker.getMapSizesByExecutorId(10, 0) - assert(statuses.map(status => (status._1.get, status._2)).toSet === - Seq((BlockManagerId("b", "hostB", 1000), - ArrayBuffer((ShuffleBlockAttemptId(10, 1, 0, 0), size10000))), - (BlockManagerId("a", "hostA", 1000), - ArrayBuffer((ShuffleBlockAttemptId(10, 0, 0, 0), size1000)))) + assert(statuses.toSet === + Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))), + (BlockManagerId("b", "hostB", 1000), ArrayBuffer((ShuffleBlockId(10, 1, 0), size10000)))) .toSet) assert(0 == tracker.getNumCachedSerializedBroadcast) tracker.stop() @@ -88,9 +86,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(compressedSize1000, compressedSize10000), 0)) + Array(compressedSize1000, compressedSize10000))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(compressedSize10000, compressedSize1000), 0)) + Array(compressedSize10000, compressedSize1000))) assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty) assert(0 == tracker.getNumCachedSerializedBroadcast) @@ -111,9 +109,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(compressedSize1000, compressedSize1000, compressedSize1000), 0)) + Array(compressedSize1000, compressedSize1000, compressedSize1000))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(compressedSize10000, compressedSize1000, compressedSize1000), 0)) + Array(compressedSize10000, compressedSize1000, compressedSize1000))) assert(0 == tracker.getNumCachedSerializedBroadcast) // As if we had two simultaneous fetch failures @@ -149,12 +147,10 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) masterTracker.registerMapOutput(10, 0, MapStatus( - BlockManagerId("a", "hostA", 1000), Array(1000L), 0)) + BlockManagerId("a", "hostA", 1000), Array(1000L))) slaveTracker.updateEpoch(masterTracker.getEpoch) - assert(slaveTracker.getMapSizesByExecutorId(10, 0) - .map(status => (status._1.get, status._2)).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), - ArrayBuffer((ShuffleBlockAttemptId(10, 0, 0, 0), size1000))))) + assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === + Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) assert(0 == masterTracker.getNumCachedSerializedBroadcast) val masterTrackerEpochBeforeLossOfMapOutput = masterTracker.getEpoch @@ -188,7 +184,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { // Message size should be ~123B, and no exception should be thrown masterTracker.registerShuffle(10, 1) masterTracker.registerMapOutput(10, 0, MapStatus( - BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0), 0)) + BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0))) val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) when(rpcCallContext.senderAddress).thenReturn(senderAddress) @@ -222,11 +218,11 @@ class MapOutputTrackerSuite extends SparkFunSuite { // on hostB with output size 3 tracker.registerShuffle(10, 3) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(2L), 0)) + Array(2L))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(2L), 0)) + Array(2L))) tracker.registerMapOutput(10, 2, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(3L), 0)) + Array(3L))) // When the threshold is 50%, only host A should be returned as a preferred location // as it has 4 out of 7 bytes of output. @@ -266,7 +262,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerShuffle(20, 100) (0 until 100).foreach { i => masterTracker.registerMapOutput(20, i, new CompressedMapStatus( - BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 0)) + BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) } val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) @@ -315,18 +311,16 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(size0, size1000, size0, size10000), 0)) + Array(size0, size1000, size0, size10000))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(size10000, size0, size1000, size0), 0)) + Array(size10000, size0, size1000, size0))) assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0, 4).toSeq === Seq( - (Some(BlockManagerId("b", "hostB", 1000)), - Seq((ShuffleBlockAttemptId(10, 1, 0, 0), size10000), - (ShuffleBlockAttemptId(10, 1, 2, 0), size1000))), - (Some(BlockManagerId("a", "hostA", 1000)), - Seq((ShuffleBlockAttemptId(10, 0, 1, 0), size1000), - (ShuffleBlockAttemptId(10, 0, 3, 0), size10000))) + (BlockManagerId("a", "hostA", 1000), + Seq((ShuffleBlockId(10, 0, 1), size1000), (ShuffleBlockId(10, 0, 3), size10000))), + (BlockManagerId("b", "hostB", 1000), + Seq((ShuffleBlockId(10, 1, 0), size10000), (ShuffleBlockId(10, 1, 2), size1000))) ) ) @@ -335,120 +329,14 @@ class MapOutputTrackerSuite extends SparkFunSuite { rpcEnv.shutdown() } - test("shuffle map statuses with null blockManagerIds") { - val rpcEnv = createRpcEnv("test") - val tracker = newTrackerMaster() - tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) - tracker.registerShuffle(10, 3) - assert(tracker.containsShuffle(10)) - val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) - val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) - tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(1000L, 10000L), 0)) - tracker.registerMapOutput(10, 1, MapStatus(null, Array(10000L, 1000L), 0)) - tracker.registerMapOutput(10, 2, MapStatus(null, Array(1000L, 10000L), 0)) - var statuses = tracker.getMapSizesByExecutorId(10, 0) - assert(statuses.toSet === - Seq( - (None, - ArrayBuffer((ShuffleBlockAttemptId(10, 1, 0, 0), size10000), - (ShuffleBlockAttemptId(10, 2, 0, 0), size1000))), - (Some(BlockManagerId("a", "hostA", 1000)), - ArrayBuffer((ShuffleBlockAttemptId(10, 0, 0, 0), size1000)))) - .toSet) - assert(0 == tracker.getNumCachedSerializedBroadcast) - tracker.removeOutputsOnHost("hostA") - - tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(1000L, 10000L), 0)) - statuses = tracker.getMapSizesByExecutorId(10, 0) - assert(statuses.toSet === - Seq( - (None, - ArrayBuffer((ShuffleBlockAttemptId(10, 1, 0, 0), size10000), - (ShuffleBlockAttemptId(10, 2, 0, 0), size1000))), - (Some(BlockManagerId("b", "hostB", 1000)), - ArrayBuffer((ShuffleBlockAttemptId(10, 0, 0, 0), size1000)))) - .toSet) - tracker.unregisterMapOutput(10, 1, null) - - tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(1000L, 10000L), 0)) - statuses = tracker.getMapSizesByExecutorId(10, 0) - assert(statuses.toSet === - Seq( - (Some(BlockManagerId("b", "hostB", 1000)), - ArrayBuffer((ShuffleBlockAttemptId(10, 0, 0, 0), size1000), - (ShuffleBlockAttemptId(10, 1, 0, 0), size1000))), - (None, - ArrayBuffer((ShuffleBlockAttemptId(10, 2, 0, 0), size1000)))) - .toSet) - - val outputs = tracker.getLocationsWithLargestOutputs(10, 0, 2, 0.01) - assert(outputs.get.toSeq === Seq(BlockManagerId("b", "hostB", 1000))) - tracker.stop() - rpcEnv.shutdown() - } - - test("shuffle map statuses with null execIds") { - val rpcEnv = createRpcEnv("test") - val tracker = newTrackerMaster() - tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) - tracker.registerShuffle(10, 2) - assert(tracker.containsShuffle(10)) - val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) - val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) - tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId(null, "hostA", 1000), - Array(1000L, 10000L), 0)) - tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId(null, "hostB", 1000), - Array(10000L, 1000L), 0)) - var statuses = tracker.getMapSizesByExecutorId(10, 0) - assert(statuses.toSet === - Seq( - (Some(BlockManagerId(null, "hostB", 1000)), - ArrayBuffer((ShuffleBlockAttemptId(10, 1, 0, 0), size10000))), - (Some(BlockManagerId(null, "hostA", 1000)), - ArrayBuffer((ShuffleBlockAttemptId(10, 0, 0, 0), size1000)))) - .toSet) - assert(0 == tracker.getNumCachedSerializedBroadcast) - tracker.removeOutputsOnExecutor("a") - - statuses = tracker.getMapSizesByExecutorId(10, 0) - assert(statuses.toSet === - Seq( - (Some(BlockManagerId(null, "hostB", 1000)), - ArrayBuffer((ShuffleBlockAttemptId(10, 1, 0, 0), size10000))), - (Some(BlockManagerId(null, "hostA", 1000)), - ArrayBuffer((ShuffleBlockAttemptId(10, 0, 0, 0), size1000)))) - .toSet) - tracker.unregisterMapOutput(10, 1, BlockManagerId(null, "hostA", 1000)) - - tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(1000L, 10000L), 0)) - statuses = tracker.getMapSizesByExecutorId(10, 0) - assert(statuses.toSet === - Seq( - (Some(BlockManagerId(null, "hostB", 1000)), - ArrayBuffer((ShuffleBlockAttemptId(10, 1, 0, 0), size10000))), - (Some(BlockManagerId("b", "hostB", 1000)), - ArrayBuffer((ShuffleBlockAttemptId(10, 0, 0, 0), size1000)))) - .toSet) - val outputs = tracker.getLocationsWithLargestOutputs(10, 0, 2, 0.01) - assert(outputs.get.toSeq === Seq(BlockManagerId("b", "hostB", 1000))) - tracker.stop() - rpcEnv.shutdown() - } - test("correctly track executors and ExecutorShuffleStatus") { val tracker = newTrackerMaster() val bmId1 = BlockManagerId("exec1", "host1", 1000) val bmId2 = BlockManagerId("exec2", "host2", 1000) tracker.registerShuffle(11, 3) - tracker.registerMapOutput(11, 0, MapStatus(bmId1, Array(10), 0L)) - tracker.registerMapOutput(11, 1, MapStatus(bmId1, Array(100), 1L)) - tracker.registerMapOutput(11, 2, MapStatus(bmId2, Array(1000), 2L)) + tracker.registerMapOutput(11, 0, MapStatus(bmId1, Array(10))) + tracker.registerMapOutput(11, 1, MapStatus(bmId1, Array(100))) + tracker.registerMapOutput(11, 2, MapStatus(bmId2, Array(1000))) assert(tracker.hasOutputsOnExecutor("exec1")) assert(tracker.getExecutorShuffleStatus.keySet.equals(Set("exec1", "exec2"))) @@ -467,4 +355,5 @@ class MapOutputTrackerSuite extends SparkFunSuite { tracker.markShuffleActive(11) assert(tracker.getExecutorShuffleStatus == Map("exec2" -> ExecutorShuffleStatus.Active)) } + } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 6eb8251ec4002..8b1084a8edc76 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.scheduler.{MapStatus, MyRDD, SparkListener, SparkListene import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.ShuffleWriter import org.apache.spark.storage.{ShuffleBlockId, ShuffleDataBlockId, ShuffleIndexBlockId} -import org.apache.spark.util.MutablePair +import org.apache.spark.util.{MutablePair, Utils} abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkContext { @@ -368,7 +368,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem) val writer1 = manager.getWriter[Int, Int]( shuffleHandle, 0, context1, context1.taskMetrics.shuffleWriteMetrics) - val data1 = (1 to 10).map { x => x -> x } + val data1 = (1 to 10).map { x => x -> x} // second attempt -- also successful. We'll write out different data, // just to simulate the fact that the records may get written differently @@ -383,18 +383,13 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // simultaneously, and everything is still OK def writeAndClose( - writer: ShuffleWriter[Int, Int], - taskContext: TaskContext)( - iter: Iterator[(Int, Int)]): Option[MapStatus] = { - try { - val files = writer.write(iter) - writer.stop(true) - } finally { - TaskContext.unset() - } + writer: ShuffleWriter[Int, Int])( + iter: Iterator[(Int, Int)]): Option[MapStatus] = { + val files = writer.write(iter) + writer.stop(true) } val interleaver = new InterleaveIterators( - data1, writeAndClose(writer1, context1), data2, writeAndClose(writer2, context2)) + data1, writeAndClose(writer1), data2, writeAndClose(writer2)) val (mapOutput1, mapOutput2) = interleaver.run() // check that we can read the map output and it has the right data @@ -410,15 +405,12 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC val taskContext = new TaskContextImpl( 1, 0, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem) - TaskContext.setTaskContext(taskContext) val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, taskContext, metrics) val readData = reader.read().toIndexedSeq - TaskContext.unset() assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq) manager.unregisterShuffle(0) - TaskContext.unset() } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerShufflePluginSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerShufflePluginSuite.scala deleted file mode 100644 index 9d3a52a237cbe..0000000000000 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerShufflePluginSuite.scala +++ /dev/null @@ -1,169 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.scheduler - -import java.util - -import org.apache.spark.{FetchFailed, HashPartitioner, ShuffleDependency, SparkConf, Success} -import org.apache.spark.internal.config -import org.apache.spark.rdd.RDD -import org.apache.spark.shuffle.api.{ShuffleDataIO, ShuffleDriverComponents, ShuffleExecutorComponents} -import org.apache.spark.shuffle.sort.io.LocalDiskShuffleDataIO -import org.apache.spark.storage.BlockManagerId - -class PluginShuffleDataIO(sparkConf: SparkConf) extends ShuffleDataIO { - val localDiskShuffleDataIO = new LocalDiskShuffleDataIO(sparkConf) - override def driver(): ShuffleDriverComponents = - new PluginShuffleDriverComponents(localDiskShuffleDataIO.driver()) - - override def executor(): ShuffleExecutorComponents = localDiskShuffleDataIO.executor() -} - -class PluginShuffleDriverComponents(delegate: ShuffleDriverComponents) - extends ShuffleDriverComponents { - override def initializeApplication(): util.Map[String, String] = - delegate.initializeApplication() - - override def cleanupApplication(): Unit = - delegate.cleanupApplication() - - override def removeShuffle(shuffleId: Int, blocking: Boolean): Unit = - delegate.removeShuffle(shuffleId, blocking) - - override def shouldUnregisterOutputOnHostOnFetchFailure(): Boolean = true -} - -class DAGSchedulerShufflePluginSuite extends DAGSchedulerSuite { - - private def setupTest(): (RDD[_], Int) = { - afterEach() - val conf = new SparkConf() - // unregistering all outputs on a host is enabled for the individual file server case - conf.set(config.SHUFFLE_IO_PLUGIN_CLASS, classOf[PluginShuffleDataIO].getName) - init(conf) - val shuffleMapRdd = new MyRDD(sc, 2, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) - val shuffleId = shuffleDep.shuffleId - val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) - (reduceRdd, shuffleId) - } - - test("Test simple file server") { - val (reduceRdd, shuffleId) = setupTest() - submit(reduceRdd, Array(0, 1)) - - // Perform map task - val mapStatus1 = makeMapStatus(null, "hostA") - val mapStatus2 = makeMapStatus(null, "hostB") - complete(taskSets(0), Seq((Success, mapStatus1), (Success, mapStatus2))) - assertMapShuffleLocations(shuffleId, Seq(mapStatus1, mapStatus2)) - - // perform reduce task - complete(taskSets(1), Seq((Success, 42), (Success, 43))) - assert(results === Map(0 -> 42, 1 -> 43)) - assertDataStructuresEmpty() - } - - test("Test simple file server fetch failure") { - val (reduceRdd, shuffleId) = setupTest() - submit(reduceRdd, Array(0, 1)) - - // Perform map task - val mapStatus1 = makeMapStatus(null, "hostA") - val mapStatus2 = makeMapStatus(null, "hostB") - complete(taskSets(0), Seq((Success, mapStatus1), (Success, mapStatus2))) - - complete(taskSets(1), Seq((Success, 42), - (FetchFailed(BlockManagerId(null, "hostB", 1234), shuffleId, 1, 0, "ignored"), null))) - assertMapShuffleLocations(shuffleId, Seq(mapStatus1, null)) - - scheduler.resubmitFailedStages() - complete(taskSets(2), Seq((Success, mapStatus2))) - - complete(taskSets(3), Seq((Success, 43))) - assert(results === Map(0 -> 42, 1 -> 43)) - assertDataStructuresEmpty() - } - - test("Test simple file fetch server - duplicate host") { - val (reduceRdd, shuffleId) = setupTest() - submit(reduceRdd, Array(0, 1)) - - // Perform map task - val mapStatus1 = makeMapStatus(null, "hostA") - val mapStatus2 = makeMapStatus(null, "hostA") - complete(taskSets(0), Seq((Success, mapStatus1), (Success, mapStatus2))) - - complete(taskSets(1), Seq((Success, 42), - (FetchFailed(BlockManagerId(null, "hostA", 1234), shuffleId, 1, 0, "ignored"), null))) - assertMapShuffleLocations(shuffleId, Seq(null, null)) // removes both - - scheduler.resubmitFailedStages() - complete(taskSets(2), Seq((Success, mapStatus1), (Success, mapStatus2))) - - complete(taskSets(3), Seq((Success, 43))) - assert(results === Map(0 -> 42, 1 -> 43)) - assertDataStructuresEmpty() - } - - test("Test DFS case - empty BlockManagerId") { - val (reduceRdd, shuffleId) = setupTest() - submit(reduceRdd, Array(0, 1)) - - val mapStatus = makeEmptyMapStatus() - complete(taskSets(0), Seq((Success, mapStatus), (Success, mapStatus))) - assertMapShuffleLocations(shuffleId, Seq(mapStatus, mapStatus)) - - // perform reduce task - complete(taskSets(1), Seq((Success, 42), (Success, 43))) - assert(results === Map(0 -> 42, 1 -> 43)) - assertDataStructuresEmpty() - } - - test("Test DFS case - fetch failure") { - val (reduceRdd, shuffleId) = setupTest() - submit(reduceRdd, Array(0, 1)) - - // Perform map task - val mapStatus = makeEmptyMapStatus() - complete(taskSets(0), Seq((Success, mapStatus), (Success, mapStatus))) - - complete(taskSets(1), Seq((Success, 42), - (FetchFailed(null, shuffleId, 1, 0, "ignored"), null))) - assertMapShuffleLocations(shuffleId, Seq(mapStatus, null)) - - scheduler.resubmitFailedStages() - complete(taskSets(2), Seq((Success, mapStatus))) - - complete(taskSets(3), Seq((Success, 43))) - assert(results === Map(0 -> 42, 1 -> 43)) - assertDataStructuresEmpty() - } - - def makeMapStatus(execId: String, host: String): MapStatus = { - MapStatus(BlockManagerId(execId, host, 1234), Array.fill[Long](2)(2), 0) - } - - def makeEmptyMapStatus(): MapStatus = { - MapStatus(null, Array.fill[Long](2)(2), 0) - } - - def assertMapShuffleLocations(shuffleId: Int, set: Seq[MapStatus]): Unit = { - val actualShuffleLocations = mapOutputTracker.shuffleStatuses(shuffleId).mapStatuses - assert(set === actualShuffleLocations.toSeq) - } -} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 8e8b90600d16e..5a128638297a5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -236,7 +236,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi init(new SparkConf()) } - def init(testConf: SparkConf): Unit = { + private def init(testConf: SparkConf): Unit = { sc = new SparkContext("local[2]", "DAGSchedulerSuite", testConf) sparkListener.submittedStageInfos.clear() sparkListener.successfulStages.clear() @@ -306,7 +306,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi it.next.asInstanceOf[Tuple2[_, _]]._1 /** Send the given CompletionEvent messages for the tasks in the TaskSet. */ - def complete(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) { + private def complete(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) { assert(taskSet.tasks.size >= results.size) for ((result, i) <- results.zipWithIndex) { if (i < taskSet.tasks.size) { @@ -332,7 +332,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi } /** Submits a job to the scheduler and returns the job id. */ - def submit( + private def submit( rdd: RDD[_], partitions: Array[Int], func: (TaskContext, Iterator[_]) => _ = jobComputeFunc, @@ -445,17 +445,17 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // map stage1 completes successfully, with one task on each executor complete(taskSets(0), Seq( (Success, - MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2), 0)), + MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2))), (Success, - MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2), 0)), + MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2))), (Success, makeMapStatus("hostB", 1)) )) // map stage2 completes successfully, with one task on each executor complete(taskSets(1), Seq( (Success, - MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2), 0)), + MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2))), (Success, - MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2), 0)), + MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2))), (Success, makeMapStatus("hostB", 1)) )) // make sure our test setup is correct @@ -700,7 +700,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)))) - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.get).toSet === + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) complete(taskSets(1), Seq((Success, 42))) assert(results === Map(0 -> 42)) @@ -727,7 +727,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // have the 2nd attempt pass complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.length)))) // we can see both result blocks now - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.get.host).toSet === + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === HashSet("hostA", "hostB")) complete(taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) @@ -769,7 +769,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0) } } else { - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.get).toSet === + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) } } @@ -1063,7 +1063,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) // The MapOutputTracker should know about both map output locations. - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.get.host).toSet === + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === HashSet("hostA", "hostB")) // The first result task fails, with a fetch failure for the output from the first mapper. @@ -1193,9 +1193,9 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi (Success, makeMapStatus("hostA", 2)), (Success, makeMapStatus("hostB", 2)))) // The MapOutputTracker should know about both map output locations. - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.get.host).toSet === + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === HashSet("hostA", "hostB")) - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 1).map(_._1.get.host).toSet === + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 1).map(_._1.host).toSet === HashSet("hostA", "hostB")) // The first result task fails, with a fetch failure for the output from the first mapper. @@ -1386,7 +1386,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi Success, makeMapStatus("hostA", reduceRdd.partitions.size))) assert(shuffleStage.numAvailableOutputs === 2) - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.get).toSet === + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) // finish the next stage normally, which completes the job @@ -1792,7 +1792,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // have hostC complete the resubmitted task complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.get).toSet === + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) // Make sure that the reduce stage was now submitted. @@ -2055,7 +2055,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi submit(reduceRdd, Array(0)) complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)))) - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.get).toSet === + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostA"))) // Reducer should run on the same host that map task ran @@ -2101,7 +2101,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi submit(reduceRdd, Array(0)) complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)))) - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.get).toSet === + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostA"))) // Reducer should run where RDD 2 has preferences, even though it also has a shuffle dep @@ -2361,7 +2361,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", rdd1.partitions.length)), (Success, makeMapStatus("hostB", rdd1.partitions.length)))) - assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1.get).toSet === + assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) assert(listener1.results.size === 1) @@ -2377,7 +2377,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assert(taskSets(2).stageId === 0) complete(taskSets(2), Seq( (Success, makeMapStatus("hostC", rdd2.partitions.length)))) - assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1.get).toSet === + assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) assert(listener2.results.size === 0) // Second stage listener should still not have a result @@ -2386,7 +2386,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi complete(taskSets(3), Seq( (Success, makeMapStatus("hostB", rdd2.partitions.length)), (Success, makeMapStatus("hostD", rdd2.partitions.length)))) - assert(mapOutputTracker.getMapSizesByExecutorId(dep2.shuffleId, 0).map(_._1.get).toSet === + assert(mapOutputTracker.getMapSizesByExecutorId(dep2.shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostD"))) assert(listener2.results.size === 1) @@ -2425,7 +2425,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", rdd1.partitions.length)), (Success, makeMapStatus("hostB", rdd1.partitions.length)))) - assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1.get).toSet === + assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) assert(listener1.results.size === 1) @@ -2451,7 +2451,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assert(taskSets(2).stageId === 0) complete(taskSets(2), Seq( (Success, makeMapStatus("hostC", rdd2.partitions.length)))) - assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1.get).toSet === + assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === Set(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) // After stage0 is finished, stage1 will be submitted and found there is no missing @@ -2956,7 +2956,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi } } - def assertDataStructuresEmpty(): Unit = { + private def assertDataStructuresEmpty(): Unit = { assert(scheduler.activeJobs.isEmpty) assert(scheduler.failedStages.isEmpty) assert(scheduler.jobIdToActiveJob.isEmpty) @@ -3000,7 +3000,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi object DAGSchedulerSuite { def makeMapStatus(host: String, reduces: Int, sizes: Byte = 2): MapStatus = - MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes), 0) + MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes)) def makeBlockManagerId(host: String): BlockManagerId = BlockManagerId("exec-" + host, host, 12345) diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index d819487816b5b..c1e7fb9a1db16 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -61,7 +61,7 @@ class MapStatusSuite extends SparkFunSuite { stddev <- Seq(0.0, 0.01, 0.5, 1.0) ) { val sizes = Array.fill[Long](numSizes)(abs(round(Random.nextGaussian() * stddev)) + mean) - val status = MapStatus(BlockManagerId("a", "b", 10), sizes, 0) + val status = MapStatus(BlockManagerId("a", "b", 10), sizes) val status1 = compressAndDecompressMapStatus(status) for (i <- 0 until numSizes) { if (sizes(i) != 0) { @@ -75,7 +75,7 @@ class MapStatusSuite extends SparkFunSuite { test("large tasks should use " + classOf[HighlyCompressedMapStatus].getName) { val sizes = Array.fill[Long](2001)(150L) - val status = MapStatus(null, sizes, 0) + val status = MapStatus(null, sizes) assert(status.isInstanceOf[HighlyCompressedMapStatus]) assert(status.getSizeForBlock(10) === 150L) assert(status.getSizeForBlock(50) === 150L) @@ -87,7 +87,7 @@ class MapStatusSuite extends SparkFunSuite { val sizes = Array.tabulate[Long](3000) { i => i.toLong } val avg = sizes.sum / sizes.count(_ != 0) val loc = BlockManagerId("a", "b", 10) - val status = MapStatus(loc, sizes, 0) + val status = MapStatus(loc, sizes) val status1 = compressAndDecompressMapStatus(status) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) assert(status1.location == loc) @@ -109,7 +109,7 @@ class MapStatusSuite extends SparkFunSuite { val smallBlockSizes = sizes.filter(n => n > 0 && n < threshold) val avg = smallBlockSizes.sum / smallBlockSizes.length val loc = BlockManagerId("a", "b", 10) - val status = MapStatus(loc, sizes, 0) + val status = MapStatus(loc, sizes) val status1 = compressAndDecompressMapStatus(status) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) assert(status1.location == loc) @@ -165,7 +165,7 @@ class MapStatusSuite extends SparkFunSuite { SparkEnv.set(env) // Value of element in sizes is equal to the corresponding index. val sizes = (0L to 2000L).toArray - val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes, 0) + val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes) val arrayStream = new ByteArrayOutputStream(102400) val objectOutputStream = new ObjectOutputStream(arrayStream) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) @@ -189,12 +189,4 @@ class MapStatusSuite extends SparkFunSuite { assert(count === 3000) } } - - test("Location can be empty") { - val sizes = (0L to 3000L).toArray - val status = MapStatus(null, sizes, 0) - val status1 = compressAndDecompressMapStatus(status) - assert(status1.isInstanceOf[HighlyCompressedMapStatus]) - assert(status1.location == null) - } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index e5deb15d9db11..fb6a89f7807dc 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -1068,7 +1068,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B WorkerOffer("exec3", "host3", 2) // unknown ) val makeMapStatus = (offer: WorkerOffer) => - MapStatus(BlockManagerId(offer.executorId, offer.host, 1), Array(10), 0L) + MapStatus(BlockManagerId(offer.executorId, offer.host, 1), Array(10)) val mapOutputTracker = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] mapOutputTracker.registerShuffle(0, 2) mapOutputTracker.registerShuffle(1, 1) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index d448125c7529e..16eec7e0bea1f 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -351,8 +351,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { val denseBlockSizes = new Array[Long](5000) val sparseBlockSizes = Array[Long](0L, 1L, 0L, 2L) Seq(denseBlockSizes, sparseBlockSizes).foreach { blockSizes => - ser.serialize(HighlyCompressedMapStatus( - BlockManagerId("exec-1", "host", 1234), blockSizes, 0)) + ser.serialize(HighlyCompressedMapStatus(BlockManagerId("exec-1", "host", 1234), blockSizes)) } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 966a6fa9d005f..6d2ef17a7a790 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -20,19 +20,13 @@ package org.apache.spark.shuffle import java.io.{ByteArrayOutputStream, InputStream} import java.nio.ByteBuffer -import org.mockito.{Mock, MockitoAnnotations} -import org.mockito.Answers.RETURNS_SMART_NULLS -import org.mockito.Mockito._ -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer +import org.mockito.Mockito.{mock, when} import org.apache.spark._ import org.apache.spark.internal.config -import org.apache.spark.io.CompressionCodec import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.serializer.{JavaSerializer, SerializerManager} -import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents -import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockAttemptId, ShuffleBlockId} +import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} /** * Wrapper for a managed buffer that keeps track of how many times retain and release are called. @@ -61,14 +55,11 @@ class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends Managed class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { - @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _ - /** * This test makes sure that, when data is read from a HashShuffleReader, the underlying * ManagedBuffers that contain the data are eventually released. */ test("read() releases resources on completion") { - MockitoAnnotations.initMocks(this) val testConf = new SparkConf(false) // Create a SparkContext as a convenient way of setting SparkEnv (needed because some of the // shuffle code calls SparkEnv.get()). @@ -87,14 +78,11 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // Create a buffer with some randomly generated key-value pairs to use as the shuffle data // from each mappers (all mappers return the same shuffle data). val byteOutputStream = new ByteArrayOutputStream() - val compressionCodec = CompressionCodec.createCodec(testConf) - val compressedOutputStream = compressionCodec.compressedOutputStream(byteOutputStream) - val serializationStream = serializer.newInstance().serializeStream(compressedOutputStream) + val serializationStream = serializer.newInstance().serializeStream(byteOutputStream) (0 until keyValuePairsPerMap).foreach { i => serializationStream.writeKey(i) serializationStream.writeValue(2*i) } - compressedOutputStream.close() // Setup the mocked BlockManager to return RecordingManagedBuffers. val localBlockManagerId = BlockManagerId("test-client", "test-client", 1) @@ -114,19 +102,15 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // Make a mocked MapOutputTracker for the shuffle reader to use to determine what // shuffle data to read. val mapOutputTracker = mock(classOf[MapOutputTracker]) - when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)) - .thenAnswer(new Answer[Iterator[(Option[BlockManagerId], Seq[(BlockId, Long)])]] { - def answer(invocationOnMock: InvocationOnMock): - Iterator[(Option[BlockManagerId], Seq[(BlockId, Long)])] = { - // Test a scenario where all data is local, to avoid creating a bunch of additional mocks - // for the code to read data over the network. - val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => - val shuffleBlockId = ShuffleBlockAttemptId(shuffleId, mapId, reduceId, 0) - (shuffleBlockId, byteOutputStream.size().toLong) - } - Seq((Some(localBlockManagerId), shuffleBlockIdsAndSizes)).toIterator - } - }) + when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)).thenReturn { + // Test a scenario where all data is local, to avoid creating a bunch of additional mocks + // for the code to read data over the network. + val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => + val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) + (shuffleBlockId, byteOutputStream.size().toLong) + } + Seq((localBlockManagerId, shuffleBlockIdsAndSizes)).toIterator + } // Create a mocked shuffle handle to pass into HashShuffleReader. val shuffleHandle = { @@ -140,29 +124,19 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext val serializerManager = new SerializerManager( serializer, new SparkConf() - .set(config.SHUFFLE_COMPRESS, true) + .set(config.SHUFFLE_COMPRESS, false) .set(config.SHUFFLE_SPILL_COMPRESS, false)) val taskContext = TaskContext.empty() - TaskContext.setTaskContext(taskContext) val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() - - val shuffleExecutorComponents = - new LocalDiskShuffleExecutorComponents( - testConf, - blockManager, - mapOutputTracker, - serializerManager, - blockResolver, - localBlockManagerId) val shuffleReader = new BlockStoreShuffleReader( shuffleHandle, reduceId, reduceId + 1, taskContext, metrics, - shuffleExecutorComponents, serializerManager, + blockManager, mapOutputTracker) assert(shuffleReader.read().length === keyValuePairsPerMap * numMaps) @@ -173,6 +147,5 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext assert(buffer.callsToRetain === 1) assert(buffer.callsToRelease === 1) } - TaskContext.unset() } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala deleted file mode 100644 index b571565cf4336..0000000000000 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle - -import java.io.InputStream -import java.lang.{Iterable => JIterable} -import java.util.{Map => JMap} - -import com.google.common.collect.ImmutableMap - -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} -import org.apache.spark.internal.config.SHUFFLE_IO_PLUGIN_CLASS -import org.apache.spark.shuffle.api.{ShuffleBlockInfo, ShuffleDataIO, ShuffleDriverComponents, ShuffleExecutorComponents, ShuffleMapOutputWriter} -import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents - -class ShuffleDriverComponentsSuite extends SparkFunSuite with LocalSparkContext { - test(s"test serialization of shuffle initialization conf to executors") { - val testConf = new SparkConf() - .setAppName("testing") - .setMaster("local-cluster[2,1,1024]") - .set(SHUFFLE_IO_PLUGIN_CLASS, "org.apache.spark.shuffle.TestShuffleDataIO") - - sc = new SparkContext(testConf) - - sc.parallelize(Seq((1, "one"), (2, "two"), (3, "three")), 3) - .groupByKey() - .collect() - } -} - -class TestShuffleDriverComponents extends ShuffleDriverComponents { - override def initializeApplication(): JMap[String, String] = - ImmutableMap.of("test-key", "test-value") -} - -class TestShuffleDataIO(sparkConf: SparkConf) extends ShuffleDataIO { - override def driver(): ShuffleDriverComponents = new TestShuffleDriverComponents() - - override def executor(): ShuffleExecutorComponents = - new TestShuffleExecutorComponents(sparkConf) -} - -class TestShuffleExecutorComponents(sparkConf: SparkConf) extends ShuffleExecutorComponents { - - private var delegate = new LocalDiskShuffleExecutorComponents(sparkConf) - - override def initializeExecutor( - appId: String, execId: String, extraConfigs: JMap[String, String]): Unit = { - assert(extraConfigs.get("test-key") == "test-value") - delegate.initializeExecutor(appId, execId, extraConfigs) - } - - override def createMapOutputWriter( - shuffleId: Int, - mapId: Int, - mapTaskAttemptId: Long, - numPartitions: Int): ShuffleMapOutputWriter = { - delegate.createMapOutputWriter(shuffleId, mapId, mapTaskAttemptId, numPartitions) - } - - override def getPartitionReaders( - blockMetadata: JIterable[ShuffleBlockInfo]): JIterable[InputStream] = { - delegate.getPartitionReaders(blockMetadata) - } - - override def shouldWrapPartitionReaderStream(): Boolean = { - delegate.shouldWrapPartitionReaderStream() - } -} diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index da1630e67a485..7f956c26d0ff0 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -22,11 +22,10 @@ import java.util.UUID import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.language.existentials import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS -import org.mockito.ArgumentMatchers.{any, anyInt, anyString} +import org.mockito.ArgumentMatchers.{any, anyInt} import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer @@ -34,11 +33,8 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark._ import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics} -import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager} import org.apache.spark.shuffle.IndexShuffleBlockResolver -import org.apache.spark.shuffle.api.ShuffleExecutorComponents -import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -49,57 +45,50 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte @Mock(answer = RETURNS_SMART_NULLS) private var taskContext: TaskContext = _ @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _ @Mock(answer = RETURNS_SMART_NULLS) private var dependency: ShuffleDependency[Int, Int, Int] = _ - @Mock(answer = RETURNS_SMART_NULLS) private var serializerManager: SerializerManager = _ - @Mock(answer = RETURNS_SMART_NULLS) private var mapOutputTracker: MapOutputTracker = _ private var taskMetrics: TaskMetrics = _ private var tempDir: File = _ private var outputFile: File = _ - private var shuffleExecutorComponents: ShuffleExecutorComponents = _ private val conf: SparkConf = new SparkConf(loadDefaults = false) - .set("spark.app.id", "sampleApp") private val temporaryFilesCreated: mutable.Buffer[File] = new ArrayBuffer[File]() private val blockIdToFileMap: mutable.Map[BlockId, File] = new mutable.HashMap[BlockId, File] private var shuffleHandle: BypassMergeSortShuffleHandle[Int, Int] = _ override def beforeEach(): Unit = { super.beforeEach() - MockitoAnnotations.initMocks(this) tempDir = Utils.createTempDir() outputFile = File.createTempFile("shuffle", null, tempDir) taskMetrics = new TaskMetrics + MockitoAnnotations.initMocks(this) shuffleHandle = new BypassMergeSortShuffleHandle[Int, Int]( shuffleId = 0, numMaps = 2, dependency = dependency ) - val memoryManager = new TestMemoryManager(conf) - val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) when(dependency.partitioner).thenReturn(new HashPartitioner(7)) when(dependency.serializer).thenReturn(new JavaSerializer(conf)) when(taskContext.taskMetrics()).thenReturn(taskMetrics) when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) - when(blockManager.diskBlockManager).thenReturn(diskBlockManager) - when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) - - when(blockResolver.writeIndexFileAndCommit( - anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File]))) - .thenAnswer { (invocationOnMock: InvocationOnMock) => - val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File] + doAnswer(new Answer[Void] { + def answer(invocationOnMock: InvocationOnMock): Void = { + val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File] if (tmp != null) { outputFile.delete tmp.renameTo(outputFile) } null } - + }).when(blockResolver) + .writeIndexFileAndCommit(anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File])) + when(blockManager.diskBlockManager).thenReturn(diskBlockManager) when(blockManager.getDiskWriter( any[BlockId], any[File], any[SerializerInstance], anyInt(), - any[ShuffleWriteMetrics])) - .thenAnswer { (invocation: InvocationOnMock) => + any[ShuffleWriteMetrics] + )).thenAnswer(new Answer[DiskBlockObjectWriter] { + override def answer(invocation: InvocationOnMock): DiskBlockObjectWriter = { val args = invocation.getArguments val manager = new SerializerManager(new JavaSerializer(conf), conf) new DiskBlockObjectWriter( @@ -109,33 +98,29 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte args(3).asInstanceOf[Int], syncWrites = false, args(4).asInstanceOf[ShuffleWriteMetrics], - blockId = args(0).asInstanceOf[BlockId]) + blockId = args(0).asInstanceOf[BlockId] + ) } - - when(diskBlockManager.createTempShuffleBlock()) - .thenAnswer { (invocationOnMock: InvocationOnMock) => - val blockId = new TempShuffleBlockId(UUID.randomUUID) - val file = new File(tempDir, blockId.name) - blockIdToFileMap.put(blockId, file) - temporaryFilesCreated += file - (blockId, file) - } - - when(diskBlockManager.getFile(any[BlockId])).thenAnswer { (invocation: InvocationOnMock) => - blockIdToFileMap(invocation.getArguments.head.asInstanceOf[BlockId]) - } - - shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents( - conf, - blockManager, - mapOutputTracker, - serializerManager, - blockResolver, - BlockManagerId("localhost", 7077)) + }) + when(diskBlockManager.createTempShuffleBlock()).thenAnswer( + new Answer[(TempShuffleBlockId, File)] { + override def answer(invocation: InvocationOnMock): (TempShuffleBlockId, File) = { + val blockId = new TempShuffleBlockId(UUID.randomUUID) + val file = new File(tempDir, blockId.name) + blockIdToFileMap.put(blockId, file) + temporaryFilesCreated += file + (blockId, file) + } + }) + when(diskBlockManager.getFile(any[BlockId])).thenAnswer( + new Answer[File] { + override def answer(invocation: InvocationOnMock): File = { + blockIdToFileMap.get(invocation.getArguments.head.asInstanceOf[BlockId]).get + } + }) } override def afterEach(): Unit = { - TaskContext.unset() try { Utils.deleteRecursively(tempDir) blockIdToFileMap.clear() @@ -148,13 +133,12 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte test("write empty iterator") { val writer = new BypassMergeSortShuffleWriter[Int, Int]( blockManager, + blockResolver, shuffleHandle, 0, // MapId - 0L, // MapTaskAttemptId conf, - taskContext.taskMetrics().shuffleWriteMetrics, - shuffleExecutorComponents) - + taskContext.taskMetrics().shuffleWriteMetrics + ) writer.write(Iterator.empty) writer.stop( /* success = */ true) assert(writer.getPartitionLengths.sum === 0) @@ -168,31 +152,28 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte assert(taskMetrics.memoryBytesSpilled === 0) } - Seq(true, false).foreach { transferTo => - test(s"write with some empty partitions - transferTo $transferTo") { - val transferConf = conf.clone.set("spark.file.transferTo", transferTo.toString) - def records: Iterator[(Int, Int)] = - Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) - val writer = new BypassMergeSortShuffleWriter[Int, Int]( - blockManager, - shuffleHandle, - 0, // MapId - 0L, - transferConf, - taskContext.taskMetrics().shuffleWriteMetrics, - shuffleExecutorComponents) - writer.write(records) - writer.stop( /* success = */ true) - assert(temporaryFilesCreated.nonEmpty) - assert(writer.getPartitionLengths.sum === outputFile.length()) - assert(writer.getPartitionLengths.count(_ == 0L) === 4) // should be 4 zero length files - assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temp files were deleted - val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics - assert(shuffleWriteMetrics.bytesWritten === outputFile.length()) - assert(shuffleWriteMetrics.recordsWritten === records.length) - assert(taskMetrics.diskBytesSpilled === 0) - assert(taskMetrics.memoryBytesSpilled === 0) - } + test("write with some empty partitions") { + def records: Iterator[(Int, Int)] = + Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) + val writer = new BypassMergeSortShuffleWriter[Int, Int]( + blockManager, + blockResolver, + shuffleHandle, + 0, // MapId + conf, + taskContext.taskMetrics().shuffleWriteMetrics + ) + writer.write(records) + writer.stop( /* success = */ true) + assert(temporaryFilesCreated.nonEmpty) + assert(writer.getPartitionLengths.sum === outputFile.length()) + assert(writer.getPartitionLengths.count(_ == 0L) === 4) // should be 4 zero length files + assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted + val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics + assert(shuffleWriteMetrics.bytesWritten === outputFile.length()) + assert(shuffleWriteMetrics.recordsWritten === records.length) + assert(taskMetrics.diskBytesSpilled === 0) + assert(taskMetrics.memoryBytesSpilled === 0) } test("only generate temp shuffle file for non-empty partition") { @@ -211,12 +192,12 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte val writer = new BypassMergeSortShuffleWriter[Int, Int]( blockManager, + blockResolver, shuffleHandle, 0, // MapId - 0L, conf, - taskContext.taskMetrics().shuffleWriteMetrics, - shuffleExecutorComponents) + taskContext.taskMetrics().shuffleWriteMetrics + ) intercept[SparkException] { writer.write(records) @@ -233,12 +214,12 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte test("cleanup of intermediate files after errors") { val writer = new BypassMergeSortShuffleWriter[Int, Int]( blockManager, + blockResolver, shuffleHandle, 0, // MapId - 0L, conf, - taskContext.taskMetrics().shuffleWriteMetrics, - shuffleExecutorComponents) + taskContext.taskMetrics().shuffleWriteMetrics + ) intercept[SparkException] { writer.write((0 until 100000).iterator.map(i => { if (i == 99990) { @@ -252,12 +233,4 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte assert(temporaryFilesCreated.count(_.exists()) === 0) } - /** - * This won't be necessary with Scala 2.12 - */ - private implicit def functionToAnswer[T](func: InvocationOnMock => T): Answer[T] = { - new Answer[T] { - override def answer(invocationOnMock: InvocationOnMock): T = func(invocationOnMock) - } - } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala deleted file mode 100644 index 326831749ce09..0000000000000 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala +++ /dev/null @@ -1,117 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.sort - -import org.mockito.{Mock, MockitoAnnotations} -import org.mockito.Answers.RETURNS_SMART_NULLS -import org.mockito.Mockito._ -import org.scalatest.Matchers - -import org.apache.spark.{MapOutputTracker, Partitioner, SharedSparkContext, ShuffleDependency, SparkFunSuite} -import org.apache.spark.memory.MemoryTestingUtils -import org.apache.spark.serializer.{JavaSerializer, SerializerManager} -import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver} -import org.apache.spark.shuffle.api.ShuffleExecutorComponents -import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents -import org.apache.spark.storage.{BlockManager, BlockManagerId} -import org.apache.spark.util.Utils - - -class SortShuffleWriterSuite extends SparkFunSuite with SharedSparkContext with Matchers { - - @Mock(answer = RETURNS_SMART_NULLS) - private var blockManager: BlockManager = _ - @Mock(answer = RETURNS_SMART_NULLS) - private var mapOutputTracker: MapOutputTracker = _ - @Mock(answer = RETURNS_SMART_NULLS) - private var serializerManager: SerializerManager = _ - - private val shuffleId = 0 - private val numMaps = 5 - private var shuffleHandle: BaseShuffleHandle[Int, Int, Int] = _ - private val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) - private val serializer = new JavaSerializer(conf) - private var shuffleExecutorComponents: ShuffleExecutorComponents = _ - - override def beforeEach(): Unit = { - super.beforeEach() - MockitoAnnotations.initMocks(this) - val partitioner = new Partitioner() { - def numPartitions = numMaps - def getPartition(key: Any) = Utils.nonNegativeMod(key.hashCode, numPartitions) - } - shuffleHandle = { - val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]]) - when(dependency.partitioner).thenReturn(partitioner) - when(dependency.serializer).thenReturn(serializer) - when(dependency.aggregator).thenReturn(None) - when(dependency.keyOrdering).thenReturn(None) - new BaseShuffleHandle(shuffleId, numMaps = numMaps, dependency) - } - shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents( - conf, - blockManager, - mapOutputTracker, - serializerManager, - shuffleBlockResolver, - BlockManagerId("localhost", 7077)) - } - - override def afterAll(): Unit = { - try { - shuffleBlockResolver.stop() - } finally { - super.afterAll() - } - } - - test("write empty iterator") { - val context = MemoryTestingUtils.fakeTaskContext(sc.env) - val writer = new SortShuffleWriter[Int, Int, Int]( - shuffleBlockResolver, - shuffleHandle, - mapId = 1, - context, - shuffleExecutorComponents) - writer.write(Iterator.empty) - writer.stop(success = true) - val dataFile = shuffleBlockResolver.getDataFile(shuffleId, 1) - val writeMetrics = context.taskMetrics().shuffleWriteMetrics - assert(!dataFile.exists()) - assert(writeMetrics.bytesWritten === 0) - assert(writeMetrics.recordsWritten === 0) - } - - test("write with some records") { - val context = MemoryTestingUtils.fakeTaskContext(sc.env) - val records = List[(Int, Int)]((1, 2), (2, 3), (4, 4), (6, 5)) - val writer = new SortShuffleWriter[Int, Int, Int]( - shuffleBlockResolver, - shuffleHandle, - mapId = 2, - context, - shuffleExecutorComponents) - writer.write(records.toIterator) - writer.stop(success = true) - val dataFile = shuffleBlockResolver.getDataFile(shuffleId, 2) - val writeMetrics = context.taskMetrics().shuffleWriteMetrics - assert(dataFile.exists()) - assert(dataFile.length() === writeMetrics.bytesWritten) - assert(records.size === writeMetrics.recordsWritten) - } -} diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala deleted file mode 100644 index 8aa9f51e09494..0000000000000 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala +++ /dev/null @@ -1,161 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.sort.io - -import java.io.{File, FileInputStream} -import java.nio.channels.FileChannel -import java.nio.file.Files -import java.util.Arrays - -import org.mockito.Answers.RETURNS_SMART_NULLS -import org.mockito.ArgumentMatchers.{any, anyInt} -import org.mockito.Mock -import org.mockito.Mockito.when -import org.mockito.MockitoAnnotations -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer -import org.scalatest.BeforeAndAfterEach - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.shuffle.IndexShuffleBlockResolver -import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.Utils - -class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAfterEach { - - @Mock(answer = RETURNS_SMART_NULLS) - private var blockResolver: IndexShuffleBlockResolver = _ - - private val NUM_PARTITIONS = 4 - private val BLOCK_MANAGER_ID = BlockManagerId("localhost", 7077) - private val data: Array[Array[Byte]] = (0 until NUM_PARTITIONS).map { p => - if (p == 3) { - Array.emptyByteArray - } else { - (0 to p * 10).map(_ + p).map(_.toByte).toArray - } - }.toArray - - private val partitionLengths = data.map(_.length) - - private var tempFile: File = _ - private var mergedOutputFile: File = _ - private var tempDir: File = _ - private var partitionSizesInMergedFile: Array[Long] = _ - private var conf: SparkConf = _ - private var mapOutputWriter: LocalDiskShuffleMapOutputWriter = _ - - override def afterEach(): Unit = { - try { - Utils.deleteRecursively(tempDir) - } finally { - super.afterEach() - } - } - - override def beforeEach(): Unit = { - MockitoAnnotations.initMocks(this) - tempDir = Utils.createTempDir() - mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir) - tempFile = File.createTempFile("tempfile", "", tempDir) - partitionSizesInMergedFile = null - conf = new SparkConf() - .set("spark.app.id", "example.spark.app") - .set("spark.shuffle.unsafe.file.output.buffer", "16k") - when(blockResolver.getDataFile(anyInt, anyInt)).thenReturn(mergedOutputFile) - when(blockResolver.writeIndexFileAndCommit( - anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File]))) - .thenAnswer { (invocationOnMock: InvocationOnMock) => - partitionSizesInMergedFile = invocationOnMock.getArguments()(2).asInstanceOf[Array[Long]] - val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File] - if (tmp != null) { - mergedOutputFile.delete() - tmp.renameTo(mergedOutputFile) - } - null - } - mapOutputWriter = new LocalDiskShuffleMapOutputWriter( - 0, - 0, - NUM_PARTITIONS, - blockResolver, - BLOCK_MANAGER_ID, - conf) - } - - test("writing to an outputstream") { - (0 until NUM_PARTITIONS).foreach { p => - val writer = mapOutputWriter.getPartitionWriter(p) - val stream = writer.openStream() - data(p).foreach { i => stream.write(i) } - stream.close() - intercept[IllegalStateException] { - stream.write(p) - } - } - verifyWrittenRecords() - } - - test("writing to a channel") { - (0 until NUM_PARTITIONS).foreach { p => - val writer = mapOutputWriter.getPartitionWriter(p) - val outputTempFile = File.createTempFile("channelTemp", "", tempDir) - Files.write(outputTempFile.toPath, data(p)) - val tempFileInput = new FileInputStream(outputTempFile) - val channel = writer.openChannelWrapper() - Utils.tryWithResource(new FileInputStream(outputTempFile)) { tempFileInput => - Utils.tryWithResource(writer.openChannelWrapper().get) { channelWrapper => - assert(channelWrapper.channel().isInstanceOf[FileChannel], - "Underlying channel should be a file channel") - Utils.copyFileStreamNIO( - tempFileInput.getChannel, channelWrapper.channel(), 0L, data(p).length) - } - } - } - verifyWrittenRecords() - } - - private def readRecordsFromFile() = { - val mergedOutputBytes = Files.readAllBytes(mergedOutputFile.toPath) - val result = (0 until NUM_PARTITIONS).map { part => - val startOffset = data.slice(0, part).map(_.length).sum - val partitionSize = data(part).length - Arrays.copyOfRange(mergedOutputBytes, startOffset, startOffset + partitionSize) - }.toArray - result - } - - private def verifyWrittenRecords(): Unit = { - val committedLengths = mapOutputWriter.commitAllPartitions() - assert(partitionSizesInMergedFile === partitionLengths) - assert(committedLengths.getPartitionLengths === partitionLengths) - assert(committedLengths.getLocation.isPresent) - assert(committedLengths.getLocation.get === BLOCK_MANAGER_ID) - assert(mergedOutputFile.length() === partitionLengths.sum) - assert(data === readRecordsFromFile()) - } - - /** - * This won't be necessary with Scala 2.12 - */ - private implicit def functionToAnswer[T](func: InvocationOnMock => T): Answer[T] = { - new Answer[T] { - override def answer(invocationOnMock: InvocationOnMock): T = func(invocationOnMock) - } - } -} diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index a4b6920be04c0..98fe9663b6211 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -21,13 +21,14 @@ import java.io.{File, InputStream, IOException} import java.util.UUID import java.util.concurrent.Semaphore +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.Future + import org.mockito.ArgumentMatchers.{any, eq => meq} import org.mockito.Mockito.{mock, times, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.PrivateMethodTester -import scala.concurrent.ExecutionContext.Implicits.global -import scala.concurrent.Future import org.apache.spark.{SparkFunSuite, TaskContext} import org.apache.spark.network._ @@ -124,8 +125,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT for (i <- 0 until 5) { assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements") - val inputStream = iterator.next() - val blockId = iterator.getCurrentBlock() + val (blockId, inputStream) = iterator.next() // Make sure we release buffers when a wrapped input stream is closed. val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId)) @@ -200,11 +200,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT taskContext.taskMetrics.createTempShuffleReadMetrics()) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() - iterator.next().close() // close() first block's input stream + iterator.next()._2.close() // close() first block's input stream verify(blocks(ShuffleBlockId(0, 0, 0)), times(1)).release() // Get the 2nd block but do not exhaust the iterator - val subIter = iterator.next() + val subIter = iterator.next()._2 // Complete the task; then the 2nd block buffer should be exhausted verify(blocks(ShuffleBlockId(0, 1, 0)), times(0)).release() @@ -402,8 +402,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT sem.acquire() // The first block should be returned without an exception - iterator.next() - val id1 = iterator.getCurrentBlock() + val (id1, _) = iterator.next() assert(id1 === ShuffleBlockId(0, 0, 0)) when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) @@ -423,7 +422,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT intercept[FetchFailedException] { iterator.next() } sem.acquire() - intercept[FetchFailedException] { iterator.next() } } @@ -465,11 +463,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, taskContext.taskMetrics.createTempShuffleReadMetrics()) // Blocks should be returned without exceptions. - iterator.next() - val blockId1 = iterator.getCurrentBlock() - iterator.next() - val blockId2 = iterator.getCurrentBlock() - assert(Set(blockId1, blockId2) === Set(ShuffleBlockId(0, 0, 0), ShuffleBlockId(0, 1, 0))) + assert(Set(iterator.next()._1, iterator.next()._1) === + Set(ShuffleBlockId(0, 0, 0), ShuffleBlockId(0, 1, 0))) } test("retry corrupt blocks (disabled)") { @@ -527,14 +522,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT sem.acquire() // The first block should be returned without an exception - iterator.next() - val id1 = iterator.getCurrentBlock() + val (id1, _) = iterator.next() assert(id1 === ShuffleBlockId(0, 0, 0)) - iterator.next() - val id2 = iterator.getCurrentBlock() + val (id2, _) = iterator.next() assert(id2 === ShuffleBlockId(0, 1, 0)) - iterator.next() - val id3 = iterator.getCurrentBlock() + val (id3, _) = iterator.next() assert(id3 === ShuffleBlockId(0, 2, 0)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala index 22cfbf506c645..079ff25fcb67e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala @@ -156,11 +156,10 @@ class ShuffledRowRDD( override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { val shuffledRowPartition = split.asInstanceOf[ShuffledRowRDDPartition] + val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics() // `SQLShuffleReadMetricsReporter` will update its own metrics for SQL exchange operator, // as well as the `tempMetrics` for basic shuffle metrics. - context.taskMetrics().decorateTempShuffleReadMetrics( - tempMetrics => new SQLShuffleReadMetricsReporter(tempMetrics, metrics)) - val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics() + val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, metrics) // The range of pre-shuffle partitions that we are fetching at here is // [startPreShufflePartitionIndex, endPreShufflePartitionIndex - 1]. val reader = @@ -169,7 +168,7 @@ class ShuffledRowRDD( shuffledRowPartition.startPreShufflePartitionIndex, shuffledRowPartition.endPreShufflePartitionIndex, context, - tempMetrics) + sqlMetricsReporter) reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2) }