From 53fe0639a98903858d0196b699720decb42aee7b Mon Sep 17 00:00:00 2001 From: Thomas Krajacic Date: Mon, 29 May 2023 12:48:27 +0200 Subject: [PATCH] Pass a copy of the control frame buffer to ping/pong callbacks (#116) * Pass a copy of the control frame buffer to callbacks * Add back old API Since the new and old methods are only overloads and share the same name, the 'renamed' parameter of the deprecation warning doesn't help. * Allow specifying payload when sending ping * Remove default value in favor of method forwarding This preserves the signature of the original method and doesn't break the API * Apply suggestions from code review New APIs should use safe code --------- Co-authored-by: Tim Condon <0xTim@users.noreply.github.com> --- .../Concurrency/WebSocket+Concurrency.swift | 32 +++++++++++++++++-- Sources/WebSocketKit/WebSocket.swift | 32 +++++++++++++------ .../WebSocketKitTests/WebSocketKitTests.swift | 7 ++-- 3 files changed, 57 insertions(+), 14 deletions(-) diff --git a/Sources/WebSocketKit/Concurrency/WebSocket+Concurrency.swift b/Sources/WebSocketKit/Concurrency/WebSocket+Concurrency.swift index 56085bef..a98afdd0 100644 --- a/Sources/WebSocketKit/Concurrency/WebSocket+Concurrency.swift +++ b/Sources/WebSocketKit/Concurrency/WebSocket+Concurrency.swift @@ -19,8 +19,12 @@ extension WebSocket { } public func sendPing() async throws { + try await sendPing(Data()) + } + + public func sendPing(_ data: Data) async throws { let promise = eventLoop.makePromise(of: Void.self) - sendPing(promise: promise) + sendPing(data, promise: promise) return try await promise.futureResult.get() } @@ -60,9 +64,20 @@ extension WebSocket { } } + public func onPong(_ callback: @Sendable @escaping (WebSocket, ByteBuffer) async -> ()) { + self.eventLoop.execute { + self.onPong { socket, data in + Task { + await callback(socket, data) + } + } + } + } + + @available(*, deprecated, message: "Please use `onPong { socket, data in /* … */ }` with the additional `data` parameter.") @preconcurrency public func onPong(_ callback: @Sendable @escaping (WebSocket) async -> ()) { self.eventLoop.execute { - self.onPong { socket in + self.onPong { socket, _ in Task { await callback(socket) } @@ -70,9 +85,20 @@ extension WebSocket { } } + public func onPing(_ callback: @Sendable @escaping (WebSocket, ByteBuffer) async -> ()) { + self.eventLoop.execute { + self.onPing { socket, data in + Task { + await callback(socket, data) + } + } + } + } + + @available(*, deprecated, message: "Please use `onPing { socket, data in /* … */ }` with the additional `data` parameter.") @preconcurrency public func onPing(_ callback: @Sendable @escaping (WebSocket) async -> ()) { self.eventLoop.execute { - self.onPing { socket in + self.onPing { socket, _ in Task { await callback(socket) } diff --git a/Sources/WebSocketKit/WebSocket.swift b/Sources/WebSocketKit/WebSocket.swift index 245d97a2..e0c11f55 100644 --- a/Sources/WebSocketKit/WebSocket.swift +++ b/Sources/WebSocketKit/WebSocket.swift @@ -34,8 +34,8 @@ public final class WebSocket: Sendable { internal let channel: Channel private let onTextCallback: NIOLoopBoundBox<@Sendable (WebSocket, String) -> ()> private let onBinaryCallback: NIOLoopBoundBox<@Sendable (WebSocket, ByteBuffer) -> ()> - private let onPongCallback: NIOLoopBoundBox<@Sendable (WebSocket) -> ()> - private let onPingCallback: NIOLoopBoundBox<@Sendable (WebSocket) -> ()> + private let onPongCallback: NIOLoopBoundBox<@Sendable (WebSocket, ByteBuffer) -> ()> + private let onPingCallback: NIOLoopBoundBox<@Sendable (WebSocket, ByteBuffer) -> ()> private let type: PeerType private let waitingForPong: NIOLockedValueBox private let waitingForClose: NIOLockedValueBox @@ -48,8 +48,8 @@ public final class WebSocket: Sendable { self.type = type self.onTextCallback = .init({ _, _ in }, eventLoop: channel.eventLoop) self.onBinaryCallback = .init({ _, _ in }, eventLoop: channel.eventLoop) - self.onPongCallback = .init({ _ in }, eventLoop: channel.eventLoop) - self.onPingCallback = .init({ _ in }, eventLoop: channel.eventLoop) + self.onPongCallback = .init({ _, _ in }, eventLoop: channel.eventLoop) + self.onPingCallback = .init({ _, _ in }, eventLoop: channel.eventLoop) self.waitingForPong = .init(false) self.waitingForClose = .init(false) self.scheduledTimeoutTask = .init(nil) @@ -66,13 +66,23 @@ public final class WebSocket: Sendable { self.onBinaryCallback.value = callback } - @preconcurrency public func onPong(_ callback: @Sendable @escaping (WebSocket) -> ()) { + public func onPong(_ callback: @Sendable @escaping (WebSocket, ByteBuffer) -> ()) { self.onPongCallback.value = callback } + + @available(*, deprecated, message: "Please use `onPong { socket, data in /* … */ }` with the additional `data` parameter.") + @preconcurrency public func onPong(_ callback: @Sendable @escaping (WebSocket) -> ()) { + self.onPongCallback.value = { ws, _ in callback(ws) } + } - @preconcurrency public func onPing(_ callback: @Sendable @escaping (WebSocket) -> ()) { + public func onPing(_ callback: @Sendable @escaping (WebSocket, ByteBuffer) -> ()) { self.onPingCallback.value = callback } + + @available(*, deprecated, message: "Please use `onPing { socket, data in /* … */ }` with the additional `data` parameter.") + @preconcurrency public func onPing(_ callback: @Sendable @escaping (WebSocket) -> ()) { + self.onPingCallback.value = { ws, _ in callback(ws) } + } /// If set, this will trigger automatic pings on the connection. If ping is not answered before /// the next ping is sent, then the WebSocket will be presumed inactive and will be closed @@ -112,8 +122,12 @@ public final class WebSocket: Sendable { } public func sendPing(promise: EventLoopPromise? = nil) { + sendPing(Data(), promise: promise) + } + + public func sendPing(_ data: Data, promise: EventLoopPromise? = nil) { self.send( - raw: Data(), + raw: data, opcode: .ping, fin: true, promise: promise @@ -236,7 +250,7 @@ public final class WebSocket: Sendable { if let maskingKey = maskingKey { frameData.webSocketUnmask(maskingKey) } - self.onPingCallback.value(self) + self.onPingCallback.value(self, ByteBuffer(buffer: frameData)) self.send( raw: frameData.readableBytesView, opcode: .pong, @@ -254,7 +268,7 @@ public final class WebSocket: Sendable { frameData.webSocketUnmask(maskingKey) } self.waitingForPong.withLockedValue { $0 = false } - self.onPongCallback.value(self) + self.onPongCallback.value(self, ByteBuffer(buffer: frameData)) } else { self.close(code: .protocolError, promise: nil) } diff --git a/Tests/WebSocketKitTests/WebSocketKitTests.swift b/Tests/WebSocketKitTests/WebSocketKitTests.swift index af3d3cf7..9fa402ba 100644 --- a/Tests/WebSocketKitTests/WebSocketKitTests.swift +++ b/Tests/WebSocketKitTests/WebSocketKitTests.swift @@ -133,7 +133,8 @@ final class WebSocketKitTests: XCTestCase { let pingPongData = ByteBuffer(bytes: "Vapor rules".utf8) let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in - ws.onPing { ws in + ws.onPing { ws, data in + XCTAssertEqual(pingPongData, data) pingPromise.succeed("ping") } }.bind(host: "localhost", port: 0).wait() @@ -144,7 +145,9 @@ final class WebSocketKitTests: XCTestCase { } WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { ws in - ws.onPong { ws in + ws.sendPing(Data(pingPongData.readableBytesView)) + ws.onPong { ws, data in + XCTAssertEqual(pingPongData, data) pongPromise.succeed("pong") ws.close(promise: nil) }