Skip to content

Commit

Permalink
add SqliteValueRef variant that takes a borrowed sqlite value pointer
Browse files Browse the repository at this point in the history
  • Loading branch information
aschey committed Dec 14, 2024
1 parent 46df6f8 commit d1a3a97
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 95 deletions.
11 changes: 6 additions & 5 deletions sqlx-sqlite/src/connection/preupdate_hook.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::SqliteOperation;
use crate::type_info::DataType;
use crate::{SqliteError, SqliteTypeInfo, SqliteValue};
use crate::{SqliteError, SqliteTypeInfo, SqliteValueRef};

use libsqlite3_sys::{
sqlite3, sqlite3_preupdate_count, sqlite3_preupdate_depth, sqlite3_preupdate_new,
Expand Down Expand Up @@ -77,7 +77,7 @@ impl<'a> PreupdateHookResult<'a> {

/// Gets the value of the row being updated/deleted at the specified index.
/// Returns an error if called from an insert operation or the index is out of bounds.
pub fn get_old_column_value(&self, i: i32) -> Result<SqliteValue, PreupdateError> {
pub fn get_old_column_value(&self, i: i32) -> Result<SqliteValueRef<'a>, PreupdateError> {
if self.operation == SqliteOperation::Insert {
return Err(PreupdateError::InvalidOperation);
}
Expand All @@ -92,7 +92,7 @@ impl<'a> PreupdateHookResult<'a> {

/// Gets the value of the row being inserted/updated at the specified index.
/// Returns an error if called from a delete operation or the index is out of bounds.
pub fn get_new_column_value(&self, i: i32) -> Result<SqliteValue, PreupdateError> {
pub fn get_new_column_value(&self, i: i32) -> Result<SqliteValueRef<'a>, PreupdateError> {
if self.operation == SqliteOperation::Delete {
return Err(PreupdateError::InvalidOperation);
}
Expand All @@ -116,12 +116,13 @@ impl<'a> PreupdateHookResult<'a> {
&self,
ret: i32,
p_value: *mut sqlite3_value,
) -> Result<SqliteValue, PreupdateError> {
) -> Result<SqliteValueRef<'a>, PreupdateError> {
if ret != SQLITE_OK {
return Err(PreupdateError::Database(SqliteError::new(self.db)));
}
let data_type = DataType::from_code(sqlite3_value_type(p_value));
Ok(SqliteValue::new(p_value, SqliteTypeInfo(data_type)))
// SAFETY: SQLite will free the sqlite3_value when the callback returns
Ok(SqliteValueRef::borrowed(p_value, SqliteTypeInfo(data_type)))
}
}

Expand Down
139 changes: 95 additions & 44 deletions sqlx-sqlite/src/value.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::borrow::Cow;
use std::marker::PhantomData;
use std::ptr::NonNull;
use std::slice::from_raw_parts;
use std::str::from_utf8;
Expand All @@ -17,6 +18,7 @@ use crate::{Sqlite, SqliteTypeInfo};

enum SqliteValueData<'r> {
Value(&'r SqliteValue),
BorrowedHandle(ValueHandle<'r>),
}

pub struct SqliteValueRef<'r>(SqliteValueData<'r>);
Expand All @@ -26,31 +28,44 @@ impl<'r> SqliteValueRef<'r> {
Self(SqliteValueData::Value(value))
}

// SAFETY: The supplied sqlite3_value must not be null and SQLite must free it. It will not be freed on drop.
// The lifetime on this struct should tie it to whatever scope it's valid for before SQLite frees it.
#[allow(unused)]
pub(crate) unsafe fn borrowed(value: *mut sqlite3_value, type_info: SqliteTypeInfo) -> Self {
debug_assert!(!value.is_null());
let handle = ValueHandle::new_borrowed(NonNull::new_unchecked(value), type_info);
Self(SqliteValueData::BorrowedHandle(handle))
}

// NOTE: `int()` is deliberately omitted because it will silently truncate a wider value,
// which is likely to cause bugs:
// https://github.com/launchbadge/sqlx/issues/3179
// (Similar bug in Postgres): https://github.com/launchbadge/sqlx/issues/3161
pub(super) fn int64(&self) -> i64 {
match self.0 {
SqliteValueData::Value(v) => v.int64(),
match &self.0 {
SqliteValueData::Value(v) => v.0.int64(),
SqliteValueData::BorrowedHandle(v) => v.int64(),
}
}

pub(super) fn double(&self) -> f64 {
match self.0 {
SqliteValueData::Value(v) => v.double(),
match &self.0 {
SqliteValueData::Value(v) => v.0.double(),
SqliteValueData::BorrowedHandle(v) => v.double(),
}
}

pub(super) fn blob(&self) -> &'r [u8] {
match self.0 {
SqliteValueData::Value(v) => v.blob(),
match &self.0 {
SqliteValueData::Value(v) => v.0.blob(),
SqliteValueData::BorrowedHandle(v) => v.blob(),
}
}

pub(super) fn text(&self) -> Result<&'r str, BoxDynError> {
match self.0 {
SqliteValueData::Value(v) => v.text(),
match &self.0 {
SqliteValueData::Value(v) => v.0.text(),
SqliteValueData::BorrowedHandle(v) => v.text(),
}
}
}
Expand All @@ -59,50 +74,66 @@ impl<'r> ValueRef<'r> for SqliteValueRef<'r> {
type Database = Sqlite;

fn to_owned(&self) -> SqliteValue {
match self.0 {
SqliteValueData::Value(v) => v.clone(),
match &self.0 {
SqliteValueData::Value(v) => (*v).clone(),
SqliteValueData::BorrowedHandle(v) => unsafe {
SqliteValue::new(v.value.as_ptr(), v.type_info.clone())
},
}
}

fn type_info(&self) -> Cow<'_, SqliteTypeInfo> {
match self.0 {
match &self.0 {
SqliteValueData::Value(v) => v.type_info(),
SqliteValueData::BorrowedHandle(v) => v.type_info(),
}
}

fn is_null(&self) -> bool {
match self.0 {
match &self.0 {
SqliteValueData::Value(v) => v.is_null(),
SqliteValueData::BorrowedHandle(v) => v.is_null(),
}
}
}

#[derive(Clone)]
pub struct SqliteValue {
pub(crate) handle: Arc<ValueHandle>,
pub(crate) type_info: SqliteTypeInfo,
}
pub struct SqliteValue(Arc<ValueHandle<'static>>);

pub(crate) struct ValueHandle(NonNull<sqlite3_value>);
pub(crate) struct ValueHandle<'a> {
value: NonNull<sqlite3_value>,
type_info: SqliteTypeInfo,
_phantom: PhantomData<&'a ()>,
free_on_drop: bool,
}

// SAFE: only protected value objects are stored in SqliteValue
unsafe impl Send for ValueHandle {}
unsafe impl Sync for ValueHandle {}
unsafe impl<'a> Send for ValueHandle<'a> {}
unsafe impl<'a> Sync for ValueHandle<'a> {}

impl SqliteValue {
pub(crate) unsafe fn new(value: *mut sqlite3_value, type_info: SqliteTypeInfo) -> Self {
debug_assert!(!value.is_null());
impl ValueHandle<'static> {
unsafe fn new_owned(value: NonNull<sqlite3_value>, type_info: SqliteTypeInfo) -> Self {
Self {
value,
type_info,
free_on_drop: true,
_phantom: PhantomData,
}
}
}

impl<'a> ValueHandle<'a> {
unsafe fn new_borrowed(value: NonNull<sqlite3_value>, type_info: SqliteTypeInfo) -> Self {
Self {
value,
type_info,
handle: Arc::new(ValueHandle(NonNull::new_unchecked(sqlite3_value_dup(
value,
)))),
free_on_drop: false,
_phantom: PhantomData,
}
}

fn type_info_opt(&self) -> Option<SqliteTypeInfo> {
let dt = DataType::from_code(unsafe { sqlite3_value_type(self.handle.0.as_ptr()) });
let dt = DataType::from_code(unsafe { sqlite3_value_type(self.value.as_ptr()) });

if let DataType::Null = dt {
None
Expand All @@ -112,15 +143,15 @@ impl SqliteValue {
}

fn int64(&self) -> i64 {
unsafe { sqlite3_value_int64(self.handle.0.as_ptr()) }
unsafe { sqlite3_value_int64(self.value.as_ptr()) }
}

fn double(&self) -> f64 {
unsafe { sqlite3_value_double(self.handle.0.as_ptr()) }
unsafe { sqlite3_value_double(self.value.as_ptr()) }
}

fn blob(&self) -> &[u8] {
let len = unsafe { sqlite3_value_bytes(self.handle.0.as_ptr()) };
fn blob<'b>(&self) -> &'b [u8] {
let len = unsafe { sqlite3_value_bytes(self.value.as_ptr()) };

// This likely means UB in SQLite itself or our usage of it;
// signed integer overflow is UB in the C standard.
Expand All @@ -133,23 +164,15 @@ impl SqliteValue {
return &[];
}

let ptr = unsafe { sqlite3_value_blob(self.handle.0.as_ptr()) } as *const u8;
let ptr = unsafe { sqlite3_value_blob(self.value.as_ptr()) } as *const u8;
debug_assert!(!ptr.is_null());

unsafe { from_raw_parts(ptr, len) }
}

fn text(&self) -> Result<&str, BoxDynError> {
fn text<'b>(&self) -> Result<&'b str, BoxDynError> {
Ok(from_utf8(self.blob())?)
}
}

impl Value for SqliteValue {
type Database = Sqlite;

fn as_ref(&self) -> SqliteValueRef<'_> {
SqliteValueRef::value(self)
}

fn type_info(&self) -> Cow<'_, SqliteTypeInfo> {
self.type_info_opt()
Expand All @@ -158,18 +181,46 @@ impl Value for SqliteValue {
}

fn is_null(&self) -> bool {
unsafe { sqlite3_value_type(self.handle.0.as_ptr()) == SQLITE_NULL }
unsafe { sqlite3_value_type(self.value.as_ptr()) == SQLITE_NULL }
}
}

impl Drop for ValueHandle {
impl<'a> Drop for ValueHandle<'a> {
fn drop(&mut self) {
unsafe {
sqlite3_value_free(self.0.as_ptr());
if self.free_on_drop {
unsafe {
sqlite3_value_free(self.value.as_ptr());
}
}
}
}

impl SqliteValue {
// SAFETY: The sqlite3_value must be non-null and SQLite must not free it. It will be freed on drop.
pub(crate) unsafe fn new(value: *mut sqlite3_value, type_info: SqliteTypeInfo) -> Self {
debug_assert!(!value.is_null());
let handle =
ValueHandle::new_owned(NonNull::new_unchecked(sqlite3_value_dup(value)), type_info);
Self(Arc::new(handle))
}
}

impl Value for SqliteValue {
type Database = Sqlite;

fn as_ref(&self) -> SqliteValueRef<'_> {
SqliteValueRef::value(self)
}

fn type_info(&self) -> Cow<'_, SqliteTypeInfo> {
self.0.type_info()
}

fn is_null(&self) -> bool {
self.0.is_null()
}
}

// #[cfg(feature = "any")]
// impl<'r> From<SqliteValueRef<'r>> for crate::any::AnyValueRef<'r> {
// #[inline]
Expand Down
Loading

0 comments on commit d1a3a97

Please sign in to comment.