sqlite: implement remainder of query API

This commit is contained in:
Ryan Leckey 2020-03-14 14:09:27 -07:00
parent a3799c3496
commit 5f27026459
16 changed files with 278 additions and 125 deletions

View file

@ -47,7 +47,7 @@ pub(crate) enum ConnectionSource<'c, C>
where
C: Connect,
{
Connection(MaybeOwned<'c, PoolConnection<C>, C>),
Connection(MaybeOwned<PoolConnection<C>, &'c mut C>),
#[allow(dead_code)]
Pool(Pool<C>),

View file

@ -1,21 +1,31 @@
use core::borrow::{Borrow, BorrowMut};
use core::ops::{Deref, DerefMut};
pub(crate) enum MaybeOwned<'a, O, B = O> {
pub(crate) enum MaybeOwned<O, B> {
#[allow(dead_code)]
Borrowed(&'a mut B),
Borrowed(B),
#[allow(dead_code)]
Owned(O),
}
impl<'a, O, B> From<&'a mut B> for MaybeOwned<'a, O, B> {
impl<O> MaybeOwned<O, usize> {
#[allow(dead_code)]
pub(crate) fn resolve<'a, 'b: 'a>(&'a mut self, collection: &'b mut Vec<O>) -> &'a mut O {
match self {
MaybeOwned::Owned(ref mut val) => val,
MaybeOwned::Borrowed(index) => &mut collection[*index],
}
}
}
impl<'a, O, B> From<&'a mut B> for MaybeOwned<O, &'a mut B> {
fn from(val: &'a mut B) -> Self {
MaybeOwned::Borrowed(val)
}
}
impl<'a, O, B> Deref for MaybeOwned<'a, O, B>
impl<O, B> Deref for MaybeOwned<O, B>
where
O: Borrow<B>,
{
@ -29,7 +39,7 @@ where
}
}
impl<'a, O, B> DerefMut for MaybeOwned<'a, O, B>
impl<O, B> DerefMut for MaybeOwned<O, B>
where
O: BorrowMut<B>,
{

View file

@ -1,7 +1,9 @@
use core::ffi::c_void;
use libc::c_int;
use libsqlite3_sys::{
sqlite3_bind_double, sqlite3_bind_int, sqlite3_bind_int64, sqlite3_bind_null,
sqlite3_bind_text, SQLITE_OK,
sqlite3_bind_blob, sqlite3_bind_double, sqlite3_bind_int, sqlite3_bind_int64,
sqlite3_bind_null, sqlite3_bind_text, SQLITE_OK,
};
use crate::arguments::Arguments;
@ -57,7 +59,14 @@ impl SqliteArgumentValue {
#[allow(unsafe_code)]
let status: c_int = match self {
SqliteArgumentValue::Blob(value) => {
todo!("bind/blob");
// TODO: Handle bytes that are too large
let bytes = value.as_slice();
let bytes_ptr = bytes.as_ptr() as *const c_void;
let bytes_len = bytes.len() as i32;
unsafe {
sqlite3_bind_blob(statement.handle.as_ptr(), index, bytes_ptr, bytes_len, None)
}
}
SqliteArgumentValue::Text(value) => {

View file

@ -3,11 +3,12 @@ use core::ptr::{null, null_mut, NonNull};
use std::collections::HashMap;
use std::convert::TryInto;
use std::ffi::CString;
use std::sync::Arc;
use futures_core::future::BoxFuture;
use futures_util::future;
use libsqlite3_sys::{
sqlite3, sqlite3_open_v2, SQLITE_OK, SQLITE_OPEN_CREATE, SQLITE_OPEN_NOMUTEX,
sqlite3, sqlite3_close, sqlite3_open_v2, SQLITE_OK, SQLITE_OPEN_CREATE, SQLITE_OPEN_NOMUTEX,
SQLITE_OPEN_READWRITE, SQLITE_OPEN_SHAREDCACHE,
};
@ -19,7 +20,9 @@ use crate::url::Url;
pub struct SqliteConnection {
pub(super) handle: NonNull<sqlite3>,
pub(super) cache_statement: HashMap<String, SqliteStatement>,
pub(super) statements: Vec<SqliteStatement>,
pub(super) statement_by_query: HashMap<String, usize>,
pub(super) columns_by_query: HashMap<String, Arc<HashMap<String, usize>>>,
}
// SAFE: A sqlite3 handle is safe to access from multiple threads provided
@ -40,7 +43,10 @@ unsafe impl Sync for SqliteConnection {}
fn establish(url: crate::Result<Url>) -> crate::Result<SqliteConnection> {
let url = url?;
let url = url.as_str().trim_start_matches("sqlite://");
let url = url
.as_str()
.trim_start_matches("sqlite:")
.trim_start_matches("//");
// By default, we connect to an in-memory database.
// TODO: Handle the error when there are internal NULs in the database URL
@ -62,7 +68,9 @@ fn establish(url: crate::Result<Url>) -> crate::Result<SqliteConnection> {
Ok(SqliteConnection {
handle: NonNull::new(handle).unwrap(),
cache_statement: HashMap::new(),
statements: Vec::with_capacity(10),
statement_by_query: HashMap::with_capacity(10),
columns_by_query: HashMap::new(),
})
}
@ -79,8 +87,8 @@ impl Connect for SqliteConnection {
impl Connection for SqliteConnection {
fn close(self) -> BoxFuture<'static, crate::Result<()>> {
// Box::pin(terminate(self.stream))
todo!()
// All necessary behavior is handled on drop
Box::pin(future::ok(()))
}
fn ping(&mut self) -> BoxFuture<crate::Result<()>> {
@ -88,3 +96,17 @@ impl Connection for SqliteConnection {
Box::pin(future::ok(()))
}
}
impl Drop for SqliteConnection {
fn drop(&mut self) {
// Drop all statements first
self.statements.clear();
// Next close the statement
// https://sqlite.org/c3ref/close.html
#[allow(unsafe_code)]
unsafe {
let _ = sqlite3_close(self.handle.as_ptr());
}
}
}

View file

@ -1,5 +1,3 @@
use core::mem::take;
use std::collections::HashMap;
use std::sync::Arc;
@ -8,29 +6,22 @@ use futures_core::future::BoxFuture;
use crate::connection::ConnectionSource;
use crate::cursor::Cursor;
use crate::executor::Execute;
use crate::maybe_owned::MaybeOwned;
use crate::pool::Pool;
use crate::sqlite::statement::{SqliteStatement, Step};
use crate::sqlite::{Sqlite, SqliteArguments, SqliteConnection, SqliteRow};
enum State<'q> {
Empty,
Query((&'q str, Option<SqliteArguments>)),
Statement {
query: &'q str,
arguments: Option<SqliteArguments>,
statement: SqliteStatement,
statement: MaybeOwned<SqliteStatement, usize>,
},
}
impl Default for State<'_> {
fn default() -> Self {
State::Empty
}
}
pub struct SqliteCursor<'c, 'q> {
source: ConnectionSource<'c, SqliteConnection>,
// query: Option<(&'q str, Option<SqliteArguments>)>,
columns: Arc<HashMap<Box<str>, usize>>,
state: State<'q>,
}
@ -76,9 +67,10 @@ async fn next<'a, 'c: 'a, 'q: 'a>(
match cursor.state {
State::Query((query, ref mut arguments)) => {
let mut statement = conn.prepare(query, arguments.is_some())?;
let statement_ = statement.resolve(&mut conn.statements);
if let Some(arguments) = arguments {
statement.bind(arguments)?;
statement_.bind(arguments)?;
}
cursor.state = State::Statement {
@ -93,44 +85,19 @@ async fn next<'a, 'c: 'a, 'q: 'a>(
} => {
break statement;
}
State::Empty => unreachable!("use after drop"),
}
};
match statement.step().await? {
let statement_ = statement.resolve(&mut conn.statements);
match statement_.step().await? {
Step::Done => {
// TODO: If there is more to do, we need to do more
Ok(None)
}
Step::Row => Ok(Some(SqliteRow {
statement,
columns: Arc::default(),
statement: &*statement_,
})),
}
}
// If there is a statement on our WIP object
// Put it back into the cache IFF this is a persistent query
impl<'c, 'q> Drop for SqliteCursor<'c, 'q> {
fn drop(&mut self) {
match take(&mut self.state) {
State::Statement {
query,
arguments,
statement,
} => {
if arguments.is_some() {
if let ConnectionSource::Connection(connection) = &mut self.source {
connection
.cache_statement
.insert(query.to_owned(), statement);
}
}
}
_ => {}
}
}
}

View file

@ -5,30 +5,54 @@ use libsqlite3_sys::sqlite3_changes;
use crate::cursor::Cursor;
use crate::describe::Describe;
use crate::executor::{Execute, Executor, RefExecutor};
use crate::maybe_owned::MaybeOwned;
use crate::sqlite::arguments::SqliteArguments;
use crate::sqlite::cursor::SqliteCursor;
use crate::sqlite::statement::{SqliteStatement, Step};
use crate::sqlite::{Sqlite, SqliteConnection};
use std::collections::HashMap;
impl SqliteConnection {
pub(super) fn prepare(
&mut self,
query: &str,
persistent: bool,
) -> crate::Result<SqliteStatement> {
if let Some(mut statement) = self.cache_statement.remove(&*query) {
) -> crate::Result<MaybeOwned<SqliteStatement, usize>> {
// TODO: Revisit statement caching and allow cache expiration by using a
// generational index
if !persistent {
// A non-persistent query will be immediately prepared and returned
return SqliteStatement::new(&mut self.handle, query, false).map(MaybeOwned::Owned);
}
if let Some(key) = self.statement_by_query.get(query) {
let statement = &mut self.statements[*key];
// As this statement has very likely been used before, we reset
// it to clear the bindings and its program state
statement.reset();
Ok(statement)
} else {
SqliteStatement::new(&mut self.handle, query, persistent)
return Ok(MaybeOwned::Borrowed(*key));
}
// Prepare a new statement object; ensuring to tell SQLite that this will be stored
// for a "long" time and re-used multiple times
let key = self.statements.len();
self.statement_by_query.insert(query.to_owned(), key);
self.statements
.push(SqliteStatement::new(&mut self.handle, query, true)?);
Ok(MaybeOwned::Borrowed(key))
}
// This is used for [affected_rows] in the public API.
fn changes(&mut self) -> u64 {
// Returns the number of rows modified, inserted or deleted by the most recently
// completed INSERT, UPDATE or DELETE statement.
// https://www.sqlite.org/c3ref/changes.html
#[allow(unsafe_code)]
let changes = unsafe { sqlite3_changes(self.handle.as_ptr()) };
@ -46,14 +70,22 @@ impl Executor for SqliteConnection {
where
E: Execute<'q, Self::Database>,
{
Box::pin(
AffectedRows::<'c, 'q> {
connection: self,
query: query.into_parts(),
statement: None,
let (mut query, mut arguments) = query.into_parts();
Box::pin(async move {
let mut statement = self.prepare(query, arguments.is_some())?;
let mut statement_ = statement.resolve(&mut self.statements);
if let Some(arguments) = &mut arguments {
statement_.bind(arguments)?;
}
.get(),
)
while let Step::Row = statement_.step().await? {
// We only care about the rows modified; ignore
}
Ok(self.changes())
})
}
fn fetch<'q, E>(&mut self, query: E) -> SqliteCursor<'_, 'q>
@ -85,41 +117,3 @@ impl<'e> RefExecutor<'e> for &'e mut SqliteConnection {
SqliteCursor::from_connection(self, query)
}
}
struct AffectedRows<'c, 'q> {
query: (&'q str, Option<SqliteArguments>),
connection: &'c mut SqliteConnection,
statement: Option<SqliteStatement>,
}
impl AffectedRows<'_, '_> {
async fn get(mut self) -> crate::Result<u64> {
let mut statement = self
.connection
.prepare(self.query.0, self.query.1.is_some())?;
if let Some(arguments) = &mut self.query.1 {
statement.bind(arguments)?;
}
while let Step::Row = statement.step().await? {
// we only care about the rows modified; ignore
}
Ok(self.connection.changes())
}
}
impl Drop for AffectedRows<'_, '_> {
fn drop(&mut self) {
// If there is a statement on our WIP object
// Put it back into the cache IFF this is a persistent query
if self.query.1.is_some() {
if let Some(statement) = self.statement.take() {
self.connection
.cache_statement
.insert(self.query.0.to_owned(), statement);
}
}
}
}

View file

@ -23,5 +23,4 @@ pub type SqlitePool = crate::pool::Pool<SqliteConnection>;
make_query_as!(SqliteQueryAs, Sqlite, SqliteRow);
impl_map_row_for_row!(Sqlite, SqliteRow);
impl_column_index_for_row!(Sqlite);
impl_from_row_for_tuples!(Sqlite, SqliteRow);

View file

@ -4,6 +4,7 @@ use std::sync::Arc;
use libc::c_int;
use libsqlite3_sys::sqlite3_data_count;
use crate::database::HasRow;
use crate::row::{ColumnIndex, Row};
use crate::sqlite::statement::SqliteStatement;
use crate::sqlite::value::SqliteResultValue;
@ -11,8 +12,6 @@ use crate::sqlite::Sqlite;
pub struct SqliteRow<'c> {
pub(super) statement: &'c SqliteStatement,
// TODO
pub(super) columns: Arc<HashMap<Box<str>, u16>>,
}
impl<'c> Row<'c> for SqliteRow<'c> {
@ -46,3 +45,25 @@ impl<'c> Row<'c> for SqliteRow<'c> {
Ok(value)
}
}
impl ColumnIndex<Sqlite> for usize {
fn resolve(self, row: &<Sqlite as HasRow>::Row) -> crate::Result<usize> {
let len = Row::len(row);
if self >= len {
return Err(crate::Error::ColumnIndexOutOfBounds { len, index: self });
}
Ok(self)
}
}
impl ColumnIndex<Sqlite> for &'_ str {
fn resolve(self, row: &<Sqlite as HasRow>::Row) -> crate::Result<usize> {
row.statement
.columns()
.get(self)
.ok_or_else(|| crate::Error::ColumnNotFound((*self).into()))
.map(|&index| index as usize)
}
}

View file

@ -1,8 +1,14 @@
use core::cell::{RefCell, RefMut};
use core::ops::Deref;
use core::ptr::{null_mut, NonNull};
use std::collections::HashMap;
use std::ffi::CStr;
use libsqlite3_sys::{
sqlite3, sqlite3_finalize, sqlite3_prepare_v3, sqlite3_reset, sqlite3_step, sqlite3_stmt,
SQLITE_DONE, SQLITE_OK, SQLITE_PREPARE_NO_VTAB, SQLITE_PREPARE_PERSISTENT, SQLITE_ROW,
sqlite3, sqlite3_column_count, sqlite3_column_name, sqlite3_finalize, sqlite3_prepare_v3,
sqlite3_reset, sqlite3_step, sqlite3_stmt, SQLITE_DONE, SQLITE_OK, SQLITE_PREPARE_NO_VTAB,
SQLITE_PREPARE_PERSISTENT, SQLITE_ROW,
};
use crate::sqlite::SqliteArguments;
@ -15,6 +21,7 @@ pub(crate) enum Step {
pub struct SqliteStatement {
pub(super) handle: NonNull<sqlite3_stmt>,
columns: RefCell<Option<HashMap<String, usize>>>,
}
// SAFE: See notes for the Send impl on [SqliteConnection].
@ -64,6 +71,32 @@ impl SqliteStatement {
Ok(Self {
handle: NonNull::new(statement_handle).unwrap(),
columns: RefCell::new(None),
})
}
pub(super) fn columns<'a>(&'a self) -> impl Deref<Target = HashMap<String, usize>> + 'a {
RefMut::map(self.columns.borrow_mut(), |columns| {
columns.get_or_insert_with(|| {
// https://sqlite.org/c3ref/column_count.html
#[allow(unsafe_code)]
let count = unsafe { sqlite3_column_count(self.handle.as_ptr()) };
let count = count as usize;
let mut columns = HashMap::with_capacity(count);
for i in 0..count {
// https://sqlite.org/c3ref/column_name.html
#[allow(unsafe_code)]
let name =
unsafe { CStr::from_ptr(sqlite3_column_name(self.handle.as_ptr(), 0)) };
let name = name.to_str().unwrap().to_owned();
columns.insert(name, i);
}
columns
})
})
}

View file

@ -0,0 +1,42 @@
use crate::decode::Decode;
use crate::encode::Encode;
use crate::sqlite::types::{SqliteType, SqliteTypeAffinity};
use crate::sqlite::{Sqlite, SqliteArgumentValue, SqliteResultValue, SqliteTypeInfo};
use crate::types::Type;
impl Type<Sqlite> for [u8] {
fn type_info() -> SqliteTypeInfo {
SqliteTypeInfo::new(SqliteType::Blob, SqliteTypeAffinity::Blob)
}
}
impl Type<Sqlite> for Vec<u8> {
fn type_info() -> SqliteTypeInfo {
<[u8] as Type<Sqlite>>::type_info()
}
}
impl Encode<Sqlite> for [u8] {
fn encode(&self, values: &mut Vec<SqliteArgumentValue>) {
// TODO: look into a way to remove this allocation
values.push(SqliteArgumentValue::Blob(self.to_owned()));
}
}
impl Encode<Sqlite> for Vec<u8> {
fn encode(&self, values: &mut Vec<SqliteArgumentValue>) {
<[u8] as Encode<Sqlite>>::encode(self, values)
}
}
impl<'de> Decode<'de, Sqlite> for &'de [u8] {
fn decode(value: SqliteResultValue<'de>) -> crate::Result<&'de [u8]> {
value.blob()
}
}
impl<'de> Decode<'de, Sqlite> for Vec<u8> {
fn decode(value: SqliteResultValue<'de>) -> crate::Result<Vec<u8>> {
<&[u8] as Decode<Sqlite>>::decode(value).map(ToOwned::to_owned)
}
}

View file

@ -6,6 +6,7 @@ use crate::sqlite::Sqlite;
use crate::types::TypeInfo;
mod bool;
mod bytes;
mod float;
mod int;
mod str;

View file

@ -12,7 +12,7 @@ impl Type<Sqlite> for str {
impl Type<Sqlite> for String {
fn type_info() -> SqliteTypeInfo {
SqliteTypeInfo::new(SqliteType::Text, SqliteTypeAffinity::Text)
<str as Type<Sqlite>>::type_info()
}
}
@ -23,6 +23,12 @@ impl Encode<Sqlite> for str {
}
}
impl Encode<Sqlite> for String {
fn encode(&self, values: &mut Vec<SqliteArgumentValue>) {
<str as Encode<Sqlite>>::encode(self, values)
}
}
impl<'de> Decode<'de, Sqlite> for &'de str {
fn decode(value: SqliteResultValue<'de>) -> crate::Result<&'de str> {
value.text()
@ -31,6 +37,6 @@ impl<'de> Decode<'de, Sqlite> for &'de str {
impl<'de> Decode<'de, Sqlite> for String {
fn decode(value: SqliteResultValue<'de>) -> crate::Result<String> {
Ok(value.text()?.to_owned())
<&str as Decode<Sqlite>>::decode(value).map(ToOwned::to_owned)
}
}

View file

@ -1,18 +1,21 @@
use std::ffi::CStr;
use libsqlite3_sys::{
sqlite3_column_double, sqlite3_column_int, sqlite3_column_int64, sqlite3_column_text,
sqlite3_column_type, SQLITE_BLOB, SQLITE_FLOAT, SQLITE_INTEGER, SQLITE_NULL, SQLITE_TEXT,
sqlite3_column_blob, sqlite3_column_bytes, sqlite3_column_double, sqlite3_column_int,
sqlite3_column_int64, sqlite3_column_text, sqlite3_column_type, SQLITE_BLOB, SQLITE_FLOAT,
SQLITE_INTEGER, SQLITE_NULL, SQLITE_TEXT,
};
use crate::sqlite::statement::SqliteStatement;
use crate::sqlite::types::SqliteType;
use core::slice;
pub struct SqliteResultValue<'c> {
pub(super) index: usize,
pub(super) statement: &'c SqliteStatement,
}
// https://www.sqlite.org/c3ref/column_blob.html
// https://www.sqlite.org/capi3ref.html#sqlite3_column_blob
// These routines return information about a single column of the current result row of a query.
@ -65,4 +68,19 @@ impl<'c> SqliteResultValue<'c> {
raw.to_str().map_err(crate::Error::decode)
}
pub(crate) fn blob(&self) -> crate::Result<&'c [u8]> {
let index = self.index as i32;
#[allow(unsafe_code)]
let ptr = unsafe { sqlite3_column_blob(self.statement.handle.as_ptr(), index) };
#[allow(unsafe_code)]
let len = unsafe { sqlite3_column_bytes(self.statement.handle.as_ptr(), index) };
#[allow(unsafe_code)]
let raw = unsafe { slice::from_raw_parts(ptr as *const u8, len as usize) };
Ok(raw)
}
}

View file

@ -40,10 +40,10 @@ macro_rules! test_unprepared_type {
$(
let query = format!("SELECT {} as _1", $text);
let mut cursor = conn.fetch(&*query);
// // let row = cursor.next().await?.unwrap();
// // let rec = row.try_get::<$ty, _>("_1")?;
let row = cursor.next().await?.unwrap();
let rec = row.try_get::<$ty, _>("_1")?;
// // assert!($value == rec);
assert!($value == rec);
)+
Ok(())
@ -93,7 +93,7 @@ macro_rules! MySql_query_for_test_prepared_type {
#[macro_export]
macro_rules! Sqlite_query_for_test_prepared_type {
() => {
"SELECT ({0} is null or {0} = ?1), ?1 as _1"
"SELECT {} is ?, ? as _1"
};
}

View file

@ -1,6 +1,6 @@
//! Tests for the raw (unprepared) query API for Sqlite.
use sqlx::{Cursor, Executor, Sqlite, Row};
use sqlx::{Cursor, Executor, Row, Sqlite};
use sqlx_test::new;
#[cfg_attr(feature = "runtime-async-std", async_std::test)]

View file

@ -7,9 +7,40 @@ test_type!(null(
"NULL" == None::<i32>
));
test_type!(bool(
test_type!(bool(Sqlite, bool, "FALSE" == false, "TRUE" == true));
test_type!(i32(Sqlite, i32, "94101" == 94101_i32));
test_type!(i64(Sqlite, i64, "9358295312" == 9358295312_i64));
// NOTE: This behavior can be surprising. Floating-point parameters are widening to double which can
// result in strange rounding.
test_type!(f32(
Sqlite,
bool,
"false::boolean" == false,
"true::boolean" == true
f32,
"3.1410000324249268" == 3.141f32 as f64 as f32
));
test_type!(f64(
Sqlite,
f64,
"939399419.1225182" == 939399419.1225182_f64
));
test_type!(string(
Sqlite,
String,
"'this is foo'" == "this is foo",
"''" == ""
));
test_type!(bytes(
Sqlite,
Vec<u8>,
"X'DEADBEEF'"
== vec![0xDE_u8, 0xAD, 0xBE, 0xEF],
"X''"
== Vec::<u8>::new(),
"X'0000000052'"
== vec![0_u8, 0, 0, 0, 0x52]
));