diff --git a/Cargo.toml b/Cargo.toml index 64b7a233..10be7501 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ tokio = { version = "1.11.0", features = ["fs", "macros", "signal", "sync"] } # Used for storing data permanently sled = { version = "0.34.6", features = ["compression", "no_metrics"], optional = true } #sled = { git = "https://github.com/spacejam/sled.git", rev = "e4640e0773595229f398438886f19bca6f7326a2", features = ["compression"] } -persy = { version = "1.2" , optional = true, features=["background_ops"] } +persy = { version = "1.2" , optional = true, features = ["background_ops"] } # Used for the http request / response body type for Ruma endpoints used with reqwest bytes = "1.1.0" @@ -64,7 +64,7 @@ regex = "1.5.4" # jwt jsonwebtokens jsonwebtoken = "7.2.0" # Performance measurements -tracing = { version = "0.1.26", features = ["release_max_level_warn"] } +tracing = { version = "0.1.26", features = [] } tracing-subscriber = "0.2.20" tracing-flame = "0.1.0" opentelemetry = { version = "0.16.0", features = ["rt-tokio"] } @@ -76,7 +76,7 @@ crossbeam = { version = "0.8.1", optional = true } num_cpus = "1.13.0" threadpool = "1.8.1" heed = { git = "https://github.com/timokoesters/heed.git", rev = "f6f825da7fb2c758867e05ad973ef800a6fe1d5d", optional = true } -rocksdb = { version = "0.17.0", default-features = false, features = ["multi-threaded-cf", "zstd"], optional = true } +rocksdb = { version = "0.17.0", default-features = true, features = ["multi-threaded-cf", "zstd"], optional = true } thread_local = "1.1.3" # used for TURN server authentication diff --git a/src/client_server/context.rs b/src/client_server/context.rs index de7aae93..e93f5a5b 100644 --- a/src/client_server/context.rs +++ b/src/client_server/context.rs @@ -137,7 +137,7 @@ pub async fn get_context_route( .expect("All rooms have state"), }; - let state_ids = db.rooms.state_full_ids(shortstatehash)?; + let state_ids = db.rooms.state_full_ids(shortstatehash).await?; let end_token = events_after .last() diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs index 65107a3c..a1b616be 100644 --- a/src/client_server/membership.rs +++ b/src/client_server/membership.rs @@ -29,7 +29,7 @@ use ruma::{ }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use std::{ - collections::{hash_map::Entry, BTreeMap, HashMap, HashSet}, + collections::{hash_map::Entry, BTreeMap, HashMap}, iter, sync::{Arc, RwLock}, time::{Duration, Instant}, @@ -48,19 +48,20 @@ pub async fn join_room_by_id_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let mut servers: HashSet<_> = db - .rooms - .invite_state(sender_user, &body.room_id)? - .unwrap_or_default() - .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(|s| s.to_owned())) - .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()) - .collect(); - - servers.insert(body.room_id.server_name().to_owned()); + let mut servers = Vec::new(); // There is no body.server_name for /roomId/join + servers.extend( + db.rooms + .invite_state(sender_user, &body.room_id)? + .unwrap_or_default() + .iter() + .filter_map(|event| serde_json::from_str(event.json().get()).ok()) + .filter_map(|event: serde_json::Value| event.get("sender").cloned()) + .filter_map(|sender| sender.as_str().map(|s| s.to_owned())) + .filter_map(|sender| UserId::parse(sender).ok()) + .map(|user| user.server_name().to_owned()), + ); + + servers.push(body.room_id.server_name().to_owned()); let ret = join_room_by_id_helper( &db, @@ -91,19 +92,20 @@ pub async fn join_room_by_id_or_alias_route( let (servers, room_id) = match Box::::try_from(body.room_id_or_alias) { Ok(room_id) => { - let mut servers: HashSet<_> = db - .rooms - .invite_state(sender_user, &room_id)? - .unwrap_or_default() - .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(|s| s.to_owned())) - .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()) - .collect(); + let mut servers = body.server_name.clone(); + servers.extend( + db.rooms + .invite_state(sender_user, &room_id)? + .unwrap_or_default() + .iter() + .filter_map(|event| serde_json::from_str(event.json().get()).ok()) + .filter_map(|event: serde_json::Value| event.get("sender").cloned()) + .filter_map(|sender| sender.as_str().map(|s| s.to_owned())) + .filter_map(|sender| UserId::parse(sender).ok()) + .map(|user| user.server_name().to_owned()), + ); - servers.insert(room_id.server_name().to_owned()); + servers.push(room_id.server_name().to_owned()); (servers, room_id) } Err(room_alias) => { @@ -413,7 +415,8 @@ pub async fn get_member_events_route( Ok(get_member_events::v3::Response { chunk: db .rooms - .room_state_full(&body.room_id)? + .room_state_full(&body.room_id) + .await? .iter() .filter(|(key, _)| key.0 == StateEventType::RoomMember) .map(|(_, pdu)| pdu.to_member_event().into()) @@ -462,7 +465,7 @@ async fn join_room_by_id_helper( db: &Database, sender_user: Option<&UserId>, room_id: &RoomId, - servers: &HashSet>, + servers: &[Box], _third_party_signed: Option<&IncomingThirdPartySigned>, ) -> Result { let sender_user = sender_user.expect("user is authenticated"); @@ -478,7 +481,7 @@ async fn join_room_by_id_helper( let state_lock = mutex_state.lock().await; // Ask a remote server if we don't have this room - if !db.rooms.exists(room_id)? && room_id.server_name() != db.globals.server_name() { + if !db.rooms.exists(room_id)? { let mut make_join_response_and_server = Err(Error::BadServerResponse( "No server available to assist in joining.", )); @@ -1032,6 +1035,13 @@ pub(crate) async fn invite_helper<'a>( return Ok(()); } + if !db.rooms.is_joined(sender_user, &room_id)? { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "You don't have permission to view this room.", + )); + } + let mutex_state = Arc::clone( db.globals .roomid_mutex_state diff --git a/src/client_server/state.rs b/src/client_server/state.rs index 50fe9b4f..4df953cf 100644 --- a/src/client_server/state.rs +++ b/src/client_server/state.rs @@ -124,7 +124,8 @@ pub async fn get_state_events_route( Ok(get_state_events::v3::Response { room_state: db .rooms - .room_state_full(&body.room_id)? + .room_state_full(&body.room_id) + .await? .values() .map(|pdu| pdu.to_state_event()) .collect(), diff --git a/src/client_server/sync.rs b/src/client_server/sync.rs index d61e6894..0c294b7e 100644 --- a/src/client_server/sync.rs +++ b/src/client_server/sync.rs @@ -230,18 +230,20 @@ async fn sync_helper( for room_id in all_joined_rooms { let room_id = room_id?; - // Get and drop the lock to wait for remaining operations to finish - // This will make sure the we have all events until next_batch - let mutex_insert = Arc::clone( - db.globals - .roomid_mutex_insert - .write() - .unwrap() - .entry(room_id.clone()) - .or_default(), - ); - let insert_lock = mutex_insert.lock().unwrap(); - drop(insert_lock); + { + // Get and drop the lock to wait for remaining operations to finish + // This will make sure the we have all events until next_batch + let mutex_insert = Arc::clone( + db.globals + .roomid_mutex_insert + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), + ); + let insert_lock = mutex_insert.lock().unwrap(); + drop(insert_lock); + } let timeline_pdus; let limited; @@ -296,10 +298,12 @@ async fn sync_helper( // Database queries: - let current_shortstatehash = db - .rooms - .current_shortstatehash(&room_id)? - .expect("All rooms have state"); + let current_shortstatehash = if let Some(s) = db.rooms.current_shortstatehash(&room_id)? { + s + } else { + error!("Room {} has no state", room_id); + continue; + }; let since_shortstatehash = db.rooms.get_token_shortstatehash(&room_id, since)?; @@ -377,11 +381,12 @@ async fn sync_helper( let (joined_member_count, invited_member_count, heroes) = calculate_counts()?; - let current_state_ids = db.rooms.state_full_ids(current_shortstatehash)?; + let current_state_ids = db.rooms.state_full_ids(current_shortstatehash).await?; let mut state_events = Vec::new(); let mut lazy_loaded = HashSet::new(); + let mut i = 0; for (shortstatekey, id) in current_state_ids { let (event_type, state_key) = db.rooms.get_statekey_from_short(shortstatekey)?; @@ -394,6 +399,11 @@ async fn sync_helper( } }; state_events.push(pdu); + + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } } else if !lazy_load_enabled || body.full_state || timeline_users.contains(&state_key) @@ -411,6 +421,11 @@ async fn sync_helper( lazy_loaded.insert(uid); } state_events.push(pdu); + + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } } } @@ -462,8 +477,8 @@ async fn sync_helper( let mut lazy_loaded = HashSet::new(); if since_shortstatehash != current_shortstatehash { - let current_state_ids = db.rooms.state_full_ids(current_shortstatehash)?; - let since_state_ids = db.rooms.state_full_ids(since_shortstatehash)?; + let current_state_ids = db.rooms.state_full_ids(current_shortstatehash).await?; + let since_state_ids = db.rooms.state_full_ids(since_shortstatehash).await?; for (key, id) in current_state_ids { if body.full_state || since_state_ids.get(&key) != Some(&id) { @@ -490,6 +505,7 @@ async fn sync_helper( } state_events.push(pdu); + tokio::task::yield_now().await; } } } @@ -753,17 +769,19 @@ async fn sync_helper( for result in all_left_rooms { let (room_id, left_state_events) = result?; - // Get and drop the lock to wait for remaining operations to finish - let mutex_insert = Arc::clone( - db.globals - .roomid_mutex_insert - .write() - .unwrap() - .entry(room_id.clone()) - .or_default(), - ); - let insert_lock = mutex_insert.lock().unwrap(); - drop(insert_lock); + { + // Get and drop the lock to wait for remaining operations to finish + let mutex_insert = Arc::clone( + db.globals + .roomid_mutex_insert + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), + ); + let insert_lock = mutex_insert.lock().unwrap(); + drop(insert_lock); + } let left_count = db.rooms.get_left_count(&room_id, &sender_user)?; @@ -793,17 +811,19 @@ async fn sync_helper( for result in all_invited_rooms { let (room_id, invite_state_events) = result?; - // Get and drop the lock to wait for remaining operations to finish - let mutex_insert = Arc::clone( - db.globals - .roomid_mutex_insert - .write() - .unwrap() - .entry(room_id.clone()) - .or_default(), - ); - let insert_lock = mutex_insert.lock().unwrap(); - drop(insert_lock); + { + // Get and drop the lock to wait for remaining operations to finish + let mutex_insert = Arc::clone( + db.globals + .roomid_mutex_insert + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), + ); + let insert_lock = mutex_insert.lock().unwrap(); + drop(insert_lock); + } let invite_count = db.rooms.get_invite_count(&room_id, &sender_user)?; diff --git a/src/client_server/unversioned.rs b/src/client_server/unversioned.rs index fd0277c6..8a5c3d25 100644 --- a/src/client_server/unversioned.rs +++ b/src/client_server/unversioned.rs @@ -18,7 +18,12 @@ pub async fn get_supported_versions_route( _body: Ruma, ) -> Result { let resp = get_supported_versions::Response { - versions: vec!["r0.5.0".to_owned(), "r0.6.0".to_owned(), "v1.1".to_owned(), "v1.2".to_owned()], + versions: vec![ + "r0.5.0".to_owned(), + "r0.6.0".to_owned(), + "v1.1".to_owned(), + "v1.2".to_owned(), + ], unstable_features: BTreeMap::from_iter([("org.matrix.e2e_cross_signing".to_owned(), true)]), }; diff --git a/src/database.rs b/src/database.rs index 4a03f18c..a0937c29 100644 --- a/src/database.rs +++ b/src/database.rs @@ -213,6 +213,8 @@ impl Database { userroomid_leftstate: builder.open_tree("userroomid_leftstate")?, roomuserid_leftcount: builder.open_tree("roomuserid_leftcount")?, + disabledroomids: builder.open_tree("disabledroomids")?, + lazyloadedids: builder.open_tree("lazyloadedids")?, userroomid_notificationcount: builder.open_tree("userroomid_notificationcount")?, diff --git a/src/database/admin.rs b/src/database/admin.rs index dcf09ebc..3ed1a8a9 100644 --- a/src/database/admin.rs +++ b/src/database/admin.rs @@ -116,7 +116,7 @@ impl Admin { send_message(content, guard, &state_lock); } AdminRoomEvent::ProcessMessage(room_message) => { - let reply_message = process_admin_message(&*guard, room_message); + let reply_message = process_admin_message(&*guard, room_message).await; send_message(reply_message, guard, &state_lock); } @@ -143,7 +143,7 @@ impl Admin { } // Parse and process a message from the admin room -fn process_admin_message(db: &Database, room_message: String) -> RoomMessageEventContent { +async fn process_admin_message(db: &Database, room_message: String) -> RoomMessageEventContent { let mut lines = room_message.lines(); let command_line = lines.next().expect("each string has at least one line"); let body: Vec<_> = lines.collect(); @@ -161,7 +161,7 @@ fn process_admin_message(db: &Database, room_message: String) -> RoomMessageEven } }; - match process_admin_command(db, admin_command, body) { + match process_admin_command(db, admin_command, body).await { Ok(reply_message) => reply_message, Err(error) => { let markdown_message = format!( @@ -231,9 +231,15 @@ enum AdminCommand { /// List all the currently registered appservices ListAppservices, + /// List all rooms the server knows about + ListRooms, + /// List users in the database ListLocalUsers, + /// List all rooms we are currently handling an incoming pdu from + IncomingFederation, + /// Get the auth_chain of a PDU GetAuthChain { /// An event ID (the $ character followed by the base64 reference hash) @@ -269,6 +275,7 @@ enum AdminCommand { /// Username of the user for whom the password should be reset username: String, }, + /// Create a new user CreateUser { /// Username of the new user @@ -276,9 +283,14 @@ enum AdminCommand { /// Password of the new user, if unspecified one is generated password: Option, }, + + /// Disables incoming federation handling for a room. + DisableRoom { room_id: Box }, + /// Enables incoming federation handling for a room again. + EnableRoom { room_id: Box }, } -fn process_admin_command( +async fn process_admin_command( db: &Database, command: AdminCommand, body: Vec<&str>, @@ -336,6 +348,26 @@ fn process_admin_command( RoomMessageEventContent::text_plain("Failed to get appservices.") } } + AdminCommand::ListRooms => { + let room_ids = db.rooms.iter_ids(); + let output = format!( + "Rooms:\n{}", + room_ids + .filter_map(|r| r.ok()) + .map(|id| id.to_string() + + "\tMembers: " + + &db + .rooms + .room_joined_count(&id) + .ok() + .flatten() + .unwrap_or(0) + .to_string()) + .collect::>() + .join("\n") + ); + RoomMessageEventContent::text_plain(output) + } AdminCommand::ListLocalUsers => match db.users.list_local_users() { Ok(users) => { let mut msg: String = format!("Found {} local user account(s):\n", users.len()); @@ -344,6 +376,22 @@ fn process_admin_command( } Err(e) => RoomMessageEventContent::text_plain(e.to_string()), }, + AdminCommand::IncomingFederation => { + let map = db.globals.roomid_federationhandletime.read().unwrap(); + let mut msg: String = format!("Handling {} incoming pdus:\n", map.len()); + + for (r, (e, i)) in map.iter() { + let elapsed = i.elapsed(); + msg += &format!( + "{} {}: {}m{}s\n", + r, + e, + elapsed.as_secs() / 60, + elapsed.as_secs() % 60 + ); + } + RoomMessageEventContent::text_plain(&msg) + } AdminCommand::GetAuthChain { event_id } => { let event_id = Arc::::from(event_id); if let Some(event) = db.rooms.get_pdu_json(&event_id)? { @@ -356,7 +404,9 @@ fn process_admin_command( Error::bad_database("Invalid room id field in event in database") })?; let start = Instant::now(); - let count = server_server::get_auth_chain(room_id, vec![event_id], db)?.count(); + let count = server_server::get_auth_chain(room_id, vec![event_id], db) + .await? + .count(); let elapsed = start.elapsed(); RoomMessageEventContent::text_plain(format!( "Loaded auth chain with length {} in {:?}", @@ -545,6 +595,14 @@ fn process_admin_command( "Created user with user_id: {user_id} and password: {password}" )) } + AdminCommand::DisableRoom { room_id } => { + db.rooms.disabledroomids.insert(room_id.as_bytes(), &[])?; + RoomMessageEventContent::text_plain("Room disabled.") + } + AdminCommand::EnableRoom { room_id } => { + db.rooms.disabledroomids.remove(room_id.as_bytes())?; + RoomMessageEventContent::text_plain("Room enabled.") + } }; Ok(reply_message_content) diff --git a/src/database/globals.rs b/src/database/globals.rs index d363e933..7e09128e 100644 --- a/src/database/globals.rs +++ b/src/database/globals.rs @@ -52,6 +52,8 @@ pub struct Globals { pub roomid_mutex_insert: RwLock, Arc>>>, pub roomid_mutex_state: RwLock, Arc>>>, pub roomid_mutex_federation: RwLock, Arc>>>, // this lock will be held longer + pub roomid_federationhandletime: RwLock, (Box, Instant)>>, + pub stateres_mutex: Arc>, pub rotate: RotationHandler, } @@ -183,6 +185,8 @@ impl Globals { roomid_mutex_state: RwLock::new(HashMap::new()), roomid_mutex_insert: RwLock::new(HashMap::new()), roomid_mutex_federation: RwLock::new(HashMap::new()), + roomid_federationhandletime: RwLock::new(HashMap::new()), + stateres_mutex: Arc::new(Mutex::new(())), sync_receivers: RwLock::new(HashMap::new()), rotate: RotationHandler::new(), }; diff --git a/src/database/rooms.rs b/src/database/rooms.rs index c885c960..7b3b7506 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -76,6 +76,8 @@ pub struct Rooms { pub(super) userroomid_leftstate: Arc, pub(super) roomuserid_leftcount: Arc, + pub(super) disabledroomids: Arc, // Rooms where incoming federation handling is disabled + pub(super) lazyloadedids: Arc, // LazyLoadedIds = UserId + DeviceId + RoomId + LazyLoadedUserId pub(super) userroomid_notificationcount: Arc, // NotifyCount = u64 @@ -142,20 +144,28 @@ impl Rooms { /// Builds a StateMap by iterating over all keys that start /// with state_hash, this gives the full state for the given state_hash. #[tracing::instrument(skip(self))] - pub fn state_full_ids(&self, shortstatehash: u64) -> Result>> { + pub async fn state_full_ids(&self, shortstatehash: u64) -> Result>> { let full_state = self .load_shortstatehash_info(shortstatehash)? .pop() .expect("there is always one layer") .1; - full_state - .into_iter() - .map(|compressed| self.parse_compressed_state_event(compressed)) - .collect() + let mut result = BTreeMap::new(); + let mut i = 0; + for compressed in full_state.into_iter() { + let parsed = self.parse_compressed_state_event(compressed)?; + result.insert(parsed.0, parsed.1); + + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + } + Ok(result) } #[tracing::instrument(skip(self))] - pub fn state_full( + pub async fn state_full( &self, shortstatehash: u64, ) -> Result>> { @@ -164,14 +174,13 @@ impl Rooms { .pop() .expect("there is always one layer") .1; - Ok(full_state - .into_iter() - .map(|compressed| self.parse_compressed_state_event(compressed)) - .filter_map(|r| r.ok()) - .map(|(_, eventid)| self.get_pdu(&eventid)) - .filter_map(|r| r.ok().flatten()) - .map(|pdu| { - Ok::<_, Error>(( + + let mut result = HashMap::new(); + let mut i = 0; + for compressed in full_state { + let (_, eventid) = self.parse_compressed_state_event(compressed)?; + if let Some(pdu) = self.get_pdu(&eventid)? { + result.insert( ( pdu.kind.to_string().into(), pdu.state_key @@ -180,10 +189,16 @@ impl Rooms { .clone(), ), pdu, - )) - }) - .filter_map(|r| r.ok()) - .collect()) + ); + } + + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + } + + Ok(result) } /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). @@ -226,7 +241,6 @@ impl Rooms { } /// Returns the state hash for this pdu. - #[tracing::instrument(skip(self))] pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result> { self.eventid_shorteventid .get(event_id.as_bytes())? @@ -529,7 +543,6 @@ impl Rooms { } } - #[tracing::instrument(skip(self, globals))] pub fn compress_state_event( &self, shortstatekey: u64, @@ -546,7 +559,6 @@ impl Rooms { } /// Returns shortstatekey, event id - #[tracing::instrument(skip(self, compressed_event))] pub fn parse_compressed_state_event( &self, compressed_event: CompressedStateEvent, @@ -705,7 +717,6 @@ impl Rooms { } /// Returns (shortstatehash, already_existed) - #[tracing::instrument(skip(self, globals))] fn get_or_create_shortstatehash( &self, state_hash: &StateHashId, @@ -726,7 +737,6 @@ impl Rooms { }) } - #[tracing::instrument(skip(self, globals))] pub fn get_or_create_shorteventid( &self, event_id: &EventId, @@ -757,7 +767,6 @@ impl Rooms { Ok(short) } - #[tracing::instrument(skip(self))] pub fn get_shortroomid(&self, room_id: &RoomId) -> Result> { self.roomid_shortroomid .get(room_id.as_bytes())? @@ -768,7 +777,6 @@ impl Rooms { .transpose() } - #[tracing::instrument(skip(self))] pub fn get_shortstatekey( &self, event_type: &StateEventType, @@ -806,7 +814,6 @@ impl Rooms { Ok(short) } - #[tracing::instrument(skip(self, globals))] pub fn get_or_create_shortroomid( &self, room_id: &RoomId, @@ -824,7 +831,6 @@ impl Rooms { }) } - #[tracing::instrument(skip(self, globals))] pub fn get_or_create_shortstatekey( &self, event_type: &StateEventType, @@ -865,7 +871,6 @@ impl Rooms { Ok(short) } - #[tracing::instrument(skip(self))] pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { if let Some(id) = self .shorteventid_cache @@ -894,7 +899,6 @@ impl Rooms { Ok(event_id) } - #[tracing::instrument(skip(self))] pub fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { if let Some(id) = self .shortstatekey_cache @@ -938,12 +942,12 @@ impl Rooms { /// Returns the full room state. #[tracing::instrument(skip(self))] - pub fn room_state_full( + pub async fn room_state_full( &self, room_id: &RoomId, ) -> Result>> { if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { - self.state_full(current_shortstatehash) + self.state_full(current_shortstatehash).await } else { Ok(HashMap::new()) } @@ -980,14 +984,12 @@ impl Rooms { } /// Returns the `count` of this pdu's id. - #[tracing::instrument(skip(self))] pub fn pdu_count(&self, pdu_id: &[u8]) -> Result { utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::()..]) .map_err(|_| Error::bad_database("PDU has invalid count bytes.")) } /// Returns the `count` of this pdu's id. - #[tracing::instrument(skip(self))] pub fn get_pdu_count(&self, event_id: &EventId) -> Result> { self.eventid_pduid .get(event_id.as_bytes())? @@ -1016,7 +1018,6 @@ impl Rooms { } /// Returns the json of a pdu. - #[tracing::instrument(skip(self))] pub fn get_pdu_json(&self, event_id: &EventId) -> Result> { self.eventid_pduid .get(event_id.as_bytes())? @@ -1035,7 +1036,6 @@ impl Rooms { } /// Returns the json of a pdu. - #[tracing::instrument(skip(self))] pub fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { self.eventid_outlierpdu .get(event_id.as_bytes())? @@ -1046,7 +1046,6 @@ impl Rooms { } /// Returns the json of a pdu. - #[tracing::instrument(skip(self))] pub fn get_non_outlier_pdu_json( &self, event_id: &EventId, @@ -1066,7 +1065,6 @@ impl Rooms { } /// Returns the pdu's id. - #[tracing::instrument(skip(self))] pub fn get_pdu_id(&self, event_id: &EventId) -> Result>> { self.eventid_pduid.get(event_id.as_bytes()) } @@ -1074,7 +1072,6 @@ impl Rooms { /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - #[tracing::instrument(skip(self))] pub fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result> { self.eventid_pduid .get(event_id.as_bytes())? @@ -1093,7 +1090,6 @@ impl Rooms { /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - #[tracing::instrument(skip(self))] pub fn get_pdu(&self, event_id: &EventId) -> Result>> { if let Some(p) = self.pdu_cache.lock().unwrap().get_mut(event_id) { return Ok(Some(Arc::clone(p))); @@ -1130,7 +1126,6 @@ impl Rooms { /// Returns the pdu. /// /// This does __NOT__ check the outliers `Tree`. - #[tracing::instrument(skip(self))] pub fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result> { self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { Ok(Some( @@ -1141,7 +1136,6 @@ impl Rooms { } /// Returns the pdu as a `BTreeMap`. - #[tracing::instrument(skip(self))] pub fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result> { self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { Ok(Some( @@ -1230,7 +1224,6 @@ impl Rooms { } /// Returns the pdu from the outlier tree. - #[tracing::instrument(skip(self))] pub fn get_pdu_outlier(&self, event_id: &EventId) -> Result> { self.eventid_outlierpdu .get(event_id.as_bytes())? @@ -2858,6 +2851,18 @@ impl Rooms { Ok(self.publicroomids.get(room_id.as_bytes())?.is_some()) } + #[tracing::instrument(skip(self))] + pub fn iter_ids(&self) -> impl Iterator>> + '_ { + self.roomid_shortroomid.iter().map(|(bytes, _)| { + RoomId::parse( + utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Room ID in publicroomids is invalid unicode.") + })?, + ) + .map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid.")) + }) + } + #[tracing::instrument(skip(self))] pub fn public_rooms(&self) -> impl Iterator>> + '_ { self.publicroomids.iter().map(|(bytes, _)| { @@ -3140,6 +3145,10 @@ impl Rooms { .transpose() } + pub fn is_disabled(&self, room_id: &RoomId) -> Result { + Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some()) + } + /// Returns an iterator over all rooms this user joined. #[tracing::instrument(skip(self))] pub fn rooms_joined<'a>( diff --git a/src/ruma_wrapper/axum.rs b/src/ruma_wrapper/axum.rs index fdb140fe..45e9d9a8 100644 --- a/src/ruma_wrapper/axum.rs +++ b/src/ruma_wrapper/axum.rs @@ -338,7 +338,7 @@ impl Credentials for XMatrix { "origin" => origin = Some(value.try_into().ok()?), "key" => key = Some(value.to_owned()), "sig" => sig = Some(value.to_owned()), - _ => warn!( + _ => debug!( "Unexpected field `{}` in X-Matrix Authorization header", name ), diff --git a/src/server_server.rs b/src/server_server.rs index a227f57c..6fa83e4c 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -691,7 +691,7 @@ pub async fn send_transaction_message_route( .roomid_mutex_federation .write() .unwrap() - .entry(room_id.clone()) + .entry(room_id.to_owned()) .or_default(), ); let mutex_lock = mutex.lock().await; @@ -768,7 +768,7 @@ pub async fn send_transaction_message_route( )?; } else { // TODO fetch missing events - debug!("No known event ids in read receipt: {:?}", user_updates); + info!("No known event ids in read receipt: {:?}", user_updates); } } } @@ -926,6 +926,13 @@ pub(crate) async fn handle_incoming_pdu<'a>( } } + match db.rooms.is_disabled(room_id) { + Ok(false) => {} + _ => { + return Err("Federation of this room is currently disabled on this server.".to_owned()); + } + } + // 1. Skip the PDU if we already have it as a timeline event if let Ok(Some(pdu_id)) = db.rooms.get_pdu_id(event_id) { return Ok(Some(pdu_id.to_vec())); @@ -1038,6 +1045,34 @@ pub(crate) async fn handle_incoming_pdu<'a>( let mut errors = 0; for prev_id in dbg!(sorted) { + match db.rooms.is_disabled(room_id) { + Ok(false) => {} + _ => { + return Err( + "Federation of this room is currently disabled on this server.".to_owned(), + ); + } + } + + if let Some((time, tries)) = db + .globals + .bad_event_ratelimiter + .read() + .unwrap() + .get(&*prev_id) + { + // Exponential backoff + let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); + if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { + min_elapsed_duration = Duration::from_secs(60 * 60 * 24); + } + + if time.elapsed() < min_elapsed_duration { + info!("Backing off from {}", prev_id); + continue; + } + } + if errors >= 5 { break; } @@ -1047,7 +1082,11 @@ pub(crate) async fn handle_incoming_pdu<'a>( } let start_time = Instant::now(); - let event_id = pdu.event_id.clone(); + db.globals + .roomid_federationhandletime + .write() + .unwrap() + .insert(room_id.to_owned(), ((*prev_id).to_owned(), start_time)); if let Err(e) = upgrade_outlier_to_timeline_pdu( pdu, json, @@ -1060,19 +1099,44 @@ pub(crate) async fn handle_incoming_pdu<'a>( .await { errors += 1; - warn!("Prev event {} failed: {}", event_id, e); + warn!("Prev event {} failed: {}", prev_id, e); + match db + .globals + .bad_event_ratelimiter + .write() + .unwrap() + .entry((*prev_id).to_owned()) + { + hash_map::Entry::Vacant(e) => { + e.insert((Instant::now(), 1)); + } + hash_map::Entry::Occupied(mut e) => { + *e.get_mut() = (Instant::now(), e.get().1 + 1) + } + } } let elapsed = start_time.elapsed(); + db.globals + .roomid_federationhandletime + .write() + .unwrap() + .remove(&room_id.to_owned()); warn!( "Handling prev event {} took {}m{}s", - event_id, + prev_id, elapsed.as_secs() / 60, elapsed.as_secs() % 60 ); } } - upgrade_outlier_to_timeline_pdu( + let start_time = Instant::now(); + db.globals + .roomid_federationhandletime + .write() + .unwrap() + .insert(room_id.to_owned(), (event_id.to_owned(), start_time)); + let r = upgrade_outlier_to_timeline_pdu( incoming_pdu, val, &create_event, @@ -1081,10 +1145,17 @@ pub(crate) async fn handle_incoming_pdu<'a>( room_id, pub_key_map, ) - .await + .await; + db.globals + .roomid_federationhandletime + .write() + .unwrap() + .remove(&room_id.to_owned()); + + r } -#[tracing::instrument(skip_all)] +#[tracing::instrument(skip(create_event, value, db, pub_key_map))] fn handle_outlier_pdu<'a>( origin: &'a ServerName, create_event: &'a PduEvent, @@ -1166,7 +1237,7 @@ fn handle_outlier_pdu<'a>( .await; // 6. Reject "due to auth events" if the event doesn't pass auth based on the auth events - debug!( + info!( "Auth check for {} based on auth events", incoming_pdu.event_id ); @@ -1221,19 +1292,19 @@ fn handle_outlier_pdu<'a>( return Err("Event has failed auth check with auth events.".to_owned()); } - debug!("Validation successful."); + info!("Validation successful."); // 7. Persist the event as an outlier. db.rooms .add_pdu_outlier(&incoming_pdu.event_id, &val) .map_err(|_| "Failed to add pdu as outlier.".to_owned())?; - debug!("Added pdu as outlier."); + info!("Added pdu as outlier."); Ok((Arc::new(incoming_pdu), val)) }) } -#[tracing::instrument(skip_all)] +#[tracing::instrument(skip(incoming_pdu, val, create_event, db, pub_key_map))] async fn upgrade_outlier_to_timeline_pdu( incoming_pdu: Arc, val: BTreeMap, @@ -1255,6 +1326,8 @@ async fn upgrade_outlier_to_timeline_pdu( return Err("Event has been soft failed".into()); } + info!("Upgrading {} to timeline pdu", incoming_pdu.event_id); + let create_event_content: RoomCreateEventContent = serde_json::from_str(create_event.content.get()).map_err(|e| { warn!("Invalid create event: {}", e); @@ -1270,7 +1343,7 @@ async fn upgrade_outlier_to_timeline_pdu( // TODO: if we know the prev_events of the incoming event we can avoid the request and build // the state from a known point and resolve if > 1 prev_event - debug!("Requesting state at event."); + info!("Requesting state at event"); let mut state_at_incoming_event = None; if incoming_pdu.prev_events.len() == 1 { @@ -1280,11 +1353,14 @@ async fn upgrade_outlier_to_timeline_pdu( .pdu_shortstatehash(prev_event) .map_err(|_| "Failed talking to db".to_owned())?; - let state = - prev_event_sstatehash.map(|shortstatehash| db.rooms.state_full_ids(shortstatehash)); + let state = if let Some(shortstatehash) = prev_event_sstatehash { + Some(db.rooms.state_full_ids(shortstatehash).await) + } else { + None + }; if let Some(Ok(mut state)) = state { - warn!("Using cached state"); + info!("Using cached state"); let prev_pdu = db.rooms.get_pdu(prev_event).ok().flatten().ok_or_else(|| { "Could not find prev event, but we know the state.".to_owned() @@ -1307,7 +1383,7 @@ async fn upgrade_outlier_to_timeline_pdu( state_at_incoming_event = Some(state); } } else { - warn!("Calculating state at event using state res"); + info!("Calculating state at event using state res"); let mut extremity_sstatehashes = HashMap::new(); let mut okay = true; @@ -1337,6 +1413,7 @@ async fn upgrade_outlier_to_timeline_pdu( let mut leaf_state: BTreeMap<_, _> = db .rooms .state_full_ids(sstatehash) + .await .map_err(|_| "Failed to ask db for room state.".to_owned())?; if let Some(state_key) = &prev_event.state_key { @@ -1368,6 +1445,7 @@ async fn upgrade_outlier_to_timeline_pdu( auth_chain_sets.push( get_auth_chain(room_id, starting_events, db) + .await .map_err(|_| "Failed to load auth chain.".to_owned())? .collect(), ); @@ -1375,18 +1453,18 @@ async fn upgrade_outlier_to_timeline_pdu( fork_states.push(state); } - state_at_incoming_event = match state_res::resolve( - room_version_id, - &fork_states, - auth_chain_sets, - |id| { - let res = db.rooms.get_pdu(id); - if let Err(e) = &res { - error!("LOOK AT ME Failed to fetch event: {}", e); - } - res.ok().flatten() - }, - ) { + let lock = db.globals.stateres_mutex.lock(); + + let result = state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| { + let res = db.rooms.get_pdu(id); + if let Err(e) = &res { + error!("LOOK AT ME Failed to fetch event: {}", e); + } + res.ok().flatten() + }); + drop(lock); + + state_at_incoming_event = match result { Ok(new_state) => Some( new_state .into_iter() @@ -1407,12 +1485,12 @@ async fn upgrade_outlier_to_timeline_pdu( warn!("State resolution on prev events failed, either an event could not be found or deserialization: {}", e); None } - }; + } } } if state_at_incoming_event.is_none() { - warn!("Calling /state_ids"); + info!("Calling /state_ids"); // Call /state_ids to find out what the state at this pdu is. We trust the server's // response to some extend, but we still do a lot of checks on the events match db @@ -1428,7 +1506,7 @@ async fn upgrade_outlier_to_timeline_pdu( .await { Ok(res) => { - warn!("Fetching state events at event."); + info!("Fetching state events at event."); let state_vec = fetch_and_handle_outliers( db, origin, @@ -1494,6 +1572,7 @@ async fn upgrade_outlier_to_timeline_pdu( let state_at_incoming_event = state_at_incoming_event.expect("we always set this to some above"); + info!("Starting auth check"); // 11. Check the auth of the event passes based on the state of the event let check_result = state_res::event_auth::auth_check( &room_version, @@ -1513,7 +1592,7 @@ async fn upgrade_outlier_to_timeline_pdu( if !check_result { return Err("Event has failed auth check with state at the event.".into()); } - debug!("Auth check succeeded."); + info!("Auth check succeeded"); // We start looking at current room state now, so lets lock the room @@ -1529,6 +1608,7 @@ async fn upgrade_outlier_to_timeline_pdu( // Now we calculate the set of extremities this room has after the incoming event has been // applied. We start with the previous extremities (aka leaves) + info!("Calculating extremities"); let mut extremities = db .rooms .get_pdu_leaves(room_id) @@ -1544,16 +1624,18 @@ async fn upgrade_outlier_to_timeline_pdu( // Only keep those extremities were not referenced yet extremities.retain(|id| !matches!(db.rooms.is_event_referenced(room_id, id), Ok(true))); - let current_sstatehash = db - .rooms - .current_shortstatehash(room_id) - .map_err(|_| "Failed to load current state hash.".to_owned())? - .expect("every room has state"); + info!("Compressing state at event"); + let state_ids_compressed = state_at_incoming_event + .iter() + .map(|(shortstatekey, id)| { + db.rooms + .compress_state_event(*shortstatekey, id, &db.globals) + .map_err(|_| "Failed to compress_state_event".to_owned()) + }) + .collect::>()?; - let current_state_ids = db - .rooms - .state_full_ids(current_sstatehash) - .map_err(|_| "Failed to load room state.")?; + // 13. Check if the event passes auth based on the "current state" of the room, if not "soft fail" it + info!("Starting soft fail auth check"); let auth_events = db .rooms @@ -1566,18 +1648,6 @@ async fn upgrade_outlier_to_timeline_pdu( ) .map_err(|_| "Failed to get_auth_events.".to_owned())?; - let state_ids_compressed = state_at_incoming_event - .iter() - .map(|(shortstatekey, id)| { - db.rooms - .compress_state_event(*shortstatekey, id, &db.globals) - .map_err(|_| "Failed to compress_state_event".to_owned()) - }) - .collect::>()?; - - // 13. Check if the event passes auth based on the "current state" of the room, if not "soft fail" it - debug!("starting soft fail auth check"); - let soft_fail = !state_res::event_auth::auth_check( &room_version, &incoming_pdu, @@ -1610,8 +1680,23 @@ async fn upgrade_outlier_to_timeline_pdu( } if incoming_pdu.state_key.is_some() { + info!("Loading current room state ids"); + let current_sstatehash = db + .rooms + .current_shortstatehash(room_id) + .map_err(|_| "Failed to load current state hash.".to_owned())? + .expect("every room has state"); + + let current_state_ids = db + .rooms + .state_full_ids(current_sstatehash) + .await + .map_err(|_| "Failed to load room state.")?; + + info!("Preparing for stateres to derive new room state"); let mut extremity_sstatehashes = HashMap::new(); + info!("Loading extremities"); for id in dbg!(&extremities) { match db .rooms @@ -1671,6 +1756,7 @@ async fn upgrade_outlier_to_timeline_pdu( let new_room_state = if fork_states.is_empty() { return Err("State is empty.".to_owned()); } else if fork_states.iter().skip(1).all(|f| &fork_states[0] == f) { + info!("State resolution trivial"); // There was only one state, so it has to be the room's current state (because that is // always included) fork_states[0] @@ -1682,6 +1768,7 @@ async fn upgrade_outlier_to_timeline_pdu( }) .collect::>()? } else { + info!("Loading auth chains"); // We do need to force an update to this room's state update_state = true; @@ -1693,11 +1780,14 @@ async fn upgrade_outlier_to_timeline_pdu( state.iter().map(|(_, id)| id.clone()).collect(), db, ) + .await .map_err(|_| "Failed to load auth chain.".to_owned())? .collect(), ); } + info!("Loading fork states"); + let fork_states: Vec<_> = fork_states .into_iter() .map(|map| { @@ -1715,6 +1805,9 @@ async fn upgrade_outlier_to_timeline_pdu( }) .collect(); + info!("Resolving state"); + + let lock = db.globals.stateres_mutex.lock(); let state = match state_res::resolve( room_version_id, &fork_states, @@ -1733,6 +1826,10 @@ async fn upgrade_outlier_to_timeline_pdu( } }; + drop(lock); + + info!("State resolution done. Compressing state"); + state .into_iter() .map(|((event_type, state_key), event_id)| { @@ -1753,13 +1850,14 @@ async fn upgrade_outlier_to_timeline_pdu( // Set the new room state to the resolved state if update_state { + info!("Forcing new room state"); db.rooms .force_state(room_id, new_room_state, db) .map_err(|_| "Failed to set new room state.".to_owned())?; } - debug!("Updated resolved state"); } + info!("Appending pdu to timeline"); extremities.insert(incoming_pdu.event_id.clone()); // Now that the event has passed all auth it is added into the timeline. @@ -1780,7 +1878,7 @@ async fn upgrade_outlier_to_timeline_pdu( "Failed to add pdu to db.".to_owned() })?; - debug!("Appended incoming pdu."); + info!("Appended incoming pdu"); // Event has passed all auth/stateres checks drop(state_lock); @@ -1844,17 +1942,23 @@ pub(crate) fn fetch_and_handle_outliers<'a>( let mut todo_auth_events = vec![Arc::clone(id)]; let mut events_in_reverse_order = Vec::new(); let mut events_all = HashSet::new(); + let mut i = 0; while let Some(next_id) = todo_auth_events.pop() { if events_all.contains(&next_id) { continue; } + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + if let Ok(Some(_)) = db.rooms.get_pdu(&next_id) { trace!("Found {} in db", id); continue; } - warn!("Fetching {} over federation.", next_id); + info!("Fetching {} over federation.", next_id); match db .sending .send_federation_request( @@ -1865,7 +1969,7 @@ pub(crate) fn fetch_and_handle_outliers<'a>( .await { Ok(res) => { - warn!("Got {} over federation", next_id); + info!("Got {} over federation", next_id); let (calculated_event_id, value) = match crate::pdu::gen_event_id_canonical_json(&res.pdu, &db) { Ok(t) => t, @@ -2187,7 +2291,7 @@ fn append_incoming_pdu<'a>( } #[tracing::instrument(skip(starting_events, db))] -pub(crate) fn get_auth_chain<'a>( +pub(crate) async fn get_auth_chain<'a>( room_id: &RoomId, starting_events: Vec>, db: &'a Database, @@ -2196,10 +2300,15 @@ pub(crate) fn get_auth_chain<'a>( let mut buckets = vec![BTreeSet::new(); NUM_BUCKETS]; + let mut i = 0; for id in starting_events { let short = db.rooms.get_or_create_shorteventid(&id, &db.globals)?; let bucket_id = (short % NUM_BUCKETS as u64) as usize; buckets[bucket_id].insert((short, id.clone())); + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } } let mut full_auth_chain = HashSet::new(); @@ -2222,6 +2331,7 @@ pub(crate) fn get_auth_chain<'a>( let mut chunk_cache = HashSet::new(); let mut hits2 = 0; let mut misses2 = 0; + let mut i = 0; for (sevent_id, event_id) in chunk { if let Some(cached) = db.rooms.get_auth_chain_from_cache(&[sevent_id])? { hits2 += 1; @@ -2237,6 +2347,11 @@ pub(crate) fn get_auth_chain<'a>( auth_chain.len() ); chunk_cache.extend(auth_chain.iter()); + + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } }; } println!( @@ -2457,7 +2572,7 @@ pub async fn get_event_authorization_route( let room_id = <&RoomId>::try_from(room_id_str) .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; - let auth_chain_ids = get_auth_chain(room_id, vec![Arc::from(&*body.event_id)], &db)?; + let auth_chain_ids = get_auth_chain(room_id, vec![Arc::from(&*body.event_id)], &db).await?; Ok(get_event_authorization::v1::Response { auth_chain: auth_chain_ids @@ -2502,7 +2617,8 @@ pub async fn get_room_state_route( let pdus = db .rooms - .state_full_ids(shortstatehash)? + .state_full_ids(shortstatehash) + .await? .into_iter() .map(|(_, id)| { PduEvent::convert_to_outgoing_federation_event( @@ -2511,7 +2627,8 @@ pub async fn get_room_state_route( }) .collect(); - let auth_chain_ids = get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)], &db)?; + let auth_chain_ids = + get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)], &db).await?; Ok(get_room_state::v1::Response { auth_chain: auth_chain_ids @@ -2561,12 +2678,14 @@ pub async fn get_room_state_ids_route( let pdu_ids = db .rooms - .state_full_ids(shortstatehash)? + .state_full_ids(shortstatehash) + .await? .into_iter() .map(|(_, id)| (*id).to_owned()) .collect(); - let auth_chain_ids = get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)], &db)?; + let auth_chain_ids = + get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)], &db).await?; Ok(get_room_state_ids::v1::Response { auth_chain_ids: auth_chain_ids.map(|id| (*id).to_owned()).collect(), @@ -2872,12 +2991,13 @@ async fn create_join_event( ))?; drop(mutex_lock); - let state_ids = db.rooms.state_full_ids(shortstatehash)?; + let state_ids = db.rooms.state_full_ids(shortstatehash).await?; let auth_chain_ids = get_auth_chain( room_id, state_ids.iter().map(|(_, id)| id.clone()).collect(), db, - )?; + ) + .await?; let servers = db .rooms