diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala b/zio-http/jvm/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala index 74340d825e..a7f9d0ba91 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala @@ -87,12 +87,19 @@ private[zio] final case class ServerInboundHandler( ) releaseRequest() } else { - val req = makeZioRequest(ctx, jReq) - val exit = handler(req) - if (attemptImmediateWrite(ctx, req.method, exit)) { + val req = makeZioRequest(ctx, jReq) + if (!validateHostHeader(req)) { + attemptFastWrite(ctx, req.method, Response.status(Status.BadRequest)) releaseRequest() } else { - writeResponse(ctx, runtime, exit, req)(releaseRequest) + + val exit = handler(req) + if (attemptImmediateWrite(ctx, req.method, exit)) { + releaseRequest() + } else { + writeResponse(ctx, runtime, exit, req)(releaseRequest) + + } } } } finally { @@ -108,6 +115,34 @@ private[zio] final case class ServerInboundHandler( } + private def validateHostHeader(req: Request): Boolean = { + req.headers.get("Host") match { + case Some(host) => + val parts = host.split(":") + val hostname = parts(0) + val isValidHost = validateHostname(hostname) + val isValidPort = parts.length == 1 || (parts.length == 2 && parts(1).forall(_.isDigit)) + val isValid = isValidHost && isValidPort + println(s"Host: $host, isValidHost: $isValidHost, isValidPort: $isValidPort, isValid: $isValid") + isValid + case None => + println("Host header missing!") + false + } + } + +// Validate a regular hostname (based on RFC 1035) + private def validateHostname(hostname: String): Boolean = { + if (hostname.isEmpty || hostname.contains("_")) { + return false + } + val labels = hostname.split("\\.") + if (labels.exists(label => label.isEmpty || label.length > 63 || label.startsWith("-") || label.endsWith("-"))) { + return false + } + hostname.forall(c => c.isLetterOrDigit || c == '.' || c == '-') && hostname.length <= 253 + } + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = cause match { case ioe: IOException if { @@ -262,7 +297,6 @@ private[zio] final case class ServerInboundHandler( remoteCertificate = clientCert, ) } - } /*