Decode username on read

This commit is contained in:
James Waples 2020-03-31 14:17:23 +01:00
parent 1d7f5c54cb
commit f05f6bedd9
3 changed files with 51 additions and 8 deletions

View file

@ -1,3 +1,4 @@
use std::borrow::Cow;
use std::collections::HashMap;
use std::convert::TryInto;
use std::ops::Range;
@ -181,7 +182,7 @@ async fn establish(stream: &mut MySqlStream, url: &Url) -> crate::Result<()> {
HandshakeResponse {
client_collation: COLLATE_UTF8MB4_UNICODE_CI,
max_packet_size: MAX_PACKET_SIZE,
username: url.username().unwrap_or("root"),
username: &url.username().unwrap_or(Cow::Borrowed("root")),
database: url.database(),
auth_plugin: &auth_plugin,
auth_response: &auth_response,

View file

@ -110,13 +110,16 @@ pub struct PgConnection {
// https://www.postgresql.org/docs/12/protocol-flow.html#id-1.10.5.7.3
async fn startup(stream: &mut PgStream, url: &Url) -> crate::Result<BackendKeyData> {
// Defaults to postgres@.../postgres
let username = url.username().unwrap_or("postgres");
let username = url
.username()
.map(|c| c.into_owned())
.unwrap_or(String::from("postgres"));
let database = url.database().unwrap_or("postgres");
// See this doc for more runtime parameters
// https://www.postgresql.org/docs/12/runtime-config-client.html
let params = &[
("user", username),
("user", username.as_str()),
("database", database),
// Sets the display format for date and time values,
// as well as the rules for interpreting ambiguous date input values.
@ -160,7 +163,7 @@ async fn startup(stream: &mut PgStream, url: &Url) -> crate::Result<BackendKeyDa
stream.write(PasswordMessage::Md5 {
password: &url.password().unwrap_or_default(),
user: username,
user: username.as_str(),
salt: data.salt,
});
@ -193,8 +196,12 @@ async fn startup(stream: &mut PgStream, url: &Url) -> crate::Result<BackendKeyDa
if has_sasl || has_sasl_plus {
// TODO: Handle -PLUS differently if we're in a TLS stream
sasl::authenticate(stream, username, &url.password().unwrap_or_default())
.await?;
sasl::authenticate(
stream,
username.as_str(),
&url.password().unwrap_or_default(),
)
.await?;
} else {
return Err(protocol_err!(
"unsupported SASL auth mechanisms: {:?}",

View file

@ -48,13 +48,17 @@ impl Url {
self.0.port().unwrap_or(default)
}
pub fn username(&self) -> Option<&str> {
pub fn username(&self) -> Option<Cow<str>> {
let username = self.0.username();
if username.is_empty() {
None
} else {
Some(username)
Some(
percent_encoding::percent_decode_str(username)
.decode_utf8()
.expect("percent-encoded username contained non-UTF-8 bytes"),
)
}
}
@ -90,3 +94,34 @@ impl Url {
.find_map(|(key_, val)| if key == key_ { Some(val) } else { None })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn azure_connection_string_username_unencoded() {
let connection_string =
"postgres://username@servername:password@example.postgres.database.azure.com/db";
let url = Url::try_from(connection_string).expect("Failed to parse URL");
assert_eq!(
url.username().map(|u| u.to_string()),
Some(String::from("username@servername"))
);
}
#[test]
fn azure_connection_string_username_encoded() {
let connection_string =
"postgres://username%40servername:password@example.postgres.database.azure.com/db";
let url = Url::try_from(connection_string).expect("Failed to parse URL");
assert_eq!(
url.username().map(|u| u.to_string()),
Some(String::from("username@servername"))
);
}
}