diff --git a/.scalafmt.conf b/.scalafmt.conf index 1e01bb04..8d384f41 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -10,4 +10,4 @@ fileOverride { "glob:**/scala-3/**" { runner.dialect = scala3 } -} \ No newline at end of file +} diff --git a/build.sbt b/build.sbt index e5028fc4..73fd4662 100644 --- a/build.sbt +++ b/build.sbt @@ -134,7 +134,10 @@ lazy val e2e = (projectMatrix in file("e2e")) .settings( codeGenClasspath := (codeGenJVM212 / Compile / fullClasspath).value, libraryDependencies := Nil, - libraryDependencies ++= List(scalaPbGrpcRuntime, scalaPbRuntime, scalaPbRuntime % "protobuf", ceMunit % Test), + libraryDependencies ++= List(grpcApi, scalaPbGrpcRuntime, scalaPbRuntime, scalaPbRuntime % "protobuf", ceMunit % Test), + libraryDependencies ++= Seq( + "io.grpc" % "grpc-services" % versions.grpc % Test, + ), Compile / PB.targets := Seq( scalapb.gen() -> (Compile / sourceManaged).value / "scalapb", genModule(codegenFullName + "$") -> (Compile / sourceManaged).value / "fs2-grpc" diff --git a/codegen/src/main/scala/fs2/grpc/codegen/Fs2GrpcServicePrinter.scala b/codegen/src/main/scala/fs2/grpc/codegen/Fs2GrpcServicePrinter.scala index 0f536c70..eec85c6b 100644 --- a/codegen/src/main/scala/fs2/grpc/codegen/Fs2GrpcServicePrinter.scala +++ b/codegen/src/main/scala/fs2/grpc/codegen/Fs2GrpcServicePrinter.scala @@ -47,18 +47,23 @@ class Fs2GrpcServicePrinter(service: ServiceDescriptor, serviceSuffix: String, d }) } - private[this] def handleMethod(method: MethodDescriptor) = { + private[this] def methodName(method: MethodDescriptor) = method.streamType match { - case StreamType.Unary => "unaryToUnaryCall" - case StreamType.ClientStreaming => "streamingToUnaryCall" - case StreamType.ServerStreaming => "unaryToStreamingCall" - case StreamType.Bidirectional => "streamingToStreamingCall" + case StreamType.Unary => "unaryToUnary" + case StreamType.ClientStreaming => "streamingToUnary" + case StreamType.ServerStreaming => "unaryToStreaming" + case StreamType.Bidirectional => "streamingToStreaming" } - } + + private[this] def handleMethod(method: MethodDescriptor) = + methodName(method) + "Call" + + private[this] def visitMethod(method: MethodDescriptor) = + "visit" + methodName(method).capitalize private[this] def createClientCall(method: MethodDescriptor) = { val basicClientCall = - s"$Fs2ClientCall[F](channel, ${method.grpcDescriptor.fullName}, dispatcher, clientOptions)" + s"$Fs2ClientCall[G](channel, ${method.grpcDescriptor.fullName}, dispatcher, clientOptions)" if (method.isServerStreaming) s"$Stream.eval($basicClientCall)" else @@ -66,29 +71,43 @@ class Fs2GrpcServicePrinter(service: ServiceDescriptor, serviceSuffix: String, d } private[this] def serviceMethodImplementation(method: MethodDescriptor): PrinterEndo = { p => - val mkMetadata = if (method.isServerStreaming) s"$Stream.eval(mkMetadata(ctx))" else "mkMetadata(ctx)" + val inType = method.inputType.scalaType + val outType = method.outputType.scalaType + val descriptor = method.grpcDescriptor.fullName - p.add(serviceMethodSignature(method) + " = {") - .indent - .add(s"$mkMetadata.flatMap { m =>") - .indent - .add(s"${createClientCall(method)}.flatMap(_.${handleMethod(method)}(request, m))") - .outdent - .add("}") - .outdent - .add("}") + p + .add(serviceMethodSignature(method) + " =") + .indented { + _.addStringMargin( + s"""|clientAspect.${visitMethod(method)}[$inType, $outType]( + | ${ClientCallContext}(ctx, $descriptor), + | request, + | (req, m) => ${createClientCall(method)}.flatMap(_.${handleMethod(method)}(req, m)) + |)""".stripMargin + ) + } } private[this] def serviceBindingImplementation(method: MethodDescriptor): PrinterEndo = { p => val inType = method.inputType.scalaType val outType = method.outputType.scalaType val descriptor = method.grpcDescriptor.fullName - val handler = s"$Fs2ServerCallHandler[F](dispatcher, serverOptions).${handleMethod(method)}[$inType, $outType]" + val handler = s"$Fs2ServerCallHandler[G](dispatcher, serverOptions).${handleMethod(method)}[$inType, $outType]" val serviceCall = s"serviceImpl.${method.name}" - val eval = if (method.isServerStreaming) s"$Stream.eval(mkCtx(m))" else "mkCtx(m)" - p.add(s".addMethod($descriptor, $handler((r, m) => $eval.flatMap($serviceCall(r, _))))") + p.addStringMargin( + s"""|.addMethod( + | $descriptor, + | $handler{ (r, m) => + | serviceAspect.${visitMethod(method)}[$inType, $outType]( + | ${ServerCallContext}(m, $descriptor), + | r, + | (r, m) => $serviceCall(r, m) + | ) + | } + |)""" + ) } private[this] def serviceMethods: PrinterEndo = _.seq(service.methods.map(serviceMethodSignature)) @@ -116,8 +135,13 @@ class Fs2GrpcServicePrinter(service: ServiceDescriptor, serviceSuffix: String, d .add("}") private[this] def serviceClient: PrinterEndo = { - _.add( - s"def mkClient[F[_]: $Async, $Ctx](dispatcher: $Dispatcher[F], channel: $Channel, mkMetadata: $Ctx => F[$Metadata], clientOptions: $ClientOptions): $serviceNameFs2[F, $Ctx] = new $serviceNameFs2[F, $Ctx] {" + _.addStringMargin( + s"""|def mkClientFull[F[_], G[_]: $Async, $Ctx]( + | dispatcher: $Dispatcher[G], + | channel: $Channel, + | clientAspect: ${ClientAspect}[F, G, $Ctx], + | clientOptions: $ClientOptions + |): $serviceNameFs2[F, $Ctx] = new $serviceNameFs2[F, $Ctx] {""" ).indent .call(serviceMethodImplementations) .outdent @@ -125,8 +149,13 @@ class Fs2GrpcServicePrinter(service: ServiceDescriptor, serviceSuffix: String, d } private[this] def serviceBinding: PrinterEndo = { - _.add( - s"protected def serviceBinding[F[_]: $Async, $Ctx](dispatcher: $Dispatcher[F], serviceImpl: $serviceNameFs2[F, $Ctx], mkCtx: $Metadata => F[$Ctx], serverOptions: $ServerOptions): $ServerServiceDefinition = {" + _.addStringMargin( + s"""|protected def serviceBindingFull[F[_], G[_]: $Async, $Ctx]( + | dispatcher: $Dispatcher[G], + | serviceImpl: $serviceNameFs2[F, $Ctx], + | serviceAspect: ${ServiceAspect}[F, G, $Ctx], + | serverOptions: $ServerOptions + |) = {""" ).indent .add(s"$ServerServiceDefinition") .call(serviceBindingImplementations) @@ -152,6 +181,8 @@ object Fs2GrpcServicePrinter { private val effPkg = "_root_.cats.effect" private val fs2Pkg = "_root_.fs2" private val fs2grpcPkg = "_root_.fs2.grpc" + private val fs2grpcServerPkg = "_root_.fs2.grpc.server" + private val fs2grpcClientPkg = "_root_.fs2.grpc.client" private val grpcPkg = "_root_.io.grpc" // / @@ -173,6 +204,10 @@ object Fs2GrpcServicePrinter { val Channel = s"$grpcPkg.Channel" val Metadata = s"$grpcPkg.Metadata" + val ServiceAspect = s"${fs2grpcServerPkg}.ServiceAspect" + val ServerCallContext = s"${fs2grpcServerPkg}.ServerCallContext" + val ClientAspect = s"${fs2grpcClientPkg}.ClientAspect" + val ClientCallContext = s"${fs2grpcClientPkg}.ClientCallContext" } } diff --git a/e2e/src/test/resources/TestServiceFs2Grpc.scala.txt b/e2e/src/test/resources/TestServiceFs2Grpc.scala.txt index 3c4b5add..c738c35e 100644 --- a/e2e/src/test/resources/TestServiceFs2Grpc.scala.txt +++ b/e2e/src/test/resources/TestServiceFs2Grpc.scala.txt @@ -11,36 +11,86 @@ trait TestServiceFs2Grpc[F[_], A] { object TestServiceFs2Grpc extends _root_.fs2.grpc.GeneratedCompanion[TestServiceFs2Grpc] { - def mkClient[F[_]: _root_.cats.effect.Async, A](dispatcher: _root_.cats.effect.std.Dispatcher[F], channel: _root_.io.grpc.Channel, mkMetadata: A => F[_root_.io.grpc.Metadata], clientOptions: _root_.fs2.grpc.client.ClientOptions): TestServiceFs2Grpc[F, A] = new TestServiceFs2Grpc[F, A] { - def noStreaming(request: hello.world.TestMessage, ctx: A): F[hello.world.TestMessage] = { - mkMetadata(ctx).flatMap { m => - _root_.fs2.grpc.client.Fs2ClientCall[F](channel, hello.world.TestServiceGrpc.METHOD_NO_STREAMING, dispatcher, clientOptions).flatMap(_.unaryToUnaryCall(request, m)) - } - } - def clientStreaming(request: _root_.fs2.Stream[F, hello.world.TestMessage], ctx: A): F[hello.world.TestMessage] = { - mkMetadata(ctx).flatMap { m => - _root_.fs2.grpc.client.Fs2ClientCall[F](channel, hello.world.TestServiceGrpc.METHOD_CLIENT_STREAMING, dispatcher, clientOptions).flatMap(_.streamingToUnaryCall(request, m)) - } - } - def serverStreaming(request: hello.world.TestMessage, ctx: A): _root_.fs2.Stream[F, hello.world.TestMessage] = { - _root_.fs2.Stream.eval(mkMetadata(ctx)).flatMap { m => - _root_.fs2.Stream.eval(_root_.fs2.grpc.client.Fs2ClientCall[F](channel, hello.world.TestServiceGrpc.METHOD_SERVER_STREAMING, dispatcher, clientOptions)).flatMap(_.unaryToStreamingCall(request, m)) - } - } - def bothStreaming(request: _root_.fs2.Stream[F, hello.world.TestMessage], ctx: A): _root_.fs2.Stream[F, hello.world.TestMessage] = { - _root_.fs2.Stream.eval(mkMetadata(ctx)).flatMap { m => - _root_.fs2.Stream.eval(_root_.fs2.grpc.client.Fs2ClientCall[F](channel, hello.world.TestServiceGrpc.METHOD_BOTH_STREAMING, dispatcher, clientOptions)).flatMap(_.streamingToStreamingCall(request, m)) - } - } + def mkClientFull[F[_], G[_]: _root_.cats.effect.Async, A]( + dispatcher: _root_.cats.effect.std.Dispatcher[G], + channel: _root_.io.grpc.Channel, + clientAspect: _root_.fs2.grpc.client.ClientAspect[F, G, A], + clientOptions: _root_.fs2.grpc.client.ClientOptions + ): TestServiceFs2Grpc[F, A] = new TestServiceFs2Grpc[F, A] { + def noStreaming(request: hello.world.TestMessage, ctx: A): F[hello.world.TestMessage] = + clientAspect.visitUnaryToUnary[hello.world.TestMessage, hello.world.TestMessage]( + _root_.fs2.grpc.client.ClientCallContext(ctx, hello.world.TestServiceGrpc.METHOD_NO_STREAMING), + request, + (req, m) => _root_.fs2.grpc.client.Fs2ClientCall[G](channel, hello.world.TestServiceGrpc.METHOD_NO_STREAMING, dispatcher, clientOptions).flatMap(_.unaryToUnaryCall(req, m)) + ) + def clientStreaming(request: _root_.fs2.Stream[F, hello.world.TestMessage], ctx: A): F[hello.world.TestMessage] = + clientAspect.visitStreamingToUnary[hello.world.TestMessage, hello.world.TestMessage]( + _root_.fs2.grpc.client.ClientCallContext(ctx, hello.world.TestServiceGrpc.METHOD_CLIENT_STREAMING), + request, + (req, m) => _root_.fs2.grpc.client.Fs2ClientCall[G](channel, hello.world.TestServiceGrpc.METHOD_CLIENT_STREAMING, dispatcher, clientOptions).flatMap(_.streamingToUnaryCall(req, m)) + ) + def serverStreaming(request: hello.world.TestMessage, ctx: A): _root_.fs2.Stream[F, hello.world.TestMessage] = + clientAspect.visitUnaryToStreaming[hello.world.TestMessage, hello.world.TestMessage]( + _root_.fs2.grpc.client.ClientCallContext(ctx, hello.world.TestServiceGrpc.METHOD_SERVER_STREAMING), + request, + (req, m) => _root_.fs2.Stream.eval(_root_.fs2.grpc.client.Fs2ClientCall[G](channel, hello.world.TestServiceGrpc.METHOD_SERVER_STREAMING, dispatcher, clientOptions)).flatMap(_.unaryToStreamingCall(req, m)) + ) + def bothStreaming(request: _root_.fs2.Stream[F, hello.world.TestMessage], ctx: A): _root_.fs2.Stream[F, hello.world.TestMessage] = + clientAspect.visitStreamingToStreaming[hello.world.TestMessage, hello.world.TestMessage]( + _root_.fs2.grpc.client.ClientCallContext(ctx, hello.world.TestServiceGrpc.METHOD_BOTH_STREAMING), + request, + (req, m) => _root_.fs2.Stream.eval(_root_.fs2.grpc.client.Fs2ClientCall[G](channel, hello.world.TestServiceGrpc.METHOD_BOTH_STREAMING, dispatcher, clientOptions)).flatMap(_.streamingToStreamingCall(req, m)) + ) } - protected def serviceBinding[F[_]: _root_.cats.effect.Async, A](dispatcher: _root_.cats.effect.std.Dispatcher[F], serviceImpl: TestServiceFs2Grpc[F, A], mkCtx: _root_.io.grpc.Metadata => F[A], serverOptions: _root_.fs2.grpc.server.ServerOptions): _root_.io.grpc.ServerServiceDefinition = { + protected def serviceBindingFull[F[_], G[_]: _root_.cats.effect.Async, A]( + dispatcher: _root_.cats.effect.std.Dispatcher[G], + serviceImpl: TestServiceFs2Grpc[F, A], + serviceAspect: _root_.fs2.grpc.server.ServiceAspect[F, G, A], + serverOptions: _root_.fs2.grpc.server.ServerOptions + ) = { _root_.io.grpc.ServerServiceDefinition .builder(hello.world.TestServiceGrpc.SERVICE) - .addMethod(hello.world.TestServiceGrpc.METHOD_NO_STREAMING, _root_.fs2.grpc.server.Fs2ServerCallHandler[F](dispatcher, serverOptions).unaryToUnaryCall[hello.world.TestMessage, hello.world.TestMessage]((r, m) => mkCtx(m).flatMap(serviceImpl.noStreaming(r, _)))) - .addMethod(hello.world.TestServiceGrpc.METHOD_CLIENT_STREAMING, _root_.fs2.grpc.server.Fs2ServerCallHandler[F](dispatcher, serverOptions).streamingToUnaryCall[hello.world.TestMessage, hello.world.TestMessage]((r, m) => mkCtx(m).flatMap(serviceImpl.clientStreaming(r, _)))) - .addMethod(hello.world.TestServiceGrpc.METHOD_SERVER_STREAMING, _root_.fs2.grpc.server.Fs2ServerCallHandler[F](dispatcher, serverOptions).unaryToStreamingCall[hello.world.TestMessage, hello.world.TestMessage]((r, m) => _root_.fs2.Stream.eval(mkCtx(m)).flatMap(serviceImpl.serverStreaming(r, _)))) - .addMethod(hello.world.TestServiceGrpc.METHOD_BOTH_STREAMING, _root_.fs2.grpc.server.Fs2ServerCallHandler[F](dispatcher, serverOptions).streamingToStreamingCall[hello.world.TestMessage, hello.world.TestMessage]((r, m) => _root_.fs2.Stream.eval(mkCtx(m)).flatMap(serviceImpl.bothStreaming(r, _)))) + .addMethod( + hello.world.TestServiceGrpc.METHOD_NO_STREAMING, + _root_.fs2.grpc.server.Fs2ServerCallHandler[G](dispatcher, serverOptions).unaryToUnaryCall[hello.world.TestMessage, hello.world.TestMessage]{ (r, m) => + serviceAspect.visitUnaryToUnary[hello.world.TestMessage, hello.world.TestMessage]( + _root_.fs2.grpc.server.ServerCallContext(m, hello.world.TestServiceGrpc.METHOD_NO_STREAMING), + r, + (r, m) => serviceImpl.noStreaming(r, m) + ) + } + ) + .addMethod( + hello.world.TestServiceGrpc.METHOD_CLIENT_STREAMING, + _root_.fs2.grpc.server.Fs2ServerCallHandler[G](dispatcher, serverOptions).streamingToUnaryCall[hello.world.TestMessage, hello.world.TestMessage]{ (r, m) => + serviceAspect.visitStreamingToUnary[hello.world.TestMessage, hello.world.TestMessage]( + _root_.fs2.grpc.server.ServerCallContext(m, hello.world.TestServiceGrpc.METHOD_CLIENT_STREAMING), + r, + (r, m) => serviceImpl.clientStreaming(r, m) + ) + } + ) + .addMethod( + hello.world.TestServiceGrpc.METHOD_SERVER_STREAMING, + _root_.fs2.grpc.server.Fs2ServerCallHandler[G](dispatcher, serverOptions).unaryToStreamingCall[hello.world.TestMessage, hello.world.TestMessage]{ (r, m) => + serviceAspect.visitUnaryToStreaming[hello.world.TestMessage, hello.world.TestMessage]( + _root_.fs2.grpc.server.ServerCallContext(m, hello.world.TestServiceGrpc.METHOD_SERVER_STREAMING), + r, + (r, m) => serviceImpl.serverStreaming(r, m) + ) + } + ) + .addMethod( + hello.world.TestServiceGrpc.METHOD_BOTH_STREAMING, + _root_.fs2.grpc.server.Fs2ServerCallHandler[G](dispatcher, serverOptions).streamingToStreamingCall[hello.world.TestMessage, hello.world.TestMessage]{ (r, m) => + serviceAspect.visitStreamingToStreaming[hello.world.TestMessage, hello.world.TestMessage]( + _root_.fs2.grpc.server.ServerCallContext(m, hello.world.TestServiceGrpc.METHOD_BOTH_STREAMING), + r, + (r, m) => serviceImpl.bothStreaming(r, m) + ) + } + ) .build() } diff --git a/e2e/src/test/scala/fs2/grpc/e2e/AspectSpec.scala b/e2e/src/test/scala/fs2/grpc/e2e/AspectSpec.scala new file mode 100644 index 00000000..6cefd257 --- /dev/null +++ b/e2e/src/test/scala/fs2/grpc/e2e/AspectSpec.scala @@ -0,0 +1,297 @@ +package fs2.grpc.e2e + +import hello.world._ +import io.grpc.inprocess._ +import cats.effect._ +import cats.implicits._ +import munit._ +import cats.effect.std.UUIDGen +import io.grpc._ +import scala.jdk.CollectionConverters._ +import cats.effect.std.Dispatcher +import fs2.grpc.server._ +import fs2.grpc.client._ +import fs2.grpc.syntax.all._ +import cats.data._ +import cats._ + +class AspectSpec extends CatsEffectSuite with CatsEffectFunFixtures { + def startServices[F[_]](id: String)(xs: ServerServiceDefinition*)(implicit F: Sync[F]): Resource[F, Server] = + InProcessServerBuilder + .forName(id.toString()) + .addServices(xs.toList.asJava) + .resource[F] + .evalTap(s => F.delay(s.start())) + + def testConnection[F[_]: Async, G[_], A, B]( + service: TestServiceFs2Grpc[G, A], + serviceAspect: ServiceAspect[G, F, A], + clientAspect: ClientAspect[G, F, B] + ): Resource[F, TestServiceFs2Grpc[G, B]] = + Dispatcher.parallel[F].flatMap { d => + Resource.eval(UUIDGen.randomUUID[F]).flatMap { id => + startServices[F](id.toString())( + TestServiceFs2Grpc.serviceFull[G, F, A]( + d, + service, + serviceAspect, + ServerOptions.default + ) + ) >> InProcessChannelBuilder.forName(id.toString()).usePlaintext().resource[F].map { conn => + TestServiceFs2Grpc.mkClientFull[G, F, B]( + d, + conn, + clientAspect, + ClientOptions.default + ) + } + } + } + + test("tracing requests should work as expected") { + case class TracingKey(value: String) + case class Span(name: String, parent: Either[Span, Option[TracingKey]]) { + def traceKey: Option[TracingKey] = parent.leftMap(_.traceKey).merge + } + case class SpanInfo(span: Span, messages: List[String]) + type WriteSpan[A] = WriterT[IO, List[SpanInfo], A] + type Traced[A] = Kleisli[WriteSpan, Span, A] + val Traced: Monad[Traced] = Monad[Traced] + val liftK: IO ~> Traced = WriterT.liftK[IO, List[SpanInfo]] andThen Kleisli.liftK[WriteSpan, Span] + def span[A](name: String)(fa: Traced[A]): Traced[A] = + fa.local[Span](parent => Span(name, Left(parent))) + + def spanStream[A](name: String)(fa: fs2.Stream[Traced, A]): fs2.Stream[Traced, A] = { + val current: Traced[Span] = Kleisli.ask + fs2.Stream.eval(current).flatMap { parent => + fa.translate(new (Traced ~> Traced) { + def apply[B](fa: Traced[B]): Traced[B] = + Kleisli.local((_: Span) => Span(name, Left(parent)))(fa) + }) + } + } + + def tell(spanInfos: List[SpanInfo]): Traced[Unit] = + Kleisli.liftF(WriterT.tell[IO, List[SpanInfo]](spanInfos)) + + def log(msgs: String*): Traced[Unit] = Kleisli.ask[WriteSpan, Span].flatMap { current => + tell(List(SpanInfo(current, msgs.toList))) + } + + val tracingHeaderKey = Metadata.Key.of("TRACE_KEY", Metadata.ASCII_STRING_MARSHALLER) + def getTraceHeader(ctx: Metadata): Option[TracingKey] = + Option(ctx.get(tracingHeaderKey)).map(TracingKey(_)) + + def serializeTraceHeader(key: TracingKey): Metadata = { + val m = new Metadata + m.put(tracingHeaderKey, key.value) + m + } + + def getTracingHeader: Traced[Metadata] = + Kleisli.ask[WriteSpan, Span].map { span => + span.traceKey.map(serializeTraceHeader).getOrElse(new Metadata) + } + + val service = new TestServiceFs2Grpc[Traced, Metadata] { + override def noStreaming(request: TestMessage, ctx: Metadata): Traced[TestMessage] = + span("noStreaming") { + log("noStreaming") >> + Traced.pure(TestMessage.defaultInstance) + } + + override def clientStreaming(request: fs2.Stream[Traced, TestMessage], ctx: Metadata): Traced[TestMessage] = + span("clientStreaming") { + log("clientStreaming") >> + request.compile.last.map(_.getOrElse(TestMessage.defaultInstance)) + } + + override def serverStreaming(request: TestMessage, ctx: Metadata): fs2.Stream[Traced, TestMessage] = + spanStream("serverStreaming") { + fs2.Stream(request).repeatN(2).evalTap(_ => log("serverStreaming")) + } + + override def bothStreaming( + request: fs2.Stream[Traced, TestMessage], + ctx: Metadata + ): fs2.Stream[Traced, TestMessage] = + spanStream("bothStreaming") { + request.evalTap(_ => log("bothStreaming")) + } + } + + IO.ref(List.empty[SpanInfo]).flatMap { state => + def runRootTrace[A](cctx: ServerCallContext[?, ?])(fa: Traced[A]): IO[A] = { + val root = Span(cctx.methodDescriptor.getFullMethodName(), Right(getTraceHeader(cctx.metadata))) + fa.run(root).run.flatMap { case (xs, a) => + state.update(_ ++ xs) as a + } + } + + def runRootTraceStreamed[A](cctx: ServerCallContext[?, ?])(fa: fs2.Stream[Traced, A]): fs2.Stream[IO, A] = + fa.translate(new (Traced ~> IO) { + def apply[B](fa: Traced[B]): IO[B] = runRootTrace(cctx)(fa) + }) + + val tracingServiceAspect = new ServiceAspect[Traced, IO, Metadata] { + override def visitUnaryToUnary[Req, Res]( + callCtx: ServerCallContext[Req, Res], + req: Req, + next: (Req, Metadata) => Traced[Res] + ): IO[Res] = runRootTrace(callCtx)(next(req, callCtx.metadata)) + + override def visitUnaryToStreaming[Req, Res]( + callCtx: ServerCallContext[Req, Res], + req: Req, + next: (Req, Metadata) => fs2.Stream[Traced, Res] + ): fs2.Stream[IO, Res] = runRootTraceStreamed(callCtx)(next(req, callCtx.metadata)) + + override def visitStreamingToUnary[Req, Res]( + callCtx: ServerCallContext[Req, Res], + req: fs2.Stream[IO, Req], + next: (fs2.Stream[Traced, Req], Metadata) => Traced[Res] + ): IO[Res] = runRootTrace(callCtx)(next(req.translate(liftK), callCtx.metadata)) + + override def visitStreamingToStreaming[Req, Res]( + callCtx: ServerCallContext[Req, Res], + req: fs2.Stream[IO, Req], + next: (fs2.Stream[Traced, Req], Metadata) => fs2.Stream[Traced, Res] + ): fs2.Stream[IO, Res] = runRootTraceStreamed(callCtx)(next(req.translate(liftK), callCtx.metadata)) + } + + val tracingClientAspect = new ClientAspect[Traced, IO, Unit] { + override def visitUnaryToUnary[Req, Res]( + callCtx: ClientCallContext[Req, Res, Unit], + req: Req, + request: (Req, Metadata) => IO[Res] + ): Traced[Res] = + getTracingHeader.flatMap(md => liftK(request(req, md))) + + override def visitUnaryToStreaming[Req, Res]( + callCtx: ClientCallContext[Req, Res, Unit], + req: Req, + request: (Req, Metadata) => fs2.Stream[IO, Res] + ): fs2.Stream[Traced, Res] = + fs2.Stream.eval(getTracingHeader).flatMap(md => request(req, md).translate(liftK)) + + override def visitStreamingToUnary[Req, Res]( + callCtx: ClientCallContext[Req, Res, Unit], + req: fs2.Stream[Traced, Req], + request: (fs2.Stream[IO, Req], Metadata) => IO[Res] + ): Traced[Res] = Kleisli.ask[WriteSpan, Span].flatMap { parent => + getTracingHeader.flatMap { md => + liftK(IO.ref(List.empty[SpanInfo])).flatMap { state => + val req2 = req.translate(new (Traced ~> IO) { + def apply[A](fa: Traced[A]): IO[A] = + fa.run(parent).run.flatMap { case (xs, a) => + state.update(_ ++ xs) as a + } + }) + liftK(request(req2, md)) <* (liftK(state.get) >>= tell) + } + } + } + + override def visitStreamingToStreaming[Req, Res]( + callCtx: ClientCallContext[Req, Res, Unit], + req: fs2.Stream[Traced, Req], + request: (fs2.Stream[IO, Req], Metadata) => fs2.Stream[IO, Res] + ): fs2.Stream[Traced, Res] = + fs2.Stream.eval(Kleisli.ask[WriteSpan, Span]).flatMap { parent => + fs2.Stream.eval(getTracingHeader).flatMap { md => + fs2.Stream.eval(liftK(IO.ref(List.empty[SpanInfo]))).flatMap { state => + val req2 = req.translate(new (Traced ~> IO) { + def apply[A](fa: Traced[A]): IO[A] = + fa.run(parent).run.flatMap { case (xs, a) => + state.update(_ ++ xs) as a + } + }) + request(req2, md).translate(liftK) ++ fs2.Stream.exec((liftK(state.get) >>= tell)) + } + } + } + } + + testConnection[IO, Traced, Metadata, Unit]( + service, + tracingServiceAspect, + tracingClientAspect + ).use{ (client: TestServiceFs2Grpc[Traced, Unit]) => + + def testWithKey(rootKey: Option[TracingKey] = None) = { + def trackServer[A](fa: IO[A]): IO[List[SpanInfo]] = + state.set(Nil) >> fa >> state.get + + def trackClient[A](traced: Traced[A]): IO[List[SpanInfo]] = + traced.run(Span("root", Right(rootKey))).written + + def trackAndAssertServer[A](name: String, n: Int)(fa: IO[A])(implicit loc: Location): IO[Unit] = + trackServer(fa).map{ serverInfos => + assertEquals(serverInfos.size, n) + serverInfos.foreach{ si => + assertEquals(si.span.name, name) + assert(clue(si.span.parent).isLeft, "is child") + assertEquals(si.messages, List(name)) + } + } + + def trackAndAssertClient[A](name: String, n: Int)(fa: Traced[A])(implicit loc: Location): IO[Unit] = + trackClient(fa).map{ clientInfos => + assertEquals(clientInfos.size, n) + clientInfos.foreach{ ci => + assertEquals(ci.span.name, s"client-${name}") + assert(clue(ci.span.parent).isLeft, "is child") + assertEquals(ci.messages, List(s"client-${name}")) + } + } + + val noStreaming = trackAndAssertServer("noStreaming", 1) { + trackAndAssertClient("noStreaming", 1){ + span("client-noStreaming") { + log("client-noStreaming") >> + client.noStreaming(TestMessage.defaultInstance, ()) + } + } + } + + val clientStreaming = trackAndAssertServer("clientStreaming", 1) { + trackAndAssertClient("clientStreaming", 2){ + val req = fs2.Stream.eval{ + span("client-clientStreaming") { + log("client-clientStreaming").as(TestMessage.defaultInstance) + } + }.repeatN(2) + + client.clientStreaming(req, ()) + } + } + + val serverStreaming = trackAndAssertServer("serverStreaming", 2) { + trackAndAssertClient("serverStreaming", 1){ + span("client-serverStreaming") { + log("client-serverStreaming") >> + client.serverStreaming(TestMessage.defaultInstance, ()).compile.drain + } + } + } + + val bothStreaming = trackAndAssertServer("bothStreaming", 2) { + trackAndAssertClient("bothStreaming", 2){ + val req = fs2.Stream.eval{ + span("client-bothStreaming") { + log("client-bothStreaming").as(TestMessage.defaultInstance) + } + }.repeatN(2) + + client.bothStreaming(req, ()).compile.drain + } + } + + noStreaming >> clientStreaming >> serverStreaming >> bothStreaming + } + + testWithKey() >> testWithKey(Some(TracingKey("my_tracing_key"))) + } + } + } +} diff --git a/runtime/src/main/scala/fs2/grpc/GeneratedCompanion.scala b/runtime/src/main/scala/fs2/grpc/GeneratedCompanion.scala index a083d623..070a6398 100644 --- a/runtime/src/main/scala/fs2/grpc/GeneratedCompanion.scala +++ b/runtime/src/main/scala/fs2/grpc/GeneratedCompanion.scala @@ -27,7 +27,9 @@ import cats.effect.{Async, Resource} import cats.effect.std.Dispatcher import io.grpc._ import fs2.grpc.client.ClientOptions +import fs2.grpc.client.ClientAspect import fs2.grpc.server.ServerOptions +import fs2.grpc.server.ServiceAspect trait GeneratedCompanion[Service[*[_], _]] { @@ -35,12 +37,25 @@ trait GeneratedCompanion[Service[*[_], _]] { ///=== Client ========================================================================================================== + def mkClientFull[F[_], G[_]: Async, A]( + dispatcher: Dispatcher[G], + channel: Channel, + clientAspect: ClientAspect[F, G, A], + clientOptions: ClientOptions + ): Service[F, A] + def mkClient[F[_]: Async, A]( dispatcher: Dispatcher[F], channel: Channel, mkMetadata: A => F[Metadata], clientOptions: ClientOptions - ): Service[F, A] + ): Service[F, A] = + mkClientFull[F, F, A]( + dispatcher, + channel, + ClientAspect.default[F].contraModify(mkMetadata), + clientOptions + ) final def mkClient[F[_]: Async, A]( dispatcher: Dispatcher[F], @@ -116,12 +131,33 @@ trait GeneratedCompanion[Service[*[_], _]] { ///=== Service ========================================================================================================= + protected def serviceBindingFull[F[_], G[_]: Async, A]( + dispatcher: Dispatcher[G], + serviceImpl: Service[F, A], + serviceAspect: ServiceAspect[F, G, A], + serverOptions: ServerOptions + ): ServerServiceDefinition + protected def serviceBinding[F[_]: Async, A]( dispatcher: Dispatcher[F], serviceImpl: Service[F, A], mkCtx: Metadata => F[A], serverOptions: ServerOptions - ): ServerServiceDefinition + ): ServerServiceDefinition = + serviceBindingFull[F, F, A]( + dispatcher, + serviceImpl, + ServiceAspect.default[F].modify(mkCtx), + serverOptions + ) + + final def serviceFull[F[_], G[_]: Async, A]( + dispatcher: Dispatcher[G], + serviceImpl: Service[F, A], + serviceAspect: ServiceAspect[F, G, A], + serverOptions: ServerOptions + ): ServerServiceDefinition = + serviceBindingFull[F, G, A](dispatcher, serviceImpl, serviceAspect, serverOptions) final def service[F[_]: Async, A]( dispatcher: Dispatcher[F], diff --git a/runtime/src/main/scala/fs2/grpc/client/ClientAspect.scala b/runtime/src/main/scala/fs2/grpc/client/ClientAspect.scala new file mode 100644 index 00000000..24938be9 --- /dev/null +++ b/runtime/src/main/scala/fs2/grpc/client/ClientAspect.scala @@ -0,0 +1,99 @@ +package fs2.grpc.client + +import io.grpc._ +import cats._ +import cats.syntax.all._ +import fs2.Stream + +final case class ClientCallContext[Req, Res, A]( + ctx: A, + methodDescriptor: MethodDescriptor[Req, Res] +) + +trait ClientAspect[F[_], G[_], A] { self => + def visitUnaryToUnary[Req, Res]( + callCtx: ClientCallContext[Req, Res, A], + req: Req, + request: (Req, Metadata) => G[Res] + ): F[Res] + + def visitUnaryToStreaming[Req, Res]( + callCtx: ClientCallContext[Req, Res, A], + req: Req, + request: (Req, Metadata) => Stream[G, Res] + ): Stream[F, Res] + + def visitStreamingToUnary[Req, Res]( + callCtx: ClientCallContext[Req, Res, A], + req: Stream[F, Req], + request: (Stream[G, Req], Metadata) => G[Res] + ): F[Res] + + def visitStreamingToStreaming[Req, Res]( + callCtx: ClientCallContext[Req, Res, A], + req: Stream[F, Req], + request: (Stream[G, Req], Metadata) => Stream[G, Res] + ): Stream[F, Res] + + def contraModify[B](f: B => F[A])(implicit F: Monad[F]): ClientAspect[F, G, B] = + new ClientAspect[F, G, B] { + def modCtx[Req, Res](ccc: ClientCallContext[Req, Res, B]): F[ClientCallContext[Req, Res, A]] = + f(ccc.ctx).map(a => ccc.copy(ctx = a)) + + override def visitUnaryToUnary[Req, Res]( + callCtx: ClientCallContext[Req, Res, B], + req: Req, + request: (Req, Metadata) => G[Res] + ): F[Res] = + modCtx(callCtx).flatMap(self.visitUnaryToUnary(_, req, request)) + + override def visitUnaryToStreaming[Req, Res]( + callCtx: ClientCallContext[Req, Res, B], + req: Req, + request: (Req, Metadata) => Stream[G, Res] + ): Stream[F, Res] = + Stream.eval(modCtx(callCtx)).flatMap(self.visitUnaryToStreaming(_, req, request)) + + override def visitStreamingToUnary[Req, Res]( + callCtx: ClientCallContext[Req, Res, B], + req: Stream[F, Req], + request: (Stream[G, Req], Metadata) => G[Res] + ): F[Res] = + modCtx(callCtx).flatMap(self.visitStreamingToUnary(_, req, request)) + + override def visitStreamingToStreaming[Req, Res]( + callCtx: ClientCallContext[Req, Res, B], + req: Stream[F, Req], + request: (Stream[G, Req], Metadata) => Stream[G, Res] + ): Stream[F, Res] = + Stream.eval(modCtx(callCtx)).flatMap(self.visitStreamingToStreaming(_, req, request)) + } +} + +object ClientAspect { + def default[F[_]] = new ClientAspect[F, F, Metadata] { + override def visitUnaryToUnary[Req, Res]( + callCtx: ClientCallContext[Req, Res, Metadata], + req: Req, + request: (Req, Metadata) => F[Res] + ): F[Res] = request(req, callCtx.ctx) + + override def visitUnaryToStreaming[Req, Res]( + callCtx: ClientCallContext[Req, Res, Metadata], + req: Req, + request: (Req, Metadata) => Stream[F, Res] + ): Stream[F, Res] = request(req, callCtx.ctx) + + override def visitStreamingToUnary[Req, Res]( + callCtx: ClientCallContext[Req, Res, Metadata], + req: Stream[F, Req], + request: (Stream[F, Req], Metadata) => F[Res] + ): F[Res] = request(req, callCtx.ctx) + + override def visitStreamingToStreaming[Req, Res]( + callCtx: ClientCallContext[Req, Res, Metadata], + req: Stream[F, Req], + request: (Stream[F, Req], Metadata) => Stream[F, Res] + ): Stream[F, Res] = request(req, callCtx.ctx) + } +} diff --git a/runtime/src/main/scala/fs2/grpc/server/ServiceAspect.scala b/runtime/src/main/scala/fs2/grpc/server/ServiceAspect.scala new file mode 100644 index 00000000..56c7365d --- /dev/null +++ b/runtime/src/main/scala/fs2/grpc/server/ServiceAspect.scala @@ -0,0 +1,112 @@ +package fs2.grpc.server + +import io.grpc._ +import cats._ +import cats.syntax.all._ +import fs2.Stream + +final case class ServerCallContext[Req, Res]( + metadata: Metadata, + methodDescriptor: MethodDescriptor[Req, Res] +) + +trait ServiceAspect[F[_], G[_], A] { self => + def visitUnaryToUnary[Req, Res]( + callCtx: ServerCallContext[Req, Res], + req: Req, + next: (Req, A) => F[Res] + ): G[Res] + + def visitUnaryToStreaming[Req, Res]( + callCtx: ServerCallContext[Req, Res], + req: Req, + next: (Req, A) => fs2.Stream[F, Res] + ): fs2.Stream[G, Res] + + def visitStreamingToUnary[Req, Res]( + callCtx: ServerCallContext[Req, Res], + req: fs2.Stream[G, Req], + next: (fs2.Stream[F, Req], A) => F[Res] + ): G[Res] + + def visitStreamingToStreaming[Req, Res]( + callCtx: ServerCallContext[Req, Res], + req: fs2.Stream[G, Req], + next: (fs2.Stream[F, Req], A) => fs2.Stream[F, Res] + ): fs2.Stream[G, Res] + + def modify[B](f: A => F[B])(implicit F: Monad[F]): ServiceAspect[F, G, B] = + new ServiceAspect[F, G, B] { + override def visitUnaryToUnary[Req, Res]( + callCtx: ServerCallContext[Req, Res], + req: Req, + request: (Req, B) => F[Res] + ): G[Res] = + self.visitUnaryToUnary[Req, Res]( + callCtx, + req, + (req, a) => f(a).flatMap(request(req, _)) + ) + + override def visitUnaryToStreaming[Req, Res]( + callCtx: ServerCallContext[Req, Res], + req: Req, + request: (Req, B) => Stream[F, Res] + ): Stream[G, Res] = + self.visitUnaryToStreaming[Req, Res]( + callCtx, + req, + (req, a) => fs2.Stream.eval(f(a)).flatMap(request(req, _)) + ) + + override def visitStreamingToUnary[Req, Res]( + callCtx: ServerCallContext[Req, Res], + req: fs2.Stream[G, Req], + request: (Stream[F, Req], B) => F[Res] + ): G[Res] = + self.visitStreamingToUnary[Req, Res]( + callCtx, + req, + (req, a) => f(a).flatMap(request(req, _)) + ) + + override def visitStreamingToStreaming[Req, Res]( + callCtx: ServerCallContext[Req, Res], + req: fs2.Stream[G, Req], + request: (Stream[F, Req], B) => Stream[F, Res] + ): Stream[G, Res] = + self.visitStreamingToStreaming[Req, Res]( + callCtx, + req, + (req, a) => fs2.Stream.eval(f(a)).flatMap(request(req, _)) + ) + } +} + +object ServiceAspect { + def default[F[_]] = new ServiceAspect[F, F, Metadata] { + override def visitUnaryToUnary[Req, Res]( + callCtx: ServerCallContext[Req, Res], + req: Req, + request: (Req, Metadata) => F[Res] + ): F[Res] = request(req, callCtx.metadata) + + override def visitUnaryToStreaming[Req, Res]( + callCtx: ServerCallContext[Req, Res], + req: Req, + request: (Req, Metadata) => Stream[F, Res] + ): Stream[F, Res] = request(req, callCtx.metadata) + + override def visitStreamingToUnary[Req, Res]( + callCtx: ServerCallContext[Req, Res], + req: fs2.Stream[F, Req], + request: (Stream[F, Req], Metadata) => F[Res] + ): F[Res] = request(req, callCtx.metadata) + + override def visitStreamingToStreaming[Req, Res]( + callCtx: ServerCallContext[Req, Res], + req: fs2.Stream[F, Req], + request: (Stream[F, Req], Metadata) => Stream[F, Res] + ): Stream[F, Res] = request(req, callCtx.metadata) + } +}