From 0ba2f8a548e8c67e75e36069830b6ee1e2222e41 Mon Sep 17 00:00:00 2001 From: Joannis Orlandos Date: Thu, 25 May 2023 18:32:36 +0200 Subject: [PATCH] Revert "Sendable Take 2 (#136)" This reverts commit a8f3ec804f84cce2384d4b52db3267e4d51c9cda. --- .../allowlist-branch-sendable-take-2.txt | 14 -- .github/workflows/test.yml | 2 + .gitignore | 1 - Package.swift | 4 +- .../Concurrency/WebSocket+Concurrency.swift | 52 +++--- Sources/WebSocketKit/Exports.swift | 4 + Sources/WebSocketKit/WebSocket+Connect.swift | 15 +- Sources/WebSocketKit/WebSocket.swift | 163 ++++++++---------- Sources/WebSocketKit/WebSocketClient.swift | 33 ++-- Sources/WebSocketKit/WebSocketHandler.swift | 16 +- .../AsyncWebSocketKitTests.swift | 52 +++--- Tests/WebSocketKitTests/SSLTestHelpers.swift | 1 - .../WebSocketKitTests/WebSocketKitTests.swift | 51 +++--- 13 files changed, 175 insertions(+), 233 deletions(-) delete mode 100644 .api-breakage/allowlist-branch-sendable-take-2.txt diff --git a/.api-breakage/allowlist-branch-sendable-take-2.txt b/.api-breakage/allowlist-branch-sendable-take-2.txt deleted file mode 100644 index a2985625..00000000 --- a/.api-breakage/allowlist-branch-sendable-take-2.txt +++ /dev/null @@ -1,14 +0,0 @@ -API breakage: func WebSocket.onText(_:) is now with @preconcurrency -API breakage: func WebSocket.onBinary(_:) is now with @preconcurrency -API breakage: func WebSocket.onPong(_:) is now with @preconcurrency -API breakage: func WebSocket.onPing(_:) is now with @preconcurrency -API breakage: func WebSocket.connect(to:headers:configuration:on:onUpgrade:) is now with @preconcurrency -API breakage: func WebSocket.connect(scheme:host:port:path:query:headers:configuration:on:onUpgrade:) is now with @preconcurrency -API breakage: func WebSocket.connect(scheme:host:port:path:query:headers:proxy:proxyPort:proxyHeaders:proxyConnectDeadline:configuration:on:onUpgrade:) is now with @preconcurrency -API breakage: func WebSocket.connect(to:headers:proxy:proxyPort:proxyHeaders:proxyConnectDeadline:configuration:on:onUpgrade:) is now with @preconcurrency -API breakage: func WebSocket.client(on:onUpgrade:) is now with @preconcurrency -API breakage: func WebSocket.client(on:config:onUpgrade:) is now with @preconcurrency -API breakage: func WebSocket.server(on:onUpgrade:) is now with @preconcurrency -API breakage: func WebSocket.server(on:config:onUpgrade:) is now with @preconcurrency -API breakage: func WebSocketClient.connect(scheme:host:port:path:query:headers:onUpgrade:) is now with @preconcurrency -API breakage: func WebSocketClient.connect(scheme:host:port:path:query:headers:proxy:proxyPort:proxyHeaders:proxyConnectDeadline:onUpgrade:) is now with @preconcurrency diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 7fe9308e..b2288dc2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -5,7 +5,9 @@ concurrency: on: pull_request: { types: [opened, reopened, synchronize, ready_for_review] } push: { branches: [ main ] } + jobs: + vapor-integration: if: ${{ !(github.event.pull_request.draft || false) }} runs-on: ubuntu-latest diff --git a/.gitignore b/.gitignore index 630ed81f..68b8b308 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,3 @@ DerivedData .swiftpm Package.resolved -.devcontainer/ diff --git a/Package.swift b/Package.swift index 86fea2af..4841d15c 100644 --- a/Package.swift +++ b/Package.swift @@ -16,8 +16,8 @@ let package = Package( .package(url: "https://github.com/apple/swift-nio.git", from: "2.53.0"), .package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.16.0"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.24.0"), - .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.16.0"), - .package(url: "https://github.com/apple/swift-atomics.git", from: "1.1.0"), + .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.11.4"), + .package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.2"), ], targets: [ .target(name: "WebSocketKit", dependencies: [ diff --git a/Sources/WebSocketKit/Concurrency/WebSocket+Concurrency.swift b/Sources/WebSocketKit/Concurrency/WebSocket+Concurrency.swift index 56085bef..d0e9c0d3 100644 --- a/Sources/WebSocketKit/Concurrency/WebSocket+Concurrency.swift +++ b/Sources/WebSocketKit/Concurrency/WebSocket+Concurrency.swift @@ -40,52 +40,44 @@ extension WebSocket { try await close(code: code).get() } - @preconcurrency public func onText(_ callback: @Sendable @escaping (WebSocket, String) async -> ()) { - self.eventLoop.execute { - self.onText { socket, text in - Task { - await callback(socket, text) - } + public func onText(_ callback: @escaping (WebSocket, String) async -> ()) { + onText { socket, text in + Task { + await callback(socket, text) } } } - @preconcurrency public func onBinary(_ callback: @Sendable @escaping (WebSocket, ByteBuffer) async -> ()) { - self.eventLoop.execute { - self.onBinary { socket, binary in - Task { - await callback(socket, binary) - } + public func onBinary(_ callback: @escaping (WebSocket, ByteBuffer) async -> ()) { + onBinary { socket, binary in + Task { + await callback(socket, binary) } } } - @preconcurrency public func onPong(_ callback: @Sendable @escaping (WebSocket) async -> ()) { - self.eventLoop.execute { - self.onPong { socket in - Task { - await callback(socket) - } + public func onPong(_ callback: @escaping (WebSocket) async -> ()) { + onPong { socket in + Task { + await callback(socket) } } } - @preconcurrency public func onPing(_ callback: @Sendable @escaping (WebSocket) async -> ()) { - self.eventLoop.execute { - self.onPing { socket in - Task { - await callback(socket) - } + public func onPing(_ callback: @escaping (WebSocket) async -> ()) { + onPing { socket in + Task { + await callback(socket) } } } - @preconcurrency public static func connect( + public static func connect( to url: String, headers: HTTPHeaders = [:], configuration: WebSocketClient.Configuration = .init(), on eventLoopGroup: EventLoopGroup, - onUpgrade: @Sendable @escaping (WebSocket) async -> () + onUpgrade: @escaping (WebSocket) async -> () ) async throws { return try await self.connect( to: url, @@ -100,12 +92,12 @@ extension WebSocket { ).get() } - @preconcurrency public static func connect( + public static func connect( to url: URL, headers: HTTPHeaders = [:], configuration: WebSocketClient.Configuration = .init(), on eventLoopGroup: EventLoopGroup, - onUpgrade: @Sendable @escaping (WebSocket) async -> () + onUpgrade: @escaping (WebSocket) async -> () ) async throws { return try await self.connect( to: url, @@ -120,7 +112,7 @@ extension WebSocket { ).get() } - @preconcurrency public static func connect( + public static func connect( scheme: String = "ws", host: String, port: Int = 80, @@ -129,7 +121,7 @@ extension WebSocket { headers: HTTPHeaders = [:], configuration: WebSocketClient.Configuration = .init(), on eventLoopGroup: EventLoopGroup, - onUpgrade: @Sendable @escaping (WebSocket) async -> () + onUpgrade: @escaping (WebSocket) async -> () ) async throws { return try await self.connect( scheme: scheme, diff --git a/Sources/WebSocketKit/Exports.swift b/Sources/WebSocketKit/Exports.swift index b853992f..11ce0e49 100644 --- a/Sources/WebSocketKit/Exports.swift +++ b/Sources/WebSocketKit/Exports.swift @@ -6,7 +6,9 @@ @_documentation(visibility: internal) @_exported import protocol NIOCore.EventLoopGroup @_documentation(visibility: internal) @_exported import struct NIOCore.EventLoopPromise @_documentation(visibility: internal) @_exported import class NIOCore.EventLoopFuture + @_documentation(visibility: internal) @_exported import struct NIOHTTP1.HTTPHeaders + @_documentation(visibility: internal) @_exported import struct Foundation.URL #else @@ -17,7 +19,9 @@ @_exported import protocol NIOCore.EventLoopGroup @_exported import struct NIOCore.EventLoopPromise @_exported import class NIOCore.EventLoopFuture + @_exported import struct NIOHTTP1.HTTPHeaders + @_exported import struct Foundation.URL #endif diff --git a/Sources/WebSocketKit/WebSocket+Connect.swift b/Sources/WebSocketKit/WebSocket+Connect.swift index 546401a4..ca94540b 100644 --- a/Sources/WebSocketKit/WebSocket+Connect.swift +++ b/Sources/WebSocketKit/WebSocket+Connect.swift @@ -12,13 +12,12 @@ extension WebSocket { /// - eventLoopGroup: Event loop group to be used by the WebSocket client. /// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`. /// - Returns: An future which completes when the connection to the WebSocket server is established. - @preconcurrency public static func connect( to url: String, headers: HTTPHeaders = [:], configuration: WebSocketClient.Configuration = .init(), on eventLoopGroup: EventLoopGroup, - onUpgrade: @Sendable @escaping (WebSocket) -> () + onUpgrade: @escaping (WebSocket) -> () ) -> EventLoopFuture { guard let url = URL(string: url) else { return eventLoopGroup.any().makeFailedFuture(WebSocketClient.Error.invalidURL) @@ -41,13 +40,12 @@ extension WebSocket { /// - eventLoopGroup: Event loop group to be used by the WebSocket client. /// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`. /// - Returns: An future which completes when the connection to the WebSocket server is established. - @preconcurrency public static func connect( to url: URL, headers: HTTPHeaders = [:], configuration: WebSocketClient.Configuration = .init(), on eventLoopGroup: EventLoopGroup, - onUpgrade: @Sendable @escaping (WebSocket) -> () + onUpgrade: @escaping (WebSocket) -> () ) -> EventLoopFuture { let scheme = url.scheme ?? "ws" return self.connect( @@ -76,7 +74,6 @@ extension WebSocket { /// - eventLoopGroup: Event loop group to be used by the WebSocket client. /// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`. /// - Returns: An future which completes when the connection to the WebSocket server is established. - @preconcurrency public static func connect( scheme: String = "ws", host: String, @@ -86,7 +83,7 @@ extension WebSocket { headers: HTTPHeaders = [:], configuration: WebSocketClient.Configuration = .init(), on eventLoopGroup: EventLoopGroup, - onUpgrade: @Sendable @escaping (WebSocket) -> () + onUpgrade: @escaping (WebSocket) -> () ) -> EventLoopFuture { return WebSocketClient( eventLoopGroupProvider: .shared(eventLoopGroup), @@ -119,7 +116,6 @@ extension WebSocket { /// - eventLoopGroup: Event loop group to be used by the WebSocket client. /// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`. /// - Returns: An future which completes when the connection to the origin server is established. - @preconcurrency public static func connect( scheme: String = "ws", host: String, @@ -133,7 +129,7 @@ extension WebSocket { proxyConnectDeadline: NIODeadline = NIODeadline.distantFuture, configuration: WebSocketClient.Configuration = .init(), on eventLoopGroup: EventLoopGroup, - onUpgrade: @Sendable @escaping (WebSocket) -> () + onUpgrade: @escaping (WebSocket) -> () ) -> EventLoopFuture { return WebSocketClient( eventLoopGroupProvider: .shared(eventLoopGroup), @@ -166,7 +162,6 @@ extension WebSocket { /// - eventLoopGroup: Event loop group to be used by the WebSocket client. /// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`. /// - Returns: An future which completes when the connection to the origin server is established. - @preconcurrency public static func connect( to url: String, headers: HTTPHeaders = [:], @@ -176,7 +171,7 @@ extension WebSocket { proxyConnectDeadline: NIODeadline = NIODeadline.distantFuture, configuration: WebSocketClient.Configuration = .init(), on eventLoopGroup: EventLoopGroup, - onUpgrade: @Sendable @escaping (WebSocket) -> () + onUpgrade: @escaping (WebSocket) -> () ) -> EventLoopFuture { guard let url = URL(string: url) else { return eventLoopGroup.any().makeFailedFuture(WebSocketClient.Error.invalidURL) diff --git a/Sources/WebSocketKit/WebSocket.swift b/Sources/WebSocketKit/WebSocket.swift index 245d97a2..6ab6b6fd 100644 --- a/Sources/WebSocketKit/WebSocket.swift +++ b/Sources/WebSocketKit/WebSocket.swift @@ -4,10 +4,9 @@ import NIOHTTP1 import NIOSSL import Foundation import NIOFoundationCompat -import NIOConcurrencyHelpers -public final class WebSocket: Sendable { - enum PeerType: Sendable { +public final class WebSocket { + enum PeerType { case server case client } @@ -19,11 +18,7 @@ public final class WebSocket: Sendable { public var isClosed: Bool { !self.channel.isActive } - public var closeCode: WebSocketErrorCode? { - _closeCode.withLockedValue { $0 } - } - - private let _closeCode: NIOLockedValueBox + public private(set) var closeCode: WebSocketErrorCode? public var onClose: EventLoopFuture { self.channel.closeFuture @@ -32,46 +27,42 @@ public final class WebSocket: Sendable { @usableFromInline /* private but @usableFromInline */ internal let channel: Channel - private let onTextCallback: NIOLoopBoundBox<@Sendable (WebSocket, String) -> ()> - private let onBinaryCallback: NIOLoopBoundBox<@Sendable (WebSocket, ByteBuffer) -> ()> - private let onPongCallback: NIOLoopBoundBox<@Sendable (WebSocket) -> ()> - private let onPingCallback: NIOLoopBoundBox<@Sendable (WebSocket) -> ()> + private var onTextCallback: (WebSocket, String) -> () + private var onBinaryCallback: (WebSocket, ByteBuffer) -> () + private var onPongCallback: (WebSocket) -> () + private var onPingCallback: (WebSocket) -> () + private var frameSequence: WebSocketFrameSequence? private let type: PeerType - private let waitingForPong: NIOLockedValueBox - private let waitingForClose: NIOLockedValueBox - private let scheduledTimeoutTask: NIOLockedValueBox?> - private let frameSequence: NIOLockedValueBox - private let _pingInterval: NIOLockedValueBox + private var waitingForPong: Bool + private var waitingForClose: Bool + private var scheduledTimeoutTask: Scheduled? init(channel: Channel, type: PeerType) { self.channel = channel self.type = type - self.onTextCallback = .init({ _, _ in }, eventLoop: channel.eventLoop) - self.onBinaryCallback = .init({ _, _ in }, eventLoop: channel.eventLoop) - self.onPongCallback = .init({ _ in }, eventLoop: channel.eventLoop) - self.onPingCallback = .init({ _ in }, eventLoop: channel.eventLoop) - self.waitingForPong = .init(false) - self.waitingForClose = .init(false) - self.scheduledTimeoutTask = .init(nil) - self._closeCode = .init(nil) - self.frameSequence = .init(nil) - self._pingInterval = .init(nil) + self.onTextCallback = { _, _ in } + self.onBinaryCallback = { _, _ in } + self.onPongCallback = { _ in } + self.onPingCallback = { _ in } + self.waitingForPong = false + self.waitingForClose = false + self.scheduledTimeoutTask = nil } - @preconcurrency public func onText(_ callback: @Sendable @escaping (WebSocket, String) -> ()) { - self.onTextCallback.value = callback + public func onText(_ callback: @escaping (WebSocket, String) -> ()) { + self.onTextCallback = callback } - @preconcurrency public func onBinary(_ callback: @Sendable @escaping (WebSocket, ByteBuffer) -> ()) { - self.onBinaryCallback.value = callback + public func onBinary(_ callback: @escaping (WebSocket, ByteBuffer) -> ()) { + self.onBinaryCallback = callback } - @preconcurrency public func onPong(_ callback: @Sendable @escaping (WebSocket) -> ()) { - self.onPongCallback.value = callback + public func onPong(_ callback: @escaping (WebSocket) -> ()) { + self.onPongCallback = callback } - @preconcurrency public func onPing(_ callback: @Sendable @escaping (WebSocket) -> ()) { - self.onPingCallback.value = callback + public func onPing(_ callback: @escaping (WebSocket) -> ()) { + self.onPingCallback = callback } /// If set, this will trigger automatic pings on the connection. If ping is not answered before @@ -81,18 +72,14 @@ public final class WebSocket: Sendable { /// mechanism shutting down inactive connections, such as a Load Balancer deployed in /// front of the server. public var pingInterval: TimeAmount? { - get { - return _pingInterval.withLockedValue { $0 } - } - set { - _pingInterval.withLockedValue { $0 = newValue } - if newValue != nil { - if scheduledTimeoutTask.withLockedValue({ $0 == nil }) { - waitingForPong.withLockedValue { $0 = false } + didSet { + if pingInterval != nil { + if scheduledTimeoutTask == nil { + waitingForPong = false self.pingAndScheduleNextTimeoutTask() } } else { - scheduledTimeoutTask.withLockedValue { $0?.cancel() } + scheduledTimeoutTask?.cancel() } } } @@ -173,12 +160,12 @@ public final class WebSocket: Sendable { promise?.succeed(()) return } - guard !self.waitingForClose.withLockedValue({ $0 }) else { + guard !self.waitingForClose else { promise?.succeed(()) return } - self.waitingForClose.withLockedValue { $0 = true } - self._closeCode.withLockedValue { $0 = code } + self.waitingForClose = true + self.closeCode = code let codeAsInt = UInt16(webSocketErrorCode: code) let codeToSend: WebSocketErrorCode @@ -210,7 +197,7 @@ public final class WebSocket: Sendable { func handle(incoming frame: WebSocketFrame) { switch frame.opcode { case .connectionClose: - if self.waitingForClose.withLockedValue({ $0 }) { + if self.waitingForClose { // peer confirmed close, time to close channel self.channel.close(mode: .all, promise: nil) } else { @@ -236,7 +223,7 @@ public final class WebSocket: Sendable { if let maskingKey = maskingKey { frameData.webSocketUnmask(maskingKey) } - self.onPingCallback.value(self) + self.onPingCallback(self) self.send( raw: frameData.readableBytesView, opcode: .pong, @@ -253,19 +240,22 @@ public final class WebSocket: Sendable { if let maskingKey = maskingKey { frameData.webSocketUnmask(maskingKey) } - self.waitingForPong.withLockedValue { $0 = false } - self.onPongCallback.value(self) + self.waitingForPong = false + self.onPongCallback(self) } else { self.close(code: .protocolError, promise: nil) } case .text, .binary: // create a new frame sequence or use existing - self.frameSequence.withLockedValue { currentFrameSequence in - var frameSequence = currentFrameSequence ?? .init(type: frame.opcode) - // append this frame and update the sequence - frameSequence.append(frame) - currentFrameSequence = frameSequence + var frameSequence: WebSocketFrameSequence + if let existing = self.frameSequence { + frameSequence = existing + } else { + frameSequence = WebSocketFrameSequence(type: frame.opcode) } + // append this frame and update the sequence + frameSequence.append(frame) + self.frameSequence = frameSequence case .continuation: /// continuations are filtered by ``NIOWebSocketFrameAggregator`` preconditionFailure("We will never receive a continuation frame") @@ -276,29 +266,26 @@ public final class WebSocket: Sendable { // if this frame was final and we have a non-nil frame sequence, // output it to the websocket and clear storage - self.frameSequence.withLockedValue { currentFrameSequence in - if let frameSequence = currentFrameSequence, frame.fin { - switch frameSequence.type { - case .binary: - self.onBinaryCallback.value(self, frameSequence.binaryBuffer) - case .text: - self.onTextCallback.value(self, frameSequence.textBuffer) - case .ping, .pong: - assertionFailure("Control frames never have a frameSequence") - default: break - } - currentFrameSequence = nil + if let frameSequence = self.frameSequence, frame.fin { + switch frameSequence.type { + case .binary: + self.onBinaryCallback(self, frameSequence.binaryBuffer) + case .text: + self.onTextCallback(self, frameSequence.textBuffer) + case .ping, .pong: + assertionFailure("Control frames never have a frameSequence") + default: break } + self.frameSequence = nil } } - @Sendable private func pingAndScheduleNextTimeoutTask() { guard channel.isActive, let pingInterval = pingInterval else { return } - if waitingForPong.withLockedValue({ $0 }) { + if waitingForPong { // We never received a pong from our last ping, so the connection has timed out let promise = self.eventLoop.makePromise(of: Void.self) self.close(code: .unknown(1006), promise: promise) @@ -311,13 +298,11 @@ public final class WebSocket: Sendable { } } else { self.sendPing() - self.waitingForPong.withLockedValue { $0 = true } - self.scheduledTimeoutTask.withLockedValue { - $0 = self.eventLoop.scheduleTask( - deadline: .now() + pingInterval, - self.pingAndScheduleNextTimeoutTask - ) - } + self.waitingForPong = true + self.scheduledTimeoutTask = self.eventLoop.scheduleTask( + deadline: .now() + pingInterval, + self.pingAndScheduleNextTimeoutTask + ) } } @@ -326,31 +311,27 @@ public final class WebSocket: Sendable { } } -private struct WebSocketFrameSequence: Sendable { +private struct WebSocketFrameSequence { var binaryBuffer: ByteBuffer var textBuffer: String - let type: WebSocketOpcode - let lock: NIOLock + var type: WebSocketOpcode init(type: WebSocketOpcode) { self.binaryBuffer = ByteBufferAllocator().buffer(capacity: 0) self.textBuffer = .init() self.type = type - self.lock = .init() } mutating func append(_ frame: WebSocketFrame) { - self.lock.withLockVoid { - var data = frame.unmaskedData - switch type { - case .binary: - self.binaryBuffer.writeBuffer(&data) - case .text: - if let string = data.readString(length: data.readableBytes) { - self.textBuffer += string - } - default: break + var data = frame.unmaskedData + switch type { + case .binary: + self.binaryBuffer.writeBuffer(&data) + case .text: + if let string = data.readString(length: data.readableBytes) { + self.textBuffer += string } + default: break } } } diff --git a/Sources/WebSocketKit/WebSocketClient.swift b/Sources/WebSocketKit/WebSocketClient.swift index df96d6bb..08bf48e5 100644 --- a/Sources/WebSocketKit/WebSocketClient.swift +++ b/Sources/WebSocketKit/WebSocketClient.swift @@ -9,7 +9,7 @@ import NIOSSL import NIOTransportServices import Atomics -public final class WebSocketClient: Sendable { +public final class WebSocketClient { public enum Error: Swift.Error, LocalizedError { case invalidURL case invalidResponseStatus(HTTPResponseHead) @@ -21,7 +21,7 @@ public final class WebSocketClient: Sendable { public typealias EventLoopGroupProvider = NIOEventLoopGroupProvider - public struct Configuration: Sendable { + public struct Configuration { public var tlsConfiguration: TLSConfiguration? public var maxFrameSize: Int @@ -63,7 +63,6 @@ public final class WebSocketClient: Sendable { self.configuration = configuration } - @preconcurrency public func connect( scheme: String, host: String, @@ -71,7 +70,7 @@ public final class WebSocketClient: Sendable { path: String = "/", query: String? = nil, headers: HTTPHeaders = [:], - onUpgrade: @Sendable @escaping (WebSocket) -> () + onUpgrade: @escaping (WebSocket) -> () ) -> EventLoopFuture { self.connect(scheme: scheme, host: host, port: port, path: path, query: query, headers: headers, proxy: nil, onUpgrade: onUpgrade) } @@ -91,7 +90,6 @@ public final class WebSocketClient: Sendable { /// - proxyConnectDeadline: Deadline for establishing the proxy connection. /// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`. /// - Returns: An future which completes when the connection to the origin server is established. - @preconcurrency public func connect( scheme: String, host: String, @@ -103,7 +101,7 @@ public final class WebSocketClient: Sendable { proxyPort: Int? = nil, proxyHeaders: HTTPHeaders = [:], proxyConnectDeadline: NIODeadline = NIODeadline.distantFuture, - onUpgrade: @Sendable @escaping (WebSocket) -> () + onUpgrade: @escaping (WebSocket) -> () ) -> EventLoopFuture { assert(["ws", "wss"].contains(scheme)) let upgradePromise = self.group.any().makePromise(of: Void.self) @@ -132,7 +130,6 @@ public final class WebSocketClient: Sendable { headers: upgradeRequestHeaders, upgradePromise: upgradePromise ) - let httpUpgradeRequestHandlerBox = NIOLoopBound(httpUpgradeRequestHandler, eventLoop: channel.eventLoop) let websocketUpgrader = NIOWebSocketClientUpgrader( maxFrameSize: self.configuration.maxFrameSize, @@ -146,10 +143,9 @@ public final class WebSocketClient: Sendable { upgraders: [websocketUpgrader], completionHandler: { context in upgradePromise.succeed(()) - channel.pipeline.removeHandler(httpUpgradeRequestHandlerBox.value, promise: nil) + channel.pipeline.removeHandler(httpUpgradeRequestHandler, promise: nil) } ) - let configBox = NIOLoopBound(config, eventLoop: channel.eventLoop) if proxy == nil || scheme == "ws" { if scheme == "wss" { @@ -167,15 +163,15 @@ public final class WebSocketClient: Sendable { leftOverBytesStrategy: .forwardBytes, withClientUpgrade: config ).flatMap { - channel.pipeline.addHandler(httpUpgradeRequestHandlerBox.value) + channel.pipeline.addHandler(httpUpgradeRequestHandler) } } // TLS + proxy // we need to handle connecting with an additional CONNECT request let proxyEstablishedPromise = channel.eventLoop.makePromise(of: Void.self) - let encoder = NIOLoopBound(HTTPRequestEncoder(), eventLoop: channel.eventLoop) - let decoder = NIOLoopBound(ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .dropBytes)), eventLoop: channel.eventLoop) + let encoder = HTTPRequestEncoder() + let decoder = ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .dropBytes)) var connectHeaders = proxyHeaders connectHeaders.add(name: "Host", value: host) @@ -192,17 +188,17 @@ public final class WebSocketClient: Sendable { // They are then removed upon completion only to be re-added in `addHTTPClientHandlers`. // This is done because the HTTP decoder is not valid after an upgrade, the CONNECT request being counted as one. do { - try channel.pipeline.syncOperations.addHandler(encoder.value) - try channel.pipeline.syncOperations.addHandler(decoder.value) + try channel.pipeline.syncOperations.addHandler(encoder) + try channel.pipeline.syncOperations.addHandler(decoder) try channel.pipeline.syncOperations.addHandler(proxyRequestHandler) } catch { return channel.eventLoop.makeFailedFuture(error) } proxyEstablishedPromise.futureResult.flatMap { - channel.pipeline.removeHandler(decoder.value) + channel.pipeline.removeHandler(decoder) }.flatMap { - channel.pipeline.removeHandler(encoder.value) + channel.pipeline.removeHandler(encoder) }.whenComplete { result in switch result { case .success: @@ -213,9 +209,9 @@ public final class WebSocketClient: Sendable { try channel.pipeline.syncOperations.addHandler(tlsHandler) try channel.pipeline.syncOperations.addHTTPClientHandlers( leftOverBytesStrategy: .forwardBytes, - withClientUpgrade: configBox.value + withClientUpgrade: config ) - try channel.pipeline.syncOperations.addHandler(httpUpgradeRequestHandlerBox.value) + try channel.pipeline.syncOperations.addHandler(httpUpgradeRequestHandler) } catch { channel.pipeline.close(mode: .all, promise: nil) } @@ -234,7 +230,6 @@ public final class WebSocketClient: Sendable { } } - @Sendable private func makeTLSHandler(tlsConfiguration: TLSConfiguration?, host: String) throws -> NIOSSLClientHandler { let context = try NIOSSLContext( configuration: self.configuration.tlsConfiguration ?? .makeClientConfiguration() diff --git a/Sources/WebSocketKit/WebSocketHandler.swift b/Sources/WebSocketKit/WebSocketHandler.swift index 6e333dc3..45f266ce 100644 --- a/Sources/WebSocketKit/WebSocketHandler.swift +++ b/Sources/WebSocketKit/WebSocketHandler.swift @@ -4,7 +4,7 @@ import NIOWebSocket extension WebSocket { /// Stores configuration for a WebSocket client/server instance - public struct Configuration: Sendable { + public struct Configuration { /// Defends against small payloads in frame aggregation. /// See `NIOWebSocketFrameAggregator` for details. public var minNonFinalFragmentSize: Int @@ -33,10 +33,9 @@ extension WebSocket { /// - channel: NIO channel which the client will use to communicate. /// - onUpgrade: An escaping closure to be executed the channel is configured with the WebSocket handlers. /// - Returns: An future which completes when the WebSocket connection to the server is established. - @preconcurrency public static func client( on channel: Channel, - onUpgrade: @Sendable @escaping (WebSocket) -> () + onUpgrade: @escaping (WebSocket) -> () ) -> EventLoopFuture { return self.configure(on: channel, as: .client, with: Configuration(), onUpgrade: onUpgrade) } @@ -47,11 +46,10 @@ extension WebSocket { /// - config: Configuration for the client channel handlers. /// - onUpgrade: An escaping closure to be executed the channel is configured with the WebSocket handlers. /// - Returns: An future which completes when the WebSocket connection to the server is established. - @preconcurrency public static func client( on channel: Channel, config: Configuration, - onUpgrade: @Sendable @escaping (WebSocket) -> () + onUpgrade: @escaping (WebSocket) -> () ) -> EventLoopFuture { return self.configure(on: channel, as: .client, with: config, onUpgrade: onUpgrade) } @@ -61,10 +59,9 @@ extension WebSocket { /// - channel: NIO channel which the server will use to communicate. /// - onUpgrade: An escaping closure to be executed the channel is configured with the WebSocket handlers. /// - Returns: An future which completes when the WebSocket connection to the server is established. - @preconcurrency public static func server( on channel: Channel, - onUpgrade: @Sendable @escaping (WebSocket) -> () + onUpgrade: @escaping (WebSocket) -> () ) -> EventLoopFuture { return self.configure(on: channel, as: .server, with: Configuration(), onUpgrade: onUpgrade) } @@ -75,11 +72,10 @@ extension WebSocket { /// - config: Configuration for the server channel handlers. /// - onUpgrade: An escaping closure to be executed the channel is configured with the WebSocket handlers. /// - Returns: An future which completes when the WebSocket connection to the server is established. - @preconcurrency public static func server( on channel: Channel, config: Configuration, - onUpgrade: @Sendable @escaping (WebSocket) -> () + onUpgrade: @escaping (WebSocket) -> () ) -> EventLoopFuture { return self.configure(on: channel, as: .server, with: config, onUpgrade: onUpgrade) } @@ -88,7 +84,7 @@ extension WebSocket { on channel: Channel, as type: PeerType, with config: Configuration, - onUpgrade: @Sendable @escaping (WebSocket) -> () + onUpgrade: @escaping (WebSocket) -> () ) -> EventLoopFuture { let webSocket = WebSocket(channel: channel, type: type) diff --git a/Tests/WebSocketKitTests/AsyncWebSocketKitTests.swift b/Tests/WebSocketKitTests/AsyncWebSocketKitTests.swift index e662b2c6..e20ffa93 100644 --- a/Tests/WebSocketKitTests/AsyncWebSocketKitTests.swift +++ b/Tests/WebSocketKitTests/AsyncWebSocketKitTests.swift @@ -5,12 +5,6 @@ import NIOWebSocket @testable import WebSocketKit final class AsyncWebSocketKitTests: XCTestCase { - - override func setUp() async throws { - // Handy for catching hangs in the tests. See https://github.com/apple/swift-corelibs-xctest/issues/422#issuecomment-1310952437 - fflush(stdout) - } - func testWebSocketEcho() async throws { let server = try await ServerBootstrap.webSocket(on: self.elg) { req, ws in ws.onText { ws, text in @@ -27,15 +21,15 @@ final class AsyncWebSocketKitTests: XCTestCase { try await WebSocket.connect(to: "ws://localhost:\(port)", on: elg) { ws in do { + try await ws.send("hello") ws.onText { ws, string in + promise.succeed(string) do { try await ws.close() } catch { XCTFail("Failed to close websocket, error: \(error)") } - promise.succeed(string) } - try await ws.send("hello") } catch { promise.fail(error) } @@ -45,6 +39,23 @@ final class AsyncWebSocketKitTests: XCTestCase { XCTAssertEqual(result, "hello") try await server.close(mode: .all) } + + func testAlternateWebsocketConnectMethods() async throws { + let server = try await ServerBootstrap.webSocket(on: self.elg) { $1.onText { $0.send($1) } }.bind(host: "localhost", port: 0).get() + let promise = self.elg.any().makePromise(of: Void.self) + guard let port = server.localAddress?.port else { + return XCTFail("couldn't get port from \(String(reflecting: server.localAddress))") + } + try await WebSocket.connect(scheme: "ws", host: "localhost", port: port, on: self.elg) { (ws) async in + do { try await ws.send("hello") } catch { promise.fail(error); try? await ws.close() } + ws.onText { ws, _ in + promise.succeed(()) + do { try await ws.close() } catch { XCTFail("Failed to close websocket: \(String(reflecting: error))") } + } + } + try await promise.futureResult.get() + try await server.close(mode: .all) + } func testBadURLInWebsocketConnect() async throws { do { @@ -66,17 +77,11 @@ final class AsyncWebSocketKitTests: XCTestCase { return XCTFail("couldn't get port from \(String(reflecting: server.localAddress))") } try await WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { ws in + do { try await ws.send([0x01]) } catch { promise.fail(error); try? await ws.close() } ws.onBinary { ws, buf in + promise.succeed(.init(buf.readableBytesView)) do { try await ws.close() } catch { XCTFail("Failed to close websocket: \(String(reflecting: error))") } - promise.succeed(.init(buf.readableBytesView)) - } - - do { - try await ws.send([0x01]) - } catch { - try? await ws.close() - promise.fail(error); } } let result = try await promise.futureResult.get() @@ -91,19 +96,10 @@ final class AsyncWebSocketKitTests: XCTestCase { return XCTFail("couldn't get port from \(String(reflecting: server.localAddress))") } try await WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { (ws) async in + do { try await ws.sendPing() } catch { promise.fail(error); try? await ws.close() } ws.onPong { - do { - try await $0.close() - } catch { - XCTFail("Failed to close websocket: \(String(reflecting: error))") - } promise.succeed(()) - } - do { - try await ws.sendPing() - } catch { - try? await ws.close() - promise.fail(error) + do { try await $0.close() } catch { XCTFail("Failed to close websocket: \(String(reflecting: error))") } } } try await promise.futureResult.get() @@ -119,8 +115,8 @@ final class AsyncWebSocketKitTests: XCTestCase { try await WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { (ws) async in ws.pingInterval = .milliseconds(100) ws.onPong { - do { try await $0.close() } catch { XCTFail("Failed to close websocket: \(String(reflecting: error))") } promise.succeed(()) + do { try await $0.close() } catch { XCTFail("Failed to close websocket: \(String(reflecting: error))") } } } try await promise.futureResult.get() diff --git a/Tests/WebSocketKitTests/SSLTestHelpers.swift b/Tests/WebSocketKitTests/SSLTestHelpers.swift index a6776aab..d515fe3e 100644 --- a/Tests/WebSocketKitTests/SSLTestHelpers.swift +++ b/Tests/WebSocketKitTests/SSLTestHelpers.swift @@ -18,7 +18,6 @@ import Foundation import NIOCore @testable import NIOSSL - // This function generates a random number suitable for use in an X509 // serial field. This needs to be a positive number less than 2^159 // (such that it will fit into 20 ASN.1 bytes). diff --git a/Tests/WebSocketKitTests/WebSocketKitTests.swift b/Tests/WebSocketKitTests/WebSocketKitTests.swift index af3d3cf7..985cb00b 100644 --- a/Tests/WebSocketKitTests/WebSocketKitTests.swift +++ b/Tests/WebSocketKitTests/WebSocketKitTests.swift @@ -8,10 +8,6 @@ import NIOWebSocket @testable import WebSocketKit final class WebSocketKitTests: XCTestCase { - override func setUp() async throws { - fflush(stdout) - } - func testWebSocketEcho() throws { let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in ws.onText { ws, text in @@ -27,11 +23,11 @@ final class WebSocketKitTests: XCTestCase { let promise = elg.any().makePromise(of: String.self) let closePromise = elg.any().makePromise(of: Void.self) WebSocket.connect(to: "ws://localhost:\(port)", on: elg) { ws in + ws.send("hello") ws.onText { ws, string in - ws.close(promise: closePromise) promise.succeed(string) + ws.close(promise: closePromise) } - ws.send("hello") }.cascadeFailure(to: promise) try XCTAssertEqual(promise.futureResult.wait(), "hello") XCTAssertNoThrow(try closePromise.futureResult.wait()) @@ -60,8 +56,8 @@ final class WebSocketKitTests: XCTestCase { } WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { ws in - ws.onClose.cascade(to: clientClose) ws.send("close", promise: sendPromise) + ws.onClose.cascade(to: clientClose) }.cascadeFailure(to: sendPromise) XCTAssertNoThrow(try sendPromise.futureResult.wait()) @@ -87,12 +83,12 @@ final class WebSocketKitTests: XCTestCase { } WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { ws in + ws.send("close", promise: sendPromise) ws.onText { ws, text in if text == "close" { ws.close(promise: clientClose) } } - ws.send("close", promise: sendPromise) }.cascadeFailure(to: sendPromise) XCTAssertNoThrow(try sendPromise.futureResult.wait()) @@ -104,11 +100,11 @@ final class WebSocketKitTests: XCTestCase { func testImmediateSend() throws { let promise = self.elg.any().makePromise(of: String.self) let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in + ws.send("hello") ws.onText { ws, string in promise.succeed(string) ws.close(promise: nil) } - ws.send("hello") }.bind(host: "localhost", port: 0).wait() guard let port = server.localAddress?.port else { @@ -144,11 +140,11 @@ final class WebSocketKitTests: XCTestCase { } WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { ws in + ws.send(raw: pingPongData.readableBytesView, opcode: .ping) ws.onPong { ws in pongPromise.succeed("pong") ws.close(promise: nil) } - ws.send(raw: pingPongData.readableBytesView, opcode: .ping) }.cascadeFailure(to: pongPromise) try XCTAssertEqual(pingPromise.futureResult.wait(), "ping") @@ -178,13 +174,13 @@ final class WebSocketKitTests: XCTestCase { let promise = elg.any().makePromise(of: String.self) let closePromise = elg.any().makePromise(of: Void.self) WebSocket.connect(to: "ws://localhost:\(port)", on: elg) { ws in - ws.onText { ws, string in - ws.close(promise: closePromise) - promise.succeed(string) - } ws.send(.init(string: "Hel"), opcode: .text, fin: false) ws.send(.init(string: "lo! Vapor r"), opcode: .continuation, fin: false) ws.send(.init(string: "ules"), opcode: .continuation, fin: true) + ws.onText { ws, string in + promise.succeed(string) + ws.close(promise: closePromise) + } }.cascadeFailure(to: promise) try XCTAssertEqual(promise.futureResult.wait(), "Hello! Vapor rules the most") XCTAssertNoThrow(try closePromise.futureResult.wait()) @@ -208,8 +204,8 @@ final class WebSocketKitTests: XCTestCase { ws.send("goodbye") } ws.onClose.whenSuccess { - XCTAssertEqual(ws.closeCode, WebSocketErrorCode.normalClosure) promise.succeed(ws.closeCode!) + XCTAssertEqual(ws.closeCode, WebSocketErrorCode.normalClosure) } }.cascadeFailure(to: promise) @@ -232,8 +228,9 @@ final class WebSocketKitTests: XCTestCase { headers.contains(name: "Content-Length") || headers.contains(name: "Content-Type") ) - ws.close(promise: nil) promiseHasUnwantedHeaders.succeed(hasUnwantedHeaders) + + ws.close(promise: nil) }.bind(host: "localhost", port: 0).wait() guard let port = server.localAddress?.port else { @@ -257,8 +254,8 @@ final class WebSocketKitTests: XCTestCase { let promise = self.elg.any().makePromise(of: String.self) let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in - ws.close(promise: nil) promise.succeed(req.uri) + ws.close(promise: nil) }.bind(host: "localhost", port: 0).wait() guard let port = server.localAddress?.port else { @@ -284,6 +281,8 @@ final class WebSocketKitTests: XCTestCase { let shutdownPromise = self.elg.any().makePromise(of: Void.self) let server = try! ServerBootstrap.webSocket(on: self.elg) { req, ws in + ws.send("welcome!") + ws.onClose.whenComplete { print("ws.onClose done: \($0)") } @@ -300,8 +299,6 @@ final class WebSocketKitTests: XCTestCase { ws.send(text.reversed()) } } - - ws.send("welcome!") }.bind(host: "localhost", port: port).wait() print("Serving at ws://localhost:\(port)") @@ -382,8 +379,8 @@ final class WebSocketKitTests: XCTestCase { ) { ws in ws.send("hello") ws.onText { ws, string in - ws.close(promise: closePromise) promise.succeed(string) + ws.close(promise: closePromise) } }.cascadeFailure(to: promise) @@ -445,11 +442,11 @@ final class WebSocketKitTests: XCTestCase { proxyPort: localWebsocketBin.port, proxyHeaders: HTTPHeaders([("proxy-authorization", "token amFwcGxlc2VlZDpwYXNzMTIz")]) ) { ws in + ws.send("hello") ws.onText { ws, string in - ws.close(promise: closePromise) promise.succeed(string) + ws.close(promise: closePromise) } - ws.send("hello") }.cascadeFailure(to: promise) XCTAssertEqual(try promise.futureResult.wait(), "hello") @@ -491,11 +488,11 @@ final class WebSocketKitTests: XCTestCase { return XCTFail("couldn't get port from \(String(reflecting: server.localAddress))") } WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { ws in + ws.send([0x01]) ws.onBinary { ws, buf in - ws.close(promise: closePromise) promise.succeed(.init(buf.readableBytesView)) + ws.close(promise: closePromise) } - ws.send([0x01]) }.whenFailure { promise.fail($0) closePromise.fail($0) @@ -513,11 +510,11 @@ final class WebSocketKitTests: XCTestCase { return XCTFail("couldn't get port from \(String(reflecting: server.localAddress))") } WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { ws in + ws.sendPing() ws.onPong { - $0.close(promise: closePromise) promise.succeed() + $0.close(promise: closePromise) } - ws.sendPing() }.cascadeFailure(to: closePromise) XCTAssertNoThrow(try promise.futureResult.wait()) XCTAssertNoThrow(try closePromise.futureResult.wait()) @@ -534,8 +531,8 @@ final class WebSocketKitTests: XCTestCase { WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { ws in ws.pingInterval = .milliseconds(100) ws.onPong { - $0.close(promise: closePromise) promise.succeed() + $0.close(promise: closePromise) } }.cascadeFailure(to: closePromise) XCTAssertNoThrow(try promise.futureResult.wait())