mirror of
https://github.com/launchbadge/sqlx
synced 2024-11-10 06:24:16 +00:00
Add extension support for SQLite (#2062)
* Add extension support for SQLite While SQLite supports loading extensions at run-time via either the C API or the SQL interface, they strongly recommend [1] only enabling the C API so that SQL injections don't allow attackers to run arbitrary extension code. Here we take the most conservative approach, we enable only the C function, and then only when the user requests extensions be loaded in their `SqliteConnectOptions`, and disable it again once we're done loading those requested modules. We don't add any support for loading extensions via environment variables or connection strings. Extensions in the options are stored as an IndexMap as the load order can have side effects, they will be loaded in the order they are supplied by the caller. Extensions with custom entry points are supported, but a default API is exposed as most users will interact with extensions using the defaults. [1]: https://sqlite.org/c3ref/enable_load_extension.html * Add extension testing for SQlite Extends x.py to download an appropriate shared object file for supported operating systems, and uses wget to fetch one into the GitHub Actions context for use by CI. Overriding LD_LIBRARY_PATH for only this specific DB minimises the impact on the rest of the suite.
This commit is contained in:
parent
9de70d2e7a
commit
20877d83fd
6 changed files with 235 additions and 4 deletions
6
.github/workflows/sqlx.yml
vendored
6
.github/workflows/sqlx.yml
vendored
|
@ -46,7 +46,7 @@ jobs:
|
|||
- uses: Swatinem/rust-cache@v1
|
||||
with:
|
||||
key: ${{ runner.os }}-check-${{ matrix.runtime }}-${{ matrix.tls }}
|
||||
|
||||
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: check
|
||||
|
@ -144,6 +144,8 @@ jobs:
|
|||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- run: mkdir /tmp/sqlite3-lib && wget -O /tmp/sqlite3-lib/ipaddr.so https://github.com/nalgeon/sqlean/releases/download/0.15.2/ipaddr.so
|
||||
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
|
@ -164,6 +166,8 @@ jobs:
|
|||
--test-threads=1
|
||||
env:
|
||||
DATABASE_URL: sqlite://tests/sqlite/sqlite.db
|
||||
RUSTFLAGS: --cfg sqlite_ipaddr
|
||||
LD_LIBRARY_PATH: /tmp/sqlite3-lib
|
||||
|
||||
postgres:
|
||||
name: Postgres
|
||||
|
|
|
@ -3,25 +3,46 @@ use crate::error::Error;
|
|||
use crate::sqlite::connection::handle::ConnectionHandle;
|
||||
use crate::sqlite::connection::{ConnectionState, Statements};
|
||||
use crate::sqlite::{SqliteConnectOptions, SqliteError};
|
||||
use indexmap::IndexMap;
|
||||
use libc::c_void;
|
||||
use libsqlite3_sys::{
|
||||
sqlite3_busy_timeout, sqlite3_extended_result_codes, sqlite3_open_v2, SQLITE_OK,
|
||||
sqlite3, sqlite3_busy_timeout, sqlite3_db_config, sqlite3_extended_result_codes, sqlite3_free,
|
||||
sqlite3_load_extension, sqlite3_open_v2, SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION, SQLITE_OK,
|
||||
SQLITE_OPEN_CREATE, SQLITE_OPEN_FULLMUTEX, SQLITE_OPEN_MEMORY, SQLITE_OPEN_NOMUTEX,
|
||||
SQLITE_OPEN_PRIVATECACHE, SQLITE_OPEN_READONLY, SQLITE_OPEN_READWRITE, SQLITE_OPEN_SHAREDCACHE,
|
||||
};
|
||||
use std::ffi::CString;
|
||||
use std::ffi::{CStr, CString};
|
||||
use std::io;
|
||||
use std::ptr::{null, null_mut};
|
||||
use std::os::raw::c_int;
|
||||
use std::ptr::{addr_of_mut, null, null_mut};
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::time::Duration;
|
||||
|
||||
static THREAD_ID: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
enum SqliteLoadExtensionMode {
|
||||
/// Enables only the C-API, leaving the SQL function disabled.
|
||||
Enable,
|
||||
/// Disables both the C-API and the SQL function.
|
||||
DisableAll,
|
||||
}
|
||||
|
||||
impl SqliteLoadExtensionMode {
|
||||
fn as_int(self) -> c_int {
|
||||
match self {
|
||||
SqliteLoadExtensionMode::Enable => 1,
|
||||
SqliteLoadExtensionMode::DisableAll => 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct EstablishParams {
|
||||
filename: CString,
|
||||
open_flags: i32,
|
||||
busy_timeout: Duration,
|
||||
statement_cache_capacity: usize,
|
||||
log_settings: LogSettings,
|
||||
extensions: IndexMap<CString, Option<CString>>,
|
||||
pub(crate) thread_name: String,
|
||||
pub(crate) command_channel_size: usize,
|
||||
}
|
||||
|
@ -89,17 +110,67 @@ impl EstablishParams {
|
|||
)
|
||||
})?;
|
||||
|
||||
let extensions = options
|
||||
.extensions
|
||||
.iter()
|
||||
.map(|(name, entry)| {
|
||||
let entry = entry
|
||||
.as_ref()
|
||||
.map(|e| {
|
||||
CString::new(e.as_bytes()).map_err(|_| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"extension entrypoint names passed to SQLite must not contain nul bytes"
|
||||
)
|
||||
})
|
||||
})
|
||||
.transpose()?;
|
||||
Ok((
|
||||
CString::new(name.as_bytes()).map_err(|_| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"extension names passed to SQLite must not contain nul bytes",
|
||||
)
|
||||
})?,
|
||||
entry,
|
||||
))
|
||||
})
|
||||
.collect::<Result<IndexMap<CString, Option<CString>>, io::Error>>()?;
|
||||
|
||||
Ok(Self {
|
||||
filename,
|
||||
open_flags: flags,
|
||||
busy_timeout: options.busy_timeout,
|
||||
statement_cache_capacity: options.statement_cache_capacity,
|
||||
log_settings: options.log_settings.clone(),
|
||||
extensions,
|
||||
thread_name: (options.thread_name)(THREAD_ID.fetch_add(1, Ordering::AcqRel)),
|
||||
command_channel_size: options.command_channel_size,
|
||||
})
|
||||
}
|
||||
|
||||
// Enable extension loading via the db_config function, as recommended by the docs rather
|
||||
// than the more obvious `sqlite3_enable_load_extension`
|
||||
// https://www.sqlite.org/c3ref/db_config.html
|
||||
// https://www.sqlite.org/c3ref/c_dbconfig_defensive.html#sqlitedbconfigenableloadextension
|
||||
unsafe fn sqlite3_set_load_extension(
|
||||
db: *mut sqlite3,
|
||||
mode: SqliteLoadExtensionMode,
|
||||
) -> Result<(), Error> {
|
||||
let status = sqlite3_db_config(
|
||||
db,
|
||||
SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION,
|
||||
mode.as_int(),
|
||||
null::<i32>(),
|
||||
);
|
||||
|
||||
if status != SQLITE_OK {
|
||||
return Err(Error::Database(Box::new(SqliteError::new(db))));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn establish(&self) -> Result<ConnectionState, Error> {
|
||||
let mut handle = null_mut();
|
||||
|
||||
|
@ -131,6 +202,57 @@ impl EstablishParams {
|
|||
sqlite3_extended_result_codes(handle.as_ptr(), 1);
|
||||
}
|
||||
|
||||
if !self.extensions.is_empty() {
|
||||
// Enable loading extensions
|
||||
unsafe {
|
||||
Self::sqlite3_set_load_extension(handle.as_ptr(), SqliteLoadExtensionMode::Enable)?;
|
||||
}
|
||||
|
||||
for ext in self.extensions.iter() {
|
||||
// `sqlite3_load_extension` is unusual as it returns its errors via an out-pointer
|
||||
// rather than by calling `sqlite3_errmsg`
|
||||
let mut error = null_mut();
|
||||
status = unsafe {
|
||||
sqlite3_load_extension(
|
||||
handle.as_ptr(),
|
||||
ext.0.as_ptr(),
|
||||
ext.1.as_ref().map_or(null(), |e| e.as_ptr()),
|
||||
addr_of_mut!(error),
|
||||
)
|
||||
};
|
||||
|
||||
if status != SQLITE_OK {
|
||||
// SAFETY: We become responsible for any memory allocation at `&error`, so test
|
||||
// for null and take an RAII version for returns
|
||||
let err_msg = if !error.is_null() {
|
||||
unsafe {
|
||||
let e = CStr::from_ptr(error).into();
|
||||
sqlite3_free(error as *mut c_void);
|
||||
e
|
||||
}
|
||||
} else {
|
||||
CString::new("Unknown error when loading extension")
|
||||
.expect("text should be representable as a CString")
|
||||
};
|
||||
return Err(Error::Database(Box::new(SqliteError::extension(
|
||||
handle.as_ptr(),
|
||||
&err_msg,
|
||||
))));
|
||||
}
|
||||
}
|
||||
|
||||
// Preempt any hypothetical security issues arising from leaving ENABLE_LOAD_EXTENSION
|
||||
// on by disabling the flag again once we've loaded all the requested modules.
|
||||
// Fail-fast (via `?`) if disabling the extension loader didn't work for some reason,
|
||||
// avoids an unexpected state going undetected.
|
||||
unsafe {
|
||||
Self::sqlite3_set_load_extension(
|
||||
handle.as_ptr(),
|
||||
SqliteLoadExtensionMode::DisableAll,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Configure a busy timeout
|
||||
// This causes SQLite to automatically sleep in increasing intervals until the time
|
||||
// when there is something locked during [sqlite3_step].
|
||||
|
|
|
@ -35,6 +35,13 @@ impl SqliteError {
|
|||
message: message.to_owned(),
|
||||
}
|
||||
}
|
||||
|
||||
/// For errors during extension load, the error message is supplied via a separate pointer
|
||||
pub(crate) fn extension(handle: *mut sqlite3, error_msg: &CStr) -> Self {
|
||||
let mut err = Self::new(handle);
|
||||
err.message = unsafe { from_utf8_unchecked(error_msg.to_bytes()).to_owned() };
|
||||
err
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for SqliteError {
|
||||
|
|
|
@ -66,6 +66,11 @@ pub struct SqliteConnectOptions {
|
|||
pub(crate) vfs: Option<Cow<'static, str>>,
|
||||
|
||||
pub(crate) pragmas: IndexMap<Cow<'static, str>, Option<Cow<'static, str>>>,
|
||||
/// Extensions are specified as a pair of <Extension Name : Optional Entry Point>, the majority
|
||||
/// of SQLite extensions will use the default entry points specified in the docs, these should
|
||||
/// be added to the map with a `None` value.
|
||||
/// <https://www.sqlite.org/loadext.html#loading_an_extension>
|
||||
pub(crate) extensions: IndexMap<Cow<'static, str>, Option<Cow<'static, str>>>,
|
||||
|
||||
pub(crate) command_channel_size: usize,
|
||||
pub(crate) row_channel_size: usize,
|
||||
|
@ -174,6 +179,7 @@ impl SqliteConnectOptions {
|
|||
immutable: false,
|
||||
vfs: None,
|
||||
pragmas,
|
||||
extensions: Default::default(),
|
||||
collations: Default::default(),
|
||||
serialized: false,
|
||||
thread_name: Arc::new(DebugFn(|id| format!("sqlx-sqlite-worker-{}", id))),
|
||||
|
@ -414,4 +420,42 @@ impl SqliteConnectOptions {
|
|||
self.vfs = Some(vfs_name.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Load an [extension](https://www.sqlite.org/loadext.html) at run-time when the database connection
|
||||
/// is established, using the default entry point.
|
||||
///
|
||||
/// Most common SQLite extensions can be loaded using this method, for extensions where you need
|
||||
/// to specify the entry point, use [`extension_with_entrypoint`][`Self::extension_with_entrypoint`] instead.
|
||||
///
|
||||
/// Multiple extensions can be loaded by calling the method repeatedly on the options struct, they
|
||||
/// will be loaded in the order they are added.
|
||||
/// ```rust,no_run
|
||||
/// # use sqlx_core::error::Error;
|
||||
/// use std::str::FromStr;
|
||||
/// use sqlx::sqlite::SqliteConnectOptions;
|
||||
/// # fn options() -> Result<SqliteConnectOptions, Error> {
|
||||
/// let options = SqliteConnectOptions::from_str("sqlite://data.db")?
|
||||
/// .extension("vsv")
|
||||
/// .extension("mod_spatialite");
|
||||
/// # Ok(options)
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn extension(mut self, extension_name: impl Into<Cow<'static, str>>) -> Self {
|
||||
self.extensions.insert(extension_name.into(), None);
|
||||
self
|
||||
}
|
||||
|
||||
/// Load an extension with a specified entry point.
|
||||
///
|
||||
/// Useful when using non-standard extensions, or when developing your own, the second argument
|
||||
/// specifies where SQLite should expect to find the extension init routine.
|
||||
pub fn extension_with_entrypoint(
|
||||
mut self,
|
||||
extension_name: impl Into<Cow<'static, str>>,
|
||||
entry_point: impl Into<Cow<'static, str>>,
|
||||
) -> Self {
|
||||
self.extensions
|
||||
.insert(extension_name.into(), Some(entry_point.into()));
|
||||
self
|
||||
}
|
||||
}
|
||||
|
|
|
@ -204,6 +204,21 @@ async fn it_executes_with_pool() -> anyhow::Result<()> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(sqlite_ipaddr)]
|
||||
#[sqlx_macros::test]
|
||||
async fn it_opens_with_extension() -> anyhow::Result<()> {
|
||||
use std::str::FromStr;
|
||||
|
||||
let opts = SqliteConnectOptions::from_str(&dotenvy::var("DATABASE_URL")?)?.extension("ipaddr");
|
||||
|
||||
let mut conn = SqliteConnection::connect_with(&opts).await?;
|
||||
conn.execute("SELECT ipmasklen('192.168.16.12/24');")
|
||||
.await?;
|
||||
conn.close().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[sqlx_macros::test]
|
||||
async fn it_opens_in_memory() -> anyhow::Result<()> {
|
||||
// If the filename is ":memory:", then a private, temporary in-memory database
|
||||
|
|
39
tests/x.py
39
tests/x.py
|
@ -5,6 +5,8 @@ import os
|
|||
import sys
|
||||
import time
|
||||
import argparse
|
||||
import platform
|
||||
import urllib.request
|
||||
from glob import glob
|
||||
from docker import start_database
|
||||
|
||||
|
@ -23,6 +25,36 @@ dir_workspace = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|||
dir_tests = os.path.join(dir_workspace, "tests")
|
||||
|
||||
|
||||
def maybe_fetch_sqlite_extension():
|
||||
"""
|
||||
For supported platforms, if we're testing SQLite and the file isn't
|
||||
already present, grab a simple extension for testing.
|
||||
|
||||
Returns the extension name if it was downloaded successfully or `None` if not.
|
||||
"""
|
||||
BASE_URL = "https://github.com/nalgeon/sqlean/releases/download/0.15.2/"
|
||||
if platform.system() == "Darwin":
|
||||
if platform.machine() == "arm64":
|
||||
download_url = BASE_URL + "/ipaddr.arm64.dylib"
|
||||
filename = "ipaddr.dylib"
|
||||
else:
|
||||
download_url = BASE_URL + "/ipaddr.dylib"
|
||||
filename = "ipaddr.dylib"
|
||||
elif platform.system() == "Linux":
|
||||
download_url = BASE_URL + "/ipaddr.so"
|
||||
filename = "ipaddr.so"
|
||||
else:
|
||||
# Unsupported OS
|
||||
return None
|
||||
|
||||
if not os.path.exists(filename):
|
||||
content = urllib.request.urlopen(download_url).read()
|
||||
with open(filename, "wb") as fd:
|
||||
fd.write(content)
|
||||
|
||||
return filename.split(".")[0]
|
||||
|
||||
|
||||
def run(command, comment=None, env=None, service=None, tag=None, args=None, database_url_args=None):
|
||||
if argv.list_targets:
|
||||
if tag:
|
||||
|
@ -41,6 +73,13 @@ def run(command, comment=None, env=None, service=None, tag=None, args=None, data
|
|||
|
||||
environ = env or {}
|
||||
|
||||
if service == "sqlite":
|
||||
if maybe_fetch_sqlite_extension() is not None:
|
||||
if environ.get("RUSTFLAGS"):
|
||||
environ["RUSTFLAGS"] += " --cfg sqlite_ipaddr"
|
||||
else:
|
||||
environ["RUSTFLAGS"] = "--cfg sqlite_ipaddr"
|
||||
|
||||
if service is not None:
|
||||
database_url = start_database(service, database="sqlite/sqlite.db" if service == "sqlite" else "sqlx", cwd=dir_tests)
|
||||
|
||||
|
|
Loading…
Reference in a new issue