diff --git a/Sources/AsyncHTTPClient/AsyncAwait/Transaction+StateMachine.swift b/Sources/AsyncHTTPClient/AsyncAwait/Transaction+StateMachine.swift index 49d734df0..53c3c3bf6 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/Transaction+StateMachine.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/Transaction+StateMachine.swift @@ -38,6 +38,7 @@ extension Transaction { case requestHeadSent case producing case paused(continuation: CheckedContinuation?) + case endForwarded case finished } @@ -97,7 +98,8 @@ extension Transaction { bodyStreamContinuation: CheckedContinuation? ) - case failRequestStreamContinuation(CheckedContinuation, Error) + case failRequestStreamContinuation(CheckedContinuation, Error, HTTPRequestExecutor) + case cancelExecutor(HTTPRequestExecutor) } mutating func fail(_ error: Error) -> FailAction { @@ -135,7 +137,7 @@ extension Transaction { bodyStreamContinuation: continuation ) - case .requestHeadSent, .finished, .producing, .paused(continuation: .none): + case .requestHeadSent, .endForwarded, .finished, .producing, .paused(continuation: .none): self.state = .finished(error: error) return .failResponseHead( context.continuation, @@ -156,12 +158,29 @@ extension Transaction { context.executor, bodyStreamContinuation: bodyStreamContinuation ) - case .finished, .producing, .requestHeadSent: + case .endForwarded, .finished, .producing, .requestHeadSent: return .failResponseStream(source, error, context.executor, bodyStreamContinuation: nil) } - case .finished(error: _), - .executing(_, _, .finished): + case .executing(let context, let requestStreamState, .finished): + // an error occured after full response received, but before the full request was sent + self.state = .finished(error: error) + switch requestStreamState { + case .paused(let bodyStreamContinuation): + if let bodyStreamContinuation { + return .failRequestStreamContinuation( + bodyStreamContinuation, + error, + context.executor + ) + } else { + return .cancelExecutor(context.executor) + } + case .endForwarded, .finished, .producing, .requestHeadSent: + return .cancelExecutor(context.executor) + } + + case .finished(error: _): return .none } } @@ -232,7 +251,7 @@ extension Transaction { self.state = .executing(context, .producing, responseState) return .resumeStream(continuation) - case .executing(_, .finished, _): + case .executing(_, .endForwarded, _), .executing(_, .finished, _): // the channels writability changed to writable after we have forwarded all the // request bytes. Can be ignored. return .none @@ -254,6 +273,7 @@ extension Transaction { self.state = .executing(context, .paused(continuation: nil), responseSteam) case .executing(_, .paused, _), + .executing(_, .endForwarded, _), .executing(_, .finished, _), .finished: // the channels writability changed to paused after we have already forwarded all @@ -298,7 +318,7 @@ extension Transaction { "A write continuation already exists, but we tried to set another one. Invalid state: \(self.state)" ) - case .finished, .executing(_, .finished, _): + case .finished, .executing(_, .endForwarded, _), .executing(_, .finished, _): return .fail } } @@ -309,6 +329,7 @@ extension Transaction { .queued, .deadlineExceededWhileQueued, .executing(_, .requestHeadSent, _), + .executing(_, .endForwarded, _), .executing(_, .finished, _): preconditionFailure( "A request stream can only produce, if the request was started. Invalid state: \(self.state)" @@ -343,6 +364,7 @@ extension Transaction { case .initialized, .queued, .deadlineExceededWhileQueued, + .executing(_, .endForwarded, _), .executing(_, .finished, _): preconditionFailure("Invalid state: \(self.state)") @@ -355,23 +377,36 @@ extension Transaction { .executing(let context, .paused(continuation: .none), let responseState), .executing(let context, .requestHeadSent, let responseState): - switch responseState { - case .finished: - // if the response stream has already finished before the request, we must succeed - // the final continuation. - self.state = .finished(error: nil) - return .forwardStreamFinished(context.executor) - - case .waitingForResponseHead, .streamingBody: - self.state = .executing(context, .finished, responseState) - return .forwardStreamFinished(context.executor) - } + self.state = .executing(context, .endForwarded, responseState) + return .forwardStreamFinished(context.executor) case .finished: return .none } } + mutating func requestBodyStreamSent() { + switch self.state { + case .initialized, + .queued, + .deadlineExceededWhileQueued, + .executing(_, .requestHeadSent, _), + .executing(_, .finished, _), + .executing(_, .producing, _), + .executing(_, .paused, _): + preconditionFailure("Invalid state: \(self.state)") + + case .executing(_, .endForwarded, .finished): + self.state = .finished(error: nil) + + case .executing(let context, .endForwarded, let responseState): + self.state = .executing(context, .finished, responseState) + + case .finished: + break + } + } + // MARK: - Response - enum ReceiveResponseHeadAction { @@ -482,7 +517,7 @@ extension Transaction { switch requestState { case .finished: self.state = .finished(error: nil) - case .paused, .producing, .requestHeadSent: + case .paused, .producing, .requestHeadSent, .endForwarded: self.state = .executing(context, requestState, .finished) } return .finishResponseStream(source, finalBody: newChunks) @@ -497,6 +532,15 @@ extension Transaction { } } + mutating func httpResponseStreamTerminated() -> FailAction { + switch self.state { + case .executing(_, _, .finished), .finished: + return .none + default: + return self.fail(HTTPClientError.cancelled) + } + } + enum DeadlineExceededAction { case none case cancelSchedulerOnly(scheduler: HTTPRequestScheduler) @@ -538,7 +582,7 @@ extension Transaction { executor: context.executor, bodyStreamContinuation: continuation ) - case .requestHeadSent, .finished, .producing, .paused(continuation: .none): + case .requestHeadSent, .endForwarded, .finished, .producing, .paused(continuation: .none): self.state = .finished(error: error) return .cancel( requestContinuation: context.continuation, diff --git a/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift b/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift index c8bf54b09..1921ddf34 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift @@ -157,7 +157,7 @@ final class Transaction: break case .forwardStreamFinished(let executor): - executor.finishRequestBodyStream(self, promise: nil) + executor.finishRequestBodyStream(trailers: nil, request: self, promise: nil) } return } @@ -206,7 +206,9 @@ extension Transaction: HTTPExecutableRequest { } } - func requestHeadSent() {} + func requestHeadSent() { + // protocol requirement. Intentionally not needed. + } func resumeRequestBodyStream() { let action = self.state.withLockedValue { state in @@ -245,6 +247,12 @@ extension Transaction: HTTPExecutableRequest { } } + func requestBodyStreamSent() { + self.state.withLockedValue { state in + state.requestBodyStreamSent() + } + } + // MARK: Response func receiveResponseHead(_ head: HTTPResponseHead) { @@ -302,6 +310,13 @@ extension Transaction: HTTPExecutableRequest { } } + func httpResponseStreamTerminated() { + let action = self.state.withLockedValue { state in + state.httpResponseStreamTerminated() + } + self.performFailAction(action) + } + func fail(_ error: Error) { let action = self.state.withLockedValue { state in state.fail(error) @@ -325,8 +340,12 @@ extension Transaction: HTTPExecutableRequest { requestBodyStreamContinuation?.resume(throwing: error) executor.cancelRequest(self) - case .failRequestStreamContinuation(let bodyStreamContinuation, let error): + case .failRequestStreamContinuation(let bodyStreamContinuation, let error, let executor): bodyStreamContinuation.resume(throwing: error) + executor.cancelRequest(self) + + case .cancelExecutor(let executor): + executor.cancelRequest(self) } } @@ -369,6 +388,6 @@ extension Transaction: NIOAsyncSequenceProducerDelegate { @usableFromInline func didTerminate() { - self.fail(HTTPClientError.cancelled) + self.httpResponseStreamTerminated() } } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift index 02ebab916..df7169fc4 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift @@ -242,8 +242,45 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { case .sendBodyPart(let part, let writePromise): context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: writePromise) - case .sendRequestEnd(let writePromise): - context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: writePromise) + case .sendRequestEnd(let trailers, let writePromise, let finalAction): + + let writePromise = writePromise ?? context.eventLoop.makePromise(of: Void.self) + // We need to defer succeeding the old request to avoid ordering issues + + writePromise.futureResult.hop(to: context.eventLoop).assumeIsolated().whenComplete { result in + guard let oldRequest = self.request else { + // in the meantime an error might have happened, which is why this request is + // not reference anymore. + return + } + oldRequest.requestBodyStreamSent() + switch result { + case .success: + // If our final action is not `none`, that means we've already received + // the complete response. As a result, once we've uploaded all the body parts + // we need to tell the pool that the connection is idle or, if we were asked to + // close when we're done, send the close. Either way, we then succeed the request + + switch finalAction { + case .none: + break + + case .informConnectionIsIdle: + self.request = nil + self.onConnectionIdle() + + case .close: + self.request = nil + context.close(promise: nil) + } + + case .failure(let error): + context.close(promise: nil) + oldRequest.fail(error) + } + } + + context.writeAndFlush(self.wrapOutboundOut(.end(trailers)), promise: writePromise) if let readTimeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { self.runTimeoutAction(readTimeoutAction, context: context) @@ -300,7 +337,7 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { // that the request is neither failed nor finished yet self.request!.receiveResponseBodyParts(buffer) - case .succeedRequest(let finalAction, let buffer): + case .forwardResponseEnd(let finalAction, let buffer, let trailers): // We can force unwrap the request here, as we have just validated in the state machine, // that the request is neither failed nor finished yet @@ -312,41 +349,22 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { // other way around. let oldRequest = self.request! - self.request = nil self.runTimeoutAction(.clearIdleReadTimeoutTimer, context: context) self.runTimeoutAction(.clearIdleWriteTimeoutTimer, context: context) switch finalAction { case .close: + self.request = nil context.close(promise: nil) - oldRequest.receiveResponseEnd(buffer, trailers: nil) - case .sendRequestEnd(let writePromise, let shouldClose): - let writePromise = writePromise ?? context.eventLoop.makePromise(of: Void.self) - // We need to defer succeeding the old request to avoid ordering issues - writePromise.futureResult.hop(to: context.eventLoop).assumeIsolated().whenComplete { result in - switch result { - case .success: - // If our final action was `sendRequestEnd`, that means we've already received - // the complete response. As a result, once we've uploaded all the body parts - // we need to tell the pool that the connection is idle or, if we were asked to - // close when we're done, send the close. Either way, we then succeed the request - if shouldClose { - context.close(promise: nil) - } else { - self.onConnectionIdle() - } - - oldRequest.receiveResponseEnd(buffer, trailers: nil) - case .failure(let error): - context.close(promise: nil) - oldRequest.fail(error) - } - } + oldRequest.receiveResponseEnd(buffer, trailers: trailers) + + case .none: + oldRequest.receiveResponseEnd(buffer, trailers: trailers) - context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: writePromise) case .informConnectionIsIdle: + self.request = nil self.onConnectionIdle() - oldRequest.receiveResponseEnd(buffer, trailers: nil) + oldRequest.receiveResponseEnd(buffer, trailers: trailers) } case .failRequest(let error, let finalAction): @@ -484,14 +502,18 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { self.run(action, context: context) } - fileprivate func finishRequestBodyStream0(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { + fileprivate func finishRequestBodyStream0( + trailers: HTTPHeaders?, + request: HTTPExecutableRequest, + promise: EventLoopPromise? + ) { guard self.request === request, let context = self.channelContext else { // See code comment in `writeRequestBodyPart0` promise?.fail(HTTPClientError.requestStreamCancelled) return } - let action = self.state.requestStreamFinished(promise: promise) + let action = self.state.requestStreamFinished(trailers: trailers, promise: promise) self.run(action, context: context) } @@ -545,9 +567,9 @@ extension HTTP1ClientChannelHandler { } } - func finishRequestBodyStream(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { + func finishRequestBodyStream(trailers: HTTPHeaders?, request: HTTPExecutableRequest, promise: EventLoopPromise?) { self.loopBound.execute { - $0.finishRequestBodyStream0(request, promise: promise) + $0.finishRequestBodyStream0(trailers: trailers, request: request, promise: promise) } } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift index 2cde1df3f..e341e0984 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift @@ -27,18 +27,12 @@ struct HTTP1ConnectionStateMachine { } enum Action { - /// A action to execute, when we consider a request "done". + /// An additional action to execute, when either the response or request stream has finished. enum FinalSuccessfulStreamAction { + /// Nothing todo + case none /// Close the connection case close - /// If the server has replied, with a status of 200...300 before all data was sent, a request is considered succeeded, - /// as soon as we wrote the request end onto the wire. - /// - /// The promise is an optional write promise. - /// - /// `shouldClose` records whether we have attached a Connection: close header to this request, and so the connection should - /// be terminated - case sendRequestEnd(EventLoopPromise?, shouldClose: Bool) /// Inform an observer that the connection has become idle case informConnectionIsIdle } @@ -63,7 +57,7 @@ struct HTTP1ConnectionStateMachine { startIdleTimer: Bool ) case sendBodyPart(IOData, EventLoopPromise?) - case sendRequestEnd(EventLoopPromise?) + case sendRequestEnd(trailers: HTTPHeaders?, EventLoopPromise?, FinalSuccessfulStreamAction) case failSendBodyPart(Error, EventLoopPromise?) case failSendStreamFinished(Error, EventLoopPromise?) @@ -72,9 +66,9 @@ struct HTTP1ConnectionStateMachine { case forwardResponseHead(HTTPResponseHead, pauseRequestBodyStream: Bool) case forwardResponseBodyParts(CircularBuffer) + case forwardResponseEnd(FinalSuccessfulStreamAction, CircularBuffer, HTTPHeaders?) case failRequest(Error, FinalFailedStreamAction) - case succeedRequest(FinalSuccessfulStreamAction, CircularBuffer) case read case close @@ -224,13 +218,13 @@ struct HTTP1ConnectionStateMachine { } } - mutating func requestStreamFinished(promise: EventLoopPromise?) -> Action { + mutating func requestStreamFinished(trailers: HTTPHeaders?, promise: EventLoopPromise?) -> Action { guard case .inRequest(var requestStateMachine, let close) = self.state else { fatalError("Invalid state: \(self.state)") } return self.avoidingStateMachineCoW { state -> Action in - let action = requestStateMachine.requestStreamFinished(promise: promise) + let action = requestStateMachine.requestStreamFinished(trailers: trailers, promise: promise) state = .inRequest(requestStateMachine, close: close) return state.modify(with: action) } @@ -433,13 +427,34 @@ extension HTTP1ConnectionStateMachine.State { return .resumeRequestBodyStream case .sendBodyPart(let part, let writePromise): return .sendBodyPart(part, writePromise) - case .sendRequestEnd(let writePromise): - return .sendRequestEnd(writePromise) + case .sendRequestEnd(trailers: let trailers, let writePromise, let finalAction): + guard case .inRequest(_, close: let close) = self else { + fatalError("Invalid state: \(self)") + } + + let newFinalAction: HTTP1ConnectionStateMachine.Action.FinalSuccessfulStreamAction + switch finalAction { + case .close: + self = .closing + newFinalAction = .close + case .requestDone: + if close { + self = .closing + newFinalAction = .close + } else { + self = .idle + newFinalAction = .informConnectionIsIdle + } + case .none: + newFinalAction = .none + } + return .sendRequestEnd(trailers: trailers, writePromise, newFinalAction) + case .forwardResponseHead(let head, let pauseRequestBodyStream): return .forwardResponseHead(head, pauseRequestBodyStream: pauseRequestBodyStream) case .forwardResponseBodyParts(let parts): return .forwardResponseBodyParts(parts) - case .succeedRequest(let finalAction, let finalParts): + case .forwardResponseEnd(let finalAction, let finalParts, let trailers): guard case .inRequest(_, close: let close) = self else { fatalError("Invalid state: \(self)") } @@ -449,14 +464,19 @@ extension HTTP1ConnectionStateMachine.State { case .close: self = .closing newFinalAction = .close - case .sendRequestEnd(let writePromise): - self = .idle - newFinalAction = .sendRequestEnd(writePromise, shouldClose: close) + case .requestDone: + if close { + self = .closing + newFinalAction = .close + } else { + self = .idle + newFinalAction = .informConnectionIsIdle + } case .none: - self = .idle - newFinalAction = close ? .close : .informConnectionIsIdle + // request is ongoing. request stream is still alive + newFinalAction = .none } - return .succeedRequest(newFinalAction, finalParts) + return .forwardResponseEnd(newFinalAction, finalParts, trailers) case .failRequest(let error, let finalAction): switch self { diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift index b010b672e..78e3ad9e1 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift @@ -196,8 +196,14 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { case .sendBodyPart(let data, let writePromise): context.writeAndFlush(self.wrapOutboundOut(.body(data)), promise: writePromise) - case .sendRequestEnd(let writePromise): - context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: writePromise) + case .sendRequestEnd(let trailers, let writePromise, let finalAction): + let promise = writePromise ?? context.eventLoop.makePromise(of: Void.self) + let request = self.request! + promise.futureResult.whenSuccess { + request.requestBodyStreamSent() + } + + context.writeAndFlush(self.wrapOutboundOut(.end(trailers)), promise: promise) if let readTimeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { self.runTimeoutAction(readTimeoutAction, context: context) @@ -206,6 +212,7 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { if let writeTimeoutAction = self.idleWriteTimeoutStateMachine?.requestEndSent() { self.runTimeoutAction(writeTimeoutAction, context: context) } + self.runSuccessfulFinalAction(finalAction, context: context) case .read: context.read() @@ -247,10 +254,10 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { // the right result for HTTP/1). In the h2 case we MUST always close. self.runFailedFinalAction(finalAction, context: context, error: error) - case .succeedRequest(let finalAction, let finalParts): + case .forwardResponseEnd(let finalAction, let finalParts, let trailers): // We can force unwrap the request here, as we have just validated in the state machine, // that the request object is still present. - self.request!.receiveResponseEnd(finalParts, trailers: nil) + self.request!.receiveResponseEnd(finalParts, trailers: trailers) self.request = nil self.runTimeoutAction(.clearIdleReadTimeoutTimer, context: context) self.runTimeoutAction(.clearIdleWriteTimeoutTimer, context: context) @@ -277,14 +284,11 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { context: ChannelHandlerContext ) { switch action { - case .close, .none: + case .close, .none, .requestDone: // The actions returned here come from an `HTTPRequestStateMachine` that assumes http/1.1 // semantics. For this reason we can ignore the close here, since an h2 stream is closed // after every request anyway. break - - case .sendRequestEnd(let writePromise): - context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: writePromise) } } @@ -399,13 +403,17 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { self.run(action, context: context) } - private func finishRequestBodyStream0(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { + private func finishRequestBodyStream0( + trailers: HTTPHeaders?, + request: HTTPExecutableRequest, + promise: EventLoopPromise? + ) { guard self.request === request, let context = self.channelContext else { // See code comment in `writeRequestBodyPart0` return } - let action = self.state.requestStreamFinished(promise: promise) + let action = self.state.requestStreamFinished(trailers: trailers, promise: promise) self.run(action, context: context) } @@ -455,9 +463,9 @@ extension HTTP2ClientRequestHandler { } } - func finishRequestBodyStream(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { + func finishRequestBodyStream(trailers: HTTPHeaders?, request: HTTPExecutableRequest, promise: EventLoopPromise?) { self.loopBound.execute { - $0.finishRequestBodyStream0(request, promise: promise) + $0.finishRequestBodyStream0(trailers: trailers, request: request, promise: promise) } } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift index 0635c7978..8cf31b6e2 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift @@ -185,7 +185,7 @@ protocol HTTPRequestExecutor: Sendable { /// Signals that the request body stream has finished /// /// This method may be **called on any thread**. The executor needs to ensure thread safety. - func finishRequestBodyStream(_ task: HTTPExecutableRequest, promise: EventLoopPromise?) + func finishRequestBodyStream(trailers: HTTPHeaders?, request: HTTPExecutableRequest, promise: EventLoopPromise?) /// Signals that more bytes from response body stream can be consumed. /// @@ -244,6 +244,11 @@ protocol HTTPExecutableRequest: AnyObject, Sendable { /// This will be called on the Channel's EventLoop. Do **not block** during your execution! func pauseRequestBodyStream() + /// Will be called by the ChannelHandler to indicate that the request body stream has been sent. + /// + /// This will be called on the Channel's EventLoop. Do **not block** during your execution! + func requestBodyStreamSent() + /// Receive a response head. /// /// Please note that `receiveResponseHead` and `receiveResponseBodyPart` may diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift index e06389360..cef736063 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift @@ -73,7 +73,7 @@ struct HTTPRequestStateMachine { } enum Action { - /// A action to execute, when we consider a successful request "done". + /// A action to execute, when we consider a request or response stream "done". enum FinalSuccessfulRequestAction { /// Close the connection case close @@ -81,7 +81,7 @@ struct HTTPRequestStateMachine { /// as soon as we wrote the request end onto the wire. /// /// The promise is an optional write promise. - case sendRequestEnd(EventLoopPromise?) + case requestDone /// Do nothing. This is action is used, if the request failed, before we the request head was written onto the wire. /// This might happen if the request is cancelled, or the request failed the soundness check. case none @@ -102,7 +102,7 @@ struct HTTPRequestStateMachine { startIdleTimer: Bool ) case sendBodyPart(IOData, EventLoopPromise?) - case sendRequestEnd(EventLoopPromise?) + case sendRequestEnd(trailers: HTTPHeaders?, EventLoopPromise?, FinalSuccessfulRequestAction) case failSendBodyPart(Error, EventLoopPromise?) case failSendStreamFinished(Error, EventLoopPromise?) @@ -111,9 +111,9 @@ struct HTTPRequestStateMachine { case forwardResponseHead(HTTPResponseHead, pauseRequestBodyStream: Bool) case forwardResponseBodyParts(CircularBuffer) + case forwardResponseEnd(FinalSuccessfulRequestAction, CircularBuffer, HTTPHeaders?) case failRequest(Error, FinalFailedRequestAction) - case succeedRequest(FinalSuccessfulRequestAction, CircularBuffer) case read case wait @@ -353,7 +353,7 @@ struct HTTPRequestStateMachine { } } - mutating func requestStreamFinished(promise: EventLoopPromise?) -> Action { + mutating func requestStreamFinished(trailers: HTTPHeaders?, promise: EventLoopPromise?) -> Action { switch self.state { case .initialized, .waitForChannelToBecomeWritable, @@ -370,7 +370,7 @@ struct HTTPRequestStateMachine { } self.state = .running(.endSent, .waitingForHead) - return .sendRequestEnd(promise) + return .sendRequestEnd(trailers: trailers, promise, .none) case .running( .streaming(let expectedBodyLength, let sentBodyBytes, _), @@ -385,7 +385,7 @@ struct HTTPRequestStateMachine { } self.state = .running(.endSent, .receivingBody(head, streamState)) - return .sendRequestEnd(promise) + return .sendRequestEnd(trailers: trailers, promise, .none) case .running(.streaming(let expectedBodyLength, let sentBodyBytes, _), .endReceived): if let expected = expectedBodyLength, expected != sentBodyBytes { @@ -395,7 +395,7 @@ struct HTTPRequestStateMachine { } self.state = .finished - return .succeedRequest(.sendRequestEnd(promise), .init()) + return .sendRequestEnd(trailers: trailers, promise, .requestDone) case .failed(let error): return .failSendStreamFinished(error, promise) @@ -497,8 +497,8 @@ struct HTTPRequestStateMachine { return self.receivedHTTPResponseHead(head) case .body(let body): return self.receivedHTTPResponseBodyPart(body) - case .end: - return self.receivedHTTPResponseEnd() + case .end(let trailers): + return self.receivedHTTPResponseEnd(trailers: trailers) } } @@ -618,7 +618,7 @@ struct HTTPRequestStateMachine { } } - private mutating func receivedHTTPResponseEnd() -> Action { + private mutating func receivedHTTPResponseEnd(trailers: HTTPHeaders?) -> Action { switch self.state { case .initialized, .waitForChannelToBecomeWritable: preconditionFailure( @@ -648,7 +648,8 @@ struct HTTPRequestStateMachine { ), .endReceived ) - return .forwardResponseBodyParts(remainingBuffer) + return .forwardResponseEnd(.none, remainingBuffer, trailers) + case .close: // If we receive a `.close` as a connectionAction from the responseStreamState // this means, that the response end was signaled by a connection close. Since @@ -671,7 +672,7 @@ struct HTTPRequestStateMachine { // connection should be closed anyway. let (remainingBuffer, _) = responseStreamState.end() state = .finished - return .succeedRequest(.close, remainingBuffer) + return .forwardResponseEnd(.close, remainingBuffer, trailers) } case .running(.endSent, .receivingBody(_, var responseStreamState)): @@ -680,9 +681,9 @@ struct HTTPRequestStateMachine { state = .finished switch action { case .none: - return .succeedRequest(.none, remainingBuffer) + return .forwardResponseEnd(.requestDone, remainingBuffer, trailers) case .close: - return .succeedRequest(.close, remainingBuffer) + return .forwardResponseEnd(.close, remainingBuffer, trailers) } } diff --git a/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift b/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift index abfad7312..7accdc51a 100644 --- a/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift +++ b/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift @@ -59,12 +59,15 @@ extension RequestBag { case initialized(RedirectHandler?) case buffering(CircularBuffer, next: Next) case waitingForRemote + case endReceived } private var state: State + private let requestFramingMetadata: RequestFramingMetadata - init(redirectHandler: RedirectHandler?) { + init(redirectHandler: RedirectHandler?, requestFramingMetadata: RequestFramingMetadata) { self.state = .initialized(redirectHandler) + self.requestFramingMetadata = requestFramingMetadata } } } @@ -100,6 +103,20 @@ extension RequestBag.StateMachine { case none } + mutating func requestHeadSent() { + switch self.state { + case .initialized: + fatalError() + case .executing(let executor, .initialized, let responseStream): + if self.requestFramingMetadata.body == .fixedSize(0) { + self.state = .executing(executor, .finished, responseStream) + } + + default: + break + } + } + mutating func willExecuteRequest(_ executor: HTTPRequestExecutor) -> WillExecuteRequestAction { switch self.state { case .initialized(let redirectHandler), .queued(_, let redirectHandler): @@ -143,7 +160,9 @@ extension RequestBag.StateMachine { // request bytes. Can be ignored. return .none - case .executing(_, .initialized, .buffering), .executing(_, .initialized, .waitingForRemote): + case .executing(_, .initialized, .buffering), + .executing(_, .initialized, .waitingForRemote), + .executing(_, .initialized, .endReceived): preconditionFailure("Invalid states: Response can not be received before request") case .redirected: @@ -239,6 +258,7 @@ extension RequestBag.StateMachine { enum FinishAction { case forwardStreamFinished(HTTPRequestExecutor, EventLoopPromise?) + case forwardStreamFinishedAndSucceedTask(HTTPRequestExecutor, EventLoopPromise?) case forwardStreamFailureAndFailTask(HTTPRequestExecutor, Error, EventLoopPromise?) case none } @@ -254,8 +274,15 @@ extension RequestBag.StateMachine { case .producing: switch result { case .success: - self.state = .executing(executor, .finished, responseState) - return .forwardStreamFinished(executor, nil) + switch responseState { + case .initialized, .buffering, .waitingForRemote: + self.state = .executing(executor, .finished, responseState) + return .forwardStreamFinished(executor, nil) + case .endReceived: + self.state = .finished(error: nil) + return .forwardStreamFinishedAndSucceedTask(executor, nil) + } + case .failure(let error): self.state = .finished(error: error) return .forwardStreamFailureAndFailTask(executor, error, nil) @@ -264,8 +291,15 @@ extension RequestBag.StateMachine { case .paused(let promise): switch result { case .success: - self.state = .executing(executor, .finished, responseState) - return .forwardStreamFinished(executor, promise) + switch responseState { + case .initialized, .buffering, .waitingForRemote: + self.state = .executing(executor, .finished, responseState) + return .forwardStreamFinished(executor, promise) + case .endReceived: + self.state = .finished(error: nil) + return .forwardStreamFinishedAndSucceedTask(executor, promise) + } + case .failure(let error): self.state = .finished(error: error) return .forwardStreamFailureAndFailTask(executor, error, promise) @@ -346,7 +380,7 @@ extension RequestBag.StateMachine { switch self.state { case .initialized, .queued, .deadlineExceededWhileQueued: preconditionFailure("How can we receive a response body part, if the request hasn't started yet.") - case .executing(_, _, .initialized): + case .executing(_, _, .initialized), .executing(_, _, .endReceived): preconditionFailure("If we receive a response body, we must have received a head before") case .executing(let executor, let requestState, .buffering(var currentBuffer, next: let next)): @@ -405,7 +439,7 @@ extension RequestBag.StateMachine { switch self.state { case .initialized, .queued, .deadlineExceededWhileQueued: preconditionFailure("How can we receive a response body part, if the request hasn't started yet.") - case .executing(_, _, .initialized): + case .executing(_, _, .initialized), .executing(_, _, .endReceived): preconditionFailure("If we receive a response body, we must have received a head before") case .executing(let executor, let requestState, .buffering(var buffer, next: let next)): @@ -426,8 +460,15 @@ extension RequestBag.StateMachine { case .executing(let executor, let requestState, .waitingForRemote): guard var newChunks = newChunks, !newChunks.isEmpty else { - self.state = .finished(error: nil) - return .succeedRequest + switch requestState { + case .initialized, .paused, .producing: + self.state = .executing(executor, requestState, .endReceived) + return .none + + case .finished: + self.state = .finished(error: nil) + return .succeedRequest + } } let first = newChunks.removeFirst() @@ -484,7 +525,7 @@ extension RequestBag.StateMachine { self.state = .finished(error: error) return .failTask(error, executorToCancel: executor) - case .executing(_, _, .waitingForRemote): + case .executing(_, _, .waitingForRemote), .executing(_, _, .endReceived): preconditionFailure( "Invalid state... We just returned from a consumption function. We can't already be waiting" ) @@ -550,6 +591,10 @@ extension RequestBag.StateMachine { "Invalid state... We just returned from a consumption function. We can't already be waiting" ) + case .executing(_, _, .endReceived): + // we can't succeed the request here, as we have not sent all request parts. + return .doNothing + case .redirected: return .doNothing @@ -614,10 +659,9 @@ extension RequestBag.StateMachine { case .executing(let executor, _, .buffering(_, next: .error(_))): // this would override another error, let's keep the first one return .cancelExecutor(executor) - case .executing(let executor, _, .initialized): - self.state = .finished(error: error) - return .failTask(error, nil, executor) - case .executing(let executor, _, .waitingForRemote): + case .executing(let executor, _, .initialized), + .executing(let executor, _, .waitingForRemote), + .executing(let executor, _, .endReceived): self.state = .finished(error: error) return .failTask(error, nil, executor) case .redirected: diff --git a/Sources/AsyncHTTPClient/RequestBag.swift b/Sources/AsyncHTTPClient/RequestBag.swift index a743f0814..9f2886bcb 100644 --- a/Sources/AsyncHTTPClient/RequestBag.swift +++ b/Sources/AsyncHTTPClient/RequestBag.swift @@ -100,9 +100,13 @@ final class RequestBag: Sendabl self.eventLoopPreference = eventLoopPreference self.task = task + let (head, metadata) = try request.createRequestHead() + self.requestHead = head + self.requestFramingMetadata = metadata + let loopBoundState = LoopBoundState( request: request, - state: StateMachine(redirectHandler: redirectHandler), + state: StateMachine(redirectHandler: redirectHandler, requestFramingMetadata: metadata), consumeBodyPartStackDepth: 0, tracing: task.tracing ) @@ -111,10 +115,6 @@ final class RequestBag: Sendabl self.requestOptions = requestOptions self.delegate = delegate - let (head, metadata) = try request.createRequestHead() - self.requestHead = head - self.requestFramingMetadata = metadata - self.tlsConfiguration = request.tlsConfiguration self.task.taskDelegate = self @@ -150,9 +150,11 @@ final class RequestBag: Sendabl } private func requestHeadSent0() { + self.loopBoundState.value.state.requestHeadSent() + self.delegate.didSendRequestHead(task: self.task, self.requestHead) - if self.loopBoundState.value.request.body == nil { + if self.requestFramingMetadata.body == .fixedSize(0) { self.delegate.didSendRequest(task: self.task) } } @@ -187,6 +189,10 @@ final class RequestBag: Sendabl self.loopBoundState.value.state.pauseRequestBodyStream() } + private func requestBodyStreamSent0() { + + } + private func writeNextRequestPart(_ part: IOData) -> EventLoopFuture { if self.eventLoop.inEventLoop { return self.writeNextRequestPart0(part) @@ -230,7 +236,26 @@ final class RequestBag: Sendabl promise.futureResult.whenSuccess { self.delegate.didSendRequest(task: self.task) } - writer.finishRequestBodyStream(self, promise: promise) + writer.finishRequestBodyStream(trailers: nil, request: self, promise: promise) + + case .forwardStreamFinishedAndSucceedTask(let writer, let writerPromise): + let promise = writerPromise ?? self.task.eventLoop.makePromise(of: Void.self) + promise.futureResult.whenComplete { result in + switch result { + case .success: + self.delegate.didSendRequest(task: self.task) + do { + let response = try self.delegate.didFinishRequest(task: self.task) + self.task.promise.succeed(response) + } catch { + self.task.promise.fail(error) + } + + case .failure(let error): + self.task.promise.fail(error) + } + } + writer.finishRequestBodyStream(trailers: nil, request: self, promise: promise) case .forwardStreamFailureAndFailTask(let writer, let error, let promise): writer.cancelRequest(self) @@ -518,6 +543,16 @@ extension RequestBag: HTTPExecutableRequest { } } + func requestBodyStreamSent() { + if self.task.eventLoop.inEventLoop { + self.requestBodyStreamSent0() + } else { + self.task.eventLoop.execute { + self.requestBodyStreamSent0() + } + } + } + func receiveResponseHead(_ head: HTTPResponseHead) { if self.task.eventLoop.inEventLoop { self.receiveResponseHead0(head) diff --git a/Tests/AsyncHTTPClientTests/BidirectionalStreamingTests.swift b/Tests/AsyncHTTPClientTests/BidirectionalStreamingTests.swift new file mode 100644 index 000000000..68adda983 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/BidirectionalStreamingTests.swift @@ -0,0 +1,98 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2026 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import AsyncHTTPClient +import NIOCore +import NIOHTTP1 +import Testing + +@Suite("Request and response streaming") +struct BidirectionalStreamingTests { + @Test func requestStreamCanOutliveResponse() async throws { + let (bodyStream, bodyWriteContinuation) = AsyncStream.makeStream(of: AsyncStream.self) + let httpBin = HTTPBin(.http1_1(ssl: false)) { _ in + let (stream, continuation) = AsyncStream.makeStream(of: ByteBuffer.self) + bodyWriteContinuation.yield(stream) + return HTTPRequestStreamingChannel(bodyStreamContinuation: continuation) + } + + defer { #expect(throws: Never.self) { try httpBin.shutdown() } } + + let httpClient = HTTPClient(eventLoopGroupProvider: .singleton) + + defer { #expect(throws: Never.self) { try httpClient.syncShutdown() } } + + var request = HTTPClientRequest(url: "http://localhost:\(httpBin.port)") + let (stream, continuation) = AsyncStream.makeStream(of: ByteBuffer.self) + request.body = .stream(stream, length: .unknown) + + await #expect(throws: Never.self) { + let response = try await httpClient.execute(request, timeout: .seconds(60), logger: nil) + var iterator = response.body.makeAsyncIterator() + #expect(try await iterator.next() == nil) // response is finished. + } + + var bodyStreamIterator = bodyStream.makeAsyncIterator() + + let serverRequestStream = await bodyStreamIterator.next() + guard let serverRequestStream else { + Issue.record("Could not get the server request stream") + return + } + var receivedWritesIterator = serverRequestStream.makeAsyncIterator() + + let payload1 = ByteBuffer(string: "Hello World! 1") + continuation.yield(payload1) + #expect(await receivedWritesIterator.next() == payload1) + let payload2 = ByteBuffer(string: "Hello World! 2") + continuation.yield(payload2) + #expect(await receivedWritesIterator.next() == payload2) + let payload3 = ByteBuffer(string: "Hello World! 3") + continuation.yield(payload3) + #expect(await receivedWritesIterator.next() == payload3) + let payload4 = ByteBuffer(string: "Hello World! 4") + continuation.yield(payload4) + #expect(await receivedWritesIterator.next() == payload4) + continuation.finish() + #expect(await receivedWritesIterator.next() == nil) + } +} + +final class HTTPRequestStreamingChannel: ChannelInboundHandler & AHCTestSendableMetatype { + typealias InboundIn = HTTPServerRequestPart + typealias OutboundOut = HTTPServerResponsePart + + let bodyStreamContinuation: AsyncStream.Continuation + + init(bodyStreamContinuation: AsyncStream.Continuation) { + self.bodyStreamContinuation = bodyStreamContinuation + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let reqPart = self.unwrapInboundIn(data) + switch reqPart { + case .head: + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) + context.write(self.wrapOutboundOut(.head(responseHead)), promise: nil) + context.write(self.wrapOutboundOut(.end(nil)), promise: nil) + context.flush() + case .body(let body): + self.bodyStreamContinuation.yield(body) + case .end: + self.bodyStreamContinuation.finish() + @unknown default: + Issue.record("Unhandled case: \(reqPart)") + } + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift b/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift index 0d871b7dc..1987ca29f 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift @@ -897,6 +897,46 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { // and ensure that the state machine can tolerate this embedded.embeddedEventLoop.advanceTime(by: .milliseconds(250)) } + + func testSendingAndReceivingTrailers() async throws { + let eventLoop = EmbeddedEventLoop() + let handler = HTTP1ClientChannelHandler( + eventLoop: eventLoop, + backgroundLogger: Logger(label: "no-op", factory: SwiftLogNoOpLogHandler.init), + connectionIdLoggerMetadata: "test connection" + ) + let channel = EmbeddedChannel(handlers: [handler], loop: eventLoop) + XCTAssertNoThrow(try channel.connect(to: .init(ipAddress: "127.0.0.1", port: 80)).wait()) + + // non empty body is important to trigger this bug as we otherwise finish the request in a single flush + let request = MockHTTPExecutableRequest( + head: .init(version: .http1_1, method: .POST, uri: "http://localhost/"), + framingMetadata: RequestFramingMetadata(connectionClose: false, body: .stream), + raiseErrorIfUnimplementedMethodIsCalled: false + ) + + let executor = handler.requestExecutor + request.resumeRequestBodyStreamCallback = { + executor.writeRequestBodyPart(.byteBuffer(.init(string: "Hello World")), request: request, promise: nil) + executor.finishRequestBodyStream(trailers: ["trailer": "foo"], request: request, promise: nil) + } + + request.receiveResponseEndCallback = { (_, trailers) in + XCTAssertEqual(trailers, ["trailer": "bar"]) + } + + channel.write(request, promise: nil) + + XCTAssertEqual(try channel.readOutbound(as: HTTPClientRequestPart.self), .head(request.requestHead)) + XCTAssertEqual(try channel.readOutbound(as: HTTPClientRequestPart.self), .body(.byteBuffer(.init(string: "Hello World")))) + XCTAssertEqual(try channel.readOutbound(as: HTTPClientRequestPart.self), .end(["trailer": "foo"])) + + XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .http1_1, status: .ok)))) + XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.body(.init(string: "Foo Bar")))) + XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.end(["trailer": "bar"]))) + + XCTAssertEqual(request.events.map(\.kind), [.willExecuteRequest, .requestHeadSent, .resumeRequestBodyStream, .requestBodySent, .receiveResponseHead, .receiveResponseEnd]) + } } final class TestBackpressureWriter: Sendable { diff --git a/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift b/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift index 1c6e9659f..2718c828e 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift @@ -52,7 +52,7 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { // once we receive a writable event again, we can allow the producer to produce more data XCTAssertEqual(state.writabilityChanged(writable: true), .resumeRequestBodyStream) XCTAssertEqual(state.requestStreamPartReceived(part3, promise: nil), .sendBodyPart(part3, nil)) - XCTAssertEqual(state.requestStreamFinished(promise: nil), .sendRequestEnd(nil)) + XCTAssertEqual(state.requestStreamFinished(trailers: nil, promise: nil), .sendRequestEnd(trailers: nil, nil, .none)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) XCTAssertEqual( @@ -61,7 +61,10 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) - XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.informConnectionIsIdle, .init([responseBody]))) + XCTAssertEqual( + state.channelRead(.end(nil)), + .forwardResponseEnd(.informConnectionIsIdle, [responseBody], nil) + ) XCTAssertEqual(state.channelReadComplete(), .wait) } @@ -96,7 +99,7 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { XCTAssertEqual(state.channelReadComplete(), .forwardResponseBodyParts(.init([part2]))) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.read(), .read) - XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.informConnectionIsIdle, .init())) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.informConnectionIsIdle, [], nil)) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) } @@ -140,7 +143,7 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) - XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, .init([responseBody]))) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.close, [responseBody], nil)) XCTAssertEqual(state.channelInactive(), .fireChannelInactive) } @@ -163,7 +166,7 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) - XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, .init([responseBody]))) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.close, [responseBody], nil)) XCTAssertEqual(state.channelInactive(), .fireChannelInactive) } @@ -190,7 +193,10 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) - XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.informConnectionIsIdle, .init([responseBody]))) + XCTAssertEqual( + state.channelRead(.end(nil)), + .forwardResponseEnd(.informConnectionIsIdle, [responseBody], nil) + ) XCTAssertEqual(state.channelInactive(), .fireChannelInactive) } @@ -214,7 +220,7 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) - XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, .init([responseBody]))) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.close, [responseBody], nil)) } func testNIOTriggersChannelActiveTwice() { @@ -367,7 +373,7 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false) ) - XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, [])) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.close, [], nil)) } func testWeDontCrashAfterEarlyHintsAndConnectionClose() { @@ -445,10 +451,10 @@ extension HTTP1ConnectionStateMachine.Action: Equatable { return lhsData == rhsData case ( - .succeedRequest(let lhsFinalAction, let lhsFinalBuffer), - .succeedRequest(let rhsFinalAction, let rhsFinalBuffer) + .forwardResponseEnd(let lhsFinalAction, let lhsFinalBuffer, let lhsTrailers), + .forwardResponseEnd(let rhsFinalAction, let rhsFinalBuffer, let rhsTrailers) ): - return lhsFinalAction == rhsFinalAction && lhsFinalBuffer == rhsFinalBuffer + return lhsFinalAction == rhsFinalAction && lhsFinalBuffer == rhsFinalBuffer && lhsTrailers == rhsTrailers case (.failRequest(_, let lhsFinalAction), .failRequest(_, let rhsFinalAction)): return lhsFinalAction == rhsFinalAction @@ -468,24 +474,6 @@ extension HTTP1ConnectionStateMachine.Action: Equatable { } } -extension HTTP1ConnectionStateMachine.Action.FinalSuccessfulStreamAction: Equatable { - public static func == ( - lhs: HTTP1ConnectionStateMachine.Action.FinalSuccessfulStreamAction, - rhs: HTTP1ConnectionStateMachine.Action.FinalSuccessfulStreamAction - ) -> Bool { - switch (lhs, rhs) { - case (.close, .close): - return true - case (sendRequestEnd(let lhsPromise, let lhsShouldClose), sendRequestEnd(let rhsPromise, let rhsShouldClose)): - return lhsPromise?.futureResult == rhsPromise?.futureResult && lhsShouldClose == rhsShouldClose - case (informConnectionIsIdle, informConnectionIsIdle): - return true - default: - return false - } - } -} - extension HTTP1ConnectionStateMachine.Action.FinalFailedStreamAction: Equatable { public static func == ( lhs: HTTP1ConnectionStateMachine.Action.FinalFailedStreamAction, diff --git a/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift b/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift index 71f7f3d1a..879345b41 100644 --- a/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift @@ -576,4 +576,41 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { channel.writeAndFlush(request, promise: nil) XCTAssertEqual(request.events.map(\.kind), [.willExecuteRequest, .requestHeadSent]) } + + func testSendingAndReceivingTrailers() async throws { + let eventLoop = EmbeddedEventLoop() + let handler = HTTP2ClientRequestHandler(eventLoop: eventLoop) + let channel = EmbeddedChannel(handlers: [handler], loop: eventLoop) + XCTAssertNoThrow(try channel.connect(to: .init(ipAddress: "127.0.0.1", port: 80)).wait()) + + // non empty body is important to trigger this bug as we otherwise finish the request in a single flush + let request = MockHTTPExecutableRequest( + head: .init(version: .http1_1, method: .POST, uri: "http://localhost/"), + framingMetadata: RequestFramingMetadata(connectionClose: false, body: .stream), + raiseErrorIfUnimplementedMethodIsCalled: false + ) + + let executor = handler.requestExecutor + request.resumeRequestBodyStreamCallback = { + executor.writeRequestBodyPart(.byteBuffer(.init(string: "Hello World")), request: request, promise: nil) + executor.finishRequestBodyStream(trailers: ["trailer": "foo"], request: request, promise: nil) + } + + request.receiveResponseEndCallback = { (_, trailers) in + XCTAssertEqual(trailers, ["trailer": "bar"]) + } + + channel.write(request, promise: nil) + + XCTAssertEqual(try channel.readOutbound(as: HTTPClientRequestPart.self), .head(request.requestHead)) + XCTAssertEqual(try channel.readOutbound(as: HTTPClientRequestPart.self), .body(.byteBuffer(.init(string: "Hello World")))) + XCTAssertEqual(try channel.readOutbound(as: HTTPClientRequestPart.self), .end(["trailer": "foo"])) + + XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .http1_1, status: .ok)))) + XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.body(.init(string: "Foo Bar")))) + XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.end(["trailer": "bar"]))) + + XCTAssertEqual(request.events.map(\.kind), [.willExecuteRequest, .requestHeadSent, .resumeRequestBodyStream, .requestBodySent, .receiveResponseHead, .receiveResponseEnd]) + } + } diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+RequestQueueTests.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+RequestQueueTests.swift index 99a61fe47..6fbdda385 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+RequestQueueTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+RequestQueueTests.swift @@ -126,6 +126,10 @@ final private class MockScheduledRequest: HTTPSchedulableRequest { preconditionFailure("Unimplemented") } + func requestBodyStreamSent() { + preconditionFailure("Unimplemented") + } + func receiveResponseHead(_: HTTPResponseHead) { preconditionFailure("Unimplemented") } diff --git a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift index 8fe879745..c03b23141 100644 --- a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift @@ -37,7 +37,7 @@ class HTTPRequestStateMachineTests: XCTestCase { ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) - XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init([responseBody]))) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, .init([responseBody]), nil)) XCTAssertEqual(state.channelReadComplete(), .wait) } @@ -77,7 +77,7 @@ class HTTPRequestStateMachineTests: XCTestCase { // once we receive a writable event again, we can allow the producer to produce more data XCTAssertEqual(state.writabilityChanged(writable: true), .resumeRequestBodyStream) XCTAssertEqual(state.requestStreamPartReceived(part3, promise: nil), .sendBodyPart(part3, nil)) - XCTAssertEqual(state.requestStreamFinished(promise: nil), .sendRequestEnd(nil)) + XCTAssertEqual(state.requestStreamFinished(trailers: nil, promise: nil), .sendRequestEnd(trailers: nil, nil, .none)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) XCTAssertEqual( @@ -86,7 +86,7 @@ class HTTPRequestStateMachineTests: XCTestCase { ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) - XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init([responseBody]))) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, .init([responseBody]), nil)) XCTAssertEqual(state.channelReadComplete(), .wait) } @@ -132,7 +132,7 @@ class HTTPRequestStateMachineTests: XCTestCase { let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) - state.requestStreamFinished(promise: nil).assertFailRequest(HTTPClientError.bodyLengthMismatch, .close(nil)) + state.requestStreamFinished(trailers: nil, promise: nil).assertFailRequest(HTTPClientError.bodyLengthMismatch, .close(nil)) } func testRequestBodyStreamIsCancelledIfServerRespondsWith301() { @@ -169,7 +169,7 @@ class HTTPRequestStateMachineTests: XCTestCase { "Expected to drop all stream data after having received a response head, with status >= 300" ) - XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, .init())) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.close, [], nil)) XCTAssertEqual( state.requestStreamPartReceived(part, promise: nil), @@ -178,7 +178,7 @@ class HTTPRequestStateMachineTests: XCTestCase { ) XCTAssertEqual( - state.requestStreamFinished(promise: nil), + state.requestStreamFinished(trailers: nil, promise: nil), .failSendStreamFinished(HTTPClientError.requestStreamCancelled, nil), "Expected to drop all stream data after having received a response head, with status >= 300" ) @@ -230,7 +230,7 @@ class HTTPRequestStateMachineTests: XCTestCase { "Expected to drop all stream data after having received a response head, with status >= 300" ) - XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, .init())) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.close, [], nil)) XCTAssertEqual( state.requestStreamPartReceived(part, promise: nil), @@ -239,7 +239,7 @@ class HTTPRequestStateMachineTests: XCTestCase { ) XCTAssertEqual( - state.requestStreamFinished(promise: nil), + state.requestStreamFinished(trailers: nil, promise: nil), .failSendStreamFinished(HTTPClientError.requestStreamCancelled, nil), "Expected to drop all stream data after having received a response head, with status >= 300" ) @@ -267,13 +267,13 @@ class HTTPRequestStateMachineTests: XCTestCase { state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false) ) - XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseBodyParts(.init())) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.none, [], nil)) let part1 = IOData.byteBuffer(ByteBuffer(bytes: 4...7)) XCTAssertEqual(state.requestStreamPartReceived(part1, promise: nil), .sendBodyPart(part1, nil)) let part2 = IOData.byteBuffer(ByteBuffer(bytes: 8...11)) XCTAssertEqual(state.requestStreamPartReceived(part2, promise: nil), .sendBodyPart(part2, nil)) - XCTAssertEqual(state.requestStreamFinished(promise: nil), .succeedRequest(.sendRequestEnd(nil), .init())) + XCTAssertEqual(state.requestStreamFinished(trailers: nil, promise: nil), .sendRequestEnd(trailers: nil, nil, .requestDone)) XCTAssertEqual( state.requestStreamPartReceived(part2, promise: nil), @@ -308,9 +308,9 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.requestStreamPartReceived(part1, promise: nil), .sendBodyPart(part1, nil)) let part2 = IOData.byteBuffer(ByteBuffer(bytes: 8...11)) XCTAssertEqual(state.requestStreamPartReceived(part2, promise: nil), .sendBodyPart(part2, nil)) - XCTAssertEqual(state.requestStreamFinished(promise: nil), .sendRequestEnd(nil)) + XCTAssertEqual(state.requestStreamFinished(trailers: nil, promise: nil), .sendRequestEnd(trailers: nil, nil, .none)) - XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init())) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, [], nil)) } func testRequestIsFailedIfRequestBodySizeIsWrongEvenAfterServerRespondedWith200() { @@ -335,11 +335,11 @@ class HTTPRequestStateMachineTests: XCTestCase { state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false) ) - XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseBodyParts(.init())) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.none, [], nil)) let part1 = IOData.byteBuffer(ByteBuffer(bytes: 4...7)) XCTAssertEqual(state.requestStreamPartReceived(part1, promise: nil), .sendBodyPart(part1, nil)) - state.requestStreamFinished(promise: nil).assertFailRequest(HTTPClientError.bodyLengthMismatch, .close(nil)) + state.requestStreamFinished(trailers: nil, promise: nil).assertFailRequest(HTTPClientError.bodyLengthMismatch, .close(nil)) XCTAssertEqual(state.channelInactive(), .wait) } @@ -368,7 +368,7 @@ class HTTPRequestStateMachineTests: XCTestCase { let part1 = IOData.byteBuffer(ByteBuffer(bytes: 4...7)) XCTAssertEqual(state.requestStreamPartReceived(part1, promise: nil), .sendBodyPart(part1, nil)) - state.requestStreamFinished(promise: nil).assertFailRequest(HTTPClientError.bodyLengthMismatch, .close(nil)) + state.requestStreamFinished(trailers: nil, promise: nil).assertFailRequest(HTTPClientError.bodyLengthMismatch, .close(nil)) XCTAssertEqual(state.channelRead(.end(nil)), .wait) } @@ -387,7 +387,7 @@ class HTTPRequestStateMachineTests: XCTestCase { ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) - XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init([responseBody]))) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, [responseBody], nil)) XCTAssertEqual(state.channelInactive(), .wait) } @@ -430,7 +430,7 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.channelReadComplete(), .forwardResponseBodyParts(.init([part2]))) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.read(), .read) - XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init())) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, [], nil)) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) @@ -467,7 +467,7 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) XCTAssertEqual(state.channelRead(.body(part2)), .wait) - XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init([part2]))) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, [part2], nil)) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) @@ -513,7 +513,7 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.channelReadComplete(), .forwardResponseBodyParts(.init([part2]))) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.read(), .read) - XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init())) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, [], nil)) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.read(), .read) } @@ -551,7 +551,7 @@ class HTTPRequestStateMachineTests: XCTestCase { ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) - XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init([responseBody]))) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, [responseBody], nil)) XCTAssertEqual(state.channelReadComplete(), .wait) } @@ -630,7 +630,7 @@ class HTTPRequestStateMachineTests: XCTestCase { state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false) ) - XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init())) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, [], nil)) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) } @@ -649,7 +649,7 @@ class HTTPRequestStateMachineTests: XCTestCase { state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false) ) - XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init())) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, [], nil)) XCTAssertEqual(state.idleReadTimeoutTriggered(), .wait, "A read timeout that fires to late must be ignored") } @@ -667,7 +667,7 @@ class HTTPRequestStateMachineTests: XCTestCase { state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false) ) - XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init())) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.requestDone, [], nil)) XCTAssertEqual(state.requestCancelled(), .wait, "A cancellation that happens to late is ignored") } @@ -705,7 +705,7 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) XCTAssertEqual(state.channelReadComplete(), .wait) - XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, [])) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.close, [], nil)) XCTAssertEqual(state.channelInactive(), .wait) } @@ -729,7 +729,7 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.read(), .read) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.channelRead(.body(body)), .wait) - XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, [body])) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.close, [body], nil)) XCTAssertEqual(state.channelInactive(), .wait) } @@ -951,7 +951,7 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.channelRead(.body(part3)), .wait) XCTAssertEqual(state.channelReadComplete(), .wait) - XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, [part1, part2, part3])) + XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseEnd(.close, [part1, part2, part3], nil)) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.channelInactive(), .wait) @@ -973,8 +973,8 @@ extension HTTPRequestStateMachine.Action: Equatable { case (.sendBodyPart(let lhsData, let lhsPromise), .sendBodyPart(let rhsData, let rhsPromise)): return lhsData == rhsData && lhsPromise?.futureResult == rhsPromise?.futureResult - case (.sendRequestEnd(let lhsPromise), .sendRequestEnd(let rhsPromise)): - return lhsPromise?.futureResult == rhsPromise?.futureResult + case (.sendRequestEnd(let lhsTrailers, let lhsPromise, let lhsAction), .sendRequestEnd(let rhsTrailers, let rhsPromise, let rhsAction)): + return lhsTrailers == rhsTrailers && lhsPromise?.futureResult == rhsPromise?.futureResult && lhsAction == rhsAction case (.pauseRequestBodyStream, .pauseRequestBodyStream): return true @@ -991,10 +991,10 @@ extension HTTPRequestStateMachine.Action: Equatable { return lhsData == rhsData case ( - .succeedRequest(let lhsFinalAction, let lhsFinalBuffer), - .succeedRequest(let rhsFinalAction, let rhsFinalBuffer) + .forwardResponseEnd(let lhsFinalAction, let lhsFinalBuffer, let lhsTrailers), + .forwardResponseEnd(let rhsFinalAction, let rhsFinalBuffer, let rhsTrailers) ): - return lhsFinalAction == rhsFinalAction && lhsFinalBuffer == rhsFinalBuffer + return lhsFinalAction == rhsFinalAction && lhsFinalBuffer == rhsFinalBuffer && lhsTrailers == rhsTrailers case (.failRequest(_, let lhsFinalAction), .failRequest(_, let rhsFinalAction)): return lhsFinalAction == rhsFinalAction @@ -1023,27 +1023,6 @@ extension HTTPRequestStateMachine.Action: Equatable { } } -extension HTTPRequestStateMachine.Action.FinalSuccessfulRequestAction: Equatable { - public static func == ( - lhs: HTTPRequestStateMachine.Action.FinalSuccessfulRequestAction, - rhs: HTTPRequestStateMachine.Action.FinalSuccessfulRequestAction - ) -> Bool { - switch (lhs, rhs) { - case (.close, close): - return true - - case (.sendRequestEnd(let lhsPromise), .sendRequestEnd(let rhsPromise)): - return lhsPromise?.futureResult == rhsPromise?.futureResult - - case (.none, .none): - return true - - default: - return false - } - } -} - extension HTTPRequestStateMachine.Action.FinalFailedRequestAction: Equatable { public static func == ( lhs: HTTPRequestStateMachine.Action.FinalFailedRequestAction, diff --git a/Tests/AsyncHTTPClientTests/Mocks/MockConnectionPool.swift b/Tests/AsyncHTTPClientTests/Mocks/MockConnectionPool.swift index 756758131..bd8f0736a 100644 --- a/Tests/AsyncHTTPClientTests/Mocks/MockConnectionPool.swift +++ b/Tests/AsyncHTTPClientTests/Mocks/MockConnectionPool.swift @@ -749,6 +749,10 @@ final class MockHTTPScheduableRequest: HTTPSchedulableRequest { preconditionFailure("Unimplemented") } + func requestBodyStreamSent() { + preconditionFailure("Unimplemented") + } + func receiveResponseHead(_: HTTPResponseHead) { preconditionFailure("Unimplemented") } diff --git a/Tests/AsyncHTTPClientTests/Mocks/MockHTTPExecutableRequest.swift b/Tests/AsyncHTTPClientTests/Mocks/MockHTTPExecutableRequest.swift index 61b59b7b7..227bbeff3 100644 --- a/Tests/AsyncHTTPClientTests/Mocks/MockHTTPExecutableRequest.swift +++ b/Tests/AsyncHTTPClientTests/Mocks/MockHTTPExecutableRequest.swift @@ -28,9 +28,10 @@ final class MockHTTPExecutableRequest: HTTPExecutableRequest { case requestHeadSent case resumeRequestBodyStream case pauseRequestBodyStream + case requestBodySent case receiveResponseHead case receiveResponseBodyParts - case succeedRequest + case receiveResponseEnd case fail } @@ -38,9 +39,10 @@ final class MockHTTPExecutableRequest: HTTPExecutableRequest { case requestHeadSent case resumeRequestBodyStream case pauseRequestBodyStream + case requestBodySent case receiveResponseHead(HTTPResponseHead) case receiveResponseBodyParts(CircularBuffer) - case succeedRequest(CircularBuffer?) + case receiveResponseEnd(CircularBuffer?, HTTPHeaders?) case fail(Error) var kind: Kind { @@ -49,9 +51,10 @@ final class MockHTTPExecutableRequest: HTTPExecutableRequest { case .requestHeadSent: return .requestHeadSent case .resumeRequestBodyStream: return .resumeRequestBodyStream case .pauseRequestBodyStream: return .pauseRequestBodyStream + case .requestBodySent: return .requestBodySent case .receiveResponseHead: return .receiveResponseHead case .receiveResponseBodyParts: return .receiveResponseBodyParts - case .succeedRequest: return .succeedRequest + case .receiveResponseEnd: return .receiveResponseEnd case .fail: return .fail } } @@ -69,14 +72,64 @@ final class MockHTTPExecutableRequest: HTTPExecutableRequest { private let file: StaticString private let line: UInt - let willExecuteRequestCallback: (@Sendable (HTTPRequestExecutor) -> Void)? = nil - let requestHeadSentCallback: (@Sendable () -> Void)? = nil - let resumeRequestBodyStreamCallback: (@Sendable () -> Void)? = nil - let pauseRequestBodyStreamCallback: (@Sendable () -> Void)? = nil - let receiveResponseHeadCallback: (@Sendable (HTTPResponseHead) -> Void)? = nil - let receiveResponseBodyPartsCallback: (@Sendable (CircularBuffer) -> Void)? = nil - let succeedRequestCallback: (@Sendable (CircularBuffer?) -> Void)? = nil - let failCallback: (@Sendable (Error) -> Void)? = nil + struct Callbacks { + var willExecuteRequestCallback: (@Sendable (HTTPRequestExecutor) -> Void)? = nil + var requestHeadSentCallback: (@Sendable () -> Void)? = nil + var resumeRequestBodyStreamCallback: (@Sendable () -> Void)? = nil + var pauseRequestBodyStreamCallback: (@Sendable () -> Void)? = nil + var requestBodyStreamSentCallback: (@Sendable () -> Void)? = nil + var receiveResponseHeadCallback: (@Sendable (HTTPResponseHead) -> Void)? = nil + var receiveResponseBodyPartsCallback: (@Sendable (CircularBuffer) -> Void)? = nil + var receiveResponseEndCallback: (@Sendable (CircularBuffer?, HTTPHeaders?) -> Void)? = nil + var failCallback: (@Sendable (Error) -> Void)? = nil + } + + let callbacks: NIOLockedValueBox = .init(.init()) + + var willExecuteRequestCallback: (@Sendable (HTTPRequestExecutor) -> Void)? { + get { self.callbacks.withLockedValue { $0.willExecuteRequestCallback } } + set { self.callbacks.withLockedValue { $0.willExecuteRequestCallback = newValue } } + } + + var requestHeadSentCallback: (@Sendable () -> Void)? { + get { self.callbacks.withLockedValue { $0.requestHeadSentCallback } } + set { self.callbacks.withLockedValue { $0.requestHeadSentCallback = newValue } } + } + + var resumeRequestBodyStreamCallback: (@Sendable () -> Void)? { + get { self.callbacks.withLockedValue { $0.resumeRequestBodyStreamCallback } } + set { self.callbacks.withLockedValue { $0.resumeRequestBodyStreamCallback = newValue } } + } + + var pauseRequestBodyStreamCallback: (@Sendable () -> Void)? { + get { self.callbacks.withLockedValue { $0.pauseRequestBodyStreamCallback } } + set { self.callbacks.withLockedValue { $0.pauseRequestBodyStreamCallback = newValue } } + } + + var requestBodyStreamSentCallback: (@Sendable () -> Void)? { + get { self.callbacks.withLockedValue { $0.requestBodyStreamSentCallback } } + set { self.callbacks.withLockedValue { $0.requestBodyStreamSentCallback = newValue } } + } + + var receiveResponseHeadCallback: (@Sendable (HTTPResponseHead) -> Void)? { + get { self.callbacks.withLockedValue { $0.receiveResponseHeadCallback } } + set { self.callbacks.withLockedValue { $0.receiveResponseHeadCallback = newValue } } + } + + var receiveResponseBodyPartsCallback: (@Sendable (CircularBuffer) -> Void)? { + get { self.callbacks.withLockedValue { $0.receiveResponseBodyPartsCallback } } + set { self.callbacks.withLockedValue { $0.receiveResponseBodyPartsCallback = newValue } } + } + + var receiveResponseEndCallback: (@Sendable (CircularBuffer?, HTTPHeaders?) -> Void)? { + get { self.callbacks.withLockedValue { $0.receiveResponseEndCallback } } + set { self.callbacks.withLockedValue { $0.receiveResponseEndCallback = newValue } } + } + + var failCallback: (@Sendable (Error) -> Void)? { + get { self.callbacks.withLockedValue { $0.failCallback } } + set { self.callbacks.withLockedValue { $0.failCallback = newValue } } + } /// captures all ``HTTPExecutableRequest`` method calls in the order of occurrence, including arguments. /// If you are not interested in the arguments you can use `events.map(\.kind)` to get all events without arguments. @@ -141,6 +194,14 @@ final class MockHTTPExecutableRequest: HTTPExecutableRequest { pauseRequestBodyStreamCallback() } + func requestBodyStreamSent() { + self.events.append(.requestBodySent) + guard let requestBodyStreamSentCallback = self.requestBodyStreamSentCallback else { + return self.calledUnimplementedMethod(#function) + } + requestBodyStreamSentCallback() + } + func receiveResponseHead(_ head: HTTPResponseHead) { self.events.append(.receiveResponseHead(head)) guard let receiveResponseHeadCallback = receiveResponseHeadCallback else { @@ -158,11 +219,11 @@ final class MockHTTPExecutableRequest: HTTPExecutableRequest { } func receiveResponseEnd(_ buffer: CircularBuffer?, trailers: HTTPHeaders?) { - self.events.append(.succeedRequest(buffer)) - guard let succeedRequestCallback = succeedRequestCallback else { + self.events.append(.receiveResponseEnd(buffer, trailers)) + guard let receiveResponseEndCallback = self.receiveResponseEndCallback else { return self.calledUnimplementedMethod(#function) } - succeedRequestCallback(buffer) + receiveResponseEndCallback(buffer, trailers) } func fail(_ error: Error) { diff --git a/Tests/AsyncHTTPClientTests/Mocks/MockRequestExecutor.swift b/Tests/AsyncHTTPClientTests/Mocks/MockRequestExecutor.swift index e5d9caa8e..a4397ff5c 100644 --- a/Tests/AsyncHTTPClientTests/Mocks/MockRequestExecutor.swift +++ b/Tests/AsyncHTTPClientTests/Mocks/MockRequestExecutor.swift @@ -14,6 +14,7 @@ import NIOConcurrencyHelpers import NIOCore +import NIOHTTP1 @testable import AsyncHTTPClient @@ -212,7 +213,7 @@ extension MockRequestExecutor: HTTPRequestExecutor { promise?.succeed(()) } - func finishRequestBodyStream(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { + func finishRequestBodyStream(trailers: HTTPHeaders?, request: HTTPExecutableRequest, promise: EventLoopPromise?) { self.writeNextRequestPart(.endOfStream, request: request) promise?.succeed(()) } diff --git a/Tests/AsyncHTTPClientTests/RequestBagTests.swift b/Tests/AsyncHTTPClientTests/RequestBagTests.swift index e68bc3f2a..51621c7a6 100644 --- a/Tests/AsyncHTTPClientTests/RequestBagTests.swift +++ b/Tests/AsyncHTTPClientTests/RequestBagTests.swift @@ -602,10 +602,13 @@ final class RequestBagTests: XCTestCase { }.always { result in XCTAssertTrue(firstWriteSuccess.withLockedValue { $0 }) - guard case .failure(let error) = result else { - return XCTFail("Expected the second write to fail") + switch result { + case .success: + // upload can now continue even after we have received the response end. + break + case .failure(let failure): + XCTFail("Unexpected error: \(failure)") } - XCTAssertEqual(error as? HTTPClientError, .requestStreamCancelled) } } ) @@ -641,9 +644,11 @@ final class RequestBagTests: XCTestCase { bag.receiveResponseHead(.init(version: .http1_1, status: .movedPermanently)) XCTAssertNoThrow(try XCTUnwrap(delegate.backpressurePromise).succeed(())) bag.receiveResponseEnd([], trailers: nil) + XCTAssertEqual(delegate.hitDidReceiveResponse, 0) // if we now write our second part of the response this should fail the backpressure promise writeSecondPartPromise.succeed(()) + XCTAssertEqual(delegate.hitDidReceiveResponse, 1) XCTAssertEqual(delegate.receivedHead?.status, .movedPermanently) XCTAssertNoThrow(try bag.task.futureResult.wait()) diff --git a/Tests/AsyncHTTPClientTests/TransactionTests.swift b/Tests/AsyncHTTPClientTests/TransactionTests.swift index dda216975..b936b7155 100644 --- a/Tests/AsyncHTTPClientTests/TransactionTests.swift +++ b/Tests/AsyncHTTPClientTests/TransactionTests.swift @@ -98,7 +98,8 @@ final class TransactionTests: XCTestCase { } func finishRequestBodyStream( - _ task: AsyncHTTPClient.HTTPExecutableRequest, + trailers: HTTPHeaders?, + request: AsyncHTTPClient.HTTPExecutableRequest, promise: NIOCore.EventLoopPromise? ) { XCTFail()