diff --git a/tests/h2-support/src/future_ext.rs b/tests/h2-support/src/future_ext.rs index 9f659b344..cca18c66e 100644 --- a/tests/h2-support/src/future_ext.rs +++ b/tests/h2-support/src/future_ext.rs @@ -1,7 +1,9 @@ -use futures::FutureExt; +use futures::{FutureExt, TryFuture}; use std::future::Future; use std::pin::Pin; -use std::task::{Context, Poll}; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; +use std::task::{Context, Poll, Wake, Waker}; /// Future extension helpers that are useful for tests pub trait TestFuture: Future { @@ -15,9 +17,140 @@ pub trait TestFuture: Future { { Drive { driver: self, - future: Box::pin(other), + future: other.wakened(), } } + + fn wakened(self) -> Wakened + where + Self: Sized, + { + Wakened { + future: Box::pin(self), + woken: Arc::new(AtomicBool::new(true)), + } + } +} + +/// Wraps futures::future::join to ensure that the futures are only polled if they are woken. +pub fn join( + future1: Fut1, + future2: Fut2, +) -> futures::future::Join, Wakened> +where + Fut1: Future, + Fut2: Future, +{ + futures::future::join(future1.wakened(), future2.wakened()) +} + +/// Wraps futures::future::join3 to ensure that the futures are only polled if they are woken. +pub fn join3( + future1: Fut1, + future2: Fut2, + future3: Fut3, +) -> futures::future::Join3, Wakened, Wakened> +where + Fut1: Future, + Fut2: Future, + Fut3: Future, +{ + futures::future::join3(future1.wakened(), future2.wakened(), future3.wakened()) +} + +/// Wraps futures::future::join4 to ensure that the futures are only polled if they are woken. +pub fn join4( + future1: Fut1, + future2: Fut2, + future3: Fut3, + future4: Fut4, +) -> futures::future::Join4, Wakened, Wakened, Wakened> +where + Fut1: Future, + Fut2: Future, + Fut3: Future, + Fut4: Future, +{ + futures::future::join4( + future1.wakened(), + future2.wakened(), + future3.wakened(), + future4.wakened(), + ) +} + +/// Wraps futures::future::try_join to ensure that the futures are only polled if they are woken. +pub fn try_join( + future1: Fut1, + future2: Fut2, +) -> futures::future::TryJoin, Wakened> +where + Fut1: futures::future::TryFuture + Future, + Fut2: Future, + Wakened: futures::future::TryFuture, + Wakened: futures::future::TryFuture as TryFuture>::Error>, +{ + futures::future::try_join(future1.wakened(), future2.wakened()) +} + +/// Wraps futures::future::select to ensure that the futures are only polled if they are woken. +pub fn select(future1: A, future2: B) -> futures::future::Select, Wakened> +where + A: Future + Unpin, + B: Future + Unpin, +{ + futures::future::select(future1.wakened(), future2.wakened()) +} + +/// Wraps futures::future::join_all to ensure that the futures are only polled if they are woken. +pub fn join_all(iter: I) -> futures::future::JoinAll> +where + I: IntoIterator, + I::Item: Future, +{ + futures::future::join_all(iter.into_iter().map(|f| f.wakened())) +} + +/// A future that only polls the inner future if it has been woken (after the initial poll). +pub struct Wakened { + future: Pin>, + woken: Arc, +} + +/// A future that only polls the inner future if it has been woken (after the initial poll). +impl Future for Wakened +where + T: Future, +{ + type Output = T::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + if !this.woken.load(std::sync::atomic::Ordering::SeqCst) { + return Poll::Pending; + } + this.woken.store(false, std::sync::atomic::Ordering::SeqCst); + let my_waker = IfWokenWaker { + inner: cx.waker().clone(), + wakened: this.woken.clone(), + }; + let my_waker = Arc::new(my_waker).into(); + let mut cx = Context::from_waker(&my_waker); + this.future.as_mut().poll(&mut cx) + } +} + +impl Wake for IfWokenWaker { + fn wake(self: Arc) { + self.wakened + .store(true, std::sync::atomic::Ordering::SeqCst); + self.inner.wake_by_ref(); + } +} + +struct IfWokenWaker { + inner: Waker, + wakened: Arc, } impl TestFuture for T {} @@ -29,7 +162,7 @@ impl TestFuture for T {} /// This is useful for H2 futures that also require the connection to be polled. pub struct Drive<'a, T, U> { driver: &'a mut T, - future: Pin>, + future: Wakened, } impl<'a, T, U> Future for Drive<'a, T, U> diff --git a/tests/h2-support/src/prelude.rs b/tests/h2-support/src/prelude.rs index c40a518da..4338143fd 100644 --- a/tests/h2-support/src/prelude.rs +++ b/tests/h2-support/src/prelude.rs @@ -35,7 +35,7 @@ pub use {bytes, futures, http, tokio::io as tokio_io, tracing, tracing_subscribe pub use futures::{Future, Sink, Stream}; // And our Future extensions -pub use super::future_ext::TestFuture; +pub use super::future_ext::{join, join3, join4, join_all, select, try_join, TestFuture}; // Our client_ext helpers pub use super::client_ext::SendRequestExt; diff --git a/tests/h2-tests/tests/client_request.rs b/tests/h2-tests/tests/client_request.rs index 261fe65fc..e914d4843 100644 --- a/tests/h2-tests/tests/client_request.rs +++ b/tests/h2-tests/tests/client_request.rs @@ -1,4 +1,4 @@ -use futures::future::{join, join_all, ready, select, Either}; +use futures::future::{ready, Either}; use futures::stream::FuturesUnordered; use futures::StreamExt; use h2_support::prelude::*; @@ -849,7 +849,7 @@ async fn recv_too_big_headers() { }; let client = async move { - let (mut client, conn) = client::Builder::new() + let (mut client, mut conn) = client::Builder::new() .max_header_list_size(10) .handshake::<_, Bytes>(io) .await @@ -863,10 +863,10 @@ async fn recv_too_big_headers() { let req1 = client.send_request(request, true); // Spawn tasks to ensure that the error wakes up tasks that are blocked // waiting for a response. - let req1 = tokio::spawn(async move { + let req1 = async move { let err = req1.expect("send_request").0.await.expect_err("response1"); assert_eq!(err.reason(), Some(Reason::PROTOCOL_ERROR)); - }); + }; let request = Request::builder() .uri("https://http2.akamai.com/") @@ -874,19 +874,12 @@ async fn recv_too_big_headers() { .unwrap(); let req2 = client.send_request(request, true); - let req2 = tokio::spawn(async move { + let req2 = async move { let err = req2.expect("send_request").0.await.expect_err("response2"); assert_eq!(err.reason(), Some(Reason::PROTOCOL_ERROR)); - }); + }; - let conn = tokio::spawn(async move { - conn.await.expect("client"); - }); - for err in join_all([req1, req2, conn]).await { - if let Some(err) = err.err().and_then(|err| err.try_into_panic().ok()) { - std::panic::resume_unwind(err); - } - } + conn.drive(join(req1, req2)).await; }; join(srv, client).await; diff --git a/tests/h2-tests/tests/codec_read.rs b/tests/h2-tests/tests/codec_read.rs index d955e186b..489d16daf 100644 --- a/tests/h2-tests/tests/codec_read.rs +++ b/tests/h2-tests/tests/codec_read.rs @@ -1,4 +1,3 @@ -use futures::future::join; use h2_support::prelude::*; #[tokio::test] diff --git a/tests/h2-tests/tests/codec_write.rs b/tests/h2-tests/tests/codec_write.rs index 0b85a2238..04627cdc9 100644 --- a/tests/h2-tests/tests/codec_write.rs +++ b/tests/h2-tests/tests/codec_write.rs @@ -1,4 +1,3 @@ -use futures::future::join; use h2_support::prelude::*; #[tokio::test] diff --git a/tests/h2-tests/tests/flow_control.rs b/tests/h2-tests/tests/flow_control.rs index dbb933286..e3caaff5f 100644 --- a/tests/h2-tests/tests/flow_control.rs +++ b/tests/h2-tests/tests/flow_control.rs @@ -1,4 +1,3 @@ -use futures::future::{join, join4}; use futures::{StreamExt, TryStreamExt}; use h2_support::prelude::*; use h2_support::util::yield_once; diff --git a/tests/h2-tests/tests/ping_pong.rs b/tests/h2-tests/tests/ping_pong.rs index 0f93578cc..2132c7acf 100644 --- a/tests/h2-tests/tests/ping_pong.rs +++ b/tests/h2-tests/tests/ping_pong.rs @@ -1,5 +1,4 @@ use futures::channel::oneshot; -use futures::future::join; use futures::StreamExt; use h2_support::assert_ping; use h2_support::prelude::*; diff --git a/tests/h2-tests/tests/prioritization.rs b/tests/h2-tests/tests/prioritization.rs index 11d2c2ccf..dd4ed9fea 100644 --- a/tests/h2-tests/tests/prioritization.rs +++ b/tests/h2-tests/tests/prioritization.rs @@ -1,4 +1,3 @@ -use futures::future::{join, select}; use futures::{pin_mut, FutureExt, StreamExt}; use h2_support::prelude::*; diff --git a/tests/h2-tests/tests/push_promise.rs b/tests/h2-tests/tests/push_promise.rs index 61fb05433..b2d8fde0a 100644 --- a/tests/h2-tests/tests/push_promise.rs +++ b/tests/h2-tests/tests/push_promise.rs @@ -1,6 +1,4 @@ -use std::iter::FromIterator; - -use futures::{future::join, FutureExt as _, StreamExt, TryStreamExt}; +use futures::{StreamExt, TryStreamExt}; use h2_support::prelude::*; #[tokio::test] @@ -52,15 +50,9 @@ async fn recv_push_works() { let ps: Vec<_> = p.collect().await; assert_eq!(1, ps.len()) }; - // Use a FuturesUnordered to poll both tasks but only poll them - // if they have been notified. - let tasks = futures::stream::FuturesUnordered::from_iter([ - check_resp_status.boxed(), - check_pushed_response.boxed(), - ]) - .collect::<()>(); - - h2.drive(tasks).await; + + h2.drive(join(check_resp_status, check_pushed_response)) + .await; }; join(mock, h2).await; diff --git a/tests/h2-tests/tests/server.rs b/tests/h2-tests/tests/server.rs index 2075d22ba..c1af54198 100644 --- a/tests/h2-tests/tests/server.rs +++ b/tests/h2-tests/tests/server.rs @@ -1,6 +1,5 @@ #![deny(warnings)] -use futures::future::join; use futures::StreamExt; use h2_support::prelude::*; use tokio::io::AsyncWriteExt; diff --git a/tests/h2-tests/tests/stream_states.rs b/tests/h2-tests/tests/stream_states.rs index 05a96a0f5..9a377d798 100644 --- a/tests/h2-tests/tests/stream_states.rs +++ b/tests/h2-tests/tests/stream_states.rs @@ -1,6 +1,6 @@ #![deny(warnings)] -use futures::future::{join, join3, lazy, try_join}; +use futures::future::lazy; use futures::{FutureExt, StreamExt, TryStreamExt}; use h2_support::prelude::*; use h2_support::util::yield_once;