From 089bd18c9d786357e83cd2fbd93925f1b83bcad8 Mon Sep 17 00:00:00 2001 From: Sun He Date: Sat, 14 May 2016 13:55:40 +0800 Subject: [PATCH 1/3] Add support for Spark SQL --- .../provider/redis/sql/DefaultSource.scala | 180 ++++++++++++++++++ 1 file changed, 180 insertions(+) create mode 100644 src/main/scala/com/redislabs/provider/redis/sql/DefaultSource.scala diff --git a/src/main/scala/com/redislabs/provider/redis/sql/DefaultSource.scala b/src/main/scala/com/redislabs/provider/redis/sql/DefaultSource.scala new file mode 100644 index 00000000..4002cefa --- /dev/null +++ b/src/main/scala/com/redislabs/provider/redis/sql/DefaultSource.scala @@ -0,0 +1,180 @@ +package com.redislabs.provider.redis.sql + +import scala.collection.JavaConversions._ +import com.redislabs.provider.redis._ +import com.redislabs.provider.redis.rdd.{Keys, RedisKeysRDD} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ +import redis.clients.jedis.Protocol +import redis.clients.util.JedisClusterCRC16 +import java.security.MessageDigest + + +case class RedisRelation(parameters: Map[String, String], userSchema: StructType) + (@transient val sqlContext: SQLContext) + extends BaseRelation with PrunedFilteredScan with InsertableRelation with Keys { + + val tableName: String = parameters.getOrElse("table", "PANIC") + + val redisConfig: RedisConfig = { + new RedisConfig({ + if ((parameters.keySet & Set("host", "port", "auth", "dbNum", "timeout")).size == 0) { + new RedisEndpoint(sqlContext.sparkContext.getConf) + } else { + val host = parameters.getOrElse("host", Protocol.DEFAULT_HOST) + val port = parameters.getOrElse("port", Protocol.DEFAULT_PORT.toString).toInt + val auth = parameters.getOrElse("auth", null) + val dbNum = parameters.getOrElse("dbNum", Protocol.DEFAULT_DATABASE.toString).toInt + val timeout = parameters.getOrElse("timeout", Protocol.DEFAULT_TIMEOUT.toString).toInt + new RedisEndpoint(host, port, auth, dbNum, timeout) + } + } + ) + } + + val partitionNum: Int = parameters.getOrElse("partitionNum", 3.toString).toInt + + val schema = userSchema + + def getNode(key: String): RedisNode = { + val slot = JedisClusterCRC16.getSlot(key) + /* Master only */ + redisConfig.hosts.filter(node => { node.startSlot <= slot && node.endSlot >= slot }).filter(_.idx == 0)(0) + } + def insert(data: DataFrame, overwrite: Boolean): Unit = { + data.foreach{ + row => { + val key = tableName + ":" + MessageDigest.getInstance("MD5").digest(System.currentTimeMillis.toString.getBytes) + val conn = getNode(key).endpoint.connect + conn.hmset(key, row.getValuesMap(row.schema.fieldNames).map(x => (x._1, x._2.toString))) + conn.close + } + } + } + + def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { + new RedisKeysRDD(sqlContext.sparkContext, redisConfig, tableName + ":*", partitionNum, null). + mapPartitions { + partition: Iterator[String] => { + groupKeysByNode(redisConfig.hosts, partition).flatMap { + x => { + val conn = x._1.endpoint.connect() + val rowKeys: Array[String] = filterKeysByType(conn, x._2, "hash") + val res = rowKeys.map { + key => { + val res = conn.hmget(key, schema.fieldNames: _*) + val pass = filters.zip(filters.map(filter => res.get((schema.fieldIndex(getAttr(filter)))))).forall{ + x => parseFilter(x._1, x._2) + } + if (pass) { + requiredColumns.map{ + c => { + val idx = schema.fieldIndex(c) + castToTarget(res.get(idx), schema.fields(idx)) + } + } + } else { + null + } + } + } + conn.close + res.filter(_!=null) + } + }.toIterator.map(x => Row.fromSeq(x)) + } + } + } + + private def getAttr(f: Filter): String = { + f match { + case EqualTo(attribute, value) => attribute + case GreaterThan(attribute, value) => attribute + case GreaterThanOrEqual(attribute, value) => attribute + case LessThan(attribute, value) => attribute + case LessThanOrEqual(attribute, value) => attribute + case In(attribute, values) => attribute + case IsNull(attribute) => attribute + case IsNotNull(attribute) => attribute + case StringStartsWith(attribute, value) => attribute + case StringEndsWith(attribute, value) => attribute + case StringContains(attribute, value) => attribute + } + } + + private def castToTarget(value: String, field: StructField) = { + field.dataType match { + case IntegerType => value.toString.toInt + case DoubleType => value.toString.toDouble + case StringType => value.toString + case _ => value.toString + } + } + + private def getDataType(attr: String) = { + schema.fields(schema.fieldIndex(attr)).dataType + } + private def parseFilter(f: Filter, target: String) = { + f match { + case EqualTo(attribute, value) => { + value.toString == target + } + case GreaterThan(attribute, value) => { + getDataType(attribute) match { + case IntegerType => value.toString.toInt < target.toInt + case DoubleType => value.toString.toDouble < target.toDouble + case StringType => value.toString < target + case _ => value.toString < target + } + } + case GreaterThanOrEqual(attribute, value) => { + getDataType(attribute) match { + case IntegerType => value.toString.toInt <= target.toInt + case DoubleType => value.toString.toDouble <= target.toDouble + case StringType => value.toString <= target + case _ => value.toString <= target + } + } + case LessThan(attribute, value) => { + getDataType(attribute) match { + case IntegerType => value.toString.toInt > target.toInt + case DoubleType => value.toString.toDouble > target.toDouble + case StringType => value.toString > target + case _ => value.toString > target + } + } + case LessThanOrEqual(attribute, value) => { + getDataType(attribute) match { + case IntegerType => value.toString.toInt >= target.toInt + case DoubleType => value.toString.toDouble >= target.toDouble + case StringType => value.toString >= target + case _ => value.toString >= target + } + } + case In(attribute, values) => { + getDataType(attribute) match { + case IntegerType => values.map(_.toString.toInt).contains(target.toInt) + case DoubleType => values.map(_.toString.toDouble).contains(target.toDouble) + case StringType => values.map(_.toString).contains(target) + case _ => values.map(_.toString).contains(target) + } + } + case IsNull(attribute) => target == null + case IsNotNull(attribute) => target != null + case StringStartsWith(attribute, value) => target.startsWith(value.toString) + case StringEndsWith(attribute, value) => target.endsWith(value.toString) + case StringContains(attribute, value) => target.contains(value.toString) + case _ => false + } + } +} + +class DefaultSource extends SchemaRelationProvider { + def createRelation(sqlContext: SQLContext, parameters: Map[String, String], schema: StructType) = { + RedisRelation(parameters, schema)(sqlContext) + } +} + From 8099a9d904a92f92e2e8e84304ba7085a5d4f74d Mon Sep 17 00:00:00 2001 From: Sun He Date: Mon, 30 May 2016 02:06:01 +0800 Subject: [PATCH 2/3] use pipeline for fetch and set --- .../provider/redis/sql/DefaultSource.scala | 76 +++++++++++++------ 1 file changed, 51 insertions(+), 25 deletions(-) diff --git a/src/main/scala/com/redislabs/provider/redis/sql/DefaultSource.scala b/src/main/scala/com/redislabs/provider/redis/sql/DefaultSource.scala index 4002cefa..dc816eaa 100644 --- a/src/main/scala/com/redislabs/provider/redis/sql/DefaultSource.scala +++ b/src/main/scala/com/redislabs/provider/redis/sql/DefaultSource.scala @@ -1,5 +1,7 @@ package com.redislabs.provider.redis.sql +import java.util + import scala.collection.JavaConversions._ import com.redislabs.provider.redis._ import com.redislabs.provider.redis.rdd.{Keys, RedisKeysRDD} @@ -44,47 +46,71 @@ case class RedisRelation(parameters: Map[String, String], userSchema: StructType /* Master only */ redisConfig.hosts.filter(node => { node.startSlot <= slot && node.endSlot >= slot }).filter(_.idx == 0)(0) } + def insert(data: DataFrame, overwrite: Boolean): Unit = { - data.foreach{ - row => { - val key = tableName + ":" + MessageDigest.getInstance("MD5").digest(System.currentTimeMillis.toString.getBytes) - val conn = getNode(key).endpoint.connect - conn.hmset(key, row.getValuesMap(row.schema.fieldNames).map(x => (x._1, x._2.toString))) - conn.close + data.foreachPartition{ + partition => { + val m: Map[String, Row] = partition.map { + row => { + val tn = tableName + ":" + MessageDigest.getInstance("MD5").digest( + row.getValuesMap(schema.fieldNames).map(_._2.toString).reduce(_ + " " + _).getBytes) + (tn, row) + } + }.toMap + groupKeysByNode(redisConfig.hosts, m.keysIterator).foreach{ + case(node, keys) => { + val conn = node.connect + val pipeline = conn.pipelined + keys.foreach{ + key => { + val row = m.get(key).get + pipeline.hmset(key, row.getValuesMap(row.schema.fieldNames).map(x => (x._1, x._2.toString))) + } + } + pipeline.sync + conn.close + } + } } } } def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { + val colsForFilter = filters.map(getAttr(_)).sorted.distinct + val colsForFilterWithIndex = colsForFilter.zipWithIndex.toMap + val requiredColumnsType = requiredColumns.map(getDataType(_)) new RedisKeysRDD(sqlContext.sparkContext, redisConfig, tableName + ":*", partitionNum, null). mapPartitions { partition: Iterator[String] => { groupKeysByNode(redisConfig.hosts, partition).flatMap { x => { val conn = x._1.endpoint.connect() - val rowKeys: Array[String] = filterKeysByType(conn, x._2, "hash") - val res = rowKeys.map { - key => { - val res = conn.hmget(key, schema.fieldNames: _*) - val pass = filters.zip(filters.map(filter => res.get((schema.fieldIndex(getAttr(filter)))))).forall{ - x => parseFilter(x._1, x._2) - } - if (pass) { - requiredColumns.map{ - c => { - val idx = schema.fieldIndex(c) - castToTarget(res.get(idx), schema.fields(idx)) - } + val pipeline = conn.pipelined + val keys: Array[String] = filterKeysByType(conn, x._2, "hash") + val rowKeys = if (colsForFilter.length == 0) { + keys + } else { + keys.foreach(key => pipeline.hmget(key, colsForFilter:_*)) + keys.zip(pipeline.syncAndReturnAll).filter { + x => { + val content = x._2.asInstanceOf[util.ArrayList[String]] + filters.forall { + filter => parseFilter(filter, content(colsForFilterWithIndex.get(getAttr(filter)).get)) } - } else { - null } + }.map(_._1) + } + + rowKeys.foreach(pipeline.hmget(_, requiredColumns:_*)) + val res = pipeline.syncAndReturnAll.map{ + _.asInstanceOf[util.ArrayList[String]].zip(requiredColumnsType).map { + case(col, targetType) => castToTarget(col, targetType) } } conn.close - res.filter(_!=null) + res } - }.toIterator.map(x => Row.fromSeq(x)) + }.toIterator.map(Row.fromSeq(_)) } } } @@ -105,8 +131,8 @@ case class RedisRelation(parameters: Map[String, String], userSchema: StructType } } - private def castToTarget(value: String, field: StructField) = { - field.dataType match { + private def castToTarget(value: String, dataType: DataType) = { + dataType match { case IntegerType => value.toString.toInt case DoubleType => value.toString.toDouble case StringType => value.toString From 26880474b62f7cf26ecd4835efd1a05dcba17955 Mon Sep 17 00:00:00 2001 From: Sun He Date: Sun, 5 Jun 2016 23:55:22 +0800 Subject: [PATCH 3/3] add test for SQL --- .../redis/rdd/RedisSparkSQLClusterSuite.scala | 69 ++++++++++++++++++ .../rdd/RedisSparkSQLStandaloneSuite.scala | 70 +++++++++++++++++++ 2 files changed, 139 insertions(+) create mode 100644 src/test/scala/com/redislabs/provider/redis/rdd/RedisSparkSQLClusterSuite.scala create mode 100644 src/test/scala/com/redislabs/provider/redis/rdd/RedisSparkSQLStandaloneSuite.scala diff --git a/src/test/scala/com/redislabs/provider/redis/rdd/RedisSparkSQLClusterSuite.scala b/src/test/scala/com/redislabs/provider/redis/rdd/RedisSparkSQLClusterSuite.scala new file mode 100644 index 00000000..51506860 --- /dev/null +++ b/src/test/scala/com/redislabs/provider/redis/rdd/RedisSparkSQLClusterSuite.scala @@ -0,0 +1,69 @@ +package com.redislabs.provider.redis.rdd + +import org.apache.spark.{SparkContext, SparkConf} +import org.scalatest.{BeforeAndAfterAll, ShouldMatchers, FunSuite} +import org.apache.spark.sql.SQLContext +import com.redislabs.provider.redis._ + +class RedisSparkSQLClusterSuite extends FunSuite with ENV with BeforeAndAfterAll with ShouldMatchers { + + var sqlContext: SQLContext = null + override def beforeAll() { + super.beforeAll() + + sc = new SparkContext(new SparkConf() + .setMaster("local").setAppName(getClass.getName) + .set("redis.host", "127.0.0.1") + .set("redis.port", "7379") + ) + redisConfig = new RedisConfig(new RedisEndpoint("127.0.0.1", 7379)) + + // Flush all the hosts + redisConfig.hosts.foreach( node => { + val conn = node.connect + conn.flushAll + conn.close + }) + + sqlContext = new SQLContext(sc) + sqlContext.sql( s""" + |CREATE TEMPORARY TABLE rl + |(name STRING, score INT) + |USING com.redislabs.provider.redis.sql + |OPTIONS (table 'rl') + """.stripMargin) + + (1 to 64).foreach{ + index => { + sqlContext.sql(s"insert overwrite table rl select t.* from (select 'rl${index}', ${index}) t") + } + } + } + + test("RedisKVRDD - default(cluster)") { + val df = sqlContext.sql( + s""" + |SELECT * + |FROM rl + """.stripMargin) + df.filter(df("score") > 10).count should be (54) + df.filter(df("score") > 10 and df("score") < 20).count should be (9) + } + + test("RedisKVRDD - cluster") { + implicit val c: RedisConfig = redisConfig + val df = sqlContext.sql( + s""" + |SELECT * + |FROM rl + """.stripMargin) + df.filter(df("score") > 10).count should be (54) + df.filter(df("score") > 10 and df("score") < 20).count should be (9) + } + + override def afterAll(): Unit = { + sc.stop + System.clearProperty("spark.driver.port") + } +} + diff --git a/src/test/scala/com/redislabs/provider/redis/rdd/RedisSparkSQLStandaloneSuite.scala b/src/test/scala/com/redislabs/provider/redis/rdd/RedisSparkSQLStandaloneSuite.scala new file mode 100644 index 00000000..667528af --- /dev/null +++ b/src/test/scala/com/redislabs/provider/redis/rdd/RedisSparkSQLStandaloneSuite.scala @@ -0,0 +1,70 @@ +package com.redislabs.provider.redis.rdd + +import com.redislabs.provider.redis._ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} +import org.scalatest.{BeforeAndAfterAll, FunSuite, ShouldMatchers} + +class RedisSparkSQLStandaloneSuite extends FunSuite with ENV with BeforeAndAfterAll with ShouldMatchers { + + var sqlContext: SQLContext = null + override def beforeAll() { + super.beforeAll() + + sc = new SparkContext(new SparkConf() + .setMaster("local").setAppName(getClass.getName) + .set("redis.host", "127.0.0.1") + .set("redis.port", "6379") + .set("redis.auth", "passwd") + ) + redisConfig = new RedisConfig(new RedisEndpoint("127.0.0.1", 6379, "passwd")) + + // Flush all the hosts + redisConfig.hosts.foreach( node => { + val conn = node.connect + conn.flushAll + conn.close + }) + + sqlContext = new SQLContext(sc) + sqlContext.sql( s""" + |CREATE TEMPORARY TABLE rl + |(name STRING, score INT) + |USING com.redislabs.provider.redis.sql + |OPTIONS (table 'rl') + """.stripMargin) + + (1 to 64).foreach{ + index => { + sqlContext.sql(s"insert overwrite table rl select t.* from (select 'rl${index}', ${index}) t") + } + } + } + + test("RedisKVRDD - default(cluster)") { + val df = sqlContext.sql( + s""" + |SELECT * + |FROM rl + """.stripMargin) + df.filter(df("score") > 10).count should be (54) + df.filter(df("score") > 10 and df("score") < 20).count should be (9) + } + + test("RedisKVRDD - cluster") { + implicit val c: RedisConfig = redisConfig + val df = sqlContext.sql( + s""" + |SELECT * + |FROM rl + """.stripMargin) + df.filter(df("score") > 10).count should be (54) + df.filter(df("score") > 10 and df("score") < 20).count should be (9) + } + + override def afterAll(): Unit = { + sc.stop + System.clearProperty("spark.driver.port") + } +} +