Skip to content

Commit

Permalink
Merge pull request #40 from davenverse/addRedisStreams
Browse files Browse the repository at this point in the history
Support Redis Streams Revisited (xadd, xread, RedisStream)
  • Loading branch information
ChristopherDavenport authored Jan 5, 2022
2 parents 3a97989 + 8de48fe commit e07198c
Show file tree
Hide file tree
Showing 6 changed files with 349 additions and 10 deletions.
7 changes: 3 additions & 4 deletions build.sbt
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import sbtcrossproject.CrossPlugin.autoImport.{crossProject, CrossType}

val catsV = "2.6.1"
val catsEffectV = "3.2.0"
// val fs2V = "3.0.6"
val fs2V = "3.1.0"
val catsV = "2.7.0"
val catsEffectV = "3.3.3"
val fs2V = "3.2.3"

val munitCatsEffectV = "1.0.7"

Expand Down
111 changes: 109 additions & 2 deletions core/src/main/scala/io/chrisdavenport/rediculous/RedisCommands.scala
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,115 @@ object RedisCommands {

// TODO Scan
// TODO LEX
// TODO xadd
// TODO xread

sealed trait Trimming
object Trimming {
case object Approximate extends Trimming
case object Exact extends Trimming
implicit val arg: RedisArg[Trimming] = RedisArg[String].contramap[Trimming]{
case Approximate => "~"
case Exact => "="
}
}

final case class XAddOpts(
id: Option[String],
maxLength: Option[Long],
trimming: Option[Trimming],
noMkStream: Boolean,
minId: Option[String],
limit: Option[Long]
)
object XAddOpts {
val default = XAddOpts(None, None, None, false, None, None)
}

def xadd[F[_]: RedisCtx](stream: String, body: List[(String, String)], xaddOpts: XAddOpts = XAddOpts.default): F[String] = {
val maxLen = xaddOpts.maxLength.toList.flatMap{ l => List("MAXLEN".some, xaddOpts.trimming.map(_.encode), l.encode.some).flattenOption }
val minId = xaddOpts.minId.toList.flatMap{ l => List("MINID".some, xaddOpts.trimming.map(_.encode), l.encode.some).flattenOption }
val limit = xaddOpts.limit.toList.flatMap(l=> if (xaddOpts.trimming.contains(Trimming.Approximate)) List("LIMIT", l.encode) else List.empty)
val noMkStream = Alternative[List].guard(xaddOpts.noMkStream).as("NOMKSTREAM")
val id = List(xaddOpts.id.getOrElse("*"))
val bodyEnd = body.foldLeft(List.empty[String]){ case (s, (k,v)) => s ::: List(k.encode, v.encode) }

RedisCtx[F].unkeyed(NEL("XADD", stream :: maxLen ::: minId ::: limit ::: noMkStream ::: id ::: bodyEnd))
}

final case class XReadOpts(
blockMillisecond: Option[Long],
count: Option[Long],
noAck: Boolean
)
object XReadOpts {
val default = XReadOpts(None, None, false)
}

sealed trait StreamOffset {
def stream: String
def offset: String
}

object StreamOffset {
case class All(stream: String) extends StreamOffset {
override def offset: String = "0"
}
case class Latest(stream: String) extends StreamOffset {
override def offset: String = "$"
}
case class From(stream: String, offset: String) extends StreamOffset
}

final case class StreamsRecord(
recordId: String,
keyValues: List[(String, String)]
)

object StreamsRecord {
implicit val result : RedisResult[StreamsRecord] = new RedisResult[StreamsRecord] {
def decode(resp: Resp): Either[Resp,StreamsRecord] = {
def two[A](l: List[A], acc: List[(A, A)] = List.empty): List[(A, A)] = l match {
case first :: second :: rest => two(rest, (first, second):: acc)
case otherwise => acc.reverse
}
resp match {
case Resp.Array(Some(Resp.BulkString(Some(recordId)) :: Resp.Array(Some(rawKeyValues)) :: Nil)) =>
for {
keyValuesList <- rawKeyValues.traverse(RedisResult[String].decode).map(two(_))
} yield StreamsRecord(recordId, keyValuesList)
case otherwise => Left(otherwise)
}
}
}
}

final case class XReadResponse(
stream: String,
records: List[StreamsRecord]
)
object XReadResponse{
implicit val result: RedisResult[XReadResponse] = new RedisResult[XReadResponse] {
def decode(resp: Resp): Either[Resp,XReadResponse] = {
resp match {
case Resp.Array(Some(Resp.BulkString(Some(stream)) :: Resp.Array(Some(list)) :: Nil)) =>
list.traverse(RedisResult[StreamsRecord].decode).map(l =>
XReadResponse(stream, l)
)
case otherwise => Left(otherwise)
}
}
}
}

def xread[F[_]: RedisCtx](streams: Set[StreamOffset], xreadOpts: XReadOpts = XReadOpts.default): F[Option[List[XReadResponse]]] = {//F[Option[List[List[(String, List[List[(String, List[(String, String)])]])]]]] = {
val block = xreadOpts.blockMillisecond.toList.flatMap(l => List("BLOCK", l.encode))
val count = xreadOpts.count.toList.flatMap(l => List("COUNT", l.encode))
val noAck = Alternative[List].guard(xreadOpts.noAck).as("NOACK")
val streamKeys = streams.map(_.stream.encode).toList
val streamOffsets = streams.map(_.offset.encode).toList
val streamPairs = "STREAMS" :: streamKeys ::: streamOffsets

RedisCtx[F].unkeyed(NEL("XREAD", block ::: count ::: noAck ::: streamPairs))
}

def xgroupcreate[F[_]: RedisCtx](stream: String, groupName: String, startId: String): F[Status] =
RedisCtx[F].unkeyed(NEL.of("XGROUP", "CREATE", stream, groupName, startId))
Expand Down
61 changes: 61 additions & 0 deletions core/src/main/scala/io/chrisdavenport/rediculous/RedisStream.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package io.chrisdavenport.rediculous

import cats.implicits._
import fs2.{Stream, Pipe}
import scala.concurrent.duration.Duration
import RedisCommands.{XAddOpts, XReadOpts, StreamOffset, Trimming, xadd, xread}
import cats.effect._


trait RedisStream[F[_]] {
def append(messages: List[RedisStream.XAddMessage]): F[List[String]]

def read(
streams: Set[String],
chunkSize: Int,
initialOffset: String => StreamOffset = {(s: String) => StreamOffset.All(s)},
block: Duration = Duration.Zero,
count: Option[Long] = None
): Stream[F, RedisCommands.XReadResponse]
}

object RedisStream {

final case class XAddMessage(stream: String, body: List[(String, String)], approxMaxlen: Option[Long] = None)

/**
* Create a RedisStream from a connection.
*
**/
def fromConnection[F[_]: Async](connection: RedisConnection[F]): RedisStream[F] = new RedisStream[F] {
def append(messages: List[XAddMessage]): F[List[String]] = {
messages
.traverse{ case msg =>
val opts = msg.approxMaxlen.map(l => XAddOpts.default.copy(maxLength = l.some, trimming = Trimming.Approximate.some))
xadd[RedisPipeline](msg.stream, msg.body, opts getOrElse XAddOpts.default)
}
.pipeline[F]
.run(connection)
}

private val nextOffset: String => RedisCommands.StreamsRecord => StreamOffset =
key => msg => StreamOffset.From(key, msg.recordId)

private val offsetsByKey: List[RedisCommands.StreamsRecord] => Map[String, Option[StreamOffset]] =
list => list.groupBy(_.recordId).map { case (k, values) => k -> values.lastOption.map(nextOffset(k)) }

def read(keys: Set[String], chunkSize: Int, initialOffset: String => StreamOffset, block: Duration, count: Option[Long]): Stream[F, RedisCommands.XReadResponse] = {
val initial = keys.map(k => k -> initialOffset(k)).toMap
val opts = XReadOpts.default.copy(blockMillisecond = block.toMillis.some, count = count)
Stream.eval(Ref.of[F, Map[String, StreamOffset]](initial)).flatMap { ref =>
(for {
offsets <- Stream.eval(ref.get)
list <- Stream.eval(xread(offsets.values.toSet, opts).run(connection)).flattenOption
newOffsets = offsetsByKey(list.flatMap(_.records)).collect { case (key, Some(value)) => key -> value }.toList
_ <- Stream.eval(newOffsets.map { case (k, v) => ref.update(_.updated(k, v)) }.sequence)
result <- Stream.emits(list)
} yield result).repeat
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import cats.syntax.all._
import cats.effect._
import munit.CatsEffectSuite
import scala.concurrent.duration._
import io.chrisdavenport.whaletail.Docker
import io.chrisdavenport.whaletail.manager._
import _root_.io.chrisdavenport.whaletail.Docker
import _root_.io.chrisdavenport.whaletail.manager._
import com.comcast.ip4s.Host
import com.comcast.ip4s.Port

Expand Down Expand Up @@ -40,14 +40,32 @@ class RedisCommandsSpec extends CatsEffectSuite {
override def munitFixtures: Seq[Fixture[_]] = Seq(
redisConnection
)
test("set a value"){ //connection =>
test("set/get parity"){ //connection =>
redisConnection().flatMap{connection =>
val key = "foo"
val value = "bar"
val action = RedisCommands.set[RedisIO](key, value) >> RedisCommands.get[RedisIO](key)
val action = RedisCommands.set[RedisIO](key, value) >>
RedisCommands.get[RedisIO](key) <*
RedisCommands.del[RedisIO]("foo")
action.run(connection)
}.map{
assertEquals(_, Some("bar"))
}
}

test("xadd/xread parity"){
redisConnection().flatMap{ connection =>
val kv = "bar" -> "baz"
val action = RedisCommands.xadd[RedisIO]("foo", List(kv)) >>
RedisCommands.xread[RedisIO](Set(RedisCommands.StreamOffset.All("foo"))) <*
RedisCommands.del[RedisIO]("foo")

val extract = (resp: Option[List[RedisCommands.XReadResponse]]) =>
resp.flatMap(_.headOption).flatMap(_.records.headOption).flatMap(_.keyValues.headOption)

action.run(connection).map{ resp =>
assertEquals(extract(resp), Some(kv))
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package io.chrisdavenport.rediculous

import cats.syntax.all._
import cats.effect._
import munit.CatsEffectSuite
import scala.concurrent.duration._
import _root_.io.chrisdavenport.whaletail.Docker
import _root_.io.chrisdavenport.whaletail.manager._
import com.comcast.ip4s.Host
import com.comcast.ip4s.Port

class RedisStreamSpec extends CatsEffectSuite {
val resource = Docker.default[IO].flatMap(client =>
WhaleTailContainer.build(client, "redis", "latest".some, Map(6379 -> None), Map.empty, Map.empty)
.evalTap(
ReadinessStrategy.checkReadiness(
client,
_,
ReadinessStrategy.LogRegex(".*Ready to accept connections.*\\s".r),
30.seconds
)
)
).flatMap(container =>
for {
t <- Resource.eval(
container.ports.get(6379).liftTo[IO](new Throwable("Missing Port"))
)
(hostS, portI) = t
host <- Resource.eval(Host.fromString(hostS).liftTo[IO](new Throwable("Invalid Host")))
port <- Resource.eval(Port.fromInt(portI).liftTo[IO](new Throwable("Invalid Port")))
connection <- RedisConnection.pool[IO].withHost(host).withPort(port).build
} yield connection

)
// Not available on scala.js
val redisConnection = UnsafeResourceSuiteLocalDeferredFixture(
"redisconnection",
resource
)
override def munitFixtures: Seq[Fixture[_]] = Seq(
redisConnection
)
test("send a single message"){ //connection =>
val messages = List(
RedisStream.XAddMessage("foo", List("bar" -> "baz", "zoom" -> "zad"))
)
redisConnection().flatMap{connection =>

val rStream = RedisStream.fromConnection(connection)
rStream.append(messages) >>
rStream.read(Set("foo"), 512).take(1).compile.lastOrError

}.map{ xrr =>
val i = xrr.stream
assertEquals(xrr.stream, "foo")
val i2 = xrr.records.flatMap(sr => sr.keyValues)
assertEquals(i2, messages.flatMap(_.body))
}
}

}


91 changes: 91 additions & 0 deletions examples/src/main/scala/StreamsExample.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import io.chrisdavenport.rediculous._
import java.util.concurrent.TimeoutException
import scala.collection.immutable.Queue
import scala.concurrent.duration._
import scala.util.Random
import cats.effect._
import cats.implicits._
import fs2._
import fs2.timeseries.{TimeStamped, TimeSeries}
import fs2.io.net._
import com.comcast.ip4s._

object StreamRate {
def rate[A] =
TimeStamped.withPerSecondRate[Option[Chunk[A]], Float](_.map(chunk => chunk.size.toFloat).getOrElse(0.0f))

def averageRate[A] =
rate[A].andThen(Scan.stateful1(Queue.empty[Float]) {
case (q, tsv @ TimeStamped(_, Right(_))) => (q, tsv)
case (q, TimeStamped(t, Left(sample))) =>
val q2 = (sample +: q).take(10)
val average = q2.sum / q2.size
(q, TimeStamped(t, Left(average)))
})

implicit class Logger[F[_]: Temporal, A](input: Stream[F, A]) {
def logAverageRate(logger: Float => F[Unit]): Stream[F, A] =
TimeSeries.timePulled(input.chunks, 1.second, 1.second)
.through(averageRate.toPipe)
.flatMap {
case TimeStamped(_, Left(rate)) => Stream.exec(logger(rate))
case TimeStamped(_, Right(Some(chunk))) => Stream.chunk(chunk)
case TimeStamped(_, Right(None)) => Stream.empty
}
}
}

object StreamProducerExample extends IOApp {
import StreamRate._

def putStrLn[A](a: A): IO[Unit] = IO(println(a))

def randomMessage: IO[List[(String, String)]] = {
val rndKey = IO(Random.nextInt(1000).toString)
val rndValue = IO(Random.nextString(10))
(rndKey, rndValue).parMapN{ case (k, v) => List(k -> v) }
}

def run(args: List[String]): IO[ExitCode] = {
val mystream = "mystream"

RedisConnection.pool[IO].withHost(host"localhost").withPort(port"6379").build
.map(RedisStream.fromConnection[IO])
.use { rs =>
val consumer = rs
.read(Set(mystream), 10000)
.evalMap(putStrLn)
.onError{ case err => Stream.exec(IO.println(s"Consumer err: $err"))}
.logAverageRate(rate => IO.println(s"Consumer rate: $rate/s"))

val producer =
Stream
.repeatEval(randomMessage)
.map(RedisStream.XAddMessage(mystream, _))
.chunkMin(10000)
.flatMap{ chunk =>
Stream.evalSeq(rs.append(chunk.toList))
}
.onError{ case err => Stream.exec(IO.println(s"Producer err: $err"))}
.logAverageRate(rate => IO.println(s"Producer rate: $rate/s"))

val stream =
// Stream.exec( RedisCommands.del[RedisPipeline]("mystream").pipeline[IO].run(client).void) ++
Stream.exec(IO.println("Started")) ++
consumer
.concurrently(producer)
.interruptAfter(7.second)

// Stream.eval( RedisCommands.xlen[RedisPipeline]("mystream").pipeline[IO].run(client).flatMap(length => IO.println(s"Finished: $length")))

stream.compile.count.flatTap(l => putStrLn(s"Length: $l"))
}
.redeem(
{ t =>
IO.println(s"Error: $t, Something went wrong")
ExitCode(1)
},
_ => ExitCode.Success
)
}
}

0 comments on commit e07198c

Please sign in to comment.