diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 43e6af2351..1e729dee57 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -483,7 +483,7 @@ class SparkContext( * Broadcast a read-only variable to the cluster, returning a [[spark.broadcast.Broadcast]] object for * reading it in distributed functions. The variable will be sent to each cluster only once. */ - def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal) + def broadcast[T](value: T, tellMaster: Boolean = true) = env.broadcastManager.newBroadcast[T](value, isLocal, tellMaster) /** * Add a file to be downloaded with this Spark job on every node. diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala index 5f18b1e15b..7266d329df 100644 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala @@ -308,7 +308,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork * Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for * reading it in distributed functions. The variable will be sent to each cluster only once. */ - def broadcast[T](value: T): Broadcast[T] = sc.broadcast(value) + def broadcast[T](value: T, tellMaster: Boolean): Broadcast[T] = sc.broadcast(value, tellMaster) /** Shut down the SparkContext. */ def stop() { diff --git a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala index adcb2d2415..f55e9a6d23 100644 --- a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala @@ -11,7 +11,7 @@ import scala.math import spark._ import spark.storage.StorageLevel -private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) +private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long, tellMaster: Boolean) extends Broadcast[T](id) with Logging with Serializable { @@ -21,7 +21,9 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: def blockId: String = "broadcast_" + id MultiTracker.synchronized { - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) + //If tellMaster is true, Let BlockManagerMaster know that we have the broadcast + //block for its latter notification us to remove. + SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster) } @transient var arrayOfBlocks: Array[BroadcastBlock] = null @@ -58,6 +60,27 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: if (!isLocal) { sendBroadcast() } + + override def remove(toReleaseSource: Boolean = false) { + logInfo("Remove broadcast variable " + blockId) + if (tellMaster) { + logInfo("remove broadcast variable block" + blockId + " on slaves") + SparkEnv.get.blockManager.master.removeBlock(blockId) + } + SparkEnv.get.blockManager.removeBlock(blockId, false) + if (toReleaseSource) { + releaseSource() + } + } + + def releaseSource(){ + arrayOfBlocks = null + hasBlocksBitVector = null + numCopiesSent = null + listOfSources = null + serveMR = null + guideMR = null + } def sendBroadcast() { logInfo("Local host address: " + hostAddress) @@ -116,7 +139,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: private def readObject(in: ObjectInputStream) { in.defaultReadObject() MultiTracker.synchronized { - SparkEnv.get.blockManager.getSingle(blockId) match { + SparkEnv.get.blockManager.getSingleLocal(blockId) match { case Some(x) => value_ = x.asInstanceOf[T] @@ -139,8 +162,11 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: val receptionSucceeded = receiveBroadcast(id) if (receptionSucceeded) { value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) + //Let BlockManagerMaster know that we have the broadcast block for its latter notification us to remove. + //If tellMaster is true, Let BlockManagerMaster know that we have the broadcast + //block for its latter notification us to remove. SparkEnv.get.blockManager.putSingle( - blockId, value_, StorageLevel.MEMORY_AND_DISK, false) + blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster) } else { logError("Reading broadcast variable " + id + " failed") } @@ -1033,8 +1059,8 @@ private[spark] class BitTorrentBroadcastFactory extends BroadcastFactory { def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) } - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = - new BitTorrentBroadcast[T](value_, isLocal, id) + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long, tellMaster: Boolean) = + new BitTorrentBroadcast[T](value_, isLocal, id, tellMaster) def stop() { MultiTracker.stop() } } diff --git a/core/src/main/scala/spark/broadcast/Broadcast.scala b/core/src/main/scala/spark/broadcast/Broadcast.scala index 415bde5d67..96896c9600 100644 --- a/core/src/main/scala/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/spark/broadcast/Broadcast.scala @@ -12,6 +12,10 @@ abstract class Broadcast[T](private[spark] val id: Long) extends Serializable { // readObject having to be 'private' in sub-classes. override def toString = "spark.Broadcast(" + id + ")" + + // Remove a Broadcast blcok from the SparkContext and Executors that have it. + // Set isClearSource true to also remove the Broadcast value from its source. + def remove(toReleaseSource: Boolean) } private[spark] @@ -45,9 +49,9 @@ class BroadcastManager(val _isDriver: Boolean) extends Logging with Serializable } private val nextBroadcastId = new AtomicLong(0) - - def newBroadcast[T](value_ : T, isLocal: Boolean) = - broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) - + + def newBroadcast[T](value_ : T, isLocal: Boolean, tellMaster: Boolean) = + broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement(), tellMaster) + def isDriver = _isDriver } diff --git a/core/src/main/scala/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/spark/broadcast/BroadcastFactory.scala index 5c6184c3c7..00023e96ed 100644 --- a/core/src/main/scala/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/spark/broadcast/BroadcastFactory.scala @@ -8,6 +8,6 @@ package spark.broadcast */ private[spark] trait BroadcastFactory { def initialize(isDriver: Boolean): Unit - def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T] + def newBroadcast[T](value: T, isLocal: Boolean, id: Long, tellMaster: Boolean): Broadcast[T] def stop(): Unit } diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala index 7e30b8f7d2..edd583c38a 100644 --- a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala @@ -13,7 +13,7 @@ import spark._ import spark.storage.StorageLevel import util.{MetadataCleaner, TimeStampedHashSet} -private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) +private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long, tellMaster: Boolean = true) extends Broadcast[T](id) with Logging with Serializable { def value = value_ @@ -21,24 +21,47 @@ extends Broadcast[T](id) with Logging with Serializable { def blockId: String = "broadcast_" + id HttpBroadcast.synchronized { - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) + //If tellMaster is true, Let BlockManagerMaster know that we have the broadcast + //block for its latter notification us to remove. + SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster) } if (!isLocal) { HttpBroadcast.write(id, value_) } + + override def remove(toReleaseSource: Boolean = false) { + logInfo("Remove broadcast variable " + blockId) + if (tellMaster) { + logInfo("remove broadcast variable block" + blockId + " on slaves") + SparkEnv.get.blockManager.master.removeBlock(blockId) + } + SparkEnv.get.blockManager.removeBlock(blockId, false) + if (toReleaseSource) { + releaseSource() + } + } + + def releaseSource(){ + val path: String = HttpBroadcast.broadcastDir + "/" + "broadcast-" + id + HttpBroadcast.files.internalMap.remove(path) + new File(path).delete() + logInfo("Deleted source broadcast file '" + path + "'") + } // Called by JVM when deserializing an object private def readObject(in: ObjectInputStream) { in.defaultReadObject() HttpBroadcast.synchronized { - SparkEnv.get.blockManager.getSingle(blockId) match { + SparkEnv.get.blockManager.getSingleLocal(blockId) match { case Some(x) => value_ = x.asInstanceOf[T] case None => { logInfo("Started reading broadcast variable " + id) val start = System.nanoTime value_ = HttpBroadcast.read[T](id) - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) + //If tellMaster is true, Let BlockManagerMaster know that we have the broadcast + //block for its latter notification us to remove. + SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster) val time = (System.nanoTime - start) / 1e9 logInfo("Reading broadcast variable " + id + " took " + time + " s") } @@ -50,8 +73,8 @@ extends Broadcast[T](id) with Logging with Serializable { private[spark] class HttpBroadcastFactory extends BroadcastFactory { def initialize(isDriver: Boolean) { HttpBroadcast.initialize(isDriver) } - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = - new HttpBroadcast[T](value_, isLocal, id) + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long, tellMaster: Boolean) = + new HttpBroadcast[T](value_, isLocal, id, tellMaster) def stop() { HttpBroadcast.stop() } } diff --git a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala index c55c476117..687385b972 100644 --- a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala @@ -10,7 +10,7 @@ import scala.math import spark._ import spark.storage.StorageLevel -private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) +private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long, tellMaster: Boolean) extends Broadcast[T](id) with Logging with Serializable { def value = value_ @@ -18,7 +18,9 @@ extends Broadcast[T](id) with Logging with Serializable { def blockId = "broadcast_" + id MultiTracker.synchronized { - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) + //If tellMaster is true, Let BlockManagerMaster know that we have the broadcast + //block for its latter notification us to remove. + SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster) } @transient var arrayOfBlocks: Array[BroadcastBlock] = null @@ -46,6 +48,25 @@ extends Broadcast[T](id) with Logging with Serializable { if (!isLocal) { sendBroadcast() } + + override def remove(toReleaseSource: Boolean = false) { + logInfo("Remove broadcast variable " + blockId) + if (tellMaster) { + logInfo("remove broadcast variable block" + blockId + " on slaves") + SparkEnv.get.blockManager.master.removeBlock(blockId) + } + SparkEnv.get.blockManager.removeBlock(blockId, false) + if (toReleaseSource) { + releaseSource() + } + } + + def releaseSource(){ + arrayOfBlocks = null + listOfSources = null + serveMR = null + guideMR = null + } def sendBroadcast() { logInfo("Local host address: " + hostAddress) @@ -92,7 +113,7 @@ extends Broadcast[T](id) with Logging with Serializable { private def readObject(in: ObjectInputStream) { in.defaultReadObject() MultiTracker.synchronized { - SparkEnv.get.blockManager.getSingle(blockId) match { + SparkEnv.get.blockManager.getSingleLocal(blockId) match { case Some(x) => value_ = x.asInstanceOf[T] @@ -114,8 +135,10 @@ extends Broadcast[T](id) with Logging with Serializable { val receptionSucceeded = receiveBroadcast(id) if (receptionSucceeded) { value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) + //If tellMaster is true, Let BlockManagerMaster know that we have the broadcast + //block for its latter notification us to remove. SparkEnv.get.blockManager.putSingle( - blockId, value_, StorageLevel.MEMORY_AND_DISK, false) + blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster) } else { logError("Reading broadcast variable " + id + " failed") } @@ -578,8 +601,8 @@ private[spark] class TreeBroadcastFactory extends BroadcastFactory { def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) } - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = - new TreeBroadcast[T](value_, isLocal, id) + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long, tellMaster: Boolean) = + new TreeBroadcast[T](value_, isLocal, id, tellMaster) def stop() { MultiTracker.stop() } } diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 9b39d3aadf..e0efdc342d 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -772,6 +772,13 @@ private[spark] class BlockManager( def getSingle(blockId: String): Option[Any] = { get(blockId).map(_.next()) } + + /** + * Read a block consisting of a single object only from local BlockManager. + */ + def getSingleLocal(blockId: String): Option[Any] = { + getLocal(blockId).map(_.next()) + } /** * Write a block consisting of a single object. diff --git a/examples/src/main/scala/spark/examples/BroadcastTest.scala b/examples/src/main/scala/spark/examples/BroadcastTest.scala index ba59be1687..757b2a80f0 100644 --- a/examples/src/main/scala/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/spark/examples/BroadcastTest.scala @@ -22,10 +22,11 @@ object BroadcastTest { for (i <- 0 until 2) { println("Iteration " + i) println("===========") - val barr1 = sc.broadcast(arr1) + val barr1 = sc.broadcast(arr1, (i == 0)) sc.parallelize(1 to 10, slices).foreach { i => println(barr1.value.size) } + barr1.remove(true) } System.exit(0)