From 0bb28f60cfc68beaa521bb7dacdec7a9f92c2288 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Fri, 22 Mar 2024 08:52:39 +0100 Subject: [PATCH] refactor: minor appservice code cleanup --- src/api/appservice_server.rs | 157 ++++++++++++++------------- src/api/client_server/alias.rs | 8 +- src/api/ruma_wrapper/axum.rs | 8 +- src/database/key_value/appservice.rs | 35 ++---- src/database/mod.rs | 19 ---- src/service/appservice/mod.rs | 39 +++++-- src/service/mod.rs | 7 +- src/service/rooms/timeline/mod.rs | 10 +- 8 files changed, 128 insertions(+), 155 deletions(-) diff --git a/src/api/appservice_server.rs b/src/api/appservice_server.rs index ab4da79f..841c32a1 100644 --- a/src/api/appservice_server.rs +++ b/src/api/appservice_server.rs @@ -17,95 +17,96 @@ pub(crate) async fn send_request( where T: Debug, { - if let Some(destination) = registration.url { - let hs_token = registration.hs_token.as_str(); + let Some(destination) = registration.url else { + return None; + }; - let mut http_request = request - .try_into_http_request::( - &destination, - SendAccessToken::IfRequired(hs_token), - &[MatrixVersion::V1_0], - ) - .unwrap() - .map(|body| body.freeze()); + let hs_token = registration.hs_token.as_str(); - let mut parts = http_request.uri().clone().into_parts(); - let old_path_and_query = parts.path_and_query.unwrap().as_str().to_owned(); - let symbol = if old_path_and_query.contains('?') { - "&" - } else { - "?" - }; + let mut http_request = request + .try_into_http_request::( + &destination, + SendAccessToken::IfRequired(hs_token), + &[MatrixVersion::V1_0], + ) + .unwrap() + .map(|body| body.freeze()); - parts.path_and_query = Some( - (old_path_and_query + symbol + "access_token=" + hs_token) - .parse() - .unwrap(), - ); - *http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid"); - - let mut reqwest_request = reqwest::Request::try_from(http_request) - .expect("all http requests are valid reqwest requests"); + let mut parts = http_request.uri().clone().into_parts(); + let old_path_and_query = parts.path_and_query.unwrap().as_str().to_owned(); + let symbol = if old_path_and_query.contains('?') { + "&" + } else { + "?" + }; - *reqwest_request.timeout_mut() = Some(Duration::from_secs(30)); + parts.path_and_query = Some( + (old_path_and_query + symbol + "access_token=" + hs_token) + .parse() + .unwrap(), + ); + *http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid"); - let url = reqwest_request.url().clone(); - let mut response = match services() - .globals - .default_client() - .execute(reqwest_request) - .await - { - Ok(r) => r, - Err(e) => { - warn!( - "Could not send request to appservice {:?} at {}: {}", - registration.id, destination, e - ); - return Some(Err(e.into())); - } - }; + let mut reqwest_request = reqwest::Request::try_from(http_request) + .expect("all http requests are valid reqwest requests"); - // reqwest::Response -> http::Response conversion - let status = response.status(); - let mut http_response_builder = http::Response::builder() - .status(status) - .version(response.version()); - mem::swap( - response.headers_mut(), - http_response_builder - .headers_mut() - .expect("http::response::Builder is usable"), - ); + *reqwest_request.timeout_mut() = Some(Duration::from_secs(30)); - let body = response.bytes().await.unwrap_or_else(|e| { - warn!("server error: {}", e); - Vec::new().into() - }); // TODO: handle timeout - - if status != 200 { + let url = reqwest_request.url().clone(); + let mut response = match services() + .globals + .default_client() + .execute(reqwest_request) + .await + { + Ok(r) => r, + Err(e) => { warn!( - "Appservice returned bad response {} {}\n{}\n{:?}", - destination, - status, - url, - utils::string_from_bytes(&body) + "Could not send request to appservice {:?} at {}: {}", + registration.id, destination, e ); + return Some(Err(e.into())); } + }; + + // reqwest::Response -> http::Response conversion + let status = response.status(); + let mut http_response_builder = http::Response::builder() + .status(status) + .version(response.version()); + mem::swap( + response.headers_mut(), + http_response_builder + .headers_mut() + .expect("http::response::Builder is usable"), + ); + + let body = response.bytes().await.unwrap_or_else(|e| { + warn!("server error: {}", e); + Vec::new().into() + }); // TODO: handle timeout - let response = T::IncomingResponse::try_from_http_response( - http_response_builder - .body(body) - .expect("reqwest body is valid http body"), + if status != 200 { + warn!( + "Appservice returned bad response {} {}\n{}\n{:?}", + destination, + status, + url, + utils::string_from_bytes(&body) ); - Some(response.map_err(|_| { - warn!( - "Appservice returned invalid response bytes {}\n{}", - destination, url - ); - Error::BadServerResponse("Server returned bad response.") - })) - } else { - None } + + let response = T::IncomingResponse::try_from_http_response( + http_response_builder + .body(body) + .expect("reqwest body is valid http body"), + ); + + Some(response.map_err(|_| { + warn!( + "Appservice returned invalid response bytes {}\n{}", + destination, url + ); + Error::BadServerResponse("Server returned bad response.") + })) } diff --git a/src/api/client_server/alias.rs b/src/api/client_server/alias.rs index d3a6e39a..00ee6c85 100644 --- a/src/api/client_server/alias.rs +++ b/src/api/client_server/alias.rs @@ -100,13 +100,7 @@ pub(crate) async fn get_alias_helper( match services().rooms.alias.resolve_local_alias(&room_alias)? { Some(r) => room_id = Some(r), None => { - for appservice in services() - .appservice - .registration_info - .read() - .await - .values() - { + for appservice in services().appservice.all().await { if appservice.aliases.is_match(room_alias.as_str()) && if let Some(opt_result) = services() .sending diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index e841f13a..6411ab9d 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -80,19 +80,19 @@ where let mut json_body = serde_json::from_slice::(&body).ok(); - let appservices = services().appservice.all().unwrap(); + let appservices = services().appservice.all().await; let appservice_registration = appservices .iter() - .find(|(_id, registration)| Some(registration.as_token.as_str()) == token); + .find(|info| Some(info.registration.as_token.as_str()) == token); let (sender_user, sender_device, sender_servername, from_appservice) = - if let Some((_id, registration)) = appservice_registration { + if let Some(info) = appservice_registration { match metadata.authentication { AuthScheme::AccessToken => { let user_id = query_params.user_id.map_or_else( || { UserId::parse_with_server_name( - registration.sender_localpart.as_str(), + info.registration.sender_localpart.as_str(), services().globals.server_name(), ) .unwrap() diff --git a/src/database/key_value/appservice.rs b/src/database/key_value/appservice.rs index 3243183d..b547e66a 100644 --- a/src/database/key_value/appservice.rs +++ b/src/database/key_value/appservice.rs @@ -10,10 +10,6 @@ impl service::appservice::Data for KeyValueDatabase { id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes(), )?; - self.cached_registrations - .write() - .unwrap() - .insert(id.to_owned(), yaml.to_owned()); Ok(id.to_owned()) } @@ -26,33 +22,18 @@ impl service::appservice::Data for KeyValueDatabase { fn unregister_appservice(&self, service_name: &str) -> Result<()> { self.id_appserviceregistrations .remove(service_name.as_bytes())?; - self.cached_registrations - .write() - .unwrap() - .remove(service_name); Ok(()) } fn get_registration(&self, id: &str) -> Result> { - self.cached_registrations - .read() - .unwrap() - .get(id) - .map_or_else( - || { - self.id_appserviceregistrations - .get(id.as_bytes())? - .map(|bytes| { - serde_yaml::from_slice(&bytes).map_err(|_| { - Error::bad_database( - "Invalid registration bytes in id_appserviceregistrations.", - ) - }) - }) - .transpose() - }, - |r| Ok(Some(r.clone())), - ) + self.id_appserviceregistrations + .get(id.as_bytes())? + .map(|bytes| { + serde_yaml::from_slice(&bytes).map_err(|_| { + Error::bad_database("Invalid registration bytes in id_appserviceregistrations.") + }) + }) + .transpose() } fn iter_ids<'a>(&'a self) -> Result> + 'a>> { diff --git a/src/database/mod.rs b/src/database/mod.rs index 5b8588cf..190e7e12 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -10,7 +10,6 @@ use directories::ProjectDirs; use lru_cache::LruCache; use ruma::{ - api::appservice::Registration, events::{ push_rules::{PushRulesEvent, PushRulesEventContent}, room::message::RoomMessageEventContent, @@ -164,7 +163,6 @@ pub struct KeyValueDatabase { //pub pusher: pusher::PushData, pub(super) senderkey_pusher: Arc, - pub(super) cached_registrations: Arc>>, pub(super) pdu_cache: Mutex>>, pub(super) shorteventid_cache: Mutex>>, pub(super) auth_chain_cache: Mutex, Arc>>>, @@ -374,7 +372,6 @@ impl KeyValueDatabase { global: builder.open_tree("global")?, server_signingkeys: builder.open_tree("server_signingkeys")?, - cached_registrations: Arc::new(RwLock::new(HashMap::new())), pdu_cache: Mutex::new(LruCache::new( config .pdu_cache_capacity @@ -969,22 +966,6 @@ impl KeyValueDatabase { ); } - // Inserting registrations into cache - for appservice in services().appservice.all()? { - services() - .appservice - .registration_info - .write() - .await - .insert( - appservice.0, - appservice - .1 - .try_into() - .expect("Should be validated on registration"), - ); - } - // This data is probably outdated db.presenceid_presence.clear()?; diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index 40fa3ee8..6b9e21f1 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -10,7 +10,8 @@ use tokio::sync::RwLock; use crate::{services, Result}; -/// Compiled regular expressions for a namespace +/// Compiled regular expressions for a namespace. +#[derive(Clone, Debug)] pub struct NamespaceRegex { pub exclusive: Option, pub non_exclusive: Option, @@ -72,7 +73,8 @@ impl TryFrom> for NamespaceRegex { type Error = regex::Error; } -/// Compiled regular expressions for an appservice +/// Appservice registration combined with its compiled regular expressions. +#[derive(Clone, Debug)] pub struct RegistrationInfo { pub registration: Registration, pub users: NamespaceRegex, @@ -95,11 +97,29 @@ impl TryFrom for RegistrationInfo { pub struct Service { pub db: &'static dyn Data, - pub registration_info: RwLock>, + registration_info: RwLock>, } impl Service { - /// Registers an appservice and returns the ID to the caller + pub fn build(db: &'static dyn Data) -> Result { + let mut registration_info = HashMap::new(); + // Inserting registrations into cache + for appservice in db.all()? { + registration_info.insert( + appservice.0, + appservice + .1 + .try_into() + .expect("Should be validated on registration"), + ); + } + + Ok(Self { + db, + registration_info: RwLock::new(registration_info), + }) + } + /// Registers an appservice and returns the ID to the caller. pub async fn register_appservice(&self, yaml: Registration) -> Result { services() .appservice @@ -111,7 +131,7 @@ impl Service { self.db.register_appservice(yaml) } - /// Remove an appservice registration + /// Removes an appservice registration. /// /// # Arguments /// @@ -135,7 +155,12 @@ impl Service { self.db.iter_ids() } - pub fn all(&self) -> Result> { - self.db.all() + pub async fn all(&self) -> Vec { + self.registration_info + .read() + .await + .values() + .cloned() + .collect() } } diff --git a/src/service/mod.rs b/src/service/mod.rs index 045ccd10..0cbe6a82 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -4,7 +4,7 @@ use std::{ }; use lru_cache::LruCache; -use tokio::sync::{Mutex, RwLock}; +use tokio::sync::Mutex; use crate::{Config, Result}; @@ -56,10 +56,7 @@ impl Services { config: Config, ) -> Result { Ok(Self { - appservice: appservice::Service { - db, - registration_info: RwLock::new(HashMap::new()), - }, + appservice: appservice::Service::build(db)?, pusher: pusher::Service { db }, rooms: rooms::Service { alias: rooms::alias::Service { db }, diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 1df1db50..035513d6 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -524,17 +524,11 @@ impl Service { } } - for appservice in services() - .appservice - .registration_info - .read() - .await - .values() - { + for appservice in services().appservice.all().await { if services() .rooms .state_cache - .appservice_in_room(&pdu.room_id, appservice)? + .appservice_in_room(&pdu.room_id, &appservice)? { services() .sending