From 17c0afbf24f4e288e183bc34317dc97b5cd084bd Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Mon, 15 May 2023 00:18:10 -0500 Subject: [PATCH] Update for new NIOSSL (#132) * Bump to Swift 5.6 minimum and update README * The usual CI cleanup * Require a more recent NIO, switch to NIOCore instead of NIO, use EventLoopGroup.any() instead of .next(), make EventLoopGroupProvider an alias for NIO's version, use annotated exports * Ditch unneeded Concurrency checks (which also incidentally enables back-deployment to 10.15) * Update SSLTestHelpers to reflect the current state of the original upstream (which basically just means using OpaquePointer instead of UMP, partly because EVP_PKEY is no longer visible from CNIOBoringSSL * Add API breakage allowlist for this branch * Add test coverage for all the connect methods, binary frames, and ping intervals. --- ...allowlist-branch-update-for-new-niossl.txt | 3 + .github/workflows/main-codecov.yml | 9 - .github/workflows/projectboard.yml | 28 +- .github/workflows/test.yml | 17 +- Package.swift | 8 +- README.md | 9 +- .../Concurrency/WebSocket+Concurrency.swift | 4 - Sources/WebSocketKit/Exports.swift | 29 +- .../HTTPUpgradeRequestHandler.swift | 2 +- Sources/WebSocketKit/WebSocket+Connect.swift | 4 +- Sources/WebSocketKit/WebSocket.swift | 2 +- Sources/WebSocketKit/WebSocketClient.swift | 10 +- Sources/WebSocketKit/WebSocketHandler.swift | 2 +- .../AsyncWebSocketKitTests.swift | 94 +++- Tests/WebSocketKitTests/SSLTestHelpers.swift | 6 +- Tests/WebSocketKitTests/Utilities.swift | 363 +++++++++++++ .../WebSocketKitTests/WebSocketKitTests.swift | 508 ++++-------------- 17 files changed, 627 insertions(+), 471 deletions(-) create mode 100644 .api-breakage/allowlist-branch-update-for-new-niossl.txt delete mode 100644 .github/workflows/main-codecov.yml create mode 100644 Tests/WebSocketKitTests/Utilities.swift diff --git a/.api-breakage/allowlist-branch-update-for-new-niossl.txt b/.api-breakage/allowlist-branch-update-for-new-niossl.txt new file mode 100644 index 00000000..a32a2389 --- /dev/null +++ b/.api-breakage/allowlist-branch-update-for-new-niossl.txt @@ -0,0 +1,3 @@ +API breakage: import NIO has been renamed to import NIOCore +API breakage: import NIO has been renamed to import NIOPosix + diff --git a/.github/workflows/main-codecov.yml b/.github/workflows/main-codecov.yml deleted file mode 100644 index 1d0fe384..00000000 --- a/.github/workflows/main-codecov.yml +++ /dev/null @@ -1,9 +0,0 @@ -name: Update code coverage baselines -on: - push: { branches: [ main ] } -jobs: - update-main-codecov: - uses: vapor/ci/.github/workflows/run-unit-tests.yml@reusable-workflows - with: - with_coverage: true - with_tsan: true diff --git a/.github/workflows/projectboard.yml b/.github/workflows/projectboard.yml index 0e4e66f6..a0e6d988 100644 --- a/.github/workflows/projectboard.yml +++ b/.github/workflows/projectboard.yml @@ -5,27 +5,7 @@ on: types: [reopened, closed, labeled, unlabeled, assigned, unassigned] jobs: - setup_matrix_input: - runs-on: ubuntu-latest - - steps: - - id: set-matrix - run: | - output=$(curl ${{ github.event.issue.url }}/labels | jq '.[] | .name') || output="" - - echo '======================' - echo 'Process incoming data' - echo '======================' - json=$(echo $output | sed 's/"\s"/","/g') - echo $json - echo "::set-output name=matrix::$(echo $json)" - outputs: - issueTags: ${{ steps.set-matrix.outputs.matrix }} - - Manage_project_issues: - needs: setup_matrix_input - uses: vapor/ci/.github/workflows/issues-to-project-board.yml@main - with: - labelsJson: ${{ needs.setup_matrix_input.outputs.issueTags }} - secrets: - PROJECT_BOARD_AUTOMATION_PAT: "${{ secrets.PROJECT_BOARD_AUTOMATION_PAT }}" + update_project_boards: + name: Update project boards + uses: vapor/ci/.github/workflows/update-project-boards-for-issue.yml@reusable-workflows + secrets: inherit diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index bff42997..b2288dc2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,11 +1,17 @@ name: test +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true on: -- pull_request + pull_request: { types: [opened, reopened, synchronize, ready_for_review] } + push: { branches: [ main ] } + jobs: + vapor-integration: + if: ${{ !(github.event.pull_request.draft || false) }} runs-on: ubuntu-latest - container: - image: swift:5.7 + container: swift:5.8-jammy steps: - name: Check out package uses: actions/checkout@v3 @@ -14,8 +20,7 @@ jobs: uses: actions/checkout@v3 with: { repository: 'vapor/vapor', path: 'vapor' } - name: Use local package in Vapor - run: | - swift package --package-path vapor edit websocket-kit --path websocket-kit + run: swift package --package-path vapor edit websocket-kit --path websocket-kit - name: Run Vapor tests run: swift test --package-path vapor @@ -24,6 +29,4 @@ jobs: with: with_coverage: true with_tsan: false - coverage_ignores: '/Tests/' with_public_api_check: ${{ github.event_name == 'pull_request' }} - diff --git a/Package.swift b/Package.swift index 1fd56471..4841d15c 100644 --- a/Package.swift +++ b/Package.swift @@ -1,4 +1,4 @@ -// swift-tools-version:5.4 +// swift-tools-version:5.6 import PackageDescription let package = Package( @@ -6,21 +6,21 @@ let package = Package( platforms: [ .macOS(.v10_15), .iOS(.v13), + .watchOS(.v6), .tvOS(.v13), ], products: [ .library(name: "WebSocketKit", targets: ["WebSocketKit"]), ], dependencies: [ - .package(url: "https://github.com/apple/swift-nio.git", from: "2.33.0"), + .package(url: "https://github.com/apple/swift-nio.git", from: "2.53.0"), .package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.16.0"), - .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.14.0"), + .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.24.0"), .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.11.4"), .package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.2"), ], targets: [ .target(name: "WebSocketKit", dependencies: [ - .product(name: "NIO", package: "swift-nio"), .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOConcurrencyHelpers", package: "swift-nio"), .product(name: "NIOExtras", package: "swift-nio-extras"), diff --git a/README.md b/README.md index 71370ce1..2066c308 100644 --- a/README.md +++ b/README.md @@ -15,10 +15,13 @@ MIT License - - Continuous Integration + + Continuous Integration - Swift 5.2 + Swift 5.6 + + + Swift 5.8

diff --git a/Sources/WebSocketKit/Concurrency/WebSocket+Concurrency.swift b/Sources/WebSocketKit/Concurrency/WebSocket+Concurrency.swift index 6b7ecc3b..d0e9c0d3 100644 --- a/Sources/WebSocketKit/Concurrency/WebSocket+Concurrency.swift +++ b/Sources/WebSocketKit/Concurrency/WebSocket+Concurrency.swift @@ -1,10 +1,8 @@ -#if compiler(>=5.5) && canImport(_Concurrency) import NIOCore import NIOWebSocket import Foundation import NIOHTTP1 -@available(macOS 12, iOS 15, watchOS 8, tvOS 15, *) extension WebSocket { public func send(_ text: S) async throws where S: Collection, S.Element == Character @@ -142,5 +140,3 @@ extension WebSocket { ).get() } } - -#endif diff --git a/Sources/WebSocketKit/Exports.swift b/Sources/WebSocketKit/Exports.swift index 9c1931e1..11ce0e49 100644 --- a/Sources/WebSocketKit/Exports.swift +++ b/Sources/WebSocketKit/Exports.swift @@ -1,14 +1,27 @@ -#if !BUILDING_DOCC +#if swift(>=5.8) -@_exported import struct NIO.ByteBuffer -@_exported import protocol NIO.Channel -@_exported import protocol NIO.EventLoop -@_exported import protocol NIO.EventLoopGroup -@_exported import struct NIO.EventLoopPromise -@_exported import class NIO.EventLoopFuture +@_documentation(visibility: internal) @_exported import struct NIOCore.ByteBuffer +@_documentation(visibility: internal) @_exported import protocol NIOCore.Channel +@_documentation(visibility: internal) @_exported import protocol NIOCore.EventLoop +@_documentation(visibility: internal) @_exported import protocol NIOCore.EventLoopGroup +@_documentation(visibility: internal) @_exported import struct NIOCore.EventLoopPromise +@_documentation(visibility: internal) @_exported import class NIOCore.EventLoopFuture + +@_documentation(visibility: internal) @_exported import struct NIOHTTP1.HTTPHeaders + +@_documentation(visibility: internal) @_exported import struct Foundation.URL + +#else + +@_exported import struct NIOCore.ByteBuffer +@_exported import protocol NIOCore.Channel +@_exported import protocol NIOCore.EventLoop +@_exported import protocol NIOCore.EventLoopGroup +@_exported import struct NIOCore.EventLoopPromise +@_exported import class NIOCore.EventLoopFuture @_exported import struct NIOHTTP1.HTTPHeaders @_exported import struct Foundation.URL -#endif \ No newline at end of file +#endif diff --git a/Sources/WebSocketKit/HTTPUpgradeRequestHandler.swift b/Sources/WebSocketKit/HTTPUpgradeRequestHandler.swift index 84af52d1..817d67b3 100644 --- a/Sources/WebSocketKit/HTTPUpgradeRequestHandler.swift +++ b/Sources/WebSocketKit/HTTPUpgradeRequestHandler.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore import NIOHTTP1 final class HTTPUpgradeRequestHandler: ChannelInboundHandler, RemovableChannelHandler { diff --git a/Sources/WebSocketKit/WebSocket+Connect.swift b/Sources/WebSocketKit/WebSocket+Connect.swift index 643fcb6c..ca94540b 100644 --- a/Sources/WebSocketKit/WebSocket+Connect.swift +++ b/Sources/WebSocketKit/WebSocket+Connect.swift @@ -20,7 +20,7 @@ extension WebSocket { onUpgrade: @escaping (WebSocket) -> () ) -> EventLoopFuture { guard let url = URL(string: url) else { - return eventLoopGroup.next().makeFailedFuture(WebSocketClient.Error.invalidURL) + return eventLoopGroup.any().makeFailedFuture(WebSocketClient.Error.invalidURL) } return self.connect( to: url, @@ -174,7 +174,7 @@ extension WebSocket { onUpgrade: @escaping (WebSocket) -> () ) -> EventLoopFuture { guard let url = URL(string: url) else { - return eventLoopGroup.next().makeFailedFuture(WebSocketClient.Error.invalidURL) + return eventLoopGroup.any().makeFailedFuture(WebSocketClient.Error.invalidURL) } let scheme = url.scheme ?? "ws" return self.connect( diff --git a/Sources/WebSocketKit/WebSocket.swift b/Sources/WebSocketKit/WebSocket.swift index ce915c6d..3e6f679c 100644 --- a/Sources/WebSocketKit/WebSocket.swift +++ b/Sources/WebSocketKit/WebSocket.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore import NIOWebSocket import NIOHTTP1 import NIOSSL diff --git a/Sources/WebSocketKit/WebSocketClient.swift b/Sources/WebSocketKit/WebSocketClient.swift index b5cd072d..08bf48e5 100644 --- a/Sources/WebSocketKit/WebSocketClient.swift +++ b/Sources/WebSocketKit/WebSocketClient.swift @@ -1,5 +1,6 @@ import Foundation -import NIO +import NIOCore +import NIOPosix import NIOConcurrencyHelpers import NIOExtras import NIOHTTP1 @@ -18,10 +19,7 @@ public final class WebSocketClient { } } - public enum EventLoopGroupProvider { - case shared(EventLoopGroup) - case createNew - } + public typealias EventLoopGroupProvider = NIOEventLoopGroupProvider public struct Configuration { public var tlsConfiguration: TLSConfiguration? @@ -106,7 +104,7 @@ public final class WebSocketClient { onUpgrade: @escaping (WebSocket) -> () ) -> EventLoopFuture { assert(["ws", "wss"].contains(scheme)) - let upgradePromise = self.group.next().makePromise(of: Void.self) + let upgradePromise = self.group.any().makePromise(of: Void.self) let bootstrap = WebSocketClient.makeBootstrap(on: self.group) .channelOption(ChannelOptions.socket(SocketOptionLevel(IPPROTO_TCP), TCP_NODELAY), value: 1) .channelInitializer { channel -> EventLoopFuture in diff --git a/Sources/WebSocketKit/WebSocketHandler.swift b/Sources/WebSocketKit/WebSocketHandler.swift index 1af0f76a..45f266ce 100644 --- a/Sources/WebSocketKit/WebSocketHandler.swift +++ b/Sources/WebSocketKit/WebSocketHandler.swift @@ -1,4 +1,4 @@ -import NIO +import NIOCore import NIOWebSocket extension WebSocket { diff --git a/Tests/WebSocketKitTests/AsyncWebSocketKitTests.swift b/Tests/WebSocketKitTests/AsyncWebSocketKitTests.swift index 04538f65..e20ffa93 100644 --- a/Tests/WebSocketKitTests/AsyncWebSocketKitTests.swift +++ b/Tests/WebSocketKitTests/AsyncWebSocketKitTests.swift @@ -1,25 +1,23 @@ -#if compiler(>=5.5) && canImport(_Concurrency) import XCTest import NIO import NIOHTTP1 import NIOWebSocket @testable import WebSocketKit -@available(macOS 12, iOS 15, watchOS 8, tvOS 15, *) final class AsyncWebSocketKitTests: XCTestCase { func testWebSocketEcho() async throws { - let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in + let server = try await ServerBootstrap.webSocket(on: self.elg) { req, ws in ws.onText { ws, text in ws.send(text) } - }.bind(host: "localhost", port: 0).wait() + }.bind(host: "localhost", port: 0).get() guard let port = server.localAddress?.port else { XCTFail("couldn't get port from \(server.localAddress.debugDescription)") return } - let promise = elg.next().makePromise(of: String.self) + let promise = elg.any().makePromise(of: String.self) try await WebSocket.connect(to: "ws://localhost:\(port)", on: elg) { ws in do { @@ -42,7 +40,91 @@ final class AsyncWebSocketKitTests: XCTestCase { try await server.close(mode: .all) } + func testAlternateWebsocketConnectMethods() async throws { + let server = try await ServerBootstrap.webSocket(on: self.elg) { $1.onText { $0.send($1) } }.bind(host: "localhost", port: 0).get() + let promise = self.elg.any().makePromise(of: Void.self) + guard let port = server.localAddress?.port else { + return XCTFail("couldn't get port from \(String(reflecting: server.localAddress))") + } + try await WebSocket.connect(scheme: "ws", host: "localhost", port: port, on: self.elg) { (ws) async in + do { try await ws.send("hello") } catch { promise.fail(error); try? await ws.close() } + ws.onText { ws, _ in + promise.succeed(()) + do { try await ws.close() } catch { XCTFail("Failed to close websocket: \(String(reflecting: error))") } + } + } + try await promise.futureResult.get() + try await server.close(mode: .all) + } + + func testBadURLInWebsocketConnect() async throws { + do { + try await WebSocket.connect(to: "%w", on: self.elg, onUpgrade: { _ async in }) + XCTAssertThrowsError({}()) + } catch { + XCTAssertThrowsError(try { throw error }()) { + guard case .invalidURL = $0 as? WebSocketClient.Error else { + return XCTFail("Expected .invalidURL but got \(String(reflecting: $0))") + } + } + } + } + + func testOnBinary() async throws { + let server = try await ServerBootstrap.webSocket(on: self.elg) { $1.onBinary { $0.send($1) } }.bind(host: "localhost", port: 0).get() + let promise = self.elg.any().makePromise(of: [UInt8].self) + guard let port = server.localAddress?.port else { + return XCTFail("couldn't get port from \(String(reflecting: server.localAddress))") + } + try await WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { ws in + do { try await ws.send([0x01]) } catch { promise.fail(error); try? await ws.close() } + ws.onBinary { ws, buf in + promise.succeed(.init(buf.readableBytesView)) + do { try await ws.close() } + catch { XCTFail("Failed to close websocket: \(String(reflecting: error))") } + } + } + let result = try await promise.futureResult.get() + XCTAssertEqual(result, [0x01]) + try await server.close(mode: .all) + } + + func testSendPing() async throws { + let server = try await ServerBootstrap.webSocket(on: self.elg) { _, _ in }.bind(host: "localhost", port: 0).get() + let promise = self.elg.any().makePromise(of: Void.self) + guard let port = server.localAddress?.port else { + return XCTFail("couldn't get port from \(String(reflecting: server.localAddress))") + } + try await WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { (ws) async in + do { try await ws.sendPing() } catch { promise.fail(error); try? await ws.close() } + ws.onPong { + promise.succeed(()) + do { try await $0.close() } catch { XCTFail("Failed to close websocket: \(String(reflecting: error))") } + } + } + try await promise.futureResult.get() + try await server.close(mode: .all) + } + + func testSetPingInterval() async throws { + let server = try await ServerBootstrap.webSocket(on: self.elg) { _, _ in }.bind(host: "localhost", port: 0).get() + let promise = self.elg.any().makePromise(of: Void.self) + guard let port = server.localAddress?.port else { + return XCTFail("couldn't get port from \(String(reflecting: server.localAddress))") + } + try await WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { (ws) async in + ws.pingInterval = .milliseconds(100) + ws.onPong { + promise.succeed(()) + do { try await $0.close() } catch { XCTFail("Failed to close websocket: \(String(reflecting: error))") } + } + } + try await promise.futureResult.get() + try await server.close(mode: .all) + } + var elg: EventLoopGroup! + override func setUp() { // needs to be at least two to avoid client / server on same EL timing issues self.elg = MultiThreadedEventLoopGroup(numberOfThreads: 2) @@ -51,5 +133,3 @@ final class AsyncWebSocketKitTests: XCTestCase { try! self.elg.syncShutdownGracefully() } } - -#endif diff --git a/Tests/WebSocketKitTests/SSLTestHelpers.swift b/Tests/WebSocketKitTests/SSLTestHelpers.swift index 574da640..d515fe3e 100644 --- a/Tests/WebSocketKitTests/SSLTestHelpers.swift +++ b/Tests/WebSocketKitTests/SSLTestHelpers.swift @@ -15,6 +15,7 @@ import Foundation @_implementationOnly import CNIOBoringSSL +import NIOCore @testable import NIOSSL // This function generates a random number suitable for use in an X509 @@ -61,7 +62,7 @@ func randomSerialNumber() -> ASN1_INTEGER { return asn1int } -func generateRSAPrivateKey() -> UnsafeMutablePointer { +func generateRSAPrivateKey() -> OpaquePointer { let exponent = CNIOBoringSSL_BN_new() defer { CNIOBoringSSL_BN_free(exponent) @@ -91,7 +92,7 @@ func addExtension(x509: OpaquePointer, nid: CInt, value: String) { CNIOBoringSSL_X509_EXTENSION_free(ext) } -func generateSelfSignedCert(keygenFunction: () -> UnsafeMutablePointer = generateRSAPrivateKey) -> (NIOSSLCertificate, NIOSSLPrivateKey) { +func generateSelfSignedCert(keygenFunction: () -> OpaquePointer = generateRSAPrivateKey) -> (NIOSSLCertificate, NIOSSLPrivateKey) { let pkey = keygenFunction() let x = CNIOBoringSSL_X509_new()! CNIOBoringSSL_X509_set_version(x, 2) @@ -139,4 +140,3 @@ func generateSelfSignedCert(keygenFunction: () -> UnsafeMutablePointer return (NIOSSLCertificate.fromUnsafePointer(takingOwnership: x), NIOSSLPrivateKey.fromUnsafePointer(takingOwnership: pkey)) } - diff --git a/Tests/WebSocketKitTests/Utilities.swift b/Tests/WebSocketKitTests/Utilities.swift new file mode 100644 index 00000000..60043ec4 --- /dev/null +++ b/Tests/WebSocketKitTests/Utilities.swift @@ -0,0 +1,363 @@ +import XCTest +import Atomics +import NIO +import NIOExtras +import NIOHTTP1 +import NIOSSL +import NIOWebSocket +@testable import WebSocketKit + +extension ServerBootstrap { + static func webSocket( + on eventLoopGroup: EventLoopGroup, + tls: Bool = false, + onUpgrade: @escaping (HTTPRequestHead, WebSocket) -> () + ) -> ServerBootstrap { + return ServerBootstrap(group: eventLoopGroup).childChannelInitializer { channel in + if tls { + let (cert, key) = generateSelfSignedCert() + let configuration = TLSConfiguration.makeServerConfiguration( + certificateChain: [.certificate(cert)], + privateKey: .privateKey(key) + ) + let sslContext = try! NIOSSLContext(configuration: configuration) + let handler = NIOSSLServerHandler(context: sslContext) + _ = channel.pipeline.addHandler(handler) + } + let webSocket = NIOWebSocketServerUpgrader( + shouldUpgrade: { channel, req in + return channel.eventLoop.makeSucceededFuture([:]) + }, + upgradePipelineHandler: { channel, req in + return WebSocket.server(on: channel) { ws in + onUpgrade(req, ws) + } + } + ) + return channel.pipeline.configureHTTPServerPipeline( + withServerUpgrade: ( + upgraders: [webSocket], + completionHandler: { ctx in + // complete + } + ) + ) + } + } +} + +internal final class WebsocketBin { + enum BindTarget { + case unixDomainSocket(String) + case localhostIPv4RandomPort + case localhostIPv6RandomPort + } + + enum Mode { + // refuses all connections + case refuse + // supports http1.1 connections only, which can be either plain text or encrypted + case http1_1(ssl: Bool = false) + } + + enum Proxy { + case none + case simulate(config: ProxyConfig, authorization: String?) + } + + struct ProxyConfig { + var tls: Bool + let headVerification: (ChannelHandlerContext, HTTPRequestHead) -> Void + } + + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + + var port: Int { + return Int(self.serverChannel.localAddress!.port!) + } + + private let mode: Mode + private let sslContext: NIOSSLContext? + private var serverChannel: Channel! + private let isShutdown = ManagedAtomic(false) + + init( + _ mode: Mode = .http1_1(ssl: false), + proxy: Proxy = .none, + bindTarget: BindTarget = .localhostIPv4RandomPort, + sslContext: NIOSSLContext?, + onUpgrade: @escaping (HTTPRequestHead, WebSocket) -> () + ) { + self.mode = mode + self.sslContext = sslContext + + let socketAddress: SocketAddress + switch bindTarget { + case .localhostIPv4RandomPort: + socketAddress = try! SocketAddress(ipAddress: "127.0.0.1", port: 0) + case .localhostIPv6RandomPort: + socketAddress = try! SocketAddress(ipAddress: "::1", port: 0) + case .unixDomainSocket(let path): + socketAddress = try! SocketAddress(unixDomainSocketPath: path) + } + + self.serverChannel = try! ServerBootstrap(group: self.group) + .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) + .childChannelInitializer { channel in + do { + + if case .refuse = mode { + throw HTTPBinError.refusedConnection + } + + let webSocket = NIOWebSocketServerUpgrader( + shouldUpgrade: { channel, req in + return channel.eventLoop.makeSucceededFuture([:]) + }, + upgradePipelineHandler: { channel, req in + return WebSocket.server(on: channel) { ws in + onUpgrade(req, ws) + } + } + ) + + // if we need to simulate a proxy, we need to add those handlers first + if case .simulate(config: let config, authorization: let expectedAuthorization) = proxy { + if config.tls { + try self.syncAddTLSHTTPProxyHandlers( + to: channel, + proxyConfig: config, + expectedAuthorization: expectedAuthorization, + upgraders: [webSocket] + ) + } else { + try self.syncAddHTTPProxyHandlers( + to: channel, + proxyConfig: config, + expectedAuthorization: expectedAuthorization, + upgraders: [webSocket] + ) + } + return channel.eventLoop.makeSucceededVoidFuture() + } + + // if a connection has been established, we need to negotiate TLS before + // anything else. Depending on the negotiation, the HTTPHandlers will be added. + if let sslContext = self.sslContext { + try channel.pipeline.syncOperations.addHandler(NIOSSLServerHandler(context: sslContext)) + } + + // if neither HTTP Proxy nor TLS are wanted, we can add HTTP1 handlers directly + try channel.pipeline.syncOperations.configureHTTPServerPipeline( + withPipeliningAssistance: true, + withServerUpgrade: ( + upgraders: [webSocket], + completionHandler: { ctx in + // complete + } + ), + withErrorHandling: true + ) + return channel.eventLoop.makeSucceededVoidFuture() + } catch { + return channel.eventLoop.makeFailedFuture(error) + } + }.bind(to: socketAddress).wait() + } + + + // In the TLS case we must set up the 'proxy' and the 'server' handlers sequentially + // rather than re-using parts because the requestDecoder stops parsing after a CONNECT request + private func syncAddTLSHTTPProxyHandlers( + to channel: Channel, + proxyConfig: ProxyConfig, + expectedAuthorization: String?, + upgraders: [HTTPServerProtocolUpgrader] + ) throws { + let sync = channel.pipeline.syncOperations + let promise = channel.eventLoop.makePromise(of: Void.self) + + let responseEncoder = HTTPResponseEncoder() + let requestDecoder = ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes)) + let proxySimulator = HTTPProxySimulator(promise: promise, config: proxyConfig, expectedAuthorization: expectedAuthorization) + + try sync.addHandler(responseEncoder) + try sync.addHandler(requestDecoder) + + try sync.addHandler(proxySimulator) + + promise.futureResult.flatMap { _ in + channel.pipeline.removeHandler(proxySimulator) + }.flatMap { _ in + channel.pipeline.removeHandler(responseEncoder) + }.flatMap { _ in + channel.pipeline.removeHandler(requestDecoder) + }.whenComplete { result in + switch result { + case .failure: + channel.close(mode: .all, promise: nil) + case .success: + self.httpProxyEstablished(channel, upgraders: upgraders) + break + } + } + } + + + // In the plain-text case we must set up the 'proxy' and the 'server' handlers simultaneously + // so that the combined proxy/upgrade request can be processed by the separate proxy and upgrade handlers + private func syncAddHTTPProxyHandlers( + to channel: Channel, + proxyConfig: ProxyConfig, + expectedAuthorization: String?, + upgraders: [HTTPServerProtocolUpgrader] + ) throws { + let sync = channel.pipeline.syncOperations + let promise = channel.eventLoop.makePromise(of: Void.self) + + let responseEncoder = HTTPResponseEncoder() + let requestDecoder = ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes)) + let proxySimulator = HTTPProxySimulator(promise: promise, config: proxyConfig, expectedAuthorization: expectedAuthorization) + + let serverPipelineHandler = HTTPServerPipelineHandler() + let serverProtocolErrorHandler = HTTPServerProtocolErrorHandler() + + let extraHTTPHandlers: [RemovableChannelHandler] = [ + requestDecoder, + serverPipelineHandler, + serverProtocolErrorHandler + ] + + try sync.addHandler(responseEncoder) + try sync.addHandler(requestDecoder) + + try sync.addHandler(proxySimulator) + + try sync.addHandler(serverPipelineHandler) + try sync.addHandler(serverProtocolErrorHandler) + + + let upgrader = HTTPServerUpgradeHandler(upgraders: upgraders, + httpEncoder: responseEncoder, + extraHTTPHandlers: extraHTTPHandlers, + upgradeCompletionHandler: { ctx in + // complete + }) + + + try sync.addHandler(upgrader) + + promise.futureResult.flatMap { () -> EventLoopFuture in + channel.pipeline.removeHandler(proxySimulator) + }.whenComplete { result in + switch result { + case .failure: + channel.close(mode: .all, promise: nil) + case .success: + break + } + } + } + + private func httpProxyEstablished(_ channel: Channel, upgraders: [HTTPServerProtocolUpgrader]) { + do { + // if a connection has been established, we need to negotiate TLS before + // anything else. Depending on the negotiation, the HTTPHandlers will be added. + if let sslContext = self.sslContext { + try channel.pipeline.syncOperations.addHandler(NIOSSLServerHandler(context: sslContext)) + } + + try channel.pipeline.syncOperations.configureHTTPServerPipeline( + withPipeliningAssistance: true, + withServerUpgrade: ( + upgraders: upgraders, + completionHandler: { ctx in + // complete + } + ), + withErrorHandling: true + ) + } catch { + // in case of an while modifying the pipeline we should close the connection + channel.close(mode: .all, promise: nil) + } + } + + func shutdown() throws { + self.isShutdown.store(true, ordering: .relaxed) + try self.group.syncShutdownGracefully() + } +} + +enum HTTPBinError: Error { + case refusedConnection + case invalidProxyRequest +} + +final class HTTPProxySimulator: ChannelInboundHandler, RemovableChannelHandler { + typealias InboundIn = HTTPServerRequestPart + typealias InboundOut = HTTPServerResponsePart + typealias OutboundOut = HTTPServerResponsePart + + + // the promise to succeed, once the proxy connection is setup + let promise: EventLoopPromise + let config: WebsocketBin.ProxyConfig + let expectedAuthorization: String? + + var head: HTTPResponseHead + + init(promise: EventLoopPromise, config: WebsocketBin.ProxyConfig, expectedAuthorization: String?) { + self.promise = promise + self.config = config + self.expectedAuthorization = expectedAuthorization + self.head = HTTPResponseHead(version: .init(major: 1, minor: 1), status: .ok, headers: .init([("Content-Length", "0")])) + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let request = self.unwrapInboundIn(data) + switch request { + case .head(let head): + if self.config.tls { + guard head.method == .CONNECT else { + self.head.status = .badRequest + return + } + } else { + guard head.method == .GET else { + self.head.status = .badRequest + return + } + } + + self.config.headVerification(context, head) + + if let expectedAuthorization = self.expectedAuthorization { + guard let authorization = head.headers["proxy-authorization"].first, + expectedAuthorization == authorization else { + self.head.status = .proxyAuthenticationRequired + return + } + } + if !self.config.tls { + context.fireChannelRead(data) + } + + case .body: + () + case .end: + if self.self.config.tls { + context.write(self.wrapOutboundOut(.head(self.head)), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + } + if self.head.status == .ok { + if !self.config.tls { + context.fireChannelRead(data) + } + self.promise.succeed(()) + } else { + self.promise.fail(HTTPBinError.invalidProxyRequest) + } + } + } +} diff --git a/Tests/WebSocketKitTests/WebSocketKitTests.swift b/Tests/WebSocketKitTests/WebSocketKitTests.swift index 7cc33a92..985cb00b 100644 --- a/Tests/WebSocketKitTests/WebSocketKitTests.swift +++ b/Tests/WebSocketKitTests/WebSocketKitTests.swift @@ -20,8 +20,8 @@ final class WebSocketKitTests: XCTestCase { return } - let promise = elg.next().makePromise(of: String.self) - let closePromise = elg.next().makePromise(of: Void.self) + let promise = elg.any().makePromise(of: String.self) + let closePromise = elg.any().makePromise(of: Void.self) WebSocket.connect(to: "ws://localhost:\(port)", on: elg) { ws in ws.send("hello") ws.onText { ws, string in @@ -39,9 +39,9 @@ final class WebSocketKitTests: XCTestCase { } func testServerClose() throws { - let sendPromise = self.elg.next().makePromise(of: Void.self) - let serverClose = self.elg.next().makePromise(of: Void.self) - let clientClose = self.elg.next().makePromise(of: Void.self) + let sendPromise = self.elg.any().makePromise(of: Void.self) + let serverClose = self.elg.any().makePromise(of: Void.self) + let clientClose = self.elg.any().makePromise(of: Void.self) let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in ws.onText { ws, text in if text == "close" { @@ -67,9 +67,9 @@ final class WebSocketKitTests: XCTestCase { } func testClientClose() throws { - let sendPromise = self.elg.next().makePromise(of: Void.self) - let serverClose = self.elg.next().makePromise(of: Void.self) - let clientClose = self.elg.next().makePromise(of: Void.self) + let sendPromise = self.elg.any().makePromise(of: Void.self) + let serverClose = self.elg.any().makePromise(of: Void.self) + let clientClose = self.elg.any().makePromise(of: Void.self) let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in ws.onText { ws, text in ws.send(text) @@ -98,7 +98,7 @@ final class WebSocketKitTests: XCTestCase { } func testImmediateSend() throws { - let promise = self.elg.next().makePromise(of: String.self) + let promise = self.elg.any().makePromise(of: String.self) let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in ws.send("hello") ws.onText { ws, string in @@ -124,8 +124,8 @@ final class WebSocketKitTests: XCTestCase { } func testWebSocketPingPong() throws { - let pingPromise = self.elg.next().makePromise(of: String.self) - let pongPromise = self.elg.next().makePromise(of: String.self) + let pingPromise = self.elg.any().makePromise(of: String.self) + let pongPromise = self.elg.any().makePromise(of: String.self) let pingPongData = ByteBuffer(bytes: "Vapor rules".utf8) let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in @@ -159,10 +159,10 @@ final class WebSocketKitTests: XCTestCase { let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in ws.onText { ws, text in - ws.send(text, opcode: .text, fin: false) - ws.send(" th", opcode: .continuation, fin: false) - ws.send("e mo", opcode: .continuation, fin: false) - ws.send("st", opcode: .continuation, fin: true) + ws.send(.init(string: text), opcode: .text, fin: false) + ws.send(.init(string: " th"), opcode: .continuation, fin: false) + ws.send(.init(string: "e mo"), opcode: .continuation, fin: false) + ws.send(.init(string: "st"), opcode: .continuation, fin: true) } }.bind(host: "localhost", port: 0).wait() @@ -171,12 +171,12 @@ final class WebSocketKitTests: XCTestCase { return } - let promise = elg.next().makePromise(of: String.self) - let closePromise = elg.next().makePromise(of: Void.self) + let promise = elg.any().makePromise(of: String.self) + let closePromise = elg.any().makePromise(of: Void.self) WebSocket.connect(to: "ws://localhost:\(port)", on: elg) { ws in - ws.send("Hel", opcode: .text, fin: false) - ws.send("lo! Vapor r", opcode: .continuation, fin: false) - ws.send("ules", opcode: .continuation, fin: true) + ws.send(.init(string: "Hel"), opcode: .text, fin: false) + ws.send(.init(string: "lo! Vapor r"), opcode: .continuation, fin: false) + ws.send(.init(string: "ules"), opcode: .continuation, fin: true) ws.onText { ws, string in promise.succeed(string) ws.close(promise: closePromise) @@ -188,7 +188,7 @@ final class WebSocketKitTests: XCTestCase { } func testErrorCode() throws { - let promise = self.elg.next().makePromise(of: WebSocketErrorCode.self) + let promise = self.elg.any().makePromise(of: WebSocketErrorCode.self) let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in ws.close(code: .normalClosure, promise: nil) @@ -214,10 +214,10 @@ final class WebSocketKitTests: XCTestCase { } func testHeadersAreSent() throws { - let promiseAuth = self.elg.next().makePromise(of: String.self) + let promiseAuth = self.elg.any().makePromise(of: String.self) // make sure there are no unwanted headers such as `Content-Length` or `Content-Type` - let promiseHasUnwantedHeaders = self.elg.next().makePromise(of: Bool.self) + let promiseHasUnwantedHeaders = self.elg.any().makePromise(of: Bool.self) let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in let headers = req.headers @@ -251,7 +251,7 @@ final class WebSocketKitTests: XCTestCase { } func testQueryParamsAreSent() throws { - let promise = self.elg.next().makePromise(of: String.self) + let promise = self.elg.any().makePromise(of: String.self) let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in promise.succeed(req.uri) @@ -278,7 +278,7 @@ final class WebSocketKitTests: XCTestCase { try XCTSkipIf(true) let port = Int(1337) - let shutdownPromise = self.elg.next().makePromise(of: Void.self) + let shutdownPromise = self.elg.any().makePromise(of: Void.self) let server = try! ServerBootstrap.webSocket(on: self.elg) { req, ws in ws.send("welcome!") @@ -334,10 +334,11 @@ final class WebSocketKitTests: XCTestCase { }.wait() try server.close(mode: .all).wait() + try client.syncShutdown() } func testProxy() throws { - let promise = elg.next().makePromise(of: String.self) + let promise = elg.any().makePromise(of: String.self) let localWebsocketBin: WebsocketBin let verifyProxyHead = { (ctx: ChannelHandlerContext, requestHead: HTTPRequestHead) in @@ -361,7 +362,7 @@ final class WebSocketKitTests: XCTestCase { XCTAssertNoThrow(try localWebsocketBin.shutdown()) } - let closePromise = elg.next().makePromise(of: Void.self) + let closePromise = elg.any().makePromise(of: Void.self) let client = WebSocketClient( eventLoopGroupProvider: .shared(self.elg), @@ -385,10 +386,11 @@ final class WebSocketKitTests: XCTestCase { XCTAssertEqual(try promise.futureResult.wait(), "hello") XCTAssertNoThrow(try closePromise.futureResult.wait()) + try client.syncShutdown() } func testProxyTLS() throws { - let promise = elg.next().makePromise(of: String.self) + let promise = elg.any().makePromise(of: String.self) let (cert, key) = generateSelfSignedCert() let configuration = TLSConfiguration.makeServerConfiguration( @@ -421,7 +423,7 @@ final class WebSocketKitTests: XCTestCase { XCTAssertNoThrow(try localWebsocketBin.shutdown()) } - let closePromise = elg.next().makePromise(of: Void.self) + let closePromise = elg.any().makePromise(of: Void.self) var tlsConfiguration = TLSConfiguration.makeClientConfiguration() tlsConfiguration.certificateVerification = .none @@ -449,384 +451,108 @@ final class WebSocketKitTests: XCTestCase { XCTAssertEqual(try promise.futureResult.wait(), "hello") XCTAssertNoThrow(try closePromise.futureResult.wait()) + try client.syncShutdown() } - - var elg: EventLoopGroup! - override func setUp() { - // needs to be at least two to avoid client / server on same EL timing issues - self.elg = MultiThreadedEventLoopGroup(numberOfThreads: 2) - } - override func tearDown() { - try! self.elg.syncShutdownGracefully() + func testAlternateWebsocketConnectMethods() throws { + let server = try ServerBootstrap.webSocket(on: self.elg) { $1.onText { $0.send($1) } }.bind(host: "localhost", port: 0).wait() + let closePromise1 = self.elg.any().makePromise(of: Void.self) + let closePromise2 = self.elg.any().makePromise(of: Void.self) + guard let port = server.localAddress?.port else { + return XCTFail("couldn't get port from \(String(reflecting: server.localAddress))") + } + WebSocket.connect(scheme: "ws", host: "localhost", port: port, proxy: nil, on: self.elg) { ws in + ws.send("hello"); ws.onText { ws, _ in ws.close(promise: closePromise1) } + }.cascadeFailure(to: closePromise1) + WebSocket.connect(to: "ws://localhost:\(port)", proxy: nil, on: self.elg) { ws in + ws.send("hello"); ws.onText { ws, _ in ws.close(promise: closePromise2) } + }.cascadeFailure(to: closePromise2) + XCTAssertNoThrow(try closePromise1.futureResult.wait()) + XCTAssertNoThrow(try closePromise2.futureResult.wait()) + try server.close(mode: .all).wait() } -} - -extension ServerBootstrap { - static func webSocket( - on eventLoopGroup: EventLoopGroup, - tls: Bool = false, - onUpgrade: @escaping (HTTPRequestHead, WebSocket) -> () - ) -> ServerBootstrap { - return ServerBootstrap(group: eventLoopGroup).childChannelInitializer { channel in - if tls { - let (cert, key) = generateSelfSignedCert() - let configuration = TLSConfiguration.makeServerConfiguration( - certificateChain: [.certificate(cert)], - privateKey: .privateKey(key) - ) - let sslContext = try! NIOSSLContext(configuration: configuration) - let handler = NIOSSLServerHandler(context: sslContext) - _ = channel.pipeline.addHandler(handler) + + func testBadURLInWebsocketConnect() async throws { + XCTAssertThrowsError(try WebSocket.connect(to: "%w", on: self.elg, onUpgrade: { _ in }).wait()) { + guard case .invalidURL = $0 as? WebSocketClient.Error else { + return XCTFail("Expected .invalidURL but got \(String(reflecting: $0))") } - let webSocket = NIOWebSocketServerUpgrader( - shouldUpgrade: { channel, req in - return channel.eventLoop.makeSucceededFuture([:]) - }, - upgradePipelineHandler: { channel, req in - return WebSocket.server(on: channel) { ws in - onUpgrade(req, ws) - } - } - ) - return channel.pipeline.configureHTTPServerPipeline( - withServerUpgrade: ( - upgraders: [webSocket], - completionHandler: { ctx in - // complete - } - ) - ) } } -} - -fileprivate extension WebSocket { - func send( - _ data: String, - opcode: WebSocketOpcode, - fin: Bool = true, - promise: EventLoopPromise? = nil - ) { - self.send(raw: ByteBuffer(string: data).readableBytesView, opcode: opcode, fin: fin, promise: promise) - } -} - - - -internal final class WebsocketBin { - enum BindTarget { - case unixDomainSocket(String) - case localhostIPv4RandomPort - case localhostIPv6RandomPort - } - - enum Mode { - // refuses all connections - case refuse - // supports http1.1 connections only, which can be either plain text or encrypted - case http1_1(ssl: Bool = false) - } - - enum Proxy { - case none - case simulate(config: ProxyConfig, authorization: String?) - } - - struct ProxyConfig { - var tls: Bool - let headVerification: (ChannelHandlerContext, HTTPRequestHead) -> Void - } - - let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) - - var port: Int { - return Int(self.serverChannel.localAddress!.port!) - } - - private let mode: Mode - private let sslContext: NIOSSLContext? - private var serverChannel: Channel! - private let isShutdown = ManagedAtomic(false) - - init( - _ mode: Mode = .http1_1(ssl: false), - proxy: Proxy = .none, - bindTarget: BindTarget = .localhostIPv4RandomPort, - sslContext: NIOSSLContext?, - onUpgrade: @escaping (HTTPRequestHead, WebSocket) -> () - ) { - self.mode = mode - self.sslContext = sslContext - - let socketAddress: SocketAddress - switch bindTarget { - case .localhostIPv4RandomPort: - socketAddress = try! SocketAddress(ipAddress: "127.0.0.1", port: 0) - case .localhostIPv6RandomPort: - socketAddress = try! SocketAddress(ipAddress: "::1", port: 0) - case .unixDomainSocket(let path): - socketAddress = try! SocketAddress(unixDomainSocketPath: path) + + func testOnBinary() throws { + let server = try ServerBootstrap.webSocket(on: self.elg) { $1.onBinary { $0.send($1) } }.bind(host: "localhost", port: 0).wait() + let promise = self.elg.any().makePromise(of: [UInt8].self) + let closePromise = self.elg.any().makePromise(of: Void.self) + guard let port = server.localAddress?.port else { + return XCTFail("couldn't get port from \(String(reflecting: server.localAddress))") } - - self.serverChannel = try! ServerBootstrap(group: self.group) - .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) - .childChannelInitializer { channel in - do { - - if case .refuse = mode { - throw HTTPBinError.refusedConnection - } - - let webSocket = NIOWebSocketServerUpgrader( - shouldUpgrade: { channel, req in - return channel.eventLoop.makeSucceededFuture([:]) - }, - upgradePipelineHandler: { channel, req in - return WebSocket.server(on: channel) { ws in - onUpgrade(req, ws) - } - } - ) - - // if we need to simulate a proxy, we need to add those handlers first - if case .simulate(config: let config, authorization: let expectedAuthorization) = proxy { - if config.tls { - try self.syncAddTLSHTTPProxyHandlers( - to: channel, - proxyConfig: config, - expectedAuthorization: expectedAuthorization, - upgraders: [webSocket] - ) - } else { - try self.syncAddHTTPProxyHandlers( - to: channel, - proxyConfig: config, - expectedAuthorization: expectedAuthorization, - upgraders: [webSocket] - ) - } - return channel.eventLoop.makeSucceededVoidFuture() - } - - // if a connection has been established, we need to negotiate TLS before - // anything else. Depending on the negotiation, the HTTPHandlers will be added. - if let sslContext = self.sslContext { - try channel.pipeline.syncOperations.addHandler(NIOSSLServerHandler(context: sslContext)) - } - - // if neither HTTP Proxy nor TLS are wanted, we can add HTTP1 handlers directly - try channel.pipeline.syncOperations.configureHTTPServerPipeline( - withPipeliningAssistance: true, - withServerUpgrade: ( - upgraders: [webSocket], - completionHandler: { ctx in - // complete - } - ), - withErrorHandling: true - ) - return channel.eventLoop.makeSucceededVoidFuture() - } catch { - return channel.eventLoop.makeFailedFuture(error) - } - }.bind(to: socketAddress).wait() - } - - - // In the TLS case we must set up the 'proxy' and the 'server' handlers sequentially - // rather than re-using parts because the requestDecoder stops parsing after a CONNECT request - private func syncAddTLSHTTPProxyHandlers( - to channel: Channel, - proxyConfig: ProxyConfig, - expectedAuthorization: String?, - upgraders: [HTTPServerProtocolUpgrader] - ) throws { - let sync = channel.pipeline.syncOperations - let promise = channel.eventLoop.makePromise(of: Void.self) - - let responseEncoder = HTTPResponseEncoder() - let requestDecoder = ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes)) - let proxySimulator = HTTPProxySimulator(promise: promise, config: proxyConfig, expectedAuthorization: expectedAuthorization) - - try sync.addHandler(responseEncoder) - try sync.addHandler(requestDecoder) - - try sync.addHandler(proxySimulator) - - promise.futureResult.flatMap { _ in - channel.pipeline.removeHandler(proxySimulator) - }.flatMap { _ in - channel.pipeline.removeHandler(responseEncoder) - }.flatMap { _ in - channel.pipeline.removeHandler(requestDecoder) - }.whenComplete { result in - switch result { - case .failure: - channel.close(mode: .all, promise: nil) - case .success: - self.httpProxyEstablished(channel, upgraders: upgraders) - break + WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { ws in + ws.send([0x01]) + ws.onBinary { ws, buf in + promise.succeed(.init(buf.readableBytesView)) + ws.close(promise: closePromise) } + }.whenFailure { + promise.fail($0) + closePromise.fail($0) } + XCTAssertEqual(try promise.futureResult.wait(), [0x01]) + XCTAssertNoThrow(try closePromise.futureResult.wait()) + try server.close(mode: .all).wait() } - - - // In the plain-text case we must set up the 'proxy' and the 'server' handlers simultaneously - // so that the combined proxy/upgrade request can be processed by the separate proxy and upgrade handlers - private func syncAddHTTPProxyHandlers( - to channel: Channel, - proxyConfig: ProxyConfig, - expectedAuthorization: String?, - upgraders: [HTTPServerProtocolUpgrader] - ) throws { - let sync = channel.pipeline.syncOperations - let promise = channel.eventLoop.makePromise(of: Void.self) - - let responseEncoder = HTTPResponseEncoder() - let requestDecoder = ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes)) - let proxySimulator = HTTPProxySimulator(promise: promise, config: proxyConfig, expectedAuthorization: expectedAuthorization) - - let serverPipelineHandler = HTTPServerPipelineHandler() - let serverProtocolErrorHandler = HTTPServerProtocolErrorHandler() - - let extraHTTPHandlers: [RemovableChannelHandler] = [ - requestDecoder, - serverPipelineHandler, - serverProtocolErrorHandler - ] - - try sync.addHandler(responseEncoder) - try sync.addHandler(requestDecoder) - - try sync.addHandler(proxySimulator) - - try sync.addHandler(serverPipelineHandler) - try sync.addHandler(serverProtocolErrorHandler) - - - let upgrader = HTTPServerUpgradeHandler(upgraders: upgraders, - httpEncoder: responseEncoder, - extraHTTPHandlers: extraHTTPHandlers, - upgradeCompletionHandler: { ctx in - // complete - }) - - - try sync.addHandler(upgrader) - - promise.futureResult.flatMap { () -> EventLoopFuture in - channel.pipeline.removeHandler(proxySimulator) - }.whenComplete { result in - switch result { - case .failure: - channel.close(mode: .all, promise: nil) - case .success: - break - } + + func testSendPing() throws { + let server = try ServerBootstrap.webSocket(on: self.elg) { _, _ in }.bind(host: "localhost", port: 0).wait() + let promise = self.elg.any().makePromise(of: Void.self) + let closePromise = self.elg.any().makePromise(of: Void.self) + guard let port = server.localAddress?.port else { + return XCTFail("couldn't get port from \(String(reflecting: server.localAddress))") } - } - - private func httpProxyEstablished(_ channel: Channel, upgraders: [HTTPServerProtocolUpgrader]) { - do { - // if a connection has been established, we need to negotiate TLS before - // anything else. Depending on the negotiation, the HTTPHandlers will be added. - if let sslContext = self.sslContext { - try channel.pipeline.syncOperations.addHandler(NIOSSLServerHandler(context: sslContext)) + WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { ws in + ws.sendPing() + ws.onPong { + promise.succeed() + $0.close(promise: closePromise) } - - try channel.pipeline.syncOperations.configureHTTPServerPipeline( - withPipeliningAssistance: true, - withServerUpgrade: ( - upgraders: upgraders, - completionHandler: { ctx in - // complete - } - ), - withErrorHandling: true - ) - } catch { - // in case of an while modifying the pipeline we should close the connection - channel.close(mode: .all, promise: nil) + }.cascadeFailure(to: closePromise) + XCTAssertNoThrow(try promise.futureResult.wait()) + XCTAssertNoThrow(try closePromise.futureResult.wait()) + try server.close(mode: .all).wait() + } + + func testSetPingInterval() throws { + let server = try ServerBootstrap.webSocket(on: self.elg) { _, _ in }.bind(host: "localhost", port: 0).wait() + let promise = self.elg.any().makePromise(of: Void.self) + let closePromise = self.elg.any().makePromise(of: Void.self) + guard let port = server.localAddress?.port else { + return XCTFail("couldn't get port from \(String(reflecting: server.localAddress))") } + WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { ws in + ws.pingInterval = .milliseconds(100) + ws.onPong { + promise.succeed() + $0.close(promise: closePromise) + } + }.cascadeFailure(to: closePromise) + XCTAssertNoThrow(try promise.futureResult.wait()) + XCTAssertNoThrow(try closePromise.futureResult.wait()) + try server.close(mode: .all).wait() } - - func shutdown() throws { - self.isShutdown.store(true, ordering: .relaxed) - try self.group.syncShutdownGracefully() + + func testCreateNewELGAndShutdown() throws { + let client = WebSocketClient(eventLoopGroupProvider: .createNew) + try client.syncShutdown() } -} - -enum HTTPBinError: Error { - case refusedConnection - case invalidProxyRequest -} - -final class HTTPProxySimulator: ChannelInboundHandler, RemovableChannelHandler { - typealias InboundIn = HTTPServerRequestPart - typealias InboundOut = HTTPServerResponsePart - typealias OutboundOut = HTTPServerResponsePart - - // the promise to succeed, once the proxy connection is setup - let promise: EventLoopPromise - let config: WebsocketBin.ProxyConfig - let expectedAuthorization: String? - - var head: HTTPResponseHead - - init(promise: EventLoopPromise, config: WebsocketBin.ProxyConfig, expectedAuthorization: String?) { - self.promise = promise - self.config = config - self.expectedAuthorization = expectedAuthorization - self.head = HTTPResponseHead(version: .init(major: 1, minor: 1), status: .ok, headers: .init([("Content-Length", "0")])) + var elg: EventLoopGroup! + + override func setUp() { + // needs to be at least two to avoid client / server on same EL timing issues + self.elg = MultiThreadedEventLoopGroup(numberOfThreads: 2) } - - func channelRead(context: ChannelHandlerContext, data: NIOAny) { - let request = self.unwrapInboundIn(data) - switch request { - case .head(let head): - if self.config.tls { - guard head.method == .CONNECT else { - self.head.status = .badRequest - return - } - } else { - guard head.method == .GET else { - self.head.status = .badRequest - return - } - } - - self.config.headVerification(context, head) - - if let expectedAuthorization = self.expectedAuthorization { - guard let authorization = head.headers["proxy-authorization"].first, - expectedAuthorization == authorization else { - self.head.status = .proxyAuthenticationRequired - return - } - } - if !self.config.tls { - context.fireChannelRead(data) - } - - case .body: - () - case .end: - if self.self.config.tls { - context.write(self.wrapOutboundOut(.head(self.head)), promise: nil) - context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) - } - if self.head.status == .ok { - if !self.config.tls { - context.fireChannelRead(data) - } - self.promise.succeed(()) - } else { - self.promise.fail(HTTPBinError.invalidProxyRequest) - } - } + + override func tearDown() { + try! self.elg.syncShutdownGracefully() } } -