feat: Add set_update_hook on SqliteConnection (#3260)

* feat: Add set_update_hook on SqliteConnection

* refactor: Address PR comments

* fix: Expose UpdateHookResult for public use

---------

Co-authored-by: John Smith <asserta4@gmail.com>
This commit is contained in:
gridbox 2024-06-05 22:06:15 -04:00 committed by GitHub
parent 8b7f352be8
commit 0ea90881c1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 152 additions and 4 deletions

View file

@ -294,6 +294,7 @@ impl EstablishParams {
transaction_depth: 0,
log_settings: self.log_settings.clone(),
progress_handler_callback: None,
update_hook_callback: None
})
}
}

View file

@ -1,4 +1,5 @@
use std::cmp::Ordering;
use std::ffi::CStr;
use std::fmt::Write;
use std::fmt::{self, Debug, Formatter};
use std::os::raw::{c_int, c_void};
@ -8,7 +9,10 @@ use std::ptr::NonNull;
use futures_core::future::BoxFuture;
use futures_intrusive::sync::MutexGuard;
use futures_util::future;
use libsqlite3_sys::{sqlite3, sqlite3_progress_handler};
use libsqlite3_sys::{
sqlite3, sqlite3_progress_handler, sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT,
SQLITE_UPDATE,
};
pub(crate) use handle::ConnectionHandle;
use sqlx_core::common::StatementCache;
@ -58,6 +62,34 @@ pub struct LockedSqliteHandle<'a> {
pub(crate) struct Handler(NonNull<dyn FnMut() -> bool + Send + 'static>);
unsafe impl Send for Handler {}
#[derive(Debug, PartialEq, Eq)]
pub enum SqliteOperation {
Insert,
Update,
Delete,
Unknown(i32),
}
impl From<i32> for SqliteOperation {
fn from(value: i32) -> Self {
match value {
SQLITE_INSERT => SqliteOperation::Insert,
SQLITE_UPDATE => SqliteOperation::Update,
SQLITE_DELETE => SqliteOperation::Delete,
code => SqliteOperation::Unknown(code),
}
}
}
pub struct UpdateHookResult<'a> {
pub operation: SqliteOperation,
pub database: &'a str,
pub table: &'a str,
pub rowid: i64,
}
pub(crate) struct UpdateHookHandler(NonNull<dyn FnMut(UpdateHookResult) + Send + 'static>);
unsafe impl Send for UpdateHookHandler {}
pub(crate) struct ConnectionState {
pub(crate) handle: ConnectionHandle,
@ -71,6 +103,8 @@ pub(crate) struct ConnectionState {
/// Stores the progress handler set on the current connection. If the handler returns `false`,
/// the query is interrupted.
progress_handler_callback: Option<Handler>,
update_hook_callback: Option<UpdateHookHandler>,
}
impl ConnectionState {
@ -78,7 +112,16 @@ impl ConnectionState {
pub(crate) fn remove_progress_handler(&mut self) {
if let Some(mut handler) = self.progress_handler_callback.take() {
unsafe {
sqlite3_progress_handler(self.handle.as_ptr(), 0, None, 0 as *mut _);
sqlite3_progress_handler(self.handle.as_ptr(), 0, None, std::ptr::null_mut());
let _ = { Box::from_raw(handler.0.as_mut()) };
}
}
}
pub(crate) fn remove_update_hook(&mut self) {
if let Some(mut handler) = self.update_hook_callback.take() {
unsafe {
sqlite3_update_hook(self.handle.as_ptr(), None, std::ptr::null_mut());
let _ = { Box::from_raw(handler.0.as_mut()) };
}
}
@ -215,6 +258,31 @@ where
}
}
extern "C" fn update_hook<F>(
callback: *mut c_void,
op_code: c_int,
database: *const i8,
table: *const i8,
rowid: i64,
) where
F: FnMut(UpdateHookResult),
{
unsafe {
let _ = catch_unwind(|| {
let callback: *mut F = callback.cast::<F>();
let operation: SqliteOperation = op_code.into();
let database = CStr::from_ptr(database).to_str().unwrap_or_default();
let table = CStr::from_ptr(table).to_str().unwrap_or_default();
(*callback)(UpdateHookResult {
operation,
database,
table,
rowid,
})
});
}
}
impl LockedSqliteHandle<'_> {
/// Returns the underlying sqlite3* connection handle.
///
@ -279,10 +347,34 @@ impl LockedSqliteHandle<'_> {
}
}
pub fn set_update_hook<F>(&mut self, callback: F)
where
F: FnMut(UpdateHookResult) + Send + 'static,
{
unsafe {
let callback_boxed = Box::new(callback);
// SAFETY: `Box::into_raw()` always returns a non-null pointer.
let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed));
let handler = callback.as_ptr() as *mut _;
self.guard.remove_update_hook();
self.guard.update_hook_callback = Some(UpdateHookHandler(callback));
sqlite3_update_hook(
self.as_raw_handle().as_mut(),
Some(update_hook::<F>),
handler,
);
}
}
/// Removes the progress handler on a database connection. The method does nothing if no handler was set.
pub fn remove_progress_handler(&mut self) {
self.guard.remove_progress_handler();
}
pub fn remove_update_hook(&mut self) {
self.guard.remove_update_hook();
}
}
impl Drop for ConnectionState {
@ -290,6 +382,7 @@ impl Drop for ConnectionState {
// explicitly drop statements before the connection handle is dropped
self.statements.clear();
self.remove_progress_handler();
self.remove_update_hook();
}
}

View file

@ -33,7 +33,7 @@ use std::sync::atomic::AtomicBool;
pub use arguments::{SqliteArgumentValue, SqliteArguments};
pub use column::SqliteColumn;
pub use connection::{LockedSqliteHandle, SqliteConnection};
pub use connection::{LockedSqliteHandle, SqliteConnection, SqliteOperation, UpdateHookResult};
pub use database::Sqlite;
pub use error::SqliteError;
pub use options::{

View file

@ -1,7 +1,7 @@
use futures::TryStreamExt;
use rand::{Rng, SeedableRng};
use rand_xoshiro::Xoshiro256PlusPlus;
use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
use sqlx::sqlite::{SqliteConnectOptions, SqliteOperation, SqlitePoolOptions};
use sqlx::{
query, sqlite::Sqlite, sqlite::SqliteRow, Column, ConnectOptions, Connection, Executor, Row,
SqliteConnection, SqlitePool, Statement, TypeInfo,
@ -794,3 +794,57 @@ async fn test_multiple_set_progress_handler_calls_drop_old_handler() -> anyhow::
assert_eq!(1, Arc::strong_count(&ref_counted_object));
Ok(())
}
#[sqlx_macros::test]
async fn test_query_with_update_hook() -> anyhow::Result<()> {
let mut conn = new::<Sqlite>().await?;
// Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer.
let state = format!("test");
conn.lock_handle().await?.set_update_hook(move |result| {
assert_eq!(state, "test");
assert_eq!(result.operation, SqliteOperation::Insert);
assert_eq!(result.database, "main");
assert_eq!(result.table, "tweet");
assert_eq!(result.rowid, 3);
});
let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 3, 'Hello, World' )")
.execute(&mut conn)
.await?;
Ok(())
}
#[sqlx_macros::test]
async fn test_multiple_set_update_hook_calls_drop_old_handler() -> anyhow::Result<()> {
let ref_counted_object = Arc::new(0);
assert_eq!(1, Arc::strong_count(&ref_counted_object));
{
let mut conn = new::<Sqlite>().await?;
let o = ref_counted_object.clone();
conn.lock_handle().await?.set_update_hook(move |_| {
println!("{o:?}");
});
assert_eq!(2, Arc::strong_count(&ref_counted_object));
let o = ref_counted_object.clone();
conn.lock_handle().await?.set_update_hook(move |_| {
println!("{o:?}");
});
assert_eq!(2, Arc::strong_count(&ref_counted_object));
let o = ref_counted_object.clone();
conn.lock_handle().await?.set_update_hook(move |_| {
println!("{o:?}");
});
assert_eq!(2, Arc::strong_count(&ref_counted_object));
conn.lock_handle().await?.remove_update_hook();
}
assert_eq!(1, Arc::strong_count(&ref_counted_object));
Ok(())
}