Skip to content

Commit

Permalink
Add trailers metadata support to clientStreaming and noStreaming
Browse files Browse the repository at this point in the history
  • Loading branch information
TalkingFoxMid committed Mar 12, 2024
1 parent 172bca0 commit 75b9ba9
Show file tree
Hide file tree
Showing 14 changed files with 457 additions and 141 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
/*
* Copyright (c) 2018 Gary Coady / Fs2 Grpc Developers
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of
* this software and associated documentation files (the "Software"), to deal in
* the Software without restriction, including without limitation the rights to
* use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
* the Software, and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
* CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/

package fs2.grpc.codegen

import com.google.protobuf.Descriptors.{MethodDescriptor, ServiceDescriptor}
import fs2.grpc.codegen.Fs2AbstractServicePrinter.constants.{
Async,
Channel,
ClientOptions,
Companion,
Ctx,
Dispatcher,
Fs2ClientCall,
Fs2ServerCallHandler,
Metadata,
ServerOptions,
ServerServiceDefinition,
Stream
}
import scalapb.compiler.{DescriptorImplicits, FunctionalPrinter}
import scalapb.compiler.FunctionalPrinter.PrinterEndo

abstract class Fs2AbstractServicePrinter extends Fs2ServicePrinter {

val service: ServiceDescriptor
val serviceSuffix: String
val di: DescriptorImplicits

import di._

private[this] val serviceName: String = service.name
private[this] val serviceNameFs2: String = s"$serviceName${serviceSuffix}"
private[this] val servicePkgName: String = service.getFile.scalaPackage.fullName

protected def serviceMethodSignature(method: MethodDescriptor): String

protected[this] def handleMethod(method: MethodDescriptor): String

private[this] def createClientCall(method: MethodDescriptor) = {
val basicClientCall =
s"$Fs2ClientCall[F](channel, ${method.grpcDescriptor.fullName}, dispatcher, clientOptions)"
if (method.isServerStreaming)
s"$Stream.eval($basicClientCall)"
else
basicClientCall
}

private[this] def serviceMethodImplementation(method: MethodDescriptor): PrinterEndo = { p =>
val mkMetadata = if (method.isServerStreaming) s"$Stream.eval(mkMetadata(ctx))" else "mkMetadata(ctx)"

p.add(serviceMethodSignature(method) + " = {")
.indent
.add(s"$mkMetadata.flatMap { m =>")
.indent
.add(s"${createClientCall(method)}.flatMap(_.${handleMethod(method)}(request, m))")
.outdent
.add("}")
.outdent
.add("}")
}

private[this] def serviceBindingImplementation(method: MethodDescriptor): PrinterEndo = { p =>
val inType = method.inputType.scalaType
val outType = method.outputType.scalaType
val descriptor = method.grpcDescriptor.fullName
val handler = s"$Fs2ServerCallHandler[F](dispatcher, serverOptions).${handleMethod(method)}[$inType, $outType]"

val serviceCall = s"serviceImpl.${method.name}"
val eval = if (method.isServerStreaming) s"$Stream.eval(mkCtx(m))" else "mkCtx(m)"

p.add(s".addMethod($descriptor, $handler((r, m) => $eval.flatMap($serviceCall(r, _))))")
}

private[this] def serviceMethods: PrinterEndo = _.seq(service.methods.map(serviceMethodSignature))

private[this] def serviceMethodImplementations: PrinterEndo =
_.call(service.methods.map(serviceMethodImplementation): _*)

private[this] def serviceBindingImplementations: PrinterEndo =
_.indent
.add(s".builder(${service.grpcDescriptor.fullName})")
.call(service.methods.map(serviceBindingImplementation): _*)
.add(".build()")
.outdent

private[this] def serviceTrait: PrinterEndo =
_.add(s"trait $serviceNameFs2[F[_], $Ctx] {").indent.call(serviceMethods).outdent.add("}")

private[this] def serviceObject: PrinterEndo =
_.add(s"object $serviceNameFs2 extends $Companion[$serviceNameFs2] {").indent.newline
.call(serviceClient)
.newline
.call(serviceBinding)
.outdent
.newline
.add("}")

private[this] def serviceClient: PrinterEndo = {
_.add(
s"def mkClient[F[_]: $Async, $Ctx](dispatcher: $Dispatcher[F], channel: $Channel, mkMetadata: $Ctx => F[$Metadata], clientOptions: $ClientOptions): $serviceNameFs2[F, $Ctx] = new $serviceNameFs2[F, $Ctx] {"
).indent
.call(serviceMethodImplementations)
.outdent
.add("}")
}

private[this] def serviceBinding: PrinterEndo = {
_.add(
s"protected def serviceBinding[F[_]: $Async, $Ctx](dispatcher: $Dispatcher[F], serviceImpl: $serviceNameFs2[F, $Ctx], mkCtx: $Metadata => F[$Ctx], serverOptions: $ServerOptions): $ServerServiceDefinition = {"
).indent
.add(s"$ServerServiceDefinition")
.call(serviceBindingImplementations)
.outdent
.add("}")
}

// /

def printService(printer: FunctionalPrinter): FunctionalPrinter = {
printer
.add(s"package $servicePkgName", "", "import _root_.cats.syntax.all._", "")
.call(serviceTrait)
.newline
.call(serviceObject)
}
}

object Fs2AbstractServicePrinter {
private[codegen] object constants {

private val effPkg = "_root_.cats.effect"
private val fs2Pkg = "_root_.fs2"
private val fs2grpcPkg = "_root_.fs2.grpc"
private val grpcPkg = "_root_.io.grpc"

// /

val Ctx = "A"

val Async = s"$effPkg.Async"
val Resource = s"$effPkg.Resource"
val Dispatcher = s"$effPkg.std.Dispatcher"
val Stream = s"$fs2Pkg.Stream"

val Fs2ServerCallHandler = s"$fs2grpcPkg.server.Fs2ServerCallHandler"
val Fs2ClientCall = s"$fs2grpcPkg.client.Fs2ClientCall"
val ClientOptions = s"$fs2grpcPkg.client.ClientOptions"
val ServerOptions = s"$fs2grpcPkg.server.ServerOptions"
val Companion = s"$fs2grpcPkg.GeneratedCompanion"

val ServerServiceDefinition = s"$grpcPkg.ServerServiceDefinition"
val Channel = s"$grpcPkg.Channel"
val Metadata = s"$grpcPkg.Metadata"

}

}
49 changes: 37 additions & 12 deletions codegen/src/main/scala/fs2/grpc/codegen/Fs2CodeGenerator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,32 +21,55 @@

package fs2.grpc.codegen

import com.google.protobuf.Descriptors.FileDescriptor
import com.google.protobuf.Descriptors.{FileDescriptor, ServiceDescriptor}
import com.google.protobuf.ExtensionRegistry
import com.google.protobuf.compiler.PluginProtos
import protocgen.{CodeGenApp, CodeGenRequest, CodeGenResponse}
import scalapb.compiler.{DescriptorImplicits, FunctionalPrinter, GeneratorParams}
import scalapb.options.Scalapb
import scala.jdk.CollectionConverters._

import scala.jdk.CollectionConverters.*

final case class Fs2Params(serviceSuffix: String = "Fs2Grpc")

object Fs2CodeGenerator extends CodeGenApp {

private def generateServiceFile(
file: FileDescriptor,
service: ServiceDescriptor,
serviceSuffix: String,
di: DescriptorImplicits,
p: ServiceDescriptor => Fs2ServicePrinter
): PluginProtos.CodeGeneratorResponse.File = {
import di.{ExtendedServiceDescriptor, ExtendedFileDescriptor}

val code = p(service).printService(FunctionalPrinter()).result()
val b = PluginProtos.CodeGeneratorResponse.File.newBuilder()
b.setName(file.scalaDirectory + "/" + service.name + s"$serviceSuffix.scala")
b.setContent(code)
b.build
}

def generateServiceFiles(
file: FileDescriptor,
fs2params: Fs2Params,
di: DescriptorImplicits
): Seq[PluginProtos.CodeGeneratorResponse.File] = {
file.getServices.asScala.map { service =>
import di.{ExtendedServiceDescriptor, ExtendedFileDescriptor}

val p = new Fs2GrpcServicePrinter(service, fs2params.serviceSuffix, di)
val code = p.printService(FunctionalPrinter()).result()
val b = PluginProtos.CodeGeneratorResponse.File.newBuilder()
b.setName(file.scalaDirectory + "/" + service.name + s"${fs2params.serviceSuffix}.scala")
b.setContent(code)
b.build
file.getServices.asScala.flatMap { service =>
generateServiceFile(
file,
service,
fs2params.serviceSuffix + "Trailers",
di,
new Fs2GrpcExhaustiveTrailersServicePrinter(_, fs2params.serviceSuffix + "Trailers", di)
) ::
generateServiceFile(
file,
service,
fs2params.serviceSuffix,
di,
new Fs2GrpcServicePrinter(_, fs2params.serviceSuffix, di)
) :: Nil
}.toSeq
}

Expand All @@ -66,7 +89,9 @@ object Fs2CodeGenerator extends CodeGenApp {
parseParameters(request.parameter) match {
case Right((params, fs2params)) =>
val implicits = DescriptorImplicits.fromCodeGenRequest(params, request)
val srvFiles = request.filesToGenerate.flatMap(generateServiceFiles(_, fs2params, implicits))
val srvFiles = request.filesToGenerate.flatMap(
generateServiceFiles(_, fs2params, implicits)
)
CodeGenResponse.succeed(
srvFiles,
Set(PluginProtos.CodeGeneratorResponse.Feature.FEATURE_PROTO3_OPTIONAL)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright (c) 2018 Gary Coady / Fs2 Grpc Developers
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of
* this software and associated documentation files (the "Software"), to deal in
* the Software without restriction, including without limitation the rights to
* use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
* the Software, and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
* CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/

package fs2.grpc.codegen

import com.google.protobuf.Descriptors.{MethodDescriptor, ServiceDescriptor}
import scalapb.compiler.{DescriptorImplicits, StreamType}

class Fs2GrpcExhaustiveTrailersServicePrinter(
val service: ServiceDescriptor,
val serviceSuffix: String,
val di: DescriptorImplicits
) extends Fs2AbstractServicePrinter {
import fs2.grpc.codegen.Fs2AbstractServicePrinter.constants._
import di._

override protected def serviceMethodSignature(method: MethodDescriptor): String = {

val scalaInType = method.inputType.scalaType
val scalaOutType = method.outputType.scalaType
val ctx = s"ctx: $Ctx"

s"def ${method.name}" + (method.streamType match {
case StreamType.Unary => s"(request: $scalaInType, $ctx): F[($scalaOutType, $Metadata)]"
case StreamType.ClientStreaming => s"(request: $Stream[F, $scalaInType], $ctx): F[($scalaOutType, $Metadata)]"
case StreamType.ServerStreaming => s"(request: $scalaInType, $ctx): $Stream[F, $scalaOutType]"
case StreamType.Bidirectional => s"(request: $Stream[F, $scalaInType], $ctx): $Stream[F, $scalaOutType]"
})
}

override protected def handleMethod(method: MethodDescriptor): String = {
method.streamType match {
case StreamType.Unary => "unaryToUnaryCallTrailers"
case StreamType.ClientStreaming => "streamingToUnaryCallTrailers"
case StreamType.ServerStreaming => "unaryToStreamingCall"
case StreamType.Bidirectional => "streamingToStreamingCall"
}
}

}
Loading

0 comments on commit 75b9ba9

Please sign in to comment.