mirror of
https://github.com/launchbadge/sqlx
synced 2024-11-10 06:24:16 +00:00
Add sqlite commit and rollback hooks (#3500)
* fix: Derive clone for SqliteOperation * feat: Add sqlite commit and rollback hooks --------- Co-authored-by: John Smith <asserta4@gmail.com>
This commit is contained in:
parent
419877d734
commit
daeb87bef1
3 changed files with 236 additions and 4 deletions
|
@ -296,6 +296,8 @@ impl EstablishParams {
|
||||||
log_settings: self.log_settings.clone(),
|
log_settings: self.log_settings.clone(),
|
||||||
progress_handler_callback: None,
|
progress_handler_callback: None,
|
||||||
update_hook_callback: None,
|
update_hook_callback: None,
|
||||||
|
commit_hook_callback: None,
|
||||||
|
rollback_hook_callback: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,8 +11,8 @@ use futures_core::future::BoxFuture;
|
||||||
use futures_intrusive::sync::MutexGuard;
|
use futures_intrusive::sync::MutexGuard;
|
||||||
use futures_util::future;
|
use futures_util::future;
|
||||||
use libsqlite3_sys::{
|
use libsqlite3_sys::{
|
||||||
sqlite3, sqlite3_progress_handler, sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT,
|
sqlite3, sqlite3_commit_hook, sqlite3_progress_handler, sqlite3_rollback_hook,
|
||||||
SQLITE_UPDATE,
|
sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub(crate) use handle::ConnectionHandle;
|
pub(crate) use handle::ConnectionHandle;
|
||||||
|
@ -63,7 +63,7 @@ pub struct LockedSqliteHandle<'a> {
|
||||||
pub(crate) struct Handler(NonNull<dyn FnMut() -> bool + Send + 'static>);
|
pub(crate) struct Handler(NonNull<dyn FnMut() -> bool + Send + 'static>);
|
||||||
unsafe impl Send for Handler {}
|
unsafe impl Send for Handler {}
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Eq)]
|
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||||
pub enum SqliteOperation {
|
pub enum SqliteOperation {
|
||||||
Insert,
|
Insert,
|
||||||
Update,
|
Update,
|
||||||
|
@ -91,6 +91,12 @@ pub struct UpdateHookResult<'a> {
|
||||||
pub(crate) struct UpdateHookHandler(NonNull<dyn FnMut(UpdateHookResult) + Send + 'static>);
|
pub(crate) struct UpdateHookHandler(NonNull<dyn FnMut(UpdateHookResult) + Send + 'static>);
|
||||||
unsafe impl Send for UpdateHookHandler {}
|
unsafe impl Send for UpdateHookHandler {}
|
||||||
|
|
||||||
|
pub(crate) struct CommitHookHandler(NonNull<dyn FnMut() -> bool + Send + 'static>);
|
||||||
|
unsafe impl Send for CommitHookHandler {}
|
||||||
|
|
||||||
|
pub(crate) struct RollbackHookHandler(NonNull<dyn FnMut() + Send + 'static>);
|
||||||
|
unsafe impl Send for RollbackHookHandler {}
|
||||||
|
|
||||||
pub(crate) struct ConnectionState {
|
pub(crate) struct ConnectionState {
|
||||||
pub(crate) handle: ConnectionHandle,
|
pub(crate) handle: ConnectionHandle,
|
||||||
|
|
||||||
|
@ -106,6 +112,10 @@ pub(crate) struct ConnectionState {
|
||||||
progress_handler_callback: Option<Handler>,
|
progress_handler_callback: Option<Handler>,
|
||||||
|
|
||||||
update_hook_callback: Option<UpdateHookHandler>,
|
update_hook_callback: Option<UpdateHookHandler>,
|
||||||
|
|
||||||
|
commit_hook_callback: Option<CommitHookHandler>,
|
||||||
|
|
||||||
|
rollback_hook_callback: Option<RollbackHookHandler>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ConnectionState {
|
impl ConnectionState {
|
||||||
|
@ -127,6 +137,24 @@ impl ConnectionState {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn remove_commit_hook(&mut self) {
|
||||||
|
if let Some(mut handler) = self.commit_hook_callback.take() {
|
||||||
|
unsafe {
|
||||||
|
sqlite3_commit_hook(self.handle.as_ptr(), None, ptr::null_mut());
|
||||||
|
let _ = { Box::from_raw(handler.0.as_mut()) };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn remove_rollback_hook(&mut self) {
|
||||||
|
if let Some(mut handler) = self.rollback_hook_callback.take() {
|
||||||
|
unsafe {
|
||||||
|
sqlite3_rollback_hook(self.handle.as_ptr(), None, ptr::null_mut());
|
||||||
|
let _ = { Box::from_raw(handler.0.as_mut()) };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) struct Statements {
|
pub(crate) struct Statements {
|
||||||
|
@ -284,6 +312,31 @@ extern "C" fn update_hook<F>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
extern "C" fn commit_hook<F>(callback: *mut c_void) -> c_int
|
||||||
|
where
|
||||||
|
F: FnMut() -> bool,
|
||||||
|
{
|
||||||
|
unsafe {
|
||||||
|
let r = catch_unwind(|| {
|
||||||
|
let callback: *mut F = callback.cast::<F>();
|
||||||
|
(*callback)()
|
||||||
|
});
|
||||||
|
c_int::from(!r.unwrap_or_default())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" fn rollback_hook<F>(callback: *mut c_void)
|
||||||
|
where
|
||||||
|
F: FnMut(),
|
||||||
|
{
|
||||||
|
unsafe {
|
||||||
|
let _ = catch_unwind(|| {
|
||||||
|
let callback: *mut F = callback.cast::<F>();
|
||||||
|
(*callback)()
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl LockedSqliteHandle<'_> {
|
impl LockedSqliteHandle<'_> {
|
||||||
/// Returns the underlying sqlite3* connection handle.
|
/// Returns the underlying sqlite3* connection handle.
|
||||||
///
|
///
|
||||||
|
@ -368,6 +421,61 @@ impl LockedSqliteHandle<'_> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Sets a commit hook that is invoked whenever a transaction is committed. If the commit hook callback
|
||||||
|
/// returns `false`, then the operation is turned into a ROLLBACK.
|
||||||
|
///
|
||||||
|
/// Only a single commit hook may be defined at one time per database connection; setting a new commit hook
|
||||||
|
/// overrides the old one.
|
||||||
|
///
|
||||||
|
/// The commit hook callback must not do anything that will modify the database connection that invoked
|
||||||
|
/// the commit hook. Note that sqlite3_prepare_v2() and sqlite3_step() both modify their database connections
|
||||||
|
/// in this context.
|
||||||
|
///
|
||||||
|
/// See https://www.sqlite.org/c3ref/commit_hook.html
|
||||||
|
pub fn set_commit_hook<F>(&mut self, callback: F)
|
||||||
|
where
|
||||||
|
F: FnMut() -> bool + Send + 'static,
|
||||||
|
{
|
||||||
|
unsafe {
|
||||||
|
let callback_boxed = Box::new(callback);
|
||||||
|
// SAFETY: `Box::into_raw()` always returns a non-null pointer.
|
||||||
|
let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed));
|
||||||
|
let handler = callback.as_ptr() as *mut _;
|
||||||
|
self.guard.remove_commit_hook();
|
||||||
|
self.guard.commit_hook_callback = Some(CommitHookHandler(callback));
|
||||||
|
|
||||||
|
sqlite3_commit_hook(
|
||||||
|
self.as_raw_handle().as_mut(),
|
||||||
|
Some(commit_hook::<F>),
|
||||||
|
handler,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sets a rollback hook that is invoked whenever a transaction rollback occurs. The rollback callback is not
|
||||||
|
/// invoked if a transaction is automatically rolled back because the database connection is closed.
|
||||||
|
///
|
||||||
|
/// See https://www.sqlite.org/c3ref/commit_hook.html
|
||||||
|
pub fn set_rollback_hook<F>(&mut self, callback: F)
|
||||||
|
where
|
||||||
|
F: FnMut() + Send + 'static,
|
||||||
|
{
|
||||||
|
unsafe {
|
||||||
|
let callback_boxed = Box::new(callback);
|
||||||
|
// SAFETY: `Box::into_raw()` always returns a non-null pointer.
|
||||||
|
let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed));
|
||||||
|
let handler = callback.as_ptr() as *mut _;
|
||||||
|
self.guard.remove_rollback_hook();
|
||||||
|
self.guard.rollback_hook_callback = Some(RollbackHookHandler(callback));
|
||||||
|
|
||||||
|
sqlite3_rollback_hook(
|
||||||
|
self.as_raw_handle().as_mut(),
|
||||||
|
Some(rollback_hook::<F>),
|
||||||
|
handler,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Removes the progress handler on a database connection. The method does nothing if no handler was set.
|
/// Removes the progress handler on a database connection. The method does nothing if no handler was set.
|
||||||
pub fn remove_progress_handler(&mut self) {
|
pub fn remove_progress_handler(&mut self) {
|
||||||
self.guard.remove_progress_handler();
|
self.guard.remove_progress_handler();
|
||||||
|
@ -376,6 +484,14 @@ impl LockedSqliteHandle<'_> {
|
||||||
pub fn remove_update_hook(&mut self) {
|
pub fn remove_update_hook(&mut self) {
|
||||||
self.guard.remove_update_hook();
|
self.guard.remove_update_hook();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn remove_commit_hook(&mut self) {
|
||||||
|
self.guard.remove_commit_hook();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn remove_rollback_hook(&mut self) {
|
||||||
|
self.guard.remove_rollback_hook();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Drop for ConnectionState {
|
impl Drop for ConnectionState {
|
||||||
|
@ -384,6 +500,8 @@ impl Drop for ConnectionState {
|
||||||
self.statements.clear();
|
self.statements.clear();
|
||||||
self.remove_progress_handler();
|
self.remove_progress_handler();
|
||||||
self.remove_update_hook();
|
self.remove_update_hook();
|
||||||
|
self.remove_commit_hook();
|
||||||
|
self.remove_rollback_hook();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -806,7 +806,7 @@ async fn test_query_with_update_hook() -> anyhow::Result<()> {
|
||||||
assert_eq!(result.operation, SqliteOperation::Insert);
|
assert_eq!(result.operation, SqliteOperation::Insert);
|
||||||
assert_eq!(result.database, "main");
|
assert_eq!(result.database, "main");
|
||||||
assert_eq!(result.table, "tweet");
|
assert_eq!(result.table, "tweet");
|
||||||
assert_eq!(result.rowid, 3);
|
assert_eq!(result.rowid, 2);
|
||||||
});
|
});
|
||||||
|
|
||||||
let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 3, 'Hello, World' )")
|
let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 3, 'Hello, World' )")
|
||||||
|
@ -848,3 +848,115 @@ async fn test_multiple_set_update_hook_calls_drop_old_handler() -> anyhow::Resul
|
||||||
assert_eq!(1, Arc::strong_count(&ref_counted_object));
|
assert_eq!(1, Arc::strong_count(&ref_counted_object));
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[sqlx_macros::test]
|
||||||
|
async fn test_query_with_commit_hook() -> anyhow::Result<()> {
|
||||||
|
let mut conn = new::<Sqlite>().await?;
|
||||||
|
|
||||||
|
// Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer.
|
||||||
|
let state = format!("test");
|
||||||
|
conn.lock_handle().await?.set_commit_hook(move || {
|
||||||
|
assert_eq!(state, "test");
|
||||||
|
false
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut tx = conn.begin().await?;
|
||||||
|
sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 4, 'Hello, World' )")
|
||||||
|
.execute(&mut *tx)
|
||||||
|
.await?;
|
||||||
|
match tx.commit().await {
|
||||||
|
Err(sqlx::Error::Database(err)) => {
|
||||||
|
assert_eq!(err.message(), String::from("constraint failed"))
|
||||||
|
}
|
||||||
|
_ => panic!("expected an error"),
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[sqlx_macros::test]
|
||||||
|
async fn test_multiple_set_commit_hook_calls_drop_old_handler() -> anyhow::Result<()> {
|
||||||
|
let ref_counted_object = Arc::new(0);
|
||||||
|
assert_eq!(1, Arc::strong_count(&ref_counted_object));
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut conn = new::<Sqlite>().await?;
|
||||||
|
|
||||||
|
let o = ref_counted_object.clone();
|
||||||
|
conn.lock_handle().await?.set_commit_hook(move || {
|
||||||
|
println!("{o:?}");
|
||||||
|
true
|
||||||
|
});
|
||||||
|
assert_eq!(2, Arc::strong_count(&ref_counted_object));
|
||||||
|
|
||||||
|
let o = ref_counted_object.clone();
|
||||||
|
conn.lock_handle().await?.set_commit_hook(move || {
|
||||||
|
println!("{o:?}");
|
||||||
|
true
|
||||||
|
});
|
||||||
|
assert_eq!(2, Arc::strong_count(&ref_counted_object));
|
||||||
|
|
||||||
|
let o = ref_counted_object.clone();
|
||||||
|
conn.lock_handle().await?.set_commit_hook(move || {
|
||||||
|
println!("{o:?}");
|
||||||
|
true
|
||||||
|
});
|
||||||
|
assert_eq!(2, Arc::strong_count(&ref_counted_object));
|
||||||
|
|
||||||
|
conn.lock_handle().await?.remove_commit_hook();
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(1, Arc::strong_count(&ref_counted_object));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[sqlx_macros::test]
|
||||||
|
async fn test_query_with_rollback_hook() -> anyhow::Result<()> {
|
||||||
|
let mut conn = new::<Sqlite>().await?;
|
||||||
|
|
||||||
|
// Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer.
|
||||||
|
let state = format!("test");
|
||||||
|
conn.lock_handle().await?.set_rollback_hook(move || {
|
||||||
|
assert_eq!(state, "test");
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut tx = conn.begin().await?;
|
||||||
|
sqlx::query("INSERT INTO tweet ( id, text ) VALUES (5, 'Hello, World' )")
|
||||||
|
.execute(&mut *tx)
|
||||||
|
.await?;
|
||||||
|
tx.rollback().await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[sqlx_macros::test]
|
||||||
|
async fn test_multiple_set_rollback_hook_calls_drop_old_handler() -> anyhow::Result<()> {
|
||||||
|
let ref_counted_object = Arc::new(0);
|
||||||
|
assert_eq!(1, Arc::strong_count(&ref_counted_object));
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut conn = new::<Sqlite>().await?;
|
||||||
|
|
||||||
|
let o = ref_counted_object.clone();
|
||||||
|
conn.lock_handle().await?.set_rollback_hook(move || {
|
||||||
|
println!("{o:?}");
|
||||||
|
});
|
||||||
|
assert_eq!(2, Arc::strong_count(&ref_counted_object));
|
||||||
|
|
||||||
|
let o = ref_counted_object.clone();
|
||||||
|
conn.lock_handle().await?.set_rollback_hook(move || {
|
||||||
|
println!("{o:?}");
|
||||||
|
});
|
||||||
|
assert_eq!(2, Arc::strong_count(&ref_counted_object));
|
||||||
|
|
||||||
|
let o = ref_counted_object.clone();
|
||||||
|
conn.lock_handle().await?.set_rollback_hook(move || {
|
||||||
|
println!("{o:?}");
|
||||||
|
});
|
||||||
|
assert_eq!(2, Arc::strong_count(&ref_counted_object));
|
||||||
|
|
||||||
|
conn.lock_handle().await?.remove_rollback_hook();
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(1, Arc::strong_count(&ref_counted_object));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue