mirror of
https://github.com/launchbadge/sqlx
synced 2024-11-10 14:34:19 +00:00
Decode username on read
This commit is contained in:
parent
1d7f5c54cb
commit
f05f6bedd9
3 changed files with 51 additions and 8 deletions
|
@ -1,3 +1,4 @@
|
|||
use std::borrow::Cow;
|
||||
use std::collections::HashMap;
|
||||
use std::convert::TryInto;
|
||||
use std::ops::Range;
|
||||
|
@ -181,7 +182,7 @@ async fn establish(stream: &mut MySqlStream, url: &Url) -> crate::Result<()> {
|
|||
HandshakeResponse {
|
||||
client_collation: COLLATE_UTF8MB4_UNICODE_CI,
|
||||
max_packet_size: MAX_PACKET_SIZE,
|
||||
username: url.username().unwrap_or("root"),
|
||||
username: &url.username().unwrap_or(Cow::Borrowed("root")),
|
||||
database: url.database(),
|
||||
auth_plugin: &auth_plugin,
|
||||
auth_response: &auth_response,
|
||||
|
|
|
@ -110,13 +110,16 @@ pub struct PgConnection {
|
|||
// https://www.postgresql.org/docs/12/protocol-flow.html#id-1.10.5.7.3
|
||||
async fn startup(stream: &mut PgStream, url: &Url) -> crate::Result<BackendKeyData> {
|
||||
// Defaults to postgres@.../postgres
|
||||
let username = url.username().unwrap_or("postgres");
|
||||
let username = url
|
||||
.username()
|
||||
.map(|c| c.into_owned())
|
||||
.unwrap_or(String::from("postgres"));
|
||||
let database = url.database().unwrap_or("postgres");
|
||||
|
||||
// See this doc for more runtime parameters
|
||||
// https://www.postgresql.org/docs/12/runtime-config-client.html
|
||||
let params = &[
|
||||
("user", username),
|
||||
("user", username.as_str()),
|
||||
("database", database),
|
||||
// Sets the display format for date and time values,
|
||||
// as well as the rules for interpreting ambiguous date input values.
|
||||
|
@ -160,7 +163,7 @@ async fn startup(stream: &mut PgStream, url: &Url) -> crate::Result<BackendKeyDa
|
|||
|
||||
stream.write(PasswordMessage::Md5 {
|
||||
password: &url.password().unwrap_or_default(),
|
||||
user: username,
|
||||
user: username.as_str(),
|
||||
salt: data.salt,
|
||||
});
|
||||
|
||||
|
@ -193,8 +196,12 @@ async fn startup(stream: &mut PgStream, url: &Url) -> crate::Result<BackendKeyDa
|
|||
|
||||
if has_sasl || has_sasl_plus {
|
||||
// TODO: Handle -PLUS differently if we're in a TLS stream
|
||||
sasl::authenticate(stream, username, &url.password().unwrap_or_default())
|
||||
.await?;
|
||||
sasl::authenticate(
|
||||
stream,
|
||||
username.as_str(),
|
||||
&url.password().unwrap_or_default(),
|
||||
)
|
||||
.await?;
|
||||
} else {
|
||||
return Err(protocol_err!(
|
||||
"unsupported SASL auth mechanisms: {:?}",
|
||||
|
|
|
@ -48,13 +48,17 @@ impl Url {
|
|||
self.0.port().unwrap_or(default)
|
||||
}
|
||||
|
||||
pub fn username(&self) -> Option<&str> {
|
||||
pub fn username(&self) -> Option<Cow<str>> {
|
||||
let username = self.0.username();
|
||||
|
||||
if username.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(username)
|
||||
Some(
|
||||
percent_encoding::percent_decode_str(username)
|
||||
.decode_utf8()
|
||||
.expect("percent-encoded username contained non-UTF-8 bytes"),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -90,3 +94,34 @@ impl Url {
|
|||
.find_map(|(key_, val)| if key == key_ { Some(val) } else { None })
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn azure_connection_string_username_unencoded() {
|
||||
let connection_string =
|
||||
"postgres://username@servername:password@example.postgres.database.azure.com/db";
|
||||
|
||||
let url = Url::try_from(connection_string).expect("Failed to parse URL");
|
||||
|
||||
assert_eq!(
|
||||
url.username().map(|u| u.to_string()),
|
||||
Some(String::from("username@servername"))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn azure_connection_string_username_encoded() {
|
||||
let connection_string =
|
||||
"postgres://username%40servername:password@example.postgres.database.azure.com/db";
|
||||
|
||||
let url = Url::try_from(connection_string).expect("Failed to parse URL");
|
||||
|
||||
assert_eq!(
|
||||
url.username().map(|u| u.to_string()),
|
||||
Some(String::from("username@servername"))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue