From c53cc03ff8db65b6b447a852eee85e540ad38cb1 Mon Sep 17 00:00:00 2001 From: Aiden McClelland Date: Thu, 1 Jul 2021 13:38:25 -0600 Subject: [PATCH] address pr comments --- conduit-example.toml | 2 + src/database.rs | 121 +--------------------------------- src/database/proxy.rs | 146 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 151 insertions(+), 118 deletions(-) create mode 100644 src/database/proxy.rs diff --git a/conduit-example.toml b/conduit-example.toml index 66c105be..db0bbb77 100644 --- a/conduit-example.toml +++ b/conduit-example.toml @@ -41,3 +41,5 @@ trusted_servers = ["matrix.org"] #workers = 4 # default: cpu core count * 2 address = "127.0.0.1" # This makes sure Conduit can only be reached using the reverse proxy + +proxy = "none" # more examples can be found at src/database/proxy.rs:6 diff --git a/src/database.rs b/src/database.rs index 64b5ee39..0ea4d784 100644 --- a/src/database.rs +++ b/src/database.rs @@ -6,6 +6,7 @@ pub mod appservice; pub mod globals; pub mod key_backups; pub mod media; +pub mod proxy; pub mod pusher; pub mod rooms; pub mod sending; @@ -28,6 +29,8 @@ use std::{ }; use tokio::sync::Semaphore; +use self::proxy::ProxyConfig; + #[derive(Clone, Debug, Deserialize)] pub struct Config { server_name: Box, @@ -85,124 +88,6 @@ pub type Engine = abstraction::SledEngine; #[cfg(feature = "rocksdb")] pub type Engine = abstraction::RocksDbEngine; -#[derive(Clone, Debug, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum ProxyConfig { - None, - Global { - #[serde(deserialize_with = "crate::utils::deserialize_from_str")] - url: reqwest::Url, - }, - ByDomain(Vec), -} -impl ProxyConfig { - pub fn to_proxy(&self) -> Result> { - Ok(match self.clone() { - ProxyConfig::None => None, - ProxyConfig::Global { url } => Some(reqwest::Proxy::all(url)?), - ProxyConfig::ByDomain(proxies) => Some(reqwest::Proxy::custom(move |url| { - proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() // first matching proxy - })), - }) - } -} -impl Default for ProxyConfig { - fn default() -> Self { - ProxyConfig::None - } -} - -#[derive(Clone, Debug, Deserialize)] -pub struct PartialProxyConfig { - #[serde(deserialize_with = "crate::utils::deserialize_from_str")] - url: reqwest::Url, - #[serde(default)] - include: Vec, - #[serde(default)] - exclude: Vec, -} -impl PartialProxyConfig { - pub fn for_url(&self, url: &reqwest::Url) -> Option<&reqwest::Url> { - let domain = url.domain()?; - let mut included_because = None; // most specific reason it was included - let mut excluded_because = None; // most specific reason it was excluded - if self.include.is_empty() { - // treat empty include list as `*` - included_because = Some(&WildCardedDomain::WildCard) - } - for wc_domain in &self.include { - if wc_domain.matches(domain) { - match included_because { - Some(prev) if !wc_domain.more_specific_than(prev) => (), - _ => included_because = Some(wc_domain), - } - } - } - for wc_domain in &self.exclude { - if wc_domain.matches(domain) { - match excluded_because { - Some(prev) if !wc_domain.more_specific_than(prev) => (), - _ => excluded_because = Some(wc_domain), - } - } - } - match (included_because, excluded_because) { - (Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), // included for a more specific reason than excluded - (Some(_), None) => Some(&self.url), - _ => None, - } - } -} - -/// A domain name, that optionally allows a * as its first subdomain. -#[derive(Clone, Debug)] -pub enum WildCardedDomain { - WildCard, - WildCarded(String), - Exact(String), -} -impl WildCardedDomain { - pub fn matches(&self, domain: &str) -> bool { - match self { - WildCardedDomain::WildCard => true, - WildCardedDomain::WildCarded(d) => domain.ends_with(d), - WildCardedDomain::Exact(d) => domain == d, - } - } - pub fn more_specific_than(&self, other: &Self) -> bool { - match (self, other) { - (WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false, - (_, WildCardedDomain::WildCard) => true, - (WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a), - (WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => { - a != b && a.ends_with(b) - } - _ => false, - } - } -} -impl std::str::FromStr for WildCardedDomain { - type Err = std::convert::Infallible; - fn from_str(s: &str) -> std::result::Result { - // maybe do some domain validation? - Ok(if s.starts_with("*.") { - WildCardedDomain::WildCarded(s[1..].to_owned()) - } else if s == "*" { - WildCardedDomain::WildCarded("".to_owned()) - } else { - WildCardedDomain::Exact(s.to_owned()) - }) - } -} -impl<'de> serde::de::Deserialize<'de> for WildCardedDomain { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::de::Deserializer<'de>, - { - crate::utils::deserialize_from_str(deserializer) - } -} - pub struct Database { pub globals: globals::Globals, pub users: users::Users, diff --git a/src/database/proxy.rs b/src/database/proxy.rs new file mode 100644 index 00000000..78e9d2bf --- /dev/null +++ b/src/database/proxy.rs @@ -0,0 +1,146 @@ +use reqwest::{Proxy, Url}; +use serde::Deserialize; + +use crate::Result; + +/// ## Examples: +/// - No proxy (default): +/// ```toml +/// proxy ="none" +/// ``` +/// - Global proxy +/// ```toml +/// [proxy] +/// global = { url = "socks5h://localhost:9050" } +/// ``` +/// - Proxy some domains +/// ```toml +/// [proxy] +/// [[proxy.by_domain]] +/// url = "socks5h://localhost:9050" +/// include = ["*.onion", "matrix.myspecial.onion"] +/// exclude = ["*.myspecial.onion"] +/// ``` +/// ## Include vs. Exclude +/// If include is an empty list, it is assumed to be `["*"]`. +/// +/// If a domain matches both the exclude and include list, the proxy will only be used if it was +/// included because of a more specific rule than it was excluded. In the above example, the proxy +/// would be used for `ordinary.onion`, `matrix.myspecial.onion`, but not `hello.myspecial.onion`. +#[derive(Clone, Debug, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ProxyConfig { + None, + Global { + #[serde(deserialize_with = "crate::utils::deserialize_from_str")] + url: Url, + }, + ByDomain(Vec), +} +impl ProxyConfig { + pub fn to_proxy(&self) -> Result> { + Ok(match self.clone() { + ProxyConfig::None => None, + ProxyConfig::Global { url } => Some(Proxy::all(url)?), + ProxyConfig::ByDomain(proxies) => Some(Proxy::custom(move |url| { + proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() // first matching proxy + })), + }) + } +} +impl Default for ProxyConfig { + fn default() -> Self { + ProxyConfig::None + } +} + +#[derive(Clone, Debug, Deserialize)] +pub struct PartialProxyConfig { + #[serde(deserialize_with = "crate::utils::deserialize_from_str")] + url: Url, + #[serde(default)] + include: Vec, + #[serde(default)] + exclude: Vec, +} +impl PartialProxyConfig { + pub fn for_url(&self, url: &Url) -> Option<&Url> { + let domain = url.domain()?; + let mut included_because = None; // most specific reason it was included + let mut excluded_because = None; // most specific reason it was excluded + if self.include.is_empty() { + // treat empty include list as `*` + included_because = Some(&WildCardedDomain::WildCard) + } + for wc_domain in &self.include { + if wc_domain.matches(domain) { + match included_because { + Some(prev) if !wc_domain.more_specific_than(prev) => (), + _ => included_because = Some(wc_domain), + } + } + } + for wc_domain in &self.exclude { + if wc_domain.matches(domain) { + match excluded_because { + Some(prev) if !wc_domain.more_specific_than(prev) => (), + _ => excluded_because = Some(wc_domain), + } + } + } + match (included_because, excluded_because) { + (Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), // included for a more specific reason than excluded + (Some(_), None) => Some(&self.url), + _ => None, + } + } +} + +/// A domain name, that optionally allows a * as its first subdomain. +#[derive(Clone, Debug)] +pub enum WildCardedDomain { + WildCard, + WildCarded(String), + Exact(String), +} +impl WildCardedDomain { + pub fn matches(&self, domain: &str) -> bool { + match self { + WildCardedDomain::WildCard => true, + WildCardedDomain::WildCarded(d) => domain.ends_with(d), + WildCardedDomain::Exact(d) => domain == d, + } + } + pub fn more_specific_than(&self, other: &Self) -> bool { + match (self, other) { + (WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false, + (_, WildCardedDomain::WildCard) => true, + (WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a), + (WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => { + a != b && a.ends_with(b) + } + _ => false, + } + } +} +impl std::str::FromStr for WildCardedDomain { + type Err = std::convert::Infallible; + fn from_str(s: &str) -> std::result::Result { + // maybe do some domain validation? + Ok(if s.starts_with("*.") { + WildCardedDomain::WildCarded(s[1..].to_owned()) + } else if s == "*" { + WildCardedDomain::WildCarded("".to_owned()) + } else { + WildCardedDomain::Exact(s.to_owned()) + }) + } +} +impl<'de> serde::de::Deserialize<'de> for WildCardedDomain { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::de::Deserializer<'de>, + { + crate::utils::deserialize_from_str(deserializer) + } +}