feat(macros): move to one-file-per-query for offline mode

Query data is now stored in .sqlx/{query_hash}.json directly by the macro
invocations, rather than first writing to target/sqlx/{input_span_hash}.json
and then collecting those into sqlx-data.json separately.
This commit is contained in:
Jonas Platte 2020-12-29 20:06:23 +01:00 committed by Austin Bonander
parent babd353c2c
commit 1fa2381015
No known key found for this signature in database
GPG key ID: 461F7F0F45383F2B
10 changed files with 501 additions and 456 deletions

47
sqlx-cli/src/cargo.rs Normal file
View file

@ -0,0 +1,47 @@
use anyhow::Context;
use serde::Deserialize;
use std::env;
use std::ffi::{OsStr, OsString};
use std::path::PathBuf;
use std::process::Command;
use std::str;
#[derive(Deserialize)]
pub struct CargoMetadata {
pub target_directory: PathBuf,
pub workspace_root: PathBuf,
}
/// Path to the `cargo` executable
pub fn cargo_path() -> anyhow::Result<OsString> {
env::var_os("CARGO").context("Failed to obtain value of `CARGO`")
}
pub fn manifest_dir(cargo: &OsStr) -> anyhow::Result<PathBuf> {
let stdout = Command::new(&cargo)
.args(&["locate-project", "--message-format=plain"])
.output()
.context("could not locate manifest dir")?
.stdout;
let mut manifest_path: PathBuf = str::from_utf8(&stdout)
.context("output of `cargo locate-project` was not valid UTF-8")?
// get rid of the trailing newline
.trim()
.into();
manifest_path.pop();
Ok(manifest_path)
}
pub fn metadata(cargo: &OsStr) -> anyhow::Result<CargoMetadata> {
let output = Command::new(&cargo)
.args(&["metadata", "--format-version=1"])
.output()
.context("Could not fetch metadata")?;
serde_json::from_slice(&output.stdout)
.context("Invalid `cargo metadata` output")
.map_err(Into::into)
}

View file

@ -2,6 +2,12 @@ use anyhow::Result;
use crate::opt::{Command, DatabaseCommand, MigrateCommand};
use anyhow::{anyhow, Context};
use dotenv::dotenv;
use prepare::PrepareCtx;
use std::env;
mod cargo;
mod database;
// mod migration;
// mod migrator;
@ -74,18 +80,36 @@ pub async fn run(opt: Opt) -> Result<()> {
},
Command::Prepare {
check: false,
merged,
args,
check,
workspace,
database_url,
} => prepare::run(&database_url, merged, args)?,
args,
} => {
let cargo_path = cargo::cargo_path()?;
println!("cargo path: {:?}", cargo_path);
Command::Prepare {
check: true,
merged,
args,
database_url,
} => prepare::check(&database_url, merged, args)?,
let manifest_dir = cargo::manifest_dir(&cargo_path)?;
let metadata = cargo::metadata(&cargo_path)
.context("`prepare` subcommand may only be invoked as `cargo sqlx prepare`")?;
let ctx = PrepareCtx {
workspace,
cargo: cargo_path,
cargo_args: args,
manifest_dir,
target_dir: metadata.target_directory,
workspace_root: metadata.workspace_root,
database_url,
};
println!("{:?}", ctx);
if check {
prepare::check(&ctx)?
} else {
prepare::run(&ctx)?
}
}
};
Ok(())

View file

@ -29,16 +29,21 @@ pub enum Command {
#[clap(long)]
check: bool,
/// Generate a single top-level `sqlx-data.json` file when using a cargo workspace.
/// Do a clean build of all crates in the workspace.
///
/// This option is intended for workspaces where multiple crates use SQLx; if there is only
/// one, it is better to run `cargo sqlx prepare` without this option inside of that crate.
#[clap(long)]
merged: bool,
workspace: bool,
/// Arguments to be passed to `cargo rustc ...`.
#[clap(last = true)]
args: Vec<String>,
#[clap(flatten)]
database_url: DatabaseUrl,
// `DatabaseUrl` doesn't allow it to be optional
/// Location of the DB, by default will be read from the DATABASE_URL env var
#[clap(long, short = 'D', env)]
database_url: Option<String>,
},
#[clap(alias = "mig")]

View file

@ -1,137 +1,82 @@
use anyhow::{bail, Context};
use console::style;
use remove_dir_all::remove_dir_all;
use serde::Deserialize;
use sqlx::any::{AnyConnectOptions, AnyKind};
use std::collections::BTreeMap;
use std::fs::File;
use std::io::{BufReader, BufWriter};
use std::path::{Path, PathBuf};
use anyhow::bail;
use std::ffi::OsString;
use std::fs;
use std::path::Path;
use std::path::PathBuf;
use std::process::Command;
use std::str::FromStr;
use std::time::SystemTime;
use std::{env, fs};
type QueryData = BTreeMap<String, serde_json::Value>;
type JsonObject = serde_json::Map<String, serde_json::Value>;
#[derive(serde::Serialize, serde::Deserialize)]
struct DataFile {
db: String,
#[serde(flatten)]
data: QueryData,
#[derive(Debug)]
pub struct PrepareCtx {
pub workspace: bool,
pub cargo: OsString,
pub cargo_args: Vec<String>,
pub manifest_dir: PathBuf,
pub target_dir: PathBuf,
pub workspace_root: PathBuf,
pub database_url: Option<String>,
}
pub fn run(url: &str, merge: bool, cargo_args: Vec<String>) -> anyhow::Result<()> {
let db_kind = get_db_kind(url)?;
let data = run_prepare_step(url, merge, cargo_args)?;
pub fn run(ctx: &PrepareCtx) -> anyhow::Result<()> {
let root = if ctx.workspace {
&ctx.workspace_root
} else {
&ctx.manifest_dir
};
if data.is_empty() {
println!(
"{} no queries found; please ensure that the `offline` feature is enabled in sqlx",
style("warning:").yellow()
);
}
serde_json::to_writer_pretty(
BufWriter::new(
File::create("sqlx-data.json").context("failed to create/open `sqlx-data.json`")?,
),
&DataFile {
db: db_kind.to_owned(),
data,
},
)
.context("failed to write to `sqlx-data.json`")?;
run_prepare_step(ctx, &root.join(".sqlx"))?;
println!(
"query data written to `sqlx-data.json` in the current directory; \
"query data written to `.sqlx` in the current directory; \
please check this into version control"
);
Ok(())
}
pub fn check(url: &str, merge: bool, cargo_args: Vec<String>) -> anyhow::Result<()> {
let db_kind = get_db_kind(url)?;
let data = run_prepare_step(url, merge, cargo_args)?;
pub fn check(ctx: &PrepareCtx) -> anyhow::Result<()> {
let cache_dir = ctx.target_dir.join("sqlx");
run_prepare_step(ctx, &cache_dir)?;
let data_file = File::open("sqlx-data.json").context(
"failed to open `sqlx-data.json`; you may need to run `cargo sqlx prepare` first",
)?;
let DataFile {
db: expected_db,
data: saved_data,
} = serde_json::from_reader(BufReader::new(data_file))?;
if db_kind != expected_db {
bail!(
"saved prepare data is for {}, not {} (inferred from `DATABASE_URL`)",
expected_db,
db_kind
)
}
if data != saved_data {
bail!("`cargo sqlx prepare` needs to be rerun")
}
// TODO: Compare .sqlx to target/sqlx
// * For files thta are only in the former, raise a warning
// * For files that are only in the latter, raise an error
Ok(())
}
fn run_prepare_step(url: &str, merge: bool, cargo_args: Vec<String>) -> anyhow::Result<QueryData> {
fn run_prepare_step(ctx: &PrepareCtx, cache_dir: &Path) -> anyhow::Result<()> {
anyhow::ensure!(
Path::new("Cargo.toml").exists(),
r#"Failed to read `Cargo.toml`.
hint: This command only works in the manifest directory of a Cargo package."#
);
// path to the Cargo executable
let cargo = env::var("CARGO")
.context("`prepare` subcommand may only be invoked as `cargo sqlx prepare`")?;
let output = Command::new(&cargo)
.args(&["metadata", "--format-version=1"])
.output()
.context("Could not fetch metadata")?;
#[derive(Deserialize)]
struct Metadata {
target_directory: PathBuf,
if cache_dir.exists() {
clear_cache_dir(cache_dir)?;
} else {
fs::create_dir(cache_dir)?;
}
let metadata: Metadata =
serde_json::from_slice(&output.stdout).context("Invalid `cargo metadata` output")?;
// try removing the target/sqlx directory before running, as stale files
// have repeatedly caused issues in the past.
let _ = remove_dir_all(metadata.target_directory.join("sqlx"));
let check_status = if merge {
let check_status = Command::new(&cargo).arg("clean").status()?;
let mut check_cmd = Command::new(&ctx.cargo);
if ctx.workspace {
let check_status = Command::new(&ctx.cargo).arg("clean").status()?;
if !check_status.success() {
bail!("`cargo clean` failed with status: {}", check_status);
}
let mut rustflags = env::var("RUSTFLAGS").unwrap_or_default();
rustflags.push_str(&format!(
" --cfg __sqlx_recompile_trigger=\"{}\"",
SystemTime::UNIX_EPOCH.elapsed()?.as_millis()
));
Command::new(&cargo)
.arg("check")
.args(cargo_args)
.env("RUSTFLAGS", rustflags)
.env("SQLX_OFFLINE", "false")
.env("DATABASE_URL", url)
.status()?
check_cmd.arg("check").args(&ctx.cargo_args).env(
"RUSTFLAGS",
format!(
"--cfg __sqlx_recompile_trigger=\"{}\"",
SystemTime::UNIX_EPOCH.elapsed()?.as_millis()
),
);
} else {
Command::new(&cargo)
check_cmd
.arg("rustc")
.args(cargo_args)
.args(&ctx.cargo_args)
.arg("--")
.arg("--emit")
.arg("dep-info,metadata")
@ -140,121 +85,33 @@ hint: This command only works in the manifest directory of a Cargo package."#
.arg(format!(
"__sqlx_recompile_trigger=\"{}\"",
SystemTime::UNIX_EPOCH.elapsed()?.as_millis()
))
.env("SQLX_OFFLINE", "false")
.env("DATABASE_URL", url)
.status()?
};
));
}
// override database url
if let Some(database_url) = &ctx.database_url {
check_cmd.env("DATABASE_URL", database_url);
}
check_cmd
.env("SQLX_OFFLINE", "false")
.env("SQLX_OFFLINE_DIR", cache_dir);
println!("executing {:?}", check_cmd);
let check_status = check_cmd.status()?;
if !check_status.success() {
bail!("`cargo check` failed with status: {}", check_status);
}
let pattern = metadata.target_directory.join("sqlx/query-*.json");
let mut data = BTreeMap::new();
for path in glob::glob(
pattern
.to_str()
.context("CARGO_TARGET_DIR not valid UTF-8")?,
)? {
let path = path?;
let contents = fs::read(&*path)?;
let mut query_data: JsonObject = serde_json::from_slice(&contents)?;
// we lift the `hash` key to the outer map
let hash = query_data
.remove("hash")
.context("expected key `hash` in query data")?;
if let serde_json::Value::String(hash) = hash {
data.insert(hash, serde_json::Value::Object(query_data));
} else {
bail!(
"expected key `hash` in query data to be string, was {:?} instead; file: {}",
hash,
path.display()
)
}
// lazily remove the file, we don't care too much if we can't
let _ = fs::remove_file(&path);
}
Ok(data)
Ok(())
}
fn get_db_kind(url: &str) -> anyhow::Result<&'static str> {
let options = AnyConnectOptions::from_str(&url)?;
// these should match the values of `DatabaseExt::NAME` in `sqlx-macros`
match options.kind() {
#[cfg(feature = "postgres")]
AnyKind::Postgres => Ok("PostgreSQL"),
#[cfg(feature = "mysql")]
AnyKind::MySql => Ok("MySQL"),
#[cfg(feature = "sqlite")]
AnyKind::Sqlite => Ok("SQLite"),
#[cfg(feature = "mssql")]
AnyKind::Mssql => Ok("MSSQL"),
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use std::assert_eq;
#[test]
fn data_file_serialization_works() {
let data_file = DataFile {
db: "mysql".to_owned(),
data: {
let mut data = BTreeMap::new();
data.insert("a".to_owned(), json!({"key1": "value1"}));
data.insert("z".to_owned(), json!({"key2": "value2"}));
data
},
};
let data_file = serde_json::to_string(&data_file).expect("Data file serialized.");
assert_eq!(
data_file,
"{\"db\":\"mysql\",\"a\":{\"key1\":\"value1\"},\"z\":{\"key2\":\"value2\"}}"
);
}
#[test]
fn data_file_deserialization_works() {
let data_file =
"{\"db\":\"mysql\",\"a\":{\"key1\":\"value1\"},\"z\":{\"key2\":\"value2\"}}";
let data_file: DataFile = serde_json::from_str(data_file).expect("Data file deserialized.");
let DataFile { db, data } = data_file;
assert_eq!(db, "mysql");
assert_eq!(data.len(), 2);
assert_eq!(data.get("a"), Some(&json!({"key1": "value1"})));
assert_eq!(data.get("z"), Some(&json!({"key2": "value2"})));
}
#[test]
fn data_file_deserialization_works_for_ordered_keys() {
let data_file =
"{\"a\":{\"key1\":\"value1\"},\"db\":\"mysql\",\"z\":{\"key2\":\"value2\"}}";
let data_file: DataFile = serde_json::from_str(data_file).expect("Data file deserialized.");
let DataFile { db, data } = data_file;
assert_eq!(db, "mysql");
assert_eq!(data.len(), 2);
assert_eq!(data.get("a"), Some(&json!({"key1": "value1"})));
assert_eq!(data.get("z"), Some(&json!({"key2": "value2"})));
fn clear_cache_dir(path: &Path) -> anyhow::Result<()> {
for entry in fs::read_dir(path)? {
fs::remove_file(entry?.path())?;
}
Ok(())
}

View file

@ -87,7 +87,7 @@ proc-macro2 = { version = "1.0.36", default-features = false }
sqlx-core = { version = "0.5.12", default-features = false, path = "../sqlx-core" }
sqlx-rt = { version = "0.5.12", default-features = false, path = "../sqlx-rt" }
serde = { version = "1.0.132", features = ["derive"], optional = true }
serde_json = { version = "1.0.73", optional = true }
serde_json = { version = "1.0.73", optional = true, features = ["raw_value"] }
sha2 = { version = "0.10.0", optional = true }
syn = { version = "1.0.84", default-features = false, features = ["full"] }
quote = { version = "1.0.14", default-features = false }

View file

@ -1,198 +0,0 @@
use sqlx_core::database::Database;
use sqlx_core::describe::Describe;
use sqlx_core::executor::Executor;
#[cfg_attr(feature = "offline", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(
feature = "offline",
serde(bound(
serialize = "Describe<DB>: serde::Serialize",
deserialize = "Describe<DB>: serde::de::DeserializeOwned"
))
)]
#[derive(Debug)]
pub struct QueryData<DB: Database> {
#[allow(dead_code)]
pub(super) query: String,
pub(super) describe: Describe<DB>,
#[cfg(feature = "offline")]
pub(super) hash: String,
}
impl<DB: Database> QueryData<DB> {
pub async fn from_db(
conn: impl Executor<'_, Database = DB>,
query: &str,
) -> crate::Result<Self> {
Ok(QueryData {
query: query.into(),
describe: conn.describe(query).await?,
#[cfg(feature = "offline")]
hash: offline::hash_string(query),
})
}
}
#[cfg(feature = "offline")]
pub mod offline {
use super::QueryData;
use crate::database::DatabaseExt;
use std::collections::BTreeMap;
use std::fs::{self, File};
use std::io::BufWriter;
use std::path::{Path, PathBuf};
use std::sync::Mutex;
use once_cell::sync::Lazy;
use proc_macro2::Span;
use sqlx_core::describe::Describe;
static OFFLINE_DATA_CACHE: Lazy<Mutex<BTreeMap<PathBuf, OfflineData>>> =
Lazy::new(|| Mutex::new(BTreeMap::new()));
#[derive(serde::Deserialize)]
struct BaseQuery {
query: String,
describe: serde_json::Value,
}
#[derive(serde::Deserialize)]
struct OfflineData {
db: String,
#[serde(flatten)]
hash_to_query: BTreeMap<String, BaseQuery>,
}
impl OfflineData {
fn get_query_from_hash(&self, hash: &str) -> Option<DynQueryData> {
self.hash_to_query.get(hash).map(|base_query| DynQueryData {
db_name: self.db.clone(),
query: base_query.query.to_owned(),
describe: base_query.describe.to_owned(),
hash: hash.to_owned(),
})
}
}
#[derive(serde::Deserialize)]
pub struct DynQueryData {
#[serde(skip)]
pub db_name: String,
pub query: String,
pub describe: serde_json::Value,
#[serde(skip)]
pub hash: String,
}
impl DynQueryData {
/// Find and deserialize the data table for this query from a shared `sqlx-data.json`
/// file. The expected structure is a JSON map keyed by the SHA-256 hash of queries in hex.
pub fn from_data_file(path: impl AsRef<Path>, query: &str) -> crate::Result<Self> {
let path = path.as_ref();
let query_data = {
let mut cache = OFFLINE_DATA_CACHE
.lock()
// Just reset the cache on error
.unwrap_or_else(|posion_err| {
let mut guard = posion_err.into_inner();
*guard = BTreeMap::new();
guard
});
if !cache.contains_key(path) {
let offline_data_contents = fs::read_to_string(path)
.map_err(|e| format!("failed to read path {}: {}", path.display(), e))?;
let offline_data: OfflineData = serde_json::from_str(&offline_data_contents)?;
let _ = cache.insert(path.to_owned(), offline_data);
}
let offline_data = cache
.get(path)
.expect("Missing data should have just been added");
let query_hash = hash_string(query);
let query_data = offline_data
.get_query_from_hash(&query_hash)
.ok_or_else(|| format!("failed to find data for query {}", query))?;
if query != query_data.query {
return Err(format!(
"hash collision for stored queryies:\n{:?}\n{:?}",
query, query_data.query
)
.into());
}
query_data
};
#[cfg(procmacr2_semver_exempt)]
{
let path = path.as_ref().canonicalize()?;
let path = path.to_str().ok_or_else(|| {
format!(
"sqlx-data.json path cannot be represented as a string: {:?}",
path
)
})?;
proc_macro::tracked_path::path(path);
}
Ok(query_data)
}
}
impl<DB: DatabaseExt> QueryData<DB>
where
Describe<DB>: serde::Serialize + serde::de::DeserializeOwned,
{
pub fn from_dyn_data(dyn_data: DynQueryData) -> crate::Result<Self> {
assert!(!dyn_data.db_name.is_empty());
assert!(!dyn_data.hash.is_empty());
if DB::NAME == dyn_data.db_name {
let describe: Describe<DB> = serde_json::from_value(dyn_data.describe)?;
Ok(QueryData {
query: dyn_data.query,
describe,
hash: dyn_data.hash,
})
} else {
Err(format!(
"expected query data for {}, got data for {}",
DB::NAME,
dyn_data.db_name
)
.into())
}
}
pub fn save_in(&self, dir: impl AsRef<Path>, input_span: Span) -> crate::Result<()> {
// we save under the hash of the span representation because that should be unique
// per invocation
let path = dir.as_ref().join(format!(
"query-{}.json",
hash_string(&format!("{:?}", input_span))
));
serde_json::to_writer_pretty(
BufWriter::new(
File::create(&path)
.map_err(|e| format!("failed to open path {}: {}", path.display(), e))?,
),
self,
)
.map_err(Into::into)
}
}
pub fn hash_string(query: &str) -> String {
// picked `sha2` because it's already in the dependency tree for both MySQL and Postgres
use sha2::{Digest, Sha256};
hex::encode(Sha256::digest(query.as_bytes()))
}
}

View file

@ -0,0 +1,39 @@
use crate::database::DatabaseExt;
use sqlx_core::database::Database;
use sqlx_core::describe::Describe;
use sqlx_core::executor::Executor;
#[cfg(feature = "offline")]
pub mod offline;
#[cfg_attr(feature = "offline", derive(serde::Serialize))]
#[cfg_attr(
feature = "offline",
serde(bound(serialize = "Describe<DB>: serde::Serialize",))
)]
#[derive(Debug)]
pub struct QueryData<DB: DatabaseExt> {
#[allow(dead_code)]
pub(super) query: String,
pub(super) describe: Describe<DB>,
#[cfg(feature = "offline")]
pub(super) hash: String,
#[cfg(feature = "offline")]
db_name: offline::SerializeDbName<DB>,
}
impl<DB: DatabaseExt> QueryData<DB> {
pub async fn from_db(
conn: impl Executor<'_, Database = DB>,
query: &str,
) -> crate::Result<Self> {
Ok(QueryData {
query: query.into(),
describe: conn.describe(query).await?,
#[cfg(feature = "offline")]
hash: super::hash_string(query),
#[cfg(feature = "offline")]
db_name: offline::SerializeDbName::default(),
})
}
}

View file

@ -0,0 +1,244 @@
use once_cell::sync::{Lazy, OnceCell};
use std::collections::{BTreeMap, HashMap};
use std::fmt::{Debug, Display, Formatter};
use std::fs::{self, File};
use std::io::BufWriter;
use std::marker::PhantomData;
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use proc_macro2::Span;
use serde::de::Visitor;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use sqlx_core::database::Database;
use sqlx_core::describe::Describe;
use crate::database::DatabaseExt;
use crate::query;
use crate::query::{Metadata, QueryMacroInput};
use super::QueryData;
#[cfg(feature = "postgres")]
use sqlx_core::postgres::Postgres;
#[cfg(feature = "mysql")]
use sqlx_core::mysql::MySql;
#[cfg(feature = "sqlite")]
use sqlx_core::sqlite::Sqlite;
#[cfg(feature = "mssql")]
use sqlx_core::mssql::Mssql;
static CACHE: Lazy<Mutex<HashMap<PathBuf, Arc<dyn DynQueryData>>>> = Lazy::new(Default::default);
pub struct SerializeDbName<DB>(PhantomData<DB>);
impl<DB> Default for SerializeDbName<DB> {
fn default() -> Self {
SerializeDbName(PhantomData)
}
}
impl<DB: DatabaseExt> Debug for SerializeDbName<DB> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("SerializeDbName")
.field(&DB::DATABASE_PATH)
.finish()
}
}
impl<DB: DatabaseExt> Display for SerializeDbName<DB> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.pad(DB::NAME)
}
}
impl<DB: DatabaseExt> Serialize for SerializeDbName<DB> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(DB::NAME)
}
}
#[derive(serde::Deserialize)]
struct RawQueryData {
db_name: String,
query: String,
#[serde(skip)]
hash: String,
describe: Box<serde_json::value::RawValue>,
}
impl<DB: DatabaseExt> QueryData<DB>
where
Describe<DB>: serde::Serialize + serde::de::DeserializeOwned,
{
pub(crate) fn save(&self, meta: &Metadata, input_span: Span) -> crate::Result<()> {
let output_dir = meta.target_dir.join("sqlx");
fs::create_dir_all(&output_dir)
.map_err(|e| format!("failed to create $TARGET_DIR/sqlx: {:?}", e))?;
// we save under the hash of the span representation because that should be unique
// per invocation
let path = output_dir.join(&format!(
"query-{}.json",
query::hash_string(&format!("{:?}", input_span))
));
serde_json::to_writer_pretty(
BufWriter::new(
File::create(&path)
.map_err(|e| format!("failed to open path {}: {}", path.display(), e))?,
),
self,
)?;
let final_path = meta
.workspace_root()
.join(&format!(".sqlx/query-{}.json", self.hash));
// renaming is atomic so we don't clash with other invocations trying to write
// to the same place
fs::rename(&path, &final_path)
.map_err(|e| format!("failed to move query data to final destination: {:?}", e))?;
Ok(())
}
}
pub trait DynQueryData: Send + Sync + 'static {
fn db_name(&self) -> &str;
fn hash(&self) -> &str;
fn query(&self) -> &str;
#[cfg(feature = "postgres")]
fn to_postgres(&self) -> &QueryData<Postgres> {
panic!(
"saved query data was not for {}, it was for {}",
Postgres::NAME,
self.db_name()
)
}
#[cfg(feature = "mysql")]
fn to_mysql(&self) -> &QueryData<MySql> {
panic!(
"saved query data was not for {}, it was for {}",
MySql::NAME,
self.db_name()
)
}
#[cfg(feature = "sqlite")]
fn to_sqlite(&self) -> &QueryData<Sqlite> {
panic!(
"saved query data was not for {}, it was for {}",
Sqlite::NAME,
self.db_name()
)
}
#[cfg(feature = "mssql")]
fn to_mssql(&self) -> &QueryData<Mssql> {
panic!(
"saved query data was not for {}, it was for {}",
Mssql::NAME,
self.db_name()
)
}
}
macro_rules! impl_dyn_query_data {
($($featname:literal, $db:ty, $method:ident);*$(;)?) => {$(
#[cfg(feature = $featname)]
impl DynQueryData for QueryData<$db> {
fn db_name(&self) -> &str {
<$db as DatabaseExt>::NAME
}
fn hash(&self) -> &str {
&self.hash
}
fn query(&self) -> &str {
&self.query
}
fn $method(&self) -> &QueryData<$db> {
self
}
}
)*}
}
impl_dyn_query_data!(
"postgres", Postgres, to_postgres;
"mysql", MySql, to_mysql;
"sqlite", Sqlite, to_sqlite;
"mssql", Mssql, to_mssql;
);
pub fn get_data(query: &str, path: &Path) -> crate::Result<Arc<dyn DynQueryData>> {
let mut cache = CACHE.lock().unwrap();
if let Some(cached) = cache.get(path).cloned() {
return Ok(cached);
}
#[cfg(procmacr2_semver_exempt)]
{
let path = path.as_ref().canonicalize()?;
let path = path.to_str().ok_or_else(|| {
format!(
"sqlx-data.json path cannot be represented as a string: {:?}",
path
)
})?;
proc_macro::tracked_path::path(path);
}
let offline_data_contents = fs::read_to_string(path)
.map_err(|e| format!("failed to read path {}: {}", path.display(), e))?;
let offline_data: RawQueryData = serde_json::from_str(&offline_data_contents)?;
if query != offline_data.query {
return Err(format!("hash collision for saved query data").into());
}
macro_rules! to_dyn_data(
($($featname:literal, $db:ty);*$(;)?) => {{
let dyn_data: Arc<dyn DynQueryData> = match &*offline_data.db_name {
$(
#[cfg(feature = $featname)]
<$db as DatabaseExt>::NAME => Arc::new(QueryData {
query: offline_data.query,
hash: offline_data.hash,
db_name: SerializeDbName(PhantomData),
describe: serde_json::from_str::<Describe<$db>>(offline_data.describe.get())?,
}),
)*
other => return Err(format!("query data from filesystem used unknown database: {:?}; is the corresponding feature enabled?", other).into())
};
dyn_data
}}
);
let dyn_data = to_dyn_data!(
"postgres", Postgres;
"mysql", MySql;
"sqlite", Sqlite;
"mssql", Mssql;
);
let _ = cache.insert(path.to_owned(), dyn_data.clone());
Ok(dyn_data)
}

View file

@ -10,7 +10,6 @@ use syn::{ExprArray, Type};
pub struct QueryMacroInput {
pub(super) sql: String,
#[cfg_attr(not(feature = "offline"), allow(dead_code))]
pub(super) src_span: Span,
pub(super) record_type: RecordType,
@ -20,6 +19,9 @@ pub struct QueryMacroInput {
pub(super) checked: bool,
pub(super) file_path: Option<String>,
#[cfg(feature = "offline")]
pub(super) hash: String,
}
enum QuerySrc {
@ -98,8 +100,12 @@ impl Parse for QueryMacroInput {
let file_path = src.file_path(src_span)?;
let sql = src.resolve(src_span)?;
Ok(QueryMacroInput {
sql: src.resolve(src_span)?,
#[cfg(feature = "offline")]
hash: super::hash_string(&sql),
sql,
src_span,
record_type,
arg_exprs,

View file

@ -2,13 +2,14 @@ use std::path::PathBuf;
#[cfg(feature = "offline")]
use std::sync::{Arc, Mutex};
use either::Either;
use once_cell::sync::Lazy;
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::Type;
use url::Url;
pub use input::QueryMacroInput;
use quote::{format_ident, quote};
use sqlx_core::connection::Connection;
use sqlx_core::database::Database;
use sqlx_core::{column::Column, describe::Describe, type_info::TypeInfo};
@ -17,14 +18,13 @@ use sqlx_rt::block_on;
use crate::database::DatabaseExt;
use crate::query::data::QueryData;
use crate::query::input::RecordType;
use either::Either;
mod args;
mod data;
mod input;
mod output;
struct Metadata {
pub struct Metadata {
#[allow(unused)]
manifest_dir: PathBuf,
offline: bool,
@ -125,12 +125,16 @@ pub fn expand_input(input: QueryMacroInput) -> crate::Result<TokenStream> {
#[cfg(feature = "offline")]
_ => {
let data_file_path = METADATA.manifest_dir.join("sqlx-data.json");
let workspace_root = METADATA.workspace_root();
let data_dir = workspace_root.join(".sqlx");
let data_file_path = data_dir.join(format!("query-{}.json", input.hash));
if data_file_path.exists() {
expand_from_file(input, data_file_path)
} else {
let workspace_data_file_path = METADATA.workspace_root().join("sqlx-data.json");
let workspace_data_file_path = workspace_root.join("sqlx-data.json");
if workspace_data_file_path.exists() {
expand_from_file(input, workspace_data_file_path)
} else {
@ -171,7 +175,7 @@ fn expand_from_db(input: QueryMacroInput, db_url: &str) -> crate::Result<TokenSt
QueryData::from_db(&mut conn, &input.sql).await
})?;
expand_with_data(input, data, false)
expand_with_data(input, &data, false)
},
#[cfg(not(feature = "postgres"))]
@ -184,7 +188,7 @@ fn expand_from_db(input: QueryMacroInput, db_url: &str) -> crate::Result<TokenSt
QueryData::from_db(&mut conn, &input.sql).await
})?;
expand_with_data(input, data, false)
expand_with_data(input, &data, false)
},
#[cfg(not(feature = "mssql"))]
@ -197,7 +201,7 @@ fn expand_from_db(input: QueryMacroInput, db_url: &str) -> crate::Result<TokenSt
QueryData::from_db(&mut conn, &input.sql).await
})?;
expand_with_data(input, data, false)
expand_with_data(input, &data, false)
},
#[cfg(not(feature = "mysql"))]
@ -210,7 +214,7 @@ fn expand_from_db(input: QueryMacroInput, db_url: &str) -> crate::Result<TokenSt
QueryData::from_db(&mut conn, &input.sql).await
})?;
expand_with_data(input, data, false)
expand_with_data(input, &data, false)
},
#[cfg(not(feature = "sqlite"))]
@ -224,31 +228,22 @@ fn expand_from_db(input: QueryMacroInput, db_url: &str) -> crate::Result<TokenSt
pub fn expand_from_file(input: QueryMacroInput, file: PathBuf) -> crate::Result<TokenStream> {
use data::offline::DynQueryData;
let query_data = DynQueryData::from_data_file(file, &input.sql)?;
assert!(!query_data.db_name.is_empty());
let query_data = data::offline::get_data(&input.sql, &file)?;
match &*query_data.db_name {
match &*query_data.db_name() {
#[cfg(feature = "postgres")]
sqlx_core::postgres::Postgres::NAME => expand_with_data(
input,
QueryData::<sqlx_core::postgres::Postgres>::from_dyn_data(query_data)?,
true,
),
sqlx_core::postgres::Postgres::NAME => {
expand_with_data(input, query_data.to_postgres(), true)
}
#[cfg(feature = "mysql")]
sqlx_core::mysql::MySql::NAME => expand_with_data(
input,
QueryData::<sqlx_core::mysql::MySql>::from_dyn_data(query_data)?,
true,
),
sqlx_core::mysql::MySql::NAME => expand_with_data(input, query_data.to_mysql(), true),
#[cfg(feature = "sqlite")]
sqlx_core::sqlite::Sqlite::NAME => expand_with_data(
input,
QueryData::<sqlx_core::sqlite::Sqlite>::from_dyn_data(query_data)?,
true,
),
sqlx_core::sqlite::Sqlite::NAME => expand_with_data(input, query_data.to_sqlite(), true),
#[cfg(feature = "mssql")]
sqlx_core::mssql::Mssql::NAME => expand_with_data(input, query_data.to_mssql(), true),
_ => Err(format!(
"found query data for {} but the feature for that database was not enabled",
query_data.db_name
query_data.db_name()
)
.into()),
}
@ -272,7 +267,7 @@ impl<DB: Database> DescribeExt for Describe<DB> {}
fn expand_with_data<DB: DatabaseExt>(
input: QueryMacroInput,
data: QueryData<DB>,
data: &QueryData<DB>,
#[allow(unused_variables)] offline: bool,
) -> crate::Result<TokenStream>
where
@ -379,9 +374,27 @@ where
// If the build is offline, the cache is our input so it's pointless to also write data for it.
#[cfg(feature = "offline")]
if !offline {
let save_dir = METADATA.target_dir.join("sqlx");
std::fs::create_dir_all(&save_dir)?;
data.save_in(save_dir, input.src_span)?;
use std::{fs, io};
let save_dir = METADATA.manifest_dir.join(".sqlx");
match fs::metadata(&save_dir) {
Err(e) => {
if e.kind() != io::ErrorKind::NotFound {
// Can't obtain information about .sqlx
return Err(e.into());
}
// .sqlx doesn't exist, do nothing
}
Ok(meta) => {
if !meta.is_dir() {
return Err(".sqlx exists, but is not a directory".into());
}
// .sqlx exists and is a directory, store data
data.save(&METADATA, input.src_span)?;
}
}
}
Ok(ret_tokens)
@ -399,3 +412,11 @@ fn env(name: &str) -> Result<String, std::env::VarError> {
std::env::var(name)
}
}
#[cfg(feature = "offline")]
pub fn hash_string(query: &str) -> String {
// picked `sha2` because it's already in the dependency tree for both MySQL and Postgres
use sha2::{Digest, Sha256};
hex::encode(Sha256::digest(query.as_bytes()))
}