feat(sqlite): support expressions and multiple no-data statements in the macros

This commit is contained in:
Ryan Leckey 2020-07-04 02:56:02 -07:00
parent 0def87b689
commit d112c4d807
8 changed files with 390 additions and 38 deletions

View file

@ -46,6 +46,7 @@ pub mod types;
#[macro_use]
pub mod query;
mod column;
mod common;
pub mod database;
pub mod describe;

View file

@ -0,0 +1,113 @@
use crate::describe::{Column, Describe};
use crate::error::Error;
use crate::sqlite::connection::explain::explain;
use crate::sqlite::statement::SqliteStatement;
use crate::sqlite::type_info::DataType;
use crate::sqlite::{Sqlite, SqliteConnection, SqliteTypeInfo};
use futures_core::future::BoxFuture;
pub(super) async fn describe(
conn: &mut SqliteConnection,
query: &str,
) -> Result<Describe<Sqlite>, Error> {
describe_with(conn, query, vec![]).await
}
pub(super) fn describe_with<'c: 'e, 'q: 'e, 'e>(
conn: &'c mut SqliteConnection,
query: &'q str,
fallback: Vec<SqliteTypeInfo>,
) -> BoxFuture<'e, Result<Describe<Sqlite>, Error>> {
Box::pin(async move {
// describing a statement from SQLite can be involved
// each SQLx statement is comprised of multiple SQL statements
let SqliteConnection {
ref mut handle,
ref worker,
..
} = conn;
let statement = SqliteStatement::prepare(handle, query, false);
let mut columns = Vec::new();
let mut num_params = 0;
let mut statement = statement?;
// we start by finding the first statement that *can* return results
while let Some((statement, _)) = statement.execute()? {
num_params += statement.bind_parameter_count();
let mut stepped = false;
let num = statement.column_count();
if num == 0 {
// no columns in this statement; skip
continue;
}
// next we try to use [column_decltype] to inspect the type of each column
columns.reserve(num);
for col in 0..num {
let name = statement.column_name(col).to_owned();
let type_info = if let Some(ty) = statement.column_decltype(col) {
ty
} else {
// if that fails, we back up and attempt to step the statement
// once *if* its read-only and then use [column_type] as a
// fallback to [column_decltype]
if !stepped && statement.read_only() && fallback.is_empty() {
stepped = true;
worker.execute(statement);
worker.wake();
let _ = worker.step(statement).await?;
}
let mut ty = statement.column_type_info(col);
if ty.0 == DataType::Null {
if fallback.is_empty() {
// this will _still_ fail if there are no actual rows to return
// this happens more often than not for the macros as we tell
// users to execute against an empty database
// as a last resort, we explain the original query and attempt to
// infer what would the expression types be as a fallback
// to [column_decltype]
let fallback = explain(conn, statement.sql()).await?;
return describe_with(conn, query, fallback).await;
}
if let Some(fallback) = fallback.get(col).cloned() {
ty = fallback;
}
}
ty
};
let not_null = statement.column_not_null(col)?;
columns.push(Column {
name,
type_info: Some(type_info),
not_null,
});
}
}
// println!("describe ->> {:#?}", columns);
Ok(Describe {
columns,
params: vec![None; num_params],
})
})
}

View file

@ -3,14 +3,15 @@ use std::sync::Arc;
use either::Either;
use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;
use futures_util::TryStreamExt;
use futures_util::{FutureExt, TryStreamExt};
use hashbrown::HashMap;
use crate::common::StatementCache;
use crate::describe::{Column, Describe};
use crate::describe::Describe;
use crate::error::Error;
use crate::executor::{Execute, Executor};
use crate::ext::ustr::UStr;
use crate::sqlite::connection::describe::describe;
use crate::sqlite::connection::ConnectionHandle;
use crate::sqlite::statement::{SqliteStatement, StatementHandle};
use crate::sqlite::{Sqlite, SqliteArguments, SqliteConnection, SqliteRow};
@ -176,34 +177,6 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
'c: 'e,
E: Execute<'q, Self::Database>,
{
let query = query.query();
let statement = SqliteStatement::prepare(&mut self.handle, query, false);
Box::pin(async move {
let mut params = Vec::new();
let mut columns = Vec::new();
if let Some(statement) = statement?.handles.get(0) {
// NOTE: we can infer *nothing* about parameters apart from the count
params.resize(statement.bind_parameter_count(), None);
let num_columns = statement.column_count();
columns.reserve(num_columns);
for i in 0..num_columns {
let name = statement.column_name(i).to_owned();
let type_info = statement.column_decltype(i);
let not_null = statement.column_not_null(i)?;
columns.push(Column {
name,
type_info,
not_null,
})
}
}
Ok(Describe { params, columns })
})
describe(self, query.query()).boxed()
}
}

View file

@ -0,0 +1,153 @@
use crate::error::Error;
use crate::query_as::query_as;
use crate::sqlite::type_info::DataType;
use crate::sqlite::{SqliteConnection, SqliteTypeInfo};
use hashbrown::HashMap;
const OP_INIT: &str = "Init";
const OP_GOTO: &str = "Goto";
const OP_COLUMN: &str = "Column";
const OP_AGG_STEP: &str = "AggStep";
const OP_MOVE: &str = "Move";
const OP_COPY: &str = "Copy";
const OP_SCOPY: &str = "SCopy";
const OP_INT_COPY: &str = "IntCopy";
const OP_STRING8: &str = "String8";
const OP_INT64: &str = "Int64";
const OP_INTEGER: &str = "Integer";
const OP_REAL: &str = "Real";
const OP_NOT: &str = "Not";
const OP_BLOB: &str = "Blob";
const OP_COUNT: &str = "Count";
const OP_ROWID: &str = "Rowid";
const OP_OR: &str = "Or";
const OP_AND: &str = "And";
const OP_BIT_AND: &str = "BitAnd";
const OP_BIT_OR: &str = "BitOr";
const OP_SHIFT_LEFT: &str = "ShiftLeft";
const OP_SHIFT_RIGHT: &str = "ShiftRight";
const OP_ADD: &str = "Add";
const OP_SUBTRACT: &str = "Subtract";
const OP_MULTIPLY: &str = "Multiply";
const OP_DIVIDE: &str = "Divide";
const OP_REMAINDER: &str = "Remainder";
const OP_CONCAT: &str = "Concat";
const OP_RESULT_ROW: &str = "ResultRow";
fn to_type(op: &str) -> DataType {
match op {
OP_REAL => DataType::Float,
OP_BLOB => DataType::Blob,
OP_AND | OP_OR => DataType::Bool,
OP_ROWID | OP_COUNT | OP_INT64 | OP_INTEGER => DataType::Int64,
OP_STRING8 => DataType::Text,
OP_COLUMN | _ => DataType::Null,
}
}
pub(super) async fn explain(
conn: &mut SqliteConnection,
query: &str,
) -> Result<Vec<SqliteTypeInfo>, Error> {
let mut r = HashMap::<i64, DataType>::with_capacity(6);
let program =
query_as::<_, (i64, String, i64, i64, i64, String)>(&*format!("EXPLAIN {}", query))
.fetch_all(&mut *conn)
.await?;
let mut program_i = 0;
let program_size = program.len();
while program_i < program_size {
let (_, ref opcode, p1, p2, p3, ref p4) = program[program_i];
match &**opcode {
OP_INIT => {
// start at <p2>
program_i = p2 as usize;
continue;
}
OP_GOTO => {
// goto <p2>
program_i = p2 as usize;
continue;
}
OP_COLUMN => {
// r[p3] = <value of column>
r.insert(p3, DataType::Null);
}
OP_AGG_STEP => {
if p4.starts_with("count(") {
// count(_) -> INTEGER
r.insert(p3, DataType::Int64);
} else if let Some(v) = r.get(&p2).copied() {
// r[p3] = AGG ( r[p2] )
r.insert(p3, v);
}
}
OP_COPY | OP_MOVE | OP_SCOPY | OP_INT_COPY => {
// r[p2] = r[p1]
if let Some(v) = r.get(&p1).copied() {
r.insert(p2, v);
}
}
OP_OR | OP_AND | OP_BLOB | OP_COUNT | OP_REAL | OP_STRING8 | OP_INTEGER | OP_ROWID => {
// r[p2] = <value of constant>
r.insert(p2, to_type(&opcode));
}
OP_NOT => {
// r[p2] = NOT r[p1]
if let Some(a) = r.get(&p1).copied() {
r.insert(p2, a);
}
}
OP_BIT_AND | OP_BIT_OR | OP_SHIFT_LEFT | OP_SHIFT_RIGHT | OP_ADD | OP_SUBTRACT
| OP_MULTIPLY | OP_DIVIDE | OP_REMAINDER | OP_CONCAT => {
// r[p3] = r[p1] + r[p2]
match (r.get(&p1).copied(), r.get(&p2).copied()) {
(Some(a), Some(b)) => {
r.insert(p3, if matches!(a, DataType::Null) { b } else { a });
}
(Some(v), None) => {
r.insert(p3, v);
}
(None, Some(v)) => {
r.insert(p3, v);
}
_ => {}
}
}
OP_RESULT_ROW => {
// output = r[p1 .. p1 + p2]
let mut output = Vec::with_capacity(p2 as usize);
for i in p1..p1 + p2 {
output.push(SqliteTypeInfo(r.remove(&i).unwrap_or(DataType::Null)));
}
return Ok(output);
}
_ => {
// ignore unsupported operations
// if we fail to find an r later, we just give up
}
}
program_i += 1;
}
// no rows
Ok(vec![])
}

View file

@ -15,8 +15,10 @@ use crate::sqlite::connection::establish::establish;
use crate::sqlite::statement::{SqliteStatement, StatementWorker};
use crate::sqlite::{Sqlite, SqliteConnectOptions};
mod describe;
mod establish;
mod executor;
mod explain;
mod handle;
pub(crate) use handle::ConnectionHandle;

View file

@ -11,8 +11,8 @@ use libsqlite3_sys::{
sqlite3_column_count, sqlite3_column_database_name, sqlite3_column_decltype,
sqlite3_column_double, sqlite3_column_int, sqlite3_column_int64, sqlite3_column_name,
sqlite3_column_origin_name, sqlite3_column_table_name, sqlite3_column_type,
sqlite3_column_value, sqlite3_db_handle, sqlite3_stmt, sqlite3_table_column_metadata,
SQLITE_OK, SQLITE_TRANSIENT, SQLITE_UTF8,
sqlite3_column_value, sqlite3_db_handle, sqlite3_sql, sqlite3_stmt, sqlite3_stmt_readonly,
sqlite3_table_column_metadata, SQLITE_OK, SQLITE_TRANSIENT, SQLITE_UTF8,
};
use crate::error::{BoxDynError, Error};
@ -38,6 +38,21 @@ impl StatementHandle {
sqlite3_db_handle(self.0.as_ptr())
}
pub(crate) fn read_only(&self) -> bool {
// https://sqlite.org/c3ref/stmt_readonly.html
unsafe { sqlite3_stmt_readonly(self.0.as_ptr()) != 0 }
}
pub(crate) fn sql(&self) -> &str {
// https://sqlite.org/c3ref/expanded_sql.html
unsafe {
let raw = sqlite3_sql(self.0.as_ptr());
debug_assert!(!raw.is_null());
from_utf8_unchecked(CStr::from_ptr(raw).to_bytes())
}
}
#[inline]
pub(crate) fn last_error(&self) -> SqliteError {
SqliteError::new(unsafe { self.db_handle() })
@ -68,6 +83,10 @@ impl StatementHandle {
}
}
pub(crate) fn column_type_info(&self, index: usize) -> SqliteTypeInfo {
SqliteTypeInfo(DataType::from_code(self.column_type(index)))
}
#[inline]
pub(crate) fn column_decltype(&self, index: usize) -> Option<SqliteTypeInfo> {
unsafe {

View file

@ -7,7 +7,7 @@ use libsqlite3_sys::{SQLITE_BLOB, SQLITE_FLOAT, SQLITE_INTEGER, SQLITE_NULL, SQL
use crate::error::BoxDynError;
use crate::type_info::TypeInfo;
#[derive(Debug, Clone, Eq, PartialEq)]
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
pub(crate) enum DataType {
Null,

View file

@ -1,6 +1,10 @@
use sqlx::describe::Column;
use sqlx::error::DatabaseError;
use sqlx::sqlite::{SqliteConnectOptions, SqliteError};
use sqlx::{sqlite::Sqlite, Executor};
use sqlx_core::describe::Column;
use sqlx::{Connect, SqliteConnection, TypeInfo};
use sqlx_test::new;
use std::env;
fn type_names(columns: &[Column<Sqlite>]) -> Vec<String> {
columns
@ -41,14 +45,101 @@ async fn it_describes_simple() -> anyhow::Result<()> {
async fn it_describes_expression() -> anyhow::Result<()> {
let mut conn = new::<Sqlite>().await?;
let d = conn.describe("SELECT 1 + 10").await?;
let d = conn
.describe("SELECT 1 + 10, 5.12 * 2, 'Hello', x'deadbeef'")
.await?;
let columns = d.columns;
assert_eq!(columns[0].type_info.as_ref().unwrap().name(), "INTEGER");
assert_eq!(columns[0].name, "1 + 10");
assert_eq!(columns[0].not_null, None);
// SQLite cannot infer types for expressions
assert_eq!(columns[0].type_info, None);
assert_eq!(columns[1].type_info.as_ref().unwrap().name(), "REAL");
assert_eq!(columns[1].name, "5.12 * 2");
assert_eq!(columns[1].not_null, None);
assert_eq!(columns[2].type_info.as_ref().unwrap().name(), "TEXT");
assert_eq!(columns[2].name, "'Hello'");
assert_eq!(columns[2].not_null, None);
assert_eq!(columns[3].type_info.as_ref().unwrap().name(), "BLOB");
assert_eq!(columns[3].name, "x'deadbeef'");
assert_eq!(columns[3].not_null, None);
Ok(())
}
#[sqlx_macros::test]
async fn it_describes_expression_from_empty_table() -> anyhow::Result<()> {
let mut conn = new::<Sqlite>().await?;
conn.execute("CREATE TEMP TABLE _temp_empty ( name TEXT, a INT )")
.await?;
let d = conn
.describe("SELECT COUNT(*), a + 1, name, 5.12, 'Hello' FROM _temp_empty")
.await?;
assert_eq!(d.columns[0].type_info.as_ref().unwrap().name(), "INTEGER");
assert_eq!(d.columns[1].type_info.as_ref().unwrap().name(), "INTEGER");
assert_eq!(d.columns[2].type_info.as_ref().unwrap().name(), "TEXT");
assert_eq!(d.columns[3].type_info.as_ref().unwrap().name(), "REAL");
assert_eq!(d.columns[4].type_info.as_ref().unwrap().name(), "TEXT");
Ok(())
}
#[sqlx_macros::test]
async fn it_describes_insert() -> anyhow::Result<()> {
let mut conn = new::<Sqlite>().await?;
let d = conn
.describe("INSERT INTO tweet (id, text) VALUES (2, 'Hello')")
.await?;
assert_eq!(d.columns.len(), 0);
let d = conn
.describe("INSERT INTO tweet (id, text) VALUES (2, 'Hello'); SELECT last_insert_rowid();")
.await?;
assert_eq!(d.columns.len(), 1);
assert_eq!(d.columns[0].type_info.as_ref().unwrap().name(), "INTEGER");
Ok(())
}
#[sqlx_macros::test]
async fn it_describes_insert_with_read_only() -> anyhow::Result<()> {
sqlx_test::setup_if_needed();
let mut options: SqliteConnectOptions = env::var("DATABASE_URL")?.parse().unwrap();
options = options.read_only(true);
let mut conn = SqliteConnection::connect_with(&options).await?;
let d = conn
.describe("INSERT INTO tweet (id, text) VALUES (2, 'Hello')")
.await?;
assert_eq!(d.columns.len(), 0);
Ok(())
}
#[sqlx_macros::test]
async fn it_describes_bad_statement() -> anyhow::Result<()> {
let mut conn = new::<Sqlite>().await?;
let err = conn.describe("SELECT 1 FROM not_found").await.unwrap_err();
let err = err
.as_database_error()
.unwrap()
.downcast_ref::<SqliteError>();
assert_eq!(err.message(), "no such table: not_found");
assert_eq!(err.code().as_deref(), Some("1"));
Ok(())
}