diff --git a/src/main/scala/shark/SharkConfVars.scala b/src/main/scala/shark/SharkConfVars.scala index eaa45ab9..a4476564 100755 --- a/src/main/scala/shark/SharkConfVars.scala +++ b/src/main/scala/shark/SharkConfVars.scala @@ -20,6 +20,9 @@ object SharkConfVars { // If true, then cache any table whose name ends in "_cached". val CHECK_TABLENAME_FLAG = new ConfVar("shark.cache.flag.checkTableName", false) + + // If true, then query plans are compressed before being sent + val COMPRESS_QUERY_PLAN = new ConfVar("shark.compressQueryPlan", true) def getIntVar(conf: Configuration, variable: ConfVar): Int = { require(variable.valClass == classOf[Int]) diff --git a/src/main/scala/shark/execution/TableScanOperator.scala b/src/main/scala/shark/execution/TableScanOperator.scala index b0f9b307..35242149 100755 --- a/src/main/scala/shark/execution/TableScanOperator.scala +++ b/src/main/scala/shark/execution/TableScanOperator.scala @@ -130,7 +130,7 @@ with HiveTopOperator { val parts = SharkEnv.sc.hadoopFile( tablePath, ifc, classOf[Writable], classOf[Writable]).map(_._2) - val serializedHconf = XmlSerializer.serialize(localHconf) + val serializedHconf = XmlSerializer.serialize(localHconf, localHconf) val partRDD = parts.mapPartitions { iter => // Map each tuple to a row object val hconf = XmlSerializer.deserialize(serializedHconf).asInstanceOf[HiveConf] diff --git a/src/main/scala/shark/execution/serialization.scala b/src/main/scala/shark/execution/serialization.scala index dfa11251..c29b8bc1 100644 --- a/src/main/scala/shark/execution/serialization.scala +++ b/src/main/scala/shark/execution/serialization.scala @@ -4,11 +4,15 @@ import java.beans.{XMLDecoder, XMLEncoder, PersistenceDelegate} import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectOutput, ObjectInput} import java.nio.ByteBuffer +import com.ning.compress.lzf.{LZFEncoder, LZFDecoder} + +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.ql.exec.Utilities.EnumDelegate import org.apache.hadoop.hive.ql.plan.GroupByDesc import org.apache.hadoop.hive.ql.plan.PlanUtils.ExpressionTypes -import shark.LogHelper +import shark.{SharkConfVars, LogHelper} + /** @@ -44,7 +48,7 @@ class OperatorSerializationWrapper[T <: Operator[_ <: HiveOperator]] def value_= (v: T):Unit = { _value = v - opSerialized = XmlSerializer.serialize(value) + opSerialized = XmlSerializer.serialize(value, v.hconf) objectInspectorsSerialized = KryoSerializer.serialize(value.objectInspectors) } @@ -72,21 +76,40 @@ object OperatorSerializationWrapper { * serialize byte arrays because it is extremely inefficient. */ object XmlSerializer { + // We prepend the buffer with a byte indicating whether payload is compressed + val COMPRESSION_ENABLED : Byte = 1; + val COMPRESSION_DISABLED : Byte = 0; - def serialize[T](o: T): Array[Byte] = { - val out = new ByteArrayOutputStream() - val e = new XMLEncoder(out) + def serialize[T](o: T, conf: Configuration): Array[Byte] = { + val byteStream = new ByteArrayOutputStream() + val e = new XMLEncoder(byteStream) // workaround for java 1.5 e.setPersistenceDelegate(classOf[ExpressionTypes], new EnumDelegate()) e.setPersistenceDelegate(classOf[GroupByDesc.Mode], new EnumDelegate()) e.writeObject(o) e.close() - out.toByteArray() + + val useCompression = conf match { + case null => SharkConfVars.COMPRESS_QUERY_PLAN.defaultBoolVal + case _ => SharkConfVars.getBoolVar(conf, SharkConfVars.COMPRESS_QUERY_PLAN) + } + + if (useCompression) { + COMPRESSION_ENABLED +: LZFEncoder.encode(byteStream.toByteArray()) + } else { + COMPRESSION_DISABLED +: byteStream.toByteArray + } } def deserialize[T](bytes: Array[Byte]): T = { val cl = Thread.currentThread.getContextClassLoader - val d: XMLDecoder = new XMLDecoder(new ByteArrayInputStream(bytes), null, null, cl) + val decodedStream = + if (bytes(0) == COMPRESSION_ENABLED) { + new ByteArrayInputStream(LZFDecoder.decode(bytes.slice(1, bytes.size))) + } else { + new ByteArrayInputStream(bytes.slice(1, bytes.size)) + } + val d: XMLDecoder = new XMLDecoder(decodedStream, null, null, cl) val ret = d.readObject() d.close() ret.asInstanceOf[T]