Skip to content

Commit

Permalink
Merge branch 'main' into docs-rewrite << upstream update
Browse files Browse the repository at this point in the history
  • Loading branch information
daveads committed Jul 19, 2023
2 parents 012dc49 + 288a646 commit 3b3a3f4
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 10 deletions.
4 changes: 2 additions & 2 deletions project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ import sbt._

object Dependencies {
val JwtCoreVersion = "9.1.1"
val NettyVersion = "4.1.93.Final"
val NettyVersion = "4.1.94.Final"
val NettyIncubatorVersion = "0.0.20.Final"
val ScalaCompactCollectionVersion = "2.8.1"
val ScalaCompactCollectionVersion = "2.11.0"
val ZioVersion = "2.0.13"
val ZioCliVersion = "0.5.0"
val ZioSchemaVersion = "0.4.12"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ private[cli] object CliEndpoint {
def loop(prefix: List[String], schema: zio.schema.Schema[_]): Set[CliEndpoint[_]] =
schema match {
case record: Schema.Record[_] =>
Set(
val endpoints =
record.fields
.foldLeft(Set.empty[CliEndpoint[_]]) { (cliEndpoints, field) =>
cliEndpoints ++ loop(prefix :+ field.name, field.schema).map { cliEndpoint =>
Expand All @@ -305,8 +305,9 @@ private[cli] object CliEndpoint {
}
}
}
.reduce(_ ++ _), // TODO review the case of nested sealed trait inside case class
)

if (endpoints.isEmpty) Set.empty
else Set(endpoints.reduce(_ ++ _)) // TODO review the case of nested sealed trait inside case class
case enumeration: Schema.Enum[_] =>
enumeration.cases.foldLeft(Set.empty[CliEndpoint[_]]) { (cliEndpoints, enumCase) =>
cliEndpoints ++ loop(prefix, enumCase.schema)
Expand Down
6 changes: 4 additions & 2 deletions zio-http/src/main/scala/zio/http/FormField.scala
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,10 @@ object FormField {
private[http] def getContentType(ast: Chunk[FormAST]): MediaType =
ast.collectFirst {
case header: FormAST.Header if header.name == "Content-Type" =>
MediaType.forContentType(header.preposition)
}.flatten.getOrElse(MediaType.text.plain)
MediaType
.forContentType(header.preposition)
.getOrElse(MediaType.application.`octet-stream`) // Unknown content type defaults to binary
}.getOrElse(MediaType.text.plain) // Missing content type defaults to text

private[http] def incomingStreamingBinary(
ast: Chunk[FormAST],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package zio.http.netty.client

import zio.{Promise, Trace, Unsafe}

import zio.http.Response
import zio.http.netty.NettyRuntime

import io.netty.channel.{ChannelHandlerContext, ChannelInboundHandlerAdapter}

/** Handles failures happening in ClientInboundHandler */
final class ClientFailureHandler(
rtm: NettyRuntime,
onResponse: Promise[Throwable, Response],
onComplete: Promise[Throwable, ChannelState],
)(implicit trace: Trace)
extends ChannelInboundHandlerAdapter {
implicit private val unsafeClass: Unsafe = Unsafe.unsafe

override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
rtm.runUninterruptible(ctx, NettyRuntime.noopEnsuring)(
onResponse.fail(cause) *> onComplete.fail(cause),
)(unsafeClass, trace)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,17 @@ final case class NettyClientDriver private (
)

pipeline.addLast(Names.ClientInboundHandler, clientInbound)

toRemove.add(clientInbound)

val clientFailureHandler =
new ClientFailureHandler(
nettyRuntime,
onResponse,
onComplete,
)
pipeline.addLast(Names.ClientFailureHandler, clientFailureHandler)
toRemove.add(clientFailureHandler)

pipeline.fireChannelRegistered()
pipeline.fireChannelActive()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@ private[zio] object NettyRequestEncoder {
val method = Conversions.methodToNetty(req.method)
val jVersion = Conversions.versionToNetty(req.version)

def replaceEmptyPathWithSlash(url: zio.http.URL) = if (url.path.isEmpty) url.addLeadingSlash else url

// As per the spec, the path should contain only the relative path.
// Host and port information should be in the headers.
val path = req.url.relative.encode
val path = replaceEmptyPathWithSlash(req.url).relative.encode

val encodedReqHeaders = Conversions.headersToNetty(req.allHeaders)

Expand Down
1 change: 1 addition & 0 deletions zio-http/src/main/scala/zio/http/netty/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ package object netty {
val HttpServerExpectContinue = "HTTP_SERVER_EXPECT_CONTINUE"
val HttpServerFlushConsolidation = "HTTP_SERVER_FLUSH_CONSOLIDATION"
val ClientInboundHandler = "CLIENT_INBOUND_HANDLER"
val ClientFailureHandler = "CLIENT_FAILURE_HANDLER"
val ClientStreamingBodyHandler = "CLIENT_STREAMING_BODY_HANDLER"
val WebSocketClientProtocolHandler = "WEB_SOCKET_CLIENT_PROTOCOL_HANDLER"
val HttpRequestDecompression = "HTTP_REQUEST_DECOMPRESSION"
Expand Down
15 changes: 15 additions & 0 deletions zio-http/src/test/scala/zio/http/ClientStreamingSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package zio.http

import zio._
import zio.test.Assertion.{equalTo, fails, hasMessage}
import zio.test.TestAspect._
import zio.test._

Expand Down Expand Up @@ -242,6 +243,20 @@ object ClientStreamingSpec extends HttpRunnableSpec {
}
} yield result
} @@ samples(20) @@ timeout(5.minutes) @@ TestAspect.ifEnvNotSet("CI"), // NOTE: random hangs on CI
test("failed stream") {
for {
port <- server(streamingServer)
client <- ZIO.service[Client]
response <- client
.request(
Request.post(
URL.decode(s"http://localhost:$port/simple-post").toOption.get,
Body.fromStream(ZStream.fail(new RuntimeException("Some error"))),
),
)
.exit
} yield assert(response)(fails(hasMessage(equalTo("Some error"))))
},
)

private def streamingOnlyTests =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import zio.test._
import zio.http.internal.HttpGen
import zio.http.netty._
import zio.http.netty.model.Conversions
import zio.http.{Body, Request}
import zio.http.{Body, QueryParams, Request, URL}

import io.netty.handler.codec.http.HttpHeaderNames

Expand All @@ -42,6 +42,13 @@ object NettyRequestEncoderSpec extends ZIOSpecDefault {
urlGen = HttpGen.genAbsoluteURL,
)

val clientParamWithEmptyPathAndQueryParams = HttpGen.requestGen(
dataGen = HttpGen.body(
Gen.listOf(Gen.alphaNumericString),
),
urlGen = Gen.const(URL.empty.addQueryParams(QueryParams(("p", "1")))),
)

def clientParamWithFiniteData(size: Int): Gen[Sized, Request] = HttpGen.requestGen(
for {
content <- Gen.alphaNumericStringBounded(size, size)
Expand Down Expand Up @@ -111,5 +118,12 @@ object NettyRequestEncoderSpec extends ZIOSpecDefault {
assertZIO(req)(equalTo(Conversions.versionToNetty(params.version)))
}
},
test("url with an empty path and query params") {
check(clientParamWithEmptyPathAndQueryParams) { params =>
val uri = encode(params).map(_.uri)
assertZIO(uri)(not(equalTo(params.url.encode)))
assertZIO(uri)(equalTo(params.url.addLeadingSlash.encode))
}
},
)
}

0 comments on commit 3b3a3f4

Please sign in to comment.