Skip to content

Commit

Permalink
fix 784: dont crash on huge in-memory bodies (#785)
Browse files Browse the repository at this point in the history
fixes #784 

`writeChunks` had 3 bugs:
1. An actually wrong `UnsafeMutableTransferBox` -> removed that type
which should never be created
2. A loooong future chain (instead of one final promise) -> implemented
3. Potentially infinite recursion which lead to the crash in #784) ->
fixed too
  • Loading branch information
weissi authored Nov 26, 2024
1 parent bdaa3b1 commit 2119f0d
Show file tree
Hide file tree
Showing 12 changed files with 322 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ extension Transaction {
case finished(error: Error?)
}

fileprivate enum RequestStreamState {
fileprivate enum RequestStreamState: Sendable {
case requestHeadSent
case producing
case paused(continuation: CheckedContinuation<Void, Error>?)
case finished
}

fileprivate enum ResponseStreamState {
fileprivate enum ResponseStreamState: Sendable {
// Waiting for response head. Valid transitions to: streamingBody.
case waitingForResponseHead
// streaming response body. Valid transitions to: finished.
Expand Down
74 changes: 40 additions & 34 deletions Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ import NIOHTTP1
import NIOSSL

@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
@usableFromInline final class Transaction: @unchecked Sendable {
@usableFromInline
final class Transaction:
// until NIOLockedValueBox learns `sending` because StateMachine cannot be Sendable
@unchecked Sendable
{
let logger: Logger

let request: HTTPClientRequest.Prepared
Expand All @@ -28,8 +32,7 @@ import NIOSSL
let preferredEventLoop: EventLoop
let requestOptions: RequestOptions

private let stateLock = NIOLock()
private var state: StateMachine
private let state: NIOLockedValueBox<StateMachine>

init(
request: HTTPClientRequest.Prepared,
Expand All @@ -44,7 +47,7 @@ import NIOSSL
self.logger = logger
self.connectionDeadline = connectionDeadline
self.preferredEventLoop = preferredEventLoop
self.state = StateMachine(responseContinuation)
self.state = NIOLockedValueBox(StateMachine(responseContinuation))
}

func cancel() {
Expand All @@ -56,8 +59,8 @@ import NIOSSL
private func writeOnceAndOneTimeOnly(byteBuffer: ByteBuffer) {
// This method is synchronously invoked after sending the request head. For this reason we
// can make a number of assumptions, how the state machine will react.
let writeAction = self.stateLock.withLock {
self.state.writeNextRequestPart()
let writeAction = self.state.withLockedValue { state in
state.writeNextRequestPart()
}

switch writeAction {
Expand Down Expand Up @@ -99,30 +102,33 @@ import NIOSSL

struct BreakTheWriteLoopError: Swift.Error {}

// FIXME: Refactor this to not use `self.state.unsafe`.
private func writeRequestBodyPart(_ part: ByteBuffer) async throws {
self.stateLock.lock()
switch self.state.writeNextRequestPart() {
self.state.unsafe.lock()
switch self.state.unsafe.withValueAssumingLockIsAcquired({ state in state.writeNextRequestPart() }) {
case .writeAndContinue(let executor):
self.stateLock.unlock()
self.state.unsafe.unlock()
executor.writeRequestBodyPart(.byteBuffer(part), request: self, promise: nil)

case .writeAndWait(let executor):
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, Error>) in
self.state.waitForRequestBodyDemand(continuation: continuation)
self.stateLock.unlock()
self.state.unsafe.withValueAssumingLockIsAcquired({ state in
state.waitForRequestBodyDemand(continuation: continuation)
})
self.state.unsafe.unlock()

executor.writeRequestBodyPart(.byteBuffer(part), request: self, promise: nil)
}

case .fail:
self.stateLock.unlock()
self.state.unsafe.unlock()
throw BreakTheWriteLoopError()
}
}

private func requestBodyStreamFinished() {
let finishAction = self.stateLock.withLock {
self.state.finishRequestBodyStream()
let finishAction = self.state.withLockedValue { state in
state.finishRequestBodyStream()
}

switch finishAction {
Expand Down Expand Up @@ -150,8 +156,8 @@ extension Transaction: HTTPSchedulableRequest {
var requiredEventLoop: EventLoop? { nil }

func requestWasQueued(_ scheduler: HTTPRequestScheduler) {
self.stateLock.withLock {
self.state.requestWasQueued(scheduler)
self.state.withLockedValue { state in
state.requestWasQueued(scheduler)
}
}
}
Expand All @@ -165,8 +171,8 @@ extension Transaction: HTTPExecutableRequest {
// MARK: Request

func willExecuteRequest(_ executor: HTTPRequestExecutor) {
let action = self.stateLock.withLock {
self.state.willExecuteRequest(executor)
let action = self.state.withLockedValue { state in
state.willExecuteRequest(executor)
}

switch action {
Expand All @@ -183,8 +189,8 @@ extension Transaction: HTTPExecutableRequest {
func requestHeadSent() {}

func resumeRequestBodyStream() {
let action = self.stateLock.withLock {
self.state.resumeRequestBodyStream()
let action = self.state.withLockedValue { state in
state.resumeRequestBodyStream()
}

switch action {
Expand Down Expand Up @@ -214,16 +220,16 @@ extension Transaction: HTTPExecutableRequest {
}

func pauseRequestBodyStream() {
self.stateLock.withLock {
self.state.pauseRequestBodyStream()
self.state.withLockedValue { state in
state.pauseRequestBodyStream()
}
}

// MARK: Response

func receiveResponseHead(_ head: HTTPResponseHead) {
let action = self.stateLock.withLock {
self.state.receiveResponseHead(head, delegate: self)
let action = self.state.withLockedValue { state in
state.receiveResponseHead(head, delegate: self)
}

switch action {
Expand All @@ -243,8 +249,8 @@ extension Transaction: HTTPExecutableRequest {
}

func receiveResponseBodyParts(_ buffer: CircularBuffer<ByteBuffer>) {
let action = self.stateLock.withLock {
self.state.receiveResponseBodyParts(buffer)
let action = self.state.withLockedValue { state in
state.receiveResponseBodyParts(buffer)
}
switch action {
case .none:
Expand All @@ -260,8 +266,8 @@ extension Transaction: HTTPExecutableRequest {
}

func succeedRequest(_ buffer: CircularBuffer<ByteBuffer>?) {
let succeedAction = self.stateLock.withLock {
self.state.succeedRequest(buffer)
let succeedAction = self.state.withLockedValue { state in
state.succeedRequest(buffer)
}
switch succeedAction {
case .finishResponseStream(let source, let finalResponse):
Expand All @@ -276,8 +282,8 @@ extension Transaction: HTTPExecutableRequest {
}

func fail(_ error: Error) {
let action = self.stateLock.withLock {
self.state.fail(error)
let action = self.state.withLockedValue { state in
state.fail(error)
}
self.performFailAction(action)
}
Expand All @@ -304,8 +310,8 @@ extension Transaction: HTTPExecutableRequest {
}

func deadlineExceeded() {
let action = self.stateLock.withLock {
self.state.deadlineExceeded()
let action = self.state.withLockedValue { state in
state.deadlineExceeded()
}
self.performDeadlineExceededAction(action)
}
Expand All @@ -329,8 +335,8 @@ extension Transaction: HTTPExecutableRequest {
extension Transaction: NIOAsyncSequenceProducerDelegate {
@usableFromInline
func produceMore() {
let action = self.stateLock.withLock {
self.state.produceMore()
let action = self.state.withLockedValue { state in
state.produceMore()
}
switch action {
case .none:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,9 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler {
}
}

@available(*, unavailable)
extension HTTP2ClientRequestHandler: Sendable {}

extension HTTP2ClientRequestHandler: HTTPRequestExecutor {
func writeRequestBodyPart(_ data: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise<Void>?) {
if self.eventLoop.inEventLoop {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ protocol HTTPConnectionPoolDelegate {
func connectionPoolDidShutdown(_ pool: HTTPConnectionPool, unclean: Bool)
}

final class HTTPConnectionPool {
final class HTTPConnectionPool:
// TODO: Refactor to use `NIOLockedValueBox` which will allow this to be checked
@unchecked Sendable
{
private let stateLock = NIOLock()
private var _state: StateMachine
/// The connection idle timeout timers. Protected by the stateLock
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ import NIOSSL
///
/// Use this handle to cancel the request, while it is waiting for a free connection, to execute the request.
/// This protocol is only intended to be implemented by the `HTTPConnectionPool`.
protocol HTTPRequestScheduler {
protocol HTTPRequestScheduler: Sendable {
/// Informs the task queuer that a request has been cancelled.
func cancelRequest(_: HTTPSchedulableRequest)
}
Expand Down
11 changes: 5 additions & 6 deletions Sources/AsyncHTTPClient/HTTPClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -222,21 +222,20 @@ public class HTTPClient {
"""
)
}
let errorStorageLock = NIOLock()
let errorStorage: UnsafeMutableTransferBox<Error?> = .init(nil)
let errorStorage: NIOLockedValueBox<Error?> = NIOLockedValueBox(nil)
let continuation = DispatchWorkItem {}
self.shutdown(requiresCleanClose: requiresCleanClose, queue: DispatchQueue(label: "async-http-client.shutdown"))
{ error in
if let error = error {
errorStorageLock.withLock {
errorStorage.wrappedValue = error
errorStorage.withLockedValue { errorStorage in
errorStorage = error
}
}
continuation.perform()
}
continuation.wait()
try errorStorageLock.withLock {
if let error = errorStorage.wrappedValue {
try errorStorage.withLockedValue { errorStorage in
if let error = errorStorage {
throw error
}
}
Expand Down
64 changes: 53 additions & 11 deletions Sources/AsyncHTTPClient/HTTPHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -47,25 +47,67 @@ extension HTTPClient {
}

@inlinable
func writeChunks<Bytes: Collection>(of bytes: Bytes, maxChunkSize: Int) -> EventLoopFuture<Void>
where Bytes.Element == UInt8 {
let iterator = UnsafeMutableTransferBox(bytes.chunks(ofCount: maxChunkSize).makeIterator())
guard let chunk = iterator.wrappedValue.next() else {
func writeChunks<Bytes: Collection>(
of bytes: Bytes,
maxChunkSize: Int
) -> EventLoopFuture<Void> where Bytes.Element == UInt8 {
// `StreamWriter` is has design issues, for example
// - https://github.com/swift-server/async-http-client/issues/194
// - https://github.com/swift-server/async-http-client/issues/264
// - We're not told the EventLoop the task runs on and the user is free to return whatever EL they
// want.
// One important consideration then is that we must lock around the iterator because we could be hopping
// between threads.
typealias Iterator = EnumeratedSequence<ChunksOfCountCollection<Bytes>>.Iterator
typealias Chunk = (offset: Int, element: ChunksOfCountCollection<Bytes>.Element)

func makeIteratorAndFirstChunk(
bytes: Bytes
) -> (
iterator: NIOLockedValueBox<Iterator>,
chunk: Chunk
)? {
var iterator = bytes.chunks(ofCount: maxChunkSize).enumerated().makeIterator()
guard let chunk = iterator.next() else {
return nil
}

return (NIOLockedValueBox(iterator), chunk)
}

guard let (iterator, chunk) = makeIteratorAndFirstChunk(bytes: bytes) else {
return self.write(IOData.byteBuffer(.init()))
}

@Sendable // can't use closure here as we recursively call ourselves which closures can't do
func writeNextChunk(_ chunk: Bytes.SubSequence) -> EventLoopFuture<Void> {
if let nextChunk = iterator.wrappedValue.next() {
return self.write(.byteBuffer(ByteBuffer(bytes: chunk))).flatMap {
writeNextChunk(nextChunk)
}
func writeNextChunk(_ chunk: Chunk, allDone: EventLoopPromise<Void>) {
if let nextElement = iterator.withLockedValue({ $0.next() }) {
self.write(.byteBuffer(ByteBuffer(bytes: chunk.element))).map {
let index = nextElement.offset
if (index + 1) % 4 == 0 {
// Let's not stack-overflow if the futures insta-complete which they at least in HTTP/2
// mode.
// Also, we must frequently return to the EventLoop because we may get the pause signal
// from another thread. If we fail to do that promptly, we may balloon our body chunks
// into memory.
allDone.futureResult.eventLoop.execute {
writeNextChunk(nextElement, allDone: allDone)
}
} else {
writeNextChunk(nextElement, allDone: allDone)
}
}.cascadeFailure(to: allDone)
} else {
return self.write(.byteBuffer(ByteBuffer(bytes: chunk)))
self.write(.byteBuffer(ByteBuffer(bytes: chunk.element))).cascade(to: allDone)
}
}

return writeNextChunk(chunk)
// HACK (again, we're not told the right EventLoop): Let's write 0 bytes to make the user tell us...
return self.write(.byteBuffer(ByteBuffer())).flatMapWithEventLoop { (_, loop) in
let allDone = loop.makePromise(of: Void.self)
writeNextChunk(chunk, allDone: allDone)
return allDone.futureResult
}
}
}

Expand Down
29 changes: 0 additions & 29 deletions Sources/AsyncHTTPClient/UnsafeTransfer.swift

This file was deleted.

Loading

0 comments on commit 2119f0d

Please sign in to comment.