mirror of
https://github.com/launchbadge/sqlx
synced 2024-11-10 06:24:16 +00:00
Keep track of column typing in SQLite EXPLAIN parsing (#1323)
* NewRowid, Column opcodes, better pointer handling * Implement tracking of column typing on sqlite explain parser * fmt for sqlite column typing for explain parsing Co-authored-by: marshoepial <marshoepial@gmail.com>
This commit is contained in:
parent
8bcac0394f
commit
cb3ff28721
2 changed files with 112 additions and 19 deletions
|
@ -17,6 +17,13 @@ const SQLITE_AFF_REAL: u8 = 0x45; /* 'E' */
|
|||
const OP_INIT: &str = "Init";
|
||||
const OP_GOTO: &str = "Goto";
|
||||
const OP_COLUMN: &str = "Column";
|
||||
const OP_MAKE_RECORD: &str = "MakeRecord";
|
||||
const OP_INSERT: &str = "Insert";
|
||||
const OP_IDX_INSERT: &str = "IdxInsert";
|
||||
const OP_OPEN_READ: &str = "OpenRead";
|
||||
const OP_OPEN_WRITE: &str = "OpenWrite";
|
||||
const OP_OPEN_EPHEMERAL: &str = "OpenEphemeral";
|
||||
const OP_OPEN_AUTOINDEX: &str = "OpenAutoindex";
|
||||
const OP_AGG_STEP: &str = "AggStep";
|
||||
const OP_FUNCTION: &str = "Function";
|
||||
const OP_MOVE: &str = "Move";
|
||||
|
@ -34,6 +41,7 @@ const OP_BLOB: &str = "Blob";
|
|||
const OP_VARIABLE: &str = "Variable";
|
||||
const OP_COUNT: &str = "Count";
|
||||
const OP_ROWID: &str = "Rowid";
|
||||
const OP_NEWROWID: &str = "NewRowid";
|
||||
const OP_OR: &str = "Or";
|
||||
const OP_AND: &str = "And";
|
||||
const OP_BIT_AND: &str = "BitAnd";
|
||||
|
@ -48,6 +56,21 @@ const OP_REMAINDER: &str = "Remainder";
|
|||
const OP_CONCAT: &str = "Concat";
|
||||
const OP_RESULT_ROW: &str = "ResultRow";
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
enum RegDataType {
|
||||
Single(DataType),
|
||||
Record(Vec<DataType>),
|
||||
}
|
||||
|
||||
impl RegDataType {
|
||||
fn map_to_datatype(self) -> DataType {
|
||||
match self {
|
||||
RegDataType::Single(d) => d,
|
||||
RegDataType::Record(_) => DataType::Null, //If we're trying to coerce to a regular Datatype, we can assume a Record is invalid for the context
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::wildcard_in_or_patterns)]
|
||||
fn affinity_to_type(affinity: u8) -> DataType {
|
||||
match affinity {
|
||||
|
@ -73,13 +96,19 @@ fn opcode_to_type(op: &str) -> DataType {
|
|||
}
|
||||
}
|
||||
|
||||
// Opcode Reference: https://sqlite.org/opcode.html
|
||||
pub(super) async fn explain(
|
||||
conn: &mut SqliteConnection,
|
||||
query: &str,
|
||||
) -> Result<(Vec<SqliteTypeInfo>, Vec<Option<bool>>), Error> {
|
||||
let mut r = HashMap::<i64, DataType>::with_capacity(6);
|
||||
// Registers
|
||||
let mut r = HashMap::<i64, RegDataType>::with_capacity(6);
|
||||
// Map between pointer and register
|
||||
let mut r_cursor = HashMap::<i64, Vec<i64>>::with_capacity(6);
|
||||
// Rows that pointers point to
|
||||
let mut p = HashMap::<i64, HashMap<i64, DataType>>::with_capacity(6);
|
||||
|
||||
// Nullable columns
|
||||
let mut n = HashMap::<i64, bool>::with_capacity(6);
|
||||
|
||||
let program =
|
||||
|
@ -119,15 +148,52 @@ pub(super) async fn explain(
|
|||
}
|
||||
|
||||
OP_COLUMN => {
|
||||
r_cursor.entry(p1).or_default().push(p3);
|
||||
//Get the row stored at p1, or NULL; get the column stored at p2, or NULL
|
||||
if let Some(record) = p.get(&p1) {
|
||||
if let Some(col) = record.get(&p2) {
|
||||
// insert into p3 the datatype of the col
|
||||
r.insert(p3, RegDataType::Single(*col));
|
||||
// map between pointer p1 and register p3
|
||||
r_cursor.entry(p1).or_default().push(p3);
|
||||
} else {
|
||||
r.insert(p3, RegDataType::Single(DataType::Null));
|
||||
}
|
||||
} else {
|
||||
r.insert(p3, RegDataType::Single(DataType::Null));
|
||||
}
|
||||
}
|
||||
|
||||
// r[p3] = <value of column>
|
||||
r.insert(p3, DataType::Null);
|
||||
OP_MAKE_RECORD => {
|
||||
// p3 = Record([p1 .. p1 + p2])
|
||||
let mut record = Vec::with_capacity(p2 as usize);
|
||||
for reg in p1..p1 + p2 {
|
||||
record.push(
|
||||
r.get(®)
|
||||
.map(|d| d.clone().map_to_datatype())
|
||||
.unwrap_or(DataType::Null),
|
||||
);
|
||||
}
|
||||
r.insert(p3, RegDataType::Record(record));
|
||||
}
|
||||
|
||||
OP_INSERT | OP_IDX_INSERT => {
|
||||
if let Some(RegDataType::Record(record)) = r.get(&p2) {
|
||||
if let Some(row) = p.get_mut(&p1) {
|
||||
// Insert the record into wherever pointer p1 is
|
||||
*row = (0..).zip(record.iter().copied()).collect();
|
||||
}
|
||||
}
|
||||
//Noop if the register p2 isn't a record, or if pointer p1 does not exist
|
||||
}
|
||||
|
||||
OP_OPEN_READ | OP_OPEN_WRITE | OP_OPEN_EPHEMERAL | OP_OPEN_AUTOINDEX => {
|
||||
//Create a new pointer which is referenced by p1
|
||||
p.insert(p1, HashMap::with_capacity(6));
|
||||
}
|
||||
|
||||
OP_VARIABLE => {
|
||||
// r[p2] = <value of variable>
|
||||
r.insert(p2, DataType::Null);
|
||||
r.insert(p2, RegDataType::Single(DataType::Null));
|
||||
n.insert(p3, true);
|
||||
}
|
||||
|
||||
|
@ -136,7 +202,7 @@ pub(super) async fn explain(
|
|||
match from_utf8(p4).map_err(Error::protocol)? {
|
||||
"last_insert_rowid(0)" => {
|
||||
// last_insert_rowid() -> INTEGER
|
||||
r.insert(p3, DataType::Int64);
|
||||
r.insert(p3, RegDataType::Single(DataType::Int64));
|
||||
n.insert(p3, n.get(&p3).copied().unwrap_or(false));
|
||||
}
|
||||
|
||||
|
@ -145,9 +211,9 @@ pub(super) async fn explain(
|
|||
}
|
||||
|
||||
OP_NULL_ROW => {
|
||||
// all values of cursor X are potentially nullable
|
||||
for column in &r_cursor[&p1] {
|
||||
n.insert(*column, true);
|
||||
// all registers that map to cursor X are potentially nullable
|
||||
for register in &r_cursor[&p1] {
|
||||
n.insert(*register, true);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -156,9 +222,9 @@ pub(super) async fn explain(
|
|||
|
||||
if p4.starts_with("count(") {
|
||||
// count(_) -> INTEGER
|
||||
r.insert(p3, DataType::Int64);
|
||||
r.insert(p3, RegDataType::Single(DataType::Int64));
|
||||
n.insert(p3, n.get(&p3).copied().unwrap_or(false));
|
||||
} else if let Some(v) = r.get(&p2).copied() {
|
||||
} else if let Some(v) = r.get(&p2).cloned() {
|
||||
// r[p3] = AGG ( r[p2] )
|
||||
r.insert(p3, v);
|
||||
let val = n.get(&p2).copied().unwrap_or(true);
|
||||
|
@ -169,13 +235,13 @@ pub(super) async fn explain(
|
|||
OP_CAST => {
|
||||
// affinity(r[p1])
|
||||
if let Some(v) = r.get_mut(&p1) {
|
||||
*v = affinity_to_type(p2 as u8);
|
||||
*v = RegDataType::Single(affinity_to_type(p2 as u8));
|
||||
}
|
||||
}
|
||||
|
||||
OP_COPY | OP_MOVE | OP_SCOPY | OP_INT_COPY => {
|
||||
// r[p2] = r[p1]
|
||||
if let Some(v) = r.get(&p1).copied() {
|
||||
if let Some(v) = r.get(&p1).cloned() {
|
||||
r.insert(p2, v);
|
||||
|
||||
if let Some(null) = n.get(&p1).copied() {
|
||||
|
@ -184,15 +250,16 @@ pub(super) async fn explain(
|
|||
}
|
||||
}
|
||||
|
||||
OP_OR | OP_AND | OP_BLOB | OP_COUNT | OP_REAL | OP_STRING8 | OP_INTEGER | OP_ROWID => {
|
||||
OP_OR | OP_AND | OP_BLOB | OP_COUNT | OP_REAL | OP_STRING8 | OP_INTEGER | OP_ROWID
|
||||
| OP_NEWROWID => {
|
||||
// r[p2] = <value of constant>
|
||||
r.insert(p2, opcode_to_type(&opcode));
|
||||
r.insert(p2, RegDataType::Single(opcode_to_type(&opcode)));
|
||||
n.insert(p2, n.get(&p2).copied().unwrap_or(false));
|
||||
}
|
||||
|
||||
OP_NOT => {
|
||||
// r[p2] = NOT r[p1]
|
||||
if let Some(a) = r.get(&p1).copied() {
|
||||
if let Some(a) = r.get(&p1).cloned() {
|
||||
r.insert(p2, a);
|
||||
let val = n.get(&p1).copied().unwrap_or(true);
|
||||
n.insert(p2, val);
|
||||
|
@ -202,9 +269,16 @@ pub(super) async fn explain(
|
|||
OP_BIT_AND | OP_BIT_OR | OP_SHIFT_LEFT | OP_SHIFT_RIGHT | OP_ADD | OP_SUBTRACT
|
||||
| OP_MULTIPLY | OP_DIVIDE | OP_REMAINDER | OP_CONCAT => {
|
||||
// r[p3] = r[p1] + r[p2]
|
||||
match (r.get(&p1).copied(), r.get(&p2).copied()) {
|
||||
match (r.get(&p1).cloned(), r.get(&p2).cloned()) {
|
||||
(Some(a), Some(b)) => {
|
||||
r.insert(p3, if matches!(a, DataType::Null) { b } else { a });
|
||||
r.insert(
|
||||
p3,
|
||||
if matches!(a, RegDataType::Single(DataType::Null)) {
|
||||
b
|
||||
} else {
|
||||
a
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
(Some(v), None) => {
|
||||
|
@ -252,7 +326,11 @@ pub(super) async fn explain(
|
|||
|
||||
if let Some(result) = result {
|
||||
for i in result {
|
||||
output.push(SqliteTypeInfo(r.remove(&i).unwrap_or(DataType::Null)));
|
||||
output.push(SqliteTypeInfo(
|
||||
r.remove(&i)
|
||||
.map(|d| d.map_to_datatype())
|
||||
.unwrap_or(DataType::Null),
|
||||
));
|
||||
nullable.push(n.remove(&i));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -171,6 +171,21 @@ async fn it_describes_insert_with_read_only() -> anyhow::Result<()> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[sqlx_macros::test]
|
||||
async fn it_describes_insert_with_returning() -> anyhow::Result<()> {
|
||||
let mut conn = new::<Sqlite>().await?;
|
||||
|
||||
let d = conn
|
||||
.describe("INSERT INTO tweet (id, text) VALUES (2, 'Hello') RETURNING *")
|
||||
.await?;
|
||||
|
||||
assert_eq!(d.columns().len(), 4);
|
||||
assert_eq!(d.column(0).type_info().name(), "INTEGER");
|
||||
assert_eq!(d.column(1).type_info().name(), "TEXT");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[sqlx_macros::test]
|
||||
async fn it_describes_bad_statement() -> anyhow::Result<()> {
|
||||
let mut conn = new::<Sqlite>().await?;
|
||||
|
|
Loading…
Reference in a new issue