From 1b3ae58063d8cfb08c6c80565319b757b2c4028b Mon Sep 17 00:00:00 2001 From: Julien Viet Date: Wed, 6 Nov 2024 10:06:14 +0100 Subject: [PATCH] Decouple the handshake part ServerWebSocket API. Motivation: The server WebSocket API can control handshake implicitly (e.g. sending a message) or explicitly (accept or any WebSocket interaction). This result in a more complex implementation than it should be for such API. Changes: Extract the handshake API of the ServerWebSocket API in a new ServerWebSocketHandshake API for which an handler can be set when WebSocket handshake needs to be controlled. This is a backport of Vert.x 5 server WebSocket handshake handler. The current API is maintained but deprecated. --- src/main/asciidoc/http.adoc | 10 +- src/main/java/examples/HTTPExamples.java | 48 +++--- .../java/io/vertx/core/http/HttpServer.java | 12 ++ .../io/vertx/core/http/ServerWebSocket.java | 12 ++ .../core/http/ServerWebSocketHandshake.java | 138 ++++++++++++++++++ .../http/impl/Http1xServerRequestHandler.java | 103 +++++++++++-- .../impl/HttpServerConnectionHandler.java | 4 + .../vertx/core/http/impl/HttpServerImpl.java | 16 +- .../io/vertx/core/http/WebSocketTest.java | 92 +++++++++--- 9 files changed, 386 insertions(+), 49 deletions(-) create mode 100644 src/main/java/io/vertx/core/http/ServerWebSocketHandshake.java diff --git a/src/main/asciidoc/http.adoc b/src/main/asciidoc/http.adoc index 8001bfe946d..aa4b6209bf0 100644 --- a/src/main/asciidoc/http.adoc +++ b/src/main/asciidoc/http.adoc @@ -1907,14 +1907,20 @@ When a WebSocket connection is made to the server, the handler will be called, p {@link examples.HTTPExamples#example51} ---- -You can choose to reject the WebSocket by calling {@link io.vertx.core.http.ServerWebSocket#reject()}. +===== Server WebSocket handshake + +By default, the server accepts any inbound WebSocket. + +You can set a WebSocket handshake handler to control the outcome of a WebSocket handshake, i.e. accept or reject an incoming WebSocket. + +You can choose to reject the WebSocket by calling {@link io.vertx.core.http.ServerWebSocketHandshake#accept()} or {@link io.vertx.core.http.ServerWebSocketHandshake#reject()}. [source,$lang] ---- {@link examples.HTTPExamples#example52} ---- -You can perform an asynchronous handshake by calling {@link io.vertx.core.http.ServerWebSocket#setHandshake} with a `Future`: +You can perform an asynchronous handshake: [source,$lang] ---- diff --git a/src/main/java/examples/HTTPExamples.java b/src/main/java/examples/HTTPExamples.java index b2decdae8ef..cf1e47b404f 100644 --- a/src/main/java/examples/HTTPExamples.java +++ b/src/main/java/examples/HTTPExamples.java @@ -1047,29 +1047,39 @@ public void example51(HttpServer server) { public void example52(HttpServer server) { - server.webSocketHandler(webSocket -> { - if (webSocket.path().equals("/myapi")) { - webSocket.reject(); - } else { + server + .webSocketHandshakeHandler(handshake -> { + if (handshake.path().equals("/myapi")) { + handshake.reject(); + } else { + handshake.accept(); + } + }) + .webSocketHandler(webSocket -> { // Do something - } - }); + }); } public void exampleAsynchronousHandshake(HttpServer server) { - server.webSocketHandler(webSocket -> { - Promise promise = Promise.promise(); - webSocket.setHandshake(promise.future()); - authenticate(webSocket.headers(), ar -> { - if (ar.succeeded()) { - // Terminate the handshake with the status code 101 (Switching Protocol) - // Reject the handshake with 401 (Unauthorized) - promise.complete(ar.result() ? 101 : 401); - } else { - // Will send a 500 error - promise.fail(ar.cause()); - } - }); + server + .webSocketHandshakeHandler(handshake -> { + authenticate(handshake.headers(), ar -> { + if (ar.succeeded()) { + if (ar.result()) { + // Terminate the handshake with the status code 101 (Switching Protocol) + handshake.accept(); + } else { + // Reject the handshake with 401 (Unauthorized) + handshake.reject(401); + } + } else { + // Will send a 500 error + handshake.reject(500); + } + }); + }) + .webSocketHandler(webSocket -> { + // Do something }); } diff --git a/src/main/java/io/vertx/core/http/HttpServer.java b/src/main/java/io/vertx/core/http/HttpServer.java index 5bcffc57e51..3eb1fb50746 100644 --- a/src/main/java/io/vertx/core/http/HttpServer.java +++ b/src/main/java/io/vertx/core/http/HttpServer.java @@ -119,6 +119,18 @@ public interface HttpServer extends Measured { @Fluent HttpServer webSocketHandler(Handler handler); + /** + * Set a handler for WebSocket handshake. + * + *

When an inbound HTTP request presents a WebSocket upgrade, this handler is called first. The handler + * can chose to {@link ServerWebSocketHandshake#accept()} or {@link ServerWebSocketHandshake#reject()} the request.

+ * + *

Setting no handler, implicitly accepts any HTTP request connection presenting an upgrade header and upgrades it + * to a WebSocket.

+ */ + @Fluent + HttpServer webSocketHandshakeHandler(Handler handler); + /** * @return the WebSocket handler */ diff --git a/src/main/java/io/vertx/core/http/ServerWebSocket.java b/src/main/java/io/vertx/core/http/ServerWebSocket.java index bc832cf8dba..b13b8155714 100644 --- a/src/main/java/io/vertx/core/http/ServerWebSocket.java +++ b/src/main/java/io/vertx/core/http/ServerWebSocket.java @@ -130,7 +130,9 @@ public interface ServerWebSocket extends WebSocketBase { * terminate the WebSocket handshake. * * @throws IllegalStateException when the WebSocket handshake is already set + * @deprecated instead use {@link ServerWebSocketHandshake#accept()} */ + @Deprecated void accept(); /** @@ -143,12 +145,17 @@ public interface ServerWebSocket extends WebSocketBase { * You might use this method, if for example you only want to accept WebSockets with a particular path. * * @throws IllegalStateException when the WebSocket handshake is already set + * @deprecated instead use {@link ServerWebSocketHandshake#reject()} */ + @Deprecated void reject(); /** * Like {@link #reject()} but with a {@code status}. + * + * @deprecated instead use {@link ServerWebSocketHandshake#reject(int)} */ + @Deprecated void reject(int status); /** @@ -172,12 +179,17 @@ public interface ServerWebSocket extends WebSocketBase { * @param future the future to complete with * @param handler the completion handler * @throws IllegalStateException when the WebSocket has already an asynchronous result + * @deprecated instead use {@link ServerWebSocketHandshake} */ + @Deprecated void setHandshake(Future future, Handler> handler); /** * Like {@link #setHandshake(Future, Handler)} but returns a {@code Future} of the asynchronous result + * + * @deprecated instead use {@link ServerWebSocketHandshake} */ + @Deprecated Future setHandshake(Future future); /** diff --git a/src/main/java/io/vertx/core/http/ServerWebSocketHandshake.java b/src/main/java/io/vertx/core/http/ServerWebSocketHandshake.java new file mode 100644 index 00000000000..55897a27216 --- /dev/null +++ b/src/main/java/io/vertx/core/http/ServerWebSocketHandshake.java @@ -0,0 +1,138 @@ +/* + * Copyright (c) 2011-2024 Contributors to the Eclipse Foundation + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 + * which is available at https://www.apache.org/licenses/LICENSE-2.0. + * + * SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 + */ +package io.vertx.core.http; + +import io.vertx.codegen.annotations.CacheReturn; +import io.vertx.codegen.annotations.GenIgnore; +import io.vertx.codegen.annotations.Nullable; +import io.vertx.codegen.annotations.VertxGen; +import io.vertx.core.Future; +import io.vertx.core.MultiMap; +import io.vertx.core.net.HostAndPort; +import io.vertx.core.net.SocketAddress; + +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSession; +import java.security.cert.Certificate; +import java.util.List; + +/** + * A server WebSocket handshake, allows to control acceptance or rejection of a WebSocket. + * + * @author Julien Viet + */ +@VertxGen +public interface ServerWebSocketHandshake { + + /** + * Returns the HTTP headers. + * + * @return the headers + */ + MultiMap headers(); + + /** + * @return the WebSocket handshake scheme + */ + @Nullable + String scheme(); + + /** + * @return the WebSocket handshake authority + */ + @Nullable + HostAndPort authority(); + + /* + * @return the WebSocket handshake URI. This is a relative URI. + */ + String uri(); + + /** + * @return the WebSocket handshake path. + */ + String path(); + + /** + * @return the WebSocket handshake query string. + */ + @Nullable + String query(); + + /** + * Accept the WebSocket and terminate the WebSocket handshake. + *

+ * This method should be called from the WebSocket handler to explicitly accept the WebSocket and + * terminate the WebSocket handshake. + * + * @throws IllegalStateException when the WebSocket handshake is already set + */ + Future accept(); + + /** + * Reject the WebSocket. + *

+ * Calling this method from the WebSocket handler when it is first passed to you gives you the opportunity to reject + * the WebSocket, which will cause the WebSocket handshake to fail by returning + * a {@literal 502} response code. + *

+ * You might use this method, if for example you only want to accept WebSockets with a particular path. + * + * @throws IllegalStateException when the WebSocket handshake is already set + */ + default Future reject() { + // SC_BAD_GATEWAY + return reject(502); + } + + /** + * Like {@link #reject()} but with a {@code status}. + */ + Future reject(int status); + + /** + * @return the remote address for this connection, possibly {@code null} (e.g a server bound on a domain socket). + * If {@code useProxyProtocol} is set to {@code true}, the address returned will be of the actual connecting client. + */ + @CacheReturn + SocketAddress remoteAddress(); + + /** + * @return the local address for this connection, possibly {@code null} (e.g a server bound on a domain socket) + * If {@code useProxyProtocol} is set to {@code true}, the address returned will be of the proxy. + */ + @CacheReturn + SocketAddress localAddress(); + + /** + * @return true if this {@link io.vertx.core.http.HttpConnection} is encrypted via SSL/TLS. + */ + boolean isSsl(); + + /** + * @return SSLSession associated with the underlying socket. Returns null if connection is + * not SSL. + * @see javax.net.ssl.SSLSession + */ + @GenIgnore(GenIgnore.PERMITTED_TYPE) + SSLSession sslSession(); + + /** + * @return an ordered list of the peer certificates. Returns null if connection is + * not SSL. + * @throws javax.net.ssl.SSLPeerUnverifiedException SSL peer's identity has not been verified. + * @see SSLSession#getPeerCertificates() () + * @see #sslSession() + */ + @GenIgnore() + List peerCertificates() throws SSLPeerUnverifiedException; + +} diff --git a/src/main/java/io/vertx/core/http/impl/Http1xServerRequestHandler.java b/src/main/java/io/vertx/core/http/impl/Http1xServerRequestHandler.java index f7a520f1fd2..7d0ebfd74be 100644 --- a/src/main/java/io/vertx/core/http/impl/Http1xServerRequestHandler.java +++ b/src/main/java/io/vertx/core/http/impl/Http1xServerRequestHandler.java @@ -10,9 +10,21 @@ */ package io.vertx.core.http.impl; +import io.vertx.codegen.annotations.Nullable; +import io.vertx.core.Future; import io.vertx.core.Handler; +import io.vertx.core.MultiMap; +import io.vertx.core.Promise; import io.vertx.core.http.HttpServerRequest; import io.vertx.core.http.ServerWebSocket; +import io.vertx.core.http.ServerWebSocketHandshake; +import io.vertx.core.net.HostAndPort; +import io.vertx.core.net.SocketAddress; + +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSession; +import java.security.cert.Certificate; +import java.util.List; import static io.vertx.core.http.HttpHeaders.UPGRADE; import static io.vertx.core.http.HttpHeaders.WEBSOCKET; @@ -39,19 +51,90 @@ public Http1xServerRequestHandler(HttpServerConnectionHandler handlers) { public void handle(HttpServerRequest req) { Handler wsHandler = handlers.wsHandler; Handler reqHandler = handlers.requestHandler; - if (wsHandler != null ) { + Handler wsHandshakeHandler = handlers.wsHandshakeHandler; + if (wsHandler != null) { if (req.headers().contains(UPGRADE, WEBSOCKET, true) && handlers.server.wsAccept()) { // Missing upgrade header + null request handler will be handled when creating the handshake by sending a 400 error // handle((Http1xServerRequest) req, wsHandler); - ((Http1xServerRequest)req).webSocket().onComplete(ar -> { - if (ar.succeeded()) { - ServerWebSocketImpl ws = (ServerWebSocketImpl) ar.result(); - wsHandler.handle(ws); - ws.tryHandshake(101); - } else { - // ???? - } - }); + Future fut = ((Http1xServerRequest) req).webSocket(); + if (wsHandshakeHandler != null) { + fut.onComplete(ar -> { + if (ar.succeeded()) { + ServerWebSocket ws = ar.result(); + Promise promise = Promise.promise(); + ws.setHandshake(promise.future()); + wsHandshakeHandler.handle(new ServerWebSocketHandshake() { + @Override + public MultiMap headers() { + return ws.headers(); + } + @Override + public @Nullable String scheme() { + return ws.scheme(); + } + @Override + public @Nullable HostAndPort authority() { + return ws.authority(); + } + @Override + public String uri() { + return ws.uri(); + } + @Override + public String path() { + return ws.path(); + } + @Override + public @Nullable String query() { + return ws.query(); + } + @Override + public Future accept() { + promise.complete(101); + wsHandler.handle(ws); + return Future.succeededFuture(ws); + } + @Override + public Future reject(int status) { + promise.complete(status); + return Future.succeededFuture(); + } + @Override + public SocketAddress remoteAddress() { + return ws.remoteAddress(); + } + @Override + public SocketAddress localAddress() { + return ws.localAddress(); + } + @Override + public boolean isSsl() { + return ws.isSsl(); + } + @Override + public SSLSession sslSession() { + return ws.sslSession(); + } + @Override + public List peerCertificates() throws SSLPeerUnverifiedException { + return ws.peerCertificates(); + } + }); + } else { + + } + }); + } else { + fut.onComplete(ar -> { + if (ar.succeeded()) { + ServerWebSocketImpl ws = (ServerWebSocketImpl) ar.result(); + wsHandler.handle(ws); + ws.tryHandshake(101); + } else { + // ???? + } + }); + } } else { if (reqHandler != null) { reqHandler.handle(req); diff --git a/src/main/java/io/vertx/core/http/impl/HttpServerConnectionHandler.java b/src/main/java/io/vertx/core/http/impl/HttpServerConnectionHandler.java index dada122d6b5..163e7f2ea36 100644 --- a/src/main/java/io/vertx/core/http/impl/HttpServerConnectionHandler.java +++ b/src/main/java/io/vertx/core/http/impl/HttpServerConnectionHandler.java @@ -21,6 +21,7 @@ import io.vertx.core.http.HttpConnection; import io.vertx.core.http.HttpServerRequest; import io.vertx.core.http.ServerWebSocket; +import io.vertx.core.http.ServerWebSocketHandshake; import io.vertx.core.impl.ContextInternal; import java.util.ArrayList; @@ -35,6 +36,7 @@ public class HttpServerConnectionHandler implements Handler requestHandler; final Handler invalidRequestHandler; + final Handler wsHandshakeHandler; final Handler wsHandler; final Handler connectionHandler; final Handler exceptionHandler; @@ -43,12 +45,14 @@ public HttpServerConnectionHandler( HttpServerImpl server, Handler requestHandler, Handler invalidRequestHandler, + Handler wsHandshakeHandler, Handler wsHandler, Handler connectionHandler, Handler exceptionHandler) { this.server = server; this.requestHandler = requestHandler; this.invalidRequestHandler = invalidRequestHandler == null ? HttpServerRequest.DEFAULT_INVALID_REQUEST_HANDLER : invalidRequestHandler; + this.wsHandshakeHandler = wsHandshakeHandler; this.wsHandler = wsHandler; this.connectionHandler = connectionHandler; this.exceptionHandler = exceptionHandler; diff --git a/src/main/java/io/vertx/core/http/impl/HttpServerImpl.java b/src/main/java/io/vertx/core/http/impl/HttpServerImpl.java index 69fdcf082f9..5079a3a8fd5 100644 --- a/src/main/java/io/vertx/core/http/impl/HttpServerImpl.java +++ b/src/main/java/io/vertx/core/http/impl/HttpServerImpl.java @@ -56,6 +56,7 @@ public class HttpServerImpl extends TCPServerBase implements HttpServer, Closeab private Handler invalidRequestHandler; private Handler connectionHandler; private Handler exceptionHandler; + private Handler webSocketHandshakeHandler; public HttpServerImpl(VertxInternal vertx, HttpServerOptions options) { super(vertx, options); @@ -90,6 +91,12 @@ public HttpServer webSocketHandler(Handler handler) { return this; } + @Override + public HttpServer webSocketHandshakeHandler(Handler handler) { + webSocketHandshakeHandler = handler; + return this; + } + @Override public Handler requestHandler() { return requestStream.handler(); @@ -145,7 +152,14 @@ protected BiConsumer childHandler(ContextInternal c String host = address.isInetSocket() ? address.host() : "localhost"; int port = address.port(); String serverOrigin = (options.isSsl() ? "https" : "http") + "://" + host + ":" + port; - HttpServerConnectionHandler hello = new HttpServerConnectionHandler(this, requestStream.handler, invalidRequestHandler, wsStream.handler, connectionHandler, exceptionHandler == null ? DEFAULT_EXCEPTION_HANDLER : exceptionHandler); + HttpServerConnectionHandler hello = new HttpServerConnectionHandler( + this, + requestStream.handler, + invalidRequestHandler, + webSocketHandshakeHandler, + wsStream.handler, + connectionHandler, + exceptionHandler == null ? DEFAULT_EXCEPTION_HANDLER : exceptionHandler); Supplier streamContextSupplier = context::duplicate; return new HttpServerWorker( connContext, diff --git a/src/test/java/io/vertx/core/http/WebSocketTest.java b/src/test/java/io/vertx/core/http/WebSocketTest.java index b2a8d856804..663d2facdca 100644 --- a/src/test/java/io/vertx/core/http/WebSocketTest.java +++ b/src/test/java/io/vertx/core/http/WebSocketTest.java @@ -131,18 +131,33 @@ protected VertxOptions getOptions() { } @Test - public void testRejectHybi00() throws Exception { - testReject(WebsocketVersion.V00, null, 502, "Bad Gateway"); + public void testRejectHybi00_1() throws Exception { + testReject1(WebsocketVersion.V00, null, 502, "Bad Gateway"); } @Test - public void testRejectHybi08() throws Exception { - testReject(WebsocketVersion.V08, null, 502, "Bad Gateway"); + public void testRejectHybi08_1() throws Exception { + testReject1(WebsocketVersion.V08, null, 502, "Bad Gateway"); } @Test - public void testRejectWithStatusCode() throws Exception { - testReject(WebsocketVersion.V08, 404, 404, "Not Found"); + public void testRejectWithStatusCode_1() throws Exception { + testReject1(WebsocketVersion.V08, 404, 404, "Not Found"); + } + + @Test + public void testRejectHybi00_2() throws Exception { + testReject2(WebsocketVersion.V00, null, 502, "Bad Gateway"); + } + + @Test + public void testRejectHybi08_2() throws Exception { + testReject2(WebsocketVersion.V08, null, 502, "Bad Gateway"); + } + + @Test + public void testRejectWithStatusCode_2() throws Exception { + testReject2(WebsocketVersion.V08, 404, 404, "Not Found"); } @Test @@ -1353,24 +1368,45 @@ private void testInvalidHandshake(BiConsumer { + server.webSocketHandler(ws -> { + assertEquals("/some/path", ws.path()); + if (rejectionStatus != null) { + ws.reject(rejectionStatus); + } else { + ws.reject(); + } + }); + }); + } - String path = "/some/path"; + private void testReject2(WebsocketVersion version, Integer rejectionStatus, int expectedRejectionStatus, String expectedBody) throws Exception { + testReject(version, expectedRejectionStatus, expectedBody, server -> { + server.webSocketHandshakeHandler(handshake -> { + assertEquals("/some/path", handshake.path()); + if (rejectionStatus != null) { + handshake.reject(rejectionStatus); + } else { + handshake.reject(); + } + }); + server.webSocketHandler(ws -> { - server = vertx.createHttpServer(new HttpServerOptions().setPort(DEFAULT_HTTP_PORT)).webSocketHandler(ws -> { - assertEquals(path, ws.path()); - if (rejectionStatus != null) { - ws.reject(rejectionStatus); - } else { - ws.reject(); - } + }); }); + } + + private void testReject(WebsocketVersion version, int expectedRejectionStatus, String expectedBody, Consumer handler) throws Exception { + + server = vertx.createHttpServer(new HttpServerOptions().setPort(DEFAULT_HTTP_PORT)); + handler.accept(server); server.listen(onSuccess(s -> { WebSocketConnectOptions options = new WebSocketConnectOptions() .setPort(DEFAULT_HTTP_PORT) .setHost(DEFAULT_HTTP_HOST) - .setURI(path) + .setURI("/some/path") .setVersion(version); client = vertx.createWebSocketClient(); client.connect(options, onFailure(t -> { @@ -1386,7 +1422,29 @@ private void testReject(WebsocketVersion version, Integer rejectionStatus, int e } @Test - public void testAsyncAccept() { + public void testAsyncAccept1() { + AtomicBoolean resolved = new AtomicBoolean(); + server = vertx.createHttpServer(new HttpServerOptions().setPort(DEFAULT_HTTP_PORT)) + .webSocketHandshakeHandler(handshake -> { + vertx.setTimer(500, id -> { + resolved.set(true); + handshake.accept(); + }); + }) + .webSocketHandler(ws -> { + }); + server.listen(onSuccess(s -> { + client = vertx.createWebSocketClient(); + client.connect(DEFAULT_HTTP_PORT, HttpTestBase.DEFAULT_HTTP_HOST, "/some/path", onSuccess(ws -> { + assertTrue(resolved.get()); + testComplete(); + })); + })); + await(); + } + + @Test + public void testAsyncAccept2() { AtomicBoolean resolved = new AtomicBoolean(); server = vertx.createHttpServer(new HttpServerOptions().setPort(DEFAULT_HTTP_PORT)).webSocketHandler(ws -> { Promise promise = Promise.promise();