Skip to content

Commit

Permalink
feat: implement Encode,Decode,Type for Arc<str> and Arc<[u8]>
Browse files Browse the repository at this point in the history
  • Loading branch information
joeydewaal committed Jan 12, 2025
1 parent 676e11e commit b5c5d68
Show file tree
Hide file tree
Showing 10 changed files with 252 additions and 0 deletions.
14 changes: 14 additions & 0 deletions sqlx-mysql/src/types/bytes.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
Expand Down Expand Up @@ -83,3 +85,15 @@ impl Decode<'_, MySql> for Vec<u8> {
<&[u8] as Decode<MySql>>::decode(value).map(ToOwned::to_owned)
}
}

impl Encode<'_, MySql> for Arc<[u8]> {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
<&[u8] as Encode<MySql>>::encode(&**self, buf)
}
}

impl Decode<'_, MySql> for Arc<[u8]> {
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
<&[u8] as Decode<MySql>>::decode(value).map(Into::into)
}
}
13 changes: 13 additions & 0 deletions sqlx-mysql/src/types/str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::protocol::text::{ColumnFlags, ColumnType};
use crate::types::Type;
use crate::{MySql, MySqlTypeInfo, MySqlValueRef};
use std::borrow::Cow;
use std::sync::Arc;

impl Type<MySql> for str {
fn type_info() -> MySqlTypeInfo {
Expand Down Expand Up @@ -114,3 +115,15 @@ impl<'r> Decode<'r, MySql> for Cow<'r, str> {
value.as_str().map(Cow::Borrowed)
}
}

impl Encode<'_, MySql> for Arc<str> {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
<&str as Encode<MySql>>::encode(&**self, buf)
}
}

impl Decode<'_, MySql> for Arc<str> {
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
<&str as Decode<MySql>>::decode(value).map(Into::into)
}
}
21 changes: 21 additions & 0 deletions sqlx-postgres/src/types/array.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use sqlx_core::bytes::Buf;
use sqlx_core::types::Text;
use std::borrow::Cow;
use std::sync::Arc;

use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
Expand Down Expand Up @@ -192,6 +193,17 @@ where
}
}

impl<'q, T> Encode<'q, Postgres> for Arc<[T]>
where
for<'a> &'a [T]: Encode<'q, Postgres>,
T: Encode<'q, Postgres>,
{
#[inline]
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
<&[T] as Encode<Postgres>>::encode_by_ref(&self.as_ref(), buf)
}
}

impl<'r, T, const N: usize> Decode<'r, Postgres> for [T; N]
where
T: for<'a> Decode<'a, Postgres> + Type<Postgres>,
Expand Down Expand Up @@ -354,3 +366,12 @@ where
}
}
}

impl<'r, T> Decode<'r, Postgres> for Arc<[T]>
where
T: for<'a> Decode<'a, Postgres> + Type<Postgres>,
{
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
<Vec<T> as Decode<Postgres>>::decode(value).map(Into::into)
}
}
23 changes: 23 additions & 0 deletions sqlx-postgres/src/types/bytes.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
Expand Down Expand Up @@ -28,6 +30,12 @@ impl PgHasArrayType for Vec<u8> {
}
}

impl PgHasArrayType for Arc<[u8]> {
fn array_type_info() -> PgTypeInfo {
<[&[u8]] as Type<Postgres>>::type_info()
}
}

impl<const N: usize> PgHasArrayType for [u8; N] {
fn array_type_info() -> PgTypeInfo {
<[&[u8]] as Type<Postgres>>::type_info()
Expand Down Expand Up @@ -60,6 +68,12 @@ impl<const N: usize> Encode<'_, Postgres> for [u8; N] {
}
}

impl Encode<'_, Postgres> for Arc<[u8]> {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
<&[u8] as Encode<Postgres>>::encode(self, buf)
}
}

impl<'r> Decode<'r, Postgres> for &'r [u8] {
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
match value.format() {
Expand Down Expand Up @@ -110,3 +124,12 @@ impl<const N: usize> Decode<'_, Postgres> for [u8; N] {
Ok(bytes)
}
}

impl Decode<'_, Postgres> for Arc<[u8]> {
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(match value.format() {
PgValueFormat::Binary => value.as_bytes()?.into(),
PgValueFormat::Text => hex::decode(text_hex_decode_input(value)?)?.into(),
})
}
}
23 changes: 23 additions & 0 deletions sqlx-postgres/src/types/str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::types::array_compatible;
use crate::types::Type;
use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueRef, Postgres};
use std::borrow::Cow;
use std::sync::Arc;

impl Type<Postgres> for str {
fn type_info() -> PgTypeInfo {
Expand Down Expand Up @@ -94,6 +95,16 @@ impl PgHasArrayType for String {
}
}

impl PgHasArrayType for Arc<str> {
fn array_type_info() -> PgTypeInfo {
<&str as PgHasArrayType>::array_type_info()
}

fn array_compatible(ty: &PgTypeInfo) -> bool {
<&str as PgHasArrayType>::array_compatible(ty)
}
}

impl Encode<'_, Postgres> for &'_ str {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
buf.extend(self.as_bytes());
Expand Down Expand Up @@ -123,6 +134,12 @@ impl Encode<'_, Postgres> for String {
}
}

impl Encode<'_, Postgres> for Arc<str> {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
<&str as Encode<Postgres>>::encode(&**self, buf)
}
}

impl<'r> Decode<'r, Postgres> for &'r str {
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
value.as_str()
Expand All @@ -146,3 +163,9 @@ impl Decode<'_, Postgres> for String {
Ok(value.as_str()?.to_owned())
}
}

impl Decode<'_, Postgres> for Arc<str> {
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(value.as_str()?.into())
}
}
24 changes: 24 additions & 0 deletions sqlx-sqlite/src/types/bytes.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::borrow::Cow;
use std::sync::Arc;

use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
Expand Down Expand Up @@ -101,3 +102,26 @@ impl<'r> Decode<'r, Sqlite> for Vec<u8> {
Ok(value.blob().to_owned())
}
}

impl<'q> Encode<'q, Sqlite> for Arc<[u8]> {
fn encode(self, args: &mut Vec<SqliteArgumentValue<'q>>) -> Result<IsNull, BoxDynError> {
args.push(SqliteArgumentValue::Blob(Cow::Owned(self.to_vec())));

Ok(IsNull::No)
}

fn encode_by_ref(
&self,
args: &mut Vec<SqliteArgumentValue<'q>>,
) -> Result<IsNull, BoxDynError> {
args.push(SqliteArgumentValue::Blob(Cow::Owned(self.to_vec())));

Ok(IsNull::No)
}
}

impl<'r> Decode<'r, Sqlite> for Arc<[u8]> {
fn decode(value: SqliteValueRef<'r>) -> Result<Self, BoxDynError> {
Ok(value.blob().into())
}
}
24 changes: 24 additions & 0 deletions sqlx-sqlite/src/types/str.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::borrow::Cow;
use std::sync::Arc;

use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
Expand Down Expand Up @@ -122,3 +123,26 @@ impl<'r> Decode<'r, Sqlite> for Cow<'r, str> {
value.text().map(Cow::Borrowed)
}
}

impl<'q> Encode<'q, Sqlite> for Arc<str> {
fn encode(self, args: &mut Vec<SqliteArgumentValue<'q>>) -> Result<IsNull, BoxDynError> {
args.push(SqliteArgumentValue::Text(Cow::Owned(self.to_string())));

Ok(IsNull::No)
}

fn encode_by_ref(
&self,
args: &mut Vec<SqliteArgumentValue<'q>>,
) -> Result<IsNull, BoxDynError> {
args.push(SqliteArgumentValue::Text(Cow::Owned(self.to_string())));

Ok(IsNull::No)
}
}

impl<'r> Decode<'r, Sqlite> for Arc<str> {
fn decode(value: SqliteValueRef<'r>) -> Result<Self, BoxDynError> {
value.text().map(Into::into)
}
}
31 changes: 31 additions & 0 deletions tests/mysql/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ extern crate time_ as time;
use std::net::SocketAddr;
#[cfg(feature = "rust_decimal")]
use std::str::FromStr;
use std::sync::Arc;

use sqlx::mysql::MySql;
use sqlx::{Executor, Row};
Expand Down Expand Up @@ -384,3 +385,33 @@ CREATE TEMPORARY TABLE user_login (

Ok(())
}

#[sqlx_macros::test]
async fn test_arc_str() -> anyhow::Result<()> {
let mut conn = new::<MySql>().await?;

let name: Arc<str> = "Harold".into();

let username: Arc<str> = sqlx::query_scalar("SELECT ? AS username")
.bind(&name)
.fetch_one(&mut conn)
.await?;

assert!(username == name);
Ok(())
}

#[sqlx_macros::test]
async fn test_arc_slice() -> anyhow::Result<()> {
let mut conn = new::<MySql>().await?;

let name: Arc<[u8]> = [5, 0].into();

let username: Arc<[u8]> = sqlx::query_scalar("SELECT ?")
.bind(&name)
.fetch_one(&mut conn)
.await?;

assert!(username == name);
Ok(())
}
48 changes: 48 additions & 0 deletions tests/postgres/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::sync::Arc;

use sqlx::postgres::types::{Oid, PgCiText, PgInterval, PgMoney, PgRange};
use sqlx::postgres::Postgres;
use sqlx_macros::FromRow;
use sqlx_test::{new, test_decode_type, test_prepared_type, test_type};

use sqlx_core::executor::Executor;
Expand Down Expand Up @@ -673,6 +674,21 @@ async fn test_arc() -> anyhow::Result<()> {
Ok(())
}

#[sqlx_macros::test]
async fn test_arc_str() -> anyhow::Result<()> {
let mut conn = new::<Postgres>().await?;

let name: Arc<str> = "Harold".into();

let username: Arc<str> = sqlx::query_scalar("SELECT $1 AS username")
.bind(&name)
.fetch_one(&mut conn)
.await?;

assert!(username == name);
Ok(())
}

#[sqlx_macros::test]
async fn test_cow() -> anyhow::Result<()> {
let mut conn = new::<Postgres>().await?;
Expand All @@ -688,6 +704,21 @@ async fn test_cow() -> anyhow::Result<()> {
Ok(())
}

#[sqlx_macros::test]
async fn test_arc_slice() -> anyhow::Result<()> {
let mut conn = new::<Postgres>().await?;

let name: Arc<[u8]> = [5, 0].into();

let username: Arc<[u8]> = sqlx::query_scalar("SELECT $1")
.bind(&name)
.fetch_one(&mut conn)
.await?;

assert!(username == name);
Ok(())
}

#[sqlx_macros::test]
async fn test_box() -> anyhow::Result<()> {
let mut conn = new::<Postgres>().await?;
Expand All @@ -713,3 +744,20 @@ async fn test_rc() -> anyhow::Result<()> {
assert!(user_age == 1);
Ok(())
}

#[sqlx_macros::test]
async fn test_arc_slice_2() -> anyhow::Result<()> {
let mut conn = new::<Postgres>().await?;

#[derive(FromRow)]
struct Nested {
inner: Arc<[i32]>,
}

let username: Nested = sqlx::query_as("SELECT ARRAY[1, 2, 3]::INT4[] as inner")
.fetch_one(&mut conn)
.await?;

assert!(username.inner.as_ref() == &[1, 2, 3]);
Ok(())
}
Loading

0 comments on commit b5c5d68

Please sign in to comment.