diff --git a/doc/rdd.md b/doc/rdd.md index c22d47df..9ef505c7 100644 --- a/doc/rdd.md +++ b/doc/rdd.md @@ -137,6 +137,15 @@ sc.toRedisFixedLIST(listRDD, listName, listSize) The `listRDD` is an RDD that contains all of the list's string elements in order, and `listName` is the list's key name. `listSize` is an integer which specifies the size of the Redis list; it is optional, and will default to an unlimited size. +Use the following to store an RDD of binary values in a Redis List: + +```scala +sc.toRedisByteLIST(byteListRDD) +``` + +The `byteListRDD` is an RDD of tuples (`list name`, `list values`) represented as byte arrays. + + #### Sets For storing data in a Redis Set, use `toRedisSET` as follows: diff --git a/src/main/scala/com/redislabs/provider/redis/RedisConfig.scala b/src/main/scala/com/redislabs/provider/redis/RedisConfig.scala index 57c2e0a6..fd47c097 100644 --- a/src/main/scala/com/redislabs/provider/redis/RedisConfig.scala +++ b/src/main/scala/com/redislabs/provider/redis/RedisConfig.scala @@ -166,14 +166,25 @@ class RedisConfig(val initialHost: RedisEndpoint) extends Serializable { } /** - * @param key * *IMPORTANT* Please remember to close after using - * @return jedis who is a connection for a given key + * + * @param key + * @return jedis that is a connection for a given key */ def connectionForKey(key: String): Jedis = { getHost(key).connect() } + /** + * *IMPORTANT* Please remember to close after using + * + * @param key + * @return jedis is a connection for a given key + */ + def connectionForKey(key: Array[Byte]): Jedis = { + getHost(key).connect() + } + /** * @param initialHost any redis endpoint of a cluster or a single server * @return true if the target server is in cluster mode @@ -195,9 +206,22 @@ class RedisConfig(val initialHost: RedisEndpoint) extends Serializable { */ def getHost(key: String): RedisNode = { val slot = JedisClusterCRC16.getSlot(key) - hosts.filter(host => { + getHostBySlot(slot) + } + + /** + * @param key + * @return host whose slots should involve key + */ + def getHost(key: Array[Byte]): RedisNode = { + val slot = JedisClusterCRC16.getSlot(key) + getHostBySlot(slot) + } + + private def getHostBySlot(slot: Int): RedisNode = { + hosts.filter { host => host.startSlot <= slot && host.endSlot >= slot - })(0) + }(0) } diff --git a/src/main/scala/com/redislabs/provider/redis/redisFunctions.scala b/src/main/scala/com/redislabs/provider/redis/redisFunctions.scala index 563e87c8..d59394cb 100644 --- a/src/main/scala/com/redislabs/provider/redis/redisFunctions.scala +++ b/src/main/scala/com/redislabs/provider/redis/redisFunctions.scala @@ -1,6 +1,7 @@ package com.redislabs.provider.redis import com.redislabs.provider.redis.rdd._ +import com.redislabs.provider.redis.util.ConnectionUtils.withConnection import com.redislabs.provider.redis.util.PipelineUtils._ import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD @@ -299,6 +300,19 @@ class RedisContext(@transient val sc: SparkContext) extends Serializable { vs.foreachPartition(partition => setList(listName, partition, ttl, redisConfig, readWriteConfig)) } + /** + * Write RDD of binary values to Redis List. + * + * @param rdd RDD of tuples (list name, list values) + * @param ttl time to live + */ + def toRedisByteLIST(rdd: RDD[(Array[Byte], Seq[Array[Byte]])], ttl: Int = 0) + (implicit + redisConfig: RedisConfig = RedisConfig.fromSparkConf(sc.getConf), + readWriteConfig: ReadWriteConfig = ReadWriteConfig.fromSparkConf(sc.getConf)) { + rdd.foreachPartition(partition => setList(partition, ttl, redisConfig, readWriteConfig)) + } + /** * @param vs RDD of values * @param listName target list's name which hold all the vs @@ -415,6 +429,30 @@ object RedisContext extends Serializable { conn.close() } + + def setList(keyValues: Iterator[(Array[Byte], Seq[Array[Byte]])], + ttl: Int, + redisConfig: RedisConfig, + readWriteConfig: ReadWriteConfig) { + implicit val rwConf: ReadWriteConfig = readWriteConfig + + keyValues + .map { case (key, listValues) => + (redisConfig.getHost(key), (key, listValues)) + } + .toArray + .groupBy(_._1) + .foreach { case (node, arr) => + withConnection(node.endpoint.connect()) { conn => + foreachWithPipeline(conn, arr) { (pipeline, a) => + val (key, listVals) = a._2 + pipeline.rpush(key, listVals: _*) + if (ttl > 0) pipeline.expire(key, ttl) + } + } + } + } + /** * @param key * @param listSize diff --git a/src/test/scala/com/redislabs/provider/redis/rdd/RedisRddSuite.scala b/src/test/scala/com/redislabs/provider/redis/rdd/RedisRddSuite.scala index 7f6ada5f..3ac2c0b0 100644 --- a/src/test/scala/com/redislabs/provider/redis/rdd/RedisRddSuite.scala +++ b/src/test/scala/com/redislabs/provider/redis/rdd/RedisRddSuite.scala @@ -1,7 +1,9 @@ package com.redislabs.provider.redis.rdd +import com.redislabs.provider.redis.util.ConnectionUtils.withConnection import com.redislabs.provider.redis.{RedisConfig, SparkRedisSuite, toRedisContext} import org.scalatest.Matchers +import scala.collection.JavaConverters._ import scala.io.Source.fromInputStream @@ -109,6 +111,27 @@ trait RedisRddSuite extends SparkRedisSuite with Keys with Matchers { setContents should be(ws) } + test("toRedisLIST, byte array") { + val list1 = Seq("a1", "b1", "c1") + val list2 = Seq("a2", "b2", "c2") + val keyValues = Seq( + ("list1", list1), + ("list2", list2) + ) + val keyValueBytes = keyValues.map {case (k, list) => (k.getBytes, list.map(_.getBytes())) } + val rdd = sc.parallelize(keyValueBytes) + sc.toRedisByteLIST(rdd) + + def verify(list: String, vals: Seq[String]): Unit = { + withConnection(redisConfig.getHost(list).endpoint.connect()) { conn => + conn.lrange(list, 0, vals.size).asScala should be(vals.toList) + } + } + + verify("list1", list1) + verify("list2", list2) + } + test("Expire") { val expireTime = 1 val prefix = s"#expire in $expireTime#:"