Skip to content

Commit b5c5d68

Browse files
committed
feat: implement Encode,Decode,Type for Arc<str> and Arc<[u8]>
1 parent 676e11e commit b5c5d68

File tree

10 files changed

+252
-0
lines changed

10 files changed

+252
-0
lines changed

sqlx-mysql/src/types/bytes.rs

+14
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::sync::Arc;
2+
13
use crate::decode::Decode;
24
use crate::encode::{Encode, IsNull};
35
use crate::error::BoxDynError;
@@ -83,3 +85,15 @@ impl Decode<'_, MySql> for Vec<u8> {
8385
<&[u8] as Decode<MySql>>::decode(value).map(ToOwned::to_owned)
8486
}
8587
}
88+
89+
impl Encode<'_, MySql> for Arc<[u8]> {
90+
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
91+
<&[u8] as Encode<MySql>>::encode(&**self, buf)
92+
}
93+
}
94+
95+
impl Decode<'_, MySql> for Arc<[u8]> {
96+
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
97+
<&[u8] as Decode<MySql>>::decode(value).map(Into::into)
98+
}
99+
}

sqlx-mysql/src/types/str.rs

+13
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use crate::protocol::text::{ColumnFlags, ColumnType};
66
use crate::types::Type;
77
use crate::{MySql, MySqlTypeInfo, MySqlValueRef};
88
use std::borrow::Cow;
9+
use std::sync::Arc;
910

1011
impl Type<MySql> for str {
1112
fn type_info() -> MySqlTypeInfo {
@@ -114,3 +115,15 @@ impl<'r> Decode<'r, MySql> for Cow<'r, str> {
114115
value.as_str().map(Cow::Borrowed)
115116
}
116117
}
118+
119+
impl Encode<'_, MySql> for Arc<str> {
120+
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
121+
<&str as Encode<MySql>>::encode(&**self, buf)
122+
}
123+
}
124+
125+
impl Decode<'_, MySql> for Arc<str> {
126+
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
127+
<&str as Decode<MySql>>::decode(value).map(Into::into)
128+
}
129+
}

sqlx-postgres/src/types/array.rs

+21
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use sqlx_core::bytes::Buf;
22
use sqlx_core::types::Text;
33
use std::borrow::Cow;
4+
use std::sync::Arc;
45

56
use crate::decode::Decode;
67
use crate::encode::{Encode, IsNull};
@@ -192,6 +193,17 @@ where
192193
}
193194
}
194195

196+
impl<'q, T> Encode<'q, Postgres> for Arc<[T]>
197+
where
198+
for<'a> &'a [T]: Encode<'q, Postgres>,
199+
T: Encode<'q, Postgres>,
200+
{
201+
#[inline]
202+
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
203+
<&[T] as Encode<Postgres>>::encode_by_ref(&self.as_ref(), buf)
204+
}
205+
}
206+
195207
impl<'r, T, const N: usize> Decode<'r, Postgres> for [T; N]
196208
where
197209
T: for<'a> Decode<'a, Postgres> + Type<Postgres>,
@@ -354,3 +366,12 @@ where
354366
}
355367
}
356368
}
369+
370+
impl<'r, T> Decode<'r, Postgres> for Arc<[T]>
371+
where
372+
T: for<'a> Decode<'a, Postgres> + Type<Postgres>,
373+
{
374+
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
375+
<Vec<T> as Decode<Postgres>>::decode(value).map(Into::into)
376+
}
377+
}

sqlx-postgres/src/types/bytes.rs

+23
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::sync::Arc;
2+
13
use crate::decode::Decode;
24
use crate::encode::{Encode, IsNull};
35
use crate::error::BoxDynError;
@@ -28,6 +30,12 @@ impl PgHasArrayType for Vec<u8> {
2830
}
2931
}
3032

33+
impl PgHasArrayType for Arc<[u8]> {
34+
fn array_type_info() -> PgTypeInfo {
35+
<[&[u8]] as Type<Postgres>>::type_info()
36+
}
37+
}
38+
3139
impl<const N: usize> PgHasArrayType for [u8; N] {
3240
fn array_type_info() -> PgTypeInfo {
3341
<[&[u8]] as Type<Postgres>>::type_info()
@@ -60,6 +68,12 @@ impl<const N: usize> Encode<'_, Postgres> for [u8; N] {
6068
}
6169
}
6270

71+
impl Encode<'_, Postgres> for Arc<[u8]> {
72+
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
73+
<&[u8] as Encode<Postgres>>::encode(self, buf)
74+
}
75+
}
76+
6377
impl<'r> Decode<'r, Postgres> for &'r [u8] {
6478
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
6579
match value.format() {
@@ -110,3 +124,12 @@ impl<const N: usize> Decode<'_, Postgres> for [u8; N] {
110124
Ok(bytes)
111125
}
112126
}
127+
128+
impl Decode<'_, Postgres> for Arc<[u8]> {
129+
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
130+
Ok(match value.format() {
131+
PgValueFormat::Binary => value.as_bytes()?.into(),
132+
PgValueFormat::Text => hex::decode(text_hex_decode_input(value)?)?.into(),
133+
})
134+
}
135+
}

sqlx-postgres/src/types/str.rs

+23
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use crate::types::array_compatible;
55
use crate::types::Type;
66
use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueRef, Postgres};
77
use std::borrow::Cow;
8+
use std::sync::Arc;
89

910
impl Type<Postgres> for str {
1011
fn type_info() -> PgTypeInfo {
@@ -94,6 +95,16 @@ impl PgHasArrayType for String {
9495
}
9596
}
9697

98+
impl PgHasArrayType for Arc<str> {
99+
fn array_type_info() -> PgTypeInfo {
100+
<&str as PgHasArrayType>::array_type_info()
101+
}
102+
103+
fn array_compatible(ty: &PgTypeInfo) -> bool {
104+
<&str as PgHasArrayType>::array_compatible(ty)
105+
}
106+
}
107+
97108
impl Encode<'_, Postgres> for &'_ str {
98109
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
99110
buf.extend(self.as_bytes());
@@ -123,6 +134,12 @@ impl Encode<'_, Postgres> for String {
123134
}
124135
}
125136

137+
impl Encode<'_, Postgres> for Arc<str> {
138+
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
139+
<&str as Encode<Postgres>>::encode(&**self, buf)
140+
}
141+
}
142+
126143
impl<'r> Decode<'r, Postgres> for &'r str {
127144
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
128145
value.as_str()
@@ -146,3 +163,9 @@ impl Decode<'_, Postgres> for String {
146163
Ok(value.as_str()?.to_owned())
147164
}
148165
}
166+
167+
impl Decode<'_, Postgres> for Arc<str> {
168+
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
169+
Ok(value.as_str()?.into())
170+
}
171+
}

sqlx-sqlite/src/types/bytes.rs

+24
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::borrow::Cow;
2+
use std::sync::Arc;
23

34
use crate::decode::Decode;
45
use crate::encode::{Encode, IsNull};
@@ -101,3 +102,26 @@ impl<'r> Decode<'r, Sqlite> for Vec<u8> {
101102
Ok(value.blob().to_owned())
102103
}
103104
}
105+
106+
impl<'q> Encode<'q, Sqlite> for Arc<[u8]> {
107+
fn encode(self, args: &mut Vec<SqliteArgumentValue<'q>>) -> Result<IsNull, BoxDynError> {
108+
args.push(SqliteArgumentValue::Blob(Cow::Owned(self.to_vec())));
109+
110+
Ok(IsNull::No)
111+
}
112+
113+
fn encode_by_ref(
114+
&self,
115+
args: &mut Vec<SqliteArgumentValue<'q>>,
116+
) -> Result<IsNull, BoxDynError> {
117+
args.push(SqliteArgumentValue::Blob(Cow::Owned(self.to_vec())));
118+
119+
Ok(IsNull::No)
120+
}
121+
}
122+
123+
impl<'r> Decode<'r, Sqlite> for Arc<[u8]> {
124+
fn decode(value: SqliteValueRef<'r>) -> Result<Self, BoxDynError> {
125+
Ok(value.blob().into())
126+
}
127+
}

sqlx-sqlite/src/types/str.rs

+24
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::borrow::Cow;
2+
use std::sync::Arc;
23

34
use crate::decode::Decode;
45
use crate::encode::{Encode, IsNull};
@@ -122,3 +123,26 @@ impl<'r> Decode<'r, Sqlite> for Cow<'r, str> {
122123
value.text().map(Cow::Borrowed)
123124
}
124125
}
126+
127+
impl<'q> Encode<'q, Sqlite> for Arc<str> {
128+
fn encode(self, args: &mut Vec<SqliteArgumentValue<'q>>) -> Result<IsNull, BoxDynError> {
129+
args.push(SqliteArgumentValue::Text(Cow::Owned(self.to_string())));
130+
131+
Ok(IsNull::No)
132+
}
133+
134+
fn encode_by_ref(
135+
&self,
136+
args: &mut Vec<SqliteArgumentValue<'q>>,
137+
) -> Result<IsNull, BoxDynError> {
138+
args.push(SqliteArgumentValue::Text(Cow::Owned(self.to_string())));
139+
140+
Ok(IsNull::No)
141+
}
142+
}
143+
144+
impl<'r> Decode<'r, Sqlite> for Arc<str> {
145+
fn decode(value: SqliteValueRef<'r>) -> Result<Self, BoxDynError> {
146+
value.text().map(Into::into)
147+
}
148+
}

tests/mysql/types.rs

+31
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ extern crate time_ as time;
33
use std::net::SocketAddr;
44
#[cfg(feature = "rust_decimal")]
55
use std::str::FromStr;
6+
use std::sync::Arc;
67

78
use sqlx::mysql::MySql;
89
use sqlx::{Executor, Row};
@@ -384,3 +385,33 @@ CREATE TEMPORARY TABLE user_login (
384385

385386
Ok(())
386387
}
388+
389+
#[sqlx_macros::test]
390+
async fn test_arc_str() -> anyhow::Result<()> {
391+
let mut conn = new::<MySql>().await?;
392+
393+
let name: Arc<str> = "Harold".into();
394+
395+
let username: Arc<str> = sqlx::query_scalar("SELECT ? AS username")
396+
.bind(&name)
397+
.fetch_one(&mut conn)
398+
.await?;
399+
400+
assert!(username == name);
401+
Ok(())
402+
}
403+
404+
#[sqlx_macros::test]
405+
async fn test_arc_slice() -> anyhow::Result<()> {
406+
let mut conn = new::<MySql>().await?;
407+
408+
let name: Arc<[u8]> = [5, 0].into();
409+
410+
let username: Arc<[u8]> = sqlx::query_scalar("SELECT ?")
411+
.bind(&name)
412+
.fetch_one(&mut conn)
413+
.await?;
414+
415+
assert!(username == name);
416+
Ok(())
417+
}

tests/postgres/types.rs

+48
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use std::sync::Arc;
88

99
use sqlx::postgres::types::{Oid, PgCiText, PgInterval, PgMoney, PgRange};
1010
use sqlx::postgres::Postgres;
11+
use sqlx_macros::FromRow;
1112
use sqlx_test::{new, test_decode_type, test_prepared_type, test_type};
1213

1314
use sqlx_core::executor::Executor;
@@ -673,6 +674,21 @@ async fn test_arc() -> anyhow::Result<()> {
673674
Ok(())
674675
}
675676

677+
#[sqlx_macros::test]
678+
async fn test_arc_str() -> anyhow::Result<()> {
679+
let mut conn = new::<Postgres>().await?;
680+
681+
let name: Arc<str> = "Harold".into();
682+
683+
let username: Arc<str> = sqlx::query_scalar("SELECT $1 AS username")
684+
.bind(&name)
685+
.fetch_one(&mut conn)
686+
.await?;
687+
688+
assert!(username == name);
689+
Ok(())
690+
}
691+
676692
#[sqlx_macros::test]
677693
async fn test_cow() -> anyhow::Result<()> {
678694
let mut conn = new::<Postgres>().await?;
@@ -688,6 +704,21 @@ async fn test_cow() -> anyhow::Result<()> {
688704
Ok(())
689705
}
690706

707+
#[sqlx_macros::test]
708+
async fn test_arc_slice() -> anyhow::Result<()> {
709+
let mut conn = new::<Postgres>().await?;
710+
711+
let name: Arc<[u8]> = [5, 0].into();
712+
713+
let username: Arc<[u8]> = sqlx::query_scalar("SELECT $1")
714+
.bind(&name)
715+
.fetch_one(&mut conn)
716+
.await?;
717+
718+
assert!(username == name);
719+
Ok(())
720+
}
721+
691722
#[sqlx_macros::test]
692723
async fn test_box() -> anyhow::Result<()> {
693724
let mut conn = new::<Postgres>().await?;
@@ -713,3 +744,20 @@ async fn test_rc() -> anyhow::Result<()> {
713744
assert!(user_age == 1);
714745
Ok(())
715746
}
747+
748+
#[sqlx_macros::test]
749+
async fn test_arc_slice_2() -> anyhow::Result<()> {
750+
let mut conn = new::<Postgres>().await?;
751+
752+
#[derive(FromRow)]
753+
struct Nested {
754+
inner: Arc<[i32]>,
755+
}
756+
757+
let username: Nested = sqlx::query_as("SELECT ARRAY[1, 2, 3]::INT4[] as inner")
758+
.fetch_one(&mut conn)
759+
.await?;
760+
761+
assert!(username.inner.as_ref() == &[1, 2, 3]);
762+
Ok(())
763+
}

0 commit comments

Comments
 (0)