mirror of
https://github.com/launchbadge/sqlx
synced 2024-11-10 06:24:16 +00:00
feat: Add set_update_hook on SqliteConnection (#3260)
* feat: Add set_update_hook on SqliteConnection * refactor: Address PR comments * fix: Expose UpdateHookResult for public use --------- Co-authored-by: John Smith <asserta4@gmail.com>
This commit is contained in:
parent
8b7f352be8
commit
0ea90881c1
4 changed files with 152 additions and 4 deletions
|
@ -294,6 +294,7 @@ impl EstablishParams {
|
|||
transaction_depth: 0,
|
||||
log_settings: self.log_settings.clone(),
|
||||
progress_handler_callback: None,
|
||||
update_hook_callback: None
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use std::cmp::Ordering;
|
||||
use std::ffi::CStr;
|
||||
use std::fmt::Write;
|
||||
use std::fmt::{self, Debug, Formatter};
|
||||
use std::os::raw::{c_int, c_void};
|
||||
|
@ -8,7 +9,10 @@ use std::ptr::NonNull;
|
|||
use futures_core::future::BoxFuture;
|
||||
use futures_intrusive::sync::MutexGuard;
|
||||
use futures_util::future;
|
||||
use libsqlite3_sys::{sqlite3, sqlite3_progress_handler};
|
||||
use libsqlite3_sys::{
|
||||
sqlite3, sqlite3_progress_handler, sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT,
|
||||
SQLITE_UPDATE,
|
||||
};
|
||||
|
||||
pub(crate) use handle::ConnectionHandle;
|
||||
use sqlx_core::common::StatementCache;
|
||||
|
@ -58,6 +62,34 @@ pub struct LockedSqliteHandle<'a> {
|
|||
pub(crate) struct Handler(NonNull<dyn FnMut() -> bool + Send + 'static>);
|
||||
unsafe impl Send for Handler {}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum SqliteOperation {
|
||||
Insert,
|
||||
Update,
|
||||
Delete,
|
||||
Unknown(i32),
|
||||
}
|
||||
|
||||
impl From<i32> for SqliteOperation {
|
||||
fn from(value: i32) -> Self {
|
||||
match value {
|
||||
SQLITE_INSERT => SqliteOperation::Insert,
|
||||
SQLITE_UPDATE => SqliteOperation::Update,
|
||||
SQLITE_DELETE => SqliteOperation::Delete,
|
||||
code => SqliteOperation::Unknown(code),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct UpdateHookResult<'a> {
|
||||
pub operation: SqliteOperation,
|
||||
pub database: &'a str,
|
||||
pub table: &'a str,
|
||||
pub rowid: i64,
|
||||
}
|
||||
pub(crate) struct UpdateHookHandler(NonNull<dyn FnMut(UpdateHookResult) + Send + 'static>);
|
||||
unsafe impl Send for UpdateHookHandler {}
|
||||
|
||||
pub(crate) struct ConnectionState {
|
||||
pub(crate) handle: ConnectionHandle,
|
||||
|
||||
|
@ -71,6 +103,8 @@ pub(crate) struct ConnectionState {
|
|||
/// Stores the progress handler set on the current connection. If the handler returns `false`,
|
||||
/// the query is interrupted.
|
||||
progress_handler_callback: Option<Handler>,
|
||||
|
||||
update_hook_callback: Option<UpdateHookHandler>,
|
||||
}
|
||||
|
||||
impl ConnectionState {
|
||||
|
@ -78,7 +112,16 @@ impl ConnectionState {
|
|||
pub(crate) fn remove_progress_handler(&mut self) {
|
||||
if let Some(mut handler) = self.progress_handler_callback.take() {
|
||||
unsafe {
|
||||
sqlite3_progress_handler(self.handle.as_ptr(), 0, None, 0 as *mut _);
|
||||
sqlite3_progress_handler(self.handle.as_ptr(), 0, None, std::ptr::null_mut());
|
||||
let _ = { Box::from_raw(handler.0.as_mut()) };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn remove_update_hook(&mut self) {
|
||||
if let Some(mut handler) = self.update_hook_callback.take() {
|
||||
unsafe {
|
||||
sqlite3_update_hook(self.handle.as_ptr(), None, std::ptr::null_mut());
|
||||
let _ = { Box::from_raw(handler.0.as_mut()) };
|
||||
}
|
||||
}
|
||||
|
@ -215,6 +258,31 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
extern "C" fn update_hook<F>(
|
||||
callback: *mut c_void,
|
||||
op_code: c_int,
|
||||
database: *const i8,
|
||||
table: *const i8,
|
||||
rowid: i64,
|
||||
) where
|
||||
F: FnMut(UpdateHookResult),
|
||||
{
|
||||
unsafe {
|
||||
let _ = catch_unwind(|| {
|
||||
let callback: *mut F = callback.cast::<F>();
|
||||
let operation: SqliteOperation = op_code.into();
|
||||
let database = CStr::from_ptr(database).to_str().unwrap_or_default();
|
||||
let table = CStr::from_ptr(table).to_str().unwrap_or_default();
|
||||
(*callback)(UpdateHookResult {
|
||||
operation,
|
||||
database,
|
||||
table,
|
||||
rowid,
|
||||
})
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl LockedSqliteHandle<'_> {
|
||||
/// Returns the underlying sqlite3* connection handle.
|
||||
///
|
||||
|
@ -279,10 +347,34 @@ impl LockedSqliteHandle<'_> {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn set_update_hook<F>(&mut self, callback: F)
|
||||
where
|
||||
F: FnMut(UpdateHookResult) + Send + 'static,
|
||||
{
|
||||
unsafe {
|
||||
let callback_boxed = Box::new(callback);
|
||||
// SAFETY: `Box::into_raw()` always returns a non-null pointer.
|
||||
let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed));
|
||||
let handler = callback.as_ptr() as *mut _;
|
||||
self.guard.remove_update_hook();
|
||||
self.guard.update_hook_callback = Some(UpdateHookHandler(callback));
|
||||
|
||||
sqlite3_update_hook(
|
||||
self.as_raw_handle().as_mut(),
|
||||
Some(update_hook::<F>),
|
||||
handler,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes the progress handler on a database connection. The method does nothing if no handler was set.
|
||||
pub fn remove_progress_handler(&mut self) {
|
||||
self.guard.remove_progress_handler();
|
||||
}
|
||||
|
||||
pub fn remove_update_hook(&mut self) {
|
||||
self.guard.remove_update_hook();
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ConnectionState {
|
||||
|
@ -290,6 +382,7 @@ impl Drop for ConnectionState {
|
|||
// explicitly drop statements before the connection handle is dropped
|
||||
self.statements.clear();
|
||||
self.remove_progress_handler();
|
||||
self.remove_update_hook();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ use std::sync::atomic::AtomicBool;
|
|||
|
||||
pub use arguments::{SqliteArgumentValue, SqliteArguments};
|
||||
pub use column::SqliteColumn;
|
||||
pub use connection::{LockedSqliteHandle, SqliteConnection};
|
||||
pub use connection::{LockedSqliteHandle, SqliteConnection, SqliteOperation, UpdateHookResult};
|
||||
pub use database::Sqlite;
|
||||
pub use error::SqliteError;
|
||||
pub use options::{
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use futures::TryStreamExt;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use rand_xoshiro::Xoshiro256PlusPlus;
|
||||
use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
|
||||
use sqlx::sqlite::{SqliteConnectOptions, SqliteOperation, SqlitePoolOptions};
|
||||
use sqlx::{
|
||||
query, sqlite::Sqlite, sqlite::SqliteRow, Column, ConnectOptions, Connection, Executor, Row,
|
||||
SqliteConnection, SqlitePool, Statement, TypeInfo,
|
||||
|
@ -794,3 +794,57 @@ async fn test_multiple_set_progress_handler_calls_drop_old_handler() -> anyhow::
|
|||
assert_eq!(1, Arc::strong_count(&ref_counted_object));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[sqlx_macros::test]
|
||||
async fn test_query_with_update_hook() -> anyhow::Result<()> {
|
||||
let mut conn = new::<Sqlite>().await?;
|
||||
|
||||
// Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer.
|
||||
let state = format!("test");
|
||||
conn.lock_handle().await?.set_update_hook(move |result| {
|
||||
assert_eq!(state, "test");
|
||||
assert_eq!(result.operation, SqliteOperation::Insert);
|
||||
assert_eq!(result.database, "main");
|
||||
assert_eq!(result.table, "tweet");
|
||||
assert_eq!(result.rowid, 3);
|
||||
});
|
||||
|
||||
let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 3, 'Hello, World' )")
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[sqlx_macros::test]
|
||||
async fn test_multiple_set_update_hook_calls_drop_old_handler() -> anyhow::Result<()> {
|
||||
let ref_counted_object = Arc::new(0);
|
||||
assert_eq!(1, Arc::strong_count(&ref_counted_object));
|
||||
|
||||
{
|
||||
let mut conn = new::<Sqlite>().await?;
|
||||
|
||||
let o = ref_counted_object.clone();
|
||||
conn.lock_handle().await?.set_update_hook(move |_| {
|
||||
println!("{o:?}");
|
||||
});
|
||||
assert_eq!(2, Arc::strong_count(&ref_counted_object));
|
||||
|
||||
let o = ref_counted_object.clone();
|
||||
conn.lock_handle().await?.set_update_hook(move |_| {
|
||||
println!("{o:?}");
|
||||
});
|
||||
assert_eq!(2, Arc::strong_count(&ref_counted_object));
|
||||
|
||||
let o = ref_counted_object.clone();
|
||||
conn.lock_handle().await?.set_update_hook(move |_| {
|
||||
println!("{o:?}");
|
||||
});
|
||||
assert_eq!(2, Arc::strong_count(&ref_counted_object));
|
||||
|
||||
conn.lock_handle().await?.remove_update_hook();
|
||||
}
|
||||
|
||||
assert_eq!(1, Arc::strong_count(&ref_counted_object));
|
||||
Ok(())
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue