sqlite: fix a couple segfaults (#1351)

* sqlite: use Arc instead of Copy-able StatementHandle

This guarantees that StatementHandle is never used after calling
`sqlite3_finalize`. Now `sqlite3_finalize` is only called when
StatementHandle is dropped.

(cherry picked from commit 5eebc05dc3)

* sqlite: use Weak poiter to StatementHandle in the worker

Otherwise some tests fail to close connection.

(cherry picked from commit 5461eeeee3)

* Fix segfault due to race condition in sqlite (#1300)

(cherry picked from commit bb62cf767e)

* fix(sqlite): run `sqlite3_reset()` in `StatementWorker`

this avoids possible race conditions without using a mutex

* fix(sqlite): have `StatementWorker` keep a strong ref to `ConnectionHandle`

this should prevent the database handle from being finalized before all statement handles
have been finalized

* fix(sqlite/test): make `concurrent_resets_dont_segfault` runtime-agnostic

Co-authored-by: link2xt <link2xt@testrun.org>
Co-authored-by: Adam Cigánek <adam.ciganek@gmail.com>
This commit is contained in:
Austin Bonander 2021-08-16 14:39:45 -07:00 committed by GitHub
parent 55c603e9e7
commit 71388a7ef2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 224 additions and 80 deletions

View file

@ -64,7 +64,7 @@ pub(super) fn describe<'c: 'e, 'q: 'e, 'e>(
// fallback to [column_decltype]
if !stepped && stmt.read_only() {
stepped = true;
let _ = conn.worker.step(*stmt).await;
let _ = conn.worker.step(stmt).await;
}
let mut ty = stmt.column_type_info(col);

View file

@ -87,7 +87,7 @@ pub(crate) async fn establish(options: &SqliteConnectOptions) -> Result<SqliteCo
// https://www.sqlite.org/c3ref/extended_result_codes.html
unsafe {
// NOTE: ignore the failure here
sqlite3_extended_result_codes(handle.0.as_ptr(), 1);
sqlite3_extended_result_codes(handle.as_ptr(), 1);
}
// Configure a busy timeout
@ -99,7 +99,7 @@ pub(crate) async fn establish(options: &SqliteConnectOptions) -> Result<SqliteCo
let ms =
i32::try_from(busy_timeout.as_millis()).expect("Given busy timeout value is too big.");
status = unsafe { sqlite3_busy_timeout(handle.0.as_ptr(), ms) };
status = unsafe { sqlite3_busy_timeout(handle.as_ptr(), ms) };
if status != SQLITE_OK {
return Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr()))));
@ -109,8 +109,8 @@ pub(crate) async fn establish(options: &SqliteConnectOptions) -> Result<SqliteCo
})?;
Ok(SqliteConnection {
worker: StatementWorker::new(handle.to_ref()),
handle,
worker: StatementWorker::new(),
statements: StatementCache::new(options.statement_cache_capacity),
statement: None,
transaction_depth: 0,

View file

@ -4,7 +4,7 @@ use crate::error::Error;
use crate::executor::{Execute, Executor};
use crate::logger::QueryLogger;
use crate::sqlite::connection::describe::describe;
use crate::sqlite::statement::{StatementHandle, VirtualStatement};
use crate::sqlite::statement::{StatementHandle, StatementWorker, VirtualStatement};
use crate::sqlite::{
Sqlite, SqliteArguments, SqliteConnection, SqliteQueryResult, SqliteRow, SqliteStatement,
SqliteTypeInfo,
@ -16,7 +16,8 @@ use libsqlite3_sys::sqlite3_last_insert_rowid;
use std::borrow::Cow;
use std::sync::Arc;
fn prepare<'a>(
async fn prepare<'a>(
worker: &mut StatementWorker,
statements: &'a mut StatementCache<VirtualStatement>,
statement: &'a mut Option<VirtualStatement>,
query: &str,
@ -39,7 +40,7 @@ fn prepare<'a>(
if exists {
// as this statement has been executed before, we reset before continuing
// this also causes any rows that are from the statement to be inflated
statement.reset();
statement.reset(worker).await?;
}
Ok(statement)
@ -61,19 +62,25 @@ fn bind(
/// A structure holding sqlite statement handle and resetting the
/// statement when it is dropped.
struct StatementResetter {
handle: StatementHandle,
struct StatementResetter<'a> {
handle: Arc<StatementHandle>,
worker: &'a mut StatementWorker,
}
impl StatementResetter {
fn new(handle: StatementHandle) -> Self {
Self { handle }
impl<'a> StatementResetter<'a> {
fn new(worker: &'a mut StatementWorker, handle: &Arc<StatementHandle>) -> Self {
Self {
worker,
handle: Arc::clone(handle),
}
}
}
impl Drop for StatementResetter {
impl Drop for StatementResetter<'_> {
fn drop(&mut self) {
self.handle.reset();
// this method is designed to eagerly send the reset command
// so we don't need to await or spawn it
let _ = self.worker.reset(&self.handle);
}
}
@ -103,7 +110,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
} = self;
// prepare statement object (or checkout from cache)
let stmt = prepare(statements, statement, sql, persistent)?;
let stmt = prepare(worker, statements, statement, sql, persistent).await?;
// keep track of how many arguments we have bound
let mut num_arguments = 0;
@ -113,7 +120,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
// is dropped. `StatementResetter` will reliably reset the
// statement even if the stream returned from `fetch_many`
// is dropped early.
let _resetter = StatementResetter::new(*stmt);
let resetter = StatementResetter::new(worker, stmt);
// bind values to the statement
num_arguments += bind(stmt, &arguments, num_arguments)?;
@ -125,7 +132,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
// invoke [sqlite3_step] on the dedicated worker thread
// this will move us forward one row or finish the statement
let s = worker.step(*stmt).await?;
let s = resetter.worker.step(stmt).await?;
match s {
Either::Left(changes) => {
@ -145,7 +152,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
Either::Right(()) => {
let (row, weak_values_ref) = SqliteRow::current(
*stmt,
&stmt,
columns,
column_names
);
@ -188,7 +195,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
} = self;
// prepare statement object (or checkout from cache)
let virtual_stmt = prepare(statements, statement, sql, persistent)?;
let virtual_stmt = prepare(worker, statements, statement, sql, persistent).await?;
// keep track of how many arguments we have bound
let mut num_arguments = 0;
@ -205,18 +212,18 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
// invoke [sqlite3_step] on the dedicated worker thread
// this will move us forward one row or finish the statement
match worker.step(*stmt).await? {
match worker.step(stmt).await? {
Either::Left(_) => (),
Either::Right(()) => {
let (row, weak_values_ref) =
SqliteRow::current(*stmt, columns, column_names);
SqliteRow::current(stmt, columns, column_names);
*last_row_values = Some(weak_values_ref);
logger.increment_rows();
virtual_stmt.reset();
virtual_stmt.reset(worker).await?;
return Ok(Some(row));
}
}
@ -238,11 +245,12 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
handle: ref mut conn,
ref mut statements,
ref mut statement,
ref mut worker,
..
} = self;
// prepare statement object (or checkout from cache)
let statement = prepare(statements, statement, sql, true)?;
let statement = prepare(worker, statements, statement, sql, true).await?;
let mut parameters = 0;
let mut columns = None;

View file

@ -3,11 +3,23 @@ use std::ptr::NonNull;
use libsqlite3_sys::{sqlite3, sqlite3_close, SQLITE_OK};
use crate::sqlite::SqliteError;
use std::sync::Arc;
/// Managed handle to the raw SQLite3 database handle.
/// The database handle will be closed when this is dropped.
/// The database handle will be closed when this is dropped and no `ConnectionHandleRef`s exist.
#[derive(Debug)]
pub(crate) struct ConnectionHandle(pub(super) NonNull<sqlite3>);
pub(crate) struct ConnectionHandle(Arc<HandleInner>);
/// A wrapper around `ConnectionHandle` which only exists for a `StatementWorker` to own
/// which prevents the `sqlite3` handle from being finalized while it is running `sqlite3_step()`
/// or `sqlite3_reset()`.
///
/// Note that this does *not* actually give access to the database handle!
pub(crate) struct ConnectionHandleRef(Arc<HandleInner>);
// Wrapper for `*mut sqlite3` which finalizes the handle on-drop.
#[derive(Debug)]
struct HandleInner(NonNull<sqlite3>);
// A SQLite3 handle is safe to send between threads, provided not more than
// one is accessing it at the same time. This is upheld as long as [SQLITE_CONFIG_MULTITHREAD] is
@ -20,19 +32,32 @@ pub(crate) struct ConnectionHandle(pub(super) NonNull<sqlite3>);
unsafe impl Send for ConnectionHandle {}
// SAFETY: `Arc<T>` normally only implements `Send` where `T: Sync` because it allows
// concurrent access.
//
// However, in this case we're only using `Arc` to prevent the database handle from being
// finalized while the worker still holds a statement handle; `ConnectionHandleRef` thus
// should *not* actually provide access to the database handle.
unsafe impl Send for ConnectionHandleRef {}
impl ConnectionHandle {
#[inline]
pub(super) unsafe fn new(ptr: *mut sqlite3) -> Self {
Self(NonNull::new_unchecked(ptr))
Self(Arc::new(HandleInner(NonNull::new_unchecked(ptr))))
}
#[inline]
pub(crate) fn as_ptr(&self) -> *mut sqlite3 {
self.0.as_ptr()
self.0 .0.as_ptr()
}
#[inline]
pub(crate) fn to_ref(&self) -> ConnectionHandleRef {
ConnectionHandleRef(Arc::clone(&self.0))
}
}
impl Drop for ConnectionHandle {
impl Drop for HandleInner {
fn drop(&mut self) {
unsafe {
// https://sqlite.org/c3ref/close.html

View file

@ -17,7 +17,7 @@ mod executor;
mod explain;
mod handle;
pub(crate) use handle::ConnectionHandle;
pub(crate) use handle::{ConnectionHandle, ConnectionHandleRef};
/// A connection to a [Sqlite] database.
pub struct SqliteConnection {

View file

@ -23,7 +23,7 @@ pub struct SqliteRow {
// IF the user drops the Row before iterating the stream (so
// nearly all of our internal stream iterators), the executor moves on; otherwise,
// it actually inflates this row with a list of owned sqlite3 values.
pub(crate) statement: StatementHandle,
pub(crate) statement: Arc<StatementHandle>,
pub(crate) values: Arc<AtomicPtr<SqliteValue>>,
pub(crate) num_values: usize,
@ -48,7 +48,7 @@ impl SqliteRow {
// returns a weak reference to an atomic list where the executor should inflate if its going
// to increment the statement with [step]
pub(crate) fn current(
statement: StatementHandle,
statement: &Arc<StatementHandle>,
columns: &Arc<Vec<SqliteColumn>>,
column_names: &Arc<HashMap<UStr, usize>>,
) -> (Self, Weak<AtomicPtr<SqliteValue>>) {
@ -57,7 +57,7 @@ impl SqliteRow {
let size = statement.column_count();
let row = Self {
statement,
statement: Arc::clone(statement),
values,
num_values: size,
columns: Arc::clone(columns),

View file

@ -1,5 +1,6 @@
use std::ffi::c_void;
use std::ffi::CStr;
use std::os::raw::{c_char, c_int};
use std::ptr;
use std::ptr::NonNull;
@ -9,21 +10,22 @@ use std::str::{from_utf8, from_utf8_unchecked};
use libsqlite3_sys::{
sqlite3, sqlite3_bind_blob64, sqlite3_bind_double, sqlite3_bind_int, sqlite3_bind_int64,
sqlite3_bind_null, sqlite3_bind_parameter_count, sqlite3_bind_parameter_name,
sqlite3_bind_text64, sqlite3_changes, sqlite3_column_blob, sqlite3_column_bytes,
sqlite3_column_count, sqlite3_column_database_name, sqlite3_column_decltype,
sqlite3_column_double, sqlite3_column_int, sqlite3_column_int64, sqlite3_column_name,
sqlite3_column_origin_name, sqlite3_column_table_name, sqlite3_column_type,
sqlite3_column_value, sqlite3_db_handle, sqlite3_reset, sqlite3_sql, sqlite3_stmt,
sqlite3_stmt_readonly, sqlite3_table_column_metadata, sqlite3_value, SQLITE_OK,
SQLITE_TRANSIENT, SQLITE_UTF8,
sqlite3_bind_text64, sqlite3_changes, sqlite3_clear_bindings, sqlite3_column_blob,
sqlite3_column_bytes, sqlite3_column_count, sqlite3_column_database_name,
sqlite3_column_decltype, sqlite3_column_double, sqlite3_column_int, sqlite3_column_int64,
sqlite3_column_name, sqlite3_column_origin_name, sqlite3_column_table_name,
sqlite3_column_type, sqlite3_column_value, sqlite3_db_handle, sqlite3_finalize, sqlite3_reset,
sqlite3_sql, sqlite3_step, sqlite3_stmt, sqlite3_stmt_readonly, sqlite3_table_column_metadata,
sqlite3_value, SQLITE_DONE, SQLITE_MISUSE, SQLITE_OK, SQLITE_ROW, SQLITE_TRANSIENT,
SQLITE_UTF8,
};
use crate::error::{BoxDynError, Error};
use crate::sqlite::type_info::DataType;
use crate::sqlite::{SqliteError, SqliteTypeInfo};
#[derive(Debug, Copy, Clone)]
pub(crate) struct StatementHandle(pub(super) NonNull<sqlite3_stmt>);
#[derive(Debug)]
pub(crate) struct StatementHandle(NonNull<sqlite3_stmt>);
// access to SQLite3 statement handles are safe to send and share between threads
// as long as the `sqlite3_step` call is serialized.
@ -32,6 +34,14 @@ unsafe impl Send for StatementHandle {}
unsafe impl Sync for StatementHandle {}
impl StatementHandle {
pub(super) fn new(ptr: NonNull<sqlite3_stmt>) -> Self {
Self(ptr)
}
pub(crate) fn as_ptr(&self) -> *mut sqlite3_stmt {
self.0.as_ptr()
}
#[inline]
pub(super) unsafe fn db_handle(&self) -> *mut sqlite3 {
// O(c) access to the connection handle for this statement handle
@ -280,7 +290,25 @@ impl StatementHandle {
Ok(from_utf8(self.column_blob(index))?)
}
pub(crate) fn reset(&self) {
unsafe { sqlite3_reset(self.0.as_ptr()) };
pub(crate) fn clear_bindings(&self) {
unsafe { sqlite3_clear_bindings(self.0.as_ptr()) };
}
}
impl Drop for StatementHandle {
fn drop(&mut self) {
// SAFETY: we have exclusive access to the `StatementHandle` here
unsafe {
// https://sqlite.org/c3ref/finalize.html
let status = sqlite3_finalize(self.0.as_ptr());
if status == SQLITE_MISUSE {
// Panic in case of detected misuse of SQLite API.
//
// sqlite3_finalize returns it at least in the
// case of detected double free, i.e. calling
// sqlite3_finalize on already finalized
// statement.
panic!("Detected sqlite3_finalize misuse.");
}
}
}
}

View file

@ -3,13 +3,12 @@
use crate::error::Error;
use crate::ext::ustr::UStr;
use crate::sqlite::connection::ConnectionHandle;
use crate::sqlite::statement::StatementHandle;
use crate::sqlite::statement::{StatementHandle, StatementWorker};
use crate::sqlite::{SqliteColumn, SqliteError, SqliteRow, SqliteValue};
use crate::HashMap;
use bytes::{Buf, Bytes};
use libsqlite3_sys::{
sqlite3, sqlite3_clear_bindings, sqlite3_finalize, sqlite3_prepare_v3, sqlite3_reset,
sqlite3_stmt, SQLITE_MISUSE, SQLITE_OK, SQLITE_PREPARE_PERSISTENT,
sqlite3, sqlite3_prepare_v3, sqlite3_stmt, SQLITE_OK, SQLITE_PREPARE_PERSISTENT,
};
use smallvec::SmallVec;
use std::i32;
@ -31,7 +30,7 @@ pub(crate) struct VirtualStatement {
// underlying sqlite handles for each inner statement
// a SQL query string in SQLite is broken up into N statements
// we use a [`SmallVec`] to optimize for the most likely case of a single statement
pub(crate) handles: SmallVec<[StatementHandle; 1]>,
pub(crate) handles: SmallVec<[Arc<StatementHandle>; 1]>,
// each set of columns
pub(crate) columns: SmallVec<[Arc<Vec<SqliteColumn>>; 1]>,
@ -92,7 +91,7 @@ fn prepare(
query.advance(n);
if let Some(handle) = NonNull::new(statement_handle) {
return Ok(Some(StatementHandle(handle)));
return Ok(Some(StatementHandle::new(handle)));
}
}
@ -126,7 +125,7 @@ impl VirtualStatement {
conn: &mut ConnectionHandle,
) -> Result<
Option<(
&StatementHandle,
&Arc<StatementHandle>,
&mut Arc<Vec<SqliteColumn>>,
&Arc<HashMap<UStr, usize>>,
&mut Option<Weak<AtomicPtr<SqliteValue>>>,
@ -159,7 +158,7 @@ impl VirtualStatement {
column_names.insert(name, i);
}
self.handles.push(statement);
self.handles.push(Arc::new(statement));
self.columns.push(Arc::new(columns));
self.column_names.push(Arc::new(column_names));
self.last_row_values.push(None);
@ -177,20 +176,20 @@ impl VirtualStatement {
)))
}
pub(crate) fn reset(&mut self) {
pub(crate) async fn reset(&mut self, worker: &mut StatementWorker) -> Result<(), Error> {
self.index = 0;
for (i, handle) in self.handles.iter().enumerate() {
SqliteRow::inflate_if_needed(&handle, &self.columns[i], self.last_row_values[i].take());
unsafe {
// Reset A Prepared Statement Object
// https://www.sqlite.org/c3ref/reset.html
// https://www.sqlite.org/c3ref/clear_bindings.html
sqlite3_reset(handle.0.as_ptr());
sqlite3_clear_bindings(handle.0.as_ptr());
}
// Reset A Prepared Statement Object
// https://www.sqlite.org/c3ref/reset.html
// https://www.sqlite.org/c3ref/clear_bindings.html
worker.reset(handle).await?;
handle.clear_bindings();
}
Ok(())
}
}
@ -198,20 +197,6 @@ impl Drop for VirtualStatement {
fn drop(&mut self) {
for (i, handle) in self.handles.drain(..).enumerate() {
SqliteRow::inflate_if_needed(&handle, &self.columns[i], self.last_row_values[i].take());
unsafe {
// https://sqlite.org/c3ref/finalize.html
let status = sqlite3_finalize(handle.0.as_ptr());
if status == SQLITE_MISUSE {
// Panic in case of detected misuse of SQLite API.
//
// sqlite3_finalize returns it at least in the
// case of detected double free, i.e. calling
// sqlite3_finalize on already finalized
// statement.
panic!("Detected sqlite3_finalize misuse.");
}
}
}
}
}

View file

@ -3,9 +3,14 @@ use crate::sqlite::statement::StatementHandle;
use crossbeam_channel::{unbounded, Sender};
use either::Either;
use futures_channel::oneshot;
use libsqlite3_sys::{sqlite3_step, SQLITE_DONE, SQLITE_ROW};
use std::sync::{Arc, Weak};
use std::thread;
use crate::sqlite::connection::ConnectionHandleRef;
use libsqlite3_sys::{sqlite3_reset, sqlite3_step, SQLITE_DONE, SQLITE_ROW};
use std::future::Future;
// Each SQLite connection has a dedicated thread.
// TODO: Tweak this so that we can use a thread pool per pool of SQLite3 connections to reduce
@ -18,31 +23,60 @@ pub(crate) struct StatementWorker {
enum StatementWorkerCommand {
Step {
statement: StatementHandle,
statement: Weak<StatementHandle>,
tx: oneshot::Sender<Result<Either<u64, ()>, Error>>,
},
Reset {
statement: Weak<StatementHandle>,
tx: oneshot::Sender<()>,
},
}
impl StatementWorker {
pub(crate) fn new() -> Self {
pub(crate) fn new(conn: ConnectionHandleRef) -> Self {
let (tx, rx) = unbounded();
thread::spawn(move || {
for cmd in rx {
match cmd {
StatementWorkerCommand::Step { statement, tx } => {
let status = unsafe { sqlite3_step(statement.0.as_ptr()) };
let statement = if let Some(statement) = statement.upgrade() {
statement
} else {
// statement is already finalized, the sender shouldn't be expecting a response
continue;
};
let resp = match status {
// SAFETY: only the `StatementWorker` calls this function
let status = unsafe { sqlite3_step(statement.as_ptr()) };
let result = match status {
SQLITE_ROW => Ok(Either::Right(())),
SQLITE_DONE => Ok(Either::Left(statement.changes())),
_ => Err(statement.last_error().into()),
};
let _ = tx.send(resp);
let _ = tx.send(result);
}
StatementWorkerCommand::Reset { statement, tx } => {
if let Some(statement) = statement.upgrade() {
// SAFETY: this must be the only place we call `sqlite3_reset`
unsafe { sqlite3_reset(statement.as_ptr()) };
// `sqlite3_reset()` always returns either `SQLITE_OK`
// or the last error code for the statement,
// which should have already been handled;
// so it's assumed the return value is safe to ignore.
//
// https://www.sqlite.org/c3ref/reset.html
let _ = tx.send(());
}
}
}
}
// SAFETY: we need to make sure a strong ref to `conn` always outlives anything in `rx`
drop(conn);
});
Self { tx }
@ -50,14 +84,47 @@ impl StatementWorker {
pub(crate) async fn step(
&mut self,
statement: StatementHandle,
statement: &Arc<StatementHandle>,
) -> Result<Either<u64, ()>, Error> {
let (tx, rx) = oneshot::channel();
self.tx
.send(StatementWorkerCommand::Step { statement, tx })
.send(StatementWorkerCommand::Step {
statement: Arc::downgrade(statement),
tx,
})
.map_err(|_| Error::WorkerCrashed)?;
rx.await.map_err(|_| Error::WorkerCrashed)?
}
/// Send a command to the worker to execute `sqlite3_reset()` next.
///
/// This method is written to execute the sending of the command eagerly so
/// you do not need to await the returned future unless you want to.
///
/// The only error is `WorkerCrashed` as `sqlite3_reset()` returns the last error
/// in the statement execution which should have already been handled from `step()`.
pub(crate) fn reset(
&mut self,
statement: &Arc<StatementHandle>,
) -> impl Future<Output = Result<(), Error>> {
// execute the sending eagerly so we don't need to spawn the future
let (tx, rx) = oneshot::channel();
let send_res = self
.tx
.send(StatementWorkerCommand::Reset {
statement: Arc::downgrade(statement),
tx,
})
.map_err(|_| Error::WorkerCrashed);
async move {
send_res?;
// wait for the response
rx.await.map_err(|_| Error::WorkerCrashed)
}
}
}

Binary file not shown.

View file

@ -536,3 +536,34 @@ async fn it_resets_prepared_statement_after_fetch_many() -> anyhow::Result<()> {
Ok(())
}
// https://github.com/launchbadge/sqlx/issues/1300
#[sqlx_macros::test]
async fn concurrent_resets_dont_segfault() {
use sqlx::{sqlite::SqliteConnectOptions, ConnectOptions};
use std::{str::FromStr, time::Duration};
let mut conn = SqliteConnectOptions::from_str(":memory:")
.unwrap()
.connect()
.await
.unwrap();
sqlx::query("CREATE TABLE stuff (name INTEGER, value INTEGER)")
.execute(&mut conn)
.await
.unwrap();
sqlx_rt::spawn(async move {
for i in 0..1000 {
sqlx::query("INSERT INTO stuff (name, value) VALUES (?, ?)")
.bind(i)
.bind(0)
.execute(&mut conn)
.await
.unwrap();
}
});
sqlx_rt::sleep(Duration::from_millis(1)).await;
}