Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shuffle consolidation #669

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 20 additions & 19 deletions core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,33 +18,34 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
val blockManager = SparkEnv.get.blockManager

val startTime = System.currentTimeMillis
val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
val (mapLocations, blockSizes)= SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put a space after the =


logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
shuffleId, reduceId, System.currentTimeMillis - startTime))

val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]
for (((address, size), index) <- statuses.zipWithIndex) {
splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
}

val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] = splitsByAddress.toSeq.map {
case (address, splits) =>
(address, splits.map(s => ("shuffle_%d_%d_%d".format(shuffleId, s._1, reduceId), s._2)))
}
val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] = blockSizes.map { s =>
val bm = s._1
val groups = s._2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just FYI you could also write this as val (bm, groups) = s or write blockSizes.map { case (bm, groups) =>.

val blockIds = groups.flatMap { v =>
val groupId = v._1
v._2.zipWithIndex.map(x=>(("shuffle_%d_%d_%d_%d").format(shuffleId, groupId, reduceId, x._2), x._1))
}
(bm, blockIds)
}

def unpackBlock(blockPair: (String, Option[Iterator[Any]])) : Iterator[(K, V)] = {
val blockId = blockPair._1
val blockOption = blockPair._2
def unpackBlock(blockTuple: (BlockManagerId, String, Option[Iterator[Any]])) : Iterator[(K, V)] = {
val address = blockTuple._1
val blockId = blockTuple._2
val blockOption = blockTuple._3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here with the unpacking, you can do val (address, blockId, blockOption) = blockTuple

blockOption match {
case Some(block) => {
block.asInstanceOf[Iterator[(K, V)]]
}
case None => {
val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)".r
val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)_([0-9]*)".r
blockId match {
case regex(shufId, mapId, _) =>
val address = statuses(mapId.toInt)._1
throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null)
case regex(shufId, _, _, _) =>
throw new FetchFailedException(address, shufId.toInt, -1, reduceId, null)
case _ =>
throw new SparkException(
"Failed to get block " + blockId + ", which is not a shuffle block")
Expand All @@ -53,12 +54,12 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
}
}

val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
val blockFetcherItr = blockManager.getMultiple(
blocksByAddress, shuffleId, reduceId, mapLocations, serializer)
val itr = blockFetcherItr.flatMap(unpackBlock)

CompletionIterator[(K,V), Iterator[(K,V)]](itr, {
val shuffleMetrics = new ShuffleReadMetrics
shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime
shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime
shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead
Expand Down
148 changes: 115 additions & 33 deletions core/src/main/scala/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ private[spark] class MapOutputTracker extends Logging {
// Set to the MapOutputTrackerActor living on the driver
var trackerActor: ActorRef = _

private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
private val mapStatuses = new TimeStampedHashMap[Int, Array[MapOutputLocation]]
private val shuffleBlockSizes = new TimeStampedHashMap[Int, HashMap[BlockManagerId, ShuffleBlockGroupSizeArray]]

// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
Expand Down Expand Up @@ -76,36 +77,35 @@ private[spark] class MapOutputTracker extends Logging {
}

def registerShuffle(shuffleId: Int, numMaps: Int) {
if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
if (mapStatuses.putIfAbsent(shuffleId, new Array[MapOutputLocation](numMaps)).isDefined) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
}

def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
var array = mapStatuses(shuffleId)
array.synchronized {
array(mapId) = status
}
}

def registerMapOutputs(
shuffleId: Int,
statuses: Array[MapStatus],
statuses: Array[MapOutputLocation],
sizes: HashMap[BlockManagerId, ShuffleBlockGroupSizeArray],
changeGeneration: Boolean = false) {
mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
mapStatuses.put(shuffleId, Array[MapOutputLocation]() ++ statuses)
shuffleBlockSizes.put(shuffleId, HashMap[BlockManagerId, ShuffleBlockGroupSizeArray]() ++ sizes)

if (changeGeneration) {
incrementGeneration()
}
}

def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
var arrayOpt = mapStatuses.get(shuffleId)
if (arrayOpt.isDefined && arrayOpt.get != null) {
var array = arrayOpt.get
val array = mapStatuses.get(shuffleId).orNull
val sizes = shuffleBlockSizes.get(shuffleId).orNull
if (array != null) {
array.synchronized {
if (array(mapId) != null && array(mapId).location == bmAddress) {
array(mapId) = null
}
if (sizes!= null) {
sizes.remove(bmAddress)
}
}
incrementGeneration()
} else {
Expand All @@ -116,12 +116,18 @@ private[spark] class MapOutputTracker extends Logging {
// Remembers which map output locations are currently being fetched on a worker
private val fetching = new HashSet[Int]

// Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle
def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
/**
* Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle
* Return an array of MapOutputLocation of the specific reduceId, one for each ShuffleMapTask,
* and sizes of all segments in for the shuffle (bucket) in the form of
* Seq(BlockManagerId, Seq(groupId, size array for all the segments in the bucket))
*/
def getServerStatuses(shuffleId: Int, reduceId: Int): (Array[MapOutputLocation], Seq[(BlockManagerId, Seq[(Int, Array[Long])])]) = {
val statuses = mapStatuses.get(shuffleId).orNull
if (statuses == null) {
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
var fetchedStatuses: Array[MapStatus] = null
var fetchedStatuses: Array[MapOutputLocation] = null
var fetchedSizes: HashMap[BlockManagerId, ShuffleBlockGroupSizeArray] = null
fetching.synchronized {
if (fetching.contains(shuffleId)) {
// Someone else is fetching it; wait for them to be done
Expand Down Expand Up @@ -151,9 +157,12 @@ private[spark] class MapOutputTracker extends Logging {
try {
val fetchedBytes =
askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]]
fetchedStatuses = deserializeStatuses(fetchedBytes)
val tuple = deserializeStatuses(fetchedBytes)
fetchedStatuses = tuple._1
fetchedSizes = tuple._2
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
shuffleBlockSizes.put(shuffleId, fetchedSizes)
} finally {
fetching.synchronized {
fetching -= shuffleId
Expand All @@ -162,8 +171,10 @@ private[spark] class MapOutputTracker extends Logging {
}
}
if (fetchedStatuses != null) {
logDebug("ShufCon - getServerStatuses for shuffle " + shuffleId + ": " + fetachedResultStr(fetchedStatuses, fetchedSizes))

fetchedStatuses.synchronized {
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
return (fetchedStatuses, MapOutputTracker.convertShuffleBlockSizes(shuffleId, reduceId, fetchedSizes))
}
}
else{
Expand All @@ -172,13 +183,25 @@ private[spark] class MapOutputTracker extends Logging {
}
} else {
statuses.synchronized {
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
return (statuses, MapOutputTracker.convertShuffleBlockSizes(shuffleId, reduceId, shuffleBlockSizes.get(shuffleId).orNull))
}
}
}

private def fetachedResultStr (fetchedStatuses: Array[MapOutputLocation], fetchedSizes: HashMap[BlockManagerId, ShuffleBlockGroupSizeArray]) = {
var str = "(fetchedStatuses="
fetchedStatuses.zipWithIndex.foreach { s =>
str += (if (s._2 != 0) ", " else "") + "map[" + s._2 + "]=" + s._1.debugString
}
str += "), fetchedSizes=("
fetchedSizes.foreach { s => str += "(" + s._1 + ", " + s._2.debugString + ") "}
str += ")"
str
}

private def cleanup(cleanupTime: Long) {
mapStatuses.clearOldValues(cleanupTime)
shuffleBlockSizes.clearOldValues(cleanupTime)
cachedSerializedStatuses.clearOldValues(cleanupTime)
}

Expand Down Expand Up @@ -219,7 +242,8 @@ private[spark] class MapOutputTracker extends Logging {
}

def getSerializedLocations(shuffleId: Int): Array[Byte] = {
var statuses: Array[MapStatus] = null
var statuses: Array[MapOutputLocation] = null
var sizes: HashMap[BlockManagerId, ShuffleBlockGroupSizeArray] = null
var generationGotten: Long = -1
generationLock.synchronized {
if (generation > cacheGeneration) {
Expand All @@ -231,12 +255,13 @@ private[spark] class MapOutputTracker extends Logging {
return bytes
case None =>
statuses = mapStatuses(shuffleId)
sizes = shuffleBlockSizes.get(shuffleId).orNull
generationGotten = generation
}
}
// If we got here, we failed to find the serialized locations in the cache, so we pulled
// out a snapshot of the locations as "locs"; let's serialize and return that
val bytes = serializeStatuses(statuses)
val bytes = serializeStatuses((statuses, sizes))
logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
// Add them into the table only if the generation hasn't changed while we were working
generationLock.synchronized {
Expand All @@ -250,24 +275,25 @@ private[spark] class MapOutputTracker extends Logging {
// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
// generally be pretty compressible because many map outputs will be on the same hostname.
private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
private def serializeStatuses(tuple: (Array[MapOutputLocation], HashMap[BlockManagerId, ShuffleBlockGroupSizeArray])): Array[Byte] = {
val out = new ByteArrayOutputStream
val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
// Since statuses can be modified in parallel, sync on it
val statuses = tuple._1
statuses.synchronized {
objOut.writeObject(statuses)
objOut.writeObject(tuple)
}
objOut.close()
out.toByteArray
}

// Opposite of serializeStatuses.
def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = {
def deserializeStatuses(bytes: Array[Byte]) = {
val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes)))
objIn.readObject().
// // drop all null's from status - not sure why they are occuring though. Causes NPE downstream in slave if present
// comment this out - nulls could be due to missing location ?
asInstanceOf[Array[MapStatus]] // .filter( _ != null )
asInstanceOf[Tuple2[Array[MapOutputLocation], HashMap[BlockManagerId, ShuffleBlockGroupSizeArray]]] // .filter( _ != null )
}
}

Expand All @@ -277,18 +303,19 @@ private[spark] object MapOutputTracker {
// Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
// any of the statuses is null (indicating a missing location due to a failed mapper),
// throw a FetchFailedException.
private def convertMapStatuses(
private def convertShuffleBlockSizes(
shuffleId: Int,
reduceId: Int,
statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = {
assert (statuses != null)
statuses.map {
status =>
if (status == null) {
sizes: HashMap[BlockManagerId, ShuffleBlockGroupSizeArray]): Seq[(BlockManagerId, Seq[(Int, Array[Long])])] = {
assert (sizes != null)
sizes.toSeq.map {
case (bmId, groups) =>
if (groups == null) {
throw new FetchFailedException(null, shuffleId, -1, reduceId,
new Exception("Missing an output location for shuffle " + shuffleId))
} else {
(status.location, decompressSize(status.compressedSizes(reduceId)))
val seq = for (i <- 0 until groups.groupNum) yield (i, groups(i).bucketSizes(reduceId).map(decompressSize(_)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is longer than 100 characters; split into two

(bmId, seq)
}
}
}
Expand Down Expand Up @@ -319,3 +346,58 @@ private[spark] object MapOutputTracker {
}
}
}

private[spark] class MapOutputLocation (val location: BlockManagerId, val sequence: Int) extends Serializable {
def this (status: MapStatus) = this (status.location, status.sequence)
def debugString = "MapOutputLocation(location=" + location + ", sequence=" + sequence +")"
}

private[spark] class GroupBucketSizes (var sequence: Int, var bucketSizes: Array[Array[Byte]]) extends Serializable {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is longer than 100 characters

def this (status: MapStatus) = this (status.sequence, status.compressedSizes)
def debugString = {
var str = "GroupBucketSizes(sequence=" + sequence + ", "
bucketSizes.zipWithIndex.foreach { s =>
str += (if (s._2 != 0) ", " else "") + "bucket[" + s._2 + "]=("
s._1.zipWithIndex.foreach{ x =>
str += (if (x._2 != 0) ", " else "") + x._1
}
str += ")"
}
str += ")"
str
}
}

private[spark] class ShuffleBlockGroupSizeArray () extends Serializable {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need the empty parens () if there are no constructor arguments; just do class X extends Y.

var groupNum = 0
private var groupSizeArray = Array.fill[GroupBucketSizes](32)(null)

def apply(idx: Int) = {
groupSizeArray(idx)
}

def update(idx: Int, elem: GroupBucketSizes) {
if (idx >= groupSizeArray.length){
var newLen = groupSizeArray.length * 2
while (idx >= newLen)
newLen = newLen * 2

val newArray = Array.fill[GroupBucketSizes](newLen)(null)
scala.compat.Platform.arraycopy(groupSizeArray, 0, newArray, 0, groupNum)
groupSizeArray = newArray
}

if (idx >= groupNum)
groupNum = idx + 1

groupSizeArray(idx) = elem
}

def debugString = {
var str = "ShuffleBlockGroupSizeArray("
for (i <- 0 until groupNum) {
str += (if (i != 0) str += ", " else "") + "group_" + i + "=" + (if (groupSizeArray(i) == null) "null" else groupSizeArray(i).debugString)
}
str + ")"
}
}
8 changes: 5 additions & 3 deletions core/src/main/scala/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ import java.nio.ByteBuffer
/**
* The Mesos executor for Spark.
*/
private[spark] class Executor(executorId: String, slaveHostname: String, properties: Seq[(String, String)]) extends Logging {
private[spark] class Executor(executorId: String, slaveHostname: String, properties: Seq[(String, String)], cores: Int = 128)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is also too long and should be split

extends Logging {

// Application dependencies (added through SparkContext) that we've fetched so far on this node.
// Each map holds the master's timestamp for the version of that file or JAR we got.
Expand All @@ -39,7 +40,8 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
for ((key, value) <- properties) {
System.setProperty(key, value)
}

System.setProperty("spark.slaveCores", cores.toString)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit unfortunate that we have to pass this to the block manager through a system property. Can you add it to the constructor of BlockManager / ShuffleBlocksPool instead? I believe that gets created in SparkEnv so you could pass it to SparkEnv.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can do that if Spark is running is cluster (including local cluster) mode, as each executor will construct its SparkEnv. However, if Spark in running in local mode, it has no executors and will just use BlockManager in SparkContext, which is ready constructed and currently we just set the system property in LocalScheduler.

An alternative is, instead of using a system property, we can pass it as a variable in BlockManager (by setting it as appropriate in Executor and LocalScheduler).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's make it a variable in the BlockManager then.


// Create our ClassLoader and set it on this thread
private val urlClassLoader = createClassLoader()
private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
Expand Down Expand Up @@ -77,7 +79,7 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert

// Start worker thread pool
val threadPool = new ThreadPoolExecutor(
1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
1, cores, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])

def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) {
threadPool.execute(new TaskRunner(context, taskId, serializedTask))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ private[spark] class StandaloneExecutorBackend(
case RegisteredExecutor(sparkProperties) =>
logInfo("Successfully registered with driver")
// Make this host instead of hostPort ?
executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties)
executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties, cores)

case RegisterExecutorFailed(message) =>
logError("Slave registration failed: " + message)
Expand Down
Loading