mirror of
https://github.com/launchbadge/sqlx
synced 2024-11-10 06:24:16 +00:00
sqlite: fix a couple segfaults (#1351)
* sqlite: use Arc instead of Copy-able StatementHandle This guarantees that StatementHandle is never used after calling `sqlite3_finalize`. Now `sqlite3_finalize` is only called when StatementHandle is dropped. (cherry picked from commit5eebc05dc3
) * sqlite: use Weak poiter to StatementHandle in the worker Otherwise some tests fail to close connection. (cherry picked from commit5461eeeee3
) * Fix segfault due to race condition in sqlite (#1300) (cherry picked from commitbb62cf767e
) * fix(sqlite): run `sqlite3_reset()` in `StatementWorker` this avoids possible race conditions without using a mutex * fix(sqlite): have `StatementWorker` keep a strong ref to `ConnectionHandle` this should prevent the database handle from being finalized before all statement handles have been finalized * fix(sqlite/test): make `concurrent_resets_dont_segfault` runtime-agnostic Co-authored-by: link2xt <link2xt@testrun.org> Co-authored-by: Adam Cigánek <adam.ciganek@gmail.com>
This commit is contained in:
parent
55c603e9e7
commit
71388a7ef2
11 changed files with 224 additions and 80 deletions
|
@ -64,7 +64,7 @@ pub(super) fn describe<'c: 'e, 'q: 'e, 'e>(
|
|||
// fallback to [column_decltype]
|
||||
if !stepped && stmt.read_only() {
|
||||
stepped = true;
|
||||
let _ = conn.worker.step(*stmt).await;
|
||||
let _ = conn.worker.step(stmt).await;
|
||||
}
|
||||
|
||||
let mut ty = stmt.column_type_info(col);
|
||||
|
|
|
@ -87,7 +87,7 @@ pub(crate) async fn establish(options: &SqliteConnectOptions) -> Result<SqliteCo
|
|||
// https://www.sqlite.org/c3ref/extended_result_codes.html
|
||||
unsafe {
|
||||
// NOTE: ignore the failure here
|
||||
sqlite3_extended_result_codes(handle.0.as_ptr(), 1);
|
||||
sqlite3_extended_result_codes(handle.as_ptr(), 1);
|
||||
}
|
||||
|
||||
// Configure a busy timeout
|
||||
|
@ -99,7 +99,7 @@ pub(crate) async fn establish(options: &SqliteConnectOptions) -> Result<SqliteCo
|
|||
let ms =
|
||||
i32::try_from(busy_timeout.as_millis()).expect("Given busy timeout value is too big.");
|
||||
|
||||
status = unsafe { sqlite3_busy_timeout(handle.0.as_ptr(), ms) };
|
||||
status = unsafe { sqlite3_busy_timeout(handle.as_ptr(), ms) };
|
||||
|
||||
if status != SQLITE_OK {
|
||||
return Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr()))));
|
||||
|
@ -109,8 +109,8 @@ pub(crate) async fn establish(options: &SqliteConnectOptions) -> Result<SqliteCo
|
|||
})?;
|
||||
|
||||
Ok(SqliteConnection {
|
||||
worker: StatementWorker::new(handle.to_ref()),
|
||||
handle,
|
||||
worker: StatementWorker::new(),
|
||||
statements: StatementCache::new(options.statement_cache_capacity),
|
||||
statement: None,
|
||||
transaction_depth: 0,
|
||||
|
|
|
@ -4,7 +4,7 @@ use crate::error::Error;
|
|||
use crate::executor::{Execute, Executor};
|
||||
use crate::logger::QueryLogger;
|
||||
use crate::sqlite::connection::describe::describe;
|
||||
use crate::sqlite::statement::{StatementHandle, VirtualStatement};
|
||||
use crate::sqlite::statement::{StatementHandle, StatementWorker, VirtualStatement};
|
||||
use crate::sqlite::{
|
||||
Sqlite, SqliteArguments, SqliteConnection, SqliteQueryResult, SqliteRow, SqliteStatement,
|
||||
SqliteTypeInfo,
|
||||
|
@ -16,7 +16,8 @@ use libsqlite3_sys::sqlite3_last_insert_rowid;
|
|||
use std::borrow::Cow;
|
||||
use std::sync::Arc;
|
||||
|
||||
fn prepare<'a>(
|
||||
async fn prepare<'a>(
|
||||
worker: &mut StatementWorker,
|
||||
statements: &'a mut StatementCache<VirtualStatement>,
|
||||
statement: &'a mut Option<VirtualStatement>,
|
||||
query: &str,
|
||||
|
@ -39,7 +40,7 @@ fn prepare<'a>(
|
|||
if exists {
|
||||
// as this statement has been executed before, we reset before continuing
|
||||
// this also causes any rows that are from the statement to be inflated
|
||||
statement.reset();
|
||||
statement.reset(worker).await?;
|
||||
}
|
||||
|
||||
Ok(statement)
|
||||
|
@ -61,19 +62,25 @@ fn bind(
|
|||
|
||||
/// A structure holding sqlite statement handle and resetting the
|
||||
/// statement when it is dropped.
|
||||
struct StatementResetter {
|
||||
handle: StatementHandle,
|
||||
struct StatementResetter<'a> {
|
||||
handle: Arc<StatementHandle>,
|
||||
worker: &'a mut StatementWorker,
|
||||
}
|
||||
|
||||
impl StatementResetter {
|
||||
fn new(handle: StatementHandle) -> Self {
|
||||
Self { handle }
|
||||
impl<'a> StatementResetter<'a> {
|
||||
fn new(worker: &'a mut StatementWorker, handle: &Arc<StatementHandle>) -> Self {
|
||||
Self {
|
||||
worker,
|
||||
handle: Arc::clone(handle),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for StatementResetter {
|
||||
impl Drop for StatementResetter<'_> {
|
||||
fn drop(&mut self) {
|
||||
self.handle.reset();
|
||||
// this method is designed to eagerly send the reset command
|
||||
// so we don't need to await or spawn it
|
||||
let _ = self.worker.reset(&self.handle);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -103,7 +110,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
|
|||
} = self;
|
||||
|
||||
// prepare statement object (or checkout from cache)
|
||||
let stmt = prepare(statements, statement, sql, persistent)?;
|
||||
let stmt = prepare(worker, statements, statement, sql, persistent).await?;
|
||||
|
||||
// keep track of how many arguments we have bound
|
||||
let mut num_arguments = 0;
|
||||
|
@ -113,7 +120,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
|
|||
// is dropped. `StatementResetter` will reliably reset the
|
||||
// statement even if the stream returned from `fetch_many`
|
||||
// is dropped early.
|
||||
let _resetter = StatementResetter::new(*stmt);
|
||||
let resetter = StatementResetter::new(worker, stmt);
|
||||
|
||||
// bind values to the statement
|
||||
num_arguments += bind(stmt, &arguments, num_arguments)?;
|
||||
|
@ -125,7 +132,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
|
|||
|
||||
// invoke [sqlite3_step] on the dedicated worker thread
|
||||
// this will move us forward one row or finish the statement
|
||||
let s = worker.step(*stmt).await?;
|
||||
let s = resetter.worker.step(stmt).await?;
|
||||
|
||||
match s {
|
||||
Either::Left(changes) => {
|
||||
|
@ -145,7 +152,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
|
|||
|
||||
Either::Right(()) => {
|
||||
let (row, weak_values_ref) = SqliteRow::current(
|
||||
*stmt,
|
||||
&stmt,
|
||||
columns,
|
||||
column_names
|
||||
);
|
||||
|
@ -188,7 +195,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
|
|||
} = self;
|
||||
|
||||
// prepare statement object (or checkout from cache)
|
||||
let virtual_stmt = prepare(statements, statement, sql, persistent)?;
|
||||
let virtual_stmt = prepare(worker, statements, statement, sql, persistent).await?;
|
||||
|
||||
// keep track of how many arguments we have bound
|
||||
let mut num_arguments = 0;
|
||||
|
@ -205,18 +212,18 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
|
|||
|
||||
// invoke [sqlite3_step] on the dedicated worker thread
|
||||
// this will move us forward one row or finish the statement
|
||||
match worker.step(*stmt).await? {
|
||||
match worker.step(stmt).await? {
|
||||
Either::Left(_) => (),
|
||||
|
||||
Either::Right(()) => {
|
||||
let (row, weak_values_ref) =
|
||||
SqliteRow::current(*stmt, columns, column_names);
|
||||
SqliteRow::current(stmt, columns, column_names);
|
||||
|
||||
*last_row_values = Some(weak_values_ref);
|
||||
|
||||
logger.increment_rows();
|
||||
|
||||
virtual_stmt.reset();
|
||||
virtual_stmt.reset(worker).await?;
|
||||
return Ok(Some(row));
|
||||
}
|
||||
}
|
||||
|
@ -238,11 +245,12 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
|
|||
handle: ref mut conn,
|
||||
ref mut statements,
|
||||
ref mut statement,
|
||||
ref mut worker,
|
||||
..
|
||||
} = self;
|
||||
|
||||
// prepare statement object (or checkout from cache)
|
||||
let statement = prepare(statements, statement, sql, true)?;
|
||||
let statement = prepare(worker, statements, statement, sql, true).await?;
|
||||
|
||||
let mut parameters = 0;
|
||||
let mut columns = None;
|
||||
|
|
|
@ -3,11 +3,23 @@ use std::ptr::NonNull;
|
|||
use libsqlite3_sys::{sqlite3, sqlite3_close, SQLITE_OK};
|
||||
|
||||
use crate::sqlite::SqliteError;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Managed handle to the raw SQLite3 database handle.
|
||||
/// The database handle will be closed when this is dropped.
|
||||
/// The database handle will be closed when this is dropped and no `ConnectionHandleRef`s exist.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct ConnectionHandle(pub(super) NonNull<sqlite3>);
|
||||
pub(crate) struct ConnectionHandle(Arc<HandleInner>);
|
||||
|
||||
/// A wrapper around `ConnectionHandle` which only exists for a `StatementWorker` to own
|
||||
/// which prevents the `sqlite3` handle from being finalized while it is running `sqlite3_step()`
|
||||
/// or `sqlite3_reset()`.
|
||||
///
|
||||
/// Note that this does *not* actually give access to the database handle!
|
||||
pub(crate) struct ConnectionHandleRef(Arc<HandleInner>);
|
||||
|
||||
// Wrapper for `*mut sqlite3` which finalizes the handle on-drop.
|
||||
#[derive(Debug)]
|
||||
struct HandleInner(NonNull<sqlite3>);
|
||||
|
||||
// A SQLite3 handle is safe to send between threads, provided not more than
|
||||
// one is accessing it at the same time. This is upheld as long as [SQLITE_CONFIG_MULTITHREAD] is
|
||||
|
@ -20,19 +32,32 @@ pub(crate) struct ConnectionHandle(pub(super) NonNull<sqlite3>);
|
|||
|
||||
unsafe impl Send for ConnectionHandle {}
|
||||
|
||||
// SAFETY: `Arc<T>` normally only implements `Send` where `T: Sync` because it allows
|
||||
// concurrent access.
|
||||
//
|
||||
// However, in this case we're only using `Arc` to prevent the database handle from being
|
||||
// finalized while the worker still holds a statement handle; `ConnectionHandleRef` thus
|
||||
// should *not* actually provide access to the database handle.
|
||||
unsafe impl Send for ConnectionHandleRef {}
|
||||
|
||||
impl ConnectionHandle {
|
||||
#[inline]
|
||||
pub(super) unsafe fn new(ptr: *mut sqlite3) -> Self {
|
||||
Self(NonNull::new_unchecked(ptr))
|
||||
Self(Arc::new(HandleInner(NonNull::new_unchecked(ptr))))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn as_ptr(&self) -> *mut sqlite3 {
|
||||
self.0.as_ptr()
|
||||
self.0 .0.as_ptr()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn to_ref(&self) -> ConnectionHandleRef {
|
||||
ConnectionHandleRef(Arc::clone(&self.0))
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ConnectionHandle {
|
||||
impl Drop for HandleInner {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
// https://sqlite.org/c3ref/close.html
|
||||
|
|
|
@ -17,7 +17,7 @@ mod executor;
|
|||
mod explain;
|
||||
mod handle;
|
||||
|
||||
pub(crate) use handle::ConnectionHandle;
|
||||
pub(crate) use handle::{ConnectionHandle, ConnectionHandleRef};
|
||||
|
||||
/// A connection to a [Sqlite] database.
|
||||
pub struct SqliteConnection {
|
||||
|
|
|
@ -23,7 +23,7 @@ pub struct SqliteRow {
|
|||
// IF the user drops the Row before iterating the stream (so
|
||||
// nearly all of our internal stream iterators), the executor moves on; otherwise,
|
||||
// it actually inflates this row with a list of owned sqlite3 values.
|
||||
pub(crate) statement: StatementHandle,
|
||||
pub(crate) statement: Arc<StatementHandle>,
|
||||
|
||||
pub(crate) values: Arc<AtomicPtr<SqliteValue>>,
|
||||
pub(crate) num_values: usize,
|
||||
|
@ -48,7 +48,7 @@ impl SqliteRow {
|
|||
// returns a weak reference to an atomic list where the executor should inflate if its going
|
||||
// to increment the statement with [step]
|
||||
pub(crate) fn current(
|
||||
statement: StatementHandle,
|
||||
statement: &Arc<StatementHandle>,
|
||||
columns: &Arc<Vec<SqliteColumn>>,
|
||||
column_names: &Arc<HashMap<UStr, usize>>,
|
||||
) -> (Self, Weak<AtomicPtr<SqliteValue>>) {
|
||||
|
@ -57,7 +57,7 @@ impl SqliteRow {
|
|||
let size = statement.column_count();
|
||||
|
||||
let row = Self {
|
||||
statement,
|
||||
statement: Arc::clone(statement),
|
||||
values,
|
||||
num_values: size,
|
||||
columns: Arc::clone(columns),
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use std::ffi::c_void;
|
||||
use std::ffi::CStr;
|
||||
|
||||
use std::os::raw::{c_char, c_int};
|
||||
use std::ptr;
|
||||
use std::ptr::NonNull;
|
||||
|
@ -9,21 +10,22 @@ use std::str::{from_utf8, from_utf8_unchecked};
|
|||
use libsqlite3_sys::{
|
||||
sqlite3, sqlite3_bind_blob64, sqlite3_bind_double, sqlite3_bind_int, sqlite3_bind_int64,
|
||||
sqlite3_bind_null, sqlite3_bind_parameter_count, sqlite3_bind_parameter_name,
|
||||
sqlite3_bind_text64, sqlite3_changes, sqlite3_column_blob, sqlite3_column_bytes,
|
||||
sqlite3_column_count, sqlite3_column_database_name, sqlite3_column_decltype,
|
||||
sqlite3_column_double, sqlite3_column_int, sqlite3_column_int64, sqlite3_column_name,
|
||||
sqlite3_column_origin_name, sqlite3_column_table_name, sqlite3_column_type,
|
||||
sqlite3_column_value, sqlite3_db_handle, sqlite3_reset, sqlite3_sql, sqlite3_stmt,
|
||||
sqlite3_stmt_readonly, sqlite3_table_column_metadata, sqlite3_value, SQLITE_OK,
|
||||
SQLITE_TRANSIENT, SQLITE_UTF8,
|
||||
sqlite3_bind_text64, sqlite3_changes, sqlite3_clear_bindings, sqlite3_column_blob,
|
||||
sqlite3_column_bytes, sqlite3_column_count, sqlite3_column_database_name,
|
||||
sqlite3_column_decltype, sqlite3_column_double, sqlite3_column_int, sqlite3_column_int64,
|
||||
sqlite3_column_name, sqlite3_column_origin_name, sqlite3_column_table_name,
|
||||
sqlite3_column_type, sqlite3_column_value, sqlite3_db_handle, sqlite3_finalize, sqlite3_reset,
|
||||
sqlite3_sql, sqlite3_step, sqlite3_stmt, sqlite3_stmt_readonly, sqlite3_table_column_metadata,
|
||||
sqlite3_value, SQLITE_DONE, SQLITE_MISUSE, SQLITE_OK, SQLITE_ROW, SQLITE_TRANSIENT,
|
||||
SQLITE_UTF8,
|
||||
};
|
||||
|
||||
use crate::error::{BoxDynError, Error};
|
||||
use crate::sqlite::type_info::DataType;
|
||||
use crate::sqlite::{SqliteError, SqliteTypeInfo};
|
||||
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub(crate) struct StatementHandle(pub(super) NonNull<sqlite3_stmt>);
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct StatementHandle(NonNull<sqlite3_stmt>);
|
||||
|
||||
// access to SQLite3 statement handles are safe to send and share between threads
|
||||
// as long as the `sqlite3_step` call is serialized.
|
||||
|
@ -32,6 +34,14 @@ unsafe impl Send for StatementHandle {}
|
|||
unsafe impl Sync for StatementHandle {}
|
||||
|
||||
impl StatementHandle {
|
||||
pub(super) fn new(ptr: NonNull<sqlite3_stmt>) -> Self {
|
||||
Self(ptr)
|
||||
}
|
||||
|
||||
pub(crate) fn as_ptr(&self) -> *mut sqlite3_stmt {
|
||||
self.0.as_ptr()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(super) unsafe fn db_handle(&self) -> *mut sqlite3 {
|
||||
// O(c) access to the connection handle for this statement handle
|
||||
|
@ -280,7 +290,25 @@ impl StatementHandle {
|
|||
Ok(from_utf8(self.column_blob(index))?)
|
||||
}
|
||||
|
||||
pub(crate) fn reset(&self) {
|
||||
unsafe { sqlite3_reset(self.0.as_ptr()) };
|
||||
pub(crate) fn clear_bindings(&self) {
|
||||
unsafe { sqlite3_clear_bindings(self.0.as_ptr()) };
|
||||
}
|
||||
}
|
||||
impl Drop for StatementHandle {
|
||||
fn drop(&mut self) {
|
||||
// SAFETY: we have exclusive access to the `StatementHandle` here
|
||||
unsafe {
|
||||
// https://sqlite.org/c3ref/finalize.html
|
||||
let status = sqlite3_finalize(self.0.as_ptr());
|
||||
if status == SQLITE_MISUSE {
|
||||
// Panic in case of detected misuse of SQLite API.
|
||||
//
|
||||
// sqlite3_finalize returns it at least in the
|
||||
// case of detected double free, i.e. calling
|
||||
// sqlite3_finalize on already finalized
|
||||
// statement.
|
||||
panic!("Detected sqlite3_finalize misuse.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,13 +3,12 @@
|
|||
use crate::error::Error;
|
||||
use crate::ext::ustr::UStr;
|
||||
use crate::sqlite::connection::ConnectionHandle;
|
||||
use crate::sqlite::statement::StatementHandle;
|
||||
use crate::sqlite::statement::{StatementHandle, StatementWorker};
|
||||
use crate::sqlite::{SqliteColumn, SqliteError, SqliteRow, SqliteValue};
|
||||
use crate::HashMap;
|
||||
use bytes::{Buf, Bytes};
|
||||
use libsqlite3_sys::{
|
||||
sqlite3, sqlite3_clear_bindings, sqlite3_finalize, sqlite3_prepare_v3, sqlite3_reset,
|
||||
sqlite3_stmt, SQLITE_MISUSE, SQLITE_OK, SQLITE_PREPARE_PERSISTENT,
|
||||
sqlite3, sqlite3_prepare_v3, sqlite3_stmt, SQLITE_OK, SQLITE_PREPARE_PERSISTENT,
|
||||
};
|
||||
use smallvec::SmallVec;
|
||||
use std::i32;
|
||||
|
@ -31,7 +30,7 @@ pub(crate) struct VirtualStatement {
|
|||
// underlying sqlite handles for each inner statement
|
||||
// a SQL query string in SQLite is broken up into N statements
|
||||
// we use a [`SmallVec`] to optimize for the most likely case of a single statement
|
||||
pub(crate) handles: SmallVec<[StatementHandle; 1]>,
|
||||
pub(crate) handles: SmallVec<[Arc<StatementHandle>; 1]>,
|
||||
|
||||
// each set of columns
|
||||
pub(crate) columns: SmallVec<[Arc<Vec<SqliteColumn>>; 1]>,
|
||||
|
@ -92,7 +91,7 @@ fn prepare(
|
|||
query.advance(n);
|
||||
|
||||
if let Some(handle) = NonNull::new(statement_handle) {
|
||||
return Ok(Some(StatementHandle(handle)));
|
||||
return Ok(Some(StatementHandle::new(handle)));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -126,7 +125,7 @@ impl VirtualStatement {
|
|||
conn: &mut ConnectionHandle,
|
||||
) -> Result<
|
||||
Option<(
|
||||
&StatementHandle,
|
||||
&Arc<StatementHandle>,
|
||||
&mut Arc<Vec<SqliteColumn>>,
|
||||
&Arc<HashMap<UStr, usize>>,
|
||||
&mut Option<Weak<AtomicPtr<SqliteValue>>>,
|
||||
|
@ -159,7 +158,7 @@ impl VirtualStatement {
|
|||
column_names.insert(name, i);
|
||||
}
|
||||
|
||||
self.handles.push(statement);
|
||||
self.handles.push(Arc::new(statement));
|
||||
self.columns.push(Arc::new(columns));
|
||||
self.column_names.push(Arc::new(column_names));
|
||||
self.last_row_values.push(None);
|
||||
|
@ -177,20 +176,20 @@ impl VirtualStatement {
|
|||
)))
|
||||
}
|
||||
|
||||
pub(crate) fn reset(&mut self) {
|
||||
pub(crate) async fn reset(&mut self, worker: &mut StatementWorker) -> Result<(), Error> {
|
||||
self.index = 0;
|
||||
|
||||
for (i, handle) in self.handles.iter().enumerate() {
|
||||
SqliteRow::inflate_if_needed(&handle, &self.columns[i], self.last_row_values[i].take());
|
||||
|
||||
unsafe {
|
||||
// Reset A Prepared Statement Object
|
||||
// https://www.sqlite.org/c3ref/reset.html
|
||||
// https://www.sqlite.org/c3ref/clear_bindings.html
|
||||
sqlite3_reset(handle.0.as_ptr());
|
||||
sqlite3_clear_bindings(handle.0.as_ptr());
|
||||
}
|
||||
worker.reset(handle).await?;
|
||||
handle.clear_bindings();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -198,20 +197,6 @@ impl Drop for VirtualStatement {
|
|||
fn drop(&mut self) {
|
||||
for (i, handle) in self.handles.drain(..).enumerate() {
|
||||
SqliteRow::inflate_if_needed(&handle, &self.columns[i], self.last_row_values[i].take());
|
||||
|
||||
unsafe {
|
||||
// https://sqlite.org/c3ref/finalize.html
|
||||
let status = sqlite3_finalize(handle.0.as_ptr());
|
||||
if status == SQLITE_MISUSE {
|
||||
// Panic in case of detected misuse of SQLite API.
|
||||
//
|
||||
// sqlite3_finalize returns it at least in the
|
||||
// case of detected double free, i.e. calling
|
||||
// sqlite3_finalize on already finalized
|
||||
// statement.
|
||||
panic!("Detected sqlite3_finalize misuse.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,9 +3,14 @@ use crate::sqlite::statement::StatementHandle;
|
|||
use crossbeam_channel::{unbounded, Sender};
|
||||
use either::Either;
|
||||
use futures_channel::oneshot;
|
||||
use libsqlite3_sys::{sqlite3_step, SQLITE_DONE, SQLITE_ROW};
|
||||
use std::sync::{Arc, Weak};
|
||||
use std::thread;
|
||||
|
||||
use crate::sqlite::connection::ConnectionHandleRef;
|
||||
|
||||
use libsqlite3_sys::{sqlite3_reset, sqlite3_step, SQLITE_DONE, SQLITE_ROW};
|
||||
use std::future::Future;
|
||||
|
||||
// Each SQLite connection has a dedicated thread.
|
||||
|
||||
// TODO: Tweak this so that we can use a thread pool per pool of SQLite3 connections to reduce
|
||||
|
@ -18,31 +23,60 @@ pub(crate) struct StatementWorker {
|
|||
|
||||
enum StatementWorkerCommand {
|
||||
Step {
|
||||
statement: StatementHandle,
|
||||
statement: Weak<StatementHandle>,
|
||||
tx: oneshot::Sender<Result<Either<u64, ()>, Error>>,
|
||||
},
|
||||
Reset {
|
||||
statement: Weak<StatementHandle>,
|
||||
tx: oneshot::Sender<()>,
|
||||
},
|
||||
}
|
||||
|
||||
impl StatementWorker {
|
||||
pub(crate) fn new() -> Self {
|
||||
pub(crate) fn new(conn: ConnectionHandleRef) -> Self {
|
||||
let (tx, rx) = unbounded();
|
||||
|
||||
thread::spawn(move || {
|
||||
for cmd in rx {
|
||||
match cmd {
|
||||
StatementWorkerCommand::Step { statement, tx } => {
|
||||
let status = unsafe { sqlite3_step(statement.0.as_ptr()) };
|
||||
let statement = if let Some(statement) = statement.upgrade() {
|
||||
statement
|
||||
} else {
|
||||
// statement is already finalized, the sender shouldn't be expecting a response
|
||||
continue;
|
||||
};
|
||||
|
||||
let resp = match status {
|
||||
// SAFETY: only the `StatementWorker` calls this function
|
||||
let status = unsafe { sqlite3_step(statement.as_ptr()) };
|
||||
let result = match status {
|
||||
SQLITE_ROW => Ok(Either::Right(())),
|
||||
SQLITE_DONE => Ok(Either::Left(statement.changes())),
|
||||
_ => Err(statement.last_error().into()),
|
||||
};
|
||||
|
||||
let _ = tx.send(resp);
|
||||
let _ = tx.send(result);
|
||||
}
|
||||
StatementWorkerCommand::Reset { statement, tx } => {
|
||||
if let Some(statement) = statement.upgrade() {
|
||||
// SAFETY: this must be the only place we call `sqlite3_reset`
|
||||
unsafe { sqlite3_reset(statement.as_ptr()) };
|
||||
|
||||
// `sqlite3_reset()` always returns either `SQLITE_OK`
|
||||
// or the last error code for the statement,
|
||||
// which should have already been handled;
|
||||
// so it's assumed the return value is safe to ignore.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/reset.html
|
||||
|
||||
let _ = tx.send(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SAFETY: we need to make sure a strong ref to `conn` always outlives anything in `rx`
|
||||
drop(conn);
|
||||
});
|
||||
|
||||
Self { tx }
|
||||
|
@ -50,14 +84,47 @@ impl StatementWorker {
|
|||
|
||||
pub(crate) async fn step(
|
||||
&mut self,
|
||||
statement: StatementHandle,
|
||||
statement: &Arc<StatementHandle>,
|
||||
) -> Result<Either<u64, ()>, Error> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
self.tx
|
||||
.send(StatementWorkerCommand::Step { statement, tx })
|
||||
.send(StatementWorkerCommand::Step {
|
||||
statement: Arc::downgrade(statement),
|
||||
tx,
|
||||
})
|
||||
.map_err(|_| Error::WorkerCrashed)?;
|
||||
|
||||
rx.await.map_err(|_| Error::WorkerCrashed)?
|
||||
}
|
||||
|
||||
/// Send a command to the worker to execute `sqlite3_reset()` next.
|
||||
///
|
||||
/// This method is written to execute the sending of the command eagerly so
|
||||
/// you do not need to await the returned future unless you want to.
|
||||
///
|
||||
/// The only error is `WorkerCrashed` as `sqlite3_reset()` returns the last error
|
||||
/// in the statement execution which should have already been handled from `step()`.
|
||||
pub(crate) fn reset(
|
||||
&mut self,
|
||||
statement: &Arc<StatementHandle>,
|
||||
) -> impl Future<Output = Result<(), Error>> {
|
||||
// execute the sending eagerly so we don't need to spawn the future
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
let send_res = self
|
||||
.tx
|
||||
.send(StatementWorkerCommand::Reset {
|
||||
statement: Arc::downgrade(statement),
|
||||
tx,
|
||||
})
|
||||
.map_err(|_| Error::WorkerCrashed);
|
||||
|
||||
async move {
|
||||
send_res?;
|
||||
|
||||
// wait for the response
|
||||
rx.await.map_err(|_| Error::WorkerCrashed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Binary file not shown.
|
@ -536,3 +536,34 @@ async fn it_resets_prepared_statement_after_fetch_many() -> anyhow::Result<()> {
|
|||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// https://github.com/launchbadge/sqlx/issues/1300
|
||||
#[sqlx_macros::test]
|
||||
async fn concurrent_resets_dont_segfault() {
|
||||
use sqlx::{sqlite::SqliteConnectOptions, ConnectOptions};
|
||||
use std::{str::FromStr, time::Duration};
|
||||
|
||||
let mut conn = SqliteConnectOptions::from_str(":memory:")
|
||||
.unwrap()
|
||||
.connect()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
sqlx::query("CREATE TABLE stuff (name INTEGER, value INTEGER)")
|
||||
.execute(&mut conn)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
sqlx_rt::spawn(async move {
|
||||
for i in 0..1000 {
|
||||
sqlx::query("INSERT INTO stuff (name, value) VALUES (?, ?)")
|
||||
.bind(i)
|
||||
.bind(0)
|
||||
.execute(&mut conn)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
sqlx_rt::sleep(Duration::from_millis(1)).await;
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue