Add extension support for SQLite (#2062)

* Add extension support for SQLite

While SQLite supports loading extensions at run-time via either the C
API or the SQL interface, they strongly recommend [1] only enabling the C
API so that SQL injections don't allow attackers to run arbitrary
extension code.

Here we take the most conservative approach, we enable only the C
function, and then only when the user requests extensions be loaded in
their `SqliteConnectOptions`, and disable it again once we're done
loading those requested modules. We don't add any support for loading
extensions via environment variables or connection strings.

Extensions in the options are stored as an IndexMap as the load order
can have side effects, they will be loaded in the order they are
supplied by the caller.

Extensions with custom entry points are supported, but a default API
is exposed as most users will interact with extensions using the
defaults.

[1]: https://sqlite.org/c3ref/enable_load_extension.html

* Add extension testing for SQlite

Extends x.py to download an appropriate shared object file for supported
operating systems, and uses wget to fetch one into the GitHub Actions
context for use by CI.

Overriding LD_LIBRARY_PATH for only this specific DB minimises the
impact on the rest of the suite.
This commit is contained in:
Richard Bradfield 2022-09-01 23:03:27 +01:00 committed by GitHub
parent 9de70d2e7a
commit 20877d83fd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 235 additions and 4 deletions

View file

@ -46,7 +46,7 @@ jobs:
- uses: Swatinem/rust-cache@v1
with:
key: ${{ runner.os }}-check-${{ matrix.runtime }}-${{ matrix.tls }}
- uses: actions-rs/cargo@v1
with:
command: check
@ -144,6 +144,8 @@ jobs:
steps:
- uses: actions/checkout@v2
- run: mkdir /tmp/sqlite3-lib && wget -O /tmp/sqlite3-lib/ipaddr.so https://github.com/nalgeon/sqlean/releases/download/0.15.2/ipaddr.so
- uses: actions-rs/toolchain@v1
with:
profile: minimal
@ -164,6 +166,8 @@ jobs:
--test-threads=1
env:
DATABASE_URL: sqlite://tests/sqlite/sqlite.db
RUSTFLAGS: --cfg sqlite_ipaddr
LD_LIBRARY_PATH: /tmp/sqlite3-lib
postgres:
name: Postgres

View file

@ -3,25 +3,46 @@ use crate::error::Error;
use crate::sqlite::connection::handle::ConnectionHandle;
use crate::sqlite::connection::{ConnectionState, Statements};
use crate::sqlite::{SqliteConnectOptions, SqliteError};
use indexmap::IndexMap;
use libc::c_void;
use libsqlite3_sys::{
sqlite3_busy_timeout, sqlite3_extended_result_codes, sqlite3_open_v2, SQLITE_OK,
sqlite3, sqlite3_busy_timeout, sqlite3_db_config, sqlite3_extended_result_codes, sqlite3_free,
sqlite3_load_extension, sqlite3_open_v2, SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION, SQLITE_OK,
SQLITE_OPEN_CREATE, SQLITE_OPEN_FULLMUTEX, SQLITE_OPEN_MEMORY, SQLITE_OPEN_NOMUTEX,
SQLITE_OPEN_PRIVATECACHE, SQLITE_OPEN_READONLY, SQLITE_OPEN_READWRITE, SQLITE_OPEN_SHAREDCACHE,
};
use std::ffi::CString;
use std::ffi::{CStr, CString};
use std::io;
use std::ptr::{null, null_mut};
use std::os::raw::c_int;
use std::ptr::{addr_of_mut, null, null_mut};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
static THREAD_ID: AtomicU64 = AtomicU64::new(0);
enum SqliteLoadExtensionMode {
/// Enables only the C-API, leaving the SQL function disabled.
Enable,
/// Disables both the C-API and the SQL function.
DisableAll,
}
impl SqliteLoadExtensionMode {
fn as_int(self) -> c_int {
match self {
SqliteLoadExtensionMode::Enable => 1,
SqliteLoadExtensionMode::DisableAll => 0,
}
}
}
pub struct EstablishParams {
filename: CString,
open_flags: i32,
busy_timeout: Duration,
statement_cache_capacity: usize,
log_settings: LogSettings,
extensions: IndexMap<CString, Option<CString>>,
pub(crate) thread_name: String,
pub(crate) command_channel_size: usize,
}
@ -89,17 +110,67 @@ impl EstablishParams {
)
})?;
let extensions = options
.extensions
.iter()
.map(|(name, entry)| {
let entry = entry
.as_ref()
.map(|e| {
CString::new(e.as_bytes()).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidData,
"extension entrypoint names passed to SQLite must not contain nul bytes"
)
})
})
.transpose()?;
Ok((
CString::new(name.as_bytes()).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidData,
"extension names passed to SQLite must not contain nul bytes",
)
})?,
entry,
))
})
.collect::<Result<IndexMap<CString, Option<CString>>, io::Error>>()?;
Ok(Self {
filename,
open_flags: flags,
busy_timeout: options.busy_timeout,
statement_cache_capacity: options.statement_cache_capacity,
log_settings: options.log_settings.clone(),
extensions,
thread_name: (options.thread_name)(THREAD_ID.fetch_add(1, Ordering::AcqRel)),
command_channel_size: options.command_channel_size,
})
}
// Enable extension loading via the db_config function, as recommended by the docs rather
// than the more obvious `sqlite3_enable_load_extension`
// https://www.sqlite.org/c3ref/db_config.html
// https://www.sqlite.org/c3ref/c_dbconfig_defensive.html#sqlitedbconfigenableloadextension
unsafe fn sqlite3_set_load_extension(
db: *mut sqlite3,
mode: SqliteLoadExtensionMode,
) -> Result<(), Error> {
let status = sqlite3_db_config(
db,
SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION,
mode.as_int(),
null::<i32>(),
);
if status != SQLITE_OK {
return Err(Error::Database(Box::new(SqliteError::new(db))));
}
Ok(())
}
pub(crate) fn establish(&self) -> Result<ConnectionState, Error> {
let mut handle = null_mut();
@ -131,6 +202,57 @@ impl EstablishParams {
sqlite3_extended_result_codes(handle.as_ptr(), 1);
}
if !self.extensions.is_empty() {
// Enable loading extensions
unsafe {
Self::sqlite3_set_load_extension(handle.as_ptr(), SqliteLoadExtensionMode::Enable)?;
}
for ext in self.extensions.iter() {
// `sqlite3_load_extension` is unusual as it returns its errors via an out-pointer
// rather than by calling `sqlite3_errmsg`
let mut error = null_mut();
status = unsafe {
sqlite3_load_extension(
handle.as_ptr(),
ext.0.as_ptr(),
ext.1.as_ref().map_or(null(), |e| e.as_ptr()),
addr_of_mut!(error),
)
};
if status != SQLITE_OK {
// SAFETY: We become responsible for any memory allocation at `&error`, so test
// for null and take an RAII version for returns
let err_msg = if !error.is_null() {
unsafe {
let e = CStr::from_ptr(error).into();
sqlite3_free(error as *mut c_void);
e
}
} else {
CString::new("Unknown error when loading extension")
.expect("text should be representable as a CString")
};
return Err(Error::Database(Box::new(SqliteError::extension(
handle.as_ptr(),
&err_msg,
))));
}
}
// Preempt any hypothetical security issues arising from leaving ENABLE_LOAD_EXTENSION
// on by disabling the flag again once we've loaded all the requested modules.
// Fail-fast (via `?`) if disabling the extension loader didn't work for some reason,
// avoids an unexpected state going undetected.
unsafe {
Self::sqlite3_set_load_extension(
handle.as_ptr(),
SqliteLoadExtensionMode::DisableAll,
)?;
}
}
// Configure a busy timeout
// This causes SQLite to automatically sleep in increasing intervals until the time
// when there is something locked during [sqlite3_step].

View file

@ -35,6 +35,13 @@ impl SqliteError {
message: message.to_owned(),
}
}
/// For errors during extension load, the error message is supplied via a separate pointer
pub(crate) fn extension(handle: *mut sqlite3, error_msg: &CStr) -> Self {
let mut err = Self::new(handle);
err.message = unsafe { from_utf8_unchecked(error_msg.to_bytes()).to_owned() };
err
}
}
impl Display for SqliteError {

View file

@ -66,6 +66,11 @@ pub struct SqliteConnectOptions {
pub(crate) vfs: Option<Cow<'static, str>>,
pub(crate) pragmas: IndexMap<Cow<'static, str>, Option<Cow<'static, str>>>,
/// Extensions are specified as a pair of <Extension Name : Optional Entry Point>, the majority
/// of SQLite extensions will use the default entry points specified in the docs, these should
/// be added to the map with a `None` value.
/// <https://www.sqlite.org/loadext.html#loading_an_extension>
pub(crate) extensions: IndexMap<Cow<'static, str>, Option<Cow<'static, str>>>,
pub(crate) command_channel_size: usize,
pub(crate) row_channel_size: usize,
@ -174,6 +179,7 @@ impl SqliteConnectOptions {
immutable: false,
vfs: None,
pragmas,
extensions: Default::default(),
collations: Default::default(),
serialized: false,
thread_name: Arc::new(DebugFn(|id| format!("sqlx-sqlite-worker-{}", id))),
@ -414,4 +420,42 @@ impl SqliteConnectOptions {
self.vfs = Some(vfs_name.into());
self
}
/// Load an [extension](https://www.sqlite.org/loadext.html) at run-time when the database connection
/// is established, using the default entry point.
///
/// Most common SQLite extensions can be loaded using this method, for extensions where you need
/// to specify the entry point, use [`extension_with_entrypoint`][`Self::extension_with_entrypoint`] instead.
///
/// Multiple extensions can be loaded by calling the method repeatedly on the options struct, they
/// will be loaded in the order they are added.
/// ```rust,no_run
/// # use sqlx_core::error::Error;
/// use std::str::FromStr;
/// use sqlx::sqlite::SqliteConnectOptions;
/// # fn options() -> Result<SqliteConnectOptions, Error> {
/// let options = SqliteConnectOptions::from_str("sqlite://data.db")?
/// .extension("vsv")
/// .extension("mod_spatialite");
/// # Ok(options)
/// # }
/// ```
pub fn extension(mut self, extension_name: impl Into<Cow<'static, str>>) -> Self {
self.extensions.insert(extension_name.into(), None);
self
}
/// Load an extension with a specified entry point.
///
/// Useful when using non-standard extensions, or when developing your own, the second argument
/// specifies where SQLite should expect to find the extension init routine.
pub fn extension_with_entrypoint(
mut self,
extension_name: impl Into<Cow<'static, str>>,
entry_point: impl Into<Cow<'static, str>>,
) -> Self {
self.extensions
.insert(extension_name.into(), Some(entry_point.into()));
self
}
}

View file

@ -204,6 +204,21 @@ async fn it_executes_with_pool() -> anyhow::Result<()> {
Ok(())
}
#[cfg(sqlite_ipaddr)]
#[sqlx_macros::test]
async fn it_opens_with_extension() -> anyhow::Result<()> {
use std::str::FromStr;
let opts = SqliteConnectOptions::from_str(&dotenvy::var("DATABASE_URL")?)?.extension("ipaddr");
let mut conn = SqliteConnection::connect_with(&opts).await?;
conn.execute("SELECT ipmasklen('192.168.16.12/24');")
.await?;
conn.close().await?;
Ok(())
}
#[sqlx_macros::test]
async fn it_opens_in_memory() -> anyhow::Result<()> {
// If the filename is ":memory:", then a private, temporary in-memory database

View file

@ -5,6 +5,8 @@ import os
import sys
import time
import argparse
import platform
import urllib.request
from glob import glob
from docker import start_database
@ -23,6 +25,36 @@ dir_workspace = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
dir_tests = os.path.join(dir_workspace, "tests")
def maybe_fetch_sqlite_extension():
"""
For supported platforms, if we're testing SQLite and the file isn't
already present, grab a simple extension for testing.
Returns the extension name if it was downloaded successfully or `None` if not.
"""
BASE_URL = "https://github.com/nalgeon/sqlean/releases/download/0.15.2/"
if platform.system() == "Darwin":
if platform.machine() == "arm64":
download_url = BASE_URL + "/ipaddr.arm64.dylib"
filename = "ipaddr.dylib"
else:
download_url = BASE_URL + "/ipaddr.dylib"
filename = "ipaddr.dylib"
elif platform.system() == "Linux":
download_url = BASE_URL + "/ipaddr.so"
filename = "ipaddr.so"
else:
# Unsupported OS
return None
if not os.path.exists(filename):
content = urllib.request.urlopen(download_url).read()
with open(filename, "wb") as fd:
fd.write(content)
return filename.split(".")[0]
def run(command, comment=None, env=None, service=None, tag=None, args=None, database_url_args=None):
if argv.list_targets:
if tag:
@ -41,6 +73,13 @@ def run(command, comment=None, env=None, service=None, tag=None, args=None, data
environ = env or {}
if service == "sqlite":
if maybe_fetch_sqlite_extension() is not None:
if environ.get("RUSTFLAGS"):
environ["RUSTFLAGS"] += " --cfg sqlite_ipaddr"
else:
environ["RUSTFLAGS"] = "--cfg sqlite_ipaddr"
if service is not None:
database_url = start_database(service, database="sqlite/sqlite.db" if service == "sqlite" else "sqlx", cwd=dir_tests)