sqlite: implement support for multiple statements

This commit is contained in:
Ryan Leckey 2020-03-14 17:04:02 -07:00
parent 0130fe1479
commit 63ef32189d
10 changed files with 256 additions and 91 deletions

View file

@ -4,7 +4,7 @@ use std::os::raw::c_int;
use libsqlite3_sys::{
sqlite3_bind_blob, sqlite3_bind_double, sqlite3_bind_int, sqlite3_bind_int64,
sqlite3_bind_null, sqlite3_bind_text, SQLITE_OK,
sqlite3_bind_null, sqlite3_bind_text, SQLITE_OK, SQLITE_TRANSIENT,
};
use crate::arguments::Arguments;
@ -13,6 +13,7 @@ use crate::sqlite::statement::SqliteStatement;
use crate::sqlite::Sqlite;
use crate::sqlite::SqliteError;
use crate::types::Type;
use core::mem;
#[derive(Debug, Clone)]
pub enum SqliteArgumentValue {
@ -33,7 +34,22 @@ pub enum SqliteArgumentValue {
#[derive(Default)]
pub struct SqliteArguments {
pub(super) values: Vec<SqliteArgumentValue>,
index: usize,
values: Vec<SqliteArgumentValue>,
}
impl SqliteArguments {
pub(crate) fn next(&mut self) -> Option<SqliteArgumentValue> {
if self.index >= self.values.len() {
return None;
}
let mut value = SqliteArgumentValue::Null;
mem::swap(&mut value, &mut self.values[self.index]);
self.index += 1;
Some(value)
}
}
impl Arguments for SqliteArguments {
@ -66,7 +82,13 @@ impl SqliteArgumentValue {
let bytes_len = bytes.len() as i32;
unsafe {
sqlite3_bind_blob(statement.handle.as_ptr(), index, bytes_ptr, bytes_len, None)
sqlite3_bind_blob(
statement.handle.as_ptr(),
index,
bytes_ptr,
bytes_len,
SQLITE_TRANSIENT(),
)
}
}
@ -77,7 +99,13 @@ impl SqliteArgumentValue {
let bytes_len = bytes.len() as i32;
unsafe {
sqlite3_bind_text(statement.handle.as_ptr(), index, bytes_ptr, bytes_len, None)
sqlite3_bind_text(
statement.handle.as_ptr(),
index,
bytes_ptr,
bytes_len,
SQLITE_TRANSIENT(),
)
}
}

View file

@ -12,13 +12,15 @@ use libsqlite3_sys::{
};
use crate::connection::{Connect, Connection};
use crate::runtime::spawn_blocking;
use crate::sqlite::statement::SqliteStatement;
use crate::sqlite::SqliteError;
use crate::url::Url;
pub struct SqliteConnection {
pub(super) handle: NonNull<sqlite3>,
// Storage of the most recently prepared, non-persistent statement
pub(super) statement: Option<SqliteStatement>,
// Storage of persistent statements
pub(super) statements: Vec<SqliteStatement>,
pub(super) statement_by_query: HashMap<String, usize>,
}
@ -66,6 +68,7 @@ fn establish(url: crate::Result<Url>) -> crate::Result<SqliteConnection> {
Ok(SqliteConnection {
handle: NonNull::new(handle).unwrap(),
statement: None,
statements: Vec::with_capacity(10),
statement_by_query: HashMap::with_capacity(10),
})

View file

@ -3,23 +3,15 @@ use futures_core::future::BoxFuture;
use crate::connection::ConnectionSource;
use crate::cursor::Cursor;
use crate::executor::Execute;
use crate::maybe_owned::MaybeOwned;
use crate::pool::Pool;
use crate::sqlite::statement::{SqliteStatement, Step};
use crate::sqlite::statement::Step;
use crate::sqlite::{Sqlite, SqliteArguments, SqliteConnection, SqliteRow};
enum State<'q> {
Query((&'q str, Option<SqliteArguments>)),
Statement {
query: &'q str,
arguments: Option<SqliteArguments>,
statement: MaybeOwned<SqliteStatement, usize>,
},
}
pub struct SqliteCursor<'c, 'q> {
source: ConnectionSource<'c, SqliteConnection>,
state: State<'q>,
pub(super) source: ConnectionSource<'c, SqliteConnection>,
query: &'q str,
arguments: Option<SqliteArguments>,
pub(super) statement: Option<Option<usize>>,
}
impl<'c, 'q> Cursor<'c, 'q> for SqliteCursor<'c, 'q> {
@ -30,9 +22,13 @@ impl<'c, 'q> Cursor<'c, 'q> for SqliteCursor<'c, 'q> {
Self: Sized,
E: Execute<'q, Sqlite>,
{
let (query, arguments) = query.into_parts();
Self {
source: ConnectionSource::Pool(pool.clone()),
state: State::Query(query.into_parts()),
statement: None,
query,
arguments,
}
}
@ -41,9 +37,13 @@ impl<'c, 'q> Cursor<'c, 'q> for SqliteCursor<'c, 'q> {
Self: Sized,
E: Execute<'q, Sqlite>,
{
let (query, arguments) = query.into_parts();
Self {
source: ConnectionSource::Connection(conn.into()),
state: State::Query(query.into_parts()),
statement: None,
query,
arguments,
}
}
@ -57,41 +57,38 @@ async fn next<'a, 'c: 'a, 'q: 'a>(
) -> crate::Result<Option<SqliteRow<'a>>> {
let conn = cursor.source.resolve().await?;
let statement = loop {
match cursor.state {
State::Query((query, ref mut arguments)) => {
let mut statement = conn.prepare(query, arguments.is_some())?;
let statement_ = statement.resolve(&mut conn.statements);
loop {
if cursor.statement.is_none() {
let key = conn.prepare(&mut cursor.query, cursor.arguments.is_some())?;
if let Some(arguments) = arguments {
statement_.bind(arguments)?;
}
cursor.state = State::Statement {
statement,
query,
arguments: arguments.take(),
};
if let Some(arguments) = &mut cursor.arguments {
conn.statement_mut(key).bind(arguments)?;
}
State::Statement {
ref mut statement, ..
} => {
break statement;
cursor.statement = Some(key);
}
let key = cursor.statement.unwrap();
let statement = conn.statement_mut(key);
let step = statement.step().await?;
match step {
Step::Row => {
return Ok(Some(SqliteRow {
statement: key,
connection: conn,
}));
}
Step::Done if cursor.query.is_empty() => {
return Ok(None);
}
Step::Done => {
cursor.statement = None;
// continue
}
}
};
let statement_ = statement.resolve(&mut conn.statements);
match statement_.step().await? {
Step::Done => {
// TODO: If there is more to do, we need to do more
Ok(None)
}
Step::Row => Ok(Some(SqliteRow {
statement: &*statement_,
})),
}
}

View file

@ -10,8 +10,7 @@ impl Database for Sqlite {
type TypeInfo = super::SqliteTypeInfo;
// TODO?
type TableId = u32;
type TableId = String;
type RawBuffer = Vec<super::SqliteArgumentValue>;
}

View file

@ -5,46 +5,67 @@ use libsqlite3_sys::sqlite3_changes;
use crate::cursor::Cursor;
use crate::describe::{Column, Describe};
use crate::executor::{Execute, Executor, RefExecutor};
use crate::maybe_owned::MaybeOwned;
use crate::sqlite::cursor::SqliteCursor;
use crate::sqlite::statement::{SqliteStatement, Step};
use crate::sqlite::types::SqliteType;
use crate::sqlite::{Sqlite, SqliteConnection, SqliteTypeInfo};
impl SqliteConnection {
pub(super) fn statement(&self, key: Option<usize>) -> &SqliteStatement {
match key {
Some(key) => &self.statements[key],
None => self.statement.as_ref().unwrap(),
}
}
pub(super) fn statement_mut(&mut self, key: Option<usize>) -> &mut SqliteStatement {
match key {
Some(key) => &mut self.statements[key],
None => self.statement.as_mut().unwrap(),
}
}
pub(super) fn prepare(
&mut self,
query: &str,
query: &mut &str,
persistent: bool,
) -> crate::Result<MaybeOwned<SqliteStatement, usize>> {
) -> crate::Result<Option<usize>> {
// TODO: Revisit statement caching and allow cache expiration by using a
// generational index
if !persistent {
// A non-persistent query will be immediately prepared and returned
return SqliteStatement::new(&mut self.handle, query, false).map(MaybeOwned::Owned);
// A non-persistent query will be immediately prepared and returned,
// regardless of the current state of the cache
self.statement = Some(SqliteStatement::new(&mut self.handle, query, false)?);
return Ok(None);
}
if let Some(key) = self.statement_by_query.get(query) {
if let Some(key) = self.statement_by_query.get(&**query) {
let statement = &mut self.statements[*key];
// Adjust the passed in query string as if [string3_prepare]
// did the tail parsing
*query = &query[statement.tail..];
// As this statement has very likely been used before, we reset
// it to clear the bindings and its program state
statement.reset();
return Ok(MaybeOwned::Borrowed(*key));
return Ok(Some(*key));
}
// Prepare a new statement object; ensuring to tell SQLite that this will be stored
// for a "long" time and re-used multiple times
let query_key = query.to_owned();
let statement = SqliteStatement::new(&mut self.handle, query, true)?;
let key = self.statements.len();
self.statement_by_query.insert(query.to_owned(), key);
self.statements
.push(SqliteStatement::new(&mut self.handle, query, true)?);
self.statement_by_query.insert(query_key, key);
self.statements.push(statement);
Ok(MaybeOwned::Borrowed(key))
Ok(Some(key))
}
// This is used for [affected_rows] in the public API.
@ -72,15 +93,21 @@ impl Executor for SqliteConnection {
let (mut query, mut arguments) = query.into_parts();
Box::pin(async move {
let mut statement = self.prepare(query, arguments.is_some())?;
let statement_ = statement.resolve(&mut self.statements);
loop {
let key = self.prepare(&mut query, arguments.is_some())?;
let statement = self.statement_mut(key);
if let Some(arguments) = &mut arguments {
statement_.bind(arguments)?;
}
if let Some(arguments) = &mut arguments {
statement.bind(arguments)?;
}
while let Step::Row = statement_.step().await? {
// We only care about the rows modified; ignore
while let Step::Row = statement.step().await? {
// We only care about the rows modified; ignore
}
if query.is_empty() {
break;
}
}
Ok(self.changes())
@ -102,9 +129,9 @@ impl Executor for SqliteConnection {
E: Execute<'q, Self::Database>,
{
Box::pin(async move {
let (query, _) = query.into_parts();
let mut statement = self.prepare(query, false)?;
let statement = statement.resolve(&mut self.statements);
let (mut query, _) = query.into_parts();
let key = self.prepare(&mut query, false)?;
let statement = self.statement(key);
// First let's attempt to describe what we can about parameter types
// Which happens to just be the count, heh

View file

@ -5,10 +5,17 @@ use crate::database::HasRow;
use crate::row::{ColumnIndex, Row};
use crate::sqlite::statement::SqliteStatement;
use crate::sqlite::value::SqliteResultValue;
use crate::sqlite::Sqlite;
use crate::sqlite::{Sqlite, SqliteConnection};
pub struct SqliteRow<'c> {
pub(super) statement: &'c SqliteStatement,
pub(super) statement: Option<usize>,
pub(super) connection: &'c SqliteConnection,
}
impl SqliteRow<'_> {
fn statement(&self) -> &SqliteStatement {
self.connection.statement(self.statement)
}
}
impl<'c> Row<'c> for SqliteRow<'c> {
@ -24,7 +31,7 @@ impl<'c> Row<'c> for SqliteRow<'c> {
// sqlite3_step that returned SQLITE_ROW.
#[allow(unsafe_code)]
let count: c_int = unsafe { sqlite3_data_count(self.statement.handle.as_ptr()) };
let count: c_int = unsafe { sqlite3_data_count(self.statement().handle.as_ptr()) };
count as usize
}
@ -36,6 +43,7 @@ impl<'c> Row<'c> for SqliteRow<'c> {
let index = index.resolve(self)?;
let value = SqliteResultValue {
index,
connection: self.connection,
statement: self.statement,
};
@ -57,7 +65,7 @@ impl ColumnIndex<Sqlite> for usize {
impl ColumnIndex<Sqlite> for &'_ str {
fn resolve(self, row: &<Sqlite as HasRow>::Row) -> crate::Result<usize> {
row.statement
row.statement()
.columns()
.get(self)
.ok_or_else(|| crate::Error::ColumnNotFound((*self).into()))

View file

@ -1,6 +1,6 @@
use core::cell::{RefCell, RefMut};
use core::ops::Deref;
use core::ptr::{null_mut, NonNull};
use core::ptr::{null, null_mut, NonNull};
use std::collections::HashMap;
use std::ffi::CStr;
@ -22,6 +22,7 @@ pub(crate) enum Step {
}
pub struct SqliteStatement {
pub(super) tail: usize,
pub(super) handle: NonNull<sqlite3_stmt>,
columns: RefCell<Option<HashMap<String, usize>>>,
}
@ -37,7 +38,7 @@ unsafe impl Sync for SqliteStatement {}
impl SqliteStatement {
pub(super) fn new(
handle: &mut NonNull<sqlite3>,
query: &str,
query: &mut &str,
persistent: bool,
) -> crate::Result<Self> {
// TODO: Error on queries that are too large
@ -45,6 +46,7 @@ impl SqliteStatement {
let query_len = query.len() as i32;
let mut statement_handle: *mut sqlite3_stmt = null_mut();
let mut flags = SQLITE_PREPARE_NO_VTAB;
let mut tail: *const i8 = null();
if persistent {
// SQLITE_PREPARE_PERSISTENT
@ -63,10 +65,15 @@ impl SqliteStatement {
query_len,
flags as u32,
&mut statement_handle,
null_mut(),
&mut tail,
)
};
// If pzTail is not NULL then *pzTail is made to point to the first byte
// past the end of the first SQL statement in zSql.
let tail = (tail as usize) - (query_ptr as usize);
*query = &query[tail..].trim();
if status != SQLITE_OK {
return Err(SqliteError::new(status).into());
}
@ -74,6 +81,7 @@ impl SqliteStatement {
Ok(Self {
handle: NonNull::new(statement_handle).unwrap(),
columns: RefCell::new(None),
tail,
})
}
@ -132,8 +140,12 @@ impl SqliteStatement {
}
pub(super) fn bind(&mut self, arguments: &mut SqliteArguments) -> crate::Result<()> {
for (index, value) in arguments.values.iter().enumerate() {
value.bind(self, index + 1)?;
for index in 0..self.params() {
if let Some(value) = arguments.next() {
value.bind(self, index + 1)?;
} else {
break;
}
}
Ok(())

View file

@ -8,11 +8,19 @@ use libsqlite3_sys::{
use crate::sqlite::statement::SqliteStatement;
use crate::sqlite::types::SqliteType;
use crate::sqlite::SqliteConnection;
use core::slice;
pub struct SqliteResultValue<'c> {
pub(super) index: usize,
pub(super) statement: &'c SqliteStatement,
pub(super) statement: Option<usize>,
pub(super) connection: &'c SqliteConnection,
}
impl SqliteResultValue<'_> {
fn statement(&self) -> &SqliteStatement {
self.connection.statement(self.statement)
}
}
// https://www.sqlite.org/c3ref/column_blob.html
@ -24,7 +32,7 @@ impl<'c> SqliteResultValue<'c> {
pub(crate) fn r#type(&self) -> SqliteType {
#[allow(unsafe_code)]
let type_code =
unsafe { sqlite3_column_type(self.statement.handle.as_ptr(), self.index as i32) };
unsafe { sqlite3_column_type(self.statement().handle.as_ptr(), self.index as i32) };
match type_code {
SQLITE_INTEGER => SqliteType::Integer,
@ -40,21 +48,21 @@ impl<'c> SqliteResultValue<'c> {
pub(crate) fn int(&self) -> i32 {
#[allow(unsafe_code)]
unsafe {
sqlite3_column_int(self.statement.handle.as_ptr(), self.index as i32)
sqlite3_column_int(self.statement().handle.as_ptr(), self.index as i32)
}
}
pub(crate) fn int64(&self) -> i64 {
#[allow(unsafe_code)]
unsafe {
sqlite3_column_int64(self.statement.handle.as_ptr(), self.index as i32)
sqlite3_column_int64(self.statement().handle.as_ptr(), self.index as i32)
}
}
pub(crate) fn double(&self) -> f64 {
#[allow(unsafe_code)]
unsafe {
sqlite3_column_double(self.statement.handle.as_ptr(), self.index as i32)
sqlite3_column_double(self.statement().handle.as_ptr(), self.index as i32)
}
}
@ -62,7 +70,8 @@ impl<'c> SqliteResultValue<'c> {
#[allow(unsafe_code)]
let raw = unsafe {
CStr::from_ptr(
sqlite3_column_text(self.statement.handle.as_ptr(), self.index as i32) as *const i8,
sqlite3_column_text(self.statement().handle.as_ptr(), self.index as i32)
as *const i8,
)
};
@ -73,10 +82,10 @@ impl<'c> SqliteResultValue<'c> {
let index = self.index as i32;
#[allow(unsafe_code)]
let ptr = unsafe { sqlite3_column_blob(self.statement.handle.as_ptr(), index) };
let ptr = unsafe { sqlite3_column_blob(self.statement().handle.as_ptr(), index) };
#[allow(unsafe_code)]
let len = unsafe { sqlite3_column_bytes(self.statement.handle.as_ptr(), index) };
let len = unsafe { sqlite3_column_bytes(self.statement().handle.as_ptr(), index) };
#[allow(unsafe_code)]
let raw = unsafe { slice::from_raw_parts(ptr as *const u8, len as usize) };

View file

@ -2,3 +2,51 @@
use sqlx::{Cursor, Executor, Row, Sqlite};
use sqlx_test::new;
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn test_select_expression() -> anyhow::Result<()> {
let mut conn = new::<Sqlite>().await?;
let mut cursor = conn.fetch("SELECT 5");
let row = cursor.next().await?.unwrap();
assert!(5i32 == row.try_get::<i32, _>(0)?);
Ok(())
}
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn test_multi_read_write() -> anyhow::Result<()> {
let mut conn = new::<Sqlite>().await?;
let mut cursor = conn.fetch(
"
CREATE TABLE IF NOT EXISTS _sqlx_test (
id INT PRIMARY KEY,
text TEXT NOT NULL
);
SELECT 'Hello World' as _1;
INSERT INTO _sqlx_test (text) VALUES ('this is a test');
SELECT id, text FROM _sqlx_test;
",
);
let row = cursor.next().await?.unwrap();
assert!("Hello World" == row.try_get::<&str, _>("_1")?);
let row = cursor.next().await?.unwrap();
let id: i64 = row.try_get("id")?;
let text: &str = row.try_get("text")?;
assert_eq!(0, id);
assert_eq!("this is a test", text);
Ok(())
}

View file

@ -54,6 +54,40 @@ CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY)
Ok(())
}
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn it_can_execute_multiple_statements() -> anyhow::Result<()> {
let mut conn = new::<Sqlite>().await?;
let affected = conn
.execute(
r#"
CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY, other INTEGER);
INSERT INTO users DEFAULT VALUES;
"#,
)
.await?;
assert_eq!(affected, 1);
for index in 2..5_i32 {
let (id, other): (i32, i32) = sqlx::query_as(
r#"
INSERT INTO users (other) VALUES (?);
SELECT id, other FROM users WHERE id = last_insert_rowid();
"#,
)
.bind(index)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, index);
assert_eq!(other, index);
}
Ok(())
}
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn it_describes() -> anyhow::Result<()> {