Skip to content

Commit

Permalink
Add support for proxying in WebsocketClient (#130)
Browse files Browse the repository at this point in the history
* WebsocketClient supports proxying

Added support for TLS and plain text proxying of Websocket traffic.

In the TLS case a CONNECT header is first sent establishing the proxied
traffic.

In the plain text case the modified URI in the initial upgrade request
header indicates to the proxy server that the traffic is to be proxied.

Use `NIOWebSocketFrameAggregator` to handle aggregating frame
fragments. This brings with it more protections e.g. against memory
exhaustion.

Accompanying config has been added to support this change.

* Reduce allocations and copies in WebSocket.send

Reduce allocation and copies necessary to send `ByteBuffer` and `ByteBufferView` through `WebSocket.send`.

In fact sending `ByteBuffer` or `ByteBufferView` doesn’t require any allocation or copy of the data. Sending a `String` now correctly pre allocates the `ByteBuffer` if multibyte characters are present in the `String`.

Remove custom random websocket mask generation which would only generate bytes between `UInt8.min..<UInt8.max`, therefore excluding `UInt8.max` aka `255`.

* add DocC comments

* DocC comments for new APIs
  • Loading branch information
rnro authored Apr 11, 2023
1 parent 2b88859 commit 2166cbe
Show file tree
Hide file tree
Showing 8 changed files with 911 additions and 65 deletions.
9 changes: 9 additions & 0 deletions NOTICES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,12 @@ This product contains a derivation of `NIOSSLTestHelpers.swift` from SwiftNIO SS
* https://www.apache.org/licenses/LICENSE-2.0
* HOMEPAGE:
* https://github.com/apple/swift-nio-ssl

---

This product contains derivations of "HTTPProxySimulator" and "HTTPBin" test utils from AsyncHTTPClient.

* LICENSE (Apache License 2.0):
* https://www.apache.org/licenses/LICENSE-2.0
* HOMEPAGE:
* https://github.com/swift-server/async-http-client
2 changes: 2 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ let package = Package(
],
dependencies: [
.package(url: "https://github.com/apple/swift-nio.git", from: "2.33.0"),
.package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.16.0"),
.package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.14.0"),
.package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.11.4"),
.package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.2"),
Expand All @@ -22,6 +23,7 @@ let package = Package(
.product(name: "NIO", package: "swift-nio"),
.product(name: "NIOCore", package: "swift-nio"),
.product(name: "NIOConcurrencyHelpers", package: "swift-nio"),
.product(name: "NIOExtras", package: "swift-nio-extras"),
.product(name: "NIOFoundationCompat", package: "swift-nio"),
.product(name: "NIOHTTP1", package: "swift-nio"),
.product(name: "NIOSSL", package: "swift-nio-ssl"),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import NIO
import NIOHTTP1

final class HTTPInitialRequestHandler: ChannelInboundHandler, RemovableChannelHandler {
final class HTTPUpgradeRequestHandler: ChannelInboundHandler, RemovableChannelHandler {
typealias InboundIn = HTTPClientResponsePart
typealias OutboundOut = HTTPClientRequestPart

Expand All @@ -11,6 +11,8 @@ final class HTTPInitialRequestHandler: ChannelInboundHandler, RemovableChannelHa
let headers: HTTPHeaders
let upgradePromise: EventLoopPromise<Void>

private var requestSent = false

init(host: String, path: String, query: String?, headers: HTTPHeaders, upgradePromise: EventLoopPromise<Void>) {
self.host = host
self.path = path
Expand All @@ -20,10 +22,33 @@ final class HTTPInitialRequestHandler: ChannelInboundHandler, RemovableChannelHa
}

func channelActive(context: ChannelHandlerContext) {
self.sendRequest(context: context)
context.fireChannelActive()
}

func handlerAdded(context: ChannelHandlerContext) {
if context.channel.isActive {
self.sendRequest(context: context)
}
}

private func sendRequest(context: ChannelHandlerContext) {
if self.requestSent {
// we might run into this handler twice, once in handlerAdded and once in channelActive.
return
}
self.requestSent = true

var headers = self.headers
headers.add(name: "Host", value: self.host)

var uri = self.path.hasPrefix("/") ? self.path : "/" + self.path
var uri: String
if self.path.hasPrefix("/") || self.path.hasPrefix("ws://") || self.path.hasPrefix("wss://") {
uri = self.path
} else {
uri = "/" + self.path
}

if let query = self.query {
uri += "?\(query)"
}
Expand All @@ -43,10 +68,13 @@ final class HTTPInitialRequestHandler: ChannelInboundHandler, RemovableChannelHa
}

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
// `NIOHTTPClientUpgradeHandler` should consume the first response in the success case,
// any response we see here indicates a failure. Report the failure and tidy up at the end of the response.
let clientResponse = self.unwrapInboundIn(data)
switch clientResponse {
case .head(let responseHead):
self.upgradePromise.fail(WebSocketClient.Error.invalidResponseStatus(responseHead))
let error = WebSocketClient.Error.invalidResponseStatus(responseHead)
self.upgradePromise.fail(error)
case .body: break
case .end:
context.close(promise: nil)
Expand Down
125 changes: 125 additions & 0 deletions Sources/WebSocketKit/WebSocket+Connect.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@ import NIOHTTP1
import Foundation

extension WebSocket {
/// Establish a WebSocket connection.
///
/// - Parameters:
/// - url: URL for the WebSocket server.
/// - headers: Headers to send to the WebSocket server.
/// - configuration: Configuration for the WebSocket client.
/// - eventLoopGroup: Event loop group to be used by the WebSocket client.
/// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`.
/// - Returns: An future which completes when the connection to the WebSocket server is established.
public static func connect(
to url: String,
headers: HTTPHeaders = [:],
Expand All @@ -22,6 +31,15 @@ extension WebSocket {
)
}

/// Establish a WebSocket connection.
///
/// - Parameters:
/// - url: URL for the WebSocket server.
/// - headers: Headers to send to the WebSocket server.
/// - configuration: Configuration for the WebSocket client.
/// - eventLoopGroup: Event loop group to be used by the WebSocket client.
/// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`.
/// - Returns: An future which completes when the connection to the WebSocket server is established.
public static func connect(
to url: URL,
headers: HTTPHeaders = [:],
Expand All @@ -43,6 +61,19 @@ extension WebSocket {
)
}

/// Establish a WebSocket connection.
///
/// - Parameters:
/// - scheme: Scheme component of the URI for the WebSocket server.
/// - host: Host component of the URI for the WebSocket server.
/// - port: Port on which to connect to the WebSocket server.
/// - path: Path component of the URI for the WebSocket server.
/// - query: Query component of the URI for the WebSocket server.
/// - headers: Headers to send to the WebSocket server.
/// - configuration: Configuration for the WebSocket client.
/// - eventLoopGroup: Event loop group to be used by the WebSocket client.
/// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`.
/// - Returns: An future which completes when the connection to the WebSocket server is established.
public static func connect(
scheme: String = "ws",
host: String,
Expand All @@ -67,4 +98,98 @@ extension WebSocket {
onUpgrade: onUpgrade
)
}

/// Establish a WebSocket connection via a proxy server.
///
/// - Parameters:
/// - scheme: Scheme component of the URI for the origin server.
/// - host: Host component of the URI for the origin server.
/// - port: Port on which to connect to the origin server.
/// - path: Path component of the URI for the origin server.
/// - query: Query component of the URI for the origin server.
/// - headers: Headers to send to the origin server.
/// - proxy: Host component of the URI for the proxy server.
/// - proxyPort: Port on which to connect to the proxy server.
/// - proxyHeaders: Headers to send to the proxy server.
/// - proxyConnectDeadline: Deadline for establishing the proxy connection.
/// - configuration: Configuration for the WebSocket client.
/// - eventLoopGroup: Event loop group to be used by the WebSocket client.
/// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`.
/// - Returns: An future which completes when the connection to the origin server is established.
public static func connect(
scheme: String = "ws",
host: String,
port: Int = 80,
path: String = "/",
query: String? = nil,
headers: HTTPHeaders = [:],
proxy: String?,
proxyPort: Int? = nil,
proxyHeaders: HTTPHeaders = [:],
proxyConnectDeadline: NIODeadline = NIODeadline.distantFuture,
configuration: WebSocketClient.Configuration = .init(),
on eventLoopGroup: EventLoopGroup,
onUpgrade: @escaping (WebSocket) -> ()
) -> EventLoopFuture<Void> {
return WebSocketClient(
eventLoopGroupProvider: .shared(eventLoopGroup),
configuration: configuration
).connect(
scheme: scheme,
host: host,
port: port,
path: path,
query: query,
headers: headers,
proxy: proxy,
proxyPort: proxyPort,
proxyHeaders: proxyHeaders,
proxyConnectDeadline: proxyConnectDeadline,
onUpgrade: onUpgrade
)
}


/// Description
/// - Parameters:
/// - url: URL for the origin server.
/// - headers: Headers to send to the origin server.
/// - proxy: Host component of the URI for the proxy server.
/// - proxyPort: Port on which to connect to the proxy server.
/// - proxyHeaders: Headers to send to the proxy server.
/// - proxyConnectDeadline: Deadline for establishing the proxy connection.
/// - configuration: Configuration for the WebSocket client.
/// - eventLoopGroup: Event loop group to be used by the WebSocket client.
/// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`.
/// - Returns: An future which completes when the connection to the origin server is established.
public static func connect(
to url: String,
headers: HTTPHeaders = [:],
proxy: String?,
proxyPort: Int? = nil,
proxyHeaders: HTTPHeaders = [:],
proxyConnectDeadline: NIODeadline = NIODeadline.distantFuture,
configuration: WebSocketClient.Configuration = .init(),
on eventLoopGroup: EventLoopGroup,
onUpgrade: @escaping (WebSocket) -> ()
) -> EventLoopFuture<Void> {
guard let url = URL(string: url) else {
return eventLoopGroup.next().makeFailedFuture(WebSocketClient.Error.invalidURL)
}
let scheme = url.scheme ?? "ws"
return self.connect(
scheme: scheme,
host: url.host ?? "localhost",
port: url.port ?? (scheme == "wss" ? 443 : 80),
path: url.path,
query: url.query,
headers: headers,
proxy: proxy,
proxyPort: proxyPort,
proxyHeaders: proxyHeaders,
proxyConnectDeadline: proxyConnectDeadline,
on: eventLoopGroup,
onUpgrade: onUpgrade
)
}
}
56 changes: 34 additions & 22 deletions Sources/WebSocketKit/WebSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ public final class WebSocket {
self.channel.closeFuture
}

private let channel: Channel
@usableFromInline
/* private but @usableFromInline */
internal let channel: Channel
private var onTextCallback: (WebSocket, String) -> ()
private var onBinaryCallback: (WebSocket, ByteBuffer) -> ()
private var onPongCallback: (WebSocket) -> ()
Expand Down Expand Up @@ -64,10 +66,10 @@ public final class WebSocket {
}

/// If set, this will trigger automatic pings on the connection. If ping is not answered before
/// the next ping is sent, then the WebSocket will be presumed innactive and will be closed
/// the next ping is sent, then the WebSocket will be presumed inactive and will be closed
/// automatically.
/// These pings can also be used to keep the WebSocket alive if there is some other timeout
/// mechanism shutting down innactive connections, such as a Load Balancer deployed in
/// mechanism shutting down inactive connections, such as a Load Balancer deployed in
/// front of the server.
public var pingInterval: TimeAmount? {
didSet {
Expand All @@ -82,13 +84,13 @@ public final class WebSocket {
}
}

@inlinable
public func send<S>(_ text: S, promise: EventLoopPromise<Void>? = nil)
where S: Collection, S.Element == Character
{
let string = String(text)
var buffer = channel.allocator.buffer(capacity: text.count)
buffer.writeString(string)
self.send(raw: buffer.readableBytesView, opcode: .text, fin: true, promise: promise)
let buffer = channel.allocator.buffer(string: string)
self.send(buffer, opcode: .text, fin: true, promise: promise)

}

Expand All @@ -105,6 +107,7 @@ public final class WebSocket {
)
}

@inlinable
public func send<Data>(
raw data: Data,
opcode: WebSocketOpcode,
Expand All @@ -113,13 +116,32 @@ public final class WebSocket {
)
where Data: DataProtocol
{
var buffer = channel.allocator.buffer(capacity: data.count)
buffer.writeBytes(data)
if let byteBufferView = data as? ByteBufferView {
// optimisation: converting from `ByteBufferView` to `ByteBuffer` doesn't allocate or copy any data
send(ByteBuffer(byteBufferView), opcode: opcode, fin: fin, promise: promise)
} else {
let buffer = channel.allocator.buffer(bytes: data)
send(buffer, opcode: opcode, fin: fin, promise: promise)
}
}

/// Send the provided data in a WebSocket frame.
/// - Parameters:
/// - data: Data to be sent.
/// - opcode: Frame opcode.
/// - fin: The value of the fin bit.
/// - promise: A promise to be completed when the write is complete.
public func send(
_ data: ByteBuffer,
opcode: WebSocketOpcode = .binary,
fin: Bool = true,
promise: EventLoopPromise<Void>? = nil
) {
let frame = WebSocketFrame(
fin: fin,
opcode: opcode,
maskKey: self.makeMaskKey(),
data: buffer
data: data
)
self.channel.writeAndFlush(frame, promise: promise)
}
Expand Down Expand Up @@ -164,11 +186,7 @@ public final class WebSocket {
func makeMaskKey() -> WebSocketMaskingKey? {
switch type {
case .client:
var bytes: [UInt8] = []
for _ in 0..<4 {
bytes.append(.random(in: .min ..< .max))
}
return WebSocketMaskingKey(bytes)
return WebSocketMaskingKey.random()
case .server:
return nil
}
Expand Down Expand Up @@ -237,14 +255,8 @@ public final class WebSocket {
frameSequence.append(frame)
self.frameSequence = frameSequence
case .continuation:
// we must have an existing sequence
if var frameSequence = self.frameSequence {
// append this frame and update
frameSequence.append(frame)
self.frameSequence = frameSequence
} else {
self.close(code: .protocolError, promise: nil)
}
/// continuations are filtered by ``NIOWebSocketFrameAggregator``
preconditionFailure("We will never receive a continuation frame")
default:
// We ignore all other frames.
break
Expand Down
Loading

0 comments on commit 2166cbe

Please sign in to comment.