diff --git a/Cargo.lock b/Cargo.lock index 2da47afa54..36cb94422e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1177,7 +1177,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.87", ] [[package]] @@ -1914,7 +1914,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.87", ] [[package]] @@ -3986,7 +3986,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.87", ] [[package]] @@ -4815,7 +4815,7 @@ checksum = "28cc31741b18cb6f1d5ff12f5b7523e3d6eb0852bbbad19d73905511d9849b95" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.87", "synstructure", ] @@ -4856,7 +4856,7 @@ checksum = "0ea7b4a3637ea8669cedf0f1fd5c286a17f3de97b8dd5a70a6c167a1730e63a5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.87", "synstructure", ] @@ -4899,5 +4899,5 @@ checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.87", ] diff --git a/sqlx-core/src/net/socket/mod.rs b/sqlx-core/src/net/socket/mod.rs index 0470abb5ec..6b09d318f7 100644 --- a/sqlx-core/src/net/socket/mod.rs +++ b/sqlx-core/src/net/socket/mod.rs @@ -143,7 +143,10 @@ where pub trait WithSocket { type Output; - fn with_socket(self, socket: S) -> Self::Output; + fn with_socket( + self, + socket: S, + ) -> impl std::future::Future + Send; } pub struct SocketIntoBox; @@ -151,7 +154,7 @@ pub struct SocketIntoBox; impl WithSocket for SocketIntoBox { type Output = Box; - fn with_socket(self, socket: S) -> Self::Output { + async fn with_socket(self, socket: S) -> Self::Output { Box::new(socket) } } @@ -197,7 +200,7 @@ pub async fn connect_tcp( let stream = TcpStream::connect((host, port)).await?; stream.set_nodelay(true)?; - return Ok(with_socket.with_socket(stream)); + return Ok(with_socket.with_socket(stream).await); } #[cfg(feature = "_rt-async-std")] @@ -217,7 +220,7 @@ pub async fn connect_tcp( Ok(s) }); match stream { - Ok(stream) => return Ok(with_socket.with_socket(stream)), + Ok(stream) => return Ok(with_socket.with_socket(stream).await), Err(e) => last_err = Some(e), } } @@ -255,7 +258,7 @@ pub async fn connect_uds, Ws: WithSocket>( let stream = UnixStream::connect(path).await?; - return Ok(with_socket.with_socket(stream)); + return Ok(with_socket.with_socket(stream).await); } #[cfg(feature = "_rt-async-std")] @@ -265,7 +268,7 @@ pub async fn connect_uds, Ws: WithSocket>( let stream = Async::::connect(path).await?; - Ok(with_socket.with_socket(stream)) + Ok(with_socket.with_socket(stream).await) } #[cfg(not(feature = "_rt-async-std"))] diff --git a/sqlx-core/src/net/tls/mod.rs b/sqlx-core/src/net/tls/mod.rs index b49708b22e..3e9fd9b9a0 100644 --- a/sqlx-core/src/net/tls/mod.rs +++ b/sqlx-core/src/net/tls/mod.rs @@ -75,10 +75,14 @@ where Ws: WithSocket, { #[cfg(feature = "_tls-native-tls")] - return Ok(with_socket.with_socket(tls_native_tls::handshake(socket, config).await?)); + return Ok(with_socket + .with_socket(tls_native_tls::handshake(socket, config).await?) + .await); #[cfg(all(feature = "_tls-rustls", not(feature = "_tls-native-tls")))] - return Ok(with_socket.with_socket(tls_rustls::handshake(socket, config).await?)); + return Ok(with_socket + .with_socket(tls_rustls::handshake(socket, config).await?) + .await); #[cfg(not(any(feature = "_tls-native-tls", feature = "_tls-rustls")))] { diff --git a/sqlx-mysql/src/connection/establish.rs b/sqlx-mysql/src/connection/establish.rs index 468478e550..0623a0556c 100644 --- a/sqlx-mysql/src/connection/establish.rs +++ b/sqlx-mysql/src/connection/establish.rs @@ -1,6 +1,5 @@ use bytes::buf::Buf; use bytes::Bytes; -use futures_core::future::BoxFuture; use crate::collation::{CharSet, Collation}; use crate::common::StatementCache; @@ -22,7 +21,7 @@ impl MySqlConnection { None => crate::net::connect_tcp(&options.host, options.port, do_handshake).await?, }; - let stream = handshake.await?; + let stream = handshake?; Ok(Self { inner: Box::new(MySqlConnectionInner { @@ -187,9 +186,9 @@ impl<'a> DoHandshake<'a> { } impl<'a> WithSocket for DoHandshake<'a> { - type Output = BoxFuture<'a, Result>; + type Output = Result; - fn with_socket(self, socket: S) -> Self::Output { - Box::pin(self.do_handshake(socket)) + async fn with_socket(self, socket: S) -> Self::Output { + self.do_handshake(socket).await } } diff --git a/sqlx-mysql/src/connection/tls.rs b/sqlx-mysql/src/connection/tls.rs index 22b3c487b5..eb077c621b 100644 --- a/sqlx-mysql/src/connection/tls.rs +++ b/sqlx-mysql/src/connection/tls.rs @@ -94,7 +94,7 @@ pub(super) async fn maybe_upgrade( impl WithSocket for MapStream { type Output = MySqlStream; - fn with_socket(self, socket: S) -> Self::Output { + async fn with_socket(self, socket: S) -> Self::Output { MySqlStream { socket: BufferedSocket::new(Box::new(socket)), server_version: self.server_version, diff --git a/sqlx-postgres/src/connection/describe.rs b/sqlx-postgres/src/connection/describe.rs index 6e7afb53d0..9303e602e5 100644 --- a/sqlx-postgres/src/connection/describe.rs +++ b/sqlx-postgres/src/connection/describe.rs @@ -10,7 +10,6 @@ use crate::types::Json; use crate::types::Oid; use crate::HashMap; use crate::{PgColumn, PgConnection, PgTypeInfo}; -use futures_core::future::BoxFuture; use smallvec::SmallVec; use sqlx_core::query_builder::QueryBuilder; use std::sync::Arc; @@ -169,7 +168,8 @@ impl PgConnection { // fallback to asking the database directly for a type name if should_fetch { - let info = self.fetch_type_by_oid(oid).await?; + // we're boxing this future here so we can use async recursion + let info = Box::pin(async { self.fetch_type_by_oid(oid).await }).await?; // cache the type name <-> oid relationship in a paired hashmap // so we don't come down this road again @@ -190,19 +190,18 @@ impl PgConnection { } } - fn fetch_type_by_oid(&mut self, oid: Oid) -> BoxFuture<'_, Result> { - Box::pin(async move { - let (name, typ_type, category, relation_id, element, base_type): ( - String, - i8, - i8, - Oid, - Oid, - Oid, - ) = query_as( - // Converting the OID to `regtype` and then `text` will give us the name that - // the type will need to be found at by search_path. - "SELECT oid::regtype::text, \ + async fn fetch_type_by_oid(&mut self, oid: Oid) -> Result { + let (name, typ_type, category, relation_id, element, base_type): ( + String, + i8, + i8, + Oid, + Oid, + Oid, + ) = query_as( + // Converting the OID to `regtype` and then `text` will give us the name that + // the type will need to be found at by search_path. + "SELECT oid::regtype::text, \ typtype, \ typcategory, \ typrelid, \ @@ -210,54 +209,51 @@ impl PgConnection { typbasetype \ FROM pg_catalog.pg_type \ WHERE oid = $1", - ) - .bind(oid) - .fetch_one(&mut *self) - .await?; - - let typ_type = TypType::try_from(typ_type); - let category = TypCategory::try_from(category); - - match (typ_type, category) { - (Ok(TypType::Domain), _) => self.fetch_domain_by_oid(oid, base_type, name).await, - - (Ok(TypType::Base), Ok(TypCategory::Array)) => { - Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { - kind: PgTypeKind::Array( - self.maybe_fetch_type_info_by_oid(element, true).await?, - ), - name: name.into(), - oid, - })))) - } - - (Ok(TypType::Pseudo), Ok(TypCategory::Pseudo)) => { - Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { - kind: PgTypeKind::Pseudo, - name: name.into(), - oid, - })))) - } + ) + .bind(oid) + .fetch_one(&mut *self) + .await?; - (Ok(TypType::Range), Ok(TypCategory::Range)) => { - self.fetch_range_by_oid(oid, name).await - } + let typ_type = TypType::try_from(typ_type); + let category = TypCategory::try_from(category); - (Ok(TypType::Enum), Ok(TypCategory::Enum)) => { - self.fetch_enum_by_oid(oid, name).await - } + match (typ_type, category) { + (Ok(TypType::Domain), _) => self.fetch_domain_by_oid(oid, base_type, name).await, - (Ok(TypType::Composite), Ok(TypCategory::Composite)) => { - self.fetch_composite_by_oid(oid, relation_id, name).await - } + (Ok(TypType::Base), Ok(TypCategory::Array)) => { + Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { + kind: PgTypeKind::Array( + self.maybe_fetch_type_info_by_oid(element, true).await?, + ), + name: name.into(), + oid, + })))) + } - _ => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { - kind: PgTypeKind::Simple, + (Ok(TypType::Pseudo), Ok(TypCategory::Pseudo)) => { + Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { + kind: PgTypeKind::Pseudo, name: name.into(), oid, - })))), + })))) + } + + (Ok(TypType::Range), Ok(TypCategory::Range)) => { + self.fetch_range_by_oid(oid, name).await } - }) + + (Ok(TypType::Enum), Ok(TypCategory::Enum)) => self.fetch_enum_by_oid(oid, name).await, + + (Ok(TypType::Composite), Ok(TypCategory::Composite)) => { + self.fetch_composite_by_oid(oid, relation_id, name).await + } + + _ => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { + kind: PgTypeKind::Simple, + name: name.into(), + oid, + })))), + } } async fn fetch_enum_by_oid(&mut self, oid: Oid, name: String) -> Result { @@ -280,15 +276,14 @@ ORDER BY enumsortorder })))) } - fn fetch_composite_by_oid( + async fn fetch_composite_by_oid( &mut self, oid: Oid, relation_id: Oid, name: String, - ) -> BoxFuture<'_, Result> { - Box::pin(async move { - let raw_fields: Vec<(String, Oid)> = query_as( - r#" + ) -> Result { + let raw_fields: Vec<(String, Oid)> = query_as( + r#" SELECT attname, atttypid FROM pg_catalog.pg_attribute WHERE attrelid = $1 @@ -296,69 +291,60 @@ AND NOT attisdropped AND attnum > 0 ORDER BY attnum "#, - ) - .bind(relation_id) - .fetch_all(&mut *self) - .await?; + ) + .bind(relation_id) + .fetch_all(&mut *self) + .await?; - let mut fields = Vec::new(); + let mut fields = Vec::new(); - for (field_name, field_oid) in raw_fields.into_iter() { - let field_type = self.maybe_fetch_type_info_by_oid(field_oid, true).await?; + for (field_name, field_oid) in raw_fields.into_iter() { + let field_type = self.maybe_fetch_type_info_by_oid(field_oid, true).await?; - fields.push((field_name, field_type)); - } + fields.push((field_name, field_type)); + } - Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { - oid, - name: name.into(), - kind: PgTypeKind::Composite(Arc::from(fields)), - })))) - }) + Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { + oid, + name: name.into(), + kind: PgTypeKind::Composite(Arc::from(fields)), + })))) } - fn fetch_domain_by_oid( + async fn fetch_domain_by_oid( &mut self, oid: Oid, base_type: Oid, name: String, - ) -> BoxFuture<'_, Result> { - Box::pin(async move { - let base_type = self.maybe_fetch_type_info_by_oid(base_type, true).await?; + ) -> Result { + let base_type = self.maybe_fetch_type_info_by_oid(base_type, true).await?; - Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { - oid, - name: name.into(), - kind: PgTypeKind::Domain(base_type), - })))) - }) + Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { + oid, + name: name.into(), + kind: PgTypeKind::Domain(base_type), + })))) } - fn fetch_range_by_oid( - &mut self, - oid: Oid, - name: String, - ) -> BoxFuture<'_, Result> { - Box::pin(async move { - let element_oid: Oid = query_scalar( - r#" + async fn fetch_range_by_oid(&mut self, oid: Oid, name: String) -> Result { + let element_oid: Oid = query_scalar( + r#" SELECT rngsubtype FROM pg_catalog.pg_range WHERE rngtypid = $1 "#, - ) - .bind(oid) - .fetch_one(&mut *self) - .await?; + ) + .bind(oid) + .fetch_one(&mut *self) + .await?; - let element = self.maybe_fetch_type_info_by_oid(element_oid, true).await?; + let element = self.maybe_fetch_type_info_by_oid(element_oid, true).await?; - Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { - kind: PgTypeKind::Range(element), - name: name.into(), - oid, - })))) - }) + Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { + kind: PgTypeKind::Range(element), + name: name.into(), + oid, + })))) } pub(crate) async fn resolve_type_id(&mut self, ty: &PgType) -> Result { diff --git a/sqlx-postgres/src/connection/stream.rs b/sqlx-postgres/src/connection/stream.rs index f165899248..e8a1aedc47 100644 --- a/sqlx-postgres/src/connection/stream.rs +++ b/sqlx-postgres/src/connection/stream.rs @@ -42,12 +42,12 @@ pub struct PgStream { impl PgStream { pub(super) async fn connect(options: &PgConnectOptions) -> Result { - let socket_future = match options.fetch_socket() { + let socket_result = match options.fetch_socket() { Some(ref path) => net::connect_uds(path, MaybeUpgradeTls(options)).await?, None => net::connect_tcp(&options.host, options.port, MaybeUpgradeTls(options)).await?, }; - let socket = socket_future.await?; + let socket = socket_result?; Ok(Self { inner: BufferedSocket::new(socket), diff --git a/sqlx-postgres/src/connection/tls.rs b/sqlx-postgres/src/connection/tls.rs index 04bab793a7..16b7333bf5 100644 --- a/sqlx-postgres/src/connection/tls.rs +++ b/sqlx-postgres/src/connection/tls.rs @@ -1,5 +1,3 @@ -use futures_core::future::BoxFuture; - use crate::error::Error; use crate::net::tls::{self, TlsConfig}; use crate::net::{Socket, SocketIntoBox, WithSocket}; @@ -10,10 +8,10 @@ use crate::{PgConnectOptions, PgSslMode}; pub struct MaybeUpgradeTls<'a>(pub &'a PgConnectOptions); impl<'a> WithSocket for MaybeUpgradeTls<'a> { - type Output = BoxFuture<'a, crate::Result>>; + type Output = crate::Result>; - fn with_socket(self, socket: S) -> Self::Output { - Box::pin(maybe_upgrade(socket, self.0)) + async fn with_socket(self, socket: S) -> Self::Output { + maybe_upgrade(socket, self.0).await } }