From fa930182aea942380f8db19b5e18152e17d9e634 Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Wed, 27 Dec 2023 13:22:21 +0000 Subject: [PATCH 1/4] fix(appservices): don't panic on empty registration url perf(appservices): cache regex for namespaces --- src/api/appservice_server.rs | 166 ++++++++++---------- src/api/client_server/alias.rs | 34 ++-- src/api/ruma_wrapper/axum.rs | 15 +- src/database/key_value/appservice.rs | 11 +- src/database/key_value/rooms/state_cache.rs | 47 ++---- src/database/mod.rs | 20 ++- src/service/admin/mod.rs | 7 +- src/service/appservice/data.rs | 8 +- src/service/appservice/mod.rs | 114 +++++++++++++- src/service/mod.rs | 7 +- src/service/rooms/state_cache/data.rs | 8 +- src/service/rooms/state_cache/mod.rs | 4 +- src/service/rooms/timeline/mod.rs | 116 ++++++-------- src/service/sending/mod.rs | 19 ++- src/utils/error.rs | 5 + 15 files changed, 336 insertions(+), 245 deletions(-) diff --git a/src/api/appservice_server.rs b/src/api/appservice_server.rs index 082a1bc2..ab4da79f 100644 --- a/src/api/appservice_server.rs +++ b/src/api/appservice_server.rs @@ -1,105 +1,111 @@ use crate::{services, utils, Error, Result}; use bytes::BytesMut; -use ruma::api::{IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken}; +use ruma::api::{ + appservice::Registration, IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken, +}; use std::{fmt::Debug, mem, time::Duration}; use tracing::warn; +/// Sends a request to an appservice +/// +/// Only returns None if there is no url specified in the appservice registration file #[tracing::instrument(skip(request))] pub(crate) async fn send_request( - registration: serde_yaml::Value, + registration: Registration, request: T, -) -> Result +) -> Option> where T: Debug, { - let destination = registration.get("url").unwrap().as_str().unwrap(); - let hs_token = registration.get("hs_token").unwrap().as_str().unwrap(); + if let Some(destination) = registration.url { + let hs_token = registration.hs_token.as_str(); - let mut http_request = request - .try_into_http_request::( - destination, - SendAccessToken::IfRequired(hs_token), - &[MatrixVersion::V1_0], - ) - .unwrap() - .map(|body| body.freeze()); + let mut http_request = request + .try_into_http_request::( + &destination, + SendAccessToken::IfRequired(hs_token), + &[MatrixVersion::V1_0], + ) + .unwrap() + .map(|body| body.freeze()); - let mut parts = http_request.uri().clone().into_parts(); - let old_path_and_query = parts.path_and_query.unwrap().as_str().to_owned(); - let symbol = if old_path_and_query.contains('?') { - "&" - } else { - "?" - }; + let mut parts = http_request.uri().clone().into_parts(); + let old_path_and_query = parts.path_and_query.unwrap().as_str().to_owned(); + let symbol = if old_path_and_query.contains('?') { + "&" + } else { + "?" + }; + + parts.path_and_query = Some( + (old_path_and_query + symbol + "access_token=" + hs_token) + .parse() + .unwrap(), + ); + *http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid"); - parts.path_and_query = Some( - (old_path_and_query + symbol + "access_token=" + hs_token) - .parse() - .unwrap(), - ); - *http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid"); + let mut reqwest_request = reqwest::Request::try_from(http_request) + .expect("all http requests are valid reqwest requests"); - let mut reqwest_request = reqwest::Request::try_from(http_request) - .expect("all http requests are valid reqwest requests"); + *reqwest_request.timeout_mut() = Some(Duration::from_secs(30)); - *reqwest_request.timeout_mut() = Some(Duration::from_secs(30)); + let url = reqwest_request.url().clone(); + let mut response = match services() + .globals + .default_client() + .execute(reqwest_request) + .await + { + Ok(r) => r, + Err(e) => { + warn!( + "Could not send request to appservice {:?} at {}: {}", + registration.id, destination, e + ); + return Some(Err(e.into())); + } + }; - let url = reqwest_request.url().clone(); - let mut response = match services() - .globals - .default_client() - .execute(reqwest_request) - .await - { - Ok(r) => r, - Err(e) => { + // reqwest::Response -> http::Response conversion + let status = response.status(); + let mut http_response_builder = http::Response::builder() + .status(status) + .version(response.version()); + mem::swap( + response.headers_mut(), + http_response_builder + .headers_mut() + .expect("http::response::Builder is usable"), + ); + + let body = response.bytes().await.unwrap_or_else(|e| { + warn!("server error: {}", e); + Vec::new().into() + }); // TODO: handle timeout + + if status != 200 { warn!( - "Could not send request to appservice {:?} at {}: {}", - registration.get("id"), + "Appservice returned bad response {} {}\n{}\n{:?}", destination, - e + status, + url, + utils::string_from_bytes(&body) ); - return Err(e.into()); } - }; - - // reqwest::Response -> http::Response conversion - let status = response.status(); - let mut http_response_builder = http::Response::builder() - .status(status) - .version(response.version()); - mem::swap( - response.headers_mut(), - http_response_builder - .headers_mut() - .expect("http::response::Builder is usable"), - ); - let body = response.bytes().await.unwrap_or_else(|e| { - warn!("server error: {}", e); - Vec::new().into() - }); // TODO: handle timeout - - if status != 200 { - warn!( - "Appservice returned bad response {} {}\n{}\n{:?}", - destination, - status, - url, - utils::string_from_bytes(&body) + let response = T::IncomingResponse::try_from_http_response( + http_response_builder + .body(body) + .expect("reqwest body is valid http body"), ); + Some(response.map_err(|_| { + warn!( + "Appservice returned invalid response bytes {}\n{}", + destination, url + ); + Error::BadServerResponse("Server returned bad response.") + })) + } else { + None } - - let response = T::IncomingResponse::try_from_http_response( - http_response_builder - .body(body) - .expect("reqwest body is valid http body"), - ); - response.map_err(|_| { - warn!( - "Appservice returned invalid response bytes {}\n{}", - destination, url - ); - Error::BadServerResponse("Server returned bad response.") - }) } diff --git a/src/api/client_server/alias.rs b/src/api/client_server/alias.rs index 7660ca2f..d3a6e39a 100644 --- a/src/api/client_server/alias.rs +++ b/src/api/client_server/alias.rs @@ -1,6 +1,5 @@ use crate::{services, Error, Result, Ruma}; use rand::seq::SliceRandom; -use regex::Regex; use ruma::{ api::{ appservice, @@ -101,31 +100,28 @@ pub(crate) async fn get_alias_helper( match services().rooms.alias.resolve_local_alias(&room_alias)? { Some(r) => room_id = Some(r), None => { - for (_id, registration) in services().appservice.all()? { - let aliases = registration - .get("namespaces") - .and_then(|ns| ns.get("aliases")) - .and_then(|aliases| aliases.as_sequence()) - .map_or_else(Vec::new, |aliases| { - aliases - .iter() - .filter_map(|aliases| Regex::new(aliases.get("regex")?.as_str()?).ok()) - .collect::>() - }); - - if aliases - .iter() - .any(|aliases| aliases.is_match(room_alias.as_str())) - && services() + for appservice in services() + .appservice + .registration_info + .read() + .await + .values() + { + if appservice.aliases.is_match(room_alias.as_str()) + && if let Some(opt_result) = services() .sending .send_appservice_request( - registration, + appservice.registration.clone(), appservice::query::query_room_alias::v1::Request { room_alias: room_alias.clone(), }, ) .await - .is_ok() + { + opt_result.is_ok() + } else { + false + } { room_id = Some( services() diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index bbd48614..e841f13a 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -81,12 +81,9 @@ where let mut json_body = serde_json::from_slice::(&body).ok(); let appservices = services().appservice.all().unwrap(); - let appservice_registration = appservices.iter().find(|(_id, registration)| { - registration - .get("as_token") - .and_then(|as_token| as_token.as_str()) - .map_or(false, |as_token| token == Some(as_token)) - }); + let appservice_registration = appservices + .iter() + .find(|(_id, registration)| Some(registration.as_token.as_str()) == token); let (sender_user, sender_device, sender_servername, from_appservice) = if let Some((_id, registration)) = appservice_registration { @@ -95,11 +92,7 @@ where let user_id = query_params.user_id.map_or_else( || { UserId::parse_with_server_name( - registration - .get("sender_localpart") - .unwrap() - .as_str() - .unwrap(), + registration.sender_localpart.as_str(), services().globals.server_name(), ) .unwrap() diff --git a/src/database/key_value/appservice.rs b/src/database/key_value/appservice.rs index 9a821a65..3243183d 100644 --- a/src/database/key_value/appservice.rs +++ b/src/database/key_value/appservice.rs @@ -1,10 +1,11 @@ +use ruma::api::appservice::Registration; + use crate::{database::KeyValueDatabase, service, utils, Error, Result}; impl service::appservice::Data for KeyValueDatabase { /// Registers an appservice and returns the ID to the caller - fn register_appservice(&self, yaml: serde_yaml::Value) -> Result { - // TODO: Rumaify - let id = yaml.get("id").unwrap().as_str().unwrap(); + fn register_appservice(&self, yaml: Registration) -> Result { + let id = yaml.id.as_str(); self.id_appserviceregistrations.insert( id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes(), @@ -32,7 +33,7 @@ impl service::appservice::Data for KeyValueDatabase { Ok(()) } - fn get_registration(&self, id: &str) -> Result> { + fn get_registration(&self, id: &str) -> Result> { self.cached_registrations .read() .unwrap() @@ -64,7 +65,7 @@ impl service::appservice::Data for KeyValueDatabase { ))) } - fn all(&self) -> Result> { + fn all(&self) -> Result> { self.iter_ids()? .filter_map(|id| id.ok()) .map(move |id| { diff --git a/src/database/key_value/rooms/state_cache.rs b/src/database/key_value/rooms/state_cache.rs index 3dcaf4ae..49e3842b 100644 --- a/src/database/key_value/rooms/state_cache.rs +++ b/src/database/key_value/rooms/state_cache.rs @@ -1,13 +1,16 @@ use std::{collections::HashSet, sync::Arc}; -use regex::Regex; use ruma::{ events::{AnyStrippedStateEvent, AnySyncStateEvent}, serde::Raw, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; -use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +use crate::{ + database::KeyValueDatabase, + service::{self, appservice::RegistrationInfo}, + services, utils, Error, Result, +}; impl service::rooms::state_cache::Data for KeyValueDatabase { fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { @@ -184,46 +187,28 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { } #[tracing::instrument(skip(self, room_id, appservice))] - fn appservice_in_room( - &self, - room_id: &RoomId, - appservice: &(String, serde_yaml::Value), - ) -> Result { + fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result { let maybe = self .appservice_in_room_cache .read() .unwrap() .get(room_id) - .and_then(|map| map.get(&appservice.0)) + .and_then(|map| map.get(&appservice.registration.id)) .copied(); if let Some(b) = maybe { Ok(b) - } else if let Some(namespaces) = appservice.1.get("namespaces") { - let users = namespaces - .get("users") - .and_then(|users| users.as_sequence()) - .map_or_else(Vec::new, |users| { - users - .iter() - .filter_map(|users| Regex::new(users.get("regex")?.as_str()?).ok()) - .collect::>() - }); - - let bridge_user_id = appservice - .1 - .get("sender_localpart") - .and_then(|string| string.as_str()) - .and_then(|string| { - UserId::parse_with_server_name(string, services().globals.server_name()).ok() - }); + } else { + let bridge_user_id = UserId::parse_with_server_name( + appservice.registration.sender_localpart.as_str(), + services().globals.server_name(), + ) + .ok(); let in_room = bridge_user_id .map_or(false, |id| self.is_joined(&id, room_id).unwrap_or(false)) || self.room_members(room_id).any(|userid| { - userid.map_or(false, |userid| { - users.iter().any(|r| r.is_match(userid.as_str())) - }) + userid.map_or(false, |userid| appservice.users.is_match(userid.as_str())) }); self.appservice_in_room_cache @@ -231,11 +216,9 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { .unwrap() .entry(room_id.to_owned()) .or_default() - .insert(appservice.0.clone(), in_room); + .insert(appservice.registration.id.clone(), in_room); Ok(in_room) - } else { - Ok(false) } } diff --git a/src/database/mod.rs b/src/database/mod.rs index 425ef4e9..5b8588cf 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -8,7 +8,9 @@ use crate::{ use abstraction::{KeyValueDatabaseEngine, KvTree}; use directories::ProjectDirs; use lru_cache::LruCache; + use ruma::{ + api::appservice::Registration, events::{ push_rules::{PushRulesEvent, PushRulesEventContent}, room::message::RoomMessageEventContent, @@ -162,7 +164,7 @@ pub struct KeyValueDatabase { //pub pusher: pusher::PushData, pub(super) senderkey_pusher: Arc, - pub(super) cached_registrations: Arc>>, + pub(super) cached_registrations: Arc>>, pub(super) pdu_cache: Mutex>>, pub(super) shorteventid_cache: Mutex>>, pub(super) auth_chain_cache: Mutex, Arc>>>, @@ -967,6 +969,22 @@ impl KeyValueDatabase { ); } + // Inserting registrations into cache + for appservice in services().appservice.all()? { + services() + .appservice + .registration_info + .write() + .await + .insert( + appservice.0, + appservice + .1 + .try_into() + .expect("Should be validated on registration"), + ); + } + // This data is probably outdated db.presenceid_presence.clear()?; diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index d99be878..12bc1cf6 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -8,6 +8,7 @@ use std::{ use clap::Parser; use regex::Regex; use ruma::{ + api::appservice::Registration, events::{ room::{ canonical_alias::RoomCanonicalAliasEventContent, @@ -335,10 +336,9 @@ impl Service { if body.len() > 2 && body[0].trim() == "```" && body.last().unwrap().trim() == "```" { let appservice_config = body[1..body.len() - 1].join("\n"); - let parsed_config = - serde_yaml::from_str::(&appservice_config); + let parsed_config = serde_yaml::from_str::(&appservice_config); match parsed_config { - Ok(yaml) => match services().appservice.register_appservice(yaml) { + Ok(yaml) => match services().appservice.register_appservice(yaml).await { Ok(id) => RoomMessageEventContent::text_plain(format!( "Appservice registered with ID: {id}." )), @@ -361,6 +361,7 @@ impl Service { } => match services() .appservice .unregister_appservice(&appservice_identifier) + .await { Ok(()) => RoomMessageEventContent::text_plain("Appservice unregistered."), Err(e) => RoomMessageEventContent::text_plain(format!( diff --git a/src/service/appservice/data.rs b/src/service/appservice/data.rs index 744f0f94..ab19a50c 100644 --- a/src/service/appservice/data.rs +++ b/src/service/appservice/data.rs @@ -1,8 +1,10 @@ +use ruma::api::appservice::Registration; + use crate::Result; pub trait Data: Send + Sync { /// Registers an appservice and returns the ID to the caller - fn register_appservice(&self, yaml: serde_yaml::Value) -> Result; + fn register_appservice(&self, yaml: Registration) -> Result; /// Remove an appservice registration /// @@ -11,9 +13,9 @@ pub trait Data: Send + Sync { /// * `service_name` - the name you send to register the service previously fn unregister_appservice(&self, service_name: &str) -> Result<()>; - fn get_registration(&self, id: &str) -> Result>; + fn get_registration(&self, id: &str) -> Result>; fn iter_ids<'a>(&'a self) -> Result> + 'a>>; - fn all(&self) -> Result>; + fn all(&self) -> Result>; } diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index 3052964d..40fa3ee8 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -1,16 +1,113 @@ mod data; +use std::collections::HashMap; + pub use data::Data; -use crate::Result; +use regex::RegexSet; +use ruma::api::appservice::{Namespace, Registration}; +use tokio::sync::RwLock; + +use crate::{services, Result}; + +/// Compiled regular expressions for a namespace +pub struct NamespaceRegex { + pub exclusive: Option, + pub non_exclusive: Option, +} + +impl NamespaceRegex { + /// Checks if this namespace has rights to a namespace + pub fn is_match(&self, heystack: &str) -> bool { + if self.is_exclusive_match(heystack) { + return true; + } + + if let Some(non_exclusive) = &self.non_exclusive { + if non_exclusive.is_match(heystack) { + return true; + } + } + false + } + + /// Checks if this namespace has exlusive rights to a namespace + pub fn is_exclusive_match(&self, heystack: &str) -> bool { + if let Some(exclusive) = &self.exclusive { + if exclusive.is_match(heystack) { + return true; + } + } + false + } +} + +impl TryFrom> for NamespaceRegex { + fn try_from(value: Vec) -> Result { + let mut exclusive = vec![]; + let mut non_exclusive = vec![]; + + for namespace in value { + if namespace.exclusive { + exclusive.push(namespace.regex); + } else { + non_exclusive.push(namespace.regex); + } + } + + Ok(NamespaceRegex { + exclusive: if exclusive.is_empty() { + None + } else { + Some(RegexSet::new(exclusive)?) + }, + non_exclusive: if non_exclusive.is_empty() { + None + } else { + Some(RegexSet::new(non_exclusive)?) + }, + }) + } + + type Error = regex::Error; +} + +/// Compiled regular expressions for an appservice +pub struct RegistrationInfo { + pub registration: Registration, + pub users: NamespaceRegex, + pub aliases: NamespaceRegex, + pub rooms: NamespaceRegex, +} + +impl TryFrom for RegistrationInfo { + fn try_from(value: Registration) -> Result { + Ok(RegistrationInfo { + users: value.namespaces.users.clone().try_into()?, + aliases: value.namespaces.aliases.clone().try_into()?, + rooms: value.namespaces.rooms.clone().try_into()?, + registration: value, + }) + } + + type Error = regex::Error; +} pub struct Service { pub db: &'static dyn Data, + pub registration_info: RwLock>, } impl Service { /// Registers an appservice and returns the ID to the caller - pub fn register_appservice(&self, yaml: serde_yaml::Value) -> Result { + pub async fn register_appservice(&self, yaml: Registration) -> Result { + services() + .appservice + .registration_info + .write() + .await + .insert(yaml.id.clone(), yaml.clone().try_into()?); + self.db.register_appservice(yaml) } @@ -19,11 +116,18 @@ impl Service { /// # Arguments /// /// * `service_name` - the name you send to register the service previously - pub fn unregister_appservice(&self, service_name: &str) -> Result<()> { + pub async fn unregister_appservice(&self, service_name: &str) -> Result<()> { + services() + .appservice + .registration_info + .write() + .await + .remove(service_name); + self.db.unregister_appservice(service_name) } - pub fn get_registration(&self, id: &str) -> Result> { + pub fn get_registration(&self, id: &str) -> Result> { self.db.get_registration(id) } @@ -31,7 +135,7 @@ impl Service { self.db.iter_ids() } - pub fn all(&self) -> Result> { + pub fn all(&self) -> Result> { self.db.all() } } diff --git a/src/service/mod.rs b/src/service/mod.rs index 8f9fb0a5..045ccd10 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -4,7 +4,7 @@ use std::{ }; use lru_cache::LruCache; -use tokio::sync::Mutex; +use tokio::sync::{Mutex, RwLock}; use crate::{Config, Result}; @@ -56,7 +56,10 @@ impl Services { config: Config, ) -> Result { Ok(Self { - appservice: appservice::Service { db }, + appservice: appservice::Service { + db, + registration_info: RwLock::new(HashMap::new()), + }, pusher: pusher::Service { db }, rooms: rooms::Service { alias: rooms::alias::Service { db }, diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index 8921909f..b511919a 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -1,6 +1,6 @@ use std::{collections::HashSet, sync::Arc}; -use crate::Result; +use crate::{service::appservice::RegistrationInfo, Result}; use ruma::{ events::{AnyStrippedStateEvent, AnySyncStateEvent}, serde::Raw, @@ -22,11 +22,7 @@ pub trait Data: Send + Sync { fn get_our_real_users(&self, room_id: &RoomId) -> Result>>; - fn appservice_in_room( - &self, - room_id: &RoomId, - appservice: &(String, serde_yaml::Value), - ) -> Result; + fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result; /// Makes a user forget a room. fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()>; diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index ef1ad61e..c108695d 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -16,7 +16,7 @@ use ruma::{ }; use tracing::warn; -use crate::{services, Error, Result}; +use crate::{service::appservice::RegistrationInfo, services, Error, Result}; pub struct Service { pub db: &'static dyn Data, @@ -205,7 +205,7 @@ impl Service { pub fn appservice_in_room( &self, room_id: &RoomId, - appservice: &(String, serde_yaml::Value), + appservice: &RegistrationInfo, ) -> Result { self.db.appservice_in_room(room_id, appservice) } diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 097cc82f..1df1db50 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -7,7 +7,7 @@ use std::{ }; pub use data::Data; -use regex::Regex; + use ruma::{ api::{client::error::ErrorKind, federation}, canonical_json::to_canonical_value, @@ -21,8 +21,7 @@ use ruma::{ }, push::{Action, Ruleset, Tweak}, serde::Base64, - state_res, - state_res::{Event, RoomVersion}, + state_res::{self, Event, RoomVersion}, uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, OwnedRoomId, OwnedServerName, RoomId, ServerName, UserId, }; @@ -33,7 +32,10 @@ use tracing::{error, info, warn}; use crate::{ api::server_server, - service::pdu::{EventHash, PduBuilder}, + service::{ + appservice::NamespaceRegex, + pdu::{EventHash, PduBuilder}, + }, services, utils, Error, PduEvent, Result, }; @@ -522,15 +524,21 @@ impl Service { } } - for appservice in services().appservice.all()? { + for appservice in services() + .appservice + .registration_info + .read() + .await + .values() + { if services() .rooms .state_cache - .appservice_in_room(&pdu.room_id, &appservice)? + .appservice_in_room(&pdu.room_id, appservice)? { services() .sending - .send_pdu_appservice(appservice.0, pdu_id.clone())?; + .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; continue; } @@ -542,73 +550,41 @@ impl Service { .as_ref() .and_then(|state_key| UserId::parse(state_key.as_str()).ok()) { - if let Some(appservice_uid) = appservice - .1 - .get("sender_localpart") - .and_then(|string| string.as_str()) - .and_then(|string| { - UserId::parse_with_server_name(string, services().globals.server_name()) - .ok() - }) - { - if state_key_uid == &appservice_uid { - services() - .sending - .send_pdu_appservice(appservice.0, pdu_id.clone())?; - continue; - } + let appservice_uid = appservice.registration.sender_localpart.as_str(); + if state_key_uid == appservice_uid { + services().sending.send_pdu_appservice( + appservice.registration.id.clone(), + pdu_id.clone(), + )?; + continue; } } } - if let Some(namespaces) = appservice.1.get("namespaces") { - let users = namespaces - .get("users") - .and_then(|users| users.as_sequence()) - .map_or_else(Vec::new, |users| { - users - .iter() - .filter_map(|users| Regex::new(users.get("regex")?.as_str()?).ok()) - .collect::>() - }); - let aliases = namespaces - .get("aliases") - .and_then(|aliases| aliases.as_sequence()) - .map_or_else(Vec::new, |aliases| { - aliases - .iter() - .filter_map(|aliases| Regex::new(aliases.get("regex")?.as_str()?).ok()) - .collect::>() - }); - let rooms = namespaces - .get("rooms") - .and_then(|rooms| rooms.as_sequence()); - - let matching_users = |users: &Regex| { - users.is_match(pdu.sender.as_str()) - || pdu.kind == TimelineEventType::RoomMember - && pdu - .state_key - .as_ref() - .map_or(false, |state_key| users.is_match(state_key)) - }; - let matching_aliases = |aliases: &Regex| { - services() - .rooms - .alias - .local_aliases_for_room(&pdu.room_id) - .filter_map(|r| r.ok()) - .any(|room_alias| aliases.is_match(room_alias.as_str())) - }; - - if aliases.iter().any(matching_aliases) - || rooms.map_or(false, |rooms| rooms.contains(&pdu.room_id.as_str().into())) - || users.iter().any(matching_users) - { - services() - .sending - .send_pdu_appservice(appservice.0, pdu_id.clone())?; - } + let matching_users = |users: &NamespaceRegex| { + appservice.users.is_match(pdu.sender.as_str()) + || pdu.kind == TimelineEventType::RoomMember + && pdu + .state_key + .as_ref() + .map_or(false, |state_key| users.is_match(state_key)) + }; + let matching_aliases = |aliases: &NamespaceRegex| { + services() + .rooms + .alias + .local_aliases_for_room(&pdu.room_id) + .filter_map(|r| r.ok()) + .any(|room_alias| aliases.is_match(room_alias.as_str())) + }; + + if matching_aliases(&appservice.aliases) + || appservice.rooms.is_match(pdu.room_id.as_str()) + || matching_users(&appservice.users) + { + services() + .sending + .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; } } diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index f80c4f0a..bbacfdec 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -22,7 +22,7 @@ use base64::{engine::general_purpose, Engine as _}; use ruma::{ api::{ - appservice, + appservice::{self, Registration}, federation::{ self, transactions::edu::{ @@ -484,7 +484,7 @@ impl Service { let permit = services().sending.maximum_requests.acquire().await; - let response = appservice_server::send_request( + let response = match appservice_server::send_request( services() .appservice .get_registration(id) @@ -511,8 +511,12 @@ impl Service { }, ) .await - .map(|_response| kind.clone()) - .map_err(|e| (kind, e)); + { + None => Ok(kind.clone()), + Some(op_resp) => op_resp + .map(|_response| kind.clone()) + .map_err(|e| (kind.clone(), e)), + }; drop(permit); @@ -698,12 +702,15 @@ impl Service { response } + /// Sends a request to an appservice + /// + /// Only returns None if there is no url specified in the appservice registration file #[tracing::instrument(skip(self, registration, request))] pub async fn send_appservice_request( &self, - registration: serde_yaml::Value, + registration: Registration, request: T, - ) -> Result + ) -> Option> where T: Debug, { diff --git a/src/utils/error.rs b/src/utils/error.rs index 765a31bb..04390283 100644 --- a/src/utils/error.rs +++ b/src/utils/error.rs @@ -54,6 +54,11 @@ pub enum Error { #[from] source: reqwest::Error, }, + #[error("Could build regular expression: {source}")] + RegexError { + #[from] + source: regex::Error, + }, #[error("{0}")] FederationError(OwnedServerName, RumaError), #[error("Could not do this io: {source}")] From 0bb28f60cfc68beaa521bb7dacdec7a9f92c2288 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Fri, 22 Mar 2024 08:52:39 +0100 Subject: [PATCH 2/4] refactor: minor appservice code cleanup --- src/api/appservice_server.rs | 157 ++++++++++++++------------- src/api/client_server/alias.rs | 8 +- src/api/ruma_wrapper/axum.rs | 8 +- src/database/key_value/appservice.rs | 35 ++---- src/database/mod.rs | 19 ---- src/service/appservice/mod.rs | 39 +++++-- src/service/mod.rs | 7 +- src/service/rooms/timeline/mod.rs | 10 +- 8 files changed, 128 insertions(+), 155 deletions(-) diff --git a/src/api/appservice_server.rs b/src/api/appservice_server.rs index ab4da79f..841c32a1 100644 --- a/src/api/appservice_server.rs +++ b/src/api/appservice_server.rs @@ -17,95 +17,96 @@ pub(crate) async fn send_request( where T: Debug, { - if let Some(destination) = registration.url { - let hs_token = registration.hs_token.as_str(); + let Some(destination) = registration.url else { + return None; + }; - let mut http_request = request - .try_into_http_request::( - &destination, - SendAccessToken::IfRequired(hs_token), - &[MatrixVersion::V1_0], - ) - .unwrap() - .map(|body| body.freeze()); + let hs_token = registration.hs_token.as_str(); - let mut parts = http_request.uri().clone().into_parts(); - let old_path_and_query = parts.path_and_query.unwrap().as_str().to_owned(); - let symbol = if old_path_and_query.contains('?') { - "&" - } else { - "?" - }; + let mut http_request = request + .try_into_http_request::( + &destination, + SendAccessToken::IfRequired(hs_token), + &[MatrixVersion::V1_0], + ) + .unwrap() + .map(|body| body.freeze()); - parts.path_and_query = Some( - (old_path_and_query + symbol + "access_token=" + hs_token) - .parse() - .unwrap(), - ); - *http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid"); - - let mut reqwest_request = reqwest::Request::try_from(http_request) - .expect("all http requests are valid reqwest requests"); + let mut parts = http_request.uri().clone().into_parts(); + let old_path_and_query = parts.path_and_query.unwrap().as_str().to_owned(); + let symbol = if old_path_and_query.contains('?') { + "&" + } else { + "?" + }; - *reqwest_request.timeout_mut() = Some(Duration::from_secs(30)); + parts.path_and_query = Some( + (old_path_and_query + symbol + "access_token=" + hs_token) + .parse() + .unwrap(), + ); + *http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid"); - let url = reqwest_request.url().clone(); - let mut response = match services() - .globals - .default_client() - .execute(reqwest_request) - .await - { - Ok(r) => r, - Err(e) => { - warn!( - "Could not send request to appservice {:?} at {}: {}", - registration.id, destination, e - ); - return Some(Err(e.into())); - } - }; + let mut reqwest_request = reqwest::Request::try_from(http_request) + .expect("all http requests are valid reqwest requests"); - // reqwest::Response -> http::Response conversion - let status = response.status(); - let mut http_response_builder = http::Response::builder() - .status(status) - .version(response.version()); - mem::swap( - response.headers_mut(), - http_response_builder - .headers_mut() - .expect("http::response::Builder is usable"), - ); + *reqwest_request.timeout_mut() = Some(Duration::from_secs(30)); - let body = response.bytes().await.unwrap_or_else(|e| { - warn!("server error: {}", e); - Vec::new().into() - }); // TODO: handle timeout - - if status != 200 { + let url = reqwest_request.url().clone(); + let mut response = match services() + .globals + .default_client() + .execute(reqwest_request) + .await + { + Ok(r) => r, + Err(e) => { warn!( - "Appservice returned bad response {} {}\n{}\n{:?}", - destination, - status, - url, - utils::string_from_bytes(&body) + "Could not send request to appservice {:?} at {}: {}", + registration.id, destination, e ); + return Some(Err(e.into())); } + }; + + // reqwest::Response -> http::Response conversion + let status = response.status(); + let mut http_response_builder = http::Response::builder() + .status(status) + .version(response.version()); + mem::swap( + response.headers_mut(), + http_response_builder + .headers_mut() + .expect("http::response::Builder is usable"), + ); + + let body = response.bytes().await.unwrap_or_else(|e| { + warn!("server error: {}", e); + Vec::new().into() + }); // TODO: handle timeout - let response = T::IncomingResponse::try_from_http_response( - http_response_builder - .body(body) - .expect("reqwest body is valid http body"), + if status != 200 { + warn!( + "Appservice returned bad response {} {}\n{}\n{:?}", + destination, + status, + url, + utils::string_from_bytes(&body) ); - Some(response.map_err(|_| { - warn!( - "Appservice returned invalid response bytes {}\n{}", - destination, url - ); - Error::BadServerResponse("Server returned bad response.") - })) - } else { - None } + + let response = T::IncomingResponse::try_from_http_response( + http_response_builder + .body(body) + .expect("reqwest body is valid http body"), + ); + + Some(response.map_err(|_| { + warn!( + "Appservice returned invalid response bytes {}\n{}", + destination, url + ); + Error::BadServerResponse("Server returned bad response.") + })) } diff --git a/src/api/client_server/alias.rs b/src/api/client_server/alias.rs index d3a6e39a..00ee6c85 100644 --- a/src/api/client_server/alias.rs +++ b/src/api/client_server/alias.rs @@ -100,13 +100,7 @@ pub(crate) async fn get_alias_helper( match services().rooms.alias.resolve_local_alias(&room_alias)? { Some(r) => room_id = Some(r), None => { - for appservice in services() - .appservice - .registration_info - .read() - .await - .values() - { + for appservice in services().appservice.all().await { if appservice.aliases.is_match(room_alias.as_str()) && if let Some(opt_result) = services() .sending diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index e841f13a..6411ab9d 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -80,19 +80,19 @@ where let mut json_body = serde_json::from_slice::(&body).ok(); - let appservices = services().appservice.all().unwrap(); + let appservices = services().appservice.all().await; let appservice_registration = appservices .iter() - .find(|(_id, registration)| Some(registration.as_token.as_str()) == token); + .find(|info| Some(info.registration.as_token.as_str()) == token); let (sender_user, sender_device, sender_servername, from_appservice) = - if let Some((_id, registration)) = appservice_registration { + if let Some(info) = appservice_registration { match metadata.authentication { AuthScheme::AccessToken => { let user_id = query_params.user_id.map_or_else( || { UserId::parse_with_server_name( - registration.sender_localpart.as_str(), + info.registration.sender_localpart.as_str(), services().globals.server_name(), ) .unwrap() diff --git a/src/database/key_value/appservice.rs b/src/database/key_value/appservice.rs index 3243183d..b547e66a 100644 --- a/src/database/key_value/appservice.rs +++ b/src/database/key_value/appservice.rs @@ -10,10 +10,6 @@ impl service::appservice::Data for KeyValueDatabase { id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes(), )?; - self.cached_registrations - .write() - .unwrap() - .insert(id.to_owned(), yaml.to_owned()); Ok(id.to_owned()) } @@ -26,33 +22,18 @@ impl service::appservice::Data for KeyValueDatabase { fn unregister_appservice(&self, service_name: &str) -> Result<()> { self.id_appserviceregistrations .remove(service_name.as_bytes())?; - self.cached_registrations - .write() - .unwrap() - .remove(service_name); Ok(()) } fn get_registration(&self, id: &str) -> Result> { - self.cached_registrations - .read() - .unwrap() - .get(id) - .map_or_else( - || { - self.id_appserviceregistrations - .get(id.as_bytes())? - .map(|bytes| { - serde_yaml::from_slice(&bytes).map_err(|_| { - Error::bad_database( - "Invalid registration bytes in id_appserviceregistrations.", - ) - }) - }) - .transpose() - }, - |r| Ok(Some(r.clone())), - ) + self.id_appserviceregistrations + .get(id.as_bytes())? + .map(|bytes| { + serde_yaml::from_slice(&bytes).map_err(|_| { + Error::bad_database("Invalid registration bytes in id_appserviceregistrations.") + }) + }) + .transpose() } fn iter_ids<'a>(&'a self) -> Result> + 'a>> { diff --git a/src/database/mod.rs b/src/database/mod.rs index 5b8588cf..190e7e12 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -10,7 +10,6 @@ use directories::ProjectDirs; use lru_cache::LruCache; use ruma::{ - api::appservice::Registration, events::{ push_rules::{PushRulesEvent, PushRulesEventContent}, room::message::RoomMessageEventContent, @@ -164,7 +163,6 @@ pub struct KeyValueDatabase { //pub pusher: pusher::PushData, pub(super) senderkey_pusher: Arc, - pub(super) cached_registrations: Arc>>, pub(super) pdu_cache: Mutex>>, pub(super) shorteventid_cache: Mutex>>, pub(super) auth_chain_cache: Mutex, Arc>>>, @@ -374,7 +372,6 @@ impl KeyValueDatabase { global: builder.open_tree("global")?, server_signingkeys: builder.open_tree("server_signingkeys")?, - cached_registrations: Arc::new(RwLock::new(HashMap::new())), pdu_cache: Mutex::new(LruCache::new( config .pdu_cache_capacity @@ -969,22 +966,6 @@ impl KeyValueDatabase { ); } - // Inserting registrations into cache - for appservice in services().appservice.all()? { - services() - .appservice - .registration_info - .write() - .await - .insert( - appservice.0, - appservice - .1 - .try_into() - .expect("Should be validated on registration"), - ); - } - // This data is probably outdated db.presenceid_presence.clear()?; diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index 40fa3ee8..6b9e21f1 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -10,7 +10,8 @@ use tokio::sync::RwLock; use crate::{services, Result}; -/// Compiled regular expressions for a namespace +/// Compiled regular expressions for a namespace. +#[derive(Clone, Debug)] pub struct NamespaceRegex { pub exclusive: Option, pub non_exclusive: Option, @@ -72,7 +73,8 @@ impl TryFrom> for NamespaceRegex { type Error = regex::Error; } -/// Compiled regular expressions for an appservice +/// Appservice registration combined with its compiled regular expressions. +#[derive(Clone, Debug)] pub struct RegistrationInfo { pub registration: Registration, pub users: NamespaceRegex, @@ -95,11 +97,29 @@ impl TryFrom for RegistrationInfo { pub struct Service { pub db: &'static dyn Data, - pub registration_info: RwLock>, + registration_info: RwLock>, } impl Service { - /// Registers an appservice and returns the ID to the caller + pub fn build(db: &'static dyn Data) -> Result { + let mut registration_info = HashMap::new(); + // Inserting registrations into cache + for appservice in db.all()? { + registration_info.insert( + appservice.0, + appservice + .1 + .try_into() + .expect("Should be validated on registration"), + ); + } + + Ok(Self { + db, + registration_info: RwLock::new(registration_info), + }) + } + /// Registers an appservice and returns the ID to the caller. pub async fn register_appservice(&self, yaml: Registration) -> Result { services() .appservice @@ -111,7 +131,7 @@ impl Service { self.db.register_appservice(yaml) } - /// Remove an appservice registration + /// Removes an appservice registration. /// /// # Arguments /// @@ -135,7 +155,12 @@ impl Service { self.db.iter_ids() } - pub fn all(&self) -> Result> { - self.db.all() + pub async fn all(&self) -> Vec { + self.registration_info + .read() + .await + .values() + .cloned() + .collect() } } diff --git a/src/service/mod.rs b/src/service/mod.rs index 045ccd10..0cbe6a82 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -4,7 +4,7 @@ use std::{ }; use lru_cache::LruCache; -use tokio::sync::{Mutex, RwLock}; +use tokio::sync::Mutex; use crate::{Config, Result}; @@ -56,10 +56,7 @@ impl Services { config: Config, ) -> Result { Ok(Self { - appservice: appservice::Service { - db, - registration_info: RwLock::new(HashMap::new()), - }, + appservice: appservice::Service::build(db)?, pusher: pusher::Service { db }, rooms: rooms::Service { alias: rooms::alias::Service { db }, diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 1df1db50..035513d6 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -524,17 +524,11 @@ impl Service { } } - for appservice in services() - .appservice - .registration_info - .read() - .await - .values() - { + for appservice in services().appservice.all().await { if services() .rooms .state_cache - .appservice_in_room(&pdu.room_id, appservice)? + .appservice_in_room(&pdu.room_id, &appservice)? { services() .sending From 5c650bb67e0f21a9080f54fc443695c7da0783af Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Fri, 22 Mar 2024 17:51:15 +0000 Subject: [PATCH 3/4] refactor: use BTreeMap for cached registration info --- src/service/appservice/mod.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index 6b9e21f1..d9ab9eb1 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -1,6 +1,6 @@ mod data; -use std::collections::HashMap; +use std::collections::BTreeMap; pub use data::Data; @@ -97,12 +97,12 @@ impl TryFrom for RegistrationInfo { pub struct Service { pub db: &'static dyn Data, - registration_info: RwLock>, + registration_info: RwLock>, } impl Service { pub fn build(db: &'static dyn Data) -> Result { - let mut registration_info = HashMap::new(); + let mut registration_info = BTreeMap::new(); // Inserting registrations into cache for appservice in db.all()? { registration_info.insert( From b20483aa13399822fe047349d1af7678f6777259 Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Fri, 22 Mar 2024 18:27:14 +0000 Subject: [PATCH 4/4] refactor(appservices): avoid cloning frequently --- src/api/appservice_server.rs | 4 +--- src/api/client_server/alias.rs | 2 +- src/api/ruma_wrapper/axum.rs | 9 +++++---- src/service/admin/mod.rs | 26 +++++++----------------- src/service/appservice/mod.rs | 33 +++++++++++++++++++++++-------- src/service/rooms/timeline/mod.rs | 4 ++-- src/service/sending/mod.rs | 2 +- 7 files changed, 42 insertions(+), 38 deletions(-) diff --git a/src/api/appservice_server.rs b/src/api/appservice_server.rs index 841c32a1..213e4c09 100644 --- a/src/api/appservice_server.rs +++ b/src/api/appservice_server.rs @@ -17,9 +17,7 @@ pub(crate) async fn send_request( where T: Debug, { - let Some(destination) = registration.url else { - return None; - }; + let destination = registration.url?; let hs_token = registration.hs_token.as_str(); diff --git a/src/api/client_server/alias.rs b/src/api/client_server/alias.rs index 00ee6c85..bc3a5e25 100644 --- a/src/api/client_server/alias.rs +++ b/src/api/client_server/alias.rs @@ -100,7 +100,7 @@ pub(crate) async fn get_alias_helper( match services().rooms.alias.resolve_local_alias(&room_alias)? { Some(r) => room_id = Some(r), None => { - for appservice in services().appservice.all().await { + for appservice in services().appservice.read().await.values() { if appservice.aliases.is_match(room_alias.as_str()) && if let Some(opt_result) = services() .sending diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index 6411ab9d..8ba9fa52 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -80,10 +80,11 @@ where let mut json_body = serde_json::from_slice::(&body).ok(); - let appservices = services().appservice.all().await; - let appservice_registration = appservices - .iter() - .find(|info| Some(info.registration.as_token.as_str()) == token); + let appservice_registration = if let Some(token) = token { + services().appservice.find_from_token(token).await + } else { + None + }; let (sender_user, sender_device, sender_servername, from_appservice) = if let Some(info) = appservice_registration { diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index 12bc1cf6..f2f60a7a 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -369,25 +369,13 @@ impl Service { )), }, AdminCommand::ListAppservices => { - if let Ok(appservices) = services() - .appservice - .iter_ids() - .map(|ids| ids.collect::>()) - { - let count = appservices.len(); - let output = format!( - "Appservices ({}): {}", - count, - appservices - .into_iter() - .filter_map(|r| r.ok()) - .collect::>() - .join(", ") - ); - RoomMessageEventContent::text_plain(output) - } else { - RoomMessageEventContent::text_plain("Failed to get appservices.") - } + let appservices = services().appservice.iter_ids().await; + let output = format!( + "Appservices ({}): {}", + appservices.len(), + appservices.join(", ") + ); + RoomMessageEventContent::text_plain(output) } AdminCommand::ListRooms => { let room_ids = services().rooms.metadata.iter_ids(); diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index d9ab9eb1..4bda8961 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -4,6 +4,7 @@ use std::collections::BTreeMap; pub use data::Data; +use futures_util::Future; use regex::RegexSet; use ruma::api::appservice::{Namespace, Registration}; use tokio::sync::RwLock; @@ -147,20 +148,36 @@ impl Service { self.db.unregister_appservice(service_name) } - pub fn get_registration(&self, id: &str) -> Result> { - self.db.get_registration(id) - } - - pub fn iter_ids(&self) -> Result> + '_> { - self.db.iter_ids() + pub async fn get_registration(&self, id: &str) -> Option { + self.registration_info + .read() + .await + .get(id) + .cloned() + .map(|info| info.registration) } - pub async fn all(&self) -> Vec { + pub async fn iter_ids(&self) -> Vec { self.registration_info .read() .await - .values() + .keys() .cloned() .collect() } + + pub async fn find_from_token(&self, token: &str) -> Option { + self.read() + .await + .values() + .find(|info| info.registration.as_token == token) + .cloned() + } + + pub fn read( + &self, + ) -> impl Future>> + { + self.registration_info.read() + } } diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 035513d6..379d97fe 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -524,11 +524,11 @@ impl Service { } } - for appservice in services().appservice.all().await { + for appservice in services().appservice.read().await.values() { if services() .rooms .state_cache - .appservice_in_room(&pdu.room_id, &appservice)? + .appservice_in_room(&pdu.room_id, appservice)? { services() .sending diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index bbacfdec..45cca173 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -488,7 +488,7 @@ impl Service { services() .appservice .get_registration(id) - .map_err(|e| (kind.clone(), e))? + .await .ok_or_else(|| { ( kind.clone(),