From 2166cbe932b29b6419f9cd751e8b27c647e1238e Mon Sep 17 00:00:00 2001 From: Rick Newton-Rogers Date: Tue, 11 Apr 2023 15:34:50 +0100 Subject: [PATCH] Add support for proxying in `WebsocketClient` (#130) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * WebsocketClient supports proxying Added support for TLS and plain text proxying of Websocket traffic. In the TLS case a CONNECT header is first sent establishing the proxied traffic. In the plain text case the modified URI in the initial upgrade request header indicates to the proxy server that the traffic is to be proxied. Use `NIOWebSocketFrameAggregator` to handle aggregating frame fragments. This brings with it more protections e.g. against memory exhaustion. Accompanying config has been added to support this change. * Reduce allocations and copies in WebSocket.send Reduce allocation and copies necessary to send `ByteBuffer` and `ByteBufferView` through `WebSocket.send`. In fact sending `ByteBuffer` or `ByteBufferView` doesn’t require any allocation or copy of the data. Sending a `String` now correctly pre allocates the `ByteBuffer` if multibyte characters are present in the `String`. Remove custom random websocket mask generation which would only generate bytes between `UInt8.min.. HTTPUpgradeRequestHandler.swift} | 34 +- Sources/WebSocketKit/WebSocket+Connect.swift | 125 +++++ Sources/WebSocketKit/WebSocket.swift | 56 +- Sources/WebSocketKit/WebSocketClient.swift | 184 +++++-- Sources/WebSocketKit/WebSocketHandler.swift | 81 ++- .../WebSocketKitTests/WebSocketKitTests.swift | 485 +++++++++++++++++- 8 files changed, 911 insertions(+), 65 deletions(-) rename Sources/WebSocketKit/{HTTPInitialRequestHandler.swift => HTTPUpgradeRequestHandler.swift} (60%) diff --git a/NOTICES.txt b/NOTICES.txt index 52e13d63..3a9cc733 100644 --- a/NOTICES.txt +++ b/NOTICES.txt @@ -17,3 +17,12 @@ This product contains a derivation of `NIOSSLTestHelpers.swift` from SwiftNIO SS * https://www.apache.org/licenses/LICENSE-2.0 * HOMEPAGE: * https://github.com/apple/swift-nio-ssl + +--- + +This product contains derivations of "HTTPProxySimulator" and "HTTPBin" test utils from AsyncHTTPClient. + + * LICENSE (Apache License 2.0): + * https://www.apache.org/licenses/LICENSE-2.0 + * HOMEPAGE: + * https://github.com/swift-server/async-http-client diff --git a/Package.swift b/Package.swift index 74dcbd6b..1fd56471 100644 --- a/Package.swift +++ b/Package.swift @@ -13,6 +13,7 @@ let package = Package( ], dependencies: [ .package(url: "https://github.com/apple/swift-nio.git", from: "2.33.0"), + .package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.16.0"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.14.0"), .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.11.4"), .package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.2"), @@ -22,6 +23,7 @@ let package = Package( .product(name: "NIO", package: "swift-nio"), .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOConcurrencyHelpers", package: "swift-nio"), + .product(name: "NIOExtras", package: "swift-nio-extras"), .product(name: "NIOFoundationCompat", package: "swift-nio"), .product(name: "NIOHTTP1", package: "swift-nio"), .product(name: "NIOSSL", package: "swift-nio-ssl"), diff --git a/Sources/WebSocketKit/HTTPInitialRequestHandler.swift b/Sources/WebSocketKit/HTTPUpgradeRequestHandler.swift similarity index 60% rename from Sources/WebSocketKit/HTTPInitialRequestHandler.swift rename to Sources/WebSocketKit/HTTPUpgradeRequestHandler.swift index e74cbf2e..84af52d1 100644 --- a/Sources/WebSocketKit/HTTPInitialRequestHandler.swift +++ b/Sources/WebSocketKit/HTTPUpgradeRequestHandler.swift @@ -1,7 +1,7 @@ import NIO import NIOHTTP1 -final class HTTPInitialRequestHandler: ChannelInboundHandler, RemovableChannelHandler { +final class HTTPUpgradeRequestHandler: ChannelInboundHandler, RemovableChannelHandler { typealias InboundIn = HTTPClientResponsePart typealias OutboundOut = HTTPClientRequestPart @@ -11,6 +11,8 @@ final class HTTPInitialRequestHandler: ChannelInboundHandler, RemovableChannelHa let headers: HTTPHeaders let upgradePromise: EventLoopPromise + private var requestSent = false + init(host: String, path: String, query: String?, headers: HTTPHeaders, upgradePromise: EventLoopPromise) { self.host = host self.path = path @@ -20,10 +22,33 @@ final class HTTPInitialRequestHandler: ChannelInboundHandler, RemovableChannelHa } func channelActive(context: ChannelHandlerContext) { + self.sendRequest(context: context) + context.fireChannelActive() + } + + func handlerAdded(context: ChannelHandlerContext) { + if context.channel.isActive { + self.sendRequest(context: context) + } + } + + private func sendRequest(context: ChannelHandlerContext) { + if self.requestSent { + // we might run into this handler twice, once in handlerAdded and once in channelActive. + return + } + self.requestSent = true + var headers = self.headers headers.add(name: "Host", value: self.host) - var uri = self.path.hasPrefix("/") ? self.path : "/" + self.path + var uri: String + if self.path.hasPrefix("/") || self.path.hasPrefix("ws://") || self.path.hasPrefix("wss://") { + uri = self.path + } else { + uri = "/" + self.path + } + if let query = self.query { uri += "?\(query)" } @@ -43,10 +68,13 @@ final class HTTPInitialRequestHandler: ChannelInboundHandler, RemovableChannelHa } func channelRead(context: ChannelHandlerContext, data: NIOAny) { + // `NIOHTTPClientUpgradeHandler` should consume the first response in the success case, + // any response we see here indicates a failure. Report the failure and tidy up at the end of the response. let clientResponse = self.unwrapInboundIn(data) switch clientResponse { case .head(let responseHead): - self.upgradePromise.fail(WebSocketClient.Error.invalidResponseStatus(responseHead)) + let error = WebSocketClient.Error.invalidResponseStatus(responseHead) + self.upgradePromise.fail(error) case .body: break case .end: context.close(promise: nil) diff --git a/Sources/WebSocketKit/WebSocket+Connect.swift b/Sources/WebSocketKit/WebSocket+Connect.swift index b996c798..643fcb6c 100644 --- a/Sources/WebSocketKit/WebSocket+Connect.swift +++ b/Sources/WebSocketKit/WebSocket+Connect.swift @@ -3,6 +3,15 @@ import NIOHTTP1 import Foundation extension WebSocket { + /// Establish a WebSocket connection. + /// + /// - Parameters: + /// - url: URL for the WebSocket server. + /// - headers: Headers to send to the WebSocket server. + /// - configuration: Configuration for the WebSocket client. + /// - eventLoopGroup: Event loop group to be used by the WebSocket client. + /// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`. + /// - Returns: An future which completes when the connection to the WebSocket server is established. public static func connect( to url: String, headers: HTTPHeaders = [:], @@ -22,6 +31,15 @@ extension WebSocket { ) } + /// Establish a WebSocket connection. + /// + /// - Parameters: + /// - url: URL for the WebSocket server. + /// - headers: Headers to send to the WebSocket server. + /// - configuration: Configuration for the WebSocket client. + /// - eventLoopGroup: Event loop group to be used by the WebSocket client. + /// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`. + /// - Returns: An future which completes when the connection to the WebSocket server is established. public static func connect( to url: URL, headers: HTTPHeaders = [:], @@ -43,6 +61,19 @@ extension WebSocket { ) } + /// Establish a WebSocket connection. + /// + /// - Parameters: + /// - scheme: Scheme component of the URI for the WebSocket server. + /// - host: Host component of the URI for the WebSocket server. + /// - port: Port on which to connect to the WebSocket server. + /// - path: Path component of the URI for the WebSocket server. + /// - query: Query component of the URI for the WebSocket server. + /// - headers: Headers to send to the WebSocket server. + /// - configuration: Configuration for the WebSocket client. + /// - eventLoopGroup: Event loop group to be used by the WebSocket client. + /// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`. + /// - Returns: An future which completes when the connection to the WebSocket server is established. public static func connect( scheme: String = "ws", host: String, @@ -67,4 +98,98 @@ extension WebSocket { onUpgrade: onUpgrade ) } + + /// Establish a WebSocket connection via a proxy server. + /// + /// - Parameters: + /// - scheme: Scheme component of the URI for the origin server. + /// - host: Host component of the URI for the origin server. + /// - port: Port on which to connect to the origin server. + /// - path: Path component of the URI for the origin server. + /// - query: Query component of the URI for the origin server. + /// - headers: Headers to send to the origin server. + /// - proxy: Host component of the URI for the proxy server. + /// - proxyPort: Port on which to connect to the proxy server. + /// - proxyHeaders: Headers to send to the proxy server. + /// - proxyConnectDeadline: Deadline for establishing the proxy connection. + /// - configuration: Configuration for the WebSocket client. + /// - eventLoopGroup: Event loop group to be used by the WebSocket client. + /// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`. + /// - Returns: An future which completes when the connection to the origin server is established. + public static func connect( + scheme: String = "ws", + host: String, + port: Int = 80, + path: String = "/", + query: String? = nil, + headers: HTTPHeaders = [:], + proxy: String?, + proxyPort: Int? = nil, + proxyHeaders: HTTPHeaders = [:], + proxyConnectDeadline: NIODeadline = NIODeadline.distantFuture, + configuration: WebSocketClient.Configuration = .init(), + on eventLoopGroup: EventLoopGroup, + onUpgrade: @escaping (WebSocket) -> () + ) -> EventLoopFuture { + return WebSocketClient( + eventLoopGroupProvider: .shared(eventLoopGroup), + configuration: configuration + ).connect( + scheme: scheme, + host: host, + port: port, + path: path, + query: query, + headers: headers, + proxy: proxy, + proxyPort: proxyPort, + proxyHeaders: proxyHeaders, + proxyConnectDeadline: proxyConnectDeadline, + onUpgrade: onUpgrade + ) + } + + + /// Description + /// - Parameters: + /// - url: URL for the origin server. + /// - headers: Headers to send to the origin server. + /// - proxy: Host component of the URI for the proxy server. + /// - proxyPort: Port on which to connect to the proxy server. + /// - proxyHeaders: Headers to send to the proxy server. + /// - proxyConnectDeadline: Deadline for establishing the proxy connection. + /// - configuration: Configuration for the WebSocket client. + /// - eventLoopGroup: Event loop group to be used by the WebSocket client. + /// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`. + /// - Returns: An future which completes when the connection to the origin server is established. + public static func connect( + to url: String, + headers: HTTPHeaders = [:], + proxy: String?, + proxyPort: Int? = nil, + proxyHeaders: HTTPHeaders = [:], + proxyConnectDeadline: NIODeadline = NIODeadline.distantFuture, + configuration: WebSocketClient.Configuration = .init(), + on eventLoopGroup: EventLoopGroup, + onUpgrade: @escaping (WebSocket) -> () + ) -> EventLoopFuture { + guard let url = URL(string: url) else { + return eventLoopGroup.next().makeFailedFuture(WebSocketClient.Error.invalidURL) + } + let scheme = url.scheme ?? "ws" + return self.connect( + scheme: scheme, + host: url.host ?? "localhost", + port: url.port ?? (scheme == "wss" ? 443 : 80), + path: url.path, + query: url.query, + headers: headers, + proxy: proxy, + proxyPort: proxyPort, + proxyHeaders: proxyHeaders, + proxyConnectDeadline: proxyConnectDeadline, + on: eventLoopGroup, + onUpgrade: onUpgrade + ) + } } diff --git a/Sources/WebSocketKit/WebSocket.swift b/Sources/WebSocketKit/WebSocket.swift index 26c7f6d6..ce915c6d 100644 --- a/Sources/WebSocketKit/WebSocket.swift +++ b/Sources/WebSocketKit/WebSocket.swift @@ -24,7 +24,9 @@ public final class WebSocket { self.channel.closeFuture } - private let channel: Channel + @usableFromInline + /* private but @usableFromInline */ + internal let channel: Channel private var onTextCallback: (WebSocket, String) -> () private var onBinaryCallback: (WebSocket, ByteBuffer) -> () private var onPongCallback: (WebSocket) -> () @@ -64,10 +66,10 @@ public final class WebSocket { } /// If set, this will trigger automatic pings on the connection. If ping is not answered before - /// the next ping is sent, then the WebSocket will be presumed innactive and will be closed + /// the next ping is sent, then the WebSocket will be presumed inactive and will be closed /// automatically. /// These pings can also be used to keep the WebSocket alive if there is some other timeout - /// mechanism shutting down innactive connections, such as a Load Balancer deployed in + /// mechanism shutting down inactive connections, such as a Load Balancer deployed in /// front of the server. public var pingInterval: TimeAmount? { didSet { @@ -82,13 +84,13 @@ public final class WebSocket { } } + @inlinable public func send(_ text: S, promise: EventLoopPromise? = nil) where S: Collection, S.Element == Character { let string = String(text) - var buffer = channel.allocator.buffer(capacity: text.count) - buffer.writeString(string) - self.send(raw: buffer.readableBytesView, opcode: .text, fin: true, promise: promise) + let buffer = channel.allocator.buffer(string: string) + self.send(buffer, opcode: .text, fin: true, promise: promise) } @@ -105,6 +107,7 @@ public final class WebSocket { ) } + @inlinable public func send( raw data: Data, opcode: WebSocketOpcode, @@ -113,13 +116,32 @@ public final class WebSocket { ) where Data: DataProtocol { - var buffer = channel.allocator.buffer(capacity: data.count) - buffer.writeBytes(data) + if let byteBufferView = data as? ByteBufferView { + // optimisation: converting from `ByteBufferView` to `ByteBuffer` doesn't allocate or copy any data + send(ByteBuffer(byteBufferView), opcode: opcode, fin: fin, promise: promise) + } else { + let buffer = channel.allocator.buffer(bytes: data) + send(buffer, opcode: opcode, fin: fin, promise: promise) + } + } + + /// Send the provided data in a WebSocket frame. + /// - Parameters: + /// - data: Data to be sent. + /// - opcode: Frame opcode. + /// - fin: The value of the fin bit. + /// - promise: A promise to be completed when the write is complete. + public func send( + _ data: ByteBuffer, + opcode: WebSocketOpcode = .binary, + fin: Bool = true, + promise: EventLoopPromise? = nil + ) { let frame = WebSocketFrame( fin: fin, opcode: opcode, maskKey: self.makeMaskKey(), - data: buffer + data: data ) self.channel.writeAndFlush(frame, promise: promise) } @@ -164,11 +186,7 @@ public final class WebSocket { func makeMaskKey() -> WebSocketMaskingKey? { switch type { case .client: - var bytes: [UInt8] = [] - for _ in 0..<4 { - bytes.append(.random(in: .min ..< .max)) - } - return WebSocketMaskingKey(bytes) + return WebSocketMaskingKey.random() case .server: return nil } @@ -237,14 +255,8 @@ public final class WebSocket { frameSequence.append(frame) self.frameSequence = frameSequence case .continuation: - // we must have an existing sequence - if var frameSequence = self.frameSequence { - // append this frame and update - frameSequence.append(frame) - self.frameSequence = frameSequence - } else { - self.close(code: .protocolError, promise: nil) - } + /// continuations are filtered by ``NIOWebSocketFrameAggregator`` + preconditionFailure("We will never receive a continuation frame") default: // We ignore all other frames. break diff --git a/Sources/WebSocketKit/WebSocketClient.swift b/Sources/WebSocketKit/WebSocketClient.swift index 2f13cfc7..b5cd072d 100644 --- a/Sources/WebSocketKit/WebSocketClient.swift +++ b/Sources/WebSocketKit/WebSocketClient.swift @@ -1,6 +1,7 @@ import Foundation import NIO import NIOConcurrencyHelpers +import NIOExtras import NIOHTTP1 import NIOWebSocket import NIOSSL @@ -26,12 +27,25 @@ public final class WebSocketClient { public var tlsConfiguration: TLSConfiguration? public var maxFrameSize: Int + /// Defends against small payloads in frame aggregation. + /// See `NIOWebSocketFrameAggregator` for details. + public var minNonFinalFragmentSize: Int + /// Max number of fragments in an aggregated frame. + /// See `NIOWebSocketFrameAggregator` for details. + public var maxAccumulatedFrameCount: Int + /// Maximum frame size after aggregation. + /// See `NIOWebSocketFrameAggregator` for details. + public var maxAccumulatedFrameSize: Int + public init( tlsConfiguration: TLSConfiguration? = nil, maxFrameSize: Int = 1 << 14 ) { self.tlsConfiguration = tlsConfiguration self.maxFrameSize = maxFrameSize + self.minNonFinalFragmentSize = 0 + self.maxAccumulatedFrameCount = Int.max + self.maxAccumulatedFrameSize = Int.max } } @@ -59,30 +73,71 @@ public final class WebSocketClient { query: String? = nil, headers: HTTPHeaders = [:], onUpgrade: @escaping (WebSocket) -> () + ) -> EventLoopFuture { + self.connect(scheme: scheme, host: host, port: port, path: path, query: query, headers: headers, proxy: nil, onUpgrade: onUpgrade) + } + + /// Establish a WebSocket connection via a proxy server. + /// + /// - Parameters: + /// - scheme: Scheme component of the URI for the origin server. + /// - host: Host component of the URI for the origin server. + /// - port: Port on which to connect to the origin server. + /// - path: Path component of the URI for the origin server. + /// - query: Query component of the URI for the origin server. + /// - headers: Headers to send to the origin server. + /// - proxy: Host component of the URI for the proxy server. + /// - proxyPort: Port on which to connect to the proxy server. + /// - proxyHeaders: Headers to send to the proxy server. + /// - proxyConnectDeadline: Deadline for establishing the proxy connection. + /// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`. + /// - Returns: An future which completes when the connection to the origin server is established. + public func connect( + scheme: String, + host: String, + port: Int, + path: String = "/", + query: String? = nil, + headers: HTTPHeaders = [:], + proxy: String?, + proxyPort: Int? = nil, + proxyHeaders: HTTPHeaders = [:], + proxyConnectDeadline: NIODeadline = NIODeadline.distantFuture, + onUpgrade: @escaping (WebSocket) -> () ) -> EventLoopFuture { assert(["ws", "wss"].contains(scheme)) let upgradePromise = self.group.next().makePromise(of: Void.self) let bootstrap = WebSocketClient.makeBootstrap(on: self.group) .channelOption(ChannelOptions.socket(SocketOptionLevel(IPPROTO_TCP), TCP_NODELAY), value: 1) - .channelInitializer { channel in - let httpHandler = HTTPInitialRequestHandler( + .channelInitializer { channel -> EventLoopFuture in + + let uri: String + var upgradeRequestHeaders = headers + if proxy == nil { + uri = path + } else { + let relativePath = path.hasPrefix("/") ? path : "/" + path + let port = proxyPort.map { ":\($0)" } ?? "" + uri = "\(scheme)://\(host)\(relativePath)\(port)" + + if scheme == "ws" { + upgradeRequestHeaders.add(contentsOf: proxyHeaders) + } + } + + let httpUpgradeRequestHandler = HTTPUpgradeRequestHandler( host: host, - path: path, + path: uri, query: query, - headers: headers, + headers: upgradeRequestHeaders, upgradePromise: upgradePromise ) - var key: [UInt8] = [] - for _ in 0..<16 { - key.append(.random(in: .min ..< .max)) - } let websocketUpgrader = NIOWebSocketClientUpgrader( - requestKey: Data(key).base64EncodedString(), maxFrameSize: self.configuration.maxFrameSize, automaticErrorHandling: true, upgradePipelineHandler: { channel, req in - return WebSocket.client(on: channel, onUpgrade: onUpgrade) + return WebSocket.client(on: channel, config: .init(clientConfig: self.configuration), onUpgrade: onUpgrade) } ) @@ -90,46 +145,105 @@ public final class WebSocketClient { upgraders: [websocketUpgrader], completionHandler: { context in upgradePromise.succeed(()) - channel.pipeline.removeHandler(httpHandler, promise: nil) + channel.pipeline.removeHandler(httpUpgradeRequestHandler, promise: nil) } ) - if scheme == "wss" { - do { - let context = try NIOSSLContext( - configuration: self.configuration.tlsConfiguration ?? .makeClientConfiguration() - ) - let tlsHandler: NIOSSLClientHandler + if proxy == nil || scheme == "ws" { + if scheme == "wss" { do { - tlsHandler = try NIOSSLClientHandler(context: context, serverHostname: host) - } catch let error as NIOSSLExtraError where error == .cannotUseIPAddressInSNI { - tlsHandler = try NIOSSLClientHandler(context: context, serverHostname: nil) + let tlsHandler = try self.makeTLSHandler(tlsConfiguration: self.configuration.tlsConfiguration, host: host) + // The sync methods here are safe because we're on the channel event loop + // due to the promise originating on the event loop of the channel. + try channel.pipeline.syncOperations.addHandler(tlsHandler) + } catch { + return channel.pipeline.close(mode: .all) } - return channel.pipeline.addHandler(tlsHandler).flatMap { - channel.pipeline.addHTTPClientHandlers(leftOverBytesStrategy: .forwardBytes, withClientUpgrade: config) - }.flatMap { - channel.pipeline.addHandler(httpHandler) - } - } catch { - return channel.pipeline.close(mode: .all) } - } else { + return channel.pipeline.addHTTPClientHandlers( leftOverBytesStrategy: .forwardBytes, withClientUpgrade: config ).flatMap { - channel.pipeline.addHandler(httpHandler) + channel.pipeline.addHandler(httpUpgradeRequestHandler) + } + } + + // TLS + proxy + // we need to handle connecting with an additional CONNECT request + let proxyEstablishedPromise = channel.eventLoop.makePromise(of: Void.self) + let encoder = HTTPRequestEncoder() + let decoder = ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .dropBytes)) + + var connectHeaders = proxyHeaders + connectHeaders.add(name: "Host", value: host) + + let proxyRequestHandler = NIOHTTP1ProxyConnectHandler( + targetHost: host, + targetPort: port, + headers: connectHeaders, + deadline: proxyConnectDeadline, + promise: proxyEstablishedPromise + ) + + // This code block adds HTTP handlers to allow the proxy request handler to function. + // They are then removed upon completion only to be re-added in `addHTTPClientHandlers`. + // This is done because the HTTP decoder is not valid after an upgrade, the CONNECT request being counted as one. + do { + try channel.pipeline.syncOperations.addHandler(encoder) + try channel.pipeline.syncOperations.addHandler(decoder) + try channel.pipeline.syncOperations.addHandler(proxyRequestHandler) + } catch { + return channel.eventLoop.makeFailedFuture(error) + } + + proxyEstablishedPromise.futureResult.flatMap { + channel.pipeline.removeHandler(decoder) + }.flatMap { + channel.pipeline.removeHandler(encoder) + }.whenComplete { result in + switch result { + case .success: + do { + let tlsHandler = try self.makeTLSHandler(tlsConfiguration: self.configuration.tlsConfiguration, host: host) + // The sync methods here are safe because we're on the channel event loop + // due to the promise originating on the event loop of the channel. + try channel.pipeline.syncOperations.addHandler(tlsHandler) + try channel.pipeline.syncOperations.addHTTPClientHandlers( + leftOverBytesStrategy: .forwardBytes, + withClientUpgrade: config + ) + try channel.pipeline.syncOperations.addHandler(httpUpgradeRequestHandler) + } catch { + channel.pipeline.close(mode: .all, promise: nil) + } + case .failure: + channel.pipeline.close(mode: .all, promise: nil) } } + + return channel.eventLoop.makeSucceededVoidFuture() } - let connect = bootstrap.connect(host: host, port: port) + let connect = bootstrap.connect(host: proxy ?? host, port: proxyPort ?? port) connect.cascadeFailure(to: upgradePromise) return connect.flatMap { channel in return upgradePromise.futureResult } } + private func makeTLSHandler(tlsConfiguration: TLSConfiguration?, host: String) throws -> NIOSSLClientHandler { + let context = try NIOSSLContext( + configuration: self.configuration.tlsConfiguration ?? .makeClientConfiguration() + ) + let tlsHandler: NIOSSLClientHandler + do { + tlsHandler = try NIOSSLClientHandler(context: context, serverHostname: host) + } catch let error as NIOSSLExtraError where error == .cannotUseIPAddressInSNI { + tlsHandler = try NIOSSLClientHandler(context: context, serverHostname: nil) + } + return tlsHandler + } public func syncShutdown() throws { switch self.eventLoopGroupProvider { @@ -153,13 +267,13 @@ public final class WebSocketClient { if let tsBootstrap = NIOTSConnectionBootstrap(validatingGroup: eventLoop) { return tsBootstrap } - #endif + #endif - if let nioBootstrap = ClientBootstrap(validatingGroup: eventLoop) { - return nioBootstrap - } + if let nioBootstrap = ClientBootstrap(validatingGroup: eventLoop) { + return nioBootstrap + } - fatalError("No matching bootstrap found") + fatalError("No matching bootstrap found") } deinit { diff --git a/Sources/WebSocketKit/WebSocketHandler.swift b/Sources/WebSocketKit/WebSocketHandler.swift index b54f9fb7..1af0f76a 100644 --- a/Sources/WebSocketKit/WebSocketHandler.swift +++ b/Sources/WebSocketKit/WebSocketHandler.swift @@ -2,27 +2,100 @@ import NIO import NIOWebSocket extension WebSocket { + + /// Stores configuration for a WebSocket client/server instance + public struct Configuration { + /// Defends against small payloads in frame aggregation. + /// See `NIOWebSocketFrameAggregator` for details. + public var minNonFinalFragmentSize: Int + /// Max number of fragments in an aggregated frame. + /// See `NIOWebSocketFrameAggregator` for details. + public var maxAccumulatedFrameCount: Int + /// Maximum frame size after aggregation. + /// See `NIOWebSocketFrameAggregator` for details. + public var maxAccumulatedFrameSize: Int + + public init() { + self.minNonFinalFragmentSize = 0 + self.maxAccumulatedFrameCount = Int.max + self.maxAccumulatedFrameSize = Int.max + } + + internal init(clientConfig: WebSocketClient.Configuration) { + self.minNonFinalFragmentSize = clientConfig.minNonFinalFragmentSize + self.maxAccumulatedFrameCount = clientConfig.maxAccumulatedFrameCount + self.maxAccumulatedFrameSize = clientConfig.maxAccumulatedFrameSize + } + } + + /// Sets up a channel to operate as a WebSocket client. + /// - Parameters: + /// - channel: NIO channel which the client will use to communicate. + /// - onUpgrade: An escaping closure to be executed the channel is configured with the WebSocket handlers. + /// - Returns: An future which completes when the WebSocket connection to the server is established. + public static func client( + on channel: Channel, + onUpgrade: @escaping (WebSocket) -> () + ) -> EventLoopFuture { + return self.configure(on: channel, as: .client, with: Configuration(), onUpgrade: onUpgrade) + } + + /// Sets up a channel to operate as a WebSocket client. + /// - Parameters: + /// - channel: NIO channel which the client/server will use to communicate. + /// - config: Configuration for the client channel handlers. + /// - onUpgrade: An escaping closure to be executed the channel is configured with the WebSocket handlers. + /// - Returns: An future which completes when the WebSocket connection to the server is established. public static func client( + on channel: Channel, + config: Configuration, + onUpgrade: @escaping (WebSocket) -> () + ) -> EventLoopFuture { + return self.configure(on: channel, as: .client, with: config, onUpgrade: onUpgrade) + } + + /// Sets up a channel to operate as a WebSocket server. + /// - Parameters: + /// - channel: NIO channel which the server will use to communicate. + /// - onUpgrade: An escaping closure to be executed the channel is configured with the WebSocket handlers. + /// - Returns: An future which completes when the WebSocket connection to the server is established. + public static func server( on channel: Channel, onUpgrade: @escaping (WebSocket) -> () ) -> EventLoopFuture { - return self.handle(on: channel, as: .client, onUpgrade: onUpgrade) + return self.configure(on: channel, as: .server, with: Configuration(), onUpgrade: onUpgrade) } + /// Sets up a channel to operate as a WebSocket server. + /// - Parameters: + /// - channel: NIO channel which the server will use to communicate. + /// - config: Configuration for the server channel handlers. + /// - onUpgrade: An escaping closure to be executed the channel is configured with the WebSocket handlers. + /// - Returns: An future which completes when the WebSocket connection to the server is established. public static func server( on channel: Channel, + config: Configuration, onUpgrade: @escaping (WebSocket) -> () ) -> EventLoopFuture { - return self.handle(on: channel, as: .server, onUpgrade: onUpgrade) + return self.configure(on: channel, as: .server, with: config, onUpgrade: onUpgrade) } - private static func handle( + private static func configure( on channel: Channel, as type: PeerType, + with config: Configuration, onUpgrade: @escaping (WebSocket) -> () ) -> EventLoopFuture { let webSocket = WebSocket(channel: channel, type: type) - return channel.pipeline.addHandler(WebSocketHandler(webSocket: webSocket)).map { _ in + + return channel.pipeline.addHandlers([ + NIOWebSocketFrameAggregator( + minNonFinalFragmentSize: config.minNonFinalFragmentSize, + maxAccumulatedFrameCount: config.maxAccumulatedFrameCount, + maxAccumulatedFrameSize: config.maxAccumulatedFrameSize + ), + WebSocketHandler(webSocket: webSocket) + ]).map { _ in onUpgrade(webSocket) } } diff --git a/Tests/WebSocketKitTests/WebSocketKitTests.swift b/Tests/WebSocketKitTests/WebSocketKitTests.swift index 19bef5a1..7cc33a92 100644 --- a/Tests/WebSocketKitTests/WebSocketKitTests.swift +++ b/Tests/WebSocketKitTests/WebSocketKitTests.swift @@ -1,5 +1,7 @@ import XCTest +import Atomics import NIO +import NIOExtras import NIOHTTP1 import NIOSSL import NIOWebSocket @@ -125,7 +127,7 @@ final class WebSocketKitTests: XCTestCase { let pingPromise = self.elg.next().makePromise(of: String.self) let pongPromise = self.elg.next().makePromise(of: String.self) let pingPongData = ByteBuffer(bytes: "Vapor rules".utf8) - + let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in ws.onPing { ws in pingPromise.succeed("ping") @@ -150,6 +152,41 @@ final class WebSocketKitTests: XCTestCase { try server.close(mode: .all).wait() } + func testWebSocketAggregateFrames() throws { + func byteBuffView(_ str: String) -> ByteBufferView { + ByteBuffer(string: str).readableBytesView + } + + let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in + ws.onText { ws, text in + ws.send(text, opcode: .text, fin: false) + ws.send(" th", opcode: .continuation, fin: false) + ws.send("e mo", opcode: .continuation, fin: false) + ws.send("st", opcode: .continuation, fin: true) + } + }.bind(host: "localhost", port: 0).wait() + + guard let port = server.localAddress?.port else { + XCTFail("couldn't get port from \(server.localAddress.debugDescription)") + return + } + + let promise = elg.next().makePromise(of: String.self) + let closePromise = elg.next().makePromise(of: Void.self) + WebSocket.connect(to: "ws://localhost:\(port)", on: elg) { ws in + ws.send("Hel", opcode: .text, fin: false) + ws.send("lo! Vapor r", opcode: .continuation, fin: false) + ws.send("ules", opcode: .continuation, fin: true) + ws.onText { ws, string in + promise.succeed(string) + ws.close(promise: closePromise) + } + }.cascadeFailure(to: promise) + try XCTAssertEqual(promise.futureResult.wait(), "Hello! Vapor rules the most") + XCTAssertNoThrow(try closePromise.futureResult.wait()) + try server.close(mode: .all).wait() + } + func testErrorCode() throws { let promise = self.elg.next().makePromise(of: WebSocketErrorCode.self) @@ -299,6 +336,122 @@ final class WebSocketKitTests: XCTestCase { try server.close(mode: .all).wait() } + func testProxy() throws { + let promise = elg.next().makePromise(of: String.self) + + let localWebsocketBin: WebsocketBin + let verifyProxyHead = { (ctx: ChannelHandlerContext, requestHead: HTTPRequestHead) in + XCTAssertEqual(requestHead.uri, "ws://apple.com/:\(ctx.localAddress!.port!)") + XCTAssertEqual(requestHead.headers.first(name: "Host"), "apple.com") + } + localWebsocketBin = WebsocketBin( + .http1_1(ssl: false), + proxy: .simulate( + config: WebsocketBin.ProxyConfig(tls: false, headVerification: verifyProxyHead), + authorization: "token amFwcGxlc2VlZDpwYXNzMTIz" + ), + sslContext: nil + ) { req, ws in + ws.onText { ws, text in + ws.send(text) + } + } + + defer { + XCTAssertNoThrow(try localWebsocketBin.shutdown()) + } + + let closePromise = elg.next().makePromise(of: Void.self) + + let client = WebSocketClient( + eventLoopGroupProvider: .shared(self.elg), + configuration: .init() + ) + + client.connect( + scheme: "ws", + host: "apple.com", + port: localWebsocketBin.port, + proxy: "localhost", + proxyPort: localWebsocketBin.port, + proxyHeaders: HTTPHeaders([("proxy-authorization", "token amFwcGxlc2VlZDpwYXNzMTIz")]) + ) { ws in + ws.send("hello") + ws.onText { ws, string in + promise.succeed(string) + ws.close(promise: closePromise) + } + }.cascadeFailure(to: promise) + + XCTAssertEqual(try promise.futureResult.wait(), "hello") + XCTAssertNoThrow(try closePromise.futureResult.wait()) + } + + func testProxyTLS() throws { + let promise = elg.next().makePromise(of: String.self) + + let (cert, key) = generateSelfSignedCert() + let configuration = TLSConfiguration.makeServerConfiguration( + certificateChain: [.certificate(cert)], + privateKey: .privateKey(key) + ) + let sslContext = try! NIOSSLContext(configuration: configuration) + + let verifyProxyHead = { (ctx: ChannelHandlerContext, requestHead: HTTPRequestHead) in + // CONNECT uses a special form of request target, unique to this method, consisting of + // only the host and port number of the tunnel destination, separated by a colon. + // https://httpwg.org/specs/rfc9110.html#CONNECT + XCTAssertEqual(requestHead.uri, "apple.com:\(ctx.localAddress!.port!)") + XCTAssertEqual(requestHead.headers.first(name: "Host"), "apple.com") + } + let localWebsocketBin = WebsocketBin( + .http1_1(ssl: true), + proxy: .simulate( + config: WebsocketBin.ProxyConfig(tls: true, headVerification: verifyProxyHead), + authorization: "token amFwcGxlc2VlZDpwYXNzMTIz" + ), + sslContext: sslContext + ) { req, ws in + ws.onText { ws, text in + ws.send(text) + } + } + + defer { + XCTAssertNoThrow(try localWebsocketBin.shutdown()) + } + + let closePromise = elg.next().makePromise(of: Void.self) + var tlsConfiguration = TLSConfiguration.makeClientConfiguration() + tlsConfiguration.certificateVerification = .none + + let client = WebSocketClient( + eventLoopGroupProvider: .shared(self.elg), + configuration: .init( + tlsConfiguration: tlsConfiguration + ) + ) + + client.connect( + scheme: "wss", + host: "apple.com", + port: localWebsocketBin.port, + proxy: "localhost", + proxyPort: localWebsocketBin.port, + proxyHeaders: HTTPHeaders([("proxy-authorization", "token amFwcGxlc2VlZDpwYXNzMTIz")]) + ) { ws in + ws.send("hello") + ws.onText { ws, string in + promise.succeed(string) + ws.close(promise: closePromise) + } + }.cascadeFailure(to: promise) + + XCTAssertEqual(try promise.futureResult.wait(), "hello") + XCTAssertNoThrow(try closePromise.futureResult.wait()) + } + + var elg: EventLoopGroup! override func setUp() { // needs to be at least two to avoid client / server on same EL timing issues @@ -347,3 +500,333 @@ extension ServerBootstrap { } } } + +fileprivate extension WebSocket { + func send( + _ data: String, + opcode: WebSocketOpcode, + fin: Bool = true, + promise: EventLoopPromise? = nil + ) { + self.send(raw: ByteBuffer(string: data).readableBytesView, opcode: opcode, fin: fin, promise: promise) + } +} + + + +internal final class WebsocketBin { + enum BindTarget { + case unixDomainSocket(String) + case localhostIPv4RandomPort + case localhostIPv6RandomPort + } + + enum Mode { + // refuses all connections + case refuse + // supports http1.1 connections only, which can be either plain text or encrypted + case http1_1(ssl: Bool = false) + } + + enum Proxy { + case none + case simulate(config: ProxyConfig, authorization: String?) + } + + struct ProxyConfig { + var tls: Bool + let headVerification: (ChannelHandlerContext, HTTPRequestHead) -> Void + } + + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + + var port: Int { + return Int(self.serverChannel.localAddress!.port!) + } + + private let mode: Mode + private let sslContext: NIOSSLContext? + private var serverChannel: Channel! + private let isShutdown = ManagedAtomic(false) + + init( + _ mode: Mode = .http1_1(ssl: false), + proxy: Proxy = .none, + bindTarget: BindTarget = .localhostIPv4RandomPort, + sslContext: NIOSSLContext?, + onUpgrade: @escaping (HTTPRequestHead, WebSocket) -> () + ) { + self.mode = mode + self.sslContext = sslContext + + let socketAddress: SocketAddress + switch bindTarget { + case .localhostIPv4RandomPort: + socketAddress = try! SocketAddress(ipAddress: "127.0.0.1", port: 0) + case .localhostIPv6RandomPort: + socketAddress = try! SocketAddress(ipAddress: "::1", port: 0) + case .unixDomainSocket(let path): + socketAddress = try! SocketAddress(unixDomainSocketPath: path) + } + + self.serverChannel = try! ServerBootstrap(group: self.group) + .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) + .childChannelInitializer { channel in + do { + + if case .refuse = mode { + throw HTTPBinError.refusedConnection + } + + let webSocket = NIOWebSocketServerUpgrader( + shouldUpgrade: { channel, req in + return channel.eventLoop.makeSucceededFuture([:]) + }, + upgradePipelineHandler: { channel, req in + return WebSocket.server(on: channel) { ws in + onUpgrade(req, ws) + } + } + ) + + // if we need to simulate a proxy, we need to add those handlers first + if case .simulate(config: let config, authorization: let expectedAuthorization) = proxy { + if config.tls { + try self.syncAddTLSHTTPProxyHandlers( + to: channel, + proxyConfig: config, + expectedAuthorization: expectedAuthorization, + upgraders: [webSocket] + ) + } else { + try self.syncAddHTTPProxyHandlers( + to: channel, + proxyConfig: config, + expectedAuthorization: expectedAuthorization, + upgraders: [webSocket] + ) + } + return channel.eventLoop.makeSucceededVoidFuture() + } + + // if a connection has been established, we need to negotiate TLS before + // anything else. Depending on the negotiation, the HTTPHandlers will be added. + if let sslContext = self.sslContext { + try channel.pipeline.syncOperations.addHandler(NIOSSLServerHandler(context: sslContext)) + } + + // if neither HTTP Proxy nor TLS are wanted, we can add HTTP1 handlers directly + try channel.pipeline.syncOperations.configureHTTPServerPipeline( + withPipeliningAssistance: true, + withServerUpgrade: ( + upgraders: [webSocket], + completionHandler: { ctx in + // complete + } + ), + withErrorHandling: true + ) + return channel.eventLoop.makeSucceededVoidFuture() + } catch { + return channel.eventLoop.makeFailedFuture(error) + } + }.bind(to: socketAddress).wait() + } + + + // In the TLS case we must set up the 'proxy' and the 'server' handlers sequentially + // rather than re-using parts because the requestDecoder stops parsing after a CONNECT request + private func syncAddTLSHTTPProxyHandlers( + to channel: Channel, + proxyConfig: ProxyConfig, + expectedAuthorization: String?, + upgraders: [HTTPServerProtocolUpgrader] + ) throws { + let sync = channel.pipeline.syncOperations + let promise = channel.eventLoop.makePromise(of: Void.self) + + let responseEncoder = HTTPResponseEncoder() + let requestDecoder = ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes)) + let proxySimulator = HTTPProxySimulator(promise: promise, config: proxyConfig, expectedAuthorization: expectedAuthorization) + + try sync.addHandler(responseEncoder) + try sync.addHandler(requestDecoder) + + try sync.addHandler(proxySimulator) + + promise.futureResult.flatMap { _ in + channel.pipeline.removeHandler(proxySimulator) + }.flatMap { _ in + channel.pipeline.removeHandler(responseEncoder) + }.flatMap { _ in + channel.pipeline.removeHandler(requestDecoder) + }.whenComplete { result in + switch result { + case .failure: + channel.close(mode: .all, promise: nil) + case .success: + self.httpProxyEstablished(channel, upgraders: upgraders) + break + } + } + } + + + // In the plain-text case we must set up the 'proxy' and the 'server' handlers simultaneously + // so that the combined proxy/upgrade request can be processed by the separate proxy and upgrade handlers + private func syncAddHTTPProxyHandlers( + to channel: Channel, + proxyConfig: ProxyConfig, + expectedAuthorization: String?, + upgraders: [HTTPServerProtocolUpgrader] + ) throws { + let sync = channel.pipeline.syncOperations + let promise = channel.eventLoop.makePromise(of: Void.self) + + let responseEncoder = HTTPResponseEncoder() + let requestDecoder = ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes)) + let proxySimulator = HTTPProxySimulator(promise: promise, config: proxyConfig, expectedAuthorization: expectedAuthorization) + + let serverPipelineHandler = HTTPServerPipelineHandler() + let serverProtocolErrorHandler = HTTPServerProtocolErrorHandler() + + let extraHTTPHandlers: [RemovableChannelHandler] = [ + requestDecoder, + serverPipelineHandler, + serverProtocolErrorHandler + ] + + try sync.addHandler(responseEncoder) + try sync.addHandler(requestDecoder) + + try sync.addHandler(proxySimulator) + + try sync.addHandler(serverPipelineHandler) + try sync.addHandler(serverProtocolErrorHandler) + + + let upgrader = HTTPServerUpgradeHandler(upgraders: upgraders, + httpEncoder: responseEncoder, + extraHTTPHandlers: extraHTTPHandlers, + upgradeCompletionHandler: { ctx in + // complete + }) + + + try sync.addHandler(upgrader) + + promise.futureResult.flatMap { () -> EventLoopFuture in + channel.pipeline.removeHandler(proxySimulator) + }.whenComplete { result in + switch result { + case .failure: + channel.close(mode: .all, promise: nil) + case .success: + break + } + } + } + + private func httpProxyEstablished(_ channel: Channel, upgraders: [HTTPServerProtocolUpgrader]) { + do { + // if a connection has been established, we need to negotiate TLS before + // anything else. Depending on the negotiation, the HTTPHandlers will be added. + if let sslContext = self.sslContext { + try channel.pipeline.syncOperations.addHandler(NIOSSLServerHandler(context: sslContext)) + } + + try channel.pipeline.syncOperations.configureHTTPServerPipeline( + withPipeliningAssistance: true, + withServerUpgrade: ( + upgraders: upgraders, + completionHandler: { ctx in + // complete + } + ), + withErrorHandling: true + ) + } catch { + // in case of an while modifying the pipeline we should close the connection + channel.close(mode: .all, promise: nil) + } + } + + func shutdown() throws { + self.isShutdown.store(true, ordering: .relaxed) + try self.group.syncShutdownGracefully() + } +} + +enum HTTPBinError: Error { + case refusedConnection + case invalidProxyRequest +} + +final class HTTPProxySimulator: ChannelInboundHandler, RemovableChannelHandler { + typealias InboundIn = HTTPServerRequestPart + typealias InboundOut = HTTPServerResponsePart + typealias OutboundOut = HTTPServerResponsePart + + + // the promise to succeed, once the proxy connection is setup + let promise: EventLoopPromise + let config: WebsocketBin.ProxyConfig + let expectedAuthorization: String? + + var head: HTTPResponseHead + + init(promise: EventLoopPromise, config: WebsocketBin.ProxyConfig, expectedAuthorization: String?) { + self.promise = promise + self.config = config + self.expectedAuthorization = expectedAuthorization + self.head = HTTPResponseHead(version: .init(major: 1, minor: 1), status: .ok, headers: .init([("Content-Length", "0")])) + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let request = self.unwrapInboundIn(data) + switch request { + case .head(let head): + if self.config.tls { + guard head.method == .CONNECT else { + self.head.status = .badRequest + return + } + } else { + guard head.method == .GET else { + self.head.status = .badRequest + return + } + } + + self.config.headVerification(context, head) + + if let expectedAuthorization = self.expectedAuthorization { + guard let authorization = head.headers["proxy-authorization"].first, + expectedAuthorization == authorization else { + self.head.status = .proxyAuthenticationRequired + return + } + } + if !self.config.tls { + context.fireChannelRead(data) + } + + case .body: + () + case .end: + if self.self.config.tls { + context.write(self.wrapOutboundOut(.head(self.head)), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + } + if self.head.status == .ok { + if !self.config.tls { + context.fireChannelRead(data) + } + self.promise.succeed(()) + } else { + self.promise.fail(HTTPBinError.invalidProxyRequest) + } + } + } +} +