From 17c0afbf24f4e288e183bc34317dc97b5cd084bd Mon Sep 17 00:00:00 2001
From: Gwynne Raskind
Date: Mon, 15 May 2023 00:18:10 -0500
Subject: [PATCH] Update for new NIOSSL (#132)
* Bump to Swift 5.6 minimum and update README
* The usual CI cleanup
* Require a more recent NIO, switch to NIOCore instead of NIO, use EventLoopGroup.any() instead of .next(), make EventLoopGroupProvider an alias for NIO's version, use annotated exports
* Ditch unneeded Concurrency checks (which also incidentally enables back-deployment to 10.15)
* Update SSLTestHelpers to reflect the current state of the original upstream (which basically just means using OpaquePointer instead of UMP, partly because EVP_PKEY is no longer visible from CNIOBoringSSL
* Add API breakage allowlist for this branch
* Add test coverage for all the connect methods, binary frames, and ping intervals.
---
...allowlist-branch-update-for-new-niossl.txt | 3 +
.github/workflows/main-codecov.yml | 9 -
.github/workflows/projectboard.yml | 28 +-
.github/workflows/test.yml | 17 +-
Package.swift | 8 +-
README.md | 9 +-
.../Concurrency/WebSocket+Concurrency.swift | 4 -
Sources/WebSocketKit/Exports.swift | 29 +-
.../HTTPUpgradeRequestHandler.swift | 2 +-
Sources/WebSocketKit/WebSocket+Connect.swift | 4 +-
Sources/WebSocketKit/WebSocket.swift | 2 +-
Sources/WebSocketKit/WebSocketClient.swift | 10 +-
Sources/WebSocketKit/WebSocketHandler.swift | 2 +-
.../AsyncWebSocketKitTests.swift | 94 +++-
Tests/WebSocketKitTests/SSLTestHelpers.swift | 6 +-
Tests/WebSocketKitTests/Utilities.swift | 363 +++++++++++++
.../WebSocketKitTests/WebSocketKitTests.swift | 508 ++++--------------
17 files changed, 627 insertions(+), 471 deletions(-)
create mode 100644 .api-breakage/allowlist-branch-update-for-new-niossl.txt
delete mode 100644 .github/workflows/main-codecov.yml
create mode 100644 Tests/WebSocketKitTests/Utilities.swift
diff --git a/.api-breakage/allowlist-branch-update-for-new-niossl.txt b/.api-breakage/allowlist-branch-update-for-new-niossl.txt
new file mode 100644
index 00000000..a32a2389
--- /dev/null
+++ b/.api-breakage/allowlist-branch-update-for-new-niossl.txt
@@ -0,0 +1,3 @@
+API breakage: import NIO has been renamed to import NIOCore
+API breakage: import NIO has been renamed to import NIOPosix
+
diff --git a/.github/workflows/main-codecov.yml b/.github/workflows/main-codecov.yml
deleted file mode 100644
index 1d0fe384..00000000
--- a/.github/workflows/main-codecov.yml
+++ /dev/null
@@ -1,9 +0,0 @@
-name: Update code coverage baselines
-on:
- push: { branches: [ main ] }
-jobs:
- update-main-codecov:
- uses: vapor/ci/.github/workflows/run-unit-tests.yml@reusable-workflows
- with:
- with_coverage: true
- with_tsan: true
diff --git a/.github/workflows/projectboard.yml b/.github/workflows/projectboard.yml
index 0e4e66f6..a0e6d988 100644
--- a/.github/workflows/projectboard.yml
+++ b/.github/workflows/projectboard.yml
@@ -5,27 +5,7 @@ on:
types: [reopened, closed, labeled, unlabeled, assigned, unassigned]
jobs:
- setup_matrix_input:
- runs-on: ubuntu-latest
-
- steps:
- - id: set-matrix
- run: |
- output=$(curl ${{ github.event.issue.url }}/labels | jq '.[] | .name') || output=""
-
- echo '======================'
- echo 'Process incoming data'
- echo '======================'
- json=$(echo $output | sed 's/"\s"/","/g')
- echo $json
- echo "::set-output name=matrix::$(echo $json)"
- outputs:
- issueTags: ${{ steps.set-matrix.outputs.matrix }}
-
- Manage_project_issues:
- needs: setup_matrix_input
- uses: vapor/ci/.github/workflows/issues-to-project-board.yml@main
- with:
- labelsJson: ${{ needs.setup_matrix_input.outputs.issueTags }}
- secrets:
- PROJECT_BOARD_AUTOMATION_PAT: "${{ secrets.PROJECT_BOARD_AUTOMATION_PAT }}"
+ update_project_boards:
+ name: Update project boards
+ uses: vapor/ci/.github/workflows/update-project-boards-for-issue.yml@reusable-workflows
+ secrets: inherit
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index bff42997..b2288dc2 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -1,11 +1,17 @@
name: test
+concurrency:
+ group: ${{ github.workflow }}-${{ github.ref }}
+ cancel-in-progress: true
on:
-- pull_request
+ pull_request: { types: [opened, reopened, synchronize, ready_for_review] }
+ push: { branches: [ main ] }
+
jobs:
+
vapor-integration:
+ if: ${{ !(github.event.pull_request.draft || false) }}
runs-on: ubuntu-latest
- container:
- image: swift:5.7
+ container: swift:5.8-jammy
steps:
- name: Check out package
uses: actions/checkout@v3
@@ -14,8 +20,7 @@ jobs:
uses: actions/checkout@v3
with: { repository: 'vapor/vapor', path: 'vapor' }
- name: Use local package in Vapor
- run: |
- swift package --package-path vapor edit websocket-kit --path websocket-kit
+ run: swift package --package-path vapor edit websocket-kit --path websocket-kit
- name: Run Vapor tests
run: swift test --package-path vapor
@@ -24,6 +29,4 @@ jobs:
with:
with_coverage: true
with_tsan: false
- coverage_ignores: '/Tests/'
with_public_api_check: ${{ github.event_name == 'pull_request' }}
-
diff --git a/Package.swift b/Package.swift
index 1fd56471..4841d15c 100644
--- a/Package.swift
+++ b/Package.swift
@@ -1,4 +1,4 @@
-// swift-tools-version:5.4
+// swift-tools-version:5.6
import PackageDescription
let package = Package(
@@ -6,21 +6,21 @@ let package = Package(
platforms: [
.macOS(.v10_15),
.iOS(.v13),
+ .watchOS(.v6),
.tvOS(.v13),
],
products: [
.library(name: "WebSocketKit", targets: ["WebSocketKit"]),
],
dependencies: [
- .package(url: "https://github.com/apple/swift-nio.git", from: "2.33.0"),
+ .package(url: "https://github.com/apple/swift-nio.git", from: "2.53.0"),
.package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.16.0"),
- .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.14.0"),
+ .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.24.0"),
.package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.11.4"),
.package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.2"),
],
targets: [
.target(name: "WebSocketKit", dependencies: [
- .product(name: "NIO", package: "swift-nio"),
.product(name: "NIOCore", package: "swift-nio"),
.product(name: "NIOConcurrencyHelpers", package: "swift-nio"),
.product(name: "NIOExtras", package: "swift-nio-extras"),
diff --git a/README.md b/README.md
index 71370ce1..2066c308 100644
--- a/README.md
+++ b/README.md
@@ -15,10 +15,13 @@
-
-
+
+
-
+
+
+
+
diff --git a/Sources/WebSocketKit/Concurrency/WebSocket+Concurrency.swift b/Sources/WebSocketKit/Concurrency/WebSocket+Concurrency.swift
index 6b7ecc3b..d0e9c0d3 100644
--- a/Sources/WebSocketKit/Concurrency/WebSocket+Concurrency.swift
+++ b/Sources/WebSocketKit/Concurrency/WebSocket+Concurrency.swift
@@ -1,10 +1,8 @@
-#if compiler(>=5.5) && canImport(_Concurrency)
import NIOCore
import NIOWebSocket
import Foundation
import NIOHTTP1
-@available(macOS 12, iOS 15, watchOS 8, tvOS 15, *)
extension WebSocket {
public func send(_ text: S) async throws
where S: Collection, S.Element == Character
@@ -142,5 +140,3 @@ extension WebSocket {
).get()
}
}
-
-#endif
diff --git a/Sources/WebSocketKit/Exports.swift b/Sources/WebSocketKit/Exports.swift
index 9c1931e1..11ce0e49 100644
--- a/Sources/WebSocketKit/Exports.swift
+++ b/Sources/WebSocketKit/Exports.swift
@@ -1,14 +1,27 @@
-#if !BUILDING_DOCC
+#if swift(>=5.8)
-@_exported import struct NIO.ByteBuffer
-@_exported import protocol NIO.Channel
-@_exported import protocol NIO.EventLoop
-@_exported import protocol NIO.EventLoopGroup
-@_exported import struct NIO.EventLoopPromise
-@_exported import class NIO.EventLoopFuture
+@_documentation(visibility: internal) @_exported import struct NIOCore.ByteBuffer
+@_documentation(visibility: internal) @_exported import protocol NIOCore.Channel
+@_documentation(visibility: internal) @_exported import protocol NIOCore.EventLoop
+@_documentation(visibility: internal) @_exported import protocol NIOCore.EventLoopGroup
+@_documentation(visibility: internal) @_exported import struct NIOCore.EventLoopPromise
+@_documentation(visibility: internal) @_exported import class NIOCore.EventLoopFuture
+
+@_documentation(visibility: internal) @_exported import struct NIOHTTP1.HTTPHeaders
+
+@_documentation(visibility: internal) @_exported import struct Foundation.URL
+
+#else
+
+@_exported import struct NIOCore.ByteBuffer
+@_exported import protocol NIOCore.Channel
+@_exported import protocol NIOCore.EventLoop
+@_exported import protocol NIOCore.EventLoopGroup
+@_exported import struct NIOCore.EventLoopPromise
+@_exported import class NIOCore.EventLoopFuture
@_exported import struct NIOHTTP1.HTTPHeaders
@_exported import struct Foundation.URL
-#endif
\ No newline at end of file
+#endif
diff --git a/Sources/WebSocketKit/HTTPUpgradeRequestHandler.swift b/Sources/WebSocketKit/HTTPUpgradeRequestHandler.swift
index 84af52d1..817d67b3 100644
--- a/Sources/WebSocketKit/HTTPUpgradeRequestHandler.swift
+++ b/Sources/WebSocketKit/HTTPUpgradeRequestHandler.swift
@@ -1,4 +1,4 @@
-import NIO
+import NIOCore
import NIOHTTP1
final class HTTPUpgradeRequestHandler: ChannelInboundHandler, RemovableChannelHandler {
diff --git a/Sources/WebSocketKit/WebSocket+Connect.swift b/Sources/WebSocketKit/WebSocket+Connect.swift
index 643fcb6c..ca94540b 100644
--- a/Sources/WebSocketKit/WebSocket+Connect.swift
+++ b/Sources/WebSocketKit/WebSocket+Connect.swift
@@ -20,7 +20,7 @@ extension WebSocket {
onUpgrade: @escaping (WebSocket) -> ()
) -> EventLoopFuture {
guard let url = URL(string: url) else {
- return eventLoopGroup.next().makeFailedFuture(WebSocketClient.Error.invalidURL)
+ return eventLoopGroup.any().makeFailedFuture(WebSocketClient.Error.invalidURL)
}
return self.connect(
to: url,
@@ -174,7 +174,7 @@ extension WebSocket {
onUpgrade: @escaping (WebSocket) -> ()
) -> EventLoopFuture {
guard let url = URL(string: url) else {
- return eventLoopGroup.next().makeFailedFuture(WebSocketClient.Error.invalidURL)
+ return eventLoopGroup.any().makeFailedFuture(WebSocketClient.Error.invalidURL)
}
let scheme = url.scheme ?? "ws"
return self.connect(
diff --git a/Sources/WebSocketKit/WebSocket.swift b/Sources/WebSocketKit/WebSocket.swift
index ce915c6d..3e6f679c 100644
--- a/Sources/WebSocketKit/WebSocket.swift
+++ b/Sources/WebSocketKit/WebSocket.swift
@@ -1,4 +1,4 @@
-import NIO
+import NIOCore
import NIOWebSocket
import NIOHTTP1
import NIOSSL
diff --git a/Sources/WebSocketKit/WebSocketClient.swift b/Sources/WebSocketKit/WebSocketClient.swift
index b5cd072d..08bf48e5 100644
--- a/Sources/WebSocketKit/WebSocketClient.swift
+++ b/Sources/WebSocketKit/WebSocketClient.swift
@@ -1,5 +1,6 @@
import Foundation
-import NIO
+import NIOCore
+import NIOPosix
import NIOConcurrencyHelpers
import NIOExtras
import NIOHTTP1
@@ -18,10 +19,7 @@ public final class WebSocketClient {
}
}
- public enum EventLoopGroupProvider {
- case shared(EventLoopGroup)
- case createNew
- }
+ public typealias EventLoopGroupProvider = NIOEventLoopGroupProvider
public struct Configuration {
public var tlsConfiguration: TLSConfiguration?
@@ -106,7 +104,7 @@ public final class WebSocketClient {
onUpgrade: @escaping (WebSocket) -> ()
) -> EventLoopFuture {
assert(["ws", "wss"].contains(scheme))
- let upgradePromise = self.group.next().makePromise(of: Void.self)
+ let upgradePromise = self.group.any().makePromise(of: Void.self)
let bootstrap = WebSocketClient.makeBootstrap(on: self.group)
.channelOption(ChannelOptions.socket(SocketOptionLevel(IPPROTO_TCP), TCP_NODELAY), value: 1)
.channelInitializer { channel -> EventLoopFuture in
diff --git a/Sources/WebSocketKit/WebSocketHandler.swift b/Sources/WebSocketKit/WebSocketHandler.swift
index 1af0f76a..45f266ce 100644
--- a/Sources/WebSocketKit/WebSocketHandler.swift
+++ b/Sources/WebSocketKit/WebSocketHandler.swift
@@ -1,4 +1,4 @@
-import NIO
+import NIOCore
import NIOWebSocket
extension WebSocket {
diff --git a/Tests/WebSocketKitTests/AsyncWebSocketKitTests.swift b/Tests/WebSocketKitTests/AsyncWebSocketKitTests.swift
index 04538f65..e20ffa93 100644
--- a/Tests/WebSocketKitTests/AsyncWebSocketKitTests.swift
+++ b/Tests/WebSocketKitTests/AsyncWebSocketKitTests.swift
@@ -1,25 +1,23 @@
-#if compiler(>=5.5) && canImport(_Concurrency)
import XCTest
import NIO
import NIOHTTP1
import NIOWebSocket
@testable import WebSocketKit
-@available(macOS 12, iOS 15, watchOS 8, tvOS 15, *)
final class AsyncWebSocketKitTests: XCTestCase {
func testWebSocketEcho() async throws {
- let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in
+ let server = try await ServerBootstrap.webSocket(on: self.elg) { req, ws in
ws.onText { ws, text in
ws.send(text)
}
- }.bind(host: "localhost", port: 0).wait()
+ }.bind(host: "localhost", port: 0).get()
guard let port = server.localAddress?.port else {
XCTFail("couldn't get port from \(server.localAddress.debugDescription)")
return
}
- let promise = elg.next().makePromise(of: String.self)
+ let promise = elg.any().makePromise(of: String.self)
try await WebSocket.connect(to: "ws://localhost:\(port)", on: elg) { ws in
do {
@@ -42,7 +40,91 @@ final class AsyncWebSocketKitTests: XCTestCase {
try await server.close(mode: .all)
}
+ func testAlternateWebsocketConnectMethods() async throws {
+ let server = try await ServerBootstrap.webSocket(on: self.elg) { $1.onText { $0.send($1) } }.bind(host: "localhost", port: 0).get()
+ let promise = self.elg.any().makePromise(of: Void.self)
+ guard let port = server.localAddress?.port else {
+ return XCTFail("couldn't get port from \(String(reflecting: server.localAddress))")
+ }
+ try await WebSocket.connect(scheme: "ws", host: "localhost", port: port, on: self.elg) { (ws) async in
+ do { try await ws.send("hello") } catch { promise.fail(error); try? await ws.close() }
+ ws.onText { ws, _ in
+ promise.succeed(())
+ do { try await ws.close() } catch { XCTFail("Failed to close websocket: \(String(reflecting: error))") }
+ }
+ }
+ try await promise.futureResult.get()
+ try await server.close(mode: .all)
+ }
+
+ func testBadURLInWebsocketConnect() async throws {
+ do {
+ try await WebSocket.connect(to: "%w", on: self.elg, onUpgrade: { _ async in })
+ XCTAssertThrowsError({}())
+ } catch {
+ XCTAssertThrowsError(try { throw error }()) {
+ guard case .invalidURL = $0 as? WebSocketClient.Error else {
+ return XCTFail("Expected .invalidURL but got \(String(reflecting: $0))")
+ }
+ }
+ }
+ }
+
+ func testOnBinary() async throws {
+ let server = try await ServerBootstrap.webSocket(on: self.elg) { $1.onBinary { $0.send($1) } }.bind(host: "localhost", port: 0).get()
+ let promise = self.elg.any().makePromise(of: [UInt8].self)
+ guard let port = server.localAddress?.port else {
+ return XCTFail("couldn't get port from \(String(reflecting: server.localAddress))")
+ }
+ try await WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { ws in
+ do { try await ws.send([0x01]) } catch { promise.fail(error); try? await ws.close() }
+ ws.onBinary { ws, buf in
+ promise.succeed(.init(buf.readableBytesView))
+ do { try await ws.close() }
+ catch { XCTFail("Failed to close websocket: \(String(reflecting: error))") }
+ }
+ }
+ let result = try await promise.futureResult.get()
+ XCTAssertEqual(result, [0x01])
+ try await server.close(mode: .all)
+ }
+
+ func testSendPing() async throws {
+ let server = try await ServerBootstrap.webSocket(on: self.elg) { _, _ in }.bind(host: "localhost", port: 0).get()
+ let promise = self.elg.any().makePromise(of: Void.self)
+ guard let port = server.localAddress?.port else {
+ return XCTFail("couldn't get port from \(String(reflecting: server.localAddress))")
+ }
+ try await WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { (ws) async in
+ do { try await ws.sendPing() } catch { promise.fail(error); try? await ws.close() }
+ ws.onPong {
+ promise.succeed(())
+ do { try await $0.close() } catch { XCTFail("Failed to close websocket: \(String(reflecting: error))") }
+ }
+ }
+ try await promise.futureResult.get()
+ try await server.close(mode: .all)
+ }
+
+ func testSetPingInterval() async throws {
+ let server = try await ServerBootstrap.webSocket(on: self.elg) { _, _ in }.bind(host: "localhost", port: 0).get()
+ let promise = self.elg.any().makePromise(of: Void.self)
+ guard let port = server.localAddress?.port else {
+ return XCTFail("couldn't get port from \(String(reflecting: server.localAddress))")
+ }
+ try await WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { (ws) async in
+ ws.pingInterval = .milliseconds(100)
+ ws.onPong {
+ promise.succeed(())
+ do { try await $0.close() } catch { XCTFail("Failed to close websocket: \(String(reflecting: error))") }
+ }
+ }
+ try await promise.futureResult.get()
+ try await server.close(mode: .all)
+ }
+
var elg: EventLoopGroup!
+
override func setUp() {
// needs to be at least two to avoid client / server on same EL timing issues
self.elg = MultiThreadedEventLoopGroup(numberOfThreads: 2)
@@ -51,5 +133,3 @@ final class AsyncWebSocketKitTests: XCTestCase {
try! self.elg.syncShutdownGracefully()
}
}
-
-#endif
diff --git a/Tests/WebSocketKitTests/SSLTestHelpers.swift b/Tests/WebSocketKitTests/SSLTestHelpers.swift
index 574da640..d515fe3e 100644
--- a/Tests/WebSocketKitTests/SSLTestHelpers.swift
+++ b/Tests/WebSocketKitTests/SSLTestHelpers.swift
@@ -15,6 +15,7 @@
import Foundation
@_implementationOnly import CNIOBoringSSL
+import NIOCore
@testable import NIOSSL
// This function generates a random number suitable for use in an X509
@@ -61,7 +62,7 @@ func randomSerialNumber() -> ASN1_INTEGER {
return asn1int
}
-func generateRSAPrivateKey() -> UnsafeMutablePointer {
+func generateRSAPrivateKey() -> OpaquePointer {
let exponent = CNIOBoringSSL_BN_new()
defer {
CNIOBoringSSL_BN_free(exponent)
@@ -91,7 +92,7 @@ func addExtension(x509: OpaquePointer, nid: CInt, value: String) {
CNIOBoringSSL_X509_EXTENSION_free(ext)
}
-func generateSelfSignedCert(keygenFunction: () -> UnsafeMutablePointer = generateRSAPrivateKey) -> (NIOSSLCertificate, NIOSSLPrivateKey) {
+func generateSelfSignedCert(keygenFunction: () -> OpaquePointer = generateRSAPrivateKey) -> (NIOSSLCertificate, NIOSSLPrivateKey) {
let pkey = keygenFunction()
let x = CNIOBoringSSL_X509_new()!
CNIOBoringSSL_X509_set_version(x, 2)
@@ -139,4 +140,3 @@ func generateSelfSignedCert(keygenFunction: () -> UnsafeMutablePointer
return (NIOSSLCertificate.fromUnsafePointer(takingOwnership: x), NIOSSLPrivateKey.fromUnsafePointer(takingOwnership: pkey))
}
-
diff --git a/Tests/WebSocketKitTests/Utilities.swift b/Tests/WebSocketKitTests/Utilities.swift
new file mode 100644
index 00000000..60043ec4
--- /dev/null
+++ b/Tests/WebSocketKitTests/Utilities.swift
@@ -0,0 +1,363 @@
+import XCTest
+import Atomics
+import NIO
+import NIOExtras
+import NIOHTTP1
+import NIOSSL
+import NIOWebSocket
+@testable import WebSocketKit
+
+extension ServerBootstrap {
+ static func webSocket(
+ on eventLoopGroup: EventLoopGroup,
+ tls: Bool = false,
+ onUpgrade: @escaping (HTTPRequestHead, WebSocket) -> ()
+ ) -> ServerBootstrap {
+ return ServerBootstrap(group: eventLoopGroup).childChannelInitializer { channel in
+ if tls {
+ let (cert, key) = generateSelfSignedCert()
+ let configuration = TLSConfiguration.makeServerConfiguration(
+ certificateChain: [.certificate(cert)],
+ privateKey: .privateKey(key)
+ )
+ let sslContext = try! NIOSSLContext(configuration: configuration)
+ let handler = NIOSSLServerHandler(context: sslContext)
+ _ = channel.pipeline.addHandler(handler)
+ }
+ let webSocket = NIOWebSocketServerUpgrader(
+ shouldUpgrade: { channel, req in
+ return channel.eventLoop.makeSucceededFuture([:])
+ },
+ upgradePipelineHandler: { channel, req in
+ return WebSocket.server(on: channel) { ws in
+ onUpgrade(req, ws)
+ }
+ }
+ )
+ return channel.pipeline.configureHTTPServerPipeline(
+ withServerUpgrade: (
+ upgraders: [webSocket],
+ completionHandler: { ctx in
+ // complete
+ }
+ )
+ )
+ }
+ }
+}
+
+internal final class WebsocketBin {
+ enum BindTarget {
+ case unixDomainSocket(String)
+ case localhostIPv4RandomPort
+ case localhostIPv6RandomPort
+ }
+
+ enum Mode {
+ // refuses all connections
+ case refuse
+ // supports http1.1 connections only, which can be either plain text or encrypted
+ case http1_1(ssl: Bool = false)
+ }
+
+ enum Proxy {
+ case none
+ case simulate(config: ProxyConfig, authorization: String?)
+ }
+
+ struct ProxyConfig {
+ var tls: Bool
+ let headVerification: (ChannelHandlerContext, HTTPRequestHead) -> Void
+ }
+
+ let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
+
+ var port: Int {
+ return Int(self.serverChannel.localAddress!.port!)
+ }
+
+ private let mode: Mode
+ private let sslContext: NIOSSLContext?
+ private var serverChannel: Channel!
+ private let isShutdown = ManagedAtomic(false)
+
+ init(
+ _ mode: Mode = .http1_1(ssl: false),
+ proxy: Proxy = .none,
+ bindTarget: BindTarget = .localhostIPv4RandomPort,
+ sslContext: NIOSSLContext?,
+ onUpgrade: @escaping (HTTPRequestHead, WebSocket) -> ()
+ ) {
+ self.mode = mode
+ self.sslContext = sslContext
+
+ let socketAddress: SocketAddress
+ switch bindTarget {
+ case .localhostIPv4RandomPort:
+ socketAddress = try! SocketAddress(ipAddress: "127.0.0.1", port: 0)
+ case .localhostIPv6RandomPort:
+ socketAddress = try! SocketAddress(ipAddress: "::1", port: 0)
+ case .unixDomainSocket(let path):
+ socketAddress = try! SocketAddress(unixDomainSocketPath: path)
+ }
+
+ self.serverChannel = try! ServerBootstrap(group: self.group)
+ .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
+ .childChannelInitializer { channel in
+ do {
+
+ if case .refuse = mode {
+ throw HTTPBinError.refusedConnection
+ }
+
+ let webSocket = NIOWebSocketServerUpgrader(
+ shouldUpgrade: { channel, req in
+ return channel.eventLoop.makeSucceededFuture([:])
+ },
+ upgradePipelineHandler: { channel, req in
+ return WebSocket.server(on: channel) { ws in
+ onUpgrade(req, ws)
+ }
+ }
+ )
+
+ // if we need to simulate a proxy, we need to add those handlers first
+ if case .simulate(config: let config, authorization: let expectedAuthorization) = proxy {
+ if config.tls {
+ try self.syncAddTLSHTTPProxyHandlers(
+ to: channel,
+ proxyConfig: config,
+ expectedAuthorization: expectedAuthorization,
+ upgraders: [webSocket]
+ )
+ } else {
+ try self.syncAddHTTPProxyHandlers(
+ to: channel,
+ proxyConfig: config,
+ expectedAuthorization: expectedAuthorization,
+ upgraders: [webSocket]
+ )
+ }
+ return channel.eventLoop.makeSucceededVoidFuture()
+ }
+
+ // if a connection has been established, we need to negotiate TLS before
+ // anything else. Depending on the negotiation, the HTTPHandlers will be added.
+ if let sslContext = self.sslContext {
+ try channel.pipeline.syncOperations.addHandler(NIOSSLServerHandler(context: sslContext))
+ }
+
+ // if neither HTTP Proxy nor TLS are wanted, we can add HTTP1 handlers directly
+ try channel.pipeline.syncOperations.configureHTTPServerPipeline(
+ withPipeliningAssistance: true,
+ withServerUpgrade: (
+ upgraders: [webSocket],
+ completionHandler: { ctx in
+ // complete
+ }
+ ),
+ withErrorHandling: true
+ )
+ return channel.eventLoop.makeSucceededVoidFuture()
+ } catch {
+ return channel.eventLoop.makeFailedFuture(error)
+ }
+ }.bind(to: socketAddress).wait()
+ }
+
+
+ // In the TLS case we must set up the 'proxy' and the 'server' handlers sequentially
+ // rather than re-using parts because the requestDecoder stops parsing after a CONNECT request
+ private func syncAddTLSHTTPProxyHandlers(
+ to channel: Channel,
+ proxyConfig: ProxyConfig,
+ expectedAuthorization: String?,
+ upgraders: [HTTPServerProtocolUpgrader]
+ ) throws {
+ let sync = channel.pipeline.syncOperations
+ let promise = channel.eventLoop.makePromise(of: Void.self)
+
+ let responseEncoder = HTTPResponseEncoder()
+ let requestDecoder = ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes))
+ let proxySimulator = HTTPProxySimulator(promise: promise, config: proxyConfig, expectedAuthorization: expectedAuthorization)
+
+ try sync.addHandler(responseEncoder)
+ try sync.addHandler(requestDecoder)
+
+ try sync.addHandler(proxySimulator)
+
+ promise.futureResult.flatMap { _ in
+ channel.pipeline.removeHandler(proxySimulator)
+ }.flatMap { _ in
+ channel.pipeline.removeHandler(responseEncoder)
+ }.flatMap { _ in
+ channel.pipeline.removeHandler(requestDecoder)
+ }.whenComplete { result in
+ switch result {
+ case .failure:
+ channel.close(mode: .all, promise: nil)
+ case .success:
+ self.httpProxyEstablished(channel, upgraders: upgraders)
+ break
+ }
+ }
+ }
+
+
+ // In the plain-text case we must set up the 'proxy' and the 'server' handlers simultaneously
+ // so that the combined proxy/upgrade request can be processed by the separate proxy and upgrade handlers
+ private func syncAddHTTPProxyHandlers(
+ to channel: Channel,
+ proxyConfig: ProxyConfig,
+ expectedAuthorization: String?,
+ upgraders: [HTTPServerProtocolUpgrader]
+ ) throws {
+ let sync = channel.pipeline.syncOperations
+ let promise = channel.eventLoop.makePromise(of: Void.self)
+
+ let responseEncoder = HTTPResponseEncoder()
+ let requestDecoder = ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes))
+ let proxySimulator = HTTPProxySimulator(promise: promise, config: proxyConfig, expectedAuthorization: expectedAuthorization)
+
+ let serverPipelineHandler = HTTPServerPipelineHandler()
+ let serverProtocolErrorHandler = HTTPServerProtocolErrorHandler()
+
+ let extraHTTPHandlers: [RemovableChannelHandler] = [
+ requestDecoder,
+ serverPipelineHandler,
+ serverProtocolErrorHandler
+ ]
+
+ try sync.addHandler(responseEncoder)
+ try sync.addHandler(requestDecoder)
+
+ try sync.addHandler(proxySimulator)
+
+ try sync.addHandler(serverPipelineHandler)
+ try sync.addHandler(serverProtocolErrorHandler)
+
+
+ let upgrader = HTTPServerUpgradeHandler(upgraders: upgraders,
+ httpEncoder: responseEncoder,
+ extraHTTPHandlers: extraHTTPHandlers,
+ upgradeCompletionHandler: { ctx in
+ // complete
+ })
+
+
+ try sync.addHandler(upgrader)
+
+ promise.futureResult.flatMap { () -> EventLoopFuture in
+ channel.pipeline.removeHandler(proxySimulator)
+ }.whenComplete { result in
+ switch result {
+ case .failure:
+ channel.close(mode: .all, promise: nil)
+ case .success:
+ break
+ }
+ }
+ }
+
+ private func httpProxyEstablished(_ channel: Channel, upgraders: [HTTPServerProtocolUpgrader]) {
+ do {
+ // if a connection has been established, we need to negotiate TLS before
+ // anything else. Depending on the negotiation, the HTTPHandlers will be added.
+ if let sslContext = self.sslContext {
+ try channel.pipeline.syncOperations.addHandler(NIOSSLServerHandler(context: sslContext))
+ }
+
+ try channel.pipeline.syncOperations.configureHTTPServerPipeline(
+ withPipeliningAssistance: true,
+ withServerUpgrade: (
+ upgraders: upgraders,
+ completionHandler: { ctx in
+ // complete
+ }
+ ),
+ withErrorHandling: true
+ )
+ } catch {
+ // in case of an while modifying the pipeline we should close the connection
+ channel.close(mode: .all, promise: nil)
+ }
+ }
+
+ func shutdown() throws {
+ self.isShutdown.store(true, ordering: .relaxed)
+ try self.group.syncShutdownGracefully()
+ }
+}
+
+enum HTTPBinError: Error {
+ case refusedConnection
+ case invalidProxyRequest
+}
+
+final class HTTPProxySimulator: ChannelInboundHandler, RemovableChannelHandler {
+ typealias InboundIn = HTTPServerRequestPart
+ typealias InboundOut = HTTPServerResponsePart
+ typealias OutboundOut = HTTPServerResponsePart
+
+
+ // the promise to succeed, once the proxy connection is setup
+ let promise: EventLoopPromise
+ let config: WebsocketBin.ProxyConfig
+ let expectedAuthorization: String?
+
+ var head: HTTPResponseHead
+
+ init(promise: EventLoopPromise, config: WebsocketBin.ProxyConfig, expectedAuthorization: String?) {
+ self.promise = promise
+ self.config = config
+ self.expectedAuthorization = expectedAuthorization
+ self.head = HTTPResponseHead(version: .init(major: 1, minor: 1), status: .ok, headers: .init([("Content-Length", "0")]))
+ }
+
+ func channelRead(context: ChannelHandlerContext, data: NIOAny) {
+ let request = self.unwrapInboundIn(data)
+ switch request {
+ case .head(let head):
+ if self.config.tls {
+ guard head.method == .CONNECT else {
+ self.head.status = .badRequest
+ return
+ }
+ } else {
+ guard head.method == .GET else {
+ self.head.status = .badRequest
+ return
+ }
+ }
+
+ self.config.headVerification(context, head)
+
+ if let expectedAuthorization = self.expectedAuthorization {
+ guard let authorization = head.headers["proxy-authorization"].first,
+ expectedAuthorization == authorization else {
+ self.head.status = .proxyAuthenticationRequired
+ return
+ }
+ }
+ if !self.config.tls {
+ context.fireChannelRead(data)
+ }
+
+ case .body:
+ ()
+ case .end:
+ if self.self.config.tls {
+ context.write(self.wrapOutboundOut(.head(self.head)), promise: nil)
+ context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
+ }
+ if self.head.status == .ok {
+ if !self.config.tls {
+ context.fireChannelRead(data)
+ }
+ self.promise.succeed(())
+ } else {
+ self.promise.fail(HTTPBinError.invalidProxyRequest)
+ }
+ }
+ }
+}
diff --git a/Tests/WebSocketKitTests/WebSocketKitTests.swift b/Tests/WebSocketKitTests/WebSocketKitTests.swift
index 7cc33a92..985cb00b 100644
--- a/Tests/WebSocketKitTests/WebSocketKitTests.swift
+++ b/Tests/WebSocketKitTests/WebSocketKitTests.swift
@@ -20,8 +20,8 @@ final class WebSocketKitTests: XCTestCase {
return
}
- let promise = elg.next().makePromise(of: String.self)
- let closePromise = elg.next().makePromise(of: Void.self)
+ let promise = elg.any().makePromise(of: String.self)
+ let closePromise = elg.any().makePromise(of: Void.self)
WebSocket.connect(to: "ws://localhost:\(port)", on: elg) { ws in
ws.send("hello")
ws.onText { ws, string in
@@ -39,9 +39,9 @@ final class WebSocketKitTests: XCTestCase {
}
func testServerClose() throws {
- let sendPromise = self.elg.next().makePromise(of: Void.self)
- let serverClose = self.elg.next().makePromise(of: Void.self)
- let clientClose = self.elg.next().makePromise(of: Void.self)
+ let sendPromise = self.elg.any().makePromise(of: Void.self)
+ let serverClose = self.elg.any().makePromise(of: Void.self)
+ let clientClose = self.elg.any().makePromise(of: Void.self)
let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in
ws.onText { ws, text in
if text == "close" {
@@ -67,9 +67,9 @@ final class WebSocketKitTests: XCTestCase {
}
func testClientClose() throws {
- let sendPromise = self.elg.next().makePromise(of: Void.self)
- let serverClose = self.elg.next().makePromise(of: Void.self)
- let clientClose = self.elg.next().makePromise(of: Void.self)
+ let sendPromise = self.elg.any().makePromise(of: Void.self)
+ let serverClose = self.elg.any().makePromise(of: Void.self)
+ let clientClose = self.elg.any().makePromise(of: Void.self)
let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in
ws.onText { ws, text in
ws.send(text)
@@ -98,7 +98,7 @@ final class WebSocketKitTests: XCTestCase {
}
func testImmediateSend() throws {
- let promise = self.elg.next().makePromise(of: String.self)
+ let promise = self.elg.any().makePromise(of: String.self)
let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in
ws.send("hello")
ws.onText { ws, string in
@@ -124,8 +124,8 @@ final class WebSocketKitTests: XCTestCase {
}
func testWebSocketPingPong() throws {
- let pingPromise = self.elg.next().makePromise(of: String.self)
- let pongPromise = self.elg.next().makePromise(of: String.self)
+ let pingPromise = self.elg.any().makePromise(of: String.self)
+ let pongPromise = self.elg.any().makePromise(of: String.self)
let pingPongData = ByteBuffer(bytes: "Vapor rules".utf8)
let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in
@@ -159,10 +159,10 @@ final class WebSocketKitTests: XCTestCase {
let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in
ws.onText { ws, text in
- ws.send(text, opcode: .text, fin: false)
- ws.send(" th", opcode: .continuation, fin: false)
- ws.send("e mo", opcode: .continuation, fin: false)
- ws.send("st", opcode: .continuation, fin: true)
+ ws.send(.init(string: text), opcode: .text, fin: false)
+ ws.send(.init(string: " th"), opcode: .continuation, fin: false)
+ ws.send(.init(string: "e mo"), opcode: .continuation, fin: false)
+ ws.send(.init(string: "st"), opcode: .continuation, fin: true)
}
}.bind(host: "localhost", port: 0).wait()
@@ -171,12 +171,12 @@ final class WebSocketKitTests: XCTestCase {
return
}
- let promise = elg.next().makePromise(of: String.self)
- let closePromise = elg.next().makePromise(of: Void.self)
+ let promise = elg.any().makePromise(of: String.self)
+ let closePromise = elg.any().makePromise(of: Void.self)
WebSocket.connect(to: "ws://localhost:\(port)", on: elg) { ws in
- ws.send("Hel", opcode: .text, fin: false)
- ws.send("lo! Vapor r", opcode: .continuation, fin: false)
- ws.send("ules", opcode: .continuation, fin: true)
+ ws.send(.init(string: "Hel"), opcode: .text, fin: false)
+ ws.send(.init(string: "lo! Vapor r"), opcode: .continuation, fin: false)
+ ws.send(.init(string: "ules"), opcode: .continuation, fin: true)
ws.onText { ws, string in
promise.succeed(string)
ws.close(promise: closePromise)
@@ -188,7 +188,7 @@ final class WebSocketKitTests: XCTestCase {
}
func testErrorCode() throws {
- let promise = self.elg.next().makePromise(of: WebSocketErrorCode.self)
+ let promise = self.elg.any().makePromise(of: WebSocketErrorCode.self)
let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in
ws.close(code: .normalClosure, promise: nil)
@@ -214,10 +214,10 @@ final class WebSocketKitTests: XCTestCase {
}
func testHeadersAreSent() throws {
- let promiseAuth = self.elg.next().makePromise(of: String.self)
+ let promiseAuth = self.elg.any().makePromise(of: String.self)
// make sure there are no unwanted headers such as `Content-Length` or `Content-Type`
- let promiseHasUnwantedHeaders = self.elg.next().makePromise(of: Bool.self)
+ let promiseHasUnwantedHeaders = self.elg.any().makePromise(of: Bool.self)
let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in
let headers = req.headers
@@ -251,7 +251,7 @@ final class WebSocketKitTests: XCTestCase {
}
func testQueryParamsAreSent() throws {
- let promise = self.elg.next().makePromise(of: String.self)
+ let promise = self.elg.any().makePromise(of: String.self)
let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in
promise.succeed(req.uri)
@@ -278,7 +278,7 @@ final class WebSocketKitTests: XCTestCase {
try XCTSkipIf(true)
let port = Int(1337)
- let shutdownPromise = self.elg.next().makePromise(of: Void.self)
+ let shutdownPromise = self.elg.any().makePromise(of: Void.self)
let server = try! ServerBootstrap.webSocket(on: self.elg) { req, ws in
ws.send("welcome!")
@@ -334,10 +334,11 @@ final class WebSocketKitTests: XCTestCase {
}.wait()
try server.close(mode: .all).wait()
+ try client.syncShutdown()
}
func testProxy() throws {
- let promise = elg.next().makePromise(of: String.self)
+ let promise = elg.any().makePromise(of: String.self)
let localWebsocketBin: WebsocketBin
let verifyProxyHead = { (ctx: ChannelHandlerContext, requestHead: HTTPRequestHead) in
@@ -361,7 +362,7 @@ final class WebSocketKitTests: XCTestCase {
XCTAssertNoThrow(try localWebsocketBin.shutdown())
}
- let closePromise = elg.next().makePromise(of: Void.self)
+ let closePromise = elg.any().makePromise(of: Void.self)
let client = WebSocketClient(
eventLoopGroupProvider: .shared(self.elg),
@@ -385,10 +386,11 @@ final class WebSocketKitTests: XCTestCase {
XCTAssertEqual(try promise.futureResult.wait(), "hello")
XCTAssertNoThrow(try closePromise.futureResult.wait())
+ try client.syncShutdown()
}
func testProxyTLS() throws {
- let promise = elg.next().makePromise(of: String.self)
+ let promise = elg.any().makePromise(of: String.self)
let (cert, key) = generateSelfSignedCert()
let configuration = TLSConfiguration.makeServerConfiguration(
@@ -421,7 +423,7 @@ final class WebSocketKitTests: XCTestCase {
XCTAssertNoThrow(try localWebsocketBin.shutdown())
}
- let closePromise = elg.next().makePromise(of: Void.self)
+ let closePromise = elg.any().makePromise(of: Void.self)
var tlsConfiguration = TLSConfiguration.makeClientConfiguration()
tlsConfiguration.certificateVerification = .none
@@ -449,384 +451,108 @@ final class WebSocketKitTests: XCTestCase {
XCTAssertEqual(try promise.futureResult.wait(), "hello")
XCTAssertNoThrow(try closePromise.futureResult.wait())
+ try client.syncShutdown()
}
-
- var elg: EventLoopGroup!
- override func setUp() {
- // needs to be at least two to avoid client / server on same EL timing issues
- self.elg = MultiThreadedEventLoopGroup(numberOfThreads: 2)
- }
- override func tearDown() {
- try! self.elg.syncShutdownGracefully()
+ func testAlternateWebsocketConnectMethods() throws {
+ let server = try ServerBootstrap.webSocket(on: self.elg) { $1.onText { $0.send($1) } }.bind(host: "localhost", port: 0).wait()
+ let closePromise1 = self.elg.any().makePromise(of: Void.self)
+ let closePromise2 = self.elg.any().makePromise(of: Void.self)
+ guard let port = server.localAddress?.port else {
+ return XCTFail("couldn't get port from \(String(reflecting: server.localAddress))")
+ }
+ WebSocket.connect(scheme: "ws", host: "localhost", port: port, proxy: nil, on: self.elg) { ws in
+ ws.send("hello"); ws.onText { ws, _ in ws.close(promise: closePromise1) }
+ }.cascadeFailure(to: closePromise1)
+ WebSocket.connect(to: "ws://localhost:\(port)", proxy: nil, on: self.elg) { ws in
+ ws.send("hello"); ws.onText { ws, _ in ws.close(promise: closePromise2) }
+ }.cascadeFailure(to: closePromise2)
+ XCTAssertNoThrow(try closePromise1.futureResult.wait())
+ XCTAssertNoThrow(try closePromise2.futureResult.wait())
+ try server.close(mode: .all).wait()
}
-}
-
-extension ServerBootstrap {
- static func webSocket(
- on eventLoopGroup: EventLoopGroup,
- tls: Bool = false,
- onUpgrade: @escaping (HTTPRequestHead, WebSocket) -> ()
- ) -> ServerBootstrap {
- return ServerBootstrap(group: eventLoopGroup).childChannelInitializer { channel in
- if tls {
- let (cert, key) = generateSelfSignedCert()
- let configuration = TLSConfiguration.makeServerConfiguration(
- certificateChain: [.certificate(cert)],
- privateKey: .privateKey(key)
- )
- let sslContext = try! NIOSSLContext(configuration: configuration)
- let handler = NIOSSLServerHandler(context: sslContext)
- _ = channel.pipeline.addHandler(handler)
+
+ func testBadURLInWebsocketConnect() async throws {
+ XCTAssertThrowsError(try WebSocket.connect(to: "%w", on: self.elg, onUpgrade: { _ in }).wait()) {
+ guard case .invalidURL = $0 as? WebSocketClient.Error else {
+ return XCTFail("Expected .invalidURL but got \(String(reflecting: $0))")
}
- let webSocket = NIOWebSocketServerUpgrader(
- shouldUpgrade: { channel, req in
- return channel.eventLoop.makeSucceededFuture([:])
- },
- upgradePipelineHandler: { channel, req in
- return WebSocket.server(on: channel) { ws in
- onUpgrade(req, ws)
- }
- }
- )
- return channel.pipeline.configureHTTPServerPipeline(
- withServerUpgrade: (
- upgraders: [webSocket],
- completionHandler: { ctx in
- // complete
- }
- )
- )
}
}
-}
-
-fileprivate extension WebSocket {
- func send(
- _ data: String,
- opcode: WebSocketOpcode,
- fin: Bool = true,
- promise: EventLoopPromise? = nil
- ) {
- self.send(raw: ByteBuffer(string: data).readableBytesView, opcode: opcode, fin: fin, promise: promise)
- }
-}
-
-
-
-internal final class WebsocketBin {
- enum BindTarget {
- case unixDomainSocket(String)
- case localhostIPv4RandomPort
- case localhostIPv6RandomPort
- }
-
- enum Mode {
- // refuses all connections
- case refuse
- // supports http1.1 connections only, which can be either plain text or encrypted
- case http1_1(ssl: Bool = false)
- }
-
- enum Proxy {
- case none
- case simulate(config: ProxyConfig, authorization: String?)
- }
-
- struct ProxyConfig {
- var tls: Bool
- let headVerification: (ChannelHandlerContext, HTTPRequestHead) -> Void
- }
-
- let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
-
- var port: Int {
- return Int(self.serverChannel.localAddress!.port!)
- }
-
- private let mode: Mode
- private let sslContext: NIOSSLContext?
- private var serverChannel: Channel!
- private let isShutdown = ManagedAtomic(false)
-
- init(
- _ mode: Mode = .http1_1(ssl: false),
- proxy: Proxy = .none,
- bindTarget: BindTarget = .localhostIPv4RandomPort,
- sslContext: NIOSSLContext?,
- onUpgrade: @escaping (HTTPRequestHead, WebSocket) -> ()
- ) {
- self.mode = mode
- self.sslContext = sslContext
-
- let socketAddress: SocketAddress
- switch bindTarget {
- case .localhostIPv4RandomPort:
- socketAddress = try! SocketAddress(ipAddress: "127.0.0.1", port: 0)
- case .localhostIPv6RandomPort:
- socketAddress = try! SocketAddress(ipAddress: "::1", port: 0)
- case .unixDomainSocket(let path):
- socketAddress = try! SocketAddress(unixDomainSocketPath: path)
+
+ func testOnBinary() throws {
+ let server = try ServerBootstrap.webSocket(on: self.elg) { $1.onBinary { $0.send($1) } }.bind(host: "localhost", port: 0).wait()
+ let promise = self.elg.any().makePromise(of: [UInt8].self)
+ let closePromise = self.elg.any().makePromise(of: Void.self)
+ guard let port = server.localAddress?.port else {
+ return XCTFail("couldn't get port from \(String(reflecting: server.localAddress))")
}
-
- self.serverChannel = try! ServerBootstrap(group: self.group)
- .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
- .childChannelInitializer { channel in
- do {
-
- if case .refuse = mode {
- throw HTTPBinError.refusedConnection
- }
-
- let webSocket = NIOWebSocketServerUpgrader(
- shouldUpgrade: { channel, req in
- return channel.eventLoop.makeSucceededFuture([:])
- },
- upgradePipelineHandler: { channel, req in
- return WebSocket.server(on: channel) { ws in
- onUpgrade(req, ws)
- }
- }
- )
-
- // if we need to simulate a proxy, we need to add those handlers first
- if case .simulate(config: let config, authorization: let expectedAuthorization) = proxy {
- if config.tls {
- try self.syncAddTLSHTTPProxyHandlers(
- to: channel,
- proxyConfig: config,
- expectedAuthorization: expectedAuthorization,
- upgraders: [webSocket]
- )
- } else {
- try self.syncAddHTTPProxyHandlers(
- to: channel,
- proxyConfig: config,
- expectedAuthorization: expectedAuthorization,
- upgraders: [webSocket]
- )
- }
- return channel.eventLoop.makeSucceededVoidFuture()
- }
-
- // if a connection has been established, we need to negotiate TLS before
- // anything else. Depending on the negotiation, the HTTPHandlers will be added.
- if let sslContext = self.sslContext {
- try channel.pipeline.syncOperations.addHandler(NIOSSLServerHandler(context: sslContext))
- }
-
- // if neither HTTP Proxy nor TLS are wanted, we can add HTTP1 handlers directly
- try channel.pipeline.syncOperations.configureHTTPServerPipeline(
- withPipeliningAssistance: true,
- withServerUpgrade: (
- upgraders: [webSocket],
- completionHandler: { ctx in
- // complete
- }
- ),
- withErrorHandling: true
- )
- return channel.eventLoop.makeSucceededVoidFuture()
- } catch {
- return channel.eventLoop.makeFailedFuture(error)
- }
- }.bind(to: socketAddress).wait()
- }
-
-
- // In the TLS case we must set up the 'proxy' and the 'server' handlers sequentially
- // rather than re-using parts because the requestDecoder stops parsing after a CONNECT request
- private func syncAddTLSHTTPProxyHandlers(
- to channel: Channel,
- proxyConfig: ProxyConfig,
- expectedAuthorization: String?,
- upgraders: [HTTPServerProtocolUpgrader]
- ) throws {
- let sync = channel.pipeline.syncOperations
- let promise = channel.eventLoop.makePromise(of: Void.self)
-
- let responseEncoder = HTTPResponseEncoder()
- let requestDecoder = ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes))
- let proxySimulator = HTTPProxySimulator(promise: promise, config: proxyConfig, expectedAuthorization: expectedAuthorization)
-
- try sync.addHandler(responseEncoder)
- try sync.addHandler(requestDecoder)
-
- try sync.addHandler(proxySimulator)
-
- promise.futureResult.flatMap { _ in
- channel.pipeline.removeHandler(proxySimulator)
- }.flatMap { _ in
- channel.pipeline.removeHandler(responseEncoder)
- }.flatMap { _ in
- channel.pipeline.removeHandler(requestDecoder)
- }.whenComplete { result in
- switch result {
- case .failure:
- channel.close(mode: .all, promise: nil)
- case .success:
- self.httpProxyEstablished(channel, upgraders: upgraders)
- break
+ WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { ws in
+ ws.send([0x01])
+ ws.onBinary { ws, buf in
+ promise.succeed(.init(buf.readableBytesView))
+ ws.close(promise: closePromise)
}
+ }.whenFailure {
+ promise.fail($0)
+ closePromise.fail($0)
}
+ XCTAssertEqual(try promise.futureResult.wait(), [0x01])
+ XCTAssertNoThrow(try closePromise.futureResult.wait())
+ try server.close(mode: .all).wait()
}
-
-
- // In the plain-text case we must set up the 'proxy' and the 'server' handlers simultaneously
- // so that the combined proxy/upgrade request can be processed by the separate proxy and upgrade handlers
- private func syncAddHTTPProxyHandlers(
- to channel: Channel,
- proxyConfig: ProxyConfig,
- expectedAuthorization: String?,
- upgraders: [HTTPServerProtocolUpgrader]
- ) throws {
- let sync = channel.pipeline.syncOperations
- let promise = channel.eventLoop.makePromise(of: Void.self)
-
- let responseEncoder = HTTPResponseEncoder()
- let requestDecoder = ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes))
- let proxySimulator = HTTPProxySimulator(promise: promise, config: proxyConfig, expectedAuthorization: expectedAuthorization)
-
- let serverPipelineHandler = HTTPServerPipelineHandler()
- let serverProtocolErrorHandler = HTTPServerProtocolErrorHandler()
-
- let extraHTTPHandlers: [RemovableChannelHandler] = [
- requestDecoder,
- serverPipelineHandler,
- serverProtocolErrorHandler
- ]
-
- try sync.addHandler(responseEncoder)
- try sync.addHandler(requestDecoder)
-
- try sync.addHandler(proxySimulator)
-
- try sync.addHandler(serverPipelineHandler)
- try sync.addHandler(serverProtocolErrorHandler)
-
-
- let upgrader = HTTPServerUpgradeHandler(upgraders: upgraders,
- httpEncoder: responseEncoder,
- extraHTTPHandlers: extraHTTPHandlers,
- upgradeCompletionHandler: { ctx in
- // complete
- })
-
-
- try sync.addHandler(upgrader)
-
- promise.futureResult.flatMap { () -> EventLoopFuture in
- channel.pipeline.removeHandler(proxySimulator)
- }.whenComplete { result in
- switch result {
- case .failure:
- channel.close(mode: .all, promise: nil)
- case .success:
- break
- }
+
+ func testSendPing() throws {
+ let server = try ServerBootstrap.webSocket(on: self.elg) { _, _ in }.bind(host: "localhost", port: 0).wait()
+ let promise = self.elg.any().makePromise(of: Void.self)
+ let closePromise = self.elg.any().makePromise(of: Void.self)
+ guard let port = server.localAddress?.port else {
+ return XCTFail("couldn't get port from \(String(reflecting: server.localAddress))")
}
- }
-
- private func httpProxyEstablished(_ channel: Channel, upgraders: [HTTPServerProtocolUpgrader]) {
- do {
- // if a connection has been established, we need to negotiate TLS before
- // anything else. Depending on the negotiation, the HTTPHandlers will be added.
- if let sslContext = self.sslContext {
- try channel.pipeline.syncOperations.addHandler(NIOSSLServerHandler(context: sslContext))
+ WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { ws in
+ ws.sendPing()
+ ws.onPong {
+ promise.succeed()
+ $0.close(promise: closePromise)
}
-
- try channel.pipeline.syncOperations.configureHTTPServerPipeline(
- withPipeliningAssistance: true,
- withServerUpgrade: (
- upgraders: upgraders,
- completionHandler: { ctx in
- // complete
- }
- ),
- withErrorHandling: true
- )
- } catch {
- // in case of an while modifying the pipeline we should close the connection
- channel.close(mode: .all, promise: nil)
+ }.cascadeFailure(to: closePromise)
+ XCTAssertNoThrow(try promise.futureResult.wait())
+ XCTAssertNoThrow(try closePromise.futureResult.wait())
+ try server.close(mode: .all).wait()
+ }
+
+ func testSetPingInterval() throws {
+ let server = try ServerBootstrap.webSocket(on: self.elg) { _, _ in }.bind(host: "localhost", port: 0).wait()
+ let promise = self.elg.any().makePromise(of: Void.self)
+ let closePromise = self.elg.any().makePromise(of: Void.self)
+ guard let port = server.localAddress?.port else {
+ return XCTFail("couldn't get port from \(String(reflecting: server.localAddress))")
}
+ WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { ws in
+ ws.pingInterval = .milliseconds(100)
+ ws.onPong {
+ promise.succeed()
+ $0.close(promise: closePromise)
+ }
+ }.cascadeFailure(to: closePromise)
+ XCTAssertNoThrow(try promise.futureResult.wait())
+ XCTAssertNoThrow(try closePromise.futureResult.wait())
+ try server.close(mode: .all).wait()
}
-
- func shutdown() throws {
- self.isShutdown.store(true, ordering: .relaxed)
- try self.group.syncShutdownGracefully()
+
+ func testCreateNewELGAndShutdown() throws {
+ let client = WebSocketClient(eventLoopGroupProvider: .createNew)
+ try client.syncShutdown()
}
-}
-
-enum HTTPBinError: Error {
- case refusedConnection
- case invalidProxyRequest
-}
-
-final class HTTPProxySimulator: ChannelInboundHandler, RemovableChannelHandler {
- typealias InboundIn = HTTPServerRequestPart
- typealias InboundOut = HTTPServerResponsePart
- typealias OutboundOut = HTTPServerResponsePart
-
- // the promise to succeed, once the proxy connection is setup
- let promise: EventLoopPromise
- let config: WebsocketBin.ProxyConfig
- let expectedAuthorization: String?
-
- var head: HTTPResponseHead
-
- init(promise: EventLoopPromise, config: WebsocketBin.ProxyConfig, expectedAuthorization: String?) {
- self.promise = promise
- self.config = config
- self.expectedAuthorization = expectedAuthorization
- self.head = HTTPResponseHead(version: .init(major: 1, minor: 1), status: .ok, headers: .init([("Content-Length", "0")]))
+ var elg: EventLoopGroup!
+
+ override func setUp() {
+ // needs to be at least two to avoid client / server on same EL timing issues
+ self.elg = MultiThreadedEventLoopGroup(numberOfThreads: 2)
}
-
- func channelRead(context: ChannelHandlerContext, data: NIOAny) {
- let request = self.unwrapInboundIn(data)
- switch request {
- case .head(let head):
- if self.config.tls {
- guard head.method == .CONNECT else {
- self.head.status = .badRequest
- return
- }
- } else {
- guard head.method == .GET else {
- self.head.status = .badRequest
- return
- }
- }
-
- self.config.headVerification(context, head)
-
- if let expectedAuthorization = self.expectedAuthorization {
- guard let authorization = head.headers["proxy-authorization"].first,
- expectedAuthorization == authorization else {
- self.head.status = .proxyAuthenticationRequired
- return
- }
- }
- if !self.config.tls {
- context.fireChannelRead(data)
- }
-
- case .body:
- ()
- case .end:
- if self.self.config.tls {
- context.write(self.wrapOutboundOut(.head(self.head)), promise: nil)
- context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
- }
- if self.head.status == .ok {
- if !self.config.tls {
- context.fireChannelRead(data)
- }
- self.promise.succeed(())
- } else {
- self.promise.fail(HTTPBinError.invalidProxyRequest)
- }
- }
+
+ override func tearDown() {
+ try! self.elg.syncShutdownGracefully()
}
}
-