Skip to content

Commit

Permalink
add support to bind to a specific device (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnnzhou authored Sep 30, 2024
1 parent 7a72794 commit 9d1f281
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 2 deletions.
8 changes: 8 additions & 0 deletions Sources/WebSocketKit/WebSocket+ChannelOption.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import Foundation
import NIOCore

extension ChannelOption where Self == ChannelOptions.Types.SocketOption {
public static func ipv6Option(_ name: NIOBSDSocket.Option) -> Self {
.init(level: .ipv6, name: name)
}
}
13 changes: 13 additions & 0 deletions Sources/WebSocketKit/WebSocket+Connect.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ extension WebSocket {
/// - Parameters:
/// - url: URL for the WebSocket server.
/// - headers: Headers to send to the WebSocket server.
/// - queueSize: the size of the buffer queue.
/// - deviceName: the device to which the data will be sent.
/// - configuration: Configuration for the WebSocket client.
/// - eventLoopGroup: Event loop group to be used by the WebSocket client.
/// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`.
Expand All @@ -46,6 +48,7 @@ extension WebSocket {
to url: URL,
headers: HTTPHeaders = [:],
queueSize: Int? = nil,
deviceName: String? = nil,
configuration: WebSocketClient.Configuration = .init(),
on eventLoopGroup: EventLoopGroup,
onUpgrade: @Sendable @escaping (WebSocket) -> ()
Expand All @@ -59,6 +62,7 @@ extension WebSocket {
query: url.query,
queueSize: queueSize,
headers: headers,
deviceName: deviceName,
configuration: configuration,
on: eventLoopGroup,
onUpgrade: onUpgrade
Expand All @@ -74,6 +78,7 @@ extension WebSocket {
/// - path: Path component of the URI for the WebSocket server.
/// - query: Query component of the URI for the WebSocket server.
/// - headers: Headers to send to the WebSocket server.
/// - deviceName: the device to which the data will be sent.
/// - configuration: Configuration for the WebSocket client.
/// - eventLoopGroup: Event loop group to be used by the WebSocket client.
/// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`.
Expand All @@ -87,6 +92,7 @@ extension WebSocket {
query: String? = nil,
queueSize: Int? = nil,
headers: HTTPHeaders = [:],
deviceName: String? = nil,
configuration: WebSocketClient.Configuration = .init(),
on eventLoopGroup: EventLoopGroup,
onUpgrade: @Sendable @escaping (WebSocket) -> ()
Expand All @@ -102,6 +108,7 @@ extension WebSocket {
query: query,
headers: headers,
maxQueueSize: queueSize,
deviceName: deviceName,
onUpgrade: onUpgrade
)
}
Expand All @@ -119,6 +126,7 @@ extension WebSocket {
/// - proxyPort: Port on which to connect to the proxy server.
/// - proxyHeaders: Headers to send to the proxy server.
/// - proxyConnectDeadline: Deadline for establishing the proxy connection.
/// - deviceName: the device to which the data will be sent.
/// - configuration: Configuration for the WebSocket client.
/// - eventLoopGroup: Event loop group to be used by the WebSocket client.
/// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`.
Expand All @@ -135,6 +143,7 @@ extension WebSocket {
proxyPort: Int? = nil,
proxyHeaders: HTTPHeaders = [:],
proxyConnectDeadline: NIODeadline = NIODeadline.distantFuture,
deviceName: String? = nil,
configuration: WebSocketClient.Configuration = .init(),
on eventLoopGroup: EventLoopGroup,
onUpgrade: @Sendable @escaping (WebSocket) -> ()
Expand All @@ -153,6 +162,7 @@ extension WebSocket {
proxyPort: proxyPort,
proxyHeaders: proxyHeaders,
proxyConnectDeadline: proxyConnectDeadline,
deviceName: deviceName,
onUpgrade: onUpgrade
)
}
Expand All @@ -166,6 +176,7 @@ extension WebSocket {
/// - proxyPort: Port on which to connect to the proxy server.
/// - proxyHeaders: Headers to send to the proxy server.
/// - proxyConnectDeadline: Deadline for establishing the proxy connection.
/// - deviceName: the device to which the data will be sent.
/// - configuration: Configuration for the WebSocket client.
/// - eventLoopGroup: Event loop group to be used by the WebSocket client.
/// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`.
Expand All @@ -178,6 +189,7 @@ extension WebSocket {
proxyPort: Int? = nil,
proxyHeaders: HTTPHeaders = [:],
proxyConnectDeadline: NIODeadline = NIODeadline.distantFuture,
deviceName: String? = nil,
configuration: WebSocketClient.Configuration = .init(),
on eventLoopGroup: EventLoopGroup,
onUpgrade: @Sendable @escaping (WebSocket) -> ()
Expand All @@ -197,6 +209,7 @@ extension WebSocket {
proxyPort: proxyPort,
proxyHeaders: proxyHeaders,
proxyConnectDeadline: proxyConnectDeadline,
deviceName: deviceName,
on: eventLoopGroup,
onUpgrade: onUpgrade
)
Expand Down
11 changes: 11 additions & 0 deletions Sources/WebSocketKit/WebSocket+SocketOptions.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import Foundation
import NIOCore

extension NIOBSDSocket.Option {
#if canImport(Darwin)
public static let ip_bound_if: NIOBSDSocket.Option = Self(rawValue: IP_BOUND_IF)
public static let ipv6_bound_if: NIOBSDSocket.Option = Self(rawValue: IPV6_BOUND_IF)
#elseif canImport(Glibc)
public static let so_bindtodevice = Self(rawValue: SO_BINDTODEVICE)
#endif
}
56 changes: 54 additions & 2 deletions Sources/WebSocketKit/WebSocketClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ public final class WebSocketClient: Sendable {
case invalidURL
case invalidResponseStatus(HTTPResponseHead)
case alreadyShutdown
case invalidAddress
public var errorDescription: String? {
return "\(self)"
}
Expand Down Expand Up @@ -85,9 +86,10 @@ public final class WebSocketClient: Sendable {
query: String? = nil,
headers: HTTPHeaders = [:],
maxQueueSize: Int? = nil,
deviceName: String? = nil,
onUpgrade: @Sendable @escaping (WebSocket) -> Void
) -> EventLoopFuture<Void> {
self.connect(scheme: scheme, host: host, port: port, path: path, query: query, maxQueueSize: maxQueueSize, headers: headers, proxy: nil, onUpgrade: onUpgrade)
self.connect(scheme: scheme, host: host, port: port, path: path, query: query, maxQueueSize: maxQueueSize, headers: headers, proxy: nil, deviceName: deviceName, onUpgrade: onUpgrade)
}

/// Establish a WebSocket connection via a proxy server.
Expand All @@ -103,6 +105,7 @@ public final class WebSocketClient: Sendable {
/// - proxyPort: Port on which to connect to the proxy server.
/// - proxyHeaders: Headers to send to the proxy server.
/// - proxyConnectDeadline: Deadline for establishing the proxy connection.
/// - deviceName: the device to which the data will be sent.
/// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`.
/// - Returns: A future which completes when the connection to the origin server is established.
@preconcurrency
Expand All @@ -118,6 +121,7 @@ public final class WebSocketClient: Sendable {
proxyPort: Int? = nil,
proxyHeaders: HTTPHeaders = [:],
proxyConnectDeadline: NIODeadline = NIODeadline.distantFuture,
deviceName: String? = nil,
onUpgrade: @Sendable @escaping (WebSocket) -> Void
) -> EventLoopFuture<Void> {
assert(["ws", "wss"].contains(scheme))
Expand All @@ -140,6 +144,51 @@ public final class WebSocketClient: Sendable {
}
}

let resolvedAddress: SocketAddress
do {
resolvedAddress = try SocketAddress.makeAddressResolvingHost(host, port: port)
} catch {
return channel.eventLoop.makeFailedFuture(error)
}

var bindDevice: NIONetworkDevice?
do {
for device in try System.enumerateDevices() {
if device.name == deviceName, let address = device.address {
switch (address.protocol, resolvedAddress.protocol) {
case (.inet, .inet), (.inet6, .inet6):
bindDevice = device
default:
continue
}
}
if bindDevice != nil {
break
}
}
} catch {
return channel.eventLoop.makeFailedFuture(error)
}

func bindToDevice() -> EventLoopFuture<Void> {
if let device = bindDevice {
#if canImport(Darwin)
switch device.address {
case .v4:
return channel.setOption(.ipOption(.ip_bound_if), value: CInt(device.interfaceIndex))
case .v6:
return channel.setOption(.ipv6Option(.ipv6_bound_if), value: CInt(device.interfaceIndex))
default:
return channel.eventLoop.makeFailedFuture(WebSocketClient.Error.invalidAddress)
}
#elseif canImport(Glibc)
return channel.setOption(.socketOption(.so_bindtodevice), value: device.interfaceIndex)
#endif
} else {
return channel.eventLoop.makeSucceededVoidFuture()
}
}

let httpUpgradeRequestHandler = HTTPUpgradeRequestHandler(
host: host,
path: uri,
Expand Down Expand Up @@ -177,11 +226,12 @@ public final class WebSocketClient: Sendable {
return channel.pipeline.close(mode: .all)
}
}

return channel.pipeline.addHTTPClientHandlers(
leftOverBytesStrategy: .forwardBytes,
withClientUpgrade: config
).flatMap {
return bindToDevice()
}.flatMap {
if let maxQueueSize = maxQueueSize {
return channel.setOption(ChannelOptions.writeBufferWaterMark, value: .init(low: maxQueueSize, high: maxQueueSize))
}
Expand Down Expand Up @@ -228,6 +278,8 @@ public final class WebSocketClient: Sendable {
return channel.setOption(ChannelOptions.writeBufferWaterMark, value: .init(low: maxQueueSize, high: maxQueueSize))
}
return channel.eventLoop.makeSucceededVoidFuture()
}.flatMap {
return bindToDevice()
}.whenComplete { result in
switch result {
case .success:
Expand Down

0 comments on commit 9d1f281

Please sign in to comment.