diff --git a/Cargo.lock b/Cargo.lock index 19df999a..209adea4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2018,8 +2018,7 @@ dependencies = [ [[package]] name = "reqwest" version = "0.11.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87f242f1488a539a79bac6dbe7c8609ae43b7914b7736210f239a37cccb32525" +source = "git+https://github.com/niuhuan/reqwest?branch=dns-resolver-fn#57b7cf4feb921573dfafad7d34b9ac6e44ead0bd" dependencies = [ "base64 0.13.0", "bytes", diff --git a/Cargo.toml b/Cargo.toml index 1e1b188f..aac840b5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,7 +48,7 @@ rand = "0.8.4" # Used to hash passwords rust-argon2 = "0.8.3" # Used to send requests -reqwest = { version = "0.11.4", default-features = false, features = ["rustls-tls", "socks"] } +reqwest = { version = "0.11.4", default-features = false, features = ["rustls-tls", "socks"], git = "https://github.com/niuhuan/reqwest", branch = "dns-resolver-fn" } # Used for conduit::Error type thiserror = "1.0.28" # Used to generate thumbnails for images diff --git a/src/appservice_server.rs b/src/appservice_server.rs index 0152c386..b2154b8d 100644 --- a/src/appservice_server.rs +++ b/src/appservice_server.rs @@ -41,11 +41,7 @@ where *reqwest_request.timeout_mut() = Some(Duration::from_secs(30)); let url = reqwest_request.url().clone(); - let mut response = globals - .reqwest_client()? - .build()? - .execute(reqwest_request) - .await?; + let mut response = globals.default_client().execute(reqwest_request).await?; // reqwest::Response -> http::Response conversion let status = response.status(); diff --git a/src/database/globals.rs b/src/database/globals.rs index 098d8197..decd84c3 100644 --- a/src/database/globals.rs +++ b/src/database/globals.rs @@ -10,7 +10,7 @@ use std::{ collections::{BTreeMap, HashMap}, fs, future::Future, - net::IpAddr, + net::{IpAddr, SocketAddr}, path::PathBuf, sync::{Arc, Mutex, RwLock}, time::{Duration, Instant}, @@ -39,6 +39,8 @@ pub struct Globals { keypair: Arc, dns_resolver: TokioAsyncResolver, jwt_decoding_key: Option>, + federation_client: reqwest::Client, + default_client: reqwest::Client, pub(super) server_signingkeys: Arc, pub bad_event_ratelimiter: Arc, RateLimitState>>>, pub bad_signature_ratelimiter: Arc, RateLimitState>>>, @@ -132,6 +134,17 @@ impl Globals { .as_ref() .map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes()).into_static()); + let default_client = reqwest_client_builder(&config)?.build()?; + let name_override = Arc::clone(&tls_name_override); + let federation_client = reqwest_client_builder(&config)? + .resolve_fn(move |domain| { + let read_guard = name_override.read().unwrap(); + let (override_name, port) = read_guard.get(&domain)?; + let first_name = override_name.get(0)?; + Some(SocketAddr::new(*first_name, *port)) + }) + .build()?; + let s = Self { globals, config, @@ -141,6 +154,8 @@ impl Globals { })?, actual_destination_cache: Arc::new(RwLock::new(WellKnownMap::new())), tls_name_override, + federation_client, + default_client, server_signingkeys, jwt_decoding_key, bad_event_ratelimiter: Arc::new(RwLock::new(HashMap::new())), @@ -163,17 +178,16 @@ impl Globals { &self.keypair } - /// Returns a reqwest client which can be used to send requests. - pub fn reqwest_client(&self) -> Result { - let mut reqwest_client_builder = reqwest::Client::builder() - .connect_timeout(Duration::from_secs(30)) - .timeout(Duration::from_secs(60 * 3)) - .pool_max_idle_per_host(1); - if let Some(proxy) = self.config.proxy.to_proxy()? { - reqwest_client_builder = reqwest_client_builder.proxy(proxy); - } + /// Returns a reqwest client which can be used to send requests + pub fn default_client(&self) -> reqwest::Client { + // Client is cheap to clone (Arc wrapper) and avoids lifetime issues + self.default_client.clone() + } - Ok(reqwest_client_builder) + /// Returns a client used for resolving .well-knowns + pub fn federation_client(&self) -> reqwest::Client { + // Client is cheap to clone (Arc wrapper) and avoids lifetime issues + self.federation_client.clone() } #[tracing::instrument(skip(self))] @@ -340,3 +354,15 @@ impl Globals { r } } + +fn reqwest_client_builder(config: &Config) -> Result { + let mut reqwest_client_builder = reqwest::Client::builder() + .connect_timeout(Duration::from_secs(30)) + .timeout(Duration::from_secs(60 * 3)); + + if let Some(proxy) = config.proxy.to_proxy()? { + reqwest_client_builder = reqwest_client_builder.proxy(proxy); + } + + Ok(reqwest_client_builder) +} diff --git a/src/database/pusher.rs b/src/database/pusher.rs index f401834a..e73ab061 100644 --- a/src/database/pusher.rs +++ b/src/database/pusher.rs @@ -115,11 +115,7 @@ where //*reqwest_request.timeout_mut() = Some(Duration::from_secs(5)); let url = reqwest_request.url().clone(); - let response = globals - .reqwest_client()? - .build()? - .execute(reqwest_request) - .await; + let response = globals.default_client().execute(reqwest_request).await; match response { Ok(mut response) => { diff --git a/src/server_server.rs b/src/server_server.rs index e730210a..2c682f6f 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -236,21 +236,7 @@ where let url = reqwest_request.url().clone(); - let mut client = globals.reqwest_client()?; - if let Some((override_name, port)) = globals - .tls_name_override - .read() - .unwrap() - .get(&actual_destination.hostname()) - { - client = client.resolve( - &actual_destination.hostname(), - SocketAddr::new(override_name[0], *port), - ); - // port will be ignored - } - - let response = client.build()?.execute(reqwest_request).await; + let response = globals.federation_client().execute(reqwest_request).await; match response { Ok(mut response) => { @@ -490,10 +476,7 @@ async fn request_well_known( ) -> Option { let body: serde_json::Value = serde_json::from_str( &globals - .reqwest_client() - .ok()? - .build() - .ok()? + .default_client() .get(&format!( "https://{}/.well-known/matrix/server", destination