diff --git a/src/proto/streams/prioritize.rs b/src/proto/streams/prioritize.rs index 14b37e22..81825f40 100644 --- a/src/proto/streams/prioritize.rs +++ b/src/proto/streams/prioritize.rs @@ -839,10 +839,7 @@ impl Prioritize { }), None => { if let Some(reason) = stream.state.get_scheduled_reset() { - let stream_id = stream.id; - stream - .state - .set_reset(stream_id, reason, Initiator::Library); + stream.set_reset(reason, Initiator::Library); let frame = frame::Reset::new(stream.id, reason); Frame::Reset(frame) diff --git a/src/proto/streams/recv.rs b/src/proto/streams/recv.rs index 46cb87cd..baa0bb97 100644 --- a/src/proto/streams/recv.rs +++ b/src/proto/streams/recv.rs @@ -296,7 +296,7 @@ impl Recv { let is_open = stream.state.ensure_recv_open()?; if is_open { - stream.recv_task = Some(cx.waker().clone()); + stream.push_task = Some(cx.waker().clone()); Poll::Pending } else { Poll::Ready(None) @@ -760,6 +760,7 @@ impl Recv { .pending_recv .push_back(&mut self.buffer, Event::Headers(Server(req))); stream.notify_recv(); + stream.notify_push(); Ok(()) } @@ -814,6 +815,7 @@ impl Recv { stream.notify_send(); stream.notify_recv(); + stream.notify_push(); Ok(()) } @@ -826,6 +828,7 @@ impl Recv { // If a receiver is waiting, notify it stream.notify_send(); stream.notify_recv(); + stream.notify_push(); } pub fn go_away(&mut self, last_processed_id: StreamId) { @@ -837,6 +840,7 @@ impl Recv { stream.state.recv_eof(); stream.notify_send(); stream.notify_recv(); + stream.notify_push(); } pub(super) fn clear_recv_buffer(&mut self, stream: &mut Stream) { diff --git a/src/proto/streams/send.rs b/src/proto/streams/send.rs index 997b0fa4..2a7abba0 100644 --- a/src/proto/streams/send.rs +++ b/src/proto/streams/send.rs @@ -206,10 +206,7 @@ impl Send { } // Transition the state to reset no matter what. - stream.state.set_reset(stream_id, reason, initiator); - // Notify the recv task if it's waiting, because it'll - // want to hear about the reset. - stream.notify_recv(); + stream.set_reset(reason, initiator); // If closed AND the send queue is flushed, then the stream cannot be // reset explicitly, either. Implicit resets can still be queued. diff --git a/src/proto/streams/stream.rs b/src/proto/streams/stream.rs index e139da97..239536ee 100644 --- a/src/proto/streams/stream.rs +++ b/src/proto/streams/stream.rs @@ -1,3 +1,5 @@ +use crate::Reason; + use super::*; use std::task::{Context, Waker}; @@ -104,6 +106,9 @@ pub(super) struct Stream { /// Task tracking receiving frames pub recv_task: Option, + /// Task tracking pushed promises. + pub push_task: Option, + /// The stream's pending push promises pub pending_push_promises: store::Queue, @@ -186,6 +191,7 @@ impl Stream { pending_recv: buffer::Deque::new(), is_recv: true, recv_task: None, + push_task: None, pending_push_promises: store::Queue::new(), content_length: ContentLength::Omitted, } @@ -369,6 +375,20 @@ impl Stream { task.wake(); } } + + pub(super) fn notify_push(&mut self) { + if let Some(task) = self.push_task.take() { + task.wake(); + } + } + + /// Set the stream's state to `Closed` with the given reason and initiator. + /// Notify the send and receive tasks, if they exist. + pub(super) fn set_reset(&mut self, reason: Reason, initiator: Initiator ) { + self.state.set_reset(self.id, reason, initiator); + self.notify_push(); + self.notify_recv(); + } } impl store::Next for NextAccept { diff --git a/src/proto/streams/streams.rs b/src/proto/streams/streams.rs index e6c9ed8a..91a8f209 100644 --- a/src/proto/streams/streams.rs +++ b/src/proto/streams/streams.rs @@ -825,7 +825,7 @@ impl Inner { let parent = &mut self.store.resolve(parent_key); parent.pending_push_promises = ppp; - parent.notify_recv(); + parent.notify_push(); }; Ok(()) diff --git a/tests/h2-tests/tests/push_promise.rs b/tests/h2-tests/tests/push_promise.rs index 94c1154e..c2138edc 100644 --- a/tests/h2-tests/tests/push_promise.rs +++ b/tests/h2-tests/tests/push_promise.rs @@ -1,5 +1,6 @@ -use futures::future::join; -use futures::{StreamExt, TryStreamExt}; +use std::iter::FromIterator; + +use futures::{future::join, FutureExt as _, StreamExt, TryStreamExt}; use h2_support::prelude::*; #[tokio::test] @@ -51,9 +52,15 @@ async fn recv_push_works() { let ps: Vec<_> = p.collect().await; assert_eq!(1, ps.len()) }; - - h2.drive(join(check_resp_status, check_pushed_response)) - .await; + // Use a FuturesUnordered to poll both tasks but only poll them + // if they have been notified. + let tasks = futures::stream::FuturesUnordered::from_iter([ + check_resp_status.boxed(), + check_pushed_response.boxed(), + ]) + .collect::<()>(); + + h2.drive(tasks).await; }; join(mock, h2).await;