From d1a3a97029553d8bb207ea10343e6bc004322ddf Mon Sep 17 00:00:00 2001 From: Austin Schey Date: Fri, 13 Dec 2024 20:30:15 -0800 Subject: [PATCH] add SqliteValueRef variant that takes a borrowed sqlite value pointer --- sqlx-sqlite/src/connection/preupdate_hook.rs | 11 +- sqlx-sqlite/src/value.rs | 139 +++++++++++++------ tests/sqlite/sqlite.rs | 99 +++++++------ 3 files changed, 154 insertions(+), 95 deletions(-) diff --git a/sqlx-sqlite/src/connection/preupdate_hook.rs b/sqlx-sqlite/src/connection/preupdate_hook.rs index 8df40e16b9..fcc0fe0bc3 100644 --- a/sqlx-sqlite/src/connection/preupdate_hook.rs +++ b/sqlx-sqlite/src/connection/preupdate_hook.rs @@ -1,6 +1,6 @@ use super::SqliteOperation; use crate::type_info::DataType; -use crate::{SqliteError, SqliteTypeInfo, SqliteValue}; +use crate::{SqliteError, SqliteTypeInfo, SqliteValueRef}; use libsqlite3_sys::{ sqlite3, sqlite3_preupdate_count, sqlite3_preupdate_depth, sqlite3_preupdate_new, @@ -77,7 +77,7 @@ impl<'a> PreupdateHookResult<'a> { /// Gets the value of the row being updated/deleted at the specified index. /// Returns an error if called from an insert operation or the index is out of bounds. - pub fn get_old_column_value(&self, i: i32) -> Result { + pub fn get_old_column_value(&self, i: i32) -> Result, PreupdateError> { if self.operation == SqliteOperation::Insert { return Err(PreupdateError::InvalidOperation); } @@ -92,7 +92,7 @@ impl<'a> PreupdateHookResult<'a> { /// Gets the value of the row being inserted/updated at the specified index. /// Returns an error if called from a delete operation or the index is out of bounds. - pub fn get_new_column_value(&self, i: i32) -> Result { + pub fn get_new_column_value(&self, i: i32) -> Result, PreupdateError> { if self.operation == SqliteOperation::Delete { return Err(PreupdateError::InvalidOperation); } @@ -116,12 +116,13 @@ impl<'a> PreupdateHookResult<'a> { &self, ret: i32, p_value: *mut sqlite3_value, - ) -> Result { + ) -> Result, PreupdateError> { if ret != SQLITE_OK { return Err(PreupdateError::Database(SqliteError::new(self.db))); } let data_type = DataType::from_code(sqlite3_value_type(p_value)); - Ok(SqliteValue::new(p_value, SqliteTypeInfo(data_type))) + // SAFETY: SQLite will free the sqlite3_value when the callback returns + Ok(SqliteValueRef::borrowed(p_value, SqliteTypeInfo(data_type))) } } diff --git a/sqlx-sqlite/src/value.rs b/sqlx-sqlite/src/value.rs index 967b3f7476..31d24b4b24 100644 --- a/sqlx-sqlite/src/value.rs +++ b/sqlx-sqlite/src/value.rs @@ -1,4 +1,5 @@ use std::borrow::Cow; +use std::marker::PhantomData; use std::ptr::NonNull; use std::slice::from_raw_parts; use std::str::from_utf8; @@ -17,6 +18,7 @@ use crate::{Sqlite, SqliteTypeInfo}; enum SqliteValueData<'r> { Value(&'r SqliteValue), + BorrowedHandle(ValueHandle<'r>), } pub struct SqliteValueRef<'r>(SqliteValueData<'r>); @@ -26,31 +28,44 @@ impl<'r> SqliteValueRef<'r> { Self(SqliteValueData::Value(value)) } + // SAFETY: The supplied sqlite3_value must not be null and SQLite must free it. It will not be freed on drop. + // The lifetime on this struct should tie it to whatever scope it's valid for before SQLite frees it. + #[allow(unused)] + pub(crate) unsafe fn borrowed(value: *mut sqlite3_value, type_info: SqliteTypeInfo) -> Self { + debug_assert!(!value.is_null()); + let handle = ValueHandle::new_borrowed(NonNull::new_unchecked(value), type_info); + Self(SqliteValueData::BorrowedHandle(handle)) + } + // NOTE: `int()` is deliberately omitted because it will silently truncate a wider value, // which is likely to cause bugs: // https://github.com/launchbadge/sqlx/issues/3179 // (Similar bug in Postgres): https://github.com/launchbadge/sqlx/issues/3161 pub(super) fn int64(&self) -> i64 { - match self.0 { - SqliteValueData::Value(v) => v.int64(), + match &self.0 { + SqliteValueData::Value(v) => v.0.int64(), + SqliteValueData::BorrowedHandle(v) => v.int64(), } } pub(super) fn double(&self) -> f64 { - match self.0 { - SqliteValueData::Value(v) => v.double(), + match &self.0 { + SqliteValueData::Value(v) => v.0.double(), + SqliteValueData::BorrowedHandle(v) => v.double(), } } pub(super) fn blob(&self) -> &'r [u8] { - match self.0 { - SqliteValueData::Value(v) => v.blob(), + match &self.0 { + SqliteValueData::Value(v) => v.0.blob(), + SqliteValueData::BorrowedHandle(v) => v.blob(), } } pub(super) fn text(&self) -> Result<&'r str, BoxDynError> { - match self.0 { - SqliteValueData::Value(v) => v.text(), + match &self.0 { + SqliteValueData::Value(v) => v.0.text(), + SqliteValueData::BorrowedHandle(v) => v.text(), } } } @@ -59,50 +74,66 @@ impl<'r> ValueRef<'r> for SqliteValueRef<'r> { type Database = Sqlite; fn to_owned(&self) -> SqliteValue { - match self.0 { - SqliteValueData::Value(v) => v.clone(), + match &self.0 { + SqliteValueData::Value(v) => (*v).clone(), + SqliteValueData::BorrowedHandle(v) => unsafe { + SqliteValue::new(v.value.as_ptr(), v.type_info.clone()) + }, } } fn type_info(&self) -> Cow<'_, SqliteTypeInfo> { - match self.0 { + match &self.0 { SqliteValueData::Value(v) => v.type_info(), + SqliteValueData::BorrowedHandle(v) => v.type_info(), } } fn is_null(&self) -> bool { - match self.0 { + match &self.0 { SqliteValueData::Value(v) => v.is_null(), + SqliteValueData::BorrowedHandle(v) => v.is_null(), } } } #[derive(Clone)] -pub struct SqliteValue { - pub(crate) handle: Arc, - pub(crate) type_info: SqliteTypeInfo, -} +pub struct SqliteValue(Arc>); -pub(crate) struct ValueHandle(NonNull); +pub(crate) struct ValueHandle<'a> { + value: NonNull, + type_info: SqliteTypeInfo, + _phantom: PhantomData<&'a ()>, + free_on_drop: bool, +} // SAFE: only protected value objects are stored in SqliteValue -unsafe impl Send for ValueHandle {} -unsafe impl Sync for ValueHandle {} +unsafe impl<'a> Send for ValueHandle<'a> {} +unsafe impl<'a> Sync for ValueHandle<'a> {} -impl SqliteValue { - pub(crate) unsafe fn new(value: *mut sqlite3_value, type_info: SqliteTypeInfo) -> Self { - debug_assert!(!value.is_null()); +impl ValueHandle<'static> { + unsafe fn new_owned(value: NonNull, type_info: SqliteTypeInfo) -> Self { + Self { + value, + type_info, + free_on_drop: true, + _phantom: PhantomData, + } + } +} +impl<'a> ValueHandle<'a> { + unsafe fn new_borrowed(value: NonNull, type_info: SqliteTypeInfo) -> Self { Self { + value, type_info, - handle: Arc::new(ValueHandle(NonNull::new_unchecked(sqlite3_value_dup( - value, - )))), + free_on_drop: false, + _phantom: PhantomData, } } fn type_info_opt(&self) -> Option { - let dt = DataType::from_code(unsafe { sqlite3_value_type(self.handle.0.as_ptr()) }); + let dt = DataType::from_code(unsafe { sqlite3_value_type(self.value.as_ptr()) }); if let DataType::Null = dt { None @@ -112,15 +143,15 @@ impl SqliteValue { } fn int64(&self) -> i64 { - unsafe { sqlite3_value_int64(self.handle.0.as_ptr()) } + unsafe { sqlite3_value_int64(self.value.as_ptr()) } } fn double(&self) -> f64 { - unsafe { sqlite3_value_double(self.handle.0.as_ptr()) } + unsafe { sqlite3_value_double(self.value.as_ptr()) } } - fn blob(&self) -> &[u8] { - let len = unsafe { sqlite3_value_bytes(self.handle.0.as_ptr()) }; + fn blob<'b>(&self) -> &'b [u8] { + let len = unsafe { sqlite3_value_bytes(self.value.as_ptr()) }; // This likely means UB in SQLite itself or our usage of it; // signed integer overflow is UB in the C standard. @@ -133,23 +164,15 @@ impl SqliteValue { return &[]; } - let ptr = unsafe { sqlite3_value_blob(self.handle.0.as_ptr()) } as *const u8; + let ptr = unsafe { sqlite3_value_blob(self.value.as_ptr()) } as *const u8; debug_assert!(!ptr.is_null()); unsafe { from_raw_parts(ptr, len) } } - fn text(&self) -> Result<&str, BoxDynError> { + fn text<'b>(&self) -> Result<&'b str, BoxDynError> { Ok(from_utf8(self.blob())?) } -} - -impl Value for SqliteValue { - type Database = Sqlite; - - fn as_ref(&self) -> SqliteValueRef<'_> { - SqliteValueRef::value(self) - } fn type_info(&self) -> Cow<'_, SqliteTypeInfo> { self.type_info_opt() @@ -158,18 +181,46 @@ impl Value for SqliteValue { } fn is_null(&self) -> bool { - unsafe { sqlite3_value_type(self.handle.0.as_ptr()) == SQLITE_NULL } + unsafe { sqlite3_value_type(self.value.as_ptr()) == SQLITE_NULL } } } -impl Drop for ValueHandle { +impl<'a> Drop for ValueHandle<'a> { fn drop(&mut self) { - unsafe { - sqlite3_value_free(self.0.as_ptr()); + if self.free_on_drop { + unsafe { + sqlite3_value_free(self.value.as_ptr()); + } } } } +impl SqliteValue { + // SAFETY: The sqlite3_value must be non-null and SQLite must not free it. It will be freed on drop. + pub(crate) unsafe fn new(value: *mut sqlite3_value, type_info: SqliteTypeInfo) -> Self { + debug_assert!(!value.is_null()); + let handle = + ValueHandle::new_owned(NonNull::new_unchecked(sqlite3_value_dup(value)), type_info); + Self(Arc::new(handle)) + } +} + +impl Value for SqliteValue { + type Database = Sqlite; + + fn as_ref(&self) -> SqliteValueRef<'_> { + SqliteValueRef::value(self) + } + + fn type_info(&self) -> Cow<'_, SqliteTypeInfo> { + self.0.type_info() + } + + fn is_null(&self) -> bool { + self.0.is_null() + } +} + // #[cfg(feature = "any")] // impl<'r> From> for crate::any::AnyValueRef<'r> { // #[inline] diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index 043520740b..d78e1151a9 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -3,11 +3,11 @@ use rand::{Rng, SeedableRng}; use rand_xoshiro::Xoshiro256PlusPlus; use sqlx::sqlite::{SqliteConnectOptions, SqliteOperation, SqlitePoolOptions}; use sqlx::Decode; -use sqlx::Value; use sqlx::{ query, sqlite::Sqlite, sqlite::SqliteRow, Column, ConnectOptions, Connection, Executor, Row, SqliteConnection, SqlitePool, Statement, TypeInfo, }; +use sqlx::{Value, ValueRef}; use sqlx_test::new; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; @@ -989,15 +989,12 @@ async fn test_query_with_preupdate_hook_insert() -> anyhow::Result<()> { assert_eq!(0, result.get_query_depth()); assert_eq!( 4, - >::decode(result.get_new_column_value(0).unwrap().as_ref(),) - .unwrap() + >::decode(result.get_new_column_value(0).unwrap()).unwrap() ); assert_eq!( "Hello, World", - >::decode( - result.get_new_column_value(1).unwrap().as_ref(), - ) - .unwrap() + >::decode(result.get_new_column_value(1).unwrap()) + .unwrap() ); // out of bounds access should return an error assert!(result.get_new_column_value(4).is_err()); @@ -1042,13 +1039,11 @@ async fn test_query_with_preupdate_hook_delete() -> anyhow::Result<()> { assert_eq!(0, result.get_query_depth()); assert_eq!( 5, - >::decode(result.get_old_column_value(0).unwrap().as_ref(),) - .unwrap() + >::decode(result.get_old_column_value(0).unwrap()).unwrap() ); assert_eq!( "Hello, World", - >::decode(result.get_old_column_value(1).unwrap().as_ref(),) - .unwrap() + >::decode(result.get_old_column_value(1).unwrap()).unwrap() ); // out of bounds access should return an error assert!(result.get_old_column_value(4).is_err()); @@ -1074,50 +1069,54 @@ async fn test_query_with_preupdate_hook_update() -> anyhow::Result<()> { .execute(&mut conn) .await?; static CALLED: AtomicBool = AtomicBool::new(false); + let sqlite_value_stored: Arc>> = Default::default(); // Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer. let state = format!("test"); - conn.lock_handle().await?.set_preupdate_hook(move |result| { - assert_eq!(state, "test"); - assert_eq!(result.operation, SqliteOperation::Update); - assert_eq!(result.database, "main"); - assert_eq!(result.table, "tweet"); + conn.lock_handle().await?.set_preupdate_hook({ + let sqlite_value_stored = sqlite_value_stored.clone(); + move |result| { + assert_eq!(state, "test"); + assert_eq!(result.operation, SqliteOperation::Update); + assert_eq!(result.database, "main"); + assert_eq!(result.table, "tweet"); - assert_eq!(4, result.get_column_count()); - assert_eq!(4, result.get_column_count()); + assert_eq!(4, result.get_column_count()); + assert_eq!(4, result.get_column_count()); - assert_eq!(2, result.get_old_row_id().unwrap()); - assert_eq!(2, result.get_new_row_id().unwrap()); + assert_eq!(2, result.get_old_row_id().unwrap()); + assert_eq!(2, result.get_new_row_id().unwrap()); - assert_eq!(0, result.get_query_depth()); - assert_eq!(0, result.get_query_depth()); + assert_eq!(0, result.get_query_depth()); + assert_eq!(0, result.get_query_depth()); - assert_eq!( - 6, - >::decode(result.get_old_column_value(0).unwrap().as_ref(),) - .unwrap() - ); - assert_eq!( - 6, - >::decode(result.get_new_column_value(0).unwrap().as_ref(),) - .unwrap() - ); + assert_eq!( + 6, + >::decode(result.get_old_column_value(0).unwrap()).unwrap() + ); + assert_eq!( + 6, + >::decode(result.get_new_column_value(0).unwrap()).unwrap() + ); - assert_eq!( - "Hello, World", - >::decode(result.get_old_column_value(1).unwrap().as_ref(),) - .unwrap() - ); - assert_eq!( - "Hello, World2", - >::decode(result.get_new_column_value(1).unwrap().as_ref(),) - .unwrap() - ); + assert_eq!( + "Hello, World", + >::decode(result.get_old_column_value(1).unwrap()) + .unwrap() + ); + assert_eq!( + "Hello, World2", + >::decode(result.get_new_column_value(1).unwrap()) + .unwrap() + ); + *sqlite_value_stored.lock().unwrap() = + Some(result.get_old_column_value(0).unwrap().to_owned()); - // out of bounds access should return an error - assert!(result.get_old_column_value(4).is_err()); - assert!(result.get_new_column_value(4).is_err()); + // out of bounds access should return an error + assert!(result.get_old_column_value(4).is_err()); + assert!(result.get_new_column_value(4).is_err()); - CALLED.store(true, Ordering::Relaxed); + CALLED.store(true, Ordering::Relaxed); + } }); let _ = sqlx::query("UPDATE tweet SET text = 'Hello, World2' WHERE id = 6") @@ -1129,6 +1128,14 @@ async fn test_query_with_preupdate_hook_update() -> anyhow::Result<()> { let _ = sqlx::query("DELETE FROM tweet where id = 6") .execute(&mut conn) .await?; + // Ensure that taking an owned SqliteValue maintains a valid reference after the callback returns + assert_eq!( + 6, + >::decode( + sqlite_value_stored.lock().unwrap().take().unwrap().as_ref() + ) + .unwrap() + ); Ok(()) }