Skip to content

Commit

Permalink
Merge pull request #29 from RedisLabs/sql
Browse files Browse the repository at this point in the history
Add support for Spark SQL
  • Loading branch information
dvirsky committed Jun 7, 2016
2 parents 3c15858 + 2688047 commit 5b23452
Show file tree
Hide file tree
Showing 3 changed files with 345 additions and 0 deletions.
206 changes: 206 additions & 0 deletions src/main/scala/com/redislabs/provider/redis/sql/DefaultSource.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
package com.redislabs.provider.redis.sql

import java.util

import scala.collection.JavaConversions._
import com.redislabs.provider.redis._
import com.redislabs.provider.redis.rdd.{Keys, RedisKeysRDD}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import redis.clients.jedis.Protocol
import redis.clients.util.JedisClusterCRC16
import java.security.MessageDigest


case class RedisRelation(parameters: Map[String, String], userSchema: StructType)
(@transient val sqlContext: SQLContext)
extends BaseRelation with PrunedFilteredScan with InsertableRelation with Keys {

val tableName: String = parameters.getOrElse("table", "PANIC")

val redisConfig: RedisConfig = {
new RedisConfig({
if ((parameters.keySet & Set("host", "port", "auth", "dbNum", "timeout")).size == 0) {
new RedisEndpoint(sqlContext.sparkContext.getConf)
} else {
val host = parameters.getOrElse("host", Protocol.DEFAULT_HOST)
val port = parameters.getOrElse("port", Protocol.DEFAULT_PORT.toString).toInt
val auth = parameters.getOrElse("auth", null)
val dbNum = parameters.getOrElse("dbNum", Protocol.DEFAULT_DATABASE.toString).toInt
val timeout = parameters.getOrElse("timeout", Protocol.DEFAULT_TIMEOUT.toString).toInt
new RedisEndpoint(host, port, auth, dbNum, timeout)
}
}
)
}

val partitionNum: Int = parameters.getOrElse("partitionNum", 3.toString).toInt

val schema = userSchema

def getNode(key: String): RedisNode = {
val slot = JedisClusterCRC16.getSlot(key)
/* Master only */
redisConfig.hosts.filter(node => { node.startSlot <= slot && node.endSlot >= slot }).filter(_.idx == 0)(0)
}

def insert(data: DataFrame, overwrite: Boolean): Unit = {
data.foreachPartition{
partition => {
val m: Map[String, Row] = partition.map {
row => {
val tn = tableName + ":" + MessageDigest.getInstance("MD5").digest(
row.getValuesMap(schema.fieldNames).map(_._2.toString).reduce(_ + " " + _).getBytes)
(tn, row)
}
}.toMap
groupKeysByNode(redisConfig.hosts, m.keysIterator).foreach{
case(node, keys) => {
val conn = node.connect
val pipeline = conn.pipelined
keys.foreach{
key => {
val row = m.get(key).get
pipeline.hmset(key, row.getValuesMap(row.schema.fieldNames).map(x => (x._1, x._2.toString)))
}
}
pipeline.sync
conn.close
}
}
}
}
}

def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
val colsForFilter = filters.map(getAttr(_)).sorted.distinct
val colsForFilterWithIndex = colsForFilter.zipWithIndex.toMap
val requiredColumnsType = requiredColumns.map(getDataType(_))
new RedisKeysRDD(sqlContext.sparkContext, redisConfig, tableName + ":*", partitionNum, null).
mapPartitions {
partition: Iterator[String] => {
groupKeysByNode(redisConfig.hosts, partition).flatMap {
x => {
val conn = x._1.endpoint.connect()
val pipeline = conn.pipelined
val keys: Array[String] = filterKeysByType(conn, x._2, "hash")
val rowKeys = if (colsForFilter.length == 0) {
keys
} else {
keys.foreach(key => pipeline.hmget(key, colsForFilter:_*))
keys.zip(pipeline.syncAndReturnAll).filter {
x => {
val content = x._2.asInstanceOf[util.ArrayList[String]]
filters.forall {
filter => parseFilter(filter, content(colsForFilterWithIndex.get(getAttr(filter)).get))
}
}
}.map(_._1)
}

rowKeys.foreach(pipeline.hmget(_, requiredColumns:_*))
val res = pipeline.syncAndReturnAll.map{
_.asInstanceOf[util.ArrayList[String]].zip(requiredColumnsType).map {
case(col, targetType) => castToTarget(col, targetType)
}
}
conn.close
res
}
}.toIterator.map(Row.fromSeq(_))
}
}
}

private def getAttr(f: Filter): String = {
f match {
case EqualTo(attribute, value) => attribute
case GreaterThan(attribute, value) => attribute
case GreaterThanOrEqual(attribute, value) => attribute
case LessThan(attribute, value) => attribute
case LessThanOrEqual(attribute, value) => attribute
case In(attribute, values) => attribute
case IsNull(attribute) => attribute
case IsNotNull(attribute) => attribute
case StringStartsWith(attribute, value) => attribute
case StringEndsWith(attribute, value) => attribute
case StringContains(attribute, value) => attribute
}
}

private def castToTarget(value: String, dataType: DataType) = {
dataType match {
case IntegerType => value.toString.toInt
case DoubleType => value.toString.toDouble
case StringType => value.toString
case _ => value.toString
}
}

private def getDataType(attr: String) = {
schema.fields(schema.fieldIndex(attr)).dataType
}
private def parseFilter(f: Filter, target: String) = {
f match {
case EqualTo(attribute, value) => {
value.toString == target
}
case GreaterThan(attribute, value) => {
getDataType(attribute) match {
case IntegerType => value.toString.toInt < target.toInt
case DoubleType => value.toString.toDouble < target.toDouble
case StringType => value.toString < target
case _ => value.toString < target
}
}
case GreaterThanOrEqual(attribute, value) => {
getDataType(attribute) match {
case IntegerType => value.toString.toInt <= target.toInt
case DoubleType => value.toString.toDouble <= target.toDouble
case StringType => value.toString <= target
case _ => value.toString <= target
}
}
case LessThan(attribute, value) => {
getDataType(attribute) match {
case IntegerType => value.toString.toInt > target.toInt
case DoubleType => value.toString.toDouble > target.toDouble
case StringType => value.toString > target
case _ => value.toString > target
}
}
case LessThanOrEqual(attribute, value) => {
getDataType(attribute) match {
case IntegerType => value.toString.toInt >= target.toInt
case DoubleType => value.toString.toDouble >= target.toDouble
case StringType => value.toString >= target
case _ => value.toString >= target
}
}
case In(attribute, values) => {
getDataType(attribute) match {
case IntegerType => values.map(_.toString.toInt).contains(target.toInt)
case DoubleType => values.map(_.toString.toDouble).contains(target.toDouble)
case StringType => values.map(_.toString).contains(target)
case _ => values.map(_.toString).contains(target)
}
}
case IsNull(attribute) => target == null
case IsNotNull(attribute) => target != null
case StringStartsWith(attribute, value) => target.startsWith(value.toString)
case StringEndsWith(attribute, value) => target.endsWith(value.toString)
case StringContains(attribute, value) => target.contains(value.toString)
case _ => false
}
}
}

class DefaultSource extends SchemaRelationProvider {
def createRelation(sqlContext: SQLContext, parameters: Map[String, String], schema: StructType) = {
RedisRelation(parameters, schema)(sqlContext)
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package com.redislabs.provider.redis.rdd

import org.apache.spark.{SparkContext, SparkConf}
import org.scalatest.{BeforeAndAfterAll, ShouldMatchers, FunSuite}
import org.apache.spark.sql.SQLContext
import com.redislabs.provider.redis._

class RedisSparkSQLClusterSuite extends FunSuite with ENV with BeforeAndAfterAll with ShouldMatchers {

var sqlContext: SQLContext = null
override def beforeAll() {
super.beforeAll()

sc = new SparkContext(new SparkConf()
.setMaster("local").setAppName(getClass.getName)
.set("redis.host", "127.0.0.1")
.set("redis.port", "7379")
)
redisConfig = new RedisConfig(new RedisEndpoint("127.0.0.1", 7379))

// Flush all the hosts
redisConfig.hosts.foreach( node => {
val conn = node.connect
conn.flushAll
conn.close
})

sqlContext = new SQLContext(sc)
sqlContext.sql( s"""
|CREATE TEMPORARY TABLE rl
|(name STRING, score INT)
|USING com.redislabs.provider.redis.sql
|OPTIONS (table 'rl')
""".stripMargin)

(1 to 64).foreach{
index => {
sqlContext.sql(s"insert overwrite table rl select t.* from (select 'rl${index}', ${index}) t")
}
}
}

test("RedisKVRDD - default(cluster)") {
val df = sqlContext.sql(
s"""
|SELECT *
|FROM rl
""".stripMargin)
df.filter(df("score") > 10).count should be (54)
df.filter(df("score") > 10 and df("score") < 20).count should be (9)
}

test("RedisKVRDD - cluster") {
implicit val c: RedisConfig = redisConfig
val df = sqlContext.sql(
s"""
|SELECT *
|FROM rl
""".stripMargin)
df.filter(df("score") > 10).count should be (54)
df.filter(df("score") > 10 and df("score") < 20).count should be (9)
}

override def afterAll(): Unit = {
sc.stop
System.clearProperty("spark.driver.port")
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package com.redislabs.provider.redis.rdd

import com.redislabs.provider.redis._
import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.{BeforeAndAfterAll, FunSuite, ShouldMatchers}

class RedisSparkSQLStandaloneSuite extends FunSuite with ENV with BeforeAndAfterAll with ShouldMatchers {

var sqlContext: SQLContext = null
override def beforeAll() {
super.beforeAll()

sc = new SparkContext(new SparkConf()
.setMaster("local").setAppName(getClass.getName)
.set("redis.host", "127.0.0.1")
.set("redis.port", "6379")
.set("redis.auth", "passwd")
)
redisConfig = new RedisConfig(new RedisEndpoint("127.0.0.1", 6379, "passwd"))

// Flush all the hosts
redisConfig.hosts.foreach( node => {
val conn = node.connect
conn.flushAll
conn.close
})

sqlContext = new SQLContext(sc)
sqlContext.sql( s"""
|CREATE TEMPORARY TABLE rl
|(name STRING, score INT)
|USING com.redislabs.provider.redis.sql
|OPTIONS (table 'rl')
""".stripMargin)

(1 to 64).foreach{
index => {
sqlContext.sql(s"insert overwrite table rl select t.* from (select 'rl${index}', ${index}) t")
}
}
}

test("RedisKVRDD - default(cluster)") {
val df = sqlContext.sql(
s"""
|SELECT *
|FROM rl
""".stripMargin)
df.filter(df("score") > 10).count should be (54)
df.filter(df("score") > 10 and df("score") < 20).count should be (9)
}

test("RedisKVRDD - cluster") {
implicit val c: RedisConfig = redisConfig
val df = sqlContext.sql(
s"""
|SELECT *
|FROM rl
""".stripMargin)
df.filter(df("score") > 10).count should be (54)
df.filter(df("score") > 10 and df("score") < 20).count should be (9)
}

override def afterAll(): Unit = {
sc.stop
System.clearProperty("spark.driver.port")
}
}

0 comments on commit 5b23452

Please sign in to comment.