diff --git a/src/api/server_server.rs b/src/api/server_server.rs index 3f59be65..6fd131c5 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -6,7 +6,7 @@ use crate::{ services, utils, Error, PduEvent, Result, Ruma, }; use axum::{response::IntoResponse, Json}; -use axum_extra::headers::authorization::Credentials; +use axum_extra::headers::{authorization::Credentials, CacheControl, Header}; use get_profile_information::v1::ProfileField; use http::header::AUTHORIZATION; @@ -96,13 +96,6 @@ impl FedDest { } } - fn into_uri_string(self) -> String { - match self { - Self::Literal(addr) => addr.to_string(), - Self::Named(host, ref port) => host + port, - } - } - fn hostname(&self) -> String { match &self { Self::Literal(addr) => addr.ip().to_string(), @@ -138,8 +131,6 @@ where debug!("Preparing to send request to {destination}"); - let mut write_destination_to_cache = false; - let cached_result = services() .globals .actual_destination_cache @@ -148,14 +139,63 @@ where .get(destination) .cloned(); - let (actual_destination, host) = if let Some(result) = cached_result { - result - } else { - write_destination_to_cache = true; - - let result = find_actual_destination(destination).await; + let actual_destination = if let Some(DestinationResponse { + actual_destination, + dest_type, + }) = cached_result + { + match dest_type { + DestType::IsIpOrHasPort => actual_destination, + DestType::LookupFailed { + well_known_retry, + well_known_backoff_mins, + } => { + if well_known_retry < Instant::now() { + find_actual_destination(destination, None, false, Some(well_known_backoff_mins)) + .await + } else { + actual_destination + } + } - (result.0, result.1.into_uri_string()) + DestType::WellKnown { expires } => { + if expires < Instant::now() { + find_actual_destination(destination, None, false, None).await + } else { + actual_destination + } + } + DestType::WellKnownSrv { + srv_expires, + well_known_expires, + well_known_host, + } => { + if well_known_expires < Instant::now() { + find_actual_destination(destination, None, false, None).await + } else if srv_expires < Instant::now() { + find_actual_destination(destination, Some(well_known_host), true, None).await + } else { + actual_destination + } + } + DestType::Srv { + well_known_retry, + well_known_backoff_mins, + srv_expires, + } => { + if well_known_retry < Instant::now() { + find_actual_destination(destination, None, false, Some(well_known_backoff_mins)) + .await + } else if srv_expires < Instant::now() { + find_actual_destination(destination, None, true, Some(well_known_backoff_mins)) + .await + } else { + actual_destination + } + } + } + } else { + find_actual_destination(destination, None, false, None).await }; let actual_destination_str = actual_destination.clone().into_https_string(); @@ -293,17 +333,6 @@ where if status == 200 { debug!("Parsing response bytes from {destination}"); let response = T::IncomingResponse::try_from_http_response(http_response); - if response.is_ok() && write_destination_to_cache { - services() - .globals - .actual_destination_cache - .write() - .await - .insert( - OwnedServerName::from(destination), - (actual_destination, host), - ); - } response.map_err(|e| { warn!( @@ -348,142 +377,211 @@ fn add_port_to_hostname(destination_str: &str) -> FedDest { FedDest::Named(host.to_owned(), port.to_owned()) } -/// Returns: actual_destination, host header -/// Implemented according to the specification at +#[derive(Clone)] +pub struct DestinationResponse { + pub actual_destination: FedDest, + pub dest_type: DestType, +} + +#[derive(Clone)] +pub enum DestType { + WellKnownSrv { + srv_expires: Instant, + well_known_expires: Instant, + well_known_host: String, + }, + WellKnown { + expires: Instant, + }, + Srv { + srv_expires: Instant, + well_known_retry: Instant, + well_known_backoff_mins: u16, + }, + IsIpOrHasPort, + LookupFailed { + well_known_retry: Instant, + well_known_backoff_mins: u16, + }, +} + +/// Implemented according to the specification at /// Numbers in comments below refer to bullet points in linked section of specification -async fn find_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDest) { +async fn find_actual_destination( + destination: &'_ ServerName, + // The host used to potentially lookup SRV records against, only used when only_request_srv is true + well_known_dest: Option, + // Should be used when only the SRV lookup has expired + only_request_srv: bool, + // The backoff time for the last well known failure, if any + well_known_backoff_mins: Option, +) -> FedDest { debug!("Finding actual destination for {destination}"); - let destination_str = destination.as_str().to_owned(); - let mut hostname = destination_str.clone(); - let actual_destination = match get_ip_with_port(&destination_str) { - Some(host_port) => { - debug!("1: IP literal with provided or default port"); - host_port - } - None => { - if let Some(pos) = destination_str.find(':') { - debug!("2: Hostname with included port"); - let (host, port) = destination_str.split_at(pos); - FedDest::Named(host.to_owned(), port.to_owned()) + let destination_str = destination.to_string(); + let next_backoff_mins = well_known_backoff_mins + // Errors are recommended to be cached for up to an hour + .map(|mins| (mins * 2).min(60)) + .unwrap_or(1); + + let (actual_destination, dest_type) = if only_request_srv { + let destination_str = well_known_dest.unwrap_or(destination_str); + let (dest, expires) = get_srv_destination(destination_str).await; + let well_known_retry = + Instant::now() + Duration::from_secs((60 * next_backoff_mins).into()); + ( + dest, + if let Some(expires) = expires { + DestType::Srv { + well_known_backoff_mins: next_backoff_mins, + srv_expires: expires, + + well_known_retry, + } } else { - debug!("Requesting well known for {destination}"); - match request_well_known(destination.as_str()).await { - Some(delegated_hostname) => { - debug!("3: A .well-known file is available"); - hostname = add_port_to_hostname(&delegated_hostname).into_uri_string(); - match get_ip_with_port(&delegated_hostname) { - Some(host_and_port) => host_and_port, // 3.1: IP literal in .well-known file - None => { - if let Some(pos) = delegated_hostname.find(':') { - debug!("3.2: Hostname with port in .well-known file"); - let (host, port) = delegated_hostname.split_at(pos); - FedDest::Named(host.to_owned(), port.to_owned()) - } else { - debug!("Delegated hostname has no port in this branch"); - if let Some(hostname_override) = - query_srv_record(&delegated_hostname).await - { - debug!("3.3: SRV lookup successful"); - let force_port = hostname_override.port(); - - if let Ok(override_ip) = services() - .globals - .dns_resolver() - .lookup_ip(hostname_override.hostname()) - .await - { - services() - .globals - .tls_name_override - .write() - .unwrap() - .insert( - delegated_hostname.clone(), - ( - override_ip.iter().collect(), - force_port.unwrap_or(8448), - ), - ); - } else { - warn!("Using SRV record, but could not resolve to IP"); - } - - if let Some(port) = force_port { - FedDest::Named(delegated_hostname, format!(":{port}")) - } else { - add_port_to_hostname(&delegated_hostname) - } + DestType::LookupFailed { + well_known_retry, + well_known_backoff_mins: next_backoff_mins, + } + }, + ) + } else { + match get_ip_with_port(&destination_str) { + Some(host_port) => { + debug!("1: IP literal with provided or default port"); + (host_port, DestType::IsIpOrHasPort) + } + None => { + if let Some(pos) = destination_str.find(':') { + debug!("2: Hostname with included port"); + let (host, port) = destination_str.split_at(pos); + ( + FedDest::Named(host.to_owned(), port.to_owned()), + DestType::IsIpOrHasPort, + ) + } else { + debug!("Requesting well known for {destination_str}"); + match request_well_known(destination_str.as_str()).await { + Some((delegated_hostname, timestamp)) => { + debug!("3: A .well-known file is available"); + match get_ip_with_port(&delegated_hostname) { + // 3.1: IP literal in .well-known file + Some(host_and_port) => { + (host_and_port, DestType::WellKnown { expires: timestamp }) + } + None => { + if let Some(pos) = delegated_hostname.find(':') { + debug!("3.2: Hostname with port in .well-known file"); + let (host, port) = delegated_hostname.split_at(pos); + ( + FedDest::Named(host.to_owned(), port.to_owned()), + DestType::WellKnown { expires: timestamp }, + ) } else { - debug!("3.4: No SRV records, just use the hostname from .well-known"); - add_port_to_hostname(&delegated_hostname) + debug!("Delegated hostname has no port in this branch"); + let (dest, srv_expires) = + get_srv_destination(delegated_hostname.clone()).await; + ( + dest, + if let Some(srv_expires) = srv_expires { + DestType::WellKnownSrv { + srv_expires, + well_known_expires: timestamp, + well_known_host: delegated_hostname, + } + } else { + DestType::WellKnown { expires: timestamp } + }, + ) } } } } - } - None => { - debug!("4: No .well-known or an error occured"); - match query_srv_record(&destination_str).await { - Some(hostname_override) => { - debug!("4: SRV record found"); - let force_port = hostname_override.port(); - - if let Ok(override_ip) = services() - .globals - .dns_resolver() - .lookup_ip(hostname_override.hostname()) - .await - { - services() - .globals - .tls_name_override - .write() - .unwrap() - .insert( - hostname.clone(), - ( - override_ip.iter().collect(), - force_port.unwrap_or(8448), - ), - ); - } else { - warn!("Using SRV record, but could not resolve to IP"); - } - - if let Some(port) = force_port { - FedDest::Named(hostname.clone(), format!(":{port}")) + None => { + debug!("4: No .well-known or an error occured"); + let (dest, expires) = get_srv_destination(destination_str).await; + let well_known_retry = Instant::now() + + Duration::from_secs((60 * next_backoff_mins).into()); + ( + dest, + if let Some(expires) = expires { + DestType::Srv { + srv_expires: expires, + well_known_retry, + well_known_backoff_mins: next_backoff_mins, + } } else { - add_port_to_hostname(&hostname) - } - } - None => { - debug!("5: No SRV record found"); - add_port_to_hostname(&destination_str) - } + DestType::LookupFailed { + well_known_retry, + well_known_backoff_mins: next_backoff_mins, + } + }, + ) } } } } } }; + debug!("Actual destination: {actual_destination:?}"); - // Can't use get_ip_with_port here because we don't want to add a port - // to an IP address if it wasn't specified - let hostname = if let Ok(addr) = hostname.parse::() { - FedDest::Literal(addr) - } else if let Ok(addr) = hostname.parse::() { - FedDest::Named(addr.to_string(), ":8448".to_owned()) - } else if let Some(pos) = hostname.find(':') { - let (host, port) = hostname.split_at(pos); - FedDest::Named(host.to_owned(), port.to_owned()) - } else { - FedDest::Named(hostname, ":8448".to_owned()) + let response = DestinationResponse { + actual_destination, + dest_type, }; - (actual_destination, hostname) + + services() + .globals + .actual_destination_cache + .write() + .await + .insert(destination.to_owned(), response.clone()); + + response.actual_destination +} + +/// Looks up the SRV records for federation usage +/// +/// If no timestamp is returned, that means no SRV record was found +async fn get_srv_destination(delegated_hostname: String) -> (FedDest, Option) { + if let Some((hostname_override, timestamp)) = query_srv_record(&delegated_hostname).await { + debug!("SRV lookup successful"); + let force_port = hostname_override.port(); + + if let Ok(override_ip) = services() + .globals + .dns_resolver() + .lookup_ip(hostname_override.hostname()) + .await + { + services() + .globals + .tls_name_override + .write() + .unwrap() + .insert( + delegated_hostname.clone(), + (override_ip.iter().collect(), force_port.unwrap_or(8448)), + ); + } else { + warn!("Using SRV record, but could not resolve to IP"); + } + + if let Some(port) = force_port { + ( + FedDest::Named(delegated_hostname, format!(":{port}")), + Some(timestamp), + ) + } else { + (add_port_to_hostname(&delegated_hostname), Some(timestamp)) + } + } else { + debug!("No SRV records found"); + (add_port_to_hostname(&delegated_hostname), None) + } } -async fn query_given_srv_record(record: &str) -> Option { +async fn query_given_srv_record(record: &str) -> Option<(FedDest, Instant)> { services() .globals .dns_resolver() @@ -491,16 +589,19 @@ async fn query_given_srv_record(record: &str) -> Option { .await .map(|srv| { srv.iter().next().map(|result| { - FedDest::Named( - result.target().to_string().trim_end_matches('.').to_owned(), - format!(":{}", result.port()), + ( + FedDest::Named( + result.target().to_string().trim_end_matches('.').to_owned(), + format!(":{}", result.port()), + ), + srv.as_lookup().valid_until(), ) }) }) .unwrap_or(None) } -async fn query_srv_record(hostname: &'_ str) -> Option { +async fn query_srv_record(hostname: &'_ str) -> Option<(FedDest, Instant)> { let hostname = hostname.trim_end_matches('.'); if let Some(host_port) = query_given_srv_record(&format!("_matrix-fed._tcp.{hostname}.")).await @@ -511,7 +612,7 @@ async fn query_srv_record(hostname: &'_ str) -> Option { } } -async fn request_well_known(destination: &str) -> Option { +async fn request_well_known(destination: &str) -> Option<(String, Instant)> { let response = services() .globals .default_client() @@ -519,14 +620,40 @@ async fn request_well_known(destination: &str) -> Option { .send() .await; debug!("Got well known response"); - if let Err(e) = &response { - debug!("Well known error: {e:?}"); - return None; - } - let text = response.ok()?.text().await; + let response = match response { + Err(e) => { + debug!("Well known error: {e:?}"); + return None; + } + Ok(r) => r, + }; + + let mut headers = response.headers().values(); + + let cache_for = CacheControl::decode(&mut headers) + .ok() + .and_then(|cc| { + // Servers should respect the cache control headers present on the response, or use a sensible default when headers are not present. + if cc.no_store() || cc.no_cache() { + Some(Duration::ZERO) + } else { + cc.max_age() + // Servers should additionally impose a maximum cache time for responses: 48 hours is recommended. + .map(|age| age.min(Duration::from_secs(60 * 60 * 48))) + } + }) + // The recommended sensible default is 24 hours. + .unwrap_or_else(|| Duration::from_secs(60 * 60 * 24)); + + let text = response.text().await; debug!("Got well known response text"); - let body: serde_json::Value = serde_json::from_str(&text.ok()?).ok()?; - Some(body.get("m.server")?.as_str()?.to_owned()) + + let host = || { + let body: serde_json::Value = serde_json::from_str(&text.ok()?).ok()?; + body.get("m.server")?.as_str().map(ToOwned::to_owned) + }; + + host().map(|host| (host, Instant::now() + cache_for)) } /// # `GET /_matrix/federation/v1/version` diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index c22ffef3..3325e518 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -5,7 +5,7 @@ use ruma::{ OwnedRoomId, OwnedServerName, OwnedUserId, RoomAliasId, }; -use crate::api::server_server::FedDest; +use crate::api::server_server::DestinationResponse; use crate::{services, Config, Error, Result}; use futures_util::FutureExt; @@ -37,7 +37,7 @@ use tracing::{error, info}; use base64::{engine::general_purpose, Engine as _}; -type WellKnownMap = HashMap; +type WellKnownMap = HashMap; type TlsNameMap = HashMap, u16)>; type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries type SyncHandle = (