Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

668 aspect oriented middleware #669

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .scalafmt.conf
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ fileOverride {
"glob:**/scala-3/**" {
runner.dialect = scala3
}
}
}
137 changes: 113 additions & 24 deletions codegen/src/main/scala/fs2/grpc/codegen/Fs2GrpcServicePrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,19 @@ class Fs2GrpcServicePrinter(service: ServiceDescriptor, serviceSuffix: String, d
})
}

private[this] def handleMethod(method: MethodDescriptor) = {
private[this] def methodName(method: MethodDescriptor) =
method.streamType match {
case StreamType.Unary => "unaryToUnaryCall"
case StreamType.ClientStreaming => "streamingToUnaryCall"
case StreamType.ServerStreaming => "unaryToStreamingCall"
case StreamType.Bidirectional => "streamingToStreamingCall"
case StreamType.Unary => "unaryToUnary"
case StreamType.ClientStreaming => "streamingToUnary"
case StreamType.ServerStreaming => "unaryToStreaming"
case StreamType.Bidirectional => "streamingToStreaming"
}
}

private[this] def handleMethod(method: MethodDescriptor) =
methodName(method) + "Call"

private[this] def visitMethod(method: MethodDescriptor) =
"visit" + methodName(method).capitalize

private[this] def createClientCall(method: MethodDescriptor) = {
val basicClientCall =
Expand All @@ -66,17 +71,21 @@ class Fs2GrpcServicePrinter(service: ServiceDescriptor, serviceSuffix: String, d
}

private[this] def serviceMethodImplementation(method: MethodDescriptor): PrinterEndo = { p =>
val mkMetadata = if (method.isServerStreaming) s"$Stream.eval(mkMetadata(ctx))" else "mkMetadata(ctx)"
val inType = method.inputType.scalaType
val outType = method.outputType.scalaType
val descriptor = method.grpcDescriptor.fullName

p.add(serviceMethodSignature(method) + " = {")
.indent
.add(s"$mkMetadata.flatMap { m =>")
.indent
.add(s"${createClientCall(method)}.flatMap(_.${handleMethod(method)}(request, m))")
.outdent
.add("}")
.outdent
.add("}")
p
.add(serviceMethodSignature(method) + " =")
.indented {
_.addStringMargin(
s"""|clientAspect.${visitMethod(method)}[$inType, $outType](
| ${ClientCallContext}(ctx, $descriptor, implicitly[Dom[$inType]], implicitly[Cod[$outType]]),
| request,
| (req, m) => ${createClientCall(method)}.flatMap(_.${handleMethod(method)}(req, m))
|)""".stripMargin
)
}
}

private[this] def serviceBindingImplementation(method: MethodDescriptor): PrinterEndo = { p =>
Expand All @@ -86,9 +95,19 @@ class Fs2GrpcServicePrinter(service: ServiceDescriptor, serviceSuffix: String, d
val handler = s"$Fs2ServerCallHandler[F](dispatcher, serverOptions).${handleMethod(method)}[$inType, $outType]"

val serviceCall = s"serviceImpl.${method.name}"
val eval = if (method.isServerStreaming) s"$Stream.eval(mkCtx(m))" else "mkCtx(m)"

p.add(s".addMethod($descriptor, $handler((r, m) => $eval.flatMap($serviceCall(r, _))))")
p.addStringMargin(
s"""|.addMethod(
| $descriptor,
| $handler{ (r, m) =>
| serviceAspect.${visitMethod(method)}[$inType, $outType](
| ${ServerCallContext}(m, $descriptor, implicitly[Dom[$inType]], implicitly[Cod[$outType]]),
| r,
| (r, m) => $serviceCall(r, m)
| )
| }
|)"""
)
}

private[this] def serviceMethods: PrinterEndo = _.seq(service.methods.map(serviceMethodSignature))
Expand All @@ -115,23 +134,85 @@ class Fs2GrpcServicePrinter(service: ServiceDescriptor, serviceSuffix: String, d
.newline
.add("}")

private[this] def typeclasses: PrinterEndo = { p =>
val doms = service.methods
.map(_.inputType.scalaType)
.distinct
.zipWithIndex
.map { case (n, i) => s"dom$i: Dom[$n]" }

val cods = service.methods
.map(_.outputType.scalaType)
.distinct
.zipWithIndex
.map { case (n, i) => s"cod$i: Cod[$n]" }

p.addWithDelimiter(",")(doms ++ cods)
}

private[this] def serviceClient: PrinterEndo = {
_.add(
s"def mkClient[F[_]: $Async, $Ctx](dispatcher: $Dispatcher[F], channel: $Channel, mkMetadata: $Ctx => F[$Metadata], clientOptions: $ClientOptions): $serviceNameFs2[F, $Ctx] = new $serviceNameFs2[F, $Ctx] {"
).indent
_.addStringMargin(
s"""|def mkClientFull[F[_]: $Async, Dom[_], Cod[_], $Ctx](
| dispatcher: $Dispatcher[F],
| channel: $Channel,
| clientAspect: ${ClientAspect}[F, Dom, Cod, $Ctx],
| clientOptions: $ClientOptions
|)(implicit"""
)
.indented(typeclasses)
.add(s"): $serviceNameFs2[F, $Ctx] = new $serviceNameFs2[F, $Ctx] {")
.indent
.call(serviceMethodImplementations)
.outdent
.add("}")
.newline
.addStringMargin(
s"""|def mkClientTrivial[F[_]: $Async, $Ctx](
| dispatcher: $Dispatcher[F],
| channel: $Channel,
| clientAspect: ${ClientAspect}[F, $Trivial, $Trivial, $Ctx],
| clientOptions: $ClientOptions
|) =
| mkClientFull[F, $Trivial, $Trivial, $Ctx](
| dispatcher,
| channel,
| clientAspect,
| clientOptions
| )"""
)
}

private[this] def serviceBinding: PrinterEndo = {
_.add(
s"protected def serviceBinding[F[_]: $Async, $Ctx](dispatcher: $Dispatcher[F], serviceImpl: $serviceNameFs2[F, $Ctx], mkCtx: $Metadata => F[$Ctx], serverOptions: $ServerOptions): $ServerServiceDefinition = {"
).indent
_.addStringMargin(
s"""|protected def serviceBindingFull[F[_]: $Async, Dom[_], Cod[_], $Ctx](
| dispatcher: $Dispatcher[F],
| serviceImpl: $serviceNameFs2[F, $Ctx],
| serviceAspect: ${ServiceAspect}[F, Dom, Cod, $Ctx],
| serverOptions: $ServerOptions
|)(implicit"""
)
.indented(typeclasses)
.add(") = {")
.indent
.add(s"$ServerServiceDefinition")
.call(serviceBindingImplementations)
.outdent
.add("}")
.newline
.addStringMargin(
s"""|protected def serviceBindingTrivial[F[_]: $Async, $Ctx](
| dispatcher: $Dispatcher[F],
| serviceImpl: $serviceNameFs2[F, $Ctx],
| serviceAspect: ${ServiceAspect}[F, $Trivial, $Trivial, $Ctx],
| serverOptions: $ServerOptions
|) =
| serviceBindingFull[F, $Trivial, $Trivial, $Ctx](
| dispatcher,
| serviceImpl,
| serviceAspect,
| serverOptions
| )"""
)
}

// /
Expand All @@ -152,6 +233,9 @@ object Fs2GrpcServicePrinter {
private val effPkg = "_root_.cats.effect"
private val fs2Pkg = "_root_.fs2"
private val fs2grpcPkg = "_root_.fs2.grpc"
private val fs2grpcServerPkg = "_root_.fs2.grpc.server"
private val fs2grpcClientPkg = "_root_.fs2.grpc.client"
private val fs2grpcSharedPkg = "_root_.fs2.grpc.shared"
private val grpcPkg = "_root_.io.grpc"

// /
Expand All @@ -173,6 +257,11 @@ object Fs2GrpcServicePrinter {
val Channel = s"$grpcPkg.Channel"
val Metadata = s"$grpcPkg.Metadata"

val ServiceAspect = s"${fs2grpcServerPkg}.ServiceAspect"
val ServerCallContext = s"${fs2grpcServerPkg}.ServerCallContext"
val ClientAspect = s"${fs2grpcClientPkg}.ClientAspect"
val ClientCallContext = s"${fs2grpcClientPkg}.ClientCallContext"
val Trivial = s"${fs2grpcSharedPkg}.Trivial"
}

}
33 changes: 31 additions & 2 deletions runtime/src/main/scala/fs2/grpc/GeneratedCompanion.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,35 @@ import cats.effect.std.Dispatcher
import io.grpc._
import fs2.grpc.client.ClientOptions
import fs2.grpc.server.ServerOptions
import fs2.grpc.client.ClientAspect
import fs2.grpc.shared.Trivial
import fs2.grpc.server.ServiceAspect

trait GeneratedCompanion[Service[*[_], _]] {

implicit final def serviceCompanion: GeneratedCompanion[Service] = this

///=== Client ==========================================================================================================

def mkClientTrivial[F[_]: Async, A](
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

protected?

dispatcher: Dispatcher[F],
channel: Channel,
clientAspect: ClientAspect[F, Trivial, Trivial, A],
clientOptions: ClientOptions
): Service[F, A]

def mkClient[F[_]: Async, A](
dispatcher: Dispatcher[F],
channel: Channel,
mkMetadata: A => F[Metadata],
clientOptions: ClientOptions
): Service[F, A]
): Service[F, A] =
mkClientTrivial[F, A](
dispatcher,
channel,
ClientAspect.default[F, Trivial, Trivial].contraModify(mkMetadata),
clientOptions
)

final def mkClient[F[_]: Async, A](
dispatcher: Dispatcher[F],
Expand Down Expand Up @@ -116,12 +132,25 @@ trait GeneratedCompanion[Service[*[_], _]] {

///=== Service =========================================================================================================

protected def serviceBindingTrivial[F[_]: Async, A](
dispatcher: Dispatcher[F],
serviceImpl: Service[F, A],
serviceAspect: ServiceAspect[F, Trivial, Trivial, A],
serverOptions: ServerOptions
): ServerServiceDefinition

protected def serviceBinding[F[_]: Async, A](
dispatcher: Dispatcher[F],
serviceImpl: Service[F, A],
mkCtx: Metadata => F[A],
serverOptions: ServerOptions
): ServerServiceDefinition
): ServerServiceDefinition =
serviceBindingTrivial[F, A](
dispatcher,
serviceImpl,
ServiceAspect.default[F, Trivial, Trivial].modify(mkCtx),
serverOptions
)

final def service[F[_]: Async, A](
dispatcher: Dispatcher[F],
Expand Down
101 changes: 101 additions & 0 deletions runtime/src/main/scala/fs2/grpc/client/ClientAspect.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package fs2.grpc.client

import io.grpc._
import cats._
import cats.syntax.all._
import fs2.Stream

final case class ClientCallContext[Req, Res, Dom[_], Cod[_], A](
ctx: A,
methodDescriptor: MethodDescriptor[Req, Res],
dom: Dom[Req],
cod: Cod[Res]
)

trait ClientAspect[F[_], Dom[_], Cod[_], A] { self =>
def visitUnaryToUnary[Req, Res](
callCtx: ClientCallContext[Req, Res, Dom, Cod, A],
req: Req,
request: (Req, Metadata) => F[Res]
): F[Res]

def visitUnaryToStreaming[Req, Res](
callCtx: ClientCallContext[Req, Res, Dom, Cod, A],
req: Req,
request: (Req, Metadata) => Stream[F, Res]
): Stream[F, Res]

def visitStreamingToUnary[Req, Res](
callCtx: ClientCallContext[Req, Res, Dom, Cod, A],
req: Stream[F, Req],
request: (Stream[F, Req], Metadata) => F[Res]
): F[Res]

def visitStreamingToStreaming[Req, Res](
callCtx: ClientCallContext[Req, Res, Dom, Cod, A],
req: Stream[F, Req],
request: (Stream[F, Req], Metadata) => Stream[F, Res]
): Stream[F, Res]

def contraModify[B](f: B => F[A])(implicit F: Monad[F]): ClientAspect[F, Dom, Cod, B] =
new ClientAspect[F, Dom, Cod, B] {
def modCtx[Req, Res](ccc: ClientCallContext[Req, Res, Dom, Cod, B]): F[ClientCallContext[Req, Res, Dom, Cod, A]] =
f(ccc.ctx).map(a => ccc.copy(ctx = a))

override def visitUnaryToUnary[Req, Res](
callCtx: ClientCallContext[Req, Res, Dom, Cod, B],
req: Req,
request: (Req, Metadata) => F[Res]
): F[Res] =
modCtx(callCtx).flatMap(self.visitUnaryToUnary(_, req, request))

override def visitUnaryToStreaming[Req, Res](
callCtx: ClientCallContext[Req, Res, Dom, Cod, B],
req: Req,
request: (Req, Metadata) => Stream[F, Res]
): Stream[F, Res] =
Stream.eval(modCtx(callCtx)).flatMap(self.visitUnaryToStreaming(_, req, request))

override def visitStreamingToUnary[Req, Res](
callCtx: ClientCallContext[Req, Res, Dom, Cod, B],
req: Stream[F, Req],
request: (Stream[F, Req], Metadata) => F[Res]
): F[Res] =
modCtx(callCtx).flatMap(self.visitStreamingToUnary(_, req, request))

override def visitStreamingToStreaming[Req, Res](
callCtx: ClientCallContext[Req, Res, Dom, Cod, B],
req: Stream[F, Req],
request: (Stream[F, Req], Metadata) => Stream[F, Res]
): Stream[F, Res] =
Stream.eval(modCtx(callCtx)).flatMap(self.visitStreamingToStreaming(_, req, request))
}
}

object ClientAspect {
def default[F[_], Dom[_], Cod[_]] = new ClientAspect[F, Dom, Cod, Metadata] {
override def visitUnaryToUnary[Req, Res](
callCtx: ClientCallContext[Req, Res, Dom, Cod, Metadata],
req: Req,
request: (Req, Metadata) => F[Res]
): F[Res] = request(req, callCtx.ctx)

override def visitUnaryToStreaming[Req, Res](
callCtx: ClientCallContext[Req, Res, Dom, Cod, Metadata],
req: Req,
request: (Req, Metadata) => Stream[F, Res]
): Stream[F, Res] = request(req, callCtx.ctx)

override def visitStreamingToUnary[Req, Res](
callCtx: ClientCallContext[Req, Res, Dom, Cod, Metadata],
req: Stream[F, Req],
request: (Stream[F, Req], Metadata) => F[Res]
): F[Res] = request(req, callCtx.ctx)

override def visitStreamingToStreaming[Req, Res](
callCtx: ClientCallContext[Req, Res, Dom, Cod, Metadata],
req: Stream[F, Req],
request: (Stream[F, Req], Metadata) => Stream[F, Res]
): Stream[F, Res] = request(req, callCtx.ctx)
}
}
Loading
Loading