diff --git a/src/proto/streams/state.rs b/src/proto/streams/state.rs index 5256f09c..3bdd6bfe 100644 --- a/src/proto/streams/state.rs +++ b/src/proto/streams/state.rs @@ -112,7 +112,7 @@ impl State { Open { local, remote } } } - HalfClosedRemote(AwaitingHeaders) | ReservedLocal => { + HalfClosedRemote(AwaitingHeaders | Streaming) | ReservedLocal => { if eos { Closed(Cause::EndStream) } else { diff --git a/src/server.rs b/src/server.rs index b00bc086..9f98ece6 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1124,6 +1124,26 @@ impl SendResponse { .map_err(Into::into) } + /// Send a non-final 1xx response to a client request. + /// + /// The [`SendResponse`] instance is already associated with a received + /// request. This function may only be called if [`send_reset`] or + /// [`send_response`] has not been previously called. + /// + /// [`SendResponse`]: # + /// [`send_reset`]: #method.send_reset + /// [`send_response`]: #method.send_response + /// + /// # Panics + /// + /// If a "final" response has already been sent, or if the stream has been reset. + pub fn send_info(&mut self, response: Response<()>) -> Result<(), crate::Error> { + assert!(response.status().is_informational()); + self.inner + .send_response(response, false) + .map_err(Into::into) + } + /// Push a request and response to the client /// /// On success, a [`SendResponse`] instance is returned. diff --git a/tests/h2-tests/tests/server.rs b/tests/h2-tests/tests/server.rs index c1af5419..2ba47443 100644 --- a/tests/h2-tests/tests/server.rs +++ b/tests/h2-tests/tests/server.rs @@ -104,6 +104,61 @@ async fn serve_request() { join(client, srv).await; } +#[tokio::test] +async fn serve_request_expect_continue() { + h2_support::trace_init!(); + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + assert_default_settings!(settings); + client + .send_frame( + frames::headers(1) + .field(http::header::EXPECT, "100-continue") + .request("POST", "https://example.com/"), + ) + .await; + client.recv_frame(frames::headers(1).response(100)).await; + client + .send_frame(frames::data(1, "hello world").eos()) + .await; + client + .recv_frame(frames::headers(1).response(200).eos()) + .await; + }; + + let srv = async move { + let mut srv = server::handshake(io).await.expect("handshake"); + let (req, mut stream) = srv.next().await.unwrap().unwrap(); + + assert_eq!(req.method(), &http::Method::POST); + assert_eq!( + req.headers().get(http::header::EXPECT), + Some(&http::HeaderValue::from_static("100-continue")) + ); + + let connection_fut = poll_fn(|cx| srv.poll_closed(cx).map(Result::ok)); + let test_fut = async move { + stream.send_continue().unwrap(); + + let mut body = req.into_body(); + assert_eq!( + body.next().await.unwrap().unwrap(), + Bytes::from_static(b"hello world") + ); + assert!(body.next().await.is_none()); + + let rsp = http::Response::builder().status(200).body(()).unwrap(); + stream.send_response(rsp, true).unwrap(); + }; + + join(connection_fut, test_fut).await; + }; + + join(client, srv).await; +} + #[tokio::test] async fn serve_connect() { h2_support::trace_init!();