diff --git a/Package@swift-4.0.swift b/Package@swift-4.0.swift index 215b92a..f7a060e 100644 --- a/Package@swift-4.0.swift +++ b/Package@swift-4.0.swift @@ -22,7 +22,7 @@ import PackageDescription var dependencies: [Package.Dependency] = [ .package(url: "https://github.com/IBM-Swift/LoggerAPI.git", from: "1.7.3"), .package(url: "https://github.com/IBM-Swift/BlueSocket.git", from: "1.0.0"), - .package(url: "https://github.com/IBM-Swift/CCurl.git", from: "1.0.0"), + .package(url: "https://github.com/IBM-Swift/CCurl.git", from: "1.1.0"), .package(url: "https://github.com/IBM-Swift/BlueSSLService.git", from: "1.0.0") ] diff --git a/Package@swift-4.1.swift b/Package@swift-4.1.swift index 215b92a..f7a060e 100644 --- a/Package@swift-4.1.swift +++ b/Package@swift-4.1.swift @@ -22,7 +22,7 @@ import PackageDescription var dependencies: [Package.Dependency] = [ .package(url: "https://github.com/IBM-Swift/LoggerAPI.git", from: "1.7.3"), .package(url: "https://github.com/IBM-Swift/BlueSocket.git", from: "1.0.0"), - .package(url: "https://github.com/IBM-Swift/CCurl.git", from: "1.0.0"), + .package(url: "https://github.com/IBM-Swift/CCurl.git", from: "1.1.0"), .package(url: "https://github.com/IBM-Swift/BlueSSLService.git", from: "1.0.0") ] diff --git a/Sources/CCurl/shim.h b/Sources/CCurl/shim.h index 4b56b6f..ff48e4d 100644 --- a/Sources/CCurl/shim.h +++ b/Sources/CCurl/shim.h @@ -18,6 +18,7 @@ #define CurlHelpers_h #import +#import #define CURL_TRUE 1 #define CURL_FALSE 0 @@ -81,5 +82,16 @@ static inline CURLcode curlHelperGetInfoLong(CURL *curl, CURLINFO info, long *da return curl_easy_getinfo(curl, info, data); } +static inline CURLMcode curlHelperSetMultiOpt(CURLM *curlMulti, CURLMoption option, long data) { + return curl_multi_setopt(curlMulti, option, data); +} + +static inline CURLcode curlHelperSetUnixSocketPath(CURL *curl, const char *data) { +#ifdef CURL_VERSION_UNIX_SOCKETS + return curl_easy_setopt(curl, CURLOPT_UNIX_SOCKET_PATH, data); +#else + return CURLE_NOT_BUILT_IN; +#endif +} #endif /* CurlHelpers_h */ diff --git a/Sources/KituraNet/ClientRequest.swift b/Sources/KituraNet/ClientRequest.swift index 0141bd7..f20d66f 100644 --- a/Sources/KituraNet/ClientRequest.swift +++ b/Sources/KituraNet/ClientRequest.swift @@ -148,6 +148,10 @@ public class ClientRequest { /// Should HTTP/2 protocol be used private var useHTTP2 = false + + /// The Unix domain socket path used for the request + private var unixDomainSocketPath: String? = nil + /// Data that represents the "HTTP/2 " header status line prefix fileprivate static let Http2StatusLineVersion = "HTTP/2 ".data(using: .utf8)! @@ -209,7 +213,7 @@ public class ClientRequest { /// If present, the client will try to use HTTP/2 protocol for the connection. case useHTTP2 - + } /** @@ -238,10 +242,12 @@ public class ClientRequest { /// Initializes a `ClientRequest` instance /// - /// - Parameter options: An array of `Options' describing the request + /// - Parameter options: An array of `Options' describing the request. + /// - Parameter unixDomainSocketPath: Specifies the path of a Unix domain socket that the client should connect to. /// - Parameter callback: The closure of type `Callback` to be used for the callback. - init(options: [Options], callback: @escaping Callback) { + init(options: [Options], unixDomainSocketPath: String? = nil, callback: @escaping Callback) { + self.unixDomainSocketPath = unixDomainSocketPath self.callback = callback var theSchema = "http://" @@ -562,6 +568,10 @@ public class ClientRequest { if useHTTP2 { curlHelperSetOptInt(handle!, CURLOPT_HTTP_VERSION, CURL_HTTP_VERSION_2_0) } + + if let socketPath = unixDomainSocketPath?.cString(using: .utf8) { + curlHelperSetUnixSocketPath(handle!, UnsafePointer(socketPath)) + } } /// Sets the HTTP method and Content-Length in libCurl diff --git a/Sources/KituraNet/HTTP/HTTP.swift b/Sources/KituraNet/HTTP/HTTP.swift index 9ba96c4..2c1db86 100644 --- a/Sources/KituraNet/HTTP/HTTP.swift +++ b/Sources/KituraNet/HTTP/HTTP.swift @@ -115,18 +115,20 @@ public class HTTP { Create a new `ClientRequest` using a list of options. - Parameter options: a list of `ClientRequest.Options`. - - Parameter callback: closure to run after the request. + - Parameter unixDomainSocketPath: the path of a Unix domain socket that this client should connect to (defaults to `nil`). + - Parameter callback: The closure to run after the request completes. The `ClientResponse?` parameter allows access to the response from the server. - Returns: a `ClientRequest` instance ### Usage Example: ### ````swift - let request = HTTP.request([ClientRequest.Options]) {response in - ... + let myOptions: [ClientRequest.Options] = [.hostname("localhost"), .port("8080")] + let request = HTTP.request(myOptions) { response in + // Process the ClientResponse } ```` */ - public static func request(_ options: [ClientRequest.Options], callback: @escaping ClientRequest.Callback) -> ClientRequest { - return ClientRequest(options: options, callback: callback) + public static func request(_ options: [ClientRequest.Options], unixDomainSocketPath: String? = nil, callback: @escaping ClientRequest.Callback) -> ClientRequest { + return ClientRequest(options: options, unixDomainSocketPath: unixDomainSocketPath, callback: callback) } /** diff --git a/Sources/KituraNet/HTTP/HTTPServer.swift b/Sources/KituraNet/HTTP/HTTPServer.swift index 899fe68..72500a4 100644 --- a/Sources/KituraNet/HTTP/HTTPServer.swift +++ b/Sources/KituraNet/HTTP/HTTPServer.swift @@ -52,16 +52,18 @@ public class HTTPServer: Server { */ public var delegate: ServerDelegate? - /** - Port number for listening for new connections. - - ### Usage Example: ### - ````swift - httpServer.port = 8080 - ```` - */ + /// The TCP port on which this server listens for new connections. If `nil`, this server does not listen on a TCP socket. public private(set) var port: Int? + /// The Unix domain socket path on which this server listens for new connections. If `nil`, this server does not listen on a Unix socket. + public private(set) var unixDomainSocketPath: String? + + /// The types of listeners we currently support. + private enum ListenerType { + case inet(Int) + case unix(String) + } + /** A server state @@ -145,20 +147,45 @@ public class HTTPServer: Server { } /** - Listens for connections on a socket. + Listens for connections on a TCP socket. ### Usage Example: ### ````swift try server.listen(on: 8080) ```` - - Parameter on: Port number for new connections, e.g. 8080 + - Parameter port: Port number for new connections, e.g. 8080 */ public func listen(on port: Int) throws { self.port = port + try listen(.inet(port)) + } + + /** + Listens for connections on a Unix socket. + + ### Usage Example: ### + ````swift + try server.listen(unixDomainSocketPath: "/my/path") + ```` + + - Parameter unixDomainSocketPath: Unix socket path for new connections, eg. "/my/path" + */ + public func listen(unixDomainSocketPath: String) throws { + self.unixDomainSocketPath = unixDomainSocketPath + try listen(.unix(unixDomainSocketPath)) + } + + private func listen(_ listener: ListenerType) throws { do { - let socket = try Socket.create() - self.listenSocket = socket + let socket: Socket + switch listener { + case .inet: + socket = try Socket.create(family: .inet) + case .unix: + socket = try Socket.create(family: .unix) + } + self.listenSocket = socket // If SSL config has been created, // create and attach the SSLService delegate to the socket @@ -166,21 +193,28 @@ public class HTTPServer: Server { socket.delegate = try SSLService(usingConfiguration: sslConfig); } - try socket.listen(on: port, maxBacklogSize: maxPendingConnections, allowPortReuse: self.allowPortReuse) + let listenerDescription: String + switch listener { + case .inet(let port): + try socket.listen(on: port, maxBacklogSize: maxPendingConnections, allowPortReuse: self.allowPortReuse) + // If a random (ephemeral) port number was requested, get the listening port + let listeningPort = Int(socket.listeningPort) + if listeningPort != port { + self.port = listeningPort + // We should only expect a different port if the requested port was zero. + if port != 0 { + Log.error("Listening port \(listeningPort) does not match requested port \(port)") + } + } + listenerDescription = "port \(listeningPort)" + case .unix(let path): + try socket.listen(on: path, maxBacklogSize: maxPendingConnections) + listenerDescription = "path \(path)" + } let socketManager = IncomingSocketManager() self.socketManager = socketManager - // If a random (ephemeral) port number was requested, get the listening port - let listeningPort = Int(socket.listeningPort) - if listeningPort != port { - self.port = listeningPort - // We should only expect a different port if the requested port was zero. - if port != 0 { - Log.error("Listening port \(listeningPort) does not match requested port \(port)") - } - } - if let delegate = socket.delegate { #if os(Linux) // Add the list of supported ALPN protocols to the SSLServiceDelegate @@ -189,11 +223,11 @@ public class HTTPServer: Server { } #endif - Log.info("Listening on port \(self.port!) (delegate: \(delegate))") - Log.verbose("Options for port \(self.port!): delegate: \(delegate), maxPendingConnections: \(maxPendingConnections), allowPortReuse: \(self.allowPortReuse)") + Log.info("Listening on \(listenerDescription) (delegate: \(delegate))") + Log.verbose("Options for \(listenerDescription): delegate: \(delegate), maxPendingConnections: \(maxPendingConnections), allowPortReuse: \(self.allowPortReuse)") } else { - Log.info("Listening on port \(self.port!)") - Log.verbose("Options for port \(self.port!): maxPendingConnections: \(maxPendingConnections), allowPortReuse: \(self.allowPortReuse)") + Log.info("Listening on \(listenerDescription)") + Log.verbose("Options for \(listenerDescription): maxPendingConnections: \(maxPendingConnections), allowPortReuse: \(self.allowPortReuse)") } // set synchronously to avoid contention in back to back server start/stop calls @@ -234,6 +268,26 @@ public class HTTPServer: Server { return server } + /** + Static method to create a new HTTP server and have it listen for connections on a Unix domain socket. + + ### Usage Example: ### + ````swift + let server = HTTPServer.listen(unixDomainSocketPath: "/my/path", delegate: self) + ```` + + - Parameter unixDomainSocketPath: The path of the Unix domain socket that this server should listen on. + - Parameter delegate: The delegate handler for HTTP connections. + + - Returns: A new instance of a `HTTPServer`. + */ + public static func listen(unixDomainSocketPath: String, delegate: ServerDelegate?) throws -> HTTPServer { + let server = HTTP.createServer() + server.delegate = delegate + try server.listen(unixDomainSocketPath: unixDomainSocketPath) + return server + } + /** Listen for connections on a socket. diff --git a/Tests/KituraNetTests/KituraNetTest.swift b/Tests/KituraNetTests/KituraNetTest.swift index f288a8b..2054110 100644 --- a/Tests/KituraNetTests/KituraNetTest.swift +++ b/Tests/KituraNetTests/KituraNetTest.swift @@ -34,6 +34,7 @@ class KituraNetTest: XCTestCase { var useSSL = useSSLDefault var port = portDefault + var unixDomainSocketPath: String? = nil static let sslConfig: SSLService.Configuration = { let sslConfigDir = URL(fileURLWithPath: #file).appendingPathComponent("../SSLConfig") @@ -61,7 +62,15 @@ class KituraNetTest: XCTestCase { func doTearDown() { } - func startServer(_ delegate: ServerDelegate?, port: Int = portDefault, useSSL: Bool = useSSLDefault, allowPortReuse: Bool = portReuseDefault) throws -> HTTPServer { + /// Start a server listening on a specified TCP port or Unix socket path. + /// - Parameter delegate: The ServerDelegate that will handle requests to this server + /// - Parameter port: The TCP port number to listen on + /// - Parameter socketPath: The Unix socket path to listen on + /// - Parameter useSSL: Whether to listen using SSL + /// - Parameter allowPortReuse: Whether to allow the TCP port to be reused by other listeners + /// - Returns: an HTTPServer instance. + /// - Throws: an error if the server fails to listen on the specified port or path. + func startServer(_ delegate: ServerDelegate?, port: Int = portDefault, unixDomainSocketPath: String? = nil, useSSL: Bool = useSSLDefault, allowPortReuse: Bool = portReuseDefault) throws -> HTTPServer { let server = HTTP.createServer() server.delegate = delegate @@ -69,10 +78,14 @@ class KituraNetTest: XCTestCase { if useSSL { server.sslConfig = KituraNetTest.sslConfig } - try server.listen(on: port) + if let unixDomainSocketPath = unixDomainSocketPath { + try server.listen(unixDomainSocketPath: unixDomainSocketPath) + } else { + try server.listen(on: port) + } return server } - + /// Convenience function for starting an HTTPServer on an ephemeral port, /// returning the a tuple containing the server and the port it is listening on. func startEphemeralServer(_ delegate: ServerDelegate?, useSSL: Bool = useSSLDefault, allowPortReuse: Bool = portReuseDefault) throws -> (server: HTTPServer, port: Int) { @@ -115,6 +128,35 @@ class KituraNetTest: XCTestCase { } } + func performServerTest(_ delegate: ServerDelegate?, unixDomainSocketPath: String, useSSL: Bool = useSSLDefault, + line: Int = #line, asyncTasks: (XCTestExpectation) -> Void...) { + + do { + self.useSSL = useSSL + self.unixDomainSocketPath = unixDomainSocketPath + + let server: HTTPServer = try startServer(delegate, unixDomainSocketPath: unixDomainSocketPath, useSSL: useSSL) + defer { + server.stop() + } + + let requestQueue = DispatchQueue(label: "Request queue") + for (index, asyncTask) in asyncTasks.enumerated() { + let expectation = self.expectation(line: line, index: index) + requestQueue.async() { + asyncTask(expectation) + } + } + + // wait for timeout or for all created expectations to be fulfilled + waitExpectation(timeout: 10) { error in + XCTAssertNil(error); + } + } catch { + XCTFail("Error: \(error)") + } + } + func performFastCGIServerTest(_ delegate: ServerDelegate?, port: Int = portDefault, allowPortReuse: Bool = portReuseDefault, line: Int = #line, asyncTasks: (XCTestExpectation) -> Void...) { @@ -145,7 +187,7 @@ class KituraNetTest: XCTestCase { } } - func performRequest(_ method: String, path: String, close: Bool=true, callback: @escaping ClientRequest.Callback, + func performRequest(_ method: String, path: String, unixDomainSocketPath: String? = nil, close: Bool=true, callback: @escaping ClientRequest.Callback, headers: [String: String]? = nil, requestModifier: ((ClientRequest) -> Void)? = nil) { var allHeaders = [String: String]() @@ -163,7 +205,7 @@ class KituraNetTest: XCTestCase { options.append(.disableSSLVerification) } - let req = HTTP.request(options, callback: callback) + let req = HTTP.request(options, unixDomainSocketPath: unixDomainSocketPath, callback: callback) if let requestModifier = requestModifier { requestModifier(req) } diff --git a/Tests/KituraNetTests/UnixSocketTests.swift b/Tests/KituraNetTests/UnixSocketTests.swift new file mode 100644 index 0000000..1e886c5 --- /dev/null +++ b/Tests/KituraNetTests/UnixSocketTests.swift @@ -0,0 +1,110 @@ +/** + * Copyright IBM Corporation 2019 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + **/ + +import Foundation +import Dispatch +import LoggerAPI + +import XCTest + +@testable import KituraNet +import Socket + +class UnixSocketTests: KituraNetTest { + + static var allTests : [(String, (UnixSocketTests) -> () throws -> Void)] { + return [ + ("testUnixSockets", testUnixSockets), + ] + } + + // Socket file path for Unix socket tests + private var socketFilePath: String = "" + + override func setUp() { + doSetUp() + // Create a path for Unix socket tests + socketFilePath = uniqueTemporaryFilePath() + } + + override func tearDown() { + doTearDown() + // Clean up temporary file + removeTemporaryFilePath(socketFilePath) + } + + // Generates a unique temporary file path suitable for use as a Unix domain socket. + // On Linux, a path is returned within /tmp + // On MacOS, a path is returned within /var/folders + func uniqueTemporaryFilePath() -> String { + #if os(Linux) + let temporaryDirectory = "/tmp" + #else + var temporaryDirectory: String + if #available(OSX 10.12, *) { + temporaryDirectory = FileManager.default.temporaryDirectory.path + } else { + temporaryDirectory = "/tmp" + } + #endif + return temporaryDirectory + "/" + String(ProcessInfo.processInfo.globallyUniqueString.prefix(20)) + } + + // Delete a temporary file path. + func removeTemporaryFilePath(_ path: String) { + let fileURL = URL(fileURLWithPath: path) + let fm = FileManager.default + do { + try fm.removeItem(at: fileURL) + } catch { + XCTFail("Unable to remove \(path): \(error.localizedDescription)") + } + } + + let unixDelegate = TestUnixSocketServerDelegate() + + /// Test that we can start a server on a Unix socket, and then make a ClientRequest + /// to that socket. The TestUnixSocketServerDelegate.handle() function will verify + /// that the incoming request's socket is a unix socket. + func testUnixSockets() { + performServerTest(unixDelegate, unixDomainSocketPath: socketFilePath) { expectation in + self.performRequest("get", path: "/banana", unixDomainSocketPath: self.socketFilePath, callback: { response in + XCTAssertEqual(response?.statusCode, HTTPStatusCode.OK, "Status code wasn't .OK was \(String(describing: response?.statusCode))") + expectation.fulfill() + }) + } + } + + class TestUnixSocketServerDelegate: ServerDelegate { + func handle(request: ServerRequest, response: ServerResponse) { + guard let request = request as? HTTPServerRequest else { + return XCTFail("Request was not an HTTPServerRequest") + } + guard let socketSignature = request.signature else { + return XCTFail("Socket signature missing") + } + XCTAssertEqual(socketSignature.protocolFamily, Socket.ProtocolFamily.unix, "Socket was not a Unix socket") + XCTAssertEqual(request.method.lowercased(), "get") + response.statusCode = .OK + do { + try response.end(text: "OK") + } catch { + XCTFail("Error sending response: \(error)") + } + } + } + +} diff --git a/Tests/LinuxMain.swift b/Tests/LinuxMain.swift index d2b8953..caca361 100644 --- a/Tests/LinuxMain.swift +++ b/Tests/LinuxMain.swift @@ -43,6 +43,7 @@ extension Sequence { } XCTMain([ + testCase(UnixSocketTests.allTests.shuffled()), testCase(ClientE2ETests.allTests.shuffled()), testCase(ClientRequestTests.allTests.shuffled()), testCase(FastCGIProtocolTests.allTests.shuffled()),