diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala index 3239f4c385..d87f95c39b 100644 --- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala @@ -18,33 +18,29 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin val blockManager = SparkEnv.get.blockManager val startTime = System.currentTimeMillis - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId) + val (mapLocations, blockSizes) = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId) + logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format( shuffleId, reduceId, System.currentTimeMillis - startTime)) - val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]] - for (((address, size), index) <- statuses.zipWithIndex) { - splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size)) - } - - val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] = splitsByAddress.toSeq.map { - case (address, splits) => - (address, splits.map(s => ("shuffle_%d_%d_%d".format(shuffleId, s._1, reduceId), s._2))) - } + val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] = blockSizes.map { case (bm, groups) => + val blockIds = groups.flatMap { case (groupId, segments) => + segments.zipWithIndex.map(x=>(("shuffle_%d_%d_%d_%d").format(shuffleId, groupId, reduceId, x._2), x._1)) + } + (bm, blockIds) + } - def unpackBlock(blockPair: (String, Option[Iterator[Any]])) : Iterator[(K, V)] = { - val blockId = blockPair._1 - val blockOption = blockPair._2 + def unpackBlock(blockTuple: (BlockManagerId, String, Option[Iterator[Any]])) : Iterator[(K, V)] = { + val (address, blockId, blockOption) = blockTuple blockOption match { case Some(block) => { block.asInstanceOf[Iterator[(K, V)]] } case None => { - val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)".r + val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)_([0-9]*)".r blockId match { - case regex(shufId, mapId, _) => - val address = statuses(mapId.toInt)._1 - throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null) + case regex(shufId, _, _, _) => + throw new FetchFailedException(address, shufId.toInt, -1, reduceId, null) case _ => throw new SparkException( "Failed to get block " + blockId + ", which is not a shuffle block") @@ -53,12 +49,12 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin } } - val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer) + val blockFetcherItr = blockManager.getMultiple( + blocksByAddress, shuffleId, reduceId, mapLocations, serializer) val itr = blockFetcherItr.flatMap(unpackBlock) CompletionIterator[(K,V), Iterator[(K,V)]](itr, { val shuffleMetrics = new ShuffleReadMetrics - shuffleMetrics.shuffleFinishTime = System.currentTimeMillis shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index fde597ffd1..c40fbe7b87 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -43,7 +43,8 @@ private[spark] class MapOutputTracker extends Logging { // Set to the MapOutputTrackerActor living on the driver var trackerActor: ActorRef = _ - private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] + private val mapStatuses = new TimeStampedHashMap[Int, Array[MapOutputLocation]] + private val shuffleBlockSizes = new TimeStampedHashMap[Int, HashMap[BlockManagerId, ShuffleBlockGroupSizeArray]] // Incremented every time a fetch fails so that client nodes know to clear // their cache of map output locations if this happens. @@ -76,36 +77,36 @@ private[spark] class MapOutputTracker extends Logging { } def registerShuffle(shuffleId: Int, numMaps: Int) { - if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) { + if (mapStatuses.putIfAbsent(shuffleId, new Array[MapOutputLocation](numMaps)).isDefined) { throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } - } - - def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) { - var array = mapStatuses(shuffleId) - array.synchronized { - array(mapId) = status - } + shuffleBlockSizes.put(shuffleId, HashMap[BlockManagerId, ShuffleBlockGroupSizeArray]()) } def registerMapOutputs( shuffleId: Int, - statuses: Array[MapStatus], + statuses: Array[MapOutputLocation], + sizes: HashMap[BlockManagerId, ShuffleBlockGroupSizeArray], changeGeneration: Boolean = false) { - mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses) + mapStatuses.put(shuffleId, Array[MapOutputLocation]() ++ statuses) + shuffleBlockSizes.put(shuffleId, HashMap[BlockManagerId, ShuffleBlockGroupSizeArray]() ++ sizes) + if (changeGeneration) { incrementGeneration() } } def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { - var arrayOpt = mapStatuses.get(shuffleId) - if (arrayOpt.isDefined && arrayOpt.get != null) { - var array = arrayOpt.get + val array = mapStatuses.get(shuffleId).orNull + val sizes = shuffleBlockSizes.get(shuffleId).orNull + if (array != null) { array.synchronized { if (array(mapId) != null && array(mapId).location == bmAddress) { array(mapId) = null } + if (sizes!= null) { + sizes.remove(bmAddress) + } } incrementGeneration() } else { @@ -116,12 +117,21 @@ private[spark] class MapOutputTracker extends Logging { // Remembers which map output locations are currently being fetched on a worker private val fetching = new HashSet[Int] - // Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle - def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = { + /** + * Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle + * Return an array of MapOutputLocation of the specific reduceId, one for each ShuffleMapTask, + * and sizes of all segments in for the shuffle (bucket) in the form of + * Seq(BlockManagerId, Seq(groupId, size array for all the segments in the bucket)) + */ + def getServerStatuses(shuffleId: Int, reduceId: Int): + (Array[MapOutputLocation], Seq[(BlockManagerId, Seq[(Int, Seq[Long])])]) = { val statuses = mapStatuses.get(shuffleId).orNull - if (statuses == null) { + val sizes = shuffleBlockSizes.get(shuffleId).orNull + + if (statuses == null || sizes == null) { logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") - var fetchedStatuses: Array[MapStatus] = null + var fetchedStatuses: Array[MapOutputLocation] = null + var fetchedSizes: HashMap[BlockManagerId, ShuffleBlockGroupSizeArray] = null fetching.synchronized { if (fetching.contains(shuffleId)) { // Someone else is fetching it; wait for them to be done @@ -151,9 +161,12 @@ private[spark] class MapOutputTracker extends Logging { try { val fetchedBytes = askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]] - fetchedStatuses = deserializeStatuses(fetchedBytes) + val tuple = deserializeStatuses(fetchedBytes) + fetchedStatuses = tuple._1 + fetchedSizes = tuple._2 logInfo("Got the output locations") mapStatuses.put(shuffleId, fetchedStatuses) + shuffleBlockSizes.put(shuffleId, fetchedSizes) } finally { fetching.synchronized { fetching -= shuffleId @@ -161,9 +174,13 @@ private[spark] class MapOutputTracker extends Logging { } } } - if (fetchedStatuses != null) { + if (fetchedStatuses != null && fetchedSizes != null) { + logDebug("ShufCon - getServerStatuses for shuffle " + shuffleId + ": " + + fetachedResultStr(fetchedStatuses, fetchedSizes)) + fetchedStatuses.synchronized { - return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) + return MapOutputTracker.convertShuffleBlockSizes(shuffleId, reduceId, + fetchedStatuses, fetchedSizes) } } else{ @@ -172,19 +189,33 @@ private[spark] class MapOutputTracker extends Logging { } } else { statuses.synchronized { - return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) + return MapOutputTracker.convertShuffleBlockSizes(shuffleId, reduceId, statuses, sizes) } } } + + private def fetachedResultStr (fetchedStatuses: Array[MapOutputLocation], + fetchedSizes: HashMap[BlockManagerId, ShuffleBlockGroupSizeArray]) = { + var str = "(fetchedStatuses=" + fetchedStatuses.zipWithIndex.foreach { s => + str += (if (s._2 != 0) ", " else "") + "map[" + s._2 + "]=" + s._1.debugString + } + str += "), fetchedSizes=(" + fetchedSizes.foreach { s => str += "(" + s._1 + ", " + s._2.debugString + ") "} + str += ")" + str + } private def cleanup(cleanupTime: Long) { mapStatuses.clearOldValues(cleanupTime) + shuffleBlockSizes.clearOldValues(cleanupTime) cachedSerializedStatuses.clearOldValues(cleanupTime) } def stop() { communicate(StopMapOutputTracker) mapStatuses.clear() + shuffleBlockSizes.clear() metadataCleaner.cancel() trackerActor = null } @@ -219,7 +250,8 @@ private[spark] class MapOutputTracker extends Logging { } def getSerializedLocations(shuffleId: Int): Array[Byte] = { - var statuses: Array[MapStatus] = null + var statuses: Array[MapOutputLocation] = null + var sizes: HashMap[BlockManagerId, ShuffleBlockGroupSizeArray] = null var generationGotten: Long = -1 generationLock.synchronized { if (generation > cacheGeneration) { @@ -231,12 +263,13 @@ private[spark] class MapOutputTracker extends Logging { return bytes case None => statuses = mapStatuses(shuffleId) + sizes = shuffleBlockSizes.get(shuffleId).orNull generationGotten = generation } } // If we got here, we failed to find the serialized locations in the cache, so we pulled // out a snapshot of the locations as "locs"; let's serialize and return that - val bytes = serializeStatuses(statuses) + val bytes = serializeStatuses((statuses, sizes)) logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length)) // Add them into the table only if the generation hasn't changed while we were working generationLock.synchronized { @@ -250,47 +283,64 @@ private[spark] class MapOutputTracker extends Logging { // Serialize an array of map output locations into an efficient byte format so that we can send // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will // generally be pretty compressible because many map outputs will be on the same hostname. - private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = { + private def serializeStatuses(tuple: (Array[MapOutputLocation], HashMap[BlockManagerId, ShuffleBlockGroupSizeArray])): Array[Byte] = { val out = new ByteArrayOutputStream val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) // Since statuses can be modified in parallel, sync on it + val statuses = tuple._1 statuses.synchronized { - objOut.writeObject(statuses) + objOut.writeObject(tuple) } objOut.close() out.toByteArray } // Opposite of serializeStatuses. - def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = { + def deserializeStatuses(bytes: Array[Byte]) = { val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes))) objIn.readObject(). // // drop all null's from status - not sure why they are occuring though. Causes NPE downstream in slave if present // comment this out - nulls could be due to missing location ? - asInstanceOf[Array[MapStatus]] // .filter( _ != null ) + asInstanceOf[Tuple2[Array[MapOutputLocation], HashMap[BlockManagerId, ShuffleBlockGroupSizeArray]]] // .filter( _ != null ) } } -private[spark] object MapOutputTracker { +private[spark] object MapOutputTracker extends Logging{ private val LOG_BASE = 1.1 // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If // any of the statuses is null (indicating a missing location due to a failed mapper), // throw a FetchFailedException. - private def convertMapStatuses( + private def convertShuffleBlockSizes( shuffleId: Int, reduceId: Int, - statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = { + statuses: Array[MapOutputLocation], + sizes: HashMap[BlockManagerId, ShuffleBlockGroupSizeArray]): + (Array[MapOutputLocation], Seq[(BlockManagerId, Seq[(Int, Seq[Long])])]) = { assert (statuses != null) + assert (sizes != null) + statuses.map { status => if (status == null) { throw new FetchFailedException(null, shuffleId, -1, reduceId, new Exception("Missing an output location for shuffle " + shuffleId)) + } + } + + val segments = sizes.toSeq.map { case (bmId, groups) => + if (groups == null) { + throw new FetchFailedException(null, shuffleId, -1, reduceId, + new Exception("Missing output blocks for shuffle " + shuffleId + " on " + bmId)) } else { - (status.location, decompressSize(status.compressedSizes(reduceId))) + val seq = + for (i <- 0 until groups.groupNum if groups(i) != null) + yield (i, groups(i).bucketSizes(reduceId).map(decompressSize(_)).toSeq) + (bmId, seq) } } + + (statuses, segments) } /** @@ -319,3 +369,81 @@ private[spark] object MapOutputTracker { } } } + +private[spark] class MapOutputLocation(val location: BlockManagerId, val sequence: Int) +extends Serializable { + def this (status: MapStatus) = this (status.location, status.sequence) + def debugString = "MapOutputLocation(location=" + location + ", sequence=" + sequence +")" + + override def equals(that: Any) = that match { + case loc: MapOutputLocation => + location == loc.location && sequence == loc.sequence + case _ => + false + } + +} + +private[spark] class GroupBucketSizes(var sequence: Int, var bucketSizes: Array[Array[Byte]]) +extends Serializable { + def this(status: MapStatus) = this(status.sequence, status.compressedSizes) + def debugString = { + var str = "GroupBucketSizes(sequence=" + sequence + ", " + bucketSizes.zipWithIndex.foreach { s => + str += (if (s._2 != 0) ", " else "") + "bucket[" + s._2 + "]=(" + s._1.zipWithIndex.foreach{ x => + str += (if (x._2 != 0) ", " else "") + x._1 + } + str += ")" + } + str += ")" + str + } +} + +private[spark] class ShuffleBlockGroupSizeArray extends Serializable { + var groupNum = 0 + private var groupSizeArray = Array.fill[GroupBucketSizes](32)(null) + + def apply(idx: Int) = if (idx >= groupSizeArray.length) null else groupSizeArray(idx) + + def update(idx: Int, elem: GroupBucketSizes) { + if (idx >= groupSizeArray.length){ + var newLen = groupSizeArray.length * 2 + while (idx >= newLen) + newLen = newLen * 2 + + val newArray = Array.fill[GroupBucketSizes](newLen)(null) + scala.compat.Platform.arraycopy(groupSizeArray, 0, newArray, 0, groupNum) + groupSizeArray = newArray + } + + if (idx >= groupNum) + groupNum = idx + 1 + + groupSizeArray(idx) = elem + } + + def +=(elem: GroupBucketSizes) { + this(groupNum) = elem + } + + def debugString = { + var str = "ShuffleBlockGroupSizeArray(" + for (i <- 0 until groupNum) { + str += (if (i != 0) str += ", " else "") + "group_" + i + "=" + (if (groupSizeArray(i) == null) "null" else groupSizeArray(i).debugString) + } + str + ")" + } +} + +private[spark] object ShuffleBlockGroupSizeArray { + def apply(xs: GroupBucketSizes*) = { + val sizes = new ShuffleBlockGroupSizeArray() + xs.foreach { x => + sizes += x + } + sizes + } +} + diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 2bf55ea9a9..917ef57958 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -16,7 +16,11 @@ import java.nio.ByteBuffer /** * The Mesos executor for Spark. */ -private[spark] class Executor(executorId: String, slaveHostname: String, properties: Seq[(String, String)]) extends Logging { +private[spark] class Executor(executorId: String, + slaveHostname: String, + properties: Seq[(String, String)], + cores: Int = 128) +extends Logging { // Application dependencies (added through SparkContext) that we've fetched so far on this node. // Each map holds the master's timestamp for the version of that file or JAR we got. @@ -39,7 +43,7 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert for ((key, value) <- properties) { System.setProperty(key, value) } - + // Create our ClassLoader and set it on this thread private val urlClassLoader = createClassLoader() private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader) @@ -72,12 +76,13 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert // Initialize Spark environment (using system properties read above) val env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false) + env.blockManager.maxShuffleGroups = cores SparkEnv.set(env) private val akkaFrameSize = env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size") // Start worker thread pool val threadPool = new ThreadPoolExecutor( - 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable]) + 1, cores, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable]) def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) { threadPool.execute(new TaskRunner(context, taskId, serializedTask)) diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala index ebe2ac68d8..cdddd004dd 100644 --- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala @@ -41,7 +41,7 @@ private[spark] class StandaloneExecutorBackend( case RegisteredExecutor(sparkProperties) => logInfo("Successfully registered with driver") // Make this host instead of hostPort ? - executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties) + executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties, cores) case RegisterExecutorFailed(message) => logError("Slave registration failed: " + message) diff --git a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala index 8d5194a737..5aa0ad6656 100644 --- a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala +++ b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala @@ -8,45 +8,41 @@ import io.netty.util.CharsetUtil import spark.Logging import spark.network.ConnectionManagerId +import spark.storage.BlockManagerId import scala.collection.JavaConverters._ - private[spark] class ShuffleCopier extends Logging { - def getBlock(host: String, port: Int, blockId: String, - resultCollectCallback: (String, Long, ByteBuf) => Unit) { + def getBlock(bmId: BlockManagerId, blockId: String, + resultCollectCallback: (BlockManagerId, String, Long, ByteBuf) => Unit) { - val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback) + val cmId = new ConnectionManagerId(bmId.host, bmId.nettyPort) + val handler = new ShuffleCopier.ShuffleClientHandler(bmId, resultCollectCallback) val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt val fc = new FileClient(handler, connectTimeout) try { fc.init() - fc.connect(host, port) + fc.connect(cmId.host, cmId.port) fc.sendRequest(blockId) fc.waitForClose() fc.close() } catch { // Handle any socket-related exceptions in FileClient case e: Exception => { - logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + " failed", e) + logError("Shuffle copy of block " + blockId + " from " + cmId.host + ":" + cmId.port + " failed", e) handler.handleError(blockId) } } } - def getBlock(cmId: ConnectionManagerId, blockId: String, - resultCollectCallback: (String, Long, ByteBuf) => Unit) { - getBlock(cmId.host, cmId.port, blockId, resultCollectCallback) - } - - def getBlocks(cmId: ConnectionManagerId, + def getBlocks(bmId: BlockManagerId, blocks: Seq[(String, Long)], - resultCollectCallback: (String, Long, ByteBuf) => Unit) { + resultCollectCallback: (BlockManagerId, String, Long, ByteBuf) => Unit) { for ((blockId, size) <- blocks) { - getBlock(cmId, blockId, resultCollectCallback) + getBlock(bmId, blockId, resultCollectCallback) } } } @@ -54,22 +50,22 @@ private[spark] class ShuffleCopier extends Logging { private[spark] object ShuffleCopier extends Logging { - private class ShuffleClientHandler(resultCollectCallBack: (String, Long, ByteBuf) => Unit) + private class ShuffleClientHandler(bmId: BlockManagerId, resultCollectCallBack: (BlockManagerId, String, Long, ByteBuf) => Unit) extends FileClientHandler with Logging { override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) { logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)"); - resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen)) + resultCollectCallBack(bmId, header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen)) } override def handleError(blockId: String) { if (!isComplete) { - resultCollectCallBack(blockId, -1, null) + resultCollectCallBack(bmId, blockId, -1, null) } } } - def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) { + def echoResultCollectCallBack(bmId: BlockManagerId, blockId: String, size: Long, content: ByteBuf) { if (size != -1) { logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"") } @@ -90,7 +86,7 @@ private[spark] object ShuffleCopier extends Logging { Executors.callable(new Runnable() { def run() { val copier = new ShuffleCopier() - copier.getBlock(host, port, file, echoResultCollectCallBack) + copier.getBlock(BlockManagerId("0", host, -1, port), file, echoResultCollectCallBack) } }) }).asJava diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index f7d60be5db..42ae2e803f 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -560,6 +560,7 @@ class DAGScheduler( mapOutputTracker.registerMapOutputs( stage.shuffleDep.get.shuffleId, stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray, + stage.shuffleBlockSizes, true) } clearCacheLocs() @@ -640,7 +641,7 @@ class DAGScheduler( for ((shuffleId, stage) <- shuffleToMapStage) { stage.removeOutputsOnExecutor(execId) val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray - mapOutputTracker.registerMapOutputs(shuffleId, locs, true) + mapOutputTracker.registerMapOutputs(shuffleId, locs, stage.shuffleBlockSizes, true) } if (shuffleToMapStage.isEmpty) { mapOutputTracker.incrementGeneration() diff --git a/core/src/main/scala/spark/scheduler/MapStatus.scala b/core/src/main/scala/spark/scheduler/MapStatus.scala index 203abb917b..1dbe6d5550 100644 --- a/core/src/main/scala/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/spark/scheduler/MapStatus.scala @@ -4,24 +4,51 @@ import spark.storage.BlockManagerId import java.io.{ObjectOutput, ObjectInput, Externalizable} /** - * Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the - * task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks. - * The map output sizes are compressed using MapOutputTracker.compressSize. + * Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address + * and sequence number for the map output location, as well as the current sizes of all segment + * files in the specific group (organized as 2-d array sizes[bucketId][segmentId]) for passing + * on to the reduce tasks. The sizes are compressed using MapOutputTracker.compressSize. */ -private[spark] class MapStatus(var location: BlockManagerId, var compressedSizes: Array[Byte]) +private[spark] class MapStatus(var location: BlockManagerId, var groupId: Int, var sequence: Int, + var compressedSizes: Array[Array[Byte]]) extends Externalizable { - def this() = this(null, null) // For deserialization only + def this() = this(null, -1, -1, null) // For deserialization only def writeExternal(out: ObjectOutput) { location.writeExternal(out) + out.writeInt(groupId) + out.writeInt(sequence) out.writeInt(compressedSizes.length) - out.write(compressedSizes) + compressedSizes.foreach{ s => + out.writeInt(s.length) + out.write(s) + } } def readExternal(in: ObjectInput) { location = BlockManagerId(in) - compressedSizes = new Array[Byte](in.readInt()) - in.readFully(compressedSizes) + groupId = in.readInt() + sequence = in.readInt() + val len = in.readInt() + compressedSizes = Array.tabulate[Array[Byte]](len) { idx => + val n = in.readInt() + val sizes = new Array[Byte](n) + in.readFully(sizes) + sizes + } + } + + def debugString = { + var str = "MapStatus(location=" + location + ", groupId=" + groupId + ", sequence=" + sequence + ", compressedSizes=(" + compressedSizes.zipWithIndex.foreach { s => + str += "bucket[" + s._2 + "]=(" + s._1.zipWithIndex.foreach { x=> + str += (if (x._2 != 0) ", " else "") + x._1 + } + str += ")" + } + str += ")" + str } } diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 95647389c3..074ed11efc 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -129,52 +129,42 @@ private[spark] class ShuffleMapTask( val taskContext = new TaskContext(stageId, partition, attemptId) metrics = Some(taskContext.taskMetrics) - + + import ShuffleBlockManager._ val blockManager = SparkEnv.get.blockManager var shuffle: ShuffleBlocks = null - var buckets: ShuffleWriterGroup = null + var group: ShuffleWriterGroup = null try { // Obtain all the block writers for shuffle blocks. + this.logDebug ("ShufCon - " + this + "run attemp " + attemptId) val ser = SparkEnv.get.serializerManager.get(dep.serializerClass) shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser) - buckets = shuffle.acquireWriters(partition) + group = shuffle.acquireWriters(partition) // Write the map output to its associated buckets. for (elem <- rdd.iterator(split, taskContext)) { val pair = elem.asInstanceOf[(Any, Any)] val bucketId = dep.partitioner.getPartition(pair._1) - buckets.writers(bucketId).write(pair) + group.writers(bucketId).write(pair) } - // Commit the writes. Get the size of each bucket block (total block size). - var totalBytes = 0L - val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter => - writer.commit() - writer.close() - val size = writer.size() - totalBytes += size - MapOutputTracker.compressSize(size) - } + val totalBytes = shuffle.commitWrites(group) + val compressedSizes = group.getBucketSizes.map(_.map(MapOutputTracker.compressSize(_))) // Update shuffle metrics. val shuffleMetrics = new ShuffleWriteMetrics shuffleMetrics.shuffleBytesWritten = totalBytes metrics.get.shuffleWriteMetrics = Some(shuffleMetrics) - - return new MapStatus(blockManager.blockManagerId, compressedSizes) + + val status = new MapStatus(blockManager.blockManagerId, group.groupId, group.sequence, compressedSizes) + this.logDebug ("ShufCon - " + this + " return " + status.debugString) + return status } catch { case e: Exception => - // If there is an exception from running the task, revert the partial writes - // and throw the exception upstream to Spark. - if (buckets != null) { - buckets.writers.foreach(_.revertPartialWrites()) - } throw e } finally { - // Release the writers back to the shuffle block manager. - if (shuffle != null && buckets != null) { - shuffle.releaseWriters(buckets) - } + if (shuffle != null && group != null) + shuffle.releaseWriters(group) // Execute the callbacks on task completion. taskContext.executeOnCompleteCallbacks() } diff --git a/core/src/main/scala/spark/scheduler/Stage.scala b/core/src/main/scala/spark/scheduler/Stage.scala index 7fc9e13fd9..a2ec0497e2 100644 --- a/core/src/main/scala/spark/scheduler/Stage.scala +++ b/core/src/main/scala/spark/scheduler/Stage.scala @@ -1,6 +1,7 @@ package spark.scheduler import java.net.URI +import scala.collection.mutable.HashMap import spark._ import spark.storage.BlockManagerId @@ -29,7 +30,8 @@ private[spark] class Stage( val isShuffleMap = shuffleDep != None val numPartitions = rdd.partitions.size - val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) + val outputLocs = Array.fill[List[MapOutputLocation]](numPartitions)(Nil) + val shuffleBlockSizes = new HashMap[BlockManagerId, ShuffleBlockGroupSizeArray] var numAvailableOutputs = 0 /** When first task was submitted to scheduler. */ @@ -47,9 +49,29 @@ private[spark] class Stage( def addOutputLoc(partition: Int, status: MapStatus) { val prevList = outputLocs(partition) - outputLocs(partition) = status :: prevList + outputLocs(partition) = new MapOutputLocation(status) :: prevList if (prevList == Nil) numAvailableOutputs += 1 + + val shufBlockSize = { + shuffleBlockSizes.get(status.location) match { + case Some(size) => size + case None => + val size = new ShuffleBlockGroupSizeArray() + shuffleBlockSizes.put(status.location, size) + size + } + } + + val groupSize = shufBlockSize(status.groupId) + if (groupSize == null) + shufBlockSize(status.groupId) = new GroupBucketSizes(status) + else if (status.sequence > groupSize.sequence) { + groupSize.sequence = status.sequence + groupSize.bucketSizes = status.compressedSizes + } + + logDebug ("ShufCon - addOutputLoc for map " + partition + ": " + this.debugString) } def removeOutputLoc(partition: Int, bmAddress: BlockManagerId) { @@ -61,11 +83,20 @@ private[spark] class Stage( } } + private def blockManagerOnExecutor(bmAddress: BlockManagerId, execId: String) = { + if (bmAddress.executorId == execId) { + shuffleBlockSizes.remove(bmAddress) + true + } + else + false + } + def removeOutputsOnExecutor(execId: String) { var becameUnavailable = false for (partition <- 0 until numPartitions) { val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.location.executorId == execId) + val newList = prevList.filterNot(x=>blockManagerOnExecutor(x.location, execId)) outputLocs(partition) = newList if (prevList != Nil && newList == Nil) { becameUnavailable = true @@ -76,6 +107,7 @@ private[spark] class Stage( logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format( this, execId, numAvailableOutputs, numPartitions, isAvailable)) } + logDebug ("ShufCon - removeOutputsOnExecutor for execId " + execId + ": " + this.debugString) } def newAttemptId(): Int = { @@ -86,6 +118,17 @@ private[spark] class Stage( def origin: String = rdd.origin + def debugString = { + var str = "Stage(id=" + id + ", outputLocs=" + outputLocs.zipWithIndex.foreach { s => + str += (if (s._2 != 0) ", " else "") + "map[" + s._2 + "]=" + (if(s._1 == Nil) "Nil" else s._1.head.debugString) + } + str += "), shuffleBlockSizes=(" + shuffleBlockSizes.foreach { s => str += "(" + s._1 + ", " + s._2.debugString + ") " } + str += ")" + str + } + override def toString = "Stage " + id override def hashCode(): Int = id diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index 93d4318b29..218c7e54ca 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -48,10 +48,10 @@ private[spark] class LocalActor(localScheduler: LocalScheduler, var freeCores: I private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext) extends TaskScheduler with Logging { - var attemptId = new AtomicInteger(0) var threadPool = Utils.newDaemonFixedThreadPool(threads) val env = SparkEnv.get + env.blockManager.maxShuffleGroups = threads var listener: TaskSchedulerListener = null // Application dependencies (added through SparkContext) that we've fetched so far on this node. diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala index bec876213e..72a412b457 100644 --- a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala @@ -9,9 +9,7 @@ import scala.collection.mutable.Queue import io.netty.buffer.ByteBuf -import spark.Logging -import spark.Utils -import spark.SparkException +import spark.{Logging, Utils, SparkException, MapOutputLocation, FetchFailedException} import spark.network.BufferMessage import spark.network.ConnectionManagerId import spark.network.netty.ShuffleCopier @@ -30,7 +28,7 @@ import spark.serializer.Serializer */ private[storage] -trait BlockFetcherIterator extends Iterator[(String, Option[Iterator[Any]])] +trait BlockFetcherIterator extends Iterator[(BlockManagerId, String, Option[Iterator[Any]])] with Logging with BlockFetchTracker { def initialize() } @@ -47,17 +45,21 @@ object BlockFetcherIterator { // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize // the block (since we want all deserializaton to happen in the calling thread); can also // represent a fetch failure if size == -1. - class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) { + class FetchResult(val bmId: BlockManagerId, val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) { def failed: Boolean = size == -1 } class BasicBlockFetcherIterator( private val blockManager: BlockManager, + val shuffleId: Int, + val reduceId: Int, val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], + val mapLocations: Array[MapOutputLocation], serializer: Serializer) extends BlockFetcherIterator { import blockManager._ + import ShuffleBlockManager._ private var _remoteBytesRead = 0l private var _remoteFetchTime = 0l @@ -115,8 +117,8 @@ object BlockFetcherIterator { "Unexpected message " + blockMessage.getType + " received from " + cmId) } val blockId = blockMessage.getId - results.put(new FetchResult(blockId, sizeMap(blockId), - () => dataDeserialize(blockId, blockMessage.getData, serializer))) + results.put(new FetchResult(req.address, blockId, sizeMap(blockId), + () => shuffleBlockDeserialize(blockManager, req.address, mapLocations, blockId, blockMessage.getData, serializer))) _remoteBytesRead += req.size logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } @@ -124,7 +126,7 @@ object BlockFetcherIterator { case None => { logError("Could not get block(s) from " + cmId) for ((blockId, size) <- req.blocks) { - results.put(new FetchResult(blockId, -1, null)) + results.put(new FetchResult(req.address, blockId, -1, null)) } } } @@ -184,10 +186,11 @@ object BlockFetcherIterator { // these all at once because they will just memory-map some files, so they won't consume // any memory that might exceed our maxBytesInFlight for (id <- localBlocksToFetch) { - getLocalFromDisk(id, serializer) match { - case Some(iter) => { + getLocalBytes(id) match { + case Some(bytes) => { // Pass 0 as size since it's not in flight - results.put(new FetchResult(id, 0, () => iter)) + results.put(new FetchResult(blockManagerId, id, 0, + () => shuffleBlockDeserialize(blockManager, blockManagerId, mapLocations, id, bytes, serializer))) logDebug("Got local block " + id) } case None => { @@ -223,7 +226,7 @@ object BlockFetcherIterator { override def hasNext: Boolean = resultsGotten < _numBlocksToFetch - override def next(): (String, Option[Iterator[Any]]) = { + override def next(): (BlockManagerId, String, Option[Iterator[Any]]) = { resultsGotten += 1 val startFetchWait = System.currentTimeMillis() val result = results.take() @@ -234,7 +237,7 @@ object BlockFetcherIterator { (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { sendRequest(fetchRequests.dequeue()) } - (result.blockId, if (result.failed) None else Some(result.deserialize())) + (result.bmId, result.blockId, if (result.failed) None else Some(result.deserialize())) } // Implementing BlockFetchTracker trait. @@ -249,12 +252,16 @@ object BlockFetcherIterator { class NettyBlockFetcherIterator( blockManager: BlockManager, + shuffleId: Int, + reduceId: Int, blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], + mapLocations: Array[MapOutputLocation], serializer: Serializer) - extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer) { + extends BasicBlockFetcherIterator( + blockManager, shuffleId, reduceId, blocksByAddress, mapLocations, serializer) { import blockManager._ - + import ShuffleBlockManager._ val fetchRequestsSync = new LinkedBlockingQueue[FetchRequest] private def startCopiers(numCopiers: Int): List[_ <: Thread] = { @@ -285,17 +292,16 @@ object BlockFetcherIterator { override protected def sendRequest(req: FetchRequest) { - def putResult(blockId: String, blockSize: Long, blockData: ByteBuf) { - val fetchResult = new FetchResult(blockId, blockSize, - () => dataDeserialize(blockId, blockData.nioBuffer, serializer)) + def putResult(bmId: BlockManagerId, blockId: String, blockSize: Long, blockData: ByteBuf) { + val fetchResult = new FetchResult(bmId, blockId, blockSize, + () => shuffleBlockDeserialize(blockManager, bmId, mapLocations, blockId, blockData.nioBuffer, serializer)) results.put(fetchResult) } logDebug("Sending request for %d blocks (%s) from %s".format( req.blocks.size, Utils.memoryBytesToString(req.size), req.address.host)) - val cmId = new ConnectionManagerId(req.address.host, req.address.nettyPort) val cpier = new ShuffleCopier - cpier.getBlocks(cmId, req.blocks, putResult) + cpier.getBlocks(req.address, req.blocks, putResult) logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host ) } @@ -319,11 +325,11 @@ object BlockFetcherIterator { logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") } - override def next(): (String, Option[Iterator[Any]]) = { + override def next(): (BlockManagerId, String, Option[Iterator[Any]]) = { resultsGotten += 1 val result = results.take() // If all the results has been retrieved, copiers will exit automatically - (result.blockId, if (result.failed) None else Some(result.deserialize())) + (result.bmId, result.blockId, if (result.failed) None else Some(result.deserialize())) } } // End of NettyBlockFetcherIterator diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 9b39d3aadf..5a906a0bc7 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -4,6 +4,7 @@ import java.io.{InputStream, OutputStream} import java.nio.{ByteBuffer, MappedByteBuffer} import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet} +import scala.collection.JavaConversions._ import akka.actor.{ActorSystem, Cancellable, Props} import akka.dispatch.{Await, Future} @@ -14,7 +15,7 @@ import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream -import spark.{Logging, SparkEnv, SparkException, Utils} +import spark.{Logging, SparkEnv, SparkException, Utils, MapOutputLocation} import spark.network._ import spark.serializer.Serializer import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStampedHashMap} @@ -85,6 +86,7 @@ private[spark] class BlockManager( } } + var maxShuffleGroups = 0 val shuffleBlockManager = new ShuffleBlockManager(this) private val blockInfo = new TimeStampedHashMap[String, BlockInfo] @@ -288,6 +290,7 @@ private[spark] class BlockManager( * never deletes (recent) items. */ def getLocalFromDisk(blockId: String, serializer: Serializer): Option[Iterator[Any]] = { + shuffleBlockManager.closeBlock(blockId) diskStore.getValues(blockId, serializer).orElse( sys.error("Block " + blockId + " not found on disk, though it should be")) } @@ -383,6 +386,8 @@ private[spark] class BlockManager( // As an optimization for map output fetches, if the block is for a shuffle, return it // without acquiring a lock; the disk store never deletes (recent) items so this should work if (ShuffleBlockManager.isShuffle(blockId)) { + //close the shuffle Writers for blockId + shuffleBlockManager.closeBlock(blockId) return diskStore.getBytes(blockId) match { case Some(bytes) => Some(bytes) @@ -484,15 +489,20 @@ private[spark] class BlockManager( * fashion as they're received. Expects a size in bytes to be provided for each block fetched, * so that we can control the maxMegabytesInFlight for the fetch. */ - def getMultiple( - blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], serializer: Serializer) + def getMultiple(blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], + shuffleId: Int, + reduceId: Int, + mapLocations: Array[MapOutputLocation], + serializer: Serializer) : BlockFetcherIterator = { val iter = if (System.getProperty("spark.shuffle.use.netty", "false").toBoolean) { - new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer) + new BlockFetcherIterator.NettyBlockFetcherIterator( + this, shuffleId, reduceId, blocksByAddress, mapLocations, serializer) } else { - new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer) + new BlockFetcherIterator.BasicBlockFetcherIterator( + this, shuffleId, reduceId, blocksByAddress, mapLocations, serializer) } iter.initialize() @@ -511,9 +521,9 @@ private[spark] class BlockManager( * This is currently used for writing shuffle files out. Callers should handle error * cases. */ - def getDiskBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int) + def getShuffleBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int) : BlockObjectWriter = { - val writer = diskStore.getBlockWriter(blockId, serializer, bufferSize) + val writer = diskStore.getShuffleBlockWriter(blockId, serializer, bufferSize) writer.registerCloseEventHandler(() => { val myInfo = new BlockInfo(StorageLevel.DISK_ONLY, false) blockInfo.put(blockId, myInfo) @@ -934,9 +944,10 @@ private[spark] class BlockManager( def dataDeserialize( blockId: String, bytes: ByteBuffer, - serializer: Serializer = defaultSerializer): Iterator[Any] = { + serializer: Serializer = defaultSerializer, + dispose: Boolean = true): Iterator[Any] = { bytes.rewind() - val stream = wrapForCompression(blockId, new ByteBufferInputStream(bytes, true)) + val stream = wrapForCompression(blockId, new ByteBufferInputStream(bytes, dispose)) serializer.newInstance().deserializeStream(stream).asIterator } diff --git a/core/src/main/scala/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/spark/storage/BlockObjectWriter.scala index 42e2b07d5c..71e449ad75 100644 --- a/core/src/main/scala/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/spark/storage/BlockObjectWriter.scala @@ -14,7 +14,7 @@ abstract class BlockObjectWriter(val blockId: String) { var closeEventHandler: () => Unit = _ - def open(): BlockObjectWriter + def open(id: Int, seq: Int): BlockObjectWriter def close() { closeEventHandler() @@ -27,16 +27,16 @@ abstract class BlockObjectWriter(val blockId: String) { } /** - * Flush the partial writes and commit them as a single atomic block. Return the - * number of bytes written for this commit. + * Persist the current chunk on disk. Return the number of + * bytes written for this commit. */ def commit(): Long /** - * Reverts writes that haven't been flushed yet. Callers should invoke this function - * when there are runtime exceptions. + * Complete the current chunk - called after all the related + * commits of are successful */ - def revertPartialWrites() + def complete() /** * Writes an object. diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index da859eebcb..d8d03aaf3d 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -8,6 +8,7 @@ import java.util.{Random, Date} import java.text.SimpleDateFormat import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.ListBuffer import it.unimi.dsi.fastutil.io.FastBufferedOutputStream @@ -25,75 +26,83 @@ import spark.network.netty.PathResolver private class DiskStore(blockManager: BlockManager, rootDirs: String) extends BlockStore(blockManager) with Logging { - class DiskBlockObjectWriter(blockId: String, serializer: Serializer, bufferSize: Int) - extends BlockObjectWriter(blockId) { + class ShuffleBlockObjectWriter(blockId: String, serializer: Serializer, bufferSize: Int) + extends BlockObjectWriter(blockId) with Logging { - private val f: File = createFile(blockId /*, allowAppendExisting */) - - // The file channel, used for repositioning / truncating the file. - private var channel: FileChannel = null - private var bs: OutputStream = null + //create a new file when the writer is created + private var f: File = createFile(blockId) private var objOut: SerializationStream = null - private var lastValidPosition = 0L - private var initialized = false + private var written = false + + //list of end positions of completed chunks + private val chunkPositionList = ListBuffer[Int](0) + private var currentPosition = 0L - override def open(): DiskBlockObjectWriter = { + private def lastvalidPosition = chunkPositionList.last + + //open by first reverting uncompleted chunks + private def _open(): OutputStream = { val fos = new FileOutputStream(f, true) - channel = fos.getChannel() - bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos, bufferSize)) - objOut = serializer.newInstance().serializeStream(bs) - initialized = true + val channel = fos.getChannel() + val lastValidPosition = chunkPositionList.last + if (channel.size() > lastValidPosition) + channel.truncate(lastValidPosition) + + new FastBufferedOutputStream(fos) + } + + override def open(id: Int, seq: Int): ShuffleBlockObjectWriter = { + written = false + + val bs = _open() + ShuffleBlockManager.writeChunkHeader(id, seq, bs) + objOut = serializer.newInstance().serializeStream(blockManager.wrapForCompression(blockId, bs)) this } + private def _close() { + objOut.close() + objOut = null + } + override def close() { - if (initialized) { - objOut.close() - bs.close() - channel = null - bs = null - objOut = null - } + if (objOut != null) + _close() + + val bs = _open() + ShuffleBlockManager.writeSegmentTailer(chunkPositionList, bs) + bs.close() + // Invoke the close callback handler. super.close() } override def isOpen: Boolean = objOut != null - // Flush the partial writes, and set valid length to be the length of the entire file. + // Commit by close the stream // Return the number of bytes written for this commit. override def commit(): Long = { - if (initialized) { - // NOTE: Flush the serializer first and then the compressed/buffered output stream - objOut.flush() - bs.flush() - val prevPos = lastValidPosition - lastValidPosition = channel.position() - lastValidPosition - prevPos - } else { - // lastValidPosition is zero if stream is uninitialized - lastValidPosition - } + _close() + if (written == true) + currentPosition = f.length() + val lastValidPosition = chunkPositionList.last + currentPosition - lastValidPosition } - override def revertPartialWrites() { - if (initialized) { - // Discard current writes. We do this by flushing the outstanding writes and - // truncate the file to the last valid position. - objOut.flush() - bs.flush() - channel.truncate(lastValidPosition) + //complete the chunk by recording its end position + override def complete() { + val lastValidPosition = chunkPositionList.last + if (lastValidPosition != currentPosition) { + chunkPositionList += currentPosition.toInt } } - + override def write(value: Any) { - if (!initialized) { - open() - } + written = true objOut.writeObject(value) } - override def size(): Long = lastValidPosition + override def size(): Long = chunkPositionList.last } private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 @@ -108,11 +117,12 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) addShutdownHook() - def getBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int) + def getShuffleBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int) : BlockObjectWriter = { - new DiskBlockObjectWriter(blockId, serializer, bufferSize) + new ShuffleBlockObjectWriter(blockId, serializer, bufferSize) } + override def getSize(blockId: String): Long = { getFile(blockId).length() } @@ -210,10 +220,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = { val file = getFile(blockId) if (!allowAppendExisting && file.exists()) { - // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task - // was rescheduled on the same machine as the old task. - logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting") - file.delete() + throw new Exception("File for block " + blockId + " already exists on disk: " + file) } file } @@ -303,6 +310,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) if (!blockId.startsWith("shuffle_")) { return null } + DiskStore.this.blockManager.shuffleBlockManager.closeBlock(blockId) DiskStore.this.getFile(blockId).getAbsolutePath() } } diff --git a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala index 44638e0c2d..35ca265335 100644 --- a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala @@ -2,49 +2,349 @@ package spark.storage import spark.serializer.Serializer +import java.util.concurrent.{ConcurrentLinkedQueue,ConcurrentHashMap} +import java.util.concurrent.atomic.AtomicInteger +import java.nio.ByteBuffer +import java.io.OutputStream -private[spark] -class ShuffleWriterGroup(val id: Int, val writers: Array[BlockObjectWriter]) +import spark.util.MetadataCleaner +import spark.MapOutputLocation +import scala.collection.JavaConversions +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.ListBuffer +import spark.Logging +import spark.SparkException +import scala.collection.JavaConverters._ +/** + * A shuffle block manager (ShuffleBlockManager) runs on each slave and maintains, for each shuffle, + * a shuffle block pool (ShuffleBlocksPool), which in turn maintains a list of shuffle block group + * (ShuffleWriterGroup). A group maintains a list of bucket writers (ShuffleBucketWriter), each for + * a different bucket (or reduce partition). Each bucket is organized as a list of segments, each + * of which is a shuffle file containing a list of chunks followed by a list of chunk sizes: + * (), …, , ). + * In each bucket, only the latest segment can be written, and previous segments are read-only. + * + * Each chunk data is a completely encoded (compressed and serialized) output of a specific map task + * (as identified by mapId) for that bucket; sequence is a unique (monotonously increasing) sequence number + * assigned by the shuffle block pool to each map task. Different chunks can be separately located and decoded + * using the chunk size list. + * + * The each segment writer maintains a chunk list (containing the end positions of all successfully committed chunks). + * A map task commits its writes by + * (1) Persisting current chunks of all buckets on disk + * (2) Completing current chunks by appending their end positions to respective segment writers + * + * Before a map task can append to the segment file, uncompleted chunks are reverted by truncating the segment file + * to the end position of its last completed chunk. Similarly, before a reduce task can fetch the segment file, + * the shuffle manager will close the segment file by (1) reverting uncompleted chunks and (2) appending the chunk size + * list to the file. A previously closed segment file is read-only, and a new segment file will be created for writing + * new data to the bucket. + */ private[spark] -trait ShuffleBlocks { - def acquireWriters(mapId: Int): ShuffleWriterGroup - def releaseWriters(group: ShuffleWriterGroup) -} +class ShuffleBlockManager(val blockManager: BlockManager) extends Logging { + initLogging() + val metadataCleaner = new MetadataCleaner("ShuffleBlockManager", this.cleanup) -private[spark] -class ShuffleBlockManager(blockManager: BlockManager) { - + val pools = new ConcurrentHashMap[Int, ShuffleBlocksPool] + import ShuffleBlockManager._ + def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): ShuffleBlocks = { new ShuffleBlocks { + val pool = getPool(shuffleId) + // Get a group of writers for a map task. - override def acquireWriters(mapId: Int): ShuffleWriterGroup = { + override def acquireWriters(mapId: Int) = pool.acquireGroup(mapId, numBuckets, serializer) + override def releaseWriters(group: ShuffleWriterGroup) = pool.releaseGroup(group) + override def commitWrites(group: ShuffleWriterGroup) = pool.commitWrites(group) + } + } + + private def getPool(shuffleId: Int) : ShuffleBlocksPool = { + val pool = pools.get(shuffleId) + if (pool == null) { + pools.putIfAbsent(shuffleId, new ShuffleBlocksPool(shuffleId)) + pools.get(shuffleId) + } + else + pool + } + + def closeBlock(blockId: String) { + logDebug("ShufCon - ShuffleBlockManager closeBlock " + blockId) + + val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)_([0-9]*)".r + blockId match { + case regex(shuffleId, groupId, bucketId, segmentId) => + logDebug("closeBlock shuffleId: " + shuffleId + ", groupId: " + groupId + ", bucketId: " + bucketId + ", segmentId: " + segmentId) + val pool = getPool(shuffleId.toInt) + if (pool != null) { + pool.closeBlock(groupId.toInt, bucketId.toInt, segmentId.toInt) + } + else + throw new SparkException( + "Failed to get shuffle block " + blockId + ", which is not stored in the ShuffleBlockManager") + + case _ => + throw new SparkException( + "Failed to get block " + blockId + ", which is not a shuffle block") + } + } + + def cleanup(cleanupTime: Long){ + pools.asScala.retain( (shuffleId, pool) => !pool.allGroupsClosed ) + } + + //a shuffle block pool maintaining a list of shuffle block groups (ShuffleWriterGroup) + class ShuffleBlocksPool (val shuffleId: Int) extends Logging { + val allGroups = Array.fill[ShuffleWriterGroup](blockManager.maxShuffleGroups)(null) + val freeGroups = new ConcurrentLinkedQueue[ShuffleWriterGroup] + val nextGroupId = new AtomicInteger(0) + val nextSequence = new AtomicInteger(0) + + def acquireGroup(mapId: Int, numBuckets: Int, serializer: Serializer) : ShuffleWriterGroup = { + var group = freeGroups.poll() + if (group == null) { val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024 - val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => - val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId) - blockManager.getDiskBlockWriter(blockId, serializer, bufferSize) + val groupId = nextGroupId.getAndIncrement() + val writers = Array.tabulate[ShuffleBucketWriter](numBuckets) { bucketId => + new ShuffleBucketWriter(blockManager, shuffleId, groupId, bucketId, serializer, bufferSize) + } + group = new ShuffleWriterGroup(shuffleId, groupId, writers) + allGroups(groupId) = group + } + group.open(mapId, nextSequence.getAndIncrement()) + group + } + + def commitWrites(group: ShuffleWriterGroup) = { + var size = 0L + //2 phase commit across all writers + size = group.writers.map(_.commit).sum + group.writers.foreach(_.complete) + size + } + + def releaseGroup(group: ShuffleWriterGroup) { + freeGroups.add(group) + group.writers.foreach(_.markDone) + } + + def closeBlock(groupId:Int, bucketId:Int, segmentId: Int) { + val writer = allGroups(groupId.toInt).writers(bucketId.toInt) + writer.close(segmentId) + } + + private def allGroupsReleased = (freeGroups.size == nextGroupId.get - 1) + private def groupClosed(group: ShuffleWriterGroup) = (group == null || group.writers.forall(!_.isOpen)) + def allGroupsClosed = (allGroupsReleased && allGroups.forall(groupClosed(_))) + } +} + +private[spark] object ShuffleBlockManager extends Logging{ + def blockId(shuffleId: Int, groupId: Int, bucketId: Int, segmentId: Int): String = { + "shuffle_" + shuffleId + "_" + groupId + "_" + bucketId + "_" + segmentId + } + + def isShuffle(blockId: String): Boolean = blockId.startsWith("shuffle_") + + private def writeInt (bs: OutputStream, value: Int) { + bs.write(value & 0xFF) + bs.write((value >>> 8) & 0xFF) + bs.write((value >>> 16) & 0xFF) + bs.write((value >>> 24) & 0xFF) + } + + private def readInt (buffer: ByteBuffer) = { + var result = buffer.get() & 0xFF + result |= (buffer.get() & 0xFF) << 8 + result |= (buffer.get() & 0xFF) << 16 + result |= (buffer.get() & 0xFF) << 24 + result + } + + def writeChunkHeader(id: Int, seq: Int, bs: OutputStream) { + writeInt(bs, id) + writeInt(bs, seq) + } + + private def readChunkHeader (bytes: ByteBuffer) = { + val mapId = readInt (bytes) + val sequence = readInt (bytes) + (mapId, sequence) + } + + def writeSegmentTailer(chunkPositionList: ListBuffer[Int], bs: OutputStream) { + val itr = chunkPositionList.toIterator + var prev = itr.next() + var count = 0 + itr.foreach {pos => + val size = pos - prev + writeInt(bs, size) + count += 1 + prev = pos + } + writeInt(bs, count) + } + + private def readSegmentTailer (bytes: ByteBuffer) = { + val limit = bytes.limit() + bytes.position(limit - 4) + val count = readInt (bytes) + + val chunkSizeList = ListBuffer[Int]() + bytes.position(limit - 4 - count * 4) + for (i <- 1 to count) { + val len = readInt (bytes) + chunkSizeList += len + } + chunkSizeList + } + + def shuffleBlockDeserialize(blockManager: BlockManager, bmId: BlockManagerId, mapLocations: Array[MapOutputLocation], + blockId: String, bytes: ByteBuffer, serializer: Serializer): Iterator[Any] = { + var pos = bytes.position + logDebug("ShufCon - shuffleBlockDeserialize: " + blockId + "(" + pos + " - " + bytes.limit + ")") + val chunkSizeList = readSegmentTailer (bytes) + val itrs = ListBuffer[Iterator[Any]]() + chunkSizeList.foreach { size => + bytes.position(pos).limit(pos + size) + val (mapId, sequence) = readChunkHeader(bytes) + val loc = mapLocations(mapId) + if (loc.location == bmId && loc.sequence == sequence) { + val block = bytes.slice() + itrs += blockManager.dataDeserialize(blockId, block, serializer, false) + } + pos += size + } + + new Iterator[Any] { + val iter = itrs.toIterator.flatMap(x=>x) + def hasNext() = iter.hasNext || {BlockManager.dispose(bytes); logDebug("ShufCon - dispose " + blockId); false} + def next() = iter.next() + } + } + + //a shuffle bucket writer maintaining a list of segments + class ShuffleBucketWriter(val blockManager: BlockManager, val shuffleId: Int, val groupId: Int, val bucketId: Int, + val serializer: Serializer, val bufferSize: Int) extends Logging { + private val segmentSizeLimit = System.getProperty("spark.shuffle.file.szie.limit", "256000000").toInt + private var writer = blockManager.getShuffleBlockWriter( + ShuffleBlockManager.blockId(shuffleId, groupId, bucketId, 0), serializer, bufferSize) + + private var nextSegmentId = 1 + private val prevSegmentSizes = ListBuffer[Long]() + + import BucketState._ + private var state = OPEN + + def getBucketSizes() = (prevSegmentSizes ++ List(writer.size)).toArray + + private def newSegmentWriter() = { + val size = writer.size + prevSegmentSizes += size + nextSegmentId += 1 + + blockManager.getShuffleBlockWriter( + ShuffleBlockManager.blockId(shuffleId, groupId, bucketId, nextSegmentId-1), serializer, bufferSize) + } + + def open(mapId: Int, seq: Int) = { + state.synchronized { + if (state == CLOSED) { + writer = newSegmentWriter() } - new ShuffleWriterGroup(mapId, writers) + else if (writer.size >= segmentSizeLimit) { + writer.close() + state = CLOSED + writer = newSegmentWriter() + logDebug ("ShufCon - " + this.debugString + " newSegmentWriter") + } + state = WRITING } + writer.open(mapId, seq) + logDebug ("ShufCon - " + this.debugString + " open") + } + + def close(segmentId: Int) { + state.synchronized { + if (segmentId == nextSegmentId - 1) { + state match { + case OPEN => + writer.close() + state = CLOSED + logDebug ("ShufCon - " + this.debugString + " close ") + case WRITING => + val blockId = ShuffleBlockManager.blockId(shuffleId, groupId, bucketId, segmentId) + throw new SparkException("Failed to close block " + blockId + ", which is currently being wrriten") + case CLOSED => + logDebug ("ShufCon - " + this.debugString + " already closed ") + } + } + } + } - override def releaseWriters(group: ShuffleWriterGroup) = { - // Nothing really to release here. + def markDone() { + state.synchronized { + state = OPEN } + logDebug ("ShufCon - " + this.debugString + " markDone ") } + + def write (value: Any) { + writer.write(value) + logDebug ("ShufCon - " + this.debugString + " write " + value) + } + + def commit() = { + val size = writer.commit() + logDebug ("ShufCon - " + this.debugString + " commit ") + size + } + + def complete() { + writer.complete() + logDebug ("ShufCon - " + this.debugString + " complete ") + } + + def isOpen = (state != CLOSED) + private def currentSegmentId = nextSegmentId - 1 + def debugString = "ShuffleBucketWriter(shuffleId=" + shuffleId + ", groupId=" + groupId + ", bucketId=" + + bucketId + ", segmentId=" + currentSegmentId + ", state=" + state + ")" + } + + object BucketState extends Enumeration { + type BucketState = Value + val OPEN = Value("OPEN") + val WRITING = Value("WRITING") + val CLOSED = Value("CLOSED") } -} + //A shuffle block group that maintains a group of shuffle bucket writers (ShuffleBucketWriter) + class ShuffleWriterGroup(val shuffleId: Int, val groupId: Int, val writers: Array[ShuffleBucketWriter]) extends Logging { + var mapId = -1 + var sequence = -1 + + def open(id: Int, seq: Int) { + mapId = id + sequence = seq + writers.foreach { _.open(id, seq) } + } -private[spark] -object ShuffleBlockManager { + def getBucketSizes() = { + val sizes = writers.map(_.getBucketSizes()) + sizes + } - // Returns the block id for a given shuffle block. - def blockId(shuffleId: Int, bucketId: Int, groupId: Int): String = { - "shuffle_" + shuffleId + "_" + groupId + "_" + bucketId + def debugString = "ShuffleWriterGroup(shuffleId=" + shuffleId + ", groupId=" + groupId + + " , mapId=" + mapId + ", sequence=" + sequence + ")" } - // Returns true if the block is a shuffle block. - def isShuffle(blockId: String): Boolean = blockId.startsWith("shuffle_") -} + trait ShuffleBlocks { + def acquireWriters(mapId: Int): ShuffleWriterGroup + def commitWrites(group: ShuffleWriterGroup): Long + def releaseWriters(group: ShuffleWriterGroup) + } +} \ No newline at end of file diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index 6e585e1c3a..9558054f63 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -2,6 +2,7 @@ package spark import org.scalatest.FunSuite +import scala.collection.mutable.HashMap import akka.actor._ import spark.scheduler.MapStatus import spark.storage.BlockManagerId @@ -41,17 +42,24 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val tracker = new MapOutputTracker() tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker))) tracker.registerShuffle(10, 2) + val compressedSize1000 = MapOutputTracker.compressSize(1000L) val compressedSize10000 = MapOutputTracker.compressSize(10000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) val size10000 = MapOutputTracker.decompressSize(compressedSize10000) - tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0), - Array(compressedSize1000, compressedSize10000))) - tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0), - Array(compressedSize10000, compressedSize1000))) - val statuses = tracker.getServerStatuses(10, 0) - assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000, 0), size1000), - (BlockManagerId("b", "hostB", 1000, 0), size10000))) + + val bmIdA = BlockManagerId("a", "hostA", 1000, 0) + val bmIdB = BlockManagerId("b", "hostB", 1000, 0) + val mapLocations = Array[MapOutputLocation](new MapOutputLocation(bmIdA, 1), new MapOutputLocation(bmIdB, 1)) + val bucketA = new GroupBucketSizes(1, Array(Array[Byte](compressedSize1000))) + val bucketB = new GroupBucketSizes(1, Array(Array[Byte](compressedSize10000))) + val segmentSizes = HashMap[BlockManagerId, ShuffleBlockGroupSizeArray]((bmIdA, ShuffleBlockGroupSizeArray(bucketA)), + (bmIdB, ShuffleBlockGroupSizeArray(bucketB))) + tracker.registerMapOutputs(10, mapLocations, segmentSizes) + + val (statuses, sizes) = tracker.getServerStatuses(10, 0) + assert(statuses.toSeq === Seq(new MapOutputLocation(bmIdA, 1), new MapOutputLocation(bmIdB, 1))) + assert(sizes.toSeq === Seq((bmIdA, Seq((0, Seq(size1000)))), (bmIdB, Seq((0, Seq(size10000)))))) tracker.stop() } @@ -64,10 +72,17 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val compressedSize10000 = MapOutputTracker.compressSize(10000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) val size10000 = MapOutputTracker.decompressSize(compressedSize10000) - tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0), - Array(compressedSize1000, compressedSize1000, compressedSize1000))) - tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0), - Array(compressedSize10000, compressedSize1000, compressedSize1000))) + + val bmIdA = BlockManagerId("a", "hostA", 1000, 0) + val bmIdB = BlockManagerId("b", "hostB", 1000, 0) + val mapLocations = Array[MapOutputLocation](new MapOutputLocation(bmIdA, 1), new MapOutputLocation(bmIdB, 1)) + + val bucketA = new GroupBucketSizes(1, Array(Array[Byte](compressedSize1000))) + val bucketB = new GroupBucketSizes(1, Array(Array[Byte](compressedSize10000))) + val segmentSizes = HashMap[BlockManagerId, ShuffleBlockGroupSizeArray]((bmIdA, ShuffleBlockGroupSizeArray(bucketA)), + (bmIdB, ShuffleBlockGroupSizeArray(bucketB))) + + tracker.registerMapOutputs(10, mapLocations, segmentSizes) // As if we had two simulatenous fetch failures tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0)) @@ -101,12 +116,17 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val compressedSize1000 = MapOutputTracker.compressSize(1000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) - masterTracker.registerMapOutput(10, 0, new MapStatus( - BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000))) + val bmIdA = BlockManagerId("a", "hostA", 1000, 0) + val mapLocations = Array[MapOutputLocation](new MapOutputLocation(bmIdA, 1)) + val bucketA = new GroupBucketSizes(1, Array(Array[Byte](compressedSize1000))) + val segmentSizes = HashMap[BlockManagerId, ShuffleBlockGroupSizeArray]((bmIdA, ShuffleBlockGroupSizeArray(bucketA))) + masterTracker.registerMapOutputs(10, mapLocations, segmentSizes) masterTracker.incrementGeneration() + slaveTracker.updateGeneration(masterTracker.getGeneration) - assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000, 0), size1000))) + val (statuses, sizes) = slaveTracker.getServerStatuses(10, 0) + assert(statuses.toSeq === Seq(new MapOutputLocation(bmIdA, 1))) + assert(sizes.toSeq === Seq((bmIdA, Seq((0, Seq(size1000)))))) masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0)) masterTracker.incrementGeneration() diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index 950218fa28..76199ba2a2 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -50,8 +50,8 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { // All blocks must have non-zero size (0 until NUM_BLOCKS).foreach { id => - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) - assert(statuses.forall(s => s._2 > 0)) + val (statuses, sizes) = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) + assert(sizes.forall(s => s._2.forall(_._2.forall(_ > 0)))) } } @@ -85,8 +85,8 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { assert(c.count === 4) val blockSizes = (0 until NUM_BLOCKS).flatMap { id => - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) - statuses.map(x => x._2) + val (statuses, sizes) = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) + sizes.flatMap(x => x._2.flatMap(v => v._2)) } val nonEmptyBlocks = blockSizes.filter(x => x > 0) @@ -110,8 +110,8 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { assert(c.count === 4) val blockSizes = (0 until NUM_BLOCKS).flatMap { id => - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) - statuses.map(x => x._2) + val (statuses, sizes) = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) + sizes.flatMap(x => x._2.flatMap(v => v._2)) } val nonEmptyBlocks = blockSizes.filter(x => x > 0) diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala index 30e6fef950..837ddf31c7 100644 --- a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala @@ -245,7 +245,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)))) - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + assert(mapOutputTracker.getServerStatuses(shuffleId, 0)._1.map(_.location) === Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) complete(taskSets(1), Seq((Success, 42))) assert(results === Map(0 -> 42)) @@ -271,7 +271,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont // have the 2nd attempt pass complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1)))) // we can see both result blocks now - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB")) + assert(mapOutputTracker.getServerStatuses(shuffleId, 0)._1.map(_.location.host) === Array("hostA", "hostB")) complete(taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) } @@ -298,7 +298,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont // should work because it's a new generation taskSet.tasks(1).generation = newGeneration runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum, null, null)) - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + assert(mapOutputTracker.getServerStatuses(shuffleId, 0)._1.map(_.location) === Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) complete(taskSets(1), Seq((Success, 42), (Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) @@ -320,7 +320,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont (Success, makeMapStatus("hostB", 1)))) // have hostC complete the resubmitted task complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + assert(mapOutputTracker.getServerStatuses(shuffleId, 0)._1.map(_.location) === Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) complete(taskSets(2), Seq((Success, 42))) assert(results === Map(0 -> 42)) @@ -395,7 +395,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont } private def makeMapStatus(host: String, reduces: Int): MapStatus = - new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2)) + new MapStatus(makeBlockManagerId(host), 0, 1, Array(Array.fill[Byte](reduces)(2))) private def makeBlockManagerId(host: String): BlockManagerId = BlockManagerId("exec-" + host, host, 12345, 0)