From fa5b1cec027e17267b394f097937b386c22585c6 Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Wed, 16 Oct 2024 17:24:20 +0200 Subject: [PATCH] http client: add `max_concurrent_requests` (#1473) --- client/http-client/src/client.rs | 38 ++++++++++++++++++++++++++++---- core/src/client/mod.rs | 2 -- 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/client/http-client/src/client.rs b/client/http-client/src/client.rs index 3f72a14276..ccbf6408e0 100644 --- a/client/http-client/src/client.rs +++ b/client/http-client/src/client.rs @@ -43,6 +43,7 @@ use jsonrpsee_core::traits::ToRpcParams; use jsonrpsee_core::{BoxError, JsonRawValue, TEN_MB_SIZE_BYTES}; use jsonrpsee_types::{ErrorObject, InvalidRequestId, ResponseSuccess, TwoPointZero}; use serde::de::DeserializeOwned; +use tokio::sync::Semaphore; use tower::layer::util::Identity; use tower::{Layer, Service}; use tracing::instrument; @@ -78,7 +79,6 @@ pub struct HttpClientBuilder { max_request_size: u32, max_response_size: u32, request_timeout: Duration, - max_concurrent_requests: usize, #[cfg(feature = "tls")] certificate_store: CertificateStore, id_kind: IdKind, @@ -86,6 +86,7 @@ pub struct HttpClientBuilder { headers: HeaderMap, service_builder: tower::ServiceBuilder, tcp_no_delay: bool, + max_concurrent_requests: Option, } impl HttpClientBuilder { @@ -107,6 +108,12 @@ impl HttpClientBuilder { self } + /// Set the maximum number of concurrent requests. Default disabled. + pub fn max_concurrent_requests(mut self, max_concurrent_requests: usize) -> Self { + self.max_concurrent_requests = Some(max_concurrent_requests); + self + } + /// Force to use the rustls native certificate store. /// /// Since multiple certificate stores can be optionally enabled, this option will @@ -216,12 +223,12 @@ impl HttpClientBuilder { id_kind: self.id_kind, headers: self.headers, max_log_length: self.max_log_length, - max_concurrent_requests: self.max_concurrent_requests, max_request_size: self.max_request_size, max_response_size: self.max_response_size, service_builder, request_timeout: self.request_timeout, tcp_no_delay: self.tcp_no_delay, + max_concurrent_requests: self.max_concurrent_requests, } } } @@ -263,7 +270,16 @@ where .build(target) .map_err(|e| Error::Transport(e.into()))?; - Ok(HttpClient { transport, id_manager: Arc::new(RequestIdManager::new(id_kind)), request_timeout }) + let request_guard = self + .max_concurrent_requests + .map(|max_concurrent_requests| Arc::new(Semaphore::new(max_concurrent_requests))); + + Ok(HttpClient { + transport, + id_manager: Arc::new(RequestIdManager::new(id_kind)), + request_timeout, + request_guard, + }) } } @@ -273,7 +289,6 @@ impl Default for HttpClientBuilder { max_request_size: TEN_MB_SIZE_BYTES, max_response_size: TEN_MB_SIZE_BYTES, request_timeout: Duration::from_secs(60), - max_concurrent_requests: 256, #[cfg(feature = "tls")] certificate_store: CertificateStore::Native, id_kind: IdKind::Number, @@ -281,6 +296,7 @@ impl Default for HttpClientBuilder { headers: HeaderMap::new(), service_builder: tower::ServiceBuilder::new(), tcp_no_delay: true, + max_concurrent_requests: None, } } } @@ -301,6 +317,8 @@ pub struct HttpClient { request_timeout: Duration, /// Request ID manager. id_manager: Arc, + /// Concurrent requests limit guard. + request_guard: Option>, } impl HttpClient { @@ -324,6 +342,10 @@ where where Params: ToRpcParams + Send, { + let _permit = match self.request_guard.as_ref() { + Some(permit) => permit.acquire().await.ok(), + None => None, + }; let params = params.to_rpc_params()?; let notif = serde_json::to_string(&NotificationSer::borrowed(&method, params.as_deref())).map_err(Error::ParseError)?; @@ -343,6 +365,10 @@ where R: DeserializeOwned, Params: ToRpcParams + Send, { + let _permit = match self.request_guard.as_ref() { + Some(permit) => permit.acquire().await.ok(), + None => None, + }; let id = self.id_manager.next_request_id(); let params = params.to_rpc_params()?; @@ -378,6 +404,10 @@ where where R: DeserializeOwned + fmt::Debug + 'a, { + let _permit = match self.request_guard.as_ref() { + Some(permit) => permit.acquire().await.ok(), + None => None, + }; let batch = batch.build()?; let id = self.id_manager.next_request_id(); let id_range = generate_batch_id_range(id, batch.len() as u64)?; diff --git a/core/src/client/mod.rs b/core/src/client/mod.rs index c2337eec67..cea2442eaa 100644 --- a/core/src/client/mod.rs +++ b/core/src/client/mod.rs @@ -468,8 +468,6 @@ impl RequestIdManager { } /// Attempts to get the next request ID. - /// - /// Fails if request limit has been exceeded. pub fn next_request_id(&self) -> Id<'static> { self.id_kind.into_id(self.current_id.next()) }