feat: implement testing utilities (#2001)

This commit is contained in:
Austin Bonander 2022-08-02 14:38:12 -07:00 committed by GitHub
parent 7adbb7fefc
commit 054f61980a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
74 changed files with 3701 additions and 236 deletions

View file

@ -170,7 +170,7 @@ jobs:
runs-on: ubuntu-20.04
strategy:
matrix:
postgres: [14, 9_6]
postgres: [14, 10]
runtime: [async-std, tokio, actix]
tls: [native-tls, rustls]
needs: check
@ -230,7 +230,7 @@ jobs:
runs-on: ubuntu-20.04
strategy:
matrix:
mysql: [8, 5_6]
mysql: [8, 5_7]
runtime: [async-std, tokio, actix]
tls: [native-tls, rustls]
needs: check
@ -257,6 +257,17 @@ jobs:
- run: sleep 60
- uses: actions-rs/cargo@v1
with:
command: test
args: >
--no-default-features
--features any,mysql,macros,migrate,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }}
env:
DATABASE_URL: mysql://root:password@localhost:3306/sqlx?ssl-mode=disabled
# MySQL 5.7 supports TLS but not TLSv1.3 as required by RusTLS.
- uses: actions-rs/cargo@v1
if: ${{ !(matrix.mysql == '5_7' && matrix.tls == 'rustls') }}
with:
command: test
args: >
@ -270,7 +281,7 @@ jobs:
runs-on: ubuntu-20.04
strategy:
matrix:
mariadb: [10_6, 10_2]
mariadb: [10_6, 10_3]
runtime: [async-std, tokio, actix]
tls: [native-tls, rustls]
needs: check

444
Cargo.lock generated
View file

@ -37,6 +37,17 @@ version = "1.0.58"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb07d2053ccdbe10e2af2995a2f116c1330396493dc1269f6a91d0ae82e19704"
[[package]]
name = "argon2"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "db4ce4441f99dbd377ca8a8f57b698c44d0d6e712d8329b5040da5a64aa1ce73"
dependencies = [
"base64ct",
"blake2",
"password-hash",
]
[[package]]
name = "arrayvec"
version = "0.7.2"
@ -228,6 +239,64 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "axum"
version = "0.5.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b9496f0c1d1afb7a2af4338bbe1d969cddfead41d87a9fb3aaa6d0bbc7af648"
dependencies = [
"async-trait",
"axum-core",
"axum-macros",
"bitflags",
"bytes",
"futures-util",
"http",
"http-body",
"hyper",
"itoa 1.0.2",
"matchit",
"memchr",
"mime",
"percent-encoding",
"pin-project-lite",
"serde",
"serde_json",
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tower",
"tower-http",
"tower-layer",
"tower-service",
]
[[package]]
name = "axum-core"
version = "0.2.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e4f44a0e6200e9d11a1cdc989e4b358f6e3d354fbf48478f345a17f4e43f8635"
dependencies = [
"async-trait",
"bytes",
"futures-util",
"http",
"http-body",
"mime",
]
[[package]]
name = "axum-macros"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6293dae2ec708e679da6736e857cf8532886ef258e92930f38279c12641628b8"
dependencies = [
"heck 0.4.0",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "backoff"
version = "0.4.0"
@ -277,6 +346,15 @@ version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "blake2"
version = "0.10.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9cf849ee05b2ee5fba5e36f97ff8ec2533916700fc0758d40d92136a42f3388"
dependencies = [
"digest",
]
[[package]]
name = "block-buffer"
version = "0.10.2"
@ -406,6 +484,7 @@ dependencies = [
"libc",
"num-integer",
"num-traits",
"serde",
"time 0.1.44",
"winapi",
]
@ -688,6 +767,41 @@ dependencies = [
"syn",
]
[[package]]
name = "darling"
version = "0.14.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4529658bdda7fd6769b8614be250cdcfc3aeb0ee72fe66f9e41e5e5eb73eac02"
dependencies = [
"darling_core",
"darling_macro",
]
[[package]]
name = "darling_core"
version = "0.14.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "649c91bc01e8b1eac09fb91e8dbc7d517684ca6be8ebc75bb9cafc894f9fdb6f"
dependencies = [
"fnv",
"ident_case",
"proc-macro2",
"quote",
"strsim 0.10.0",
"syn",
]
[[package]]
name = "darling_macro"
version = "0.14.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ddfc69c5bfcbd2fc09a0f38451d2daf0e372e367986a83906d1b0dbc88134fb5"
dependencies = [
"darling_core",
"quote",
"syn",
]
[[package]]
name = "der"
version = "0.5.1"
@ -915,6 +1029,12 @@ dependencies = [
"spin 0.9.3",
]
[[package]]
name = "fnv"
version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "foreign-types"
version = "0.3.2"
@ -1199,12 +1319,81 @@ dependencies = [
"digest",
]
[[package]]
name = "http"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75f43d41e26995c17e71ee126451dd3941010b0514a81a9d11f3b341debc2399"
dependencies = [
"bytes",
"fnv",
"itoa 1.0.2",
]
[[package]]
name = "http-body"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1"
dependencies = [
"bytes",
"http",
"pin-project-lite",
]
[[package]]
name = "http-range-header"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bfe8eed0a9285ef776bb792479ea3834e8b94e13d615c2f66d03dd50a435a29"
[[package]]
name = "httparse"
version = "1.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "496ce29bb5a52785b44e0f7ca2847ae0bb839c9bd28f69acac9b99d461c0c04c"
[[package]]
name = "httpdate"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421"
[[package]]
name = "humantime"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
[[package]]
name = "hyper"
version = "0.14.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02c929dc5c39e335a03c405292728118860721b10190d98c2a0f0efd5baafbac"
dependencies = [
"bytes",
"futures-channel",
"futures-core",
"futures-util",
"http",
"http-body",
"httparse",
"httpdate",
"itoa 1.0.2",
"pin-project-lite",
"socket2",
"tokio",
"tower-service",
"tracing",
"want",
]
[[package]]
name = "ident_case"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
[[package]]
name = "idna"
version = "0.2.3"
@ -1216,6 +1405,12 @@ dependencies = [
"unicode-normalization",
]
[[package]]
name = "if_chain"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb56e1aa765b4b4f3aadfab769793b7087bb03a4ea4920644a6d238e2df5b9ed"
[[package]]
name = "indexmap"
version = "1.9.1"
@ -1224,6 +1419,7 @@ checksum = "10a35a97730320ffe8e2d410b5d3b69279b98d2c14bdb8b70ea89ecf7888d41e"
dependencies = [
"autocfg",
"hashbrown",
"serde",
]
[[package]]
@ -1407,6 +1603,12 @@ version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a3e378b66a060d48947b590737b30a1be76706c8dd7b8ba0f2fe3989c68a853f"
[[package]]
name = "matchit"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb"
[[package]]
name = "md-5"
version = "0.10.1"
@ -1431,6 +1633,12 @@ dependencies = [
"autocfg",
]
[[package]]
name = "mime"
version = "0.3.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d"
[[package]]
name = "minimal-lexical"
version = "0.2.1"
@ -1737,6 +1945,17 @@ dependencies = [
"windows-sys",
]
[[package]]
name = "password-hash"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7676374caaee8a325c9e7a2ae557f216c5563a171d6997b0ef8a65af35147700"
dependencies = [
"base64ct",
"rand_core",
"subtle",
]
[[package]]
name = "paste"
version = "1.0.7"
@ -2273,9 +2492,9 @@ dependencies = [
[[package]]
name = "serde"
version = "1.0.139"
version = "1.0.140"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0171ebb889e45aa68b44aee0859b3eede84c6f5f5c228e6f140c0b2a0a46cad6"
checksum = "fc855a42c7967b7c369eb5860f7164ef1f6f81c20c7cc1141f2a604e18723b03"
dependencies = [
"serde_derive",
]
@ -2292,9 +2511,9 @@ dependencies = [
[[package]]
name = "serde_derive"
version = "1.0.139"
version = "1.0.140"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc1d3230c1de7932af58ad8ffbe1d784bd55efd5a9d84ac24f69c72d83543dfb"
checksum = "6f2122636b9fe3b81f1cb25099fcf2d3f542cdb1d45940d56c713158884a05da"
dependencies = [
"proc-macro2",
"quote",
@ -2312,6 +2531,46 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_urlencoded"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd"
dependencies = [
"form_urlencoded",
"itoa 1.0.2",
"ryu",
"serde",
]
[[package]]
name = "serde_with"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89df7a26519371a3cce44fbb914c2819c84d9b897890987fa3ab096491cc0ea8"
dependencies = [
"base64",
"chrono",
"hex",
"indexmap",
"serde",
"serde_json",
"serde_with_macros",
"time 0.3.11",
]
[[package]]
name = "serde_with_macros"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "de337f322382fcdfbb21a014f7c224ee041a23785651db67b9827403178f698f"
dependencies = [
"darling",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "sha-1"
version = "0.10.0"
@ -2491,6 +2750,7 @@ dependencies = [
"crossbeam-queue",
"digest",
"dirs",
"dotenvy",
"either",
"encoding_rs",
"event-listener",
@ -2555,6 +2815,30 @@ dependencies = [
"tokio",
]
[[package]]
name = "sqlx-example-postgres-axum-social"
version = "0.1.0"
dependencies = [
"anyhow",
"argon2",
"axum",
"dotenvy",
"once_cell",
"rand",
"regex",
"serde",
"serde_json",
"serde_with",
"sqlx",
"thiserror",
"time 0.3.11",
"tokio",
"tower",
"tracing",
"uuid",
"validator",
]
[[package]]
name = "sqlx-example-postgres-listen"
version = "0.1.0"
@ -2725,6 +3009,12 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "sync_wrapper"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "20518fe4a4c9acf048008599e464deb21beeae3d3578418951a189c235a7a9a8"
[[package]]
name = "tempfile"
version = "3.3.0"
@ -2819,6 +3109,7 @@ dependencies = [
"itoa 1.0.2",
"libc",
"num_threads",
"serde",
"time-macros",
]
@ -2855,9 +3146,9 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c"
[[package]]
name = "tokio"
version = "1.20.0"
version = "1.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57aec3cfa4c296db7255446efb4928a6be304b431a806216105542a67b6ca82e"
checksum = "7a8325f63a7d4774dd041e363b2409ed1c5cbbd0f867795e661df066b2b0a581"
dependencies = [
"autocfg",
"bytes",
@ -2926,6 +3217,92 @@ dependencies = [
"serde",
]
[[package]]
name = "tower"
version = "0.4.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c"
dependencies = [
"futures-core",
"futures-util",
"pin-project",
"pin-project-lite",
"tokio",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "tower-http"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c530c8675c1dbf98facee631536fa116b5fb6382d7dd6dc1b118d970eafe3ba"
dependencies = [
"bitflags",
"bytes",
"futures-core",
"futures-util",
"http",
"http-body",
"http-range-header",
"pin-project-lite",
"tower",
"tower-layer",
"tower-service",
]
[[package]]
name = "tower-layer"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "343bc9466d3fe6b0f960ef45960509f84480bf4fd96f92901afe7ff3df9d3a62"
[[package]]
name = "tower-service"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52"
[[package]]
name = "tracing"
version = "0.1.35"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a400e31aa60b9d44a52a8ee0343b5b18566b03a8321e0d321f695cf56e940160"
dependencies = [
"cfg-if",
"log",
"pin-project-lite",
"tracing-attributes",
"tracing-core",
]
[[package]]
name = "tracing-attributes"
version = "0.1.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "11c75893af559bc8e10716548bdef5cb2b983f8e637db9d0e15126b61b484ee2"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "tracing-core"
version = "0.1.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b7358be39f2f274f322d2aaed611acc57f382e8eb1e5b48cb9ae30933495ce7"
dependencies = [
"once_cell",
]
[[package]]
name = "try-lock"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642"
[[package]]
name = "trybuild"
version = "1.0.63"
@ -3015,6 +3392,51 @@ name = "uuid"
version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd6469f4314d5f1ffec476e05f17cc9a78bc7a27a6a857842170bdf8d6f98d2f"
dependencies = [
"serde",
]
[[package]]
name = "validator"
version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32ad5bf234c7d3ad1042e5252b7eddb2c4669ee23f32c7dd0e9b7705f07ef591"
dependencies = [
"idna",
"lazy_static",
"regex",
"serde",
"serde_derive",
"serde_json",
"url",
"validator_derive",
]
[[package]]
name = "validator_derive"
version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc44ca3088bb3ba384d9aecf40c6a23a676ce23e09bdaca2073d99c207f864af"
dependencies = [
"if_chain",
"lazy_static",
"proc-macro-error",
"proc-macro2",
"quote",
"regex",
"syn",
"validator_types",
]
[[package]]
name = "validator_types"
version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "111abfe30072511849c5910134e8baf8dc05de4c0e5903d681cbd5c9c4d611e3"
dependencies = [
"proc-macro2",
"syn",
]
[[package]]
name = "value-bag"
@ -3061,6 +3483,16 @@ dependencies = [
"winapi-util",
]
[[package]]
name = "want"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ce8a968cb1cd110d136ff8b819a556d6fb6d919363c61534f6860c7eb172ba0"
dependencies = [
"log",
"try-lock",
]
[[package]]
name = "wasi"
version = "0.10.0+wasi-snapshot-preview1"

View file

@ -8,6 +8,7 @@ members = [
"sqlx-cli",
"sqlx-bench",
"examples/mysql/todos",
"examples/postgres/axum-social-with-tests",
"examples/postgres/files",
"examples/postgres/json",
"examples/postgres/listen",
@ -197,6 +198,11 @@ name = "sqlite-derives"
path = "tests/sqlite/derives.rs"
required-features = ["sqlite", "macros"]
[[test]]
name = "sqlite-test-attr"
path = "tests/sqlite/test-attr.rs"
required-features = ["sqlite", "macros", "migrate"]
#
# MySQL
#
@ -221,6 +227,11 @@ name = "mysql-macros"
path = "tests/mysql/macros.rs"
required-features = ["mysql", "macros"]
[[test]]
name = "mysql-test-attr"
path = "tests/mysql/test-attr.rs"
required-features = ["mysql", "macros", "migrate"]
#
# PostgreSQL
#
@ -250,6 +261,11 @@ name = "postgres-derives"
path = "tests/postgres/derives.rs"
required-features = ["postgres", "macros"]
[[test]]
name = "postgres-test-attr"
path = "tests/postgres/test-attr.rs"
required-features = ["postgres", "macros", "migrate"]
#
# Microsoft SQL Server (MSSQL)
#

View file

@ -0,0 +1,33 @@
[package]
name = "sqlx-example-postgres-axum-social"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
# Primary crates
axum = { version = "0.5.13", features = ["macros"] }
sqlx = { version = "0.6.0", path = "../../../", features = ["runtime-tokio-rustls", "postgres", "time", "uuid"] }
tokio = { version = "1.20.1", features = ["rt-multi-thread", "macros"] }
# Important secondary crates
argon2 = "0.4.1"
rand = "0.8.5"
regex = "1.6.0"
serde = "1.0.140"
serde_with = { version = "2.0.0", features = ["time_0_3"] }
time = "0.3.11"
uuid = { version = "1.1.2", features = ["serde"] }
validator = { version = "0.16.0", features = ["derive"] }
# Auxilliary crates
anyhow = "1.0.58"
dotenvy = "0.15.1"
once_cell = "1.13.0"
thiserror = "1.0.31"
tracing = "0.1.35"
[dev-dependencies]
serde_json = "1.0.82"
tower = "0.4.13"

View file

@ -0,0 +1,10 @@
This example demonstrates how to write integration tests for an API build with [Axum] and SQLx using `#[sqlx::test]`.
See also: https://github.com/tokio-rs/axum/blob/main/examples/testing
# Warning
For the sake of brevity, this project omits numerous critical security precautions. You can use it as a starting point,
but deploy to production at your own risk!
[Axum]: https://github.com/tokio-rs/axum

View file

@ -0,0 +1,6 @@
create table "user"
(
user_id uuid primary key default gen_random_uuid(),
username text unique not null,
password_hash text not null
);

View file

@ -0,0 +1,8 @@
create table post (
post_id uuid primary key default gen_random_uuid(),
user_id uuid not null references "user"(user_id),
content text not null,
created_at timestamptz not null default now()
);
create index on post(created_at desc);

View file

@ -0,0 +1,9 @@
create table comment (
comment_id uuid primary key default gen_random_uuid(),
post_id uuid not null references post(post_id),
user_id uuid not null references "user"(user_id),
content text not null,
created_at timestamptz not null default now()
);
create index on comment(post_id, created_at);

View file

@ -0,0 +1,75 @@
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::Json;
use serde_with::DisplayFromStr;
use validator::ValidationErrors;
/// An API-friendly error type.
#[derive(thiserror::Error, Debug)]
pub enum Error {
/// A SQLx call returned an error.
///
/// The exact error contents are not reported to the user in order to avoid leaking
/// information about databse internals.
#[error("an internal database error occurred")]
Sqlx(#[from] sqlx::Error),
/// Similarly, we don't want to report random `anyhow` errors to the user.
#[error("an internal server error occurred")]
Anyhow(#[from] anyhow::Error),
#[error("validation error in request body")]
InvalidEntity(#[from] ValidationErrors),
#[error("{0}")]
UnprocessableEntity(String),
#[error("{0}")]
Conflict(String),
}
impl IntoResponse for Error {
fn into_response(self) -> Response {
#[serde_with::serde_as]
#[serde_with::skip_serializing_none]
#[derive(serde::Serialize)]
struct ErrorResponse<'a> {
// Serialize the `Display` output as the error message
#[serde_as(as = "DisplayFromStr")]
message: &'a Error,
errors: Option<&'a ValidationErrors>,
}
let errors = match &self {
Error::InvalidEntity(errors) => Some(errors),
_ => None,
};
// Normally you wouldn't just print this, but it's useful for debugging without
// using a logging framework.
println!("API error: {:?}", self);
(
self.status_code(),
Json(ErrorResponse {
message: &self,
errors,
}),
)
.into_response()
}
}
impl Error {
fn status_code(&self) -> StatusCode {
use Error::*;
match self {
Sqlx(_) | Anyhow(_) => StatusCode::INTERNAL_SERVER_ERROR,
InvalidEntity(_) | UnprocessableEntity(_) => StatusCode::UNPROCESSABLE_ENTITY,
Conflict(_) => StatusCode::CONFLICT,
}
}
}

View file

@ -0,0 +1,26 @@
use anyhow::Context;
use axum::{Extension, Router};
use sqlx::PgPool;
mod error;
mod post;
mod user;
pub use self::error::Error;
pub type Result<T, E = Error> = ::std::result::Result<T, E>;
pub fn app(db: PgPool) -> Router {
Router::new()
.merge(user::router())
.merge(post::router())
.layer(Extension(db))
}
pub async fn serve(db: PgPool) -> anyhow::Result<()> {
axum::Server::bind(&"0.0.0.0:8080".parse().unwrap())
.serve(app(db).into_make_service())
.await
.context("failed to serve API")
}

View file

@ -0,0 +1,100 @@
use axum::extract::Path;
use axum::{Extension, Json, Router};
use axum::routing::get;
use serde::{Deserialize, Serialize};
use time::OffsetDateTime;
use crate::http::user::UserAuth;
use sqlx::PgPool;
use validator::Validate;
use crate::http::Result;
use time::format_description::well_known::Rfc3339;
use uuid::Uuid;
pub fn router() -> Router {
Router::new().route(
"/v1/post/:postId/comment",
get(get_post_comments).post(create_post_comment),
)
}
#[derive(Deserialize, Validate)]
#[serde(rename_all = "camelCase")]
struct CreateCommentRequest {
auth: UserAuth,
#[validate(length(min = 1, max = 1000))]
content: String,
}
#[serde_with::serde_as]
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct Comment {
comment_id: Uuid,
username: String,
content: String,
// `OffsetDateTime`'s default serialization format is not standard.
#[serde_as(as = "Rfc3339")]
created_at: OffsetDateTime,
}
// #[axum::debug_handler] // very useful!
async fn create_post_comment(
db: Extension<PgPool>,
Path(post_id): Path<Uuid>,
Json(req): Json<CreateCommentRequest>,
) -> Result<Json<Comment>> {
req.validate()?;
let user_id = req.auth.verify(&*db).await?;
let comment = sqlx::query_as!(
Comment,
// language=PostgreSQL
r#"
with inserted_comment as (
insert into comment(user_id, post_id, content)
values ($1, $2, $3)
returning comment_id, user_id, content, created_at
)
select comment_id, username, content, created_at
from inserted_comment
inner join "user" using (user_id)
"#,
user_id,
post_id,
req.content
)
.fetch_one(&*db)
.await?;
Ok(Json(comment))
}
/// Returns comments in ascending chronological order.
async fn get_post_comments(
db: Extension<PgPool>,
Path(post_id): Path<Uuid>,
) -> Result<Json<Vec<Comment>>> {
// Note: normally you'd want to put a `LIMIT` on this as well,
// though that would also necessitate implementing pagination.
let comments = sqlx::query_as!(
Comment,
// language=PostgreSQL
r#"
select comment_id, username, content, created_at
from comment
inner join "user" using (user_id)
where post_id = $1
order by created_at
"#,
post_id
)
.fetch_all(&*db)
.await?;
Ok(Json(comments))
}

View file

@ -0,0 +1,93 @@
use axum::{Extension, Json, Router};
use axum::routing::get;
use serde::{Deserialize, Serialize};
use time::OffsetDateTime;
use crate::http::user::UserAuth;
use sqlx::PgPool;
use validator::Validate;
use crate::http::Result;
use time::format_description::well_known::Rfc3339;
use uuid::Uuid;
mod comment;
pub fn router() -> Router {
Router::new()
.route("/v1/post", get(get_posts).post(create_post))
.merge(comment::router())
}
#[derive(Deserialize, Validate)]
#[serde(rename_all = "camelCase")]
struct CreatePostRequest {
auth: UserAuth,
#[validate(length(min = 1, max = 1000))]
content: String,
}
#[serde_with::serde_as]
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct Post {
post_id: Uuid,
username: String,
content: String,
// `OffsetDateTime`'s default serialization format is not standard.
#[serde_as(as = "Rfc3339")]
created_at: OffsetDateTime,
}
// #[axum::debug_handler] // very useful!
async fn create_post(
db: Extension<PgPool>,
Json(req): Json<CreatePostRequest>,
) -> Result<Json<Post>> {
req.validate()?;
let user_id = req.auth.verify(&*db).await?;
let post = sqlx::query_as!(
Post,
// language=PostgreSQL
r#"
with inserted_post as (
insert into post(user_id, content)
values ($1, $2)
returning post_id, user_id, content, created_at
)
select post_id, username, content, created_at
from inserted_post
inner join "user" using (user_id)
"#,
user_id,
req.content
)
.fetch_one(&*db)
.await?;
Ok(Json(post))
}
/// Returns posts in descending chronological order.
async fn get_posts(db: Extension<PgPool>) -> Result<Json<Vec<Post>>> {
// Note: normally you'd want to put a `LIMIT` on this as well,
// though that would also necessitate implementing pagination.
let posts = sqlx::query_as!(
Post,
// language=PostgreSQL
r#"
select post_id, username, content, created_at
from post
inner join "user" using (user_id)
order by created_at desc
"#
)
.fetch_all(&*db)
.await?;
Ok(Json(posts))
}

View file

@ -0,0 +1,95 @@
use axum::http::StatusCode;
use axum::{routing::post, Extension, Json, Router};
use once_cell::sync::Lazy;
use rand::Rng;
use regex::Regex;
use std::time::Duration;
use serde::Deserialize;
use sqlx::{PgExecutor, PgPool};
use uuid::Uuid;
use validator::Validate;
use crate::http::{Error, Result};
pub type UserId = Uuid;
pub fn router() -> Router {
Router::new().route("/v1/user", post(create_user))
}
static USERNAME_REGEX: Lazy<Regex> = Lazy::new(|| Regex::new(r"^[0-9A-Za-z_]+$").unwrap());
// CREATE USER
#[derive(Deserialize, Validate)]
#[serde(rename_all = "camelCase")]
pub struct UserAuth {
#[validate(length(min = 3, max = 16), regex = "USERNAME_REGEX")]
username: String,
#[validate(length(min = 8, max = 32))]
password: String,
}
// WARNING: this API has none of the checks that a normal user signup flow implements,
// such as email or phone verification.
async fn create_user(db: Extension<PgPool>, Json(req): Json<UserAuth>) -> Result<StatusCode> {
req.validate()?;
let UserAuth { username, password } = req;
// It would be irresponsible to store passwords in plaintext, however.
let password_hash = crate::password::hash(password).await?;
sqlx::query!(
// language=PostgreSQL
r#"
insert into "user"(username, password_hash)
values ($1, $2)
"#,
username,
password_hash
)
.execute(&*db)
.await
.map_err(|e| match e {
sqlx::Error::Database(dbe) if dbe.constraint() == Some("user_username_key") => {
Error::Conflict("username taken".into())
}
_ => e.into(),
})?;
Ok(StatusCode::NO_CONTENT)
}
impl UserAuth {
// NOTE: normally we wouldn't want to verify the username and password every time,
// but persistent sessions would have complicated the example.
pub async fn verify(self, db: impl PgExecutor<'_> + Send) -> Result<UserId> {
self.validate()?;
let maybe_user = sqlx::query!(
r#"select user_id, password_hash from "user" where username = $1"#,
self.username
)
.fetch_optional(db)
.await?;
if let Some(user) = maybe_user {
let verified = crate::password::verify(self.password, user.password_hash).await?;
if verified {
return Ok(user.user_id);
}
}
// Sleep a random amount of time to avoid leaking existence of a user in timing.
let sleep_duration =
rand::thread_rng().gen_range(Duration::from_millis(100)..=Duration::from_millis(500));
tokio::time::sleep(sleep_duration).await;
Err(Error::UnprocessableEntity(
"invalid username/password".into(),
))
}
}

View file

@ -0,0 +1,3 @@
pub mod http;
mod password;

View file

@ -0,0 +1,19 @@
use anyhow::Context;
use sqlx::postgres::PgPoolOptions;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let database_url = dotenvy::var("DATABASE_URL")
// The error from `var()` doesn't mention the environment variable.
.context("DATABASE_URL must be set")?;
let db = PgPoolOptions::new()
.max_connections(20)
.connect(&database_url)
.await
.context("failed to connect to DATABASE_URL")?;
sqlx::migrate!().run(&db).await?;
sqlx_example_postgres_axum_social::http::serve(db).await
}

View file

@ -0,0 +1,34 @@
use anyhow::{anyhow, Context};
use tokio::task;
use argon2::password_hash::SaltString;
use argon2::{password_hash, Argon2, PasswordHash, PasswordHasher, PasswordVerifier};
pub async fn hash(password: String) -> anyhow::Result<String> {
task::spawn_blocking(move || {
let salt = SaltString::generate(rand::thread_rng());
Ok(Argon2::default()
.hash_password(password.as_bytes(), &salt)
.map_err(|e| anyhow!(e).context("failed to hash password"))?
.to_string())
})
.await
.context("panic in hash()")?
}
pub async fn verify(password: String, hash: String) -> anyhow::Result<bool> {
task::spawn_blocking(move || {
let hash = PasswordHash::new(&hash)
.map_err(|e| anyhow!(e).context("BUG: password hash invalid"))?;
let res = Argon2::default().verify_password(password.as_bytes(), &hash);
match res {
Ok(()) => Ok(true),
Err(password_hash::Error::Password) => Ok(false),
Err(e) => Err(anyhow!(e).context("failed to verify password")),
}
})
.await
.context("panic in verify()")?
}

View file

@ -0,0 +1,152 @@
use sqlx::PgPool;
use sqlx_example_postgres_axum_social::http;
use axum::http::{Request, StatusCode};
use tower::ServiceExt;
use std::borrow::BorrowMut;
use common::{expect_rfc3339_timestamp, expect_uuid, response_json, RequestBuilderExt};
use serde_json::json;
mod common;
#[sqlx::test(fixtures("users", "posts"))]
async fn test_create_comment(db: PgPool) {
let mut app = http::app(db);
// Happy path!
let mut resp1 = app
.borrow_mut()
.oneshot(
Request::post("/v1/post/d9ca2672-24c5-4442-b32f-cd717adffbaa/comment").json(json! {
{
"auth": {
"username": "bob",
"password": "pro gamer 1990"
},
"content": "lol bet ur still bad, 1v1 me"
}
}),
)
.await
.unwrap();
assert_eq!(resp1.status(), StatusCode::OK);
let resp1_json = response_json(&mut resp1).await;
assert_eq!(resp1_json["username"], "bob");
assert_eq!(resp1_json["content"], "lol bet ur still bad, 1v1 me");
let _comment_id = expect_uuid(&resp1_json["commentId"]);
let _created_at = expect_rfc3339_timestamp(&resp1_json["createdAt"]);
// Incorrect username
let mut resp2 = app
.borrow_mut()
.oneshot(
Request::post("/v1/post/d9ca2672-24c5-4442-b32f-cd717adffbaa/comment").json(json! {
{
"auth": {
"username": "bobbbbbb",
"password": "pro gamer 1990"
},
"content": "lol bet ur still bad, 1v1 me"
}
}),
)
.await
.unwrap();
assert_eq!(resp2.status(), StatusCode::UNPROCESSABLE_ENTITY);
let resp2_json = response_json(&mut resp2).await;
assert_eq!(resp2_json["message"], "invalid username/password");
// Incorrect password
let mut resp3 = app
.borrow_mut()
.oneshot(
Request::post("/v1/post/d9ca2672-24c5-4442-b32f-cd717adffbaa/comment").json(json! {
{
"auth": {
"username": "bob",
"password": "pro gamer 1990"
},
"content": "lol bet ur still bad, 1v1 me"
}
}),
)
.await
.unwrap();
assert_eq!(resp3.status(), StatusCode::UNPROCESSABLE_ENTITY);
let resp3_json = response_json(&mut resp3).await;
assert_eq!(resp3_json["message"], "invalid username/password");
}
#[sqlx::test(fixtures("users", "posts", "comments"))]
async fn test_list_comments(db: PgPool) {
let mut app = http::app(db);
// This only has the happy path.
let mut resp = app
.borrow_mut()
.oneshot(Request::get("/v1/post/d9ca2672-24c5-4442-b32f-cd717adffbaa/comment").empty_body())
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let resp_json = response_json(&mut resp).await;
let comments = resp_json
.as_array()
.expect("expected request to return an array");
assert_eq!(comments.len(), 2);
assert_eq!(comments[0]["username"], "bob");
assert_eq!(comments[0]["content"], "lol bet ur still bad, 1v1 me");
let _comment_id = expect_uuid(&comments[0]["commentId"]);
let created_at_0 = expect_rfc3339_timestamp(&comments[0]["createdAt"]);
assert_eq!(comments[1]["username"], "alice");
assert_eq!(comments[1]["content"], "you're on!");
let _comment_id = expect_uuid(&comments[1]["commentId"]);
let created_at_1 = expect_rfc3339_timestamp(&comments[1]["createdAt"]);
assert!(
created_at_0 < created_at_1,
"comments must be assorted in ascending order"
);
let mut resp = app
.borrow_mut()
.oneshot(Request::get("/v1/post/7e3d4d16-a35e-46ba-8223-b4f1debbfbfe/comment").empty_body())
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let resp_json = response_json(&mut resp).await;
let comments = resp_json
.as_array()
.expect("expected request to return an array");
assert_eq!(comments.len(), 1);
assert_eq!(comments[0]["username"], "alice");
assert_eq!(comments[0]["content"], "lol you're just mad you lost :P");
let _comment_id = expect_uuid(&comments[0]["commentId"]);
let _created_at = expect_rfc3339_timestamp(&comments[0]["createdAt"]);
}

View file

@ -0,0 +1,72 @@
// This is imported by different tests that use different functions.
#![allow(dead_code)]
use axum::body::{Body, BoxBody, HttpBody};
use axum::http::header::CONTENT_TYPE;
use axum::http::{request, Request};
use axum::response::Response;
use time::format_description::well_known::Rfc3339;
use time::OffsetDateTime;
use uuid::Uuid;
pub trait RequestBuilderExt {
fn json(self, json: serde_json::Value) -> Request<Body>;
fn empty_body(self) -> Request<Body>;
}
impl RequestBuilderExt for request::Builder {
fn json(self, json: serde_json::Value) -> Request<Body> {
self.header("Content-Type", "application/json")
.body(Body::from(json.to_string()))
.expect("failed to build request")
}
fn empty_body(self) -> Request<Body> {
self.body(Body::empty()).expect("failed to build request")
}
}
#[track_caller]
pub async fn response_json(resp: &mut Response<BoxBody>) -> serde_json::Value {
assert_eq!(
resp.headers()
.get(CONTENT_TYPE)
.expect("expected Content-Type"),
"application/json"
);
let body = resp.body_mut();
let mut bytes = Vec::new();
while let Some(res) = body.data().await {
let chunk = res.expect("error reading response body");
bytes.extend_from_slice(&chunk[..]);
}
serde_json::from_slice(&bytes).expect("failed to read response body as json")
}
#[track_caller]
pub fn expect_string(value: &serde_json::Value) -> &str {
value
.as_str()
.unwrap_or_else(|| panic!("expected string, got {:?}", value))
}
#[track_caller]
pub fn expect_uuid(value: &serde_json::Value) -> Uuid {
expect_string(value)
.parse::<Uuid>()
.unwrap_or_else(|e| panic!("failed to parse UUID from {:?}: {}", value, e))
}
#[track_caller]
pub fn expect_rfc3339_timestamp(value: &serde_json::Value) -> OffsetDateTime {
let s = expect_string(value);
OffsetDateTime::parse(s, &Rfc3339)
.unwrap_or_else(|e| panic!("failed to parse RFC-3339 timestamp from {:?}: {}", value, e))
}

View file

@ -0,0 +1,12 @@
INSERT INTO public.comment (comment_id, post_id, user_id, content, created_at)
VALUES
-- from: bob
('3a86b8f8-827b-4f14-94a2-34517b4b5bde', 'd9ca2672-24c5-4442-b32f-cd717adffbaa',
'c994b839-84f4-4509-ad49-59119133d6f5', 'lol bet ur still bad, 1v1 me', '2022-07-29 01:52:31.167673'),
-- from: alice
('d6f862b5-2b87-4af4-b15e-6b3398729e6d', 'd9ca2672-24c5-4442-b32f-cd717adffbaa',
'51b374f1-93ae-4c5c-89dd-611bda8412ce', 'you''re on!', '2022-07-29 01:53:53.115782'),
-- from: alice
('1eed85ae-adae-473c-8d05-b1dae0a1df63', '7e3d4d16-a35e-46ba-8223-b4f1debbfbfe',
'51b374f1-93ae-4c5c-89dd-611bda8412ce', 'lol you''re just mad you lost :P', '2022-07-29 01:55:50.116119');

View file

@ -0,0 +1,8 @@
INSERT INTO public.post (post_id, user_id, content, created_at)
VALUES
-- from: alice
('d9ca2672-24c5-4442-b32f-cd717adffbaa', '51b374f1-93ae-4c5c-89dd-611bda8412ce',
'This new computer is blazing fast!', '2022-07-29 01:36:24.679082'),
-- from: bob
('7e3d4d16-a35e-46ba-8223-b4f1debbfbfe', 'c994b839-84f4-4509-ad49-59119133d6f5', '@alice is a haxxor',
'2022-07-29 01:54:45.823523');

View file

@ -0,0 +1,10 @@
INSERT INTO public."user" (user_id, username, password_hash)
VALUES
-- username: "alice"; password: "rustacean since 2015"
('51b374f1-93ae-4c5c-89dd-611bda8412ce', 'alice',
'$argon2id$v=19$m=4096,t=3,p=1$3v3ats/tYTXAYs3q9RycDw$ZltwjS3oQwPuNmL9f6DNb+sH5N81dTVZhVNbUQzmmVU'),
-- username: "bob"; password: "pro gamer 1990"
('c994b839-84f4-4509-ad49-59119133d6f5', 'bob',
'$argon2id$v=19$m=4096,t=3,p=1$1zbkRinUH9WHzkyu8C1Vlg$70pu5Cca/s3d0nh5BYQGkN7+s9cqlNxTE7rFZaUaP4c');

View file

@ -0,0 +1,120 @@
use sqlx::PgPool;
use sqlx_example_postgres_axum_social::http;
use axum::http::{Request, StatusCode};
use tower::ServiceExt;
use std::borrow::BorrowMut;
use common::{expect_rfc3339_timestamp, expect_uuid, response_json, RequestBuilderExt};
use serde_json::json;
mod common;
#[sqlx::test(fixtures("users"))]
async fn test_create_post(db: PgPool) {
let mut app = http::app(db);
// Happy path!
let mut resp1 = app
.borrow_mut()
.oneshot(Request::post("/v1/post").json(json! {
{
"auth": {
"username": "alice",
"password": "rustacean since 2015"
},
"content": "This new computer is blazing fast!"
}
}))
.await
.unwrap();
assert_eq!(resp1.status(), StatusCode::OK);
let resp1_json = response_json(&mut resp1).await;
assert_eq!(resp1_json["username"], "alice");
assert_eq!(resp1_json["content"], "This new computer is blazing fast!");
let _post_id = expect_uuid(&resp1_json["postId"]);
let _created_at = expect_rfc3339_timestamp(&resp1_json["createdAt"]);
// Incorrect username
let mut resp2 = app
.borrow_mut()
.oneshot(Request::post("/v1/post").json(json! {
{
"auth": {
"username": "aliceee",
"password": "rustacean since 2015"
},
"content": "This new computer is blazing fast!"
}
}))
.await
.unwrap();
assert_eq!(resp2.status(), StatusCode::UNPROCESSABLE_ENTITY);
let resp2_json = response_json(&mut resp2).await;
assert_eq!(resp2_json["message"], "invalid username/password");
// Incorrect password
let mut resp3 = app
.borrow_mut()
.oneshot(Request::post("/v1/post").json(json! {
{
"auth": {
"username": "alice",
"password": "rustaceansince2015"
},
"content": "This new computer is blazing fast!"
}
}))
.await
.unwrap();
assert_eq!(resp3.status(), StatusCode::UNPROCESSABLE_ENTITY);
let resp3_json = response_json(&mut resp3).await;
assert_eq!(resp3_json["message"], "invalid username/password");
}
#[sqlx::test(fixtures("users", "posts"))]
async fn test_list_posts(db: PgPool) {
// This only has the happy path.
let mut resp = http::app(db)
.oneshot(Request::get("/v1/post").empty_body())
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let resp_json = response_json(&mut resp).await;
let posts = resp_json
.as_array()
.expect("expected GET /v1/post to return an array");
assert_eq!(posts.len(), 2);
assert_eq!(posts[0]["username"], "bob");
assert_eq!(posts[0]["content"], "@alice is a haxxor");
let _post_id = expect_uuid(&posts[0]["postId"]);
let created_at_0 = expect_rfc3339_timestamp(&posts[0]["createdAt"]);
assert_eq!(posts[1]["username"], "alice");
assert_eq!(posts[1]["content"], "This new computer is blazing fast!");
let _post_id = expect_uuid(&posts[1]["postId"]);
let created_at_1 = expect_rfc3339_timestamp(&posts[1]["createdAt"]);
assert!(
created_at_0 > created_at_1,
"posts must be sorted in descending order"
);
}

View file

@ -0,0 +1,89 @@
use sqlx::PgPool;
use sqlx_example_postgres_axum_social::http;
use axum::http::{Request, StatusCode};
use tower::ServiceExt;
use std::borrow::BorrowMut;
use common::{response_json, RequestBuilderExt};
use serde_json::json;
mod common;
#[sqlx::test]
async fn test_create_user(db: PgPool) {
let mut app = http::app(db);
// Happy path!
let resp1 = app
.borrow_mut()
// We handle JSON objects directly to sanity check the serialization and deserialization
.oneshot(Request::post("/v1/user").json(json! {{
"username": "alice",
"password": "rustacean since 2015"
}}))
.await
.unwrap();
assert_eq!(resp1.status(), StatusCode::NO_CONTENT);
// Username taken
let mut resp2 = app
.borrow_mut()
.oneshot(Request::post("/v1/user").json(json! {{
"username": "alice",
"password": "uhhh i forgot"
}}))
.await
.unwrap();
assert_eq!(resp2.status(), StatusCode::CONFLICT);
let resp2_json = response_json(&mut resp2).await;
assert_eq!(resp2_json["message"], "username taken");
// Invalid username
let mut resp3 = app
.borrow_mut()
.oneshot(Request::post("/v1/user").json(json! {{
"username": "definitely an invalid username",
"password": "password"
}}))
.await
.unwrap();
assert_eq!(resp3.status(), StatusCode::UNPROCESSABLE_ENTITY);
let resp3_json = response_json(&mut resp3).await;
assert_eq!(resp3_json["message"], "validation error in request body");
assert!(
resp3_json["errors"]["username"].is_array(),
"errors.username is not an array: {:?}",
resp3_json
);
// Invalid password
let mut resp4 = app
.borrow_mut()
.oneshot(Request::post("/v1/user").json(json! {{
"username": "bobby123",
"password": ""
}}))
.await
.unwrap();
assert_eq!(resp4.status(), StatusCode::UNPROCESSABLE_ENTITY);
let resp4_json = response_json(&mut resp4).await;
assert_eq!(resp4_json["message"], "validation error in request body");
assert!(
resp4_json["errors"]["password"].is_array(),
"errors.password is not an array: {:?}",
resp4_json
);
}

Binary file not shown.

View file

@ -175,6 +175,8 @@ indexmap = "1.6.0"
hkdf = { version = "0.12.0", optional = true }
event-listener = "2.5.2"
dotenvy = "0.15"
[dev-dependencies]
sqlx = { version = "0.6.0", path = "..", features = ["postgres", "sqlite", "mysql"] }
tokio = { version = "1", features = ["rt"] }

View file

@ -103,6 +103,12 @@ pub mod mysql;
#[cfg_attr(docsrs, doc(cfg(feature = "mssql")))]
pub mod mssql;
// Implements test support with automatic DB management.
#[cfg(feature = "migrate")]
pub mod testing;
pub use sqlx_rt::test_block_on;
/// sqlx uses ahash for increased performance, at the cost of reduced DoS resistance.
use ahash::AHashMap as HashMap;
//type HashMap<K, V> = std::collections::HashMap<K, V, ahash::RandomState>;

View file

@ -93,7 +93,15 @@ impl Migrator {
<A::Connection as Deref>::Target: Migrate,
{
let mut conn = migrator.acquire().await?;
self.run_direct(&mut *conn).await
}
// Getting around the annoying "implementation of `Acquire` is not general enough" error
#[doc(hidden)]
pub async fn run_direct<C>(&self, conn: &mut C) -> Result<(), MigrateError>
where
C: Migrate,
{
// lock the database for exclusive access by the migrator
conn.lock().await?;

View file

@ -60,7 +60,33 @@ impl MySqlConnection {
}
// Upgrade to TLS if we were asked to and the server supports it
tls::maybe_upgrade(&mut stream, options).await?;
#[cfg(feature = "_tls-rustls")]
{
// To aid in debugging: https://github.com/rustls/rustls/issues/893
let local_addr = stream.local_addr();
match tls::maybe_upgrade(&mut stream, options).await {
Ok(()) => (),
#[cfg(feature = "_tls-rustls")]
Err(Error::Io(ioe)) => {
if let Some(&rustls::Error::CorruptMessage) =
ioe.get_ref().and_then(|e| e.downcast_ref())
{
log::trace!("got corrupt message on socket {:?}", local_addr);
}
return Err(Error::Io(ioe));
}
Err(e) => return Err(e),
}
}
#[cfg(not(feature = "_tls-rustls"))]
{
tls::maybe_upgrade(&mut stream, options).await?
}
let auth_response = if let (Some(plugin), Some(password)) = (plugin, &options.password) {
Some(plugin.scramble(&mut stream, password, &nonce).await?)

View file

@ -22,6 +22,9 @@ mod value;
#[cfg(feature = "migrate")]
mod migrate;
#[cfg(feature = "migrate")]
mod testing;
pub use arguments::MySqlArguments;
pub use column::MySqlColumn;
pub use connection::MySqlConnection;

View file

@ -0,0 +1,223 @@
use std::fmt::Write;
use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use futures_core::future::BoxFuture;
use once_cell::sync::OnceCell;
use crate::connection::Connection;
use crate::error::Error;
use crate::executor::Executor;
use crate::mysql::{MySql, MySqlConnectOptions, MySqlConnection};
use crate::pool::{Pool, PoolOptions};
use crate::query::query;
use crate::query_builder::QueryBuilder;
use crate::query_scalar::query_scalar;
use crate::testing::{FixtureSnapshot, TestArgs, TestContext, TestSupport};
// Using a blocking `OnceCell` here because the critical sections are short.
static MASTER_POOL: OnceCell<Pool<MySql>> = OnceCell::new();
// Automatically delete any databases created before the start of the test binary.
static DO_CLEANUP: AtomicBool = AtomicBool::new(true);
impl TestSupport for MySql {
fn test_context(args: &TestArgs) -> BoxFuture<'_, Result<TestContext<Self>, Error>> {
Box::pin(async move {
let res = test_context(args).await;
res
})
}
fn cleanup_test(db_name: &str) -> BoxFuture<'_, Result<(), Error>> {
Box::pin(async move {
let mut conn = MASTER_POOL
.get()
.expect("cleanup_test() invoked outside `#[sqlx::test]")
.acquire()
.await?;
let db_id = db_id(db_name);
conn.execute(&format!("drop database if exists {};", db_name)[..])
.await?;
query("delete from _sqlx_test_databases where db_id = ?")
.bind(&db_id)
.execute(&mut conn)
.await?;
Ok(())
})
}
fn cleanup_test_dbs() -> BoxFuture<'static, Result<Option<usize>, Error>> {
Box::pin(async move {
let url = dotenvy::var("DATABASE_URL").expect("DATABASE_URL must be set");
let mut conn = MySqlConnection::connect(&url).await?;
let num_deleted = do_cleanup(&mut conn).await?;
let _ = conn.close().await;
Ok(Some(num_deleted))
})
}
fn snapshot(
_conn: &mut Self::Connection,
) -> BoxFuture<'_, Result<FixtureSnapshot<Self>, Error>> {
// TODO: I want to get the testing feature out the door so this will have to wait,
// but I'm keeping the code around for now because I plan to come back to it.
todo!()
}
}
async fn test_context(args: &TestArgs) -> Result<TestContext<MySql>, Error> {
let url = dotenvy::var("DATABASE_URL").expect("DATABASE_URL must be set");
let master_opts = MySqlConnectOptions::from_str(&url).expect("failed to parse DATABASE_URL");
let pool = PoolOptions::new()
// MySql's normal connection limit is 150 plus 1 superuser connection
// We don't want to use the whole cap and there may be fuzziness here due to
// concurrently running tests anyway.
.max_connections(20)
// Immediately close master connections. Tokio's I/O streams don't like hopping runtimes.
.after_release(|_conn, _| Box::pin(async move { Ok(false) }))
.connect_lazy_with(master_opts);
let master_pool = match MASTER_POOL.try_insert(pool) {
Ok(inserted) => inserted,
Err((existing, pool)) => {
// Sanity checks.
assert_eq!(
existing.connect_options().host,
pool.connect_options().host,
"DATABASE_URL changed at runtime, host differs"
);
assert_eq!(
existing.connect_options().database,
pool.connect_options().database,
"DATABASE_URL changed at runtime, database differs"
);
existing
}
};
let mut conn = master_pool.acquire().await?;
// language=MySQL
conn.execute(
r#"
create table if not exists _sqlx_test_databases (
db_id bigint unsigned primary key auto_increment,
test_path text not null,
created_at timestamp not null default current_timestamp
);
"#,
)
.await?;
// Only run cleanup if the test binary just started.
if DO_CLEANUP.swap(false, Ordering::SeqCst) {
do_cleanup(&mut conn).await?;
}
query("insert into _sqlx_test_databases(test_path) values (?)")
.bind(&args.test_path)
.execute(&mut conn)
.await?;
// MySQL doesn't have `INSERT ... RETURNING`
let new_db_id: u64 = query_scalar("select last_insert_id()")
.fetch_one(&mut conn)
.await?;
let new_db_name = db_name(new_db_id);
conn.execute(&format!("create database {}", new_db_name)[..])
.await?;
eprintln!("created database {}", new_db_name);
Ok(TestContext {
pool_opts: PoolOptions::new()
// Don't allow a single test to take all the connections.
// Most tests shouldn't require more than 5 connections concurrently,
// or else they're likely doing too much in one test.
.max_connections(5)
// Close connections ASAP if left in the idle queue.
.idle_timeout(Some(Duration::from_secs(1)))
.parent(master_pool.clone()),
connect_opts: master_pool.connect_options().clone().database(&new_db_name),
db_name: new_db_name,
})
}
async fn do_cleanup(conn: &mut MySqlConnection) -> Result<usize, Error> {
let delete_db_ids: Vec<u64> = query_scalar(
"select db_id from _sqlx_test_databases where created_at < current_timestamp()",
)
.fetch_all(&mut *conn)
.await?;
if delete_db_ids.is_empty() {
return Ok(0);
}
let mut deleted_db_ids = Vec::with_capacity(delete_db_ids.len());
let mut command = String::new();
for db_id in delete_db_ids {
command.clear();
let db_name = db_name(db_id);
writeln!(command, "drop database if exists {}", db_name).ok();
match conn.execute(&*command).await {
Ok(_deleted) => {
deleted_db_ids.push(db_id);
}
// Assume a database error just means the DB is still in use.
Err(Error::Database(dbe)) => {
eprintln!("could not clean test database {:?}: {}", db_id, dbe)
}
// Bubble up other errors
Err(e) => return Err(e),
}
}
let mut query = QueryBuilder::new("delete from _sqlx_test_databases where db_id in (");
let mut separated = query.separated(",");
for db_id in &deleted_db_ids {
separated.push_bind(db_id);
}
drop(separated);
query.push(")").build().execute(&mut *conn).await?;
Ok(deleted_db_ids.len())
}
fn db_name(id: u64) -> String {
format!("_sqlx_test_database_{}", id)
}
fn db_id(name: &str) -> u64 {
name.trim_start_matches("_sqlx_test_database_")
.parse()
.unwrap_or_else(|_1| panic!("failed to parse ID from database name {:?}", name))
}
#[test]
fn test_db_name_id() {
assert_eq!(db_name(12345), "_sqlx_test_database_12345");
assert_eq!(db_id("_sqlx_test_database_12345"), 12345);
}

View file

@ -1,6 +1,7 @@
#![allow(dead_code)]
use std::io;
use std::net::SocketAddr;
use std::path::Path;
use std::pin::Pin;
use std::task::{Context, Poll};
@ -30,6 +31,14 @@ impl Socket {
.map(Socket::Unix)
}
pub fn local_addr(&self) -> Option<SocketAddr> {
match self {
Self::Tcp(tcp) => tcp.local_addr().ok(),
#[cfg(unix)]
Self::Unix(_) => None,
}
}
#[cfg(not(unix))]
pub async fn connect_uds(_: impl AsRef<Path>) -> io::Result<Self> {
Err(io::Error::new(

View file

@ -3,7 +3,7 @@ use crate::connection::ConnectOptions;
use crate::connection::Connection;
use crate::database::Database;
use crate::error::Error;
use crate::pool::{deadline_as_timeout, CloseEvent, PoolOptions};
use crate::pool::{deadline_as_timeout, CloseEvent, Pool, PoolOptions};
use crossbeam_queue::ArrayQueue;
use futures_intrusive::sync::{Semaphore, SemaphoreReleaser};
@ -12,16 +12,13 @@ use std::cmp;
use std::future::Future;
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::Poll;
use crate::pool::options::PoolConnectionMetadata;
use futures_util::future::{self};
use futures_util::FutureExt;
use std::time::{Duration, Instant};
/// Ihe number of permits to release to wake all waiters, such as on `PoolInner::close()`.
///
/// This should be large enough to realistically wake all tasks waiting on the pool without
/// potentially overflowing the permits count in the semaphore itself.
const WAKE_ALL_PERMITS: usize = usize::MAX / 2;
pub(crate) struct PoolInner<DB: Database> {
pub(super) connect_options: <DB::Connection as Connection>::Options,
pub(super) idle_conns: ArrayQueue<Idle<DB>>,
@ -40,16 +37,19 @@ impl<DB: Database> PoolInner<DB> {
) -> Arc<Self> {
let capacity = options.max_connections as usize;
// ensure the permit count won't overflow if we release `WAKE_ALL_PERMITS`
// this assert should never fire on 64-bit targets as `max_connections` is a u32
let _ = capacity
.checked_add(WAKE_ALL_PERMITS)
.expect("max_connections exceeds max capacity of the pool");
let semaphore_capacity = if let Some(parent) = &options.parent_pool {
assert!(options.max_connections <= parent.options().max_connections);
assert_eq!(options.fair, parent.options().fair);
// The child pool must steal permits from the parent
0
} else {
capacity
};
let pool = Self {
connect_options,
idle_conns: ArrayQueue::new(capacity),
semaphore: Semaphore::new(options.fair, capacity),
semaphore: Semaphore::new(options.fair, semaphore_capacity),
size: AtomicU32::new(0),
num_idle: AtomicUsize::new(0),
is_closed: AtomicBool::new(false),
@ -82,31 +82,22 @@ impl<DB: Database> PoolInner<DB> {
}
pub(super) fn close<'a>(self: &'a Arc<Self>) -> impl Future<Output = ()> + 'a {
let already_closed = self.is_closed.swap(true, Ordering::AcqRel);
if !already_closed {
// if we were the one to mark this closed, release enough permits to wake all waiters
// we can't just do `usize::MAX` because that would overflow
// and we can't do this more than once cause that would _also_ overflow
self.semaphore.release(WAKE_ALL_PERMITS);
self.on_closed.notify(usize::MAX);
}
self.is_closed.store(true, Ordering::Release);
self.on_closed.notify(usize::MAX);
async move {
// Close any currently idle connections in the pool.
while let Some(idle) = self.idle_conns.pop() {
let _ = idle.live.float((*self).clone()).close().await;
}
for permits in 1..=self.options.max_connections as usize {
// Close any currently idle connections in the pool.
while let Some(idle) = self.idle_conns.pop() {
let _ = idle.live.float((*self).clone()).close().await;
}
// Wait for all permits to be released.
let _permits = self
.semaphore
.acquire(WAKE_ALL_PERMITS + (self.options.max_connections as usize))
.await;
if self.size() == 0 {
break;
}
// Clean up any remaining connections.
while let Some(idle) = self.idle_conns.pop() {
let _ = idle.live.float((*self).clone()).close().await;
// Wait for all permits to be released.
let _permits = self.semaphore.acquire(permits).await;
}
}
}
@ -117,6 +108,67 @@ impl<DB: Database> PoolInner<DB> {
}
}
/// Attempt to pull a permit from `self.semaphore` or steal one from the parent.
///
/// If we steal a permit from the parent but *don't* open a connection,
/// it should be returned to the parent.
async fn acquire_permit<'a>(self: &'a Arc<Self>) -> Result<SemaphoreReleaser<'a>, Error> {
let parent = self
.parent()
// If we're already at the max size, we shouldn't try to steal from the parent.
// This is just going to cause unnecessary churn in `acquire()`.
.filter(|_| self.size() < self.options.max_connections);
let acquire_self = self.semaphore.acquire(1).fuse();
let mut close_event = self.close_event();
if let Some(parent) = parent {
let acquire_parent = parent.0.semaphore.acquire(1);
let parent_close_event = parent.0.close_event();
futures_util::pin_mut!(
acquire_parent,
acquire_self,
close_event,
parent_close_event
);
let mut poll_parent = false;
future::poll_fn(|cx| {
if close_event.as_mut().poll(cx).is_ready() {
return Poll::Ready(Err(Error::PoolClosed));
}
if parent_close_event.as_mut().poll(cx).is_ready() {
// Propagate the parent's close event to the child.
let _ = self.close();
return Poll::Ready(Err(Error::PoolClosed));
}
if let Poll::Ready(permit) = acquire_self.as_mut().poll(cx) {
return Poll::Ready(Ok(permit));
}
// Don't try the parent right away.
if poll_parent {
acquire_parent.as_mut().poll(cx).map(Ok)
} else {
poll_parent = true;
cx.waker().wake_by_ref();
Poll::Pending
}
})
.await
} else {
close_event.do_until(acquire_self).await
}
}
fn parent(&self) -> Option<&Pool<DB>> {
self.options.parent_pool.as_ref()
}
#[inline]
pub(super) fn try_acquire(self: &Arc<Self>) -> Option<Floating<DB, Idle<DB>>> {
if self.is_closed() {
@ -124,6 +176,7 @@ impl<DB: Database> PoolInner<DB> {
}
let permit = self.semaphore.try_acquire(1)?;
self.pop_idle(permit).ok()
}
@ -184,11 +237,9 @@ impl<DB: Database> PoolInner<DB> {
self.options.acquire_timeout,
async {
loop {
let permit = self.semaphore.acquire(1).await;
// Handles the close-event internally
let permit = self.acquire_permit().await?;
if self.is_closed() {
return Err(Error::PoolClosed);
}
// First attempt to pop a connection from the idle queue.
let guard = match self.pop_idle(permit) {
@ -207,7 +258,12 @@ impl<DB: Database> PoolInner<DB> {
// we can open a new connection
guard
} else {
// This can happen for a child pool that's at its connection limit.
log::debug!("woke but was unable to acquire idle connection or open new one; retrying");
// If so, we're likely in the current-thread runtime if it's Tokio
// and so we should yield to let any spawned release_to_pool() tasks
// execute.
sqlx_rt::yield_now().await;
continue;
}
};
@ -334,6 +390,15 @@ impl<DB: Database> PoolInner<DB> {
}
}
impl<DB: Database> Drop for PoolInner<DB> {
fn drop(&mut self) {
if let Some(parent) = &self.options.parent_pool {
// Release the stolen permits.
parent.0.semaphore.release(self.semaphore.permits());
}
}
}
/// Returns `true` if the connection has exceeded `options.max_lifetime` if set, `false` otherwise.
fn is_beyond_max_lifetime<DB: Database>(live: &Live<DB>, options: &PoolOptions<DB>) -> bool {
options
@ -486,6 +551,8 @@ impl<DB: Database> DecrementSizeGuard<DB> {
}
/// Release the semaphore permit without decreasing the pool size.
///
/// If the permit was stolen from the pool's parent, it will be returned to the child's semaphore.
fn release_permit(self) {
self.pool.semaphore.release(1);
self.cancel();

View file

@ -80,6 +80,8 @@ pub struct PoolOptions<DB: Database> {
pub(crate) max_lifetime: Option<Duration>,
pub(crate) idle_timeout: Option<Duration>,
pub(crate) fair: bool,
pub(crate) parent_pool: Option<Pool<DB>>,
}
/// Metadata for the connection being processed by a [`PoolOptions`] callback.
@ -125,6 +127,7 @@ impl<DB: Database> PoolOptions<DB> {
idle_timeout: Some(Duration::from_secs(10 * 60)),
max_lifetime: Some(Duration::from_secs(30 * 60)),
fair: true,
parent_pool: None,
}
}
@ -400,6 +403,19 @@ impl<DB: Database> PoolOptions<DB> {
self
}
/// Set the parent `Pool` from which the new pool will inherit its semaphore.
///
/// This is currently an internal-only API.
///
/// ### Panics
/// If `self.max_connections` is greater than the setting the given pool was created with,
/// or `self.fair` differs from the setting the given pool was created with.
#[doc(hidden)]
pub fn parent(mut self, pool: Pool<DB>) -> Self {
self.parent_pool = Some(pool);
self
}
/// Create a new pool from this `PoolOptions` and immediately open at least one connection.
///
/// This ensures the configuration is correct.

View file

@ -24,6 +24,9 @@ mod value;
#[cfg(feature = "migrate")]
mod migrate;
#[cfg(feature = "migrate")]
mod testing;
pub use advisory_lock::{PgAdvisoryLock, PgAdvisoryLockGuard, PgAdvisoryLockKey};
pub use arguments::{PgArgumentBuffer, PgArguments};
pub use column::PgColumn;

View file

@ -0,0 +1,204 @@
use std::fmt::Write;
use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use futures_core::future::BoxFuture;
use once_cell::sync::OnceCell;
use crate::connection::Connection;
use crate::error::Error;
use crate::executor::Executor;
use crate::pool::{Pool, PoolOptions};
use crate::postgres::{PgConnectOptions, PgConnection, Postgres};
use crate::query::query;
use crate::query_scalar::query_scalar;
use crate::testing::{FixtureSnapshot, TestArgs, TestContext, TestSupport};
// Using a blocking `OnceCell` here because the critical sections are short.
static MASTER_POOL: OnceCell<Pool<Postgres>> = OnceCell::new();
// Automatically delete any databases created before the start of the test binary.
static DO_CLEANUP: AtomicBool = AtomicBool::new(true);
impl TestSupport for Postgres {
fn test_context(args: &TestArgs) -> BoxFuture<'_, Result<TestContext<Self>, Error>> {
Box::pin(async move {
let res = test_context(args).await;
res
})
}
fn cleanup_test(db_name: &str) -> BoxFuture<'_, Result<(), Error>> {
Box::pin(async move {
let mut conn = MASTER_POOL
.get()
.expect("cleanup_test() invoked outside `#[sqlx::test]")
.acquire()
.await?;
conn.execute(&format!("drop database if exists {0:?};", db_name)[..])
.await?;
query("delete from _sqlx_test.databases where db_name = $1")
.bind(&db_name)
.execute(&mut conn)
.await?;
Ok(())
})
}
fn cleanup_test_dbs() -> BoxFuture<'static, Result<Option<usize>, Error>> {
Box::pin(async move {
let url = dotenvy::var("DATABASE_URL").expect("DATABASE_URL must be set");
let mut conn = PgConnection::connect(&url).await?;
let num_deleted = do_cleanup(&mut conn).await?;
let _ = conn.close().await;
Ok(Some(num_deleted))
})
}
fn snapshot(
_conn: &mut Self::Connection,
) -> BoxFuture<'_, Result<FixtureSnapshot<Self>, Error>> {
// TODO: I want to get the testing feature out the door so this will have to wait,
// but I'm keeping the code around for now because I plan to come back to it.
todo!()
}
}
async fn test_context(args: &TestArgs) -> Result<TestContext<Postgres>, Error> {
let url = dotenvy::var("DATABASE_URL").expect("DATABASE_URL must be set");
let master_opts = PgConnectOptions::from_str(&url).expect("failed to parse DATABASE_URL");
let pool = PoolOptions::new()
// Postgres' normal connection limit is 100 plus 3 superuser connections
// We don't want to use the whole cap and there may be fuzziness here due to
// concurrently running tests anyway.
.max_connections(20)
// Immediately close master connections. Tokio's I/O streams don't like hopping runtimes.
.after_release(|_conn, _| Box::pin(async move { Ok(false) }))
.connect_lazy_with(master_opts);
let master_pool = match MASTER_POOL.try_insert(pool) {
Ok(inserted) => inserted,
Err((existing, pool)) => {
// Sanity checks.
assert_eq!(
existing.connect_options().host,
pool.connect_options().host,
"DATABASE_URL changed at runtime, host differs"
);
assert_eq!(
existing.connect_options().database,
pool.connect_options().database,
"DATABASE_URL changed at runtime, database differs"
);
existing
}
};
let mut conn = master_pool.acquire().await?;
// language=PostgreSQL
conn.execute(
// Explicit lock avoids this latent bug: https://stackoverflow.com/a/29908840
// I couldn't find a bug on the mailing list for `CREATE SCHEMA` specifically,
// but a clearly related bug with `CREATE TABLE` has been known since 2007:
// https://www.postgresql.org/message-id/200710222037.l9MKbCJZ098744%40wwwmaster.postgresql.org
r#"
lock table pg_catalog.pg_namespace in share row exclusive mode;
create schema if not exists _sqlx_test;
create table if not exists _sqlx_test.databases (
db_name text primary key,
test_path text not null,
created_at timestamptz not null default now()
);
create index if not exists databases_created_at
on _sqlx_test.databases(created_at);
create sequence if not exists _sqlx_test.database_ids;
"#,
)
.await?;
// Only run cleanup if the test binary just started.
if DO_CLEANUP.swap(false, Ordering::SeqCst) {
do_cleanup(&mut conn).await?;
}
let new_db_name: String = query_scalar(
r#"
insert into _sqlx_test.databases(db_name, test_path)
select '_sqlx_test_' || nextval('_sqlx_test.database_ids'), $1
returning db_name
"#,
)
.bind(&args.test_path)
.fetch_one(&mut conn)
.await?;
conn.execute(&format!("create database {:?}", new_db_name)[..])
.await?;
Ok(TestContext {
pool_opts: PoolOptions::new()
// Don't allow a single test to take all the connections.
// Most tests shouldn't require more than 5 connections concurrently,
// or else they're likely doing too much in one test.
.max_connections(5)
// Close connections ASAP if left in the idle queue.
.idle_timeout(Some(Duration::from_secs(1)))
.parent(master_pool.clone()),
connect_opts: master_pool.connect_options().clone().database(&new_db_name),
db_name: new_db_name,
})
}
async fn do_cleanup(conn: &mut PgConnection) -> Result<usize, Error> {
let delete_db_names: Vec<String> =
query_scalar("select db_name from _sqlx_test.databases where created_at < now()")
.fetch_all(&mut *conn)
.await?;
if delete_db_names.is_empty() {
return Ok(0);
}
let mut deleted_db_names = Vec::with_capacity(delete_db_names.len());
let delete_db_names = delete_db_names.into_iter();
let mut command = String::new();
for db_name in delete_db_names {
command.clear();
writeln!(command, "drop database if exists {:?};", db_name).ok();
match conn.execute(&*command).await {
Ok(_deleted) => {
deleted_db_names.push(db_name);
}
// Assume a database error just means the DB is still in use.
Err(Error::Database(dbe)) => {
eprintln!("could not clean test database {:?}: {}", db_name, dbe)
}
// Bubble up other errors
Err(e) => return Err(e),
}
}
query("delete from _sqlx_test.databases where db_name = any($1::text[])")
.bind(&deleted_db_names)
.execute(&mut *conn)
.await?;
Ok(deleted_db_names.len())
}

View file

@ -43,6 +43,9 @@ mod value;
#[cfg(feature = "migrate")]
mod migrate;
#[cfg(feature = "migrate")]
mod testing;
/// An alias for [`Pool`][crate::pool::Pool], specialized for SQLite.
pub type SqlitePool = crate::pool::Pool<Sqlite>;

View file

@ -0,0 +1,81 @@
use crate::error::Error;
use crate::pool::PoolOptions;
use crate::sqlite::{Sqlite, SqliteConnectOptions};
use crate::testing::{FixtureSnapshot, TestArgs, TestContext, TestSupport};
use futures_core::future::BoxFuture;
use std::path::{Path, PathBuf};
const BASE_PATH: &str = "target/sqlx/test-dbs";
impl TestSupport for Sqlite {
fn test_context(args: &TestArgs) -> BoxFuture<'_, Result<TestContext<Self>, Error>> {
Box::pin(async move {
let res = test_context(args).await;
res
})
}
fn cleanup_test(db_name: &str) -> BoxFuture<'_, Result<(), Error>> {
Box::pin(async move { Ok(sqlx_rt::fs::remove_file(db_name).await?) })
}
fn cleanup_test_dbs() -> BoxFuture<'static, Result<Option<usize>, Error>> {
Box::pin(async move {
sqlx_rt::fs::remove_dir_all(BASE_PATH).await?;
Ok(None)
})
}
fn snapshot(
_conn: &mut Self::Connection,
) -> BoxFuture<'_, Result<FixtureSnapshot<Self>, Error>> {
todo!()
}
}
async fn test_context(args: &TestArgs) -> Result<TestContext<Sqlite>, Error> {
let db_path = convert_path(args.test_path);
if let Some(parent_path) = Path::parent(db_path.as_ref()) {
sqlx_rt::fs::create_dir_all(parent_path)
.await
.expect("failed to create folders");
}
if Path::exists(db_path.as_ref()) {
sqlx_rt::fs::remove_file(&db_path)
.await
.expect("failed to remove database from previous test run");
}
Ok(TestContext {
connect_opts: SqliteConnectOptions::new()
.filename(&db_path)
.create_if_missing(true),
// This doesn't really matter for SQLite as the databases are independent of each other.
// The main limitation is going to be the number of concurrent running tests.
pool_opts: PoolOptions::new().max_connections(1000),
db_name: db_path,
})
}
fn convert_path(test_path: &str) -> String {
let mut path = PathBuf::from(BASE_PATH);
for segment in test_path.split("::") {
path.push(segment);
}
path.set_extension("sqlite");
path.into_os_string()
.into_string()
.expect("path should be UTF-8")
}
#[test]
fn test_convert_path() {
let path = convert_path("foo::bar::baz::quux");
assert_eq!(path, "target/sqlx/test-dbs/foo/bar/baz/quux.sqlite");
}

View file

@ -0,0 +1,280 @@
//! TODO: automatic test fixture capture
use crate::database::{Database, HasArguments};
use crate::query_builder::QueryBuilder;
use indexmap::set::IndexSet;
use std::cmp;
use std::collections::{BTreeMap, HashMap};
use std::marker::PhantomData;
use std::sync::Arc;
pub type Result<T, E = FixtureError> = std::result::Result<T, E>;
/// A snapshot of the current state of the database.
///
/// Can be used to generate an `INSERT` fixture for populating an empty database,
/// or in the future it may be possible to generate a fixture from the difference between
/// two snapshots.
pub struct FixtureSnapshot<DB> {
tables: BTreeMap<TableName, Table>,
db: PhantomData<DB>,
}
#[derive(Debug, thiserror::Error)]
#[error("could not create fixture: {0}")]
pub struct FixtureError(String);
pub struct Fixture<DB> {
ops: Vec<FixtureOp>,
db: PhantomData<DB>,
}
enum FixtureOp {
Insert {
table: TableName,
columns: Vec<ColumnName>,
rows: Vec<Vec<Value>>,
},
// TODO: handle updates and deletes by diffing two snapshots
}
type TableName = Arc<str>;
type ColumnName = Arc<str>;
type Value = String;
struct Table {
name: TableName,
columns: IndexSet<ColumnName>,
rows: Vec<Vec<Value>>,
foreign_keys: HashMap<ColumnName, (TableName, ColumnName)>,
}
macro_rules! fixture_assert (
($cond:expr, $msg:literal $($arg:tt)*) => {
if !($cond) {
return Err(FixtureError(format!($msg $($arg)*)))
}
}
);
impl<DB: Database> FixtureSnapshot<DB> {
/// Generate a fixture to reproduce this snapshot from an empty database using `INSERT`s.
///
/// Note that this doesn't take into account any triggers that might modify the data before
/// it's stored.
///
/// The `INSERT` statements are ordered on a best-effort basis to satisfy any foreign key
/// constraints (data from tables with no foreign keys are inserted first, then the tables
/// that reference those tables, and so on).
///
/// If a cycle in foreign-key constraints is detected, this returns with an error.
pub fn additive_fixture(&self) -> Result<Fixture<DB>> {
let visit_order = self.calculate_visit_order()?;
let mut ops = Vec::new();
for table_name in visit_order {
let table = self.tables.get(&table_name).unwrap();
ops.push(FixtureOp::Insert {
table: table_name,
columns: table.columns.iter().cloned().collect(),
rows: table.rows.clone(),
});
}
Ok(Fixture { ops, db: self.db })
}
/// Determine an order for outputting `INSERTS` for each table by calculating the max
/// length of all its foreign key chains.
///
/// This should hopefully ensure that there are no foreign-key errors.
fn calculate_visit_order(&self) -> Result<Vec<TableName>> {
let mut table_depths = HashMap::with_capacity(self.tables.len());
let mut visited_set = IndexSet::with_capacity(self.tables.len());
for table in self.tables.values() {
foreign_key_depth(&self.tables, table, &mut table_depths, &mut visited_set)?;
visited_set.clear();
}
let mut table_names: Vec<TableName> = table_depths.keys().cloned().collect();
table_names.sort_by_key(|name| table_depths.get(name).unwrap());
Ok(table_names)
}
}
/// Implements `ToString` but not `Display` because it uses [`QueryBuilder`] internally,
/// which appends to an internal string.
impl<DB: Database> ToString for Fixture<DB>
where
for<'a> <DB as HasArguments<'a>>::Arguments: Default,
{
fn to_string(&self) -> String {
let mut query = QueryBuilder::<DB>::new("");
for op in &self.ops {
match op {
FixtureOp::Insert {
table,
columns,
rows,
} => {
// Sanity check, empty tables shouldn't appear in snapshots anyway.
if columns.is_empty() || rows.is_empty() {
continue;
}
query.push(format_args!("INSERT INTO {} (", table));
let mut separated = query.separated(", ");
for column in columns {
separated.push(column);
}
query.push(")\n");
query.push_values(rows, |mut separated, row| {
for value in row {
separated.push(value);
}
});
query.push(";\n");
}
}
}
query.into_sql()
}
}
fn foreign_key_depth(
tables: &BTreeMap<TableName, Table>,
table: &Table,
depths: &mut HashMap<TableName, usize>,
visited_set: &mut IndexSet<TableName>,
) -> Result<usize> {
if let Some(&depth) = depths.get(&table.name) {
return Ok(depth);
}
// This keeps us from looping forever.
fixture_assert!(
visited_set.insert(table.name.clone()),
"foreign key cycle detected: {:?} -> {:?}",
visited_set,
table.name
);
let mut refdepth = 0;
for (colname, (refname, refcol)) in &table.foreign_keys {
let referenced = tables.get(refname).ok_or_else(|| {
FixtureError(format!(
"table {:?} in foreign key `{}.{} references {}.{}` does not exist",
refname, table.name, colname, refname, refcol
))
})?;
refdepth = cmp::max(
refdepth,
foreign_key_depth(tables, referenced, depths, visited_set)?,
);
}
let depth = refdepth + 1;
depths.insert(table.name.clone(), depth);
Ok(depth)
}
#[test]
#[cfg(feature = "postgres")]
fn test_additive_fixture() -> Result<()> {
use crate::postgres::Postgres;
let mut snapshot = FixtureSnapshot {
tables: BTreeMap::new(),
db: PhantomData::<Postgres>,
};
snapshot.tables.insert(
"foo".into(),
Table {
name: "foo".into(),
columns: ["foo_id", "foo_a", "foo_b"]
.into_iter()
.map(Arc::<str>::from)
.collect(),
rows: vec![vec!["1".into(), "'asdf'".into(), "true".into()]],
foreign_keys: HashMap::new(),
},
);
// foreign-keyed to `foo`
// since `tables` is a `BTreeMap` we would expect a naive algorithm to visit this first.
snapshot.tables.insert(
"bar".into(),
Table {
name: "bar".into(),
columns: ["bar_id", "foo_id", "bar_a", "bar_b"]
.into_iter()
.map(Arc::<str>::from)
.collect(),
rows: vec![vec![
"1234".into(),
"1".into(),
"'2022-07-22 23:27:48.775113301+00:00'".into(),
"3.14".into(),
]],
foreign_keys: [("foo_id".into(), ("foo".into(), "foo_id".into()))]
.into_iter()
.collect(),
},
);
// foreign-keyed to both `foo` and `bar`
snapshot.tables.insert(
"baz".into(),
Table {
name: "baz".into(),
columns: ["baz_id", "bar_id", "foo_id", "baz_a", "baz_b"]
.into_iter()
.map(Arc::<str>::from)
.collect(),
rows: vec![vec![
"5678".into(),
"1234".into(),
"1".into(),
"'2022-07-22 23:27:48.775113301+00:00'".into(),
"3.14".into(),
]],
foreign_keys: [
("foo_id".into(), ("foo".into(), "foo_id".into())),
("bar_id".into(), ("bar".into(), "bar_id".into())),
]
.into_iter()
.collect(),
},
);
let fixture = snapshot.additive_fixture()?;
assert_eq!(
fixture.to_string(),
"INSERT INTO foo (foo_id, foo_a, foo_b)\n\
VALUES (1, 'asdf', true);\n\
INSERT INTO bar (bar_id, foo_id, bar_a, bar_b)\n\
VALUES (1234, 1, '2022-07-22 23:27:48.775113301+00:00', 3.14);\n\
INSERT INTO baz (baz_id, bar_id, foo_id, baz_a, baz_b)\n\
VALUES (5678, 1234, 1, '2022-07-22 23:27:48.775113301+00:00', 3.14);\n"
);
Ok(())
}

View file

@ -0,0 +1,262 @@
use std::future::Future;
use std::time::Duration;
use futures_core::future::BoxFuture;
pub use fixtures::FixtureSnapshot;
use sqlx_rt::test_block_on;
use crate::connection::{ConnectOptions, Connection};
use crate::database::Database;
use crate::error::Error;
use crate::executor::Executor;
use crate::migrate::{Migrate, Migrator};
use crate::pool::{Pool, PoolConnection, PoolOptions};
mod fixtures;
pub trait TestSupport: Database {
/// Get parameters to construct a `Pool` suitable for testing.
///
/// This `Pool` instance will behave somewhat specially:
/// * all handles share a single global semaphore to avoid exceeding the connection limit
/// on the database server.
/// * each invocation results in a different temporary database.
///
/// The implementation may require `DATABASE_URL` to be set in order to manage databases.
/// The user credentials it contains must have the privilege to create and drop databases.
fn test_context(args: &TestArgs) -> BoxFuture<'_, Result<TestContext<Self>, Error>>;
fn cleanup_test(db_name: &str) -> BoxFuture<'_, Result<(), Error>>;
/// Cleanup any test databases that are no longer in-use.
///
/// Returns a count of the databases deleted, if possible.
///
/// The implementation may require `DATABASE_URL` to be set in order to manage databases.
/// The user credentials it contains must have the privilege to create and drop databases.
fn cleanup_test_dbs() -> BoxFuture<'static, Result<Option<usize>, Error>>;
/// Take a snapshot of the current state of the database (data only).
///
/// This snapshot can then be used to generate test fixtures.
fn snapshot(conn: &mut Self::Connection)
-> BoxFuture<'_, Result<FixtureSnapshot<Self>, Error>>;
}
pub struct TestFixture {
pub path: &'static str,
pub contents: &'static str,
}
pub struct TestArgs {
pub test_path: &'static str,
pub migrator: Option<&'static Migrator>,
pub fixtures: &'static [TestFixture],
}
pub trait TestFn {
type Output;
fn run_test(self, args: TestArgs) -> Self::Output;
}
pub trait TestTermination {
fn is_success(&self) -> bool;
}
pub struct TestContext<DB: Database> {
pub pool_opts: PoolOptions<DB>,
pub connect_opts: <DB::Connection as Connection>::Options,
pub db_name: String,
}
impl<DB, Fut> TestFn for fn(Pool<DB>) -> Fut
where
DB: TestSupport + Database,
DB::Connection: Migrate,
for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>,
Fut: Future,
Fut::Output: TestTermination,
{
type Output = Fut::Output;
fn run_test(self, args: TestArgs) -> Self::Output {
run_test_with_pool(args, self)
}
}
impl<DB, Fut> TestFn for fn(PoolConnection<DB>) -> Fut
where
DB: TestSupport + Database,
DB::Connection: Migrate,
for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>,
Fut: Future,
Fut::Output: TestTermination,
{
type Output = Fut::Output;
fn run_test(self, args: TestArgs) -> Self::Output {
run_test_with_pool(args, |pool| async move {
let conn = pool
.acquire()
.await
.expect("failed to acquire test pool connection");
let res = (self)(conn).await;
pool.close().await;
res
})
}
}
impl<DB, Fut> TestFn for fn(PoolOptions<DB>, <DB::Connection as Connection>::Options) -> Fut
where
DB: Database + TestSupport,
DB::Connection: Migrate,
for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>,
Fut: Future,
Fut::Output: TestTermination,
{
type Output = Fut::Output;
fn run_test(self, args: TestArgs) -> Self::Output {
run_test(args, self)
}
}
impl<Fut> TestFn for fn() -> Fut
where
Fut: Future,
{
type Output = Fut::Output;
fn run_test(self, args: TestArgs) -> Self::Output {
assert!(
args.fixtures.is_empty(),
"fixtures cannot be applied for a bare function"
);
test_block_on(self())
}
}
impl TestArgs {
pub fn new(test_path: &'static str) -> Self {
TestArgs {
test_path,
migrator: None,
fixtures: &[],
}
}
pub fn migrator(&mut self, migrator: &'static Migrator) {
self.migrator = Some(migrator);
}
pub fn fixtures(&mut self, fixtures: &'static [TestFixture]) {
self.fixtures = fixtures;
}
}
impl TestTermination for () {
fn is_success(&self) -> bool {
true
}
}
impl<T, E> TestTermination for Result<T, E> {
fn is_success(&self) -> bool {
self.is_ok()
}
}
fn run_test_with_pool<DB, F, Fut>(args: TestArgs, test_fn: F) -> Fut::Output
where
DB: TestSupport,
DB::Connection: Migrate,
for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>,
F: FnOnce(Pool<DB>) -> Fut,
Fut: Future,
Fut::Output: TestTermination,
{
let test_path = args.test_path;
run_test::<DB, _, _>(args, |pool_opts, connect_opts| async move {
let pool = pool_opts
.connect_with(connect_opts)
.await
.expect("failed to connect test pool");
let res = test_fn(pool.clone()).await;
let close_timed_out = sqlx_rt::timeout(Duration::from_secs(10), pool.close())
.await
.is_err();
if close_timed_out {
eprintln!("test {} held onto Pool after exiting", test_path);
}
res
})
}
fn run_test<DB, F, Fut>(args: TestArgs, test_fn: F) -> Fut::Output
where
DB: TestSupport,
DB::Connection: Migrate,
for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>,
F: FnOnce(PoolOptions<DB>, <DB::Connection as Connection>::Options) -> Fut,
Fut: Future,
Fut::Output: TestTermination,
{
test_block_on(async move {
let test_context = DB::test_context(&args)
.await
.expect("failed to connect to setup test database");
setup_test_db::<DB>(&test_context.connect_opts, &args).await;
let res = test_fn(test_context.pool_opts, test_context.connect_opts).await;
if res.is_success() {
if let Err(e) = DB::cleanup_test(&test_context.db_name).await {
eprintln!(
"failed to delete database {:?}: {}",
test_context.db_name, e
);
}
}
res
})
}
async fn setup_test_db<DB: Database>(
copts: &<DB::Connection as Connection>::Options,
args: &TestArgs,
) where
DB::Connection: Migrate + Sized,
for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>,
{
let mut conn = copts
.connect()
.await
.expect("failed to connect to test database");
if let Some(migrator) = args.migrator {
migrator
.run_direct(&mut conn)
.await
.expect("failed to apply migrations");
}
for fixture in args.fixtures {
(&mut conn)
.execute(fixture.contents)
.await
.unwrap_or_else(|e| panic!("failed to apply test fixture {:?}: {:?}", fixture.path, e));
}
conn.close()
.await
.expect("failed to close setup connection");
}

View file

@ -2,8 +2,8 @@ use proc_macro2::Span;
use std::env;
use std::path::{Path, PathBuf};
pub(crate) fn resolve_path(path: &str, err_span: Span) -> syn::Result<PathBuf> {
let path = Path::new(path);
pub(crate) fn resolve_path(path: impl AsRef<Path>, err_span: Span) -> syn::Result<PathBuf> {
let path = path.as_ref();
if path.is_absolute() {
return Err(syn::Error::new(

View file

@ -21,6 +21,9 @@ mod database;
mod derives;
mod query;
// The compiler gives misleading help messages about `#[cfg(test)]` when this is just named `test`.
mod test_attr;
#[cfg(feature = "migrate")]
mod migrate;
@ -84,7 +87,7 @@ pub fn migrate(input: TokenStream) -> TokenStream {
use syn::LitStr;
let input = syn::parse_macro_input!(input as LitStr);
match migrate::expand_migrator_from_dir(input) {
match migrate::expand_migrator_from_lit_dir(input) {
Ok(ts) => ts.into(),
Err(e) => {
if let Some(parse_err) = e.downcast_ref::<syn::Error>() {
@ -97,40 +100,20 @@ pub fn migrate(input: TokenStream) -> TokenStream {
}
}
#[doc(hidden)]
#[proc_macro_attribute]
pub fn test(_attr: TokenStream, input: TokenStream) -> TokenStream {
pub fn test(args: TokenStream, input: TokenStream) -> TokenStream {
let args = syn::parse_macro_input!(args as syn::AttributeArgs);
let input = syn::parse_macro_input!(input as syn::ItemFn);
let ret = &input.sig.output;
let name = &input.sig.ident;
let body = &input.block;
let attrs = &input.attrs;
let result = if cfg!(feature = "_rt-tokio") {
quote! {
#[test]
#(#attrs)*
fn #name() #ret {
::sqlx_rt::tokio::runtime::Builder::new_multi_thread()
.enable_io()
.enable_time()
.build()
.unwrap()
.block_on(async { #body })
match test_attr::expand(args, input) {
Ok(ts) => ts.into(),
Err(e) => {
if let Some(parse_err) = e.downcast_ref::<syn::Error>() {
parse_err.to_compile_error().into()
} else {
let msg = e.to_string();
quote!(::std::compile_error!(#msg)).into()
}
}
} else if cfg!(feature = "_rt-async-std") {
quote! {
#[test]
#(#attrs)*
fn #name() #ret {
::sqlx_rt::async_std::task::block_on(async { #body })
}
}
} else {
panic!("one of 'runtime-actix', 'runtime-async-std' or 'runtime-tokio' features must be enabled");
};
result.into()
}
}

View file

@ -3,6 +3,7 @@ use quote::{quote, ToTokens, TokenStreamExt};
use sha2::{Digest, Sha384};
use sqlx_core::migrate::MigrationType;
use std::fs;
use std::path::Path;
use syn::LitStr;
pub struct QuotedMigrationType(MigrationType);
@ -56,8 +57,20 @@ impl ToTokens for QuotedMigration {
}
// mostly copied from sqlx-core/src/migrate/source.rs
pub(crate) fn expand_migrator_from_dir(dir: LitStr) -> crate::Result<TokenStream> {
let path = crate::common::resolve_path(&dir.value(), dir.span())?;
pub(crate) fn expand_migrator_from_lit_dir(dir: LitStr) -> crate::Result<TokenStream> {
expand_migrator_from_dir(&dir.value(), dir.span())
}
pub(crate) fn expand_migrator_from_dir(
dir: &str,
err_span: proc_macro2::Span,
) -> crate::Result<TokenStream> {
let path = crate::common::resolve_path(dir, err_span)?;
expand_migrator(&path)
}
pub(crate) fn expand_migrator(path: &Path) -> crate::Result<TokenStream> {
let mut migrations = Vec::new();
for entry in fs::read_dir(&path)? {

View file

@ -0,0 +1,217 @@
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::LitStr;
struct Args {
fixtures: Vec<LitStr>,
migrations: MigrationsOpt,
}
enum MigrationsOpt {
InferredPath,
ExplicitPath(LitStr),
ExplicitMigrator(syn::Path),
Disabled,
}
pub fn expand(args: syn::AttributeArgs, input: syn::ItemFn) -> crate::Result<TokenStream> {
if input.sig.inputs.is_empty() {
if !args.is_empty() {
if cfg!(feature = "migrate") {
return Err(syn::Error::new_spanned(
args.first().unwrap(),
"control attributes are not allowed unless \
the `migrate` feature is enabled and \
automatic test DB management is used; see docs",
)
.into());
}
return Err(syn::Error::new_spanned(
args.first().unwrap(),
"control attributes are not allowed unless \
automatic test DB management is used; see docs",
)
.into());
}
return Ok(expand_simple(input));
}
#[cfg(feature = "migrate")]
return expand_advanced(args, input);
#[cfg(not(feature = "migrate"))]
return Err(syn::Error::new_spanned(input, "`migrate` feature required").into());
}
fn expand_simple(input: syn::ItemFn) -> TokenStream {
let ret = &input.sig.output;
let name = &input.sig.ident;
let body = &input.block;
let attrs = &input.attrs;
quote! {
#[test]
#(#attrs)*
fn #name() #ret {
::sqlx::test_block_on(async { #body })
}
}
}
#[cfg(feature = "migrate")]
fn expand_advanced(args: syn::AttributeArgs, input: syn::ItemFn) -> crate::Result<TokenStream> {
let ret = &input.sig.output;
let name = &input.sig.ident;
let inputs = &input.sig.inputs;
let body = &input.block;
let attrs = &input.attrs;
let args = parse_args(args)?;
let fn_arg_types = inputs.iter().map(|_| quote! { _ });
let fixtures = args.fixtures.into_iter().map(|fixture| {
let path = format!("fixtures/{}.sql", fixture.value());
quote! {
::sqlx::testing::TestFixture {
path: #path,
contents: include_str!(#path),
}
}
});
let migrations = match args.migrations {
MigrationsOpt::ExplicitPath(path) => {
let migrator = crate::migrate::expand_migrator_from_lit_dir(path)?;
quote! { args.migrator(&#migrator); }
}
MigrationsOpt::InferredPath if !inputs.is_empty() => {
let migrations_path = crate::common::resolve_path("./migrations", Span::call_site())?;
if migrations_path.is_dir() {
let migrator = crate::migrate::expand_migrator(&migrations_path)?;
quote! { args.migrator(&#migrator); }
} else {
quote! {}
}
}
MigrationsOpt::ExplicitMigrator(path) => {
quote! { args.migrator(&#path); }
}
_ => quote! {},
};
Ok(quote! {
#[test]
#(#attrs)*
fn #name() #ret {
async fn inner(#inputs) #ret {
#body
}
let mut args = ::sqlx::testing::TestArgs::new(concat!(module_path!(), "::", stringify!(#name)));
#migrations
args.fixtures(&[#(#fixtures),*]);
// We need to give a coercion site or else we get "unimplemented trait" errors.
let f: fn(#(#fn_arg_types),*) -> _ = inner;
::sqlx::testing::TestFn::run_test(f, args)
}
})
}
#[cfg(feature = "migrate")]
fn parse_args(attr_args: syn::AttributeArgs) -> syn::Result<Args> {
let mut fixtures = vec![];
let mut migrations = MigrationsOpt::InferredPath;
for arg in attr_args {
match arg {
syn::NestedMeta::Meta(syn::Meta::List(list)) if list.path.is_ident("fixtures") => {
if !fixtures.is_empty() {
return Err(syn::Error::new_spanned(list, "duplicate `fixtures` arg"));
}
for nested in list.nested {
match nested {
syn::NestedMeta::Lit(syn::Lit::Str(litstr)) => fixtures.push(litstr),
other => {
return Err(syn::Error::new_spanned(other, "expected string literal"))
}
}
}
}
syn::NestedMeta::Meta(syn::Meta::NameValue(namevalue))
if namevalue.path.is_ident("migrations") =>
{
if !matches!(migrations, MigrationsOpt::InferredPath) {
return Err(syn::Error::new_spanned(
namevalue,
"cannot have more than one `migrations` or `migrator` arg",
));
}
migrations = match namevalue.lit {
syn::Lit::Bool(litbool) => {
if !litbool.value {
// migrations = false
MigrationsOpt::Disabled
} else {
// migrations = true
return Err(syn::Error::new_spanned(
litbool,
"`migrations = true` is redundant",
));
}
}
// migrations = "<path>"
syn::Lit::Str(litstr) => MigrationsOpt::ExplicitPath(litstr),
_ => {
return Err(syn::Error::new_spanned(
namevalue,
"expected string or `false`",
))
}
};
}
syn::NestedMeta::Meta(syn::Meta::NameValue(namevalue))
if namevalue.path.is_ident("migrator") =>
{
if !matches!(migrations, MigrationsOpt::InferredPath) {
return Err(syn::Error::new_spanned(
namevalue,
"cannot have more than one `migrations` or `migrator` arg",
));
}
migrations = match namevalue.lit {
// migrator = "<path>"
syn::Lit::Str(litstr) => MigrationsOpt::ExplicitMigrator(litstr.parse()?),
_ => {
return Err(syn::Error::new_spanned(
namevalue,
"expected string",
))
}
};
}
other => {
return Err(syn::Error::new_spanned(
other,
"expected `fixtures(\"<filename>\", ...)` or `migrations = \"<path>\" | false` or `migrator = \"<rust path>\"`",
))
}
}
}
Ok(Args {
fixtures,
migrations,
})
}

View file

@ -1,6 +1,10 @@
//! Core runtime support for SQLx. **Semver-exempt**, not for general use.
#[cfg(not(any(
feature = "runtime-actix-native-tls",
feature = "runtime-async-std-native-tls",
feature = "runtime-tokio-native-tls",
feature = "runtime-actix-rustls",
feature = "runtime-async-std-rustls",
feature = "runtime-tokio-rustls",
)))]
@ -11,6 +15,8 @@ compile_error!(
);
#[cfg(any(
all(feature = "_rt-actix", feature = "_rt-async-std"),
all(feature = "_rt-actix", feature = "_rt-tokio"),
all(feature = "_rt-async-std", feature = "_rt-tokio"),
all(feature = "_tls-native-tls", feature = "_tls-rustls"),
))]
@ -20,116 +26,24 @@ compile_error!(
'runtime-tokio-rustls'] can be enabled"
);
#[cfg(feature = "_rt-async-std")]
mod rt_async_std;
#[cfg(any(feature = "_rt-tokio", feature = "_rt-actix"))]
mod rt_tokio;
#[cfg(all(feature = "_tls-native-tls"))]
pub use native_tls;
//
// Tokio
// Actix *OR* Tokio
//
#[cfg(all(feature = "_rt-tokio", not(feature = "_rt-async-std")))]
pub use tokio::{
self, fs, io::AsyncRead, io::AsyncReadExt, io::AsyncWrite, io::AsyncWriteExt, io::ReadBuf,
net::TcpStream, runtime::Handle, sync::Mutex as AsyncMutex, task::spawn, task::yield_now,
time::sleep, time::timeout,
};
#[cfg(all(unix, feature = "_rt-tokio", not(feature = "_rt-async-std")))]
pub use tokio::net::UnixStream;
#[cfg(all(feature = "_rt-tokio", not(feature = "_rt-async-std")))]
pub use tokio_runtime::{block_on, enter_runtime};
#[cfg(feature = "_rt-tokio")]
mod tokio_runtime {
use once_cell::sync::Lazy;
use tokio::runtime::{self, Runtime};
// lazily initialize a global runtime once for multiple invocations of the macros
static RUNTIME: Lazy<Runtime> = Lazy::new(|| {
runtime::Builder::new_multi_thread()
.enable_io()
.enable_time()
.build()
.expect("failed to initialize Tokio runtime")
});
pub fn block_on<F: std::future::Future>(future: F) -> F::Output {
RUNTIME.block_on(future)
}
pub fn enter_runtime<F, R>(f: F) -> R
where
F: FnOnce() -> R,
{
let _rt = RUNTIME.enter();
f()
}
}
#[cfg(all(any(feature = "_rt-tokio", feature = "_rt-actix"),))]
pub use rt_tokio::*;
#[cfg(all(
feature = "_tls-native-tls",
feature = "_rt-tokio",
not(any(feature = "_tls-rustls", feature = "_rt-async-std")),
))]
pub use tokio_native_tls::{TlsConnector, TlsStream};
#[cfg(all(
feature = "_tls-rustls",
feature = "_rt-tokio",
not(any(feature = "_tls-native-tls", feature = "_rt-async-std")),
))]
pub use tokio_rustls::{client::TlsStream, TlsConnector};
#[cfg(all(feature = "_rt-tokio", not(feature = "_rt-async-std")))]
#[macro_export]
macro_rules! blocking {
($($expr:tt)*) => {
$crate::tokio::task::spawn_blocking(move || { $($expr)* })
.await.expect("Blocking task failed to complete.")
};
}
//
// async-std
//
#[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))]
pub use async_std::{
self, fs, future::timeout, io::prelude::ReadExt as AsyncReadExt,
io::prelude::WriteExt as AsyncWriteExt, io::Read as AsyncRead, io::Write as AsyncWrite,
net::TcpStream, sync::Mutex as AsyncMutex, task::sleep, task::spawn, task::yield_now,
};
#[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))]
#[macro_export]
macro_rules! blocking {
($($expr:tt)*) => {
$crate::async_std::task::spawn_blocking(move || { $($expr)* }).await
};
}
#[cfg(all(unix, feature = "_rt-async-std", not(feature = "_rt-tokio")))]
pub use async_std::os::unix::net::UnixStream;
#[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))]
pub use async_std::task::block_on;
#[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))]
pub fn enter_runtime<F, R>(f: F) -> R
where
F: FnOnce() -> R,
{
// no-op for async-std
f()
}
#[cfg(all(feature = "async-native-tls", not(feature = "tokio-native-tls")))]
pub use async_native_tls::{TlsConnector, TlsStream};
#[cfg(all(
feature = "_tls-rustls",
feature = "_rt-async-std",
not(any(feature = "_tls-native-tls", feature = "_rt-tokio")),
not(any(feature = "_rt-tokio", feature = "_rt-actix"))
))]
pub use futures_rustls::{client::TlsStream, TlsConnector};
pub use rt_async_std::*;

View file

@ -0,0 +1,24 @@
pub use async_std::{
self, fs, future::timeout, io::prelude::ReadExt as AsyncReadExt,
io::prelude::WriteExt as AsyncWriteExt, io::Read as AsyncRead, io::Write as AsyncWrite,
net::TcpStream, sync::Mutex as AsyncMutex, task::sleep, task::spawn, task::yield_now,
};
#[cfg(unix)]
pub use async_std::os::unix::net::UnixStream;
#[cfg(all(feature = "_tls-native-tls", not(feature = "_tls-rustls")))]
pub use async_native_tls::{TlsConnector, TlsStream};
#[cfg(all(feature = "_tls-rustls", not(feature = "_tls-native-tls")))]
pub use futures_rustls::{client::TlsStream, TlsConnector};
pub use async_std::task::{block_on, block_on as test_block_on};
pub fn enter_runtime<F, R>(f: F) -> R
where
F: FnOnce() -> R,
{
// no-op for async-std
f()
}

47
sqlx-rt/src/rt_tokio.rs Normal file
View file

@ -0,0 +1,47 @@
pub use tokio::{
self, fs, io::AsyncRead, io::AsyncReadExt, io::AsyncWrite, io::AsyncWriteExt, io::ReadBuf,
net::TcpStream, runtime::Handle, sync::Mutex as AsyncMutex, task::spawn, task::yield_now,
time::sleep, time::timeout,
};
#[cfg(unix)]
pub use tokio::net::UnixStream;
use once_cell::sync::Lazy;
use tokio::runtime::{self, Runtime};
#[cfg(all(feature = "_tls-native-tls", not(feature = "_tls-rustls")))]
pub use tokio_native_tls::{TlsConnector, TlsStream};
#[cfg(all(feature = "_tls-rustls", not(feature = "_tls-native-tls")))]
pub use tokio_rustls::{client::TlsStream, TlsConnector};
// lazily initialize a global runtime once for multiple invocations of the macros
static RUNTIME: Lazy<Runtime> = Lazy::new(|| {
runtime::Builder::new_multi_thread()
.enable_io()
.enable_time()
.build()
.expect("failed to initialize Tokio runtime")
});
pub fn block_on<F: std::future::Future>(future: F) -> F::Output {
RUNTIME.block_on(future)
}
pub fn enter_runtime<F, R>(f: F) -> R
where
F: FnOnce() -> R,
{
let _rt = RUNTIME.enter();
f()
}
pub fn test_block_on<F: std::future::Future>(future: F) -> F::Output {
// For tests, we want a single runtime per thread for isolation.
runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("failed to initialize Tokio test runtime")
.block_on(future)
}

View file

@ -71,6 +71,18 @@ pub extern crate sqlx_macros;
#[doc(hidden)]
pub use sqlx_macros::{FromRow, Type};
// We can't do our normal facade approach with an attribute, but thankfully we can now
// have docs out-of-line quite easily.
#[doc = include_str!("macros/test.md")]
pub use sqlx_macros::test;
#[doc(hidden)]
#[cfg(feature = "migrate")]
pub use sqlx_core::testing;
#[doc(hidden)]
pub use sqlx_core::test_block_on;
#[cfg(feature = "macros")]
mod macros;

218
src/macros/test.md Normal file
View file

@ -0,0 +1,218 @@
Mark an `async fn` as a test with SQLx support.
The test will automatically be executed in the async runtime according to the chosen
`runtime-{async-std, tokio}-{native-tls, rustls}` feature.
By default, this behaves identically to `#[tokio::test]`<sup>1</sup> or `#[async_std::test]`:
```rust,norun
# // Note if reading these examples directly in `test.md`:
# // lines prefixed with `#` are not meant to be shown;
# // they are supporting code to help the examples to compile successfully.
# #[cfg(feature = "_rt-tokio")]
#[sqlx::test]
async fn test_async_fn() {
tokio::task::yield_now().await;
}
```
However, several advanced features are also supported as shown in the next section.
<sup>1</sup>`#[sqlx::test]` does not recognize any of the control arguments supported by `#[tokio::test]`
as that would have complicated the implementation. If your use case requires any of those, feel free to open an issue.
### Automatic Test Database Management (requires `migrate` feature)
`#[sqlx::test]` can automatically create test databases for you and provide live connections to your test.
For every annotated function, a new test database is created so tests can run against a live database
but are isolated from each other.
This feature is activated by changing the signature of your test function. The following signatures are supported:
* `async fn(Pool<DB>) -> Ret`
* the `Pool`s used by all running tests share a single connection limit to avoid exceeding the server's limit.
* `async fn(PoolConnection<DB>) -> Ret`
* `PoolConnection<Postgres>`, etc.
* `async fn(PoolOptions<DB>, impl ConnectOptions<DB>) -> Ret`
* Where `impl ConnectOptions` is, e.g, `PgConnectOptions`, `MySqlConnectOptions`, etc.
* If your test wants to create its own `Pool` (for example, to set pool callbacks or to modify `ConnectOptions`),
you can use this signature.
Where `DB` is a supported `Database` type and `Ret` is `()` or `Result<_, _>`.
##### Supported Databases
Most of these will require you to set `DATABASE_URL` as an environment variable
or in a `.env` file like `sqlx::query!()` _et al_, to give the test driver a superuser connection with which
to manage test databases.
| Database | Requires `DATABASE_URL` |
| --- | --- |
| Postgres | Yes |
| MySQL | Yes |
| SQLite | No<sup>2</sup> |
Test databases are automatically cleaned up as tests succeed, but failed tests will leave their databases in-place
to facilitate debugging. Note that to simplify the implementation, panics are _always_ considered to be failures,
even for `#[should_panic]` tests.
If you have `sqlx-cli` installed, you can run `sqlx test-db cleanup` to delete all test databases.
Old test databases will also be deleted the next time a test binary using `#[sqlx::test]` is run.
```rust,no_run
# #[cfg(all(feature = "migrate", feature = "postgres"))]
# mod example {
use sqlx::PgPool;
#[sqlx::test]
async fn basic_test(pool: PgPool) -> sqlx::Result<()> {
let mut conn = pool.acquire().await?;
sqlx::query("SELECT * FROM foo")
.fetch_one(&mut conn)
.await?;
assert_eq!(foo.get::<String>("bar"), "foobar!");
Ok(())
}
# }
```
<sup>2</sup> SQLite defaults to `target/sqlx/test-dbs/<path>.sqlite` where `<path>` is the path of the test function
converted to a filesystem path (`::` replaced with `/`).
### Automatic Migrations (requires `migrate` feature)
To ensure a straightforward test implementation against a fresh test database, migrations are automatically applied if a
`migrations` folder is found in the same directory as `CARGO_MANIFEST_DIR` (the directory where the current crate's
`Cargo.toml` resides).
You can override the resolved path relative to `CARGO_MANIFEST_DIR` in the attribute (global overrides are not currently
supported):
```rust,ignore
# #[cfg(all(feature = "migrate", feature = "postgres"))]
# mod example {
use sqlx::PgPool;
#[sqlx::test(migrations = "foo_migrations")]
async fn basic_test(pool: PgPool) -> sqlx::Result<()> {
let mut conn = pool.acquire().await?;
sqlx::query("SELECT * FROM foo")
.fetch_one(&mut conn)
.await?;
assert_eq!(foo.get::<String>("bar"), "foobar!");
Ok(())
}
# }
```
Or if you're already embedding migrations in your main crate, you can reference them directly:
`foo_crate/lib.rs`
```rust,ignore
pub static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!("foo_migrations");
```
`foo_crate/tests/foo_test.rs`
```rust,no_run
# #[cfg(all(feature = "migrate", feature = "postgres"))]
# mod example {
use sqlx::PgPool;
# // This is standing in for the main crate since doc examples don't support multiple crates.
# mod foo_crate {
# use std::borrow::Cow;
# static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate::Migrator {
# migrations: Cow::Borrowed(&[]),
# ignore_missing: false,
# };
# }
// You could also do `use foo_crate::MIGRATOR` and just refer to it as `MIGRATOR` here.
#[sqlx::test(migrator = "foo_crate::MIGRATOR")]
async fn basic_test(pool: PgPool) -> sqlx::Result<()> {
let mut conn = pool.acquire().await?;
sqlx::query("SELECT * FROM foo")
.fetch_one(&mut conn)
.await?;
assert_eq!(foo.get::<String>("bar"), "foobar!");
Ok(())
}
# }
```
Or disable migrations processing entirely:
```rust,no_run
# #[cfg(all(feature = "migrate", feature = "postgres"))]
# mod example {
use sqlx::PgPool;
#[sqlx::test(migrations = false)]
async fn basic_test(pool: PgPool) -> sqlx::Result<()> {
let mut conn = pool.acquire().await?;
conn.execute("CREATE TABLE foo(bar text)").await?;
sqlx::query("SELECT * FROM foo")
.fetch_one(&mut conn)
.await?;
assert_eq!(foo.get::<String>("bar"), "foobar!");
Ok(())
}
# }
```
### Automatic Fixture Application (requires `migrate` feature)
Since tests are isolated from each other but may require data to already exist in the database to keep from growing
exponentially in complexity, `#[sqlx::test]` also supports applying test fixtures, which are SQL scripts that function
similarly to migrations but are solely intended to insert test data and be arbitrarily composable.
Imagine a basic social app that has users, posts and comments. To test the comment routes, you'd want
the database to already have users and posts in it so the comments tests don't have to duplicate that work.
You can pass a list of fixture names to the attribute like so, and they will be applied in the given order<sup>3</sup>:
```rust,no_run
# #[cfg(all(feature = "migrate", feature = "postgres"))]
# mod example {
# struct App {}
# fn create_app(pool: PgPool) -> App { App {} }
use sqlx::PgPool;
use serde_json::json;
#[sqlx::test(fixtures("users", "posts"))]
async fn test_create_comment(pool: PgPool) -> sqlx::Result<()> {
// See examples/postgres/social-axum-with-tests for a more in-depth example.
let mut app = create_app(pool);
let comment = test_request(
&mut app, "POST", "/v1/comment", json! { "postId": "1234" }
).await?;
assert_eq!(comment["postId"], "1234");
Ok(())
}
# }
```
Fixtures are resolved relative to the current file as `./fixtures/{name}.sql`.
<sup>3</sup>Ordering for test fixtures is entirely up to the application, and each test may choose which fixtures to
apply and which to omit. However, since each fixture is applied separately (sent as a single command string, so wrapped
in an implicit `BEGIN` and `COMMIT`), you will want to make sure to order the fixtures such that foreign key
requirements are always satisfied, or else you might get errors.

View file

@ -2,7 +2,7 @@ version: "3"
services:
#
# MySQL 8.x, 5.7.x, 5.6.x
# MySQL 8.x, 5.7.x
# https://www.mysql.com/support/supportedplatforms/database.html
#
@ -29,20 +29,8 @@ services:
MYSQL_ROOT_HOST: '%'
MYSQL_ROOT_PASSWORD: password
MYSQL_DATABASE: sqlx
mysql_5_6:
image: mysql:5.6
volumes:
- "./mysql/setup.sql:/docker-entrypoint-initdb.d/setup.sql"
ports:
- 3306
environment:
MYSQL_ROOT_HOST: '%'
MYSQL_ROOT_PASSWORD: password
MYSQL_DATABASE: sqlx
#
# MariaDB 10.6, 10.5, 10.4, 10.3, 10.2
# MariaDB 10.6, 10.5, 10.4, 10.3
# https://mariadb.org/about/#maintenance-policy
#
@ -86,18 +74,8 @@ services:
MYSQL_ROOT_PASSWORD: password
MYSQL_DATABASE: sqlx
mariadb_10_2:
image: mariadb:10.2
volumes:
- "./mysql/setup.sql:/docker-entrypoint-initdb.d/setup.sql"
ports:
- 3306
environment:
MYSQL_ROOT_PASSWORD: password
MYSQL_DATABASE: sqlx
#
# PostgreSQL 14.x, 13.x, 12.x, 11.x 10.x, 9.6.x
# PostgreSQL 14.x, 13.x, 12.x, 11.x 10.x
# https://www.postgresql.org/support/versioning/
#
@ -195,25 +173,6 @@ services:
- "./postgres/setup.sql:/docker-entrypoint-initdb.d/setup.sql"
command: >
-c ssl=on -c ssl_cert_file=/var/lib/postgresql/server.crt -c ssl_key_file=/var/lib/postgresql/server.key
postgres_9_6:
build:
context: .
dockerfile: postgres/Dockerfile
args:
VERSION: 9.6
ports:
- 5432
environment:
POSTGRES_DB: sqlx
POSTGRES_USER: postgres
POSTGRES_PASSWORD: password
POSTGRES_HOST_AUTH_METHOD: md5
volumes:
- "./postgres/setup.sql:/docker-entrypoint-initdb.d/setup.sql"
command: >
-c ssl=on -c ssl_cert_file=/var/lib/postgresql/server.crt -c ssl_key_file=/var/lib/postgresql/server.key
#
# Microsoft SQL Server (MSSQL)
# https://hub.docker.com/_/microsoft-mssql-server

View file

@ -0,0 +1,16 @@
insert into comment(comment_id, post_id, user_id, content, created_at)
values (1,
1,
2,
'lol bet ur still bad, 1v1 me',
timestamp(now(), '-0:50:00')),
(2,
1,
1,
'you''re on!',
timestamp(now(), '-0:45:00')),
(3,
2,
1,
'lol you''re just mad you lost :P',
timestamp(now(), '-0:15:00'));

View file

@ -0,0 +1,9 @@
insert into post(post_id, user_id, content, created_at)
values (1,
1,
'This new computer is lightning-fast!',
timestamp(now(), '-1:00:00')),
(2,
2,
'@alice is a haxxor :(',
timestamp(now(), '-0:30:00'));

View file

@ -0,0 +1,2 @@
insert into user(user_id, username)
values (1, 'alice'), (2, 'bob');

View file

@ -0,0 +1,7 @@
create table user
(
-- integer primary keys are the most efficient in SQLite
user_id integer primary key auto_increment,
-- indexed text values have to have a max length
username varchar(16) unique not null
);

View file

@ -0,0 +1,10 @@
create table post
(
post_id integer primary key auto_increment,
user_id integer not null references user (user_id),
content text not null,
-- Defaults have to be wrapped in parenthesis
created_at datetime default current_timestamp
);
create index post_created_at on post (created_at desc);

View file

@ -0,0 +1,10 @@
create table comment
(
comment_id integer primary key,
post_id integer not null references post (post_id),
user_id integer not null references user (user_id),
content text not null,
created_at datetime default current_timestamp
);
create index comment_created_at on comment (created_at desc);

96
tests/mysql/test-attr.rs Normal file
View file

@ -0,0 +1,96 @@
// The no-arg variant is covered by other tests already.
use sqlx::{MySqlPool, Row};
const MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!("tests/mysql/migrations");
#[sqlx::test]
async fn it_gets_a_pool(pool: MySqlPool) -> sqlx::Result<()> {
let mut conn = pool.acquire().await?;
let db_name: String = sqlx::query_scalar("select database()")
.fetch_one(&mut conn)
.await?;
assert!(
db_name.starts_with("_sqlx_test_database_"),
"db_name: {:?}",
db_name
);
Ok(())
}
// This should apply migrations and then `fixtures/users.sql`
#[sqlx::test(migrations = "tests/mysql/migrations", fixtures("users"))]
async fn it_gets_users(pool: MySqlPool) -> sqlx::Result<()> {
let usernames: Vec<String> =
sqlx::query_scalar(r#"SELECT username FROM user ORDER BY username"#)
.fetch_all(&pool)
.await?;
assert_eq!(usernames, ["alice", "bob"]);
let post_exists: bool = sqlx::query_scalar("SELECT exists(SELECT 1 FROM post)")
.fetch_one(&pool)
.await?;
assert!(!post_exists);
let comment_exists: bool = sqlx::query_scalar("SELECT exists(SELECT 1 FROM comment)")
.fetch_one(&pool)
.await?;
assert!(!comment_exists);
Ok(())
}
#[sqlx::test(migrations = "tests/mysql/migrations", fixtures("users", "posts"))]
async fn it_gets_posts(pool: MySqlPool) -> sqlx::Result<()> {
let post_contents: Vec<String> =
sqlx::query_scalar("SELECT content FROM post ORDER BY created_at")
.fetch_all(&pool)
.await?;
assert_eq!(
post_contents,
[
"This new computer is lightning-fast!",
"@alice is a haxxor :("
]
);
let comment_exists: bool = sqlx::query_scalar("SELECT exists(SELECT 1 FROM comment)")
.fetch_one(&pool)
.await?;
assert!(!comment_exists);
Ok(())
}
// Try `migrator`
#[sqlx::test(migrator = "MIGRATOR", fixtures("users", "posts", "comments"))]
async fn it_gets_comments(pool: MySqlPool) -> sqlx::Result<()> {
let post_1_comments: Vec<String> =
sqlx::query_scalar("SELECT content FROM comment WHERE post_id = ? ORDER BY created_at")
.bind(&1)
.fetch_all(&pool)
.await?;
assert_eq!(
post_1_comments,
["lol bet ur still bad, 1v1 me", "you're on!"]
);
let post_2_comments: Vec<String> =
sqlx::query_scalar("SELECT content FROM comment WHERE post_id = ? ORDER BY created_at")
.bind(&2)
.fetch_all(&pool)
.await?;
assert_eq!(post_2_comments, ["lol you're just mad you lost :P"]);
Ok(())
}

View file

@ -0,0 +1,16 @@
insert into comment(comment_id, post_id, user_id, content, created_at)
values ('fbbbb7dc-dc6f-4649-b663-8d3636035164',
'252c1d98-a9b0-4f18-8298-e59058bdfe16',
'297923c5-a83c-4052-bab0-030887154e52',
'lol bet ur still bad, 1v1 me',
now() + '50 minutes ago'::interval),
('cb7612a2-cff4-4e3e-a768-055f01f25dc4',
'252c1d98-a9b0-4f18-8298-e59058bdfe16',
'297923c5-a83c-4052-bab0-030887154e52',
'you''re on!',
now() + '45 minutes ago'::interval),
('f2164fcc-a770-4f52-8714-d9cc6a1c89cf',
'844265f7-2472-4689-9a2e-b21f40dbf401',
'297923c5-a83c-4052-bab0-030887154e52',
'lol you''re just mad you lost :P',
now() + '15 minutes ago'::interval);

View file

@ -0,0 +1,14 @@
insert into post(post_id, user_id, content, created_at)
values
(
'252c1d98-a9b0-4f18-8298-e59058bdfe16',
'6592b7c0-b531-4613-ace5-94246b7ce0c3',
'This new computer is lightning-fast!',
now() + '1 hour ago'::interval
),
(
'844265f7-2472-4689-9a2e-b21f40dbf401',
'6592b7c0-b531-4613-ace5-94246b7ce0c3',
'@alice is a haxxor :(',
now() + '30 minutes ago'::interval
);

View file

@ -0,0 +1,2 @@
insert into "user"(user_id, username)
values ('6592b7c0-b531-4613-ace5-94246b7ce0c3', 'alice'), ('297923c5-a83c-4052-bab0-030887154e52', 'bob');

View file

@ -0,0 +1,2 @@
-- `gen_random_uuid()` wasn't added until Postgres 13
create extension if not exists "uuid-ossp";

View file

@ -0,0 +1,5 @@
create table "user"
(
user_id uuid primary key default uuid_generate_v1mc(),
username text unique not null
);

View file

@ -0,0 +1,8 @@
create table post (
post_id uuid primary key default uuid_generate_v1mc(),
user_id uuid not null references "user"(user_id),
content text not null,
created_at timestamptz default now()
);
create index on post(created_at desc);

View file

@ -0,0 +1,9 @@
create table comment (
comment_id uuid primary key default uuid_generate_v1mc(),
post_id uuid not null references post(post_id),
user_id uuid not null references "user"(user_id),
content text not null,
created_at timestamptz not null default now()
);
create index on comment(created_at desc);

View file

@ -0,0 +1,94 @@
// The no-arg variant is covered by other tests already.
use sqlx::PgPool;
const MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!("tests/postgres/migrations");
#[sqlx::test]
async fn it_gets_a_pool(pool: PgPool) -> sqlx::Result<()> {
let mut conn = pool.acquire().await?;
let db_name: String = sqlx::query_scalar("SELECT current_database()")
.fetch_one(&mut conn)
.await?;
assert!(db_name.starts_with("_sqlx_test"), "dbname: {db_name:?}");
Ok(())
}
// This should apply migrations and then `fixtures/users.sql`
#[sqlx::test(migrations = "tests/postgres/migrations", fixtures("users"))]
async fn it_gets_users(pool: PgPool) -> sqlx::Result<()> {
let usernames: Vec<String> =
sqlx::query_scalar(r#"SELECT username FROM "user" ORDER BY username"#)
.fetch_all(&pool)
.await?;
assert_eq!(usernames, ["alice", "bob"]);
let post_exists: bool = sqlx::query_scalar("SELECT exists(SELECT 1 FROM post)")
.fetch_one(&pool)
.await?;
assert!(!post_exists);
let comment_exists: bool = sqlx::query_scalar("SELECT exists(SELECT 1 FROM comment)")
.fetch_one(&pool)
.await?;
assert!(!comment_exists);
Ok(())
}
#[sqlx::test(migrations = "tests/postgres/migrations", fixtures("users", "posts"))]
async fn it_gets_posts(pool: PgPool) -> sqlx::Result<()> {
let post_contents: Vec<String> =
sqlx::query_scalar("SELECT content FROM post ORDER BY created_at")
.fetch_all(&pool)
.await?;
assert_eq!(
post_contents,
[
"This new computer is lightning-fast!",
"@alice is a haxxor :("
]
);
let comment_exists: bool = sqlx::query_scalar("SELECT exists(SELECT 1 FROM comment)")
.fetch_one(&pool)
.await?;
assert!(!comment_exists);
Ok(())
}
// Try `migrator`
#[sqlx::test(migrator = "MIGRATOR", fixtures("users", "posts", "comments"))]
async fn it_gets_comments(pool: PgPool) -> sqlx::Result<()> {
let post_1_comments: Vec<String> = sqlx::query_scalar(
"SELECT content FROM comment WHERE post_id = $1::uuid ORDER BY created_at",
)
.bind(&"252c1d98-a9b0-4f18-8298-e59058bdfe16")
.fetch_all(&pool)
.await?;
assert_eq!(
post_1_comments,
["lol bet ur still bad, 1v1 me", "you're on!"]
);
let post_2_comments: Vec<String> = sqlx::query_scalar(
"SELECT content FROM comment WHERE post_id = $1::uuid ORDER BY created_at",
)
.bind(&"844265f7-2472-4689-9a2e-b21f40dbf401")
.fetch_all(&pool)
.await?;
assert_eq!(post_2_comments, ["lol you're just mad you lost :P"]);
Ok(())
}

View file

@ -547,7 +547,7 @@ test_prepared_type!(money_vec<Vec<PgMoney>>(Postgres,
"array[123.45,420.00,666.66]::money[]" == vec![PgMoney(12345), PgMoney(42000), PgMoney(66666)],
));
// FIXME: needed to disable `ltree` tests in Postgres 9.6
// FIXME: needed to disable `ltree` tests in version that don't have a binary format for it
// but `PgLTree` should just fall back to text format
#[cfg(postgres_14)]
test_type!(ltree<sqlx::postgres::types::PgLTree>(Postgres,
@ -555,7 +555,7 @@ test_type!(ltree<sqlx::postgres::types::PgLTree>(Postgres,
"'Alpha.Beta.Delta.Gamma'::ltree" == sqlx::postgres::types::PgLTree::from_iter(["Alpha", "Beta", "Delta", "Gamma"]).unwrap(),
));
// FIXME: needed to disable `ltree` tests in Postgres 9.6
// FIXME: needed to disable `ltree` tests in version that don't have a binary format for it
// but `PgLTree` should just fall back to text format
#[cfg(postgres_14)]
test_type!(ltree_vec<Vec<sqlx::postgres::types::PgLTree>>(Postgres,

View file

@ -0,0 +1,16 @@
insert into comment(comment_id, post_id, user_id, content, created_at)
values (1,
1,
2,
'lol bet ur still bad, 1v1 me',
datetime('now', '-50 minutes')),
(2,
1,
1,
'you''re on!',
datetime('now', '-45 minutes')),
(3,
2,
1,
'lol you''re just mad you lost :P',
datetime('now', '-15 minutes'));

View file

@ -0,0 +1,9 @@
insert into post(post_id, user_id, content, created_at)
values (1,
1,
'This new computer is lightning-fast!',
datetime('now', '-1 hour')),
(2,
2,
'@alice is a haxxor :(',
datetime('now', '-30 minutes'));

View file

@ -0,0 +1,2 @@
insert into "user"(user_id, username)
values (1, 'alice'), (2, 'bob');

View file

@ -0,0 +1,6 @@
create table user
(
-- integer primary keys are the most efficient in SQLite
user_id integer primary key,
username text unique not null
);

View file

@ -0,0 +1,10 @@
create table post
(
post_id integer primary key,
user_id integer not null references user (user_id),
content text not null,
-- Defaults have to be wrapped in parenthesis
created_at datetime default (datetime('now'))
);
create index post_created_at on post (created_at desc);

View file

@ -0,0 +1,10 @@
create table comment
(
comment_id integer primary key,
post_id integer not null references post (post_id),
user_id integer not null references "user" (user_id),
content text not null,
created_at datetime default (datetime('now'))
);
create index comment_created_at on comment (created_at desc);

99
tests/sqlite/test-attr.rs Normal file
View file

@ -0,0 +1,99 @@
// The no-arg variant is covered by other tests already.
use sqlx::{Row, SqlitePool};
const MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!("tests/sqlite/migrations");
#[sqlx::test]
async fn it_gets_a_pool(pool: SqlitePool) -> sqlx::Result<()> {
let mut conn = pool.acquire().await?;
// https://www.sqlite.org/pragma.html#pragma_database_list
let db = sqlx::query("PRAGMA database_list")
.fetch_one(&mut conn)
.await?;
let db_name = db.get::<String, _>(2);
assert!(
db_name.ends_with("target/sqlx/test-dbs/sqlite_test_attr/it_gets_a_pool.sqlite"),
"db_name: {:?}",
db_name
);
Ok(())
}
// This should apply migrations and then `fixtures/users.sql`
#[sqlx::test(migrations = "tests/sqlite/migrations", fixtures("users"))]
async fn it_gets_users(pool: SqlitePool) -> sqlx::Result<()> {
let usernames: Vec<String> =
sqlx::query_scalar(r#"SELECT username FROM "user" ORDER BY username"#)
.fetch_all(&pool)
.await?;
assert_eq!(usernames, ["alice", "bob"]);
let post_exists: bool = sqlx::query_scalar("SELECT exists(SELECT 1 FROM post)")
.fetch_one(&pool)
.await?;
assert!(!post_exists);
let comment_exists: bool = sqlx::query_scalar("SELECT exists(SELECT 1 FROM comment)")
.fetch_one(&pool)
.await?;
assert!(!comment_exists);
Ok(())
}
#[sqlx::test(migrations = "tests/sqlite/migrations", fixtures("users", "posts"))]
async fn it_gets_posts(pool: SqlitePool) -> sqlx::Result<()> {
let post_contents: Vec<String> =
sqlx::query_scalar("SELECT content FROM post ORDER BY created_at")
.fetch_all(&pool)
.await?;
assert_eq!(
post_contents,
[
"This new computer is lightning-fast!",
"@alice is a haxxor :("
]
);
let comment_exists: bool = sqlx::query_scalar("SELECT exists(SELECT 1 FROM comment)")
.fetch_one(&pool)
.await?;
assert!(!comment_exists);
Ok(())
}
// Try `migrator`
#[sqlx::test(migrator = "MIGRATOR", fixtures("users", "posts", "comments"))]
async fn it_gets_comments(pool: SqlitePool) -> sqlx::Result<()> {
let post_1_comments: Vec<String> =
sqlx::query_scalar("SELECT content FROM comment WHERE post_id = ? ORDER BY created_at")
.bind(&1)
.fetch_all(&pool)
.await?;
assert_eq!(
post_1_comments,
["lol bet ur still bad, 1v1 me", "you're on!"]
);
let post_2_comments: Vec<String> =
sqlx::query_scalar("SELECT content FROM comment WHERE post_id = ? ORDER BY created_at")
.bind(&2)
.fetch_all(&pool)
.await?;
assert_eq!(post_2_comments, ["lol you're just mad you lost :P"]);
Ok(())
}

View file

@ -130,7 +130,7 @@ for runtime in ["async-std", "tokio", "actix"]:
# postgres
#
for version in ["14", "13", "12", "11", "10", "9_6"]:
for version in ["14", "13", "12", "11", "10"]:
run(
f"cargo test --no-default-features --features macros,offline,any,all-types,postgres,runtime-{runtime}-{tls}",
comment=f"test postgres {version}",
@ -139,7 +139,7 @@ for runtime in ["async-std", "tokio", "actix"]:
)
## +ssl
for version in ["14", "13", "12", "11", "10", "9_6"]:
for version in ["14", "13", "12", "11", "10"]:
run(
f"cargo test --no-default-features --features macros,offline,any,all-types,postgres,runtime-{runtime}-{tls}",
comment=f"test postgres {version} ssl",
@ -152,7 +152,7 @@ for runtime in ["async-std", "tokio", "actix"]:
# mysql
#
for version in ["8", "5_7", "5_6"]:
for version in ["8", "5_7"]:
run(
f"cargo test --no-default-features --features macros,offline,any,all-types,mysql,runtime-{runtime}-{tls}",
comment=f"test mysql {version}",