mirror of
https://github.com/dani-garcia/vaultwarden
synced 2024-11-25 13:20:23 +00:00
Implement custom DNS resolver (#3988)
This commit is contained in:
parent
2ad33ec97f
commit
27dc67fadd
6 changed files with 345 additions and 255 deletions
10
Cargo.lock
generated
10
Cargo.lock
generated
|
@ -2819,7 +2819,6 @@ dependencies = [
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"h2 0.4.4",
|
"h2 0.4.4",
|
||||||
"hickory-resolver",
|
|
||||||
"http 1.1.0",
|
"http 1.1.0",
|
||||||
"http-body 1.0.0",
|
"http-body 1.0.0",
|
||||||
"http-body-util",
|
"http-body-util",
|
||||||
|
@ -3187,9 +3186,9 @@ checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde"
|
name = "serde"
|
||||||
version = "1.0.198"
|
version = "1.0.199"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9846a40c979031340571da2545a4e5b7c4163bdae79b301d5f86d03979451fcc"
|
checksum = "0c9f6e76df036c77cd94996771fb40db98187f096dd0b9af39c6c6e452ba966a"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"serde_derive",
|
"serde_derive",
|
||||||
]
|
]
|
||||||
|
@ -3206,9 +3205,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_derive"
|
name = "serde_derive"
|
||||||
version = "1.0.198"
|
version = "1.0.199"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e88edab869b01783ba905e7d0153f9fc1a6505a96e4ad3018011eedb838566d9"
|
checksum = "11bd257a6541e141e42ca6d24ae26f7714887b47e89aa739099104c7e4d3b7fc"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
|
@ -3970,6 +3969,7 @@ dependencies = [
|
||||||
"futures",
|
"futures",
|
||||||
"governor",
|
"governor",
|
||||||
"handlebars",
|
"handlebars",
|
||||||
|
"hickory-resolver",
|
||||||
"html5gum",
|
"html5gum",
|
||||||
"job_scheduler_ng",
|
"job_scheduler_ng",
|
||||||
"jsonwebtoken",
|
"jsonwebtoken",
|
||||||
|
|
|
@ -67,7 +67,7 @@ dashmap = "5.5.3"
|
||||||
|
|
||||||
# Async futures
|
# Async futures
|
||||||
futures = "0.3.30"
|
futures = "0.3.30"
|
||||||
tokio = { version = "1.37.0", features = ["rt-multi-thread", "fs", "io-util", "parking_lot", "time", "signal"] }
|
tokio = { version = "1.37.0", features = ["rt-multi-thread", "fs", "io-util", "parking_lot", "time", "signal", "net"] }
|
||||||
|
|
||||||
# A generic serialization/deserialization framework
|
# A generic serialization/deserialization framework
|
||||||
serde = { version = "1.0.198", features = ["derive"] }
|
serde = { version = "1.0.198", features = ["derive"] }
|
||||||
|
@ -123,7 +123,8 @@ email_address = "0.2.4"
|
||||||
handlebars = { version = "5.1.2", features = ["dir_source"] }
|
handlebars = { version = "5.1.2", features = ["dir_source"] }
|
||||||
|
|
||||||
# HTTP client (Used for favicons, version check, DUO and HIBP API)
|
# HTTP client (Used for favicons, version check, DUO and HIBP API)
|
||||||
reqwest = { version = "0.12.4", features = ["native-tls-alpn", "stream", "json", "gzip", "brotli", "socks", "cookies", "hickory-dns"] }
|
reqwest = { version = "0.12.4", features = ["native-tls-alpn", "stream", "json", "gzip", "brotli", "socks", "cookies"] }
|
||||||
|
hickory-resolver = "0.24.1"
|
||||||
|
|
||||||
# Favicon extraction libraries
|
# Favicon extraction libraries
|
||||||
html5gum = "0.5.7"
|
html5gum = "0.5.7"
|
||||||
|
|
246
src/api/icons.rs
246
src/api/icons.rs
|
@ -1,6 +1,6 @@
|
||||||
use std::{
|
use std::{
|
||||||
net::IpAddr,
|
net::IpAddr,
|
||||||
sync::Arc,
|
sync::{Arc, Mutex},
|
||||||
time::{Duration, SystemTime},
|
time::{Duration, SystemTime},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -16,14 +16,13 @@ use rocket::{http::ContentType, response::Redirect, Route};
|
||||||
use tokio::{
|
use tokio::{
|
||||||
fs::{create_dir_all, remove_file, symlink_metadata, File},
|
fs::{create_dir_all, remove_file, symlink_metadata, File},
|
||||||
io::{AsyncReadExt, AsyncWriteExt},
|
io::{AsyncReadExt, AsyncWriteExt},
|
||||||
net::lookup_host,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use html5gum::{Emitter, HtmlString, InfallibleTokenizer, Readable, StringReader, Tokenizer};
|
use html5gum::{Emitter, HtmlString, InfallibleTokenizer, Readable, StringReader, Tokenizer};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
error::Error,
|
error::Error,
|
||||||
util::{get_reqwest_client_builder, Cached},
|
util::{get_reqwest_client_builder, Cached, CustomDnsResolver, CustomResolverError},
|
||||||
CONFIG,
|
CONFIG,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -49,48 +48,32 @@ static CLIENT: Lazy<Client> = Lazy::new(|| {
|
||||||
let icon_download_timeout = Duration::from_secs(CONFIG.icon_download_timeout());
|
let icon_download_timeout = Duration::from_secs(CONFIG.icon_download_timeout());
|
||||||
let pool_idle_timeout = Duration::from_secs(10);
|
let pool_idle_timeout = Duration::from_secs(10);
|
||||||
// Reuse the client between requests
|
// Reuse the client between requests
|
||||||
let client = get_reqwest_client_builder()
|
get_reqwest_client_builder()
|
||||||
.cookie_provider(Arc::clone(&cookie_store))
|
.cookie_provider(Arc::clone(&cookie_store))
|
||||||
.timeout(icon_download_timeout)
|
.timeout(icon_download_timeout)
|
||||||
.pool_max_idle_per_host(5) // Configure the Hyper Pool to only have max 5 idle connections
|
.pool_max_idle_per_host(5) // Configure the Hyper Pool to only have max 5 idle connections
|
||||||
.pool_idle_timeout(pool_idle_timeout) // Configure the Hyper Pool to timeout after 10 seconds
|
.pool_idle_timeout(pool_idle_timeout) // Configure the Hyper Pool to timeout after 10 seconds
|
||||||
.hickory_dns(true)
|
.dns_resolver(CustomDnsResolver::instance())
|
||||||
.default_headers(default_headers.clone());
|
.default_headers(default_headers.clone())
|
||||||
|
|
||||||
match client.build() {
|
|
||||||
Ok(client) => client,
|
|
||||||
Err(e) => {
|
|
||||||
error!("Possible trust-dns error, trying with trust-dns disabled: '{e}'");
|
|
||||||
get_reqwest_client_builder()
|
|
||||||
.cookie_provider(cookie_store)
|
|
||||||
.timeout(icon_download_timeout)
|
|
||||||
.pool_max_idle_per_host(5) // Configure the Hyper Pool to only have max 5 idle connections
|
|
||||||
.pool_idle_timeout(pool_idle_timeout) // Configure the Hyper Pool to timeout after 10 seconds
|
|
||||||
.hickory_dns(false)
|
|
||||||
.default_headers(default_headers)
|
|
||||||
.build()
|
.build()
|
||||||
.expect("Failed to build client")
|
.expect("Failed to build client")
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Build Regex only once since this takes a lot of time.
|
// Build Regex only once since this takes a lot of time.
|
||||||
static ICON_SIZE_REGEX: Lazy<Regex> = Lazy::new(|| Regex::new(r"(?x)(\d+)\D*(\d+)").unwrap());
|
static ICON_SIZE_REGEX: Lazy<Regex> = Lazy::new(|| Regex::new(r"(?x)(\d+)\D*(\d+)").unwrap());
|
||||||
|
|
||||||
// Special HashMap which holds the user defined Regex to speedup matching the regex.
|
#[get("/<domain>/icon.png")]
|
||||||
static ICON_BLACKLIST_REGEX: Lazy<dashmap::DashMap<String, Regex>> = Lazy::new(dashmap::DashMap::new);
|
fn icon_external(domain: &str) -> Option<Redirect> {
|
||||||
|
|
||||||
async fn icon_redirect(domain: &str, template: &str) -> Option<Redirect> {
|
|
||||||
if !is_valid_domain(domain) {
|
if !is_valid_domain(domain) {
|
||||||
warn!("Invalid domain: {}", domain);
|
warn!("Invalid domain: {}", domain);
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
if check_domain_blacklist_reason(domain).await.is_some() {
|
if is_domain_blacklisted(domain) {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
let url = template.replace("{}", domain);
|
let url = CONFIG._icon_service_url().replace("{}", domain);
|
||||||
match CONFIG.icon_redirect_code() {
|
match CONFIG.icon_redirect_code() {
|
||||||
301 => Some(Redirect::moved(url)), // legacy permanent redirect
|
301 => Some(Redirect::moved(url)), // legacy permanent redirect
|
||||||
302 => Some(Redirect::found(url)), // legacy temporary redirect
|
302 => Some(Redirect::found(url)), // legacy temporary redirect
|
||||||
|
@ -103,11 +86,6 @@ async fn icon_redirect(domain: &str, template: &str) -> Option<Redirect> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/<domain>/icon.png")]
|
|
||||||
async fn icon_external(domain: &str) -> Option<Redirect> {
|
|
||||||
icon_redirect(domain, &CONFIG._icon_service_url()).await
|
|
||||||
}
|
|
||||||
|
|
||||||
#[get("/<domain>/icon.png")]
|
#[get("/<domain>/icon.png")]
|
||||||
async fn icon_internal(domain: &str) -> Cached<(ContentType, Vec<u8>)> {
|
async fn icon_internal(domain: &str) -> Cached<(ContentType, Vec<u8>)> {
|
||||||
const FALLBACK_ICON: &[u8] = include_bytes!("../static/images/fallback-icon.png");
|
const FALLBACK_ICON: &[u8] = include_bytes!("../static/images/fallback-icon.png");
|
||||||
|
@ -166,153 +144,28 @@ fn is_valid_domain(domain: &str) -> bool {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
/// TODO: This is extracted from IpAddr::is_global, which is unstable:
|
pub fn is_domain_blacklisted(domain: &str) -> bool {
|
||||||
/// https://doc.rust-lang.org/nightly/std/net/enum.IpAddr.html#method.is_global
|
let Some(config_blacklist) = CONFIG.icon_blacklist_regex() else {
|
||||||
/// Remove once https://github.com/rust-lang/rust/issues/27709 is merged
|
return false;
|
||||||
#[allow(clippy::nonminimal_bool)]
|
|
||||||
#[cfg(not(feature = "unstable"))]
|
|
||||||
fn is_global(ip: IpAddr) -> bool {
|
|
||||||
match ip {
|
|
||||||
IpAddr::V4(ip) => {
|
|
||||||
// check if this address is 192.0.0.9 or 192.0.0.10. These addresses are the only two
|
|
||||||
// globally routable addresses in the 192.0.0.0/24 range.
|
|
||||||
if u32::from(ip) == 0xc0000009 || u32::from(ip) == 0xc000000a {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
!ip.is_private()
|
|
||||||
&& !ip.is_loopback()
|
|
||||||
&& !ip.is_link_local()
|
|
||||||
&& !ip.is_broadcast()
|
|
||||||
&& !ip.is_documentation()
|
|
||||||
&& !(ip.octets()[0] == 100 && (ip.octets()[1] & 0b1100_0000 == 0b0100_0000))
|
|
||||||
&& !(ip.octets()[0] == 192 && ip.octets()[1] == 0 && ip.octets()[2] == 0)
|
|
||||||
&& !(ip.octets()[0] & 240 == 240 && !ip.is_broadcast())
|
|
||||||
&& !(ip.octets()[0] == 198 && (ip.octets()[1] & 0xfe) == 18)
|
|
||||||
// Make sure the address is not in 0.0.0.0/8
|
|
||||||
&& ip.octets()[0] != 0
|
|
||||||
}
|
|
||||||
IpAddr::V6(ip) => {
|
|
||||||
if ip.is_multicast() && ip.segments()[0] & 0x000f == 14 {
|
|
||||||
true
|
|
||||||
} else {
|
|
||||||
!ip.is_multicast()
|
|
||||||
&& !ip.is_loopback()
|
|
||||||
&& !((ip.segments()[0] & 0xffc0) == 0xfe80)
|
|
||||||
&& !((ip.segments()[0] & 0xfe00) == 0xfc00)
|
|
||||||
&& !ip.is_unspecified()
|
|
||||||
&& !((ip.segments()[0] == 0x2001) && (ip.segments()[1] == 0xdb8))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "unstable")]
|
|
||||||
fn is_global(ip: IpAddr) -> bool {
|
|
||||||
ip.is_global()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// These are some tests to check that the implementations match
|
|
||||||
/// The IPv4 can be all checked in 5 mins or so and they are correct as of nightly 2020-07-11
|
|
||||||
/// The IPV6 can't be checked in a reasonable time, so we check about ten billion random ones, so far correct
|
|
||||||
/// Note that the is_global implementation is subject to change as new IP RFCs are created
|
|
||||||
///
|
|
||||||
/// To run while showing progress output:
|
|
||||||
/// cargo test --features sqlite,unstable -- --nocapture --ignored
|
|
||||||
#[cfg(test)]
|
|
||||||
#[cfg(feature = "unstable")]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
#[ignore]
|
|
||||||
fn test_ipv4_global() {
|
|
||||||
for a in 0..u8::MAX {
|
|
||||||
println!("Iter: {}/255", a);
|
|
||||||
for b in 0..u8::MAX {
|
|
||||||
for c in 0..u8::MAX {
|
|
||||||
for d in 0..u8::MAX {
|
|
||||||
let ip = IpAddr::V4(std::net::Ipv4Addr::new(a, b, c, d));
|
|
||||||
assert_eq!(ip.is_global(), is_global(ip))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
#[ignore]
|
|
||||||
fn test_ipv6_global() {
|
|
||||||
use ring::rand::{SecureRandom, SystemRandom};
|
|
||||||
let mut v = [0u8; 16];
|
|
||||||
let rand = SystemRandom::new();
|
|
||||||
for i in 0..1_000 {
|
|
||||||
println!("Iter: {}/1_000", i);
|
|
||||||
for _ in 0..10_000_000 {
|
|
||||||
rand.fill(&mut v).expect("Error generating random values");
|
|
||||||
let ip = IpAddr::V6(std::net::Ipv6Addr::new(
|
|
||||||
(v[14] as u16) << 8 | v[15] as u16,
|
|
||||||
(v[12] as u16) << 8 | v[13] as u16,
|
|
||||||
(v[10] as u16) << 8 | v[11] as u16,
|
|
||||||
(v[8] as u16) << 8 | v[9] as u16,
|
|
||||||
(v[6] as u16) << 8 | v[7] as u16,
|
|
||||||
(v[4] as u16) << 8 | v[5] as u16,
|
|
||||||
(v[2] as u16) << 8 | v[3] as u16,
|
|
||||||
(v[0] as u16) << 8 | v[1] as u16,
|
|
||||||
));
|
|
||||||
assert_eq!(ip.is_global(), is_global(ip))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
enum DomainBlacklistReason {
|
|
||||||
Regex,
|
|
||||||
IP,
|
|
||||||
}
|
|
||||||
|
|
||||||
use cached::proc_macro::cached;
|
|
||||||
#[cached(key = "String", convert = r#"{ domain.to_string() }"#, size = 16, time = 60)]
|
|
||||||
async fn check_domain_blacklist_reason(domain: &str) -> Option<DomainBlacklistReason> {
|
|
||||||
// First check the blacklist regex if there is a match.
|
|
||||||
// This prevents the blocked domain(s) from being leaked via a DNS lookup.
|
|
||||||
if let Some(blacklist) = CONFIG.icon_blacklist_regex() {
|
|
||||||
// Use the pre-generate Regex stored in a Lazy HashMap if there's one, else generate it.
|
|
||||||
let is_match = if let Some(regex) = ICON_BLACKLIST_REGEX.get(&blacklist) {
|
|
||||||
regex.is_match(domain)
|
|
||||||
} else {
|
|
||||||
// Clear the current list if the previous key doesn't exists.
|
|
||||||
// To prevent growing of the HashMap after someone has changed it via the admin interface.
|
|
||||||
if ICON_BLACKLIST_REGEX.len() >= 1 {
|
|
||||||
ICON_BLACKLIST_REGEX.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate the regex to store in too the Lazy Static HashMap.
|
|
||||||
let blacklist_regex = Regex::new(&blacklist).unwrap();
|
|
||||||
let is_match = blacklist_regex.is_match(domain);
|
|
||||||
ICON_BLACKLIST_REGEX.insert(blacklist.clone(), blacklist_regex);
|
|
||||||
|
|
||||||
is_match
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if is_match {
|
// Compiled domain blacklist
|
||||||
debug!("Blacklisted domain: {} matched ICON_BLACKLIST_REGEX", domain);
|
static COMPILED_BLACKLIST: Mutex<Option<(String, Regex)>> = Mutex::new(None);
|
||||||
return Some(DomainBlacklistReason::Regex);
|
let mut guard = COMPILED_BLACKLIST.lock().unwrap();
|
||||||
|
|
||||||
|
// If the stored regex is up to date, use it
|
||||||
|
if let Some((value, regex)) = &*guard {
|
||||||
|
if value == &config_blacklist {
|
||||||
|
return regex.is_match(domain);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if CONFIG.icon_blacklist_non_global_ips() {
|
// If we don't have a regex stored, or it's not up to date, recreate it
|
||||||
if let Ok(s) = lookup_host((domain, 0)).await {
|
let regex = Regex::new(&config_blacklist).unwrap();
|
||||||
for addr in s {
|
let is_match = regex.is_match(domain);
|
||||||
if !is_global(addr.ip()) {
|
*guard = Some((config_blacklist, regex));
|
||||||
debug!("IP {} for domain '{}' is not a global IP!", addr.ip(), domain);
|
|
||||||
return Some(DomainBlacklistReason::IP);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
None
|
is_match
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
|
async fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
|
||||||
|
@ -342,6 +195,13 @@ async fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
|
||||||
Some((icon.to_vec(), icon_type.unwrap_or("x-icon").to_string()))
|
Some((icon.to_vec(), icon_type.unwrap_or("x-icon").to_string()))
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
// If this error comes from the custom resolver, this means this is a blacklisted domain
|
||||||
|
// or non global IP, don't save the miss file in this case to avoid leaking it
|
||||||
|
if let Some(error) = CustomResolverError::downcast_ref(&e) {
|
||||||
|
warn!("{error}");
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
warn!("Unable to download icon: {:?}", e);
|
warn!("Unable to download icon: {:?}", e);
|
||||||
let miss_indicator = path + ".miss";
|
let miss_indicator = path + ".miss";
|
||||||
save_icon(&miss_indicator, &[]).await;
|
save_icon(&miss_indicator, &[]).await;
|
||||||
|
@ -491,12 +351,12 @@ async fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
|
||||||
let ssldomain = format!("https://{domain}");
|
let ssldomain = format!("https://{domain}");
|
||||||
let httpdomain = format!("http://{domain}");
|
let httpdomain = format!("http://{domain}");
|
||||||
|
|
||||||
// First check the domain as given during the request for both HTTPS and HTTP.
|
// First check the domain as given during the request for HTTPS.
|
||||||
let resp = match get_page(&ssldomain).or_else(|_| get_page(&httpdomain)).await {
|
let resp = match get_page(&ssldomain).await {
|
||||||
Ok(c) => Ok(c),
|
Err(e) if CustomResolverError::downcast_ref(&e).is_none() => {
|
||||||
Err(e) => {
|
// If we get an error that is not caused by the blacklist, we retry with HTTP
|
||||||
let mut sub_resp = Err(e);
|
match get_page(&httpdomain).await {
|
||||||
|
mut sub_resp @ Err(_) => {
|
||||||
// When the domain is not an IP, and has more then one dot, remove all subdomains.
|
// When the domain is not an IP, and has more then one dot, remove all subdomains.
|
||||||
let is_ip = domain.parse::<IpAddr>();
|
let is_ip = domain.parse::<IpAddr>();
|
||||||
if is_ip.is_err() && domain.matches('.').count() > 1 {
|
if is_ip.is_err() && domain.matches('.').count() > 1 {
|
||||||
|
@ -527,6 +387,12 @@ async fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
|
||||||
}
|
}
|
||||||
sub_resp
|
sub_resp
|
||||||
}
|
}
|
||||||
|
res => res,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we get a result or a blacklist error, just continue
|
||||||
|
res => res,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create the iconlist
|
// Create the iconlist
|
||||||
|
@ -573,21 +439,12 @@ async fn get_page(url: &str) -> Result<Response, Error> {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_page_with_referer(url: &str, referer: &str) -> Result<Response, Error> {
|
async fn get_page_with_referer(url: &str, referer: &str) -> Result<Response, Error> {
|
||||||
match check_domain_blacklist_reason(url::Url::parse(url).unwrap().host_str().unwrap_or_default()).await {
|
|
||||||
Some(DomainBlacklistReason::Regex) => warn!("Favicon '{}' is from a blacklisted domain!", url),
|
|
||||||
Some(DomainBlacklistReason::IP) => warn!("Favicon '{}' is hosted on a non-global IP!", url),
|
|
||||||
None => (),
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut client = CLIENT.get(url);
|
let mut client = CLIENT.get(url);
|
||||||
if !referer.is_empty() {
|
if !referer.is_empty() {
|
||||||
client = client.header("Referer", referer)
|
client = client.header("Referer", referer)
|
||||||
}
|
}
|
||||||
|
|
||||||
match client.send().await {
|
Ok(client.send().await?.error_for_status()?)
|
||||||
Ok(c) => c.error_for_status().map_err(Into::into),
|
|
||||||
Err(e) => err_silent!(format!("{e}")),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a Integer with the priority of the type of the icon which to prefer.
|
/// Returns a Integer with the priority of the type of the icon which to prefer.
|
||||||
|
@ -670,12 +527,6 @@ fn parse_sizes(sizes: &str) -> (u16, u16) {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn download_icon(domain: &str) -> Result<(Bytes, Option<&str>), Error> {
|
async fn download_icon(domain: &str) -> Result<(Bytes, Option<&str>), Error> {
|
||||||
match check_domain_blacklist_reason(domain).await {
|
|
||||||
Some(DomainBlacklistReason::Regex) => err_silent!("Domain is blacklisted", domain),
|
|
||||||
Some(DomainBlacklistReason::IP) => err_silent!("Host resolves to a non-global IP", domain),
|
|
||||||
None => (),
|
|
||||||
}
|
|
||||||
|
|
||||||
let icon_result = get_icon_url(domain).await?;
|
let icon_result = get_icon_url(domain).await?;
|
||||||
|
|
||||||
let mut buffer = Bytes::new();
|
let mut buffer = Bytes::new();
|
||||||
|
@ -711,8 +562,8 @@ async fn download_icon(domain: &str) -> Result<(Bytes, Option<&str>), Error> {
|
||||||
_ => debug!("Extracted icon from data:image uri is invalid"),
|
_ => debug!("Extracted icon from data:image uri is invalid"),
|
||||||
};
|
};
|
||||||
} else {
|
} else {
|
||||||
match get_page_with_referer(&icon.href, &icon_result.referer).await {
|
let res = get_page_with_referer(&icon.href, &icon_result.referer).await?;
|
||||||
Ok(res) => {
|
|
||||||
buffer = stream_to_bytes_limit(res, 5120 * 1024).await?; // 5120KB/5MB for each icon max (Same as icons.bitwarden.net)
|
buffer = stream_to_bytes_limit(res, 5120 * 1024).await?; // 5120KB/5MB for each icon max (Same as icons.bitwarden.net)
|
||||||
|
|
||||||
// Check if the icon type is allowed, else try an icon from the list.
|
// Check if the icon type is allowed, else try an icon from the list.
|
||||||
|
@ -725,9 +576,6 @@ async fn download_icon(domain: &str) -> Result<(Bytes, Option<&str>), Error> {
|
||||||
info!("Downloaded icon from {}", icon.href);
|
info!("Downloaded icon from {}", icon.href);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
Err(e) => debug!("{:?}", e),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if buffer.is_empty() {
|
if buffer.is_empty() {
|
||||||
|
|
|
@ -20,7 +20,7 @@ pub use crate::api::{
|
||||||
core::two_factor::send_incomplete_2fa_notifications,
|
core::two_factor::send_incomplete_2fa_notifications,
|
||||||
core::{emergency_notification_reminder_job, emergency_request_timeout_job},
|
core::{emergency_notification_reminder_job, emergency_request_timeout_job},
|
||||||
core::{event_cleanup_job, events_routes as core_events_routes},
|
core::{event_cleanup_job, events_routes as core_events_routes},
|
||||||
icons::routes as icons_routes,
|
icons::{is_domain_blacklisted, routes as icons_routes},
|
||||||
identity::routes as identity_routes,
|
identity::routes as identity_routes,
|
||||||
notifications::routes as notifications_routes,
|
notifications::routes as notifications_routes,
|
||||||
notifications::{AnonymousNotify, Notify, UpdateType, WS_ANONYMOUS_SUBSCRIPTIONS, WS_USERS},
|
notifications::{AnonymousNotify, Notify, UpdateType, WS_ANONYMOUS_SUBSCRIPTIONS, WS_USERS},
|
||||||
|
|
|
@ -211,8 +211,8 @@ fn launch_info() {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn init_logging(level: log::LevelFilter) -> Result<(), fern::InitError> {
|
fn init_logging(level: log::LevelFilter) -> Result<(), fern::InitError> {
|
||||||
// Depending on the main log level we either want to disable or enable logging for trust-dns.
|
// Depending on the main log level we either want to disable or enable logging for hickory.
|
||||||
// Else if there are timeouts it will clutter the logs since trust-dns uses warn for this.
|
// Else if there are timeouts it will clutter the logs since hickory uses warn for this.
|
||||||
let hickory_level = if level >= log::LevelFilter::Debug {
|
let hickory_level = if level >= log::LevelFilter::Debug {
|
||||||
level
|
level
|
||||||
} else {
|
} else {
|
||||||
|
@ -266,7 +266,7 @@ fn init_logging(level: log::LevelFilter) -> Result<(), fern::InitError> {
|
||||||
.level_for("handlebars::render", handlebars_level)
|
.level_for("handlebars::render", handlebars_level)
|
||||||
// Prevent cookie_store logs
|
// Prevent cookie_store logs
|
||||||
.level_for("cookie_store", log::LevelFilter::Off)
|
.level_for("cookie_store", log::LevelFilter::Off)
|
||||||
// Variable level for trust-dns used by reqwest
|
// Variable level for hickory used by reqwest
|
||||||
.level_for("hickory_resolver::name_server::name_server", hickory_level)
|
.level_for("hickory_resolver::name_server::name_server", hickory_level)
|
||||||
.level_for("hickory_proto::xfer", hickory_level)
|
.level_for("hickory_proto::xfer", hickory_level)
|
||||||
.level_for("diesel_logger", diesel_logger_level)
|
.level_for("diesel_logger", diesel_logger_level)
|
||||||
|
|
257
src/util.rs
257
src/util.rs
|
@ -4,6 +4,7 @@
|
||||||
use std::{collections::HashMap, io::Cursor, ops::Deref, path::Path};
|
use std::{collections::HashMap, io::Cursor, ops::Deref, path::Path};
|
||||||
|
|
||||||
use num_traits::ToPrimitive;
|
use num_traits::ToPrimitive;
|
||||||
|
use once_cell::sync::Lazy;
|
||||||
use rocket::{
|
use rocket::{
|
||||||
fairing::{Fairing, Info, Kind},
|
fairing::{Fairing, Info, Kind},
|
||||||
http::{ContentType, Header, HeaderMap, Method, Status},
|
http::{ContentType, Header, HeaderMap, Method, Status},
|
||||||
|
@ -701,14 +702,9 @@ where
|
||||||
|
|
||||||
use reqwest::{header, Client, ClientBuilder};
|
use reqwest::{header, Client, ClientBuilder};
|
||||||
|
|
||||||
pub fn get_reqwest_client() -> Client {
|
pub fn get_reqwest_client() -> &'static Client {
|
||||||
match get_reqwest_client_builder().build() {
|
static INSTANCE: Lazy<Client> = Lazy::new(|| get_reqwest_client_builder().build().expect("Failed to build client"));
|
||||||
Ok(client) => client,
|
&INSTANCE
|
||||||
Err(e) => {
|
|
||||||
error!("Possible trust-dns error, trying with trust-dns disabled: '{e}'");
|
|
||||||
get_reqwest_client_builder().hickory_dns(false).build().expect("Failed to build client")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_reqwest_client_builder() -> ClientBuilder {
|
pub fn get_reqwest_client_builder() -> ClientBuilder {
|
||||||
|
@ -767,3 +763,248 @@ pub fn parse_experimental_client_feature_flags(experimental_client_feature_flags
|
||||||
|
|
||||||
feature_states
|
feature_states
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mod dns_resolver {
|
||||||
|
use std::{
|
||||||
|
fmt,
|
||||||
|
net::{IpAddr, SocketAddr},
|
||||||
|
sync::Arc,
|
||||||
|
};
|
||||||
|
|
||||||
|
use hickory_resolver::{system_conf::read_system_conf, TokioAsyncResolver};
|
||||||
|
use once_cell::sync::Lazy;
|
||||||
|
use reqwest::dns::{Name, Resolve, Resolving};
|
||||||
|
|
||||||
|
use crate::{util::is_global, CONFIG};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum CustomResolverError {
|
||||||
|
Blacklist {
|
||||||
|
domain: String,
|
||||||
|
},
|
||||||
|
NonGlobalIp {
|
||||||
|
domain: String,
|
||||||
|
ip: IpAddr,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CustomResolverError {
|
||||||
|
pub fn downcast_ref(e: &dyn std::error::Error) -> Option<&Self> {
|
||||||
|
let mut source = e.source();
|
||||||
|
|
||||||
|
while let Some(err) = source {
|
||||||
|
source = err.source();
|
||||||
|
if let Some(err) = err.downcast_ref::<CustomResolverError>() {
|
||||||
|
return Some(err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for CustomResolverError {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
match self {
|
||||||
|
Self::Blacklist {
|
||||||
|
domain,
|
||||||
|
} => write!(f, "Blacklisted domain: {domain} matched ICON_BLACKLIST_REGEX"),
|
||||||
|
Self::NonGlobalIp {
|
||||||
|
domain,
|
||||||
|
ip,
|
||||||
|
} => write!(f, "IP {ip} for domain '{domain}' is not a global IP!"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for CustomResolverError {}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum CustomDnsResolver {
|
||||||
|
Default(),
|
||||||
|
Hickory(Arc<TokioAsyncResolver>),
|
||||||
|
}
|
||||||
|
type BoxError = Box<dyn std::error::Error + Send + Sync>;
|
||||||
|
|
||||||
|
impl CustomDnsResolver {
|
||||||
|
pub fn instance() -> Arc<Self> {
|
||||||
|
static INSTANCE: Lazy<Arc<CustomDnsResolver>> = Lazy::new(CustomDnsResolver::new);
|
||||||
|
Arc::clone(&*INSTANCE)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new() -> Arc<Self> {
|
||||||
|
match read_system_conf() {
|
||||||
|
Ok((config, opts)) => {
|
||||||
|
let resolver = TokioAsyncResolver::tokio(config.clone(), opts.clone());
|
||||||
|
Arc::new(Self::Hickory(Arc::new(resolver)))
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Error creating Hickory resolver, falling back to default: {e:?}");
|
||||||
|
Arc::new(Self::Default())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note that we get an iterator of addresses, but we only grab the first one for convenience
|
||||||
|
async fn resolve_domain(&self, name: &str) -> Result<Option<SocketAddr>, BoxError> {
|
||||||
|
pre_resolve(name)?;
|
||||||
|
|
||||||
|
let result = match self {
|
||||||
|
Self::Default() => tokio::net::lookup_host(name).await?.next(),
|
||||||
|
Self::Hickory(r) => r.lookup_ip(name).await?.iter().next().map(|a| SocketAddr::new(a, 0)),
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(addr) = &result {
|
||||||
|
post_resolve(name, addr.ip())?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn pre_resolve(name: &str) -> Result<(), CustomResolverError> {
|
||||||
|
if crate::api::is_domain_blacklisted(name) {
|
||||||
|
return Err(CustomResolverError::Blacklist {
|
||||||
|
domain: name.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn post_resolve(name: &str, ip: IpAddr) -> Result<(), CustomResolverError> {
|
||||||
|
if CONFIG.icon_blacklist_non_global_ips() && !is_global(ip) {
|
||||||
|
Err(CustomResolverError::NonGlobalIp {
|
||||||
|
domain: name.to_string(),
|
||||||
|
ip,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Resolve for CustomDnsResolver {
|
||||||
|
fn resolve(&self, name: Name) -> Resolving {
|
||||||
|
let this = self.clone();
|
||||||
|
Box::pin(async move {
|
||||||
|
let name = name.as_str();
|
||||||
|
let result = this.resolve_domain(name).await?;
|
||||||
|
Ok::<reqwest::dns::Addrs, _>(Box::new(result.into_iter()))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub use dns_resolver::{CustomDnsResolver, CustomResolverError};
|
||||||
|
|
||||||
|
/// TODO: This is extracted from IpAddr::is_global, which is unstable:
|
||||||
|
/// https://doc.rust-lang.org/nightly/std/net/enum.IpAddr.html#method.is_global
|
||||||
|
/// Remove once https://github.com/rust-lang/rust/issues/27709 is merged
|
||||||
|
#[allow(clippy::nonminimal_bool)]
|
||||||
|
#[cfg(any(not(feature = "unstable"), test))]
|
||||||
|
pub fn is_global_hardcoded(ip: std::net::IpAddr) -> bool {
|
||||||
|
match ip {
|
||||||
|
std::net::IpAddr::V4(ip) => {
|
||||||
|
!(ip.octets()[0] == 0 // "This network"
|
||||||
|
|| ip.is_private()
|
||||||
|
|| (ip.octets()[0] == 100 && (ip.octets()[1] & 0b1100_0000 == 0b0100_0000)) //ip.is_shared()
|
||||||
|
|| ip.is_loopback()
|
||||||
|
|| ip.is_link_local()
|
||||||
|
// addresses reserved for future protocols (`192.0.0.0/24`)
|
||||||
|
||(ip.octets()[0] == 192 && ip.octets()[1] == 0 && ip.octets()[2] == 0)
|
||||||
|
|| ip.is_documentation()
|
||||||
|
|| (ip.octets()[0] == 198 && (ip.octets()[1] & 0xfe) == 18) // ip.is_benchmarking()
|
||||||
|
|| (ip.octets()[0] & 240 == 240 && !ip.is_broadcast()) //ip.is_reserved()
|
||||||
|
|| ip.is_broadcast())
|
||||||
|
}
|
||||||
|
std::net::IpAddr::V6(ip) => {
|
||||||
|
!(ip.is_unspecified()
|
||||||
|
|| ip.is_loopback()
|
||||||
|
// IPv4-mapped Address (`::ffff:0:0/96`)
|
||||||
|
|| matches!(ip.segments(), [0, 0, 0, 0, 0, 0xffff, _, _])
|
||||||
|
// IPv4-IPv6 Translat. (`64:ff9b:1::/48`)
|
||||||
|
|| matches!(ip.segments(), [0x64, 0xff9b, 1, _, _, _, _, _])
|
||||||
|
// Discard-Only Address Block (`100::/64`)
|
||||||
|
|| matches!(ip.segments(), [0x100, 0, 0, 0, _, _, _, _])
|
||||||
|
// IETF Protocol Assignments (`2001::/23`)
|
||||||
|
|| (matches!(ip.segments(), [0x2001, b, _, _, _, _, _, _] if b < 0x200)
|
||||||
|
&& !(
|
||||||
|
// Port Control Protocol Anycast (`2001:1::1`)
|
||||||
|
u128::from_be_bytes(ip.octets()) == 0x2001_0001_0000_0000_0000_0000_0000_0001
|
||||||
|
// Traversal Using Relays around NAT Anycast (`2001:1::2`)
|
||||||
|
|| u128::from_be_bytes(ip.octets()) == 0x2001_0001_0000_0000_0000_0000_0000_0002
|
||||||
|
// AMT (`2001:3::/32`)
|
||||||
|
|| matches!(ip.segments(), [0x2001, 3, _, _, _, _, _, _])
|
||||||
|
// AS112-v6 (`2001:4:112::/48`)
|
||||||
|
|| matches!(ip.segments(), [0x2001, 4, 0x112, _, _, _, _, _])
|
||||||
|
// ORCHIDv2 (`2001:20::/28`)
|
||||||
|
|| matches!(ip.segments(), [0x2001, b, _, _, _, _, _, _] if (0x20..=0x2F).contains(&b))
|
||||||
|
))
|
||||||
|
|| ((ip.segments()[0] == 0x2001) && (ip.segments()[1] == 0xdb8)) // ip.is_documentation()
|
||||||
|
|| ((ip.segments()[0] & 0xfe00) == 0xfc00) //ip.is_unique_local()
|
||||||
|
|| ((ip.segments()[0] & 0xffc0) == 0xfe80)) //ip.is_unicast_link_local()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "unstable"))]
|
||||||
|
pub use is_global_hardcoded as is_global;
|
||||||
|
|
||||||
|
#[cfg(feature = "unstable")]
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn is_global(ip: std::net::IpAddr) -> bool {
|
||||||
|
ip.is_global()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// These are some tests to check that the implementations match
|
||||||
|
/// The IPv4 can be all checked in 30 seconds or so and they are correct as of nightly 2023-07-17
|
||||||
|
/// The IPV6 can't be checked in a reasonable time, so we check over a hundred billion random ones, so far correct
|
||||||
|
/// Note that the is_global implementation is subject to change as new IP RFCs are created
|
||||||
|
///
|
||||||
|
/// To run while showing progress output:
|
||||||
|
/// cargo +nightly test --release --features sqlite,unstable -- --nocapture --ignored
|
||||||
|
#[cfg(test)]
|
||||||
|
#[cfg(feature = "unstable")]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::net::IpAddr;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[ignore]
|
||||||
|
fn test_ipv4_global() {
|
||||||
|
for a in 0..u8::MAX {
|
||||||
|
println!("Iter: {}/255", a);
|
||||||
|
for b in 0..u8::MAX {
|
||||||
|
for c in 0..u8::MAX {
|
||||||
|
for d in 0..u8::MAX {
|
||||||
|
let ip = IpAddr::V4(std::net::Ipv4Addr::new(a, b, c, d));
|
||||||
|
assert_eq!(ip.is_global(), is_global_hardcoded(ip), "IP mismatch: {}", ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[ignore]
|
||||||
|
fn test_ipv6_global() {
|
||||||
|
use rand::Rng;
|
||||||
|
|
||||||
|
std::thread::scope(|s| {
|
||||||
|
for t in 0..16 {
|
||||||
|
let handle = s.spawn(move || {
|
||||||
|
let mut v = [0u8; 16];
|
||||||
|
let mut rng = rand::thread_rng();
|
||||||
|
|
||||||
|
for i in 0..20 {
|
||||||
|
println!("Thread {t} Iter: {i}/50");
|
||||||
|
for _ in 0..500_000_000 {
|
||||||
|
rng.fill(&mut v);
|
||||||
|
let ip = IpAddr::V6(std::net::Ipv6Addr::from(v));
|
||||||
|
assert_eq!(ip.is_global(), is_global_hardcoded(ip), "IP mismatch: {ip}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue