Skip to content

Commit

Permalink
Add authorization to proxy (#94)
Browse files Browse the repository at this point in the history
  • Loading branch information
vkill authored and weissi committed Sep 8, 2019
1 parent 244aea6 commit 47de4bb
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 26 deletions.
11 changes: 7 additions & 4 deletions Sources/AsyncHTTPClient/HTTPClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ public class HTTPClient {
switch self.configuration.proxy {
case .none:
return channel.pipeline.addSSLHandlerIfNeeded(for: request, tlsConfiguration: self.configuration.tlsConfiguration)
case .some:
return channel.pipeline.addProxyHandler(for: request, decoder: decoder, encoder: encoder, tlsConfiguration: self.configuration.tlsConfiguration)
case .some(let proxy):
return channel.pipeline.addProxyHandler(for: request, decoder: decoder, encoder: encoder, tlsConfiguration: self.configuration.tlsConfiguration, proxy: proxy)
}
}.flatMap {
if let timeout = self.resolve(timeout: self.configuration.timeout.read, deadline: deadline) {
Expand Down Expand Up @@ -383,8 +383,8 @@ extension HTTPClient.Configuration {
}

private extension ChannelPipeline {
func addProxyHandler(for request: HTTPClient.Request, decoder: ByteToMessageHandler<HTTPResponseDecoder>, encoder: HTTPRequestEncoder, tlsConfiguration: TLSConfiguration?) -> EventLoopFuture<Void> {
let handler = HTTPClientProxyHandler(host: request.host, port: request.port, onConnect: { channel in
func addProxyHandler(for request: HTTPClient.Request, decoder: ByteToMessageHandler<HTTPResponseDecoder>, encoder: HTTPRequestEncoder, tlsConfiguration: TLSConfiguration?, proxy: HTTPClient.Configuration.Proxy?) -> EventLoopFuture<Void> {
let handler = HTTPClientProxyHandler(host: request.host, port: request.port, authorization: proxy?.authorization, onConnect: { channel in
channel.pipeline.removeHandler(decoder).flatMap {
return channel.pipeline.addHandler(
ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes)),
Expand Down Expand Up @@ -428,6 +428,7 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible {
case chunkedSpecifiedMultipleTimes
case invalidProxyResponse
case contentLengthMissing
case proxyAuthenticationRequired
}

private var code: Code
Expand Down Expand Up @@ -464,4 +465,6 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible {
public static let invalidProxyResponse = HTTPClientError(code: .invalidProxyResponse)
/// Request does not contain `Content-Length` header.
public static let contentLengthMissing = HTTPClientError(code: .contentLengthMissing)
/// Proxy Authentication Required
public static let proxyAuthenticationRequired = HTTPClientError(code: .proxyAuthenticationRequired)
}
23 changes: 21 additions & 2 deletions Sources/AsyncHTTPClient/HTTPClientProxyHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,26 @@ public extension HTTPClient.Configuration {
public var host: String
/// Specifies Proxy server port.
public var port: Int
/// Specifies Proxy server authorization.
public var authorization: HTTPClient.Authorization?

/// Create proxy.
///
/// - parameters:
/// - host: proxy server host.
/// - port: proxy server port.
public static func server(host: String, port: Int) -> Proxy {
return .init(host: host, port: port)
return .init(host: host, port: port, authorization: nil)
}

/// Create proxy.
///
/// - parameters:
/// - host: proxy server host.
/// - port: proxy server port.
/// - authorization: proxy server authorization.
public static func server(host: String, port: Int, authorization: HTTPClient.Authorization? = nil) -> Proxy {
return .init(host: host, port: port, authorization: authorization)
}
}
}
Expand All @@ -61,14 +73,16 @@ internal final class HTTPClientProxyHandler: ChannelDuplexHandler, RemovableChan

private let host: String
private let port: Int
private let authorization: HTTPClient.Authorization?
private var onConnect: (Channel) -> EventLoopFuture<Void>
private var writeBuffer: CircularBuffer<WriteItem>
private var readBuffer: CircularBuffer<NIOAny>
private var readState: ReadState

init(host: String, port: Int, onConnect: @escaping (Channel) -> EventLoopFuture<Void>) {
init(host: String, port: Int, authorization: HTTPClient.Authorization?, onConnect: @escaping (Channel) -> EventLoopFuture<Void>) {
self.host = host
self.port = port
self.authorization = authorization
self.onConnect = onConnect
self.writeBuffer = .init()
self.readBuffer = .init()
Expand All @@ -87,6 +101,8 @@ internal final class HTTPClientProxyHandler: ChannelDuplexHandler, RemovableChan
// inbound proxies) will switch to tunnel mode immediately after the
// blank line that concludes the successful response's header section
break
case 407:
context.fireErrorCaught(HTTPClientError.proxyAuthenticationRequired)
default:
// Any response other than a successful response
// indicates that the tunnel has not yet been formed and that the
Expand Down Expand Up @@ -150,6 +166,9 @@ internal final class HTTPClientProxyHandler: ChannelDuplexHandler, RemovableChan
uri: "\(self.host):\(self.port)"
)
head.headers.add(name: "proxy-connection", value: "keep-alive")
if let authorization = authorization {
head.headers.add(name: "proxy-authorization", value: authorization.headerValue)
}
context.write(self.wrapOutboundOut(.head(head)), promise: nil)
context.write(self.wrapOutboundOut(.end(nil)), promise: nil)
context.flush()
Expand Down
35 changes: 35 additions & 0 deletions Sources/AsyncHTTPClient/HTTPHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,41 @@ extension HTTPClient {
self.body = body
}
}

/// HTTP authentication
public struct Authorization {
private enum Scheme {
case Basic(String)
case Bearer(String)
}

private let scheme: Scheme

private init(scheme: Scheme) {
self.scheme = scheme
}

public static func basic(username: String, password: String) -> HTTPClient.Authorization {
return .basic(credentials: Data("\(username):\(password)".utf8).base64EncodedString())
}

public static func basic(credentials: String) -> HTTPClient.Authorization {
return .init(scheme: .Basic(credentials))
}

public static func bearer(tokens: String) -> HTTPClient.Authorization {
return .init(scheme: .Bearer(tokens))
}

public var headerValue: String {
switch self.scheme {
case .Basic(let credentials):
return "Basic \(credentials)"
case .Bearer(let tokens):
return "Bearer \(tokens)"
}
}
}
}

internal class ResponseAccumulator: HTTPClientResponseDelegate {
Expand Down
54 changes: 34 additions & 20 deletions Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@ internal class HttpBin {
.childChannelInitializer { channel in
channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: true, withErrorHandling: true).flatMap {
if let simulateProxy = simulateProxy {
return channel.pipeline.addHandler(HTTPProxySimulator(option: simulateProxy), position: .first)
let responseEncoder = HTTPResponseEncoder()
let requestDecoder = ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes))

return channel.pipeline.addHandlers([responseEncoder, requestDecoder, HTTPProxySimulator(option: simulateProxy, encoder: responseEncoder, decoder: requestDecoder)], position: .first)
} else {
return channel.eventLoop.makeSucceededFuture(())
}
Expand All @@ -138,43 +141,54 @@ internal class HttpBin {
}

final class HTTPProxySimulator: ChannelInboundHandler, RemovableChannelHandler {
typealias InboundIn = ByteBuffer
typealias InboundOut = ByteBuffer
typealias OutboundOut = ByteBuffer
typealias InboundIn = HTTPServerRequestPart
typealias InboundOut = HTTPServerResponsePart
typealias OutboundOut = HTTPServerResponsePart

enum Option {
case plaintext
case tls
}

let option: Option
let encoder: HTTPResponseEncoder
let decoder: ByteToMessageHandler<HTTPRequestDecoder>
var head: HTTPResponseHead

init(option: Option) {
init(option: Option, encoder: HTTPResponseEncoder, decoder: ByteToMessageHandler<HTTPRequestDecoder>) {
self.option = option
self.encoder = encoder
self.decoder = decoder
self.head = HTTPResponseHead(version: .init(major: 1, minor: 1), status: .ok, headers: .init([("Content-Length", "0"), ("Connection", "close")]))
}

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let response = """
HTTP/1.1 200 OK\r\n\
Content-Length: 0\r\n\
Connection: close\r\n\
\r\n
"""
var buffer = self.unwrapInboundIn(data)
let request = buffer.readString(length: buffer.readableBytes)!
if request.hasPrefix("CONNECT") {
var buffer = context.channel.allocator.buffer(capacity: 0)
buffer.writeString(response)
context.write(self.wrapInboundOut(buffer), promise: nil)
context.flush()
let request = self.unwrapInboundIn(data)
switch request {
case .head(let head):
guard head.method == .CONNECT else {
fatalError("Expected a CONNECT request")
}
if head.headers.contains(name: "proxy-authorization") {
if head.headers["proxy-authorization"].first != "Basic YWxhZGRpbjpvcGVuc2VzYW1l" {
self.head.status = .proxyAuthenticationRequired
}
}
case .body:
()
case .end:
context.write(self.wrapOutboundOut(.head(self.head)), promise: nil)
context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)

context.channel.pipeline.removeHandler(self, promise: nil)
context.channel.pipeline.removeHandler(self.decoder, promise: nil)
context.channel.pipeline.removeHandler(self.encoder, promise: nil)

switch self.option {
case .tls:
_ = HttpBin.configureTLS(channel: context.channel)
case .plaintext: break
}
} else {
fatalError("Expected a CONNECT request")
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,11 @@ extension HTTPClientTests {
("testReadTimeout", testReadTimeout),
("testDeadline", testDeadline),
("testCancel", testCancel),
("testHTTPClientAuthorization", testHTTPClientAuthorization),
("testProxyPlaintext", testProxyPlaintext),
("testProxyTLS", testProxyTLS),
("testProxyPlaintextWithCorrectlyAuthorization", testProxyPlaintextWithCorrectlyAuthorization),
("testProxyPlaintextWithIncorrectlyAuthorization", testProxyPlaintextWithIncorrectlyAuthorization),
("testUploadStreaming", testUploadStreaming),
("testNoContentLengthForSSLUncleanShutdown", testNoContentLengthForSSLUncleanShutdown),
("testNoContentLengthWithIgnoreErrorForSSLUncleanShutdown", testNoContentLengthWithIgnoreErrorForSSLUncleanShutdown),
Expand Down
39 changes: 39 additions & 0 deletions Tests/AsyncHTTPClientTests/HTTPClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,14 @@ class HTTPClientTests: XCTestCase {
}
}

func testHTTPClientAuthorization() {
var authorization = HTTPClient.Authorization.basic(username: "aladdin", password: "opensesame")
XCTAssertEqual(authorization.headerValue, "Basic YWxhZGRpbjpvcGVuc2VzYW1l")

authorization = HTTPClient.Authorization.bearer(tokens: "mF_9.B5f-4.1JqM")
XCTAssertEqual(authorization.headerValue, "Bearer mF_9.B5f-4.1JqM")
}

func testProxyPlaintext() throws {
let httpBin = HttpBin(simulateProxy: .plaintext)
let httpClient = HTTPClient(
Expand Down Expand Up @@ -321,6 +329,37 @@ class HTTPClientTests: XCTestCase {
XCTAssertEqual(res.status, .ok)
}

func testProxyPlaintextWithCorrectlyAuthorization() throws {
let httpBin = HttpBin(simulateProxy: .plaintext)
let httpClient = HTTPClient(
eventLoopGroupProvider: .createNew,
configuration: .init(proxy: .server(host: "localhost", port: httpBin.port, authorization: .basic(username: "aladdin", password: "opensesame")))
)
defer {
try! httpClient.syncShutdown()
httpBin.shutdown()
}
let res = try httpClient.get(url: "http://test/ok").wait()
XCTAssertEqual(res.status, .ok)
}

func testProxyPlaintextWithIncorrectlyAuthorization() throws {
let httpBin = HttpBin(simulateProxy: .plaintext)
let httpClient = HTTPClient(
eventLoopGroupProvider: .createNew,
configuration: .init(proxy: .server(host: "localhost", port: httpBin.port, authorization: .basic(username: "aladdin", password: "opensesamefoo")))
)
defer {
try! httpClient.syncShutdown()
httpBin.shutdown()
}
XCTAssertThrowsError(try httpClient.get(url: "http://test/ok").wait(), "Should fail") { error in
guard case let error = error as? HTTPClientError, error == .proxyAuthenticationRequired else {
return XCTFail("Should fail with HTTPClientError.proxyAuthenticationRequired")
}
}
}

func testUploadStreaming() throws {
let httpBin = HttpBin()
let httpClient = HTTPClient(eventLoopGroupProvider: .createNew)
Expand Down

0 comments on commit 47de4bb

Please sign in to comment.