diff --git a/src/database/abstraction.rs b/src/database/abstraction.rs index 74f3a45a..29325bd6 100644 --- a/src/database/abstraction.rs +++ b/src/database/abstraction.rs @@ -26,7 +26,7 @@ pub mod persy; ))] pub mod watchers; -pub trait DatabaseEngine: Send + Sync { +pub trait KeyValueDatabaseEngine: Send + Sync { fn open(config: &Config) -> Result where Self: Sized; @@ -40,7 +40,7 @@ pub trait DatabaseEngine: Send + Sync { } } -pub trait Tree: Send + Sync { +pub trait KeyValueTree: Send + Sync { fn get(&self, key: &[u8]) -> Result>>; fn insert(&self, key: &[u8], value: &[u8]) -> Result<()>; diff --git a/src/database/key_value.rs b/src/database/key_value.rs index 1a793f39..34916e4b 100644 --- a/src/database/key_value.rs +++ b/src/database/key_value.rs @@ -1,10 +1,7 @@ -pub trait Data { - fn get_room_shortstatehash(room_id: &RoomId); -} +use crate::service; - /// Returns the last state hash key added to the db for the given room. - #[tracing::instrument(skip(self))] - pub fn current_shortstatehash(&self, room_id: &RoomId) -> Result> { +impl service::room::state::Data for KeyValueDatabase { + fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result> { self.roomid_shortstatehash .get(room_id.as_bytes())? .map_or(Ok(None), |bytes| { @@ -14,77 +11,21 @@ pub trait Data { }) } -pub struct Service { - db: D, -} - -impl Service { - /// Set the room to the given statehash and update caches. - #[tracing::instrument(skip(self, new_state_ids_compressed, db))] - pub fn force_state( - &self, - room_id: &RoomId, - shortstatehash: u64, - statediffnew :HashSet, - statediffremoved :HashSet, - db: &Database, - ) -> Result<()> { - - for event_id in statediffnew.into_iter().filter_map(|new| { - self.parse_compressed_state_event(new) - .ok() - .map(|(_, id)| id) - }) { - let pdu = match self.get_pdu_json(&event_id)? { - Some(pdu) => pdu, - None => continue, - }; - - if pdu.get("type").and_then(|val| val.as_str()) != Some("m.room.member") { - continue; - } - - let pdu: PduEvent = match serde_json::from_str( - &serde_json::to_string(&pdu).expect("CanonicalJsonObj can be serialized to JSON"), - ) { - Ok(pdu) => pdu, - Err(_) => continue, - }; - - #[derive(Deserialize)] - struct ExtractMembership { - membership: MembershipState, - } - - let membership = match serde_json::from_str::(pdu.content.get()) { - Ok(e) => e.membership, - Err(_) => continue, - }; - - let state_key = match pdu.state_key { - Some(k) => k, - None => continue, - }; - - let user_id = match UserId::parse(state_key) { - Ok(id) => id, - Err(_) => continue, - }; - - self.update_membership(room_id, &user_id, membership, &pdu.sender, None, db, false)?; - } - - self.update_joined_count(room_id, db)?; - + fn set_room_state(&self, room_id: &RoomId, new_shortstatehash: u64 + _mutex_lock: &MutexGuard<'_, StateLock>, // Take mutex guard to make sure users get the room state mutex + ) -> Result<()> { self.roomid_shortstatehash .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?; + Ok(()) + } + fn set_event_state(&self) -> Result<()> { + db.shorteventid_shortstatehash + .insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; Ok(()) } - /// Returns the leaf pdus of a room. - #[tracing::instrument(skip(self))] - pub fn get_pdu_leaves(&self, room_id: &RoomId) -> Result>> { + fn get_pdu_leaves(&self, room_id: &RoomId) -> Result>> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); @@ -99,15 +40,11 @@ impl Service { .collect() } - /// Replace the leaves of a room. - /// - /// The provided `event_ids` become the new leaves, this allows a room to have multiple - /// `prev_events`. - #[tracing::instrument(skip(self))] - pub fn replace_pdu_leaves<'a>( + fn set_forward_extremities( &self, room_id: &RoomId, event_ids: impl IntoIterator + Debug, + _mutex_lock: &MutexGuard<'_, StateLock>, // Take mutex guard to make sure users get the room state mutex ) -> Result<()> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); @@ -125,230 +62,48 @@ impl Service { Ok(()) } - /// Generates a new StateHash and associates it with the incoming event. - /// - /// This adds all current state events (not including the incoming event) - /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. - #[tracing::instrument(skip(self, state_ids_compressed, globals))] - pub fn set_event_state( +} + +impl service::room::alias::Data for KeyValueDatabase { + fn set_alias( &self, - event_id: &EventId, - room_id: &RoomId, - state_ids_compressed: HashSet, - globals: &super::globals::Globals, + alias: &RoomAliasId, + room_id: Option<&RoomId> ) -> Result<()> { - let shorteventid = self.get_or_create_shorteventid(event_id, globals)?; - - let previous_shortstatehash = self.current_shortstatehash(room_id)?; - - let state_hash = self.calculate_hash( - &state_ids_compressed - .iter() - .map(|s| &s[..]) - .collect::>(), - ); - - let (shortstatehash, already_existed) = - self.get_or_create_shortstatehash(&state_hash, globals)?; - - if !already_existed { - let states_parents = previous_shortstatehash - .map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; - - let (statediffnew, statediffremoved) = - if let Some(parent_stateinfo) = states_parents.last() { - let statediffnew: HashSet<_> = state_ids_compressed - .difference(&parent_stateinfo.1) - .copied() - .collect(); - - let statediffremoved: HashSet<_> = parent_stateinfo - .1 - .difference(&state_ids_compressed) - .copied() - .collect(); - - (statediffnew, statediffremoved) - } else { - (state_ids_compressed, HashSet::new()) - }; - self.save_state_from_diff( - shortstatehash, - statediffnew, - statediffremoved, - 1_000_000, // high number because no state will be based on this one - states_parents, - )?; - } - - self.shorteventid_shortstatehash - .insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; - + self.alias_roomid + .insert(alias.alias().as_bytes(), room_id.as_bytes())?; + let mut aliasid = room_id.as_bytes().to_vec(); + aliasid.push(0xff); + aliasid.extend_from_slice(&globals.next_count()?.to_be_bytes()); + self.aliasid_alias.insert(&aliasid, &*alias.as_bytes())?; Ok(()) } - /// Generates a new StateHash and associates it with the incoming event. - /// - /// This adds all current state events (not including the incoming event) - /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. - #[tracing::instrument(skip(self, new_pdu, globals))] - pub fn append_to_state( + fn remove_alias( &self, - new_pdu: &PduEvent, - globals: &super::globals::Globals, - ) -> Result { - let shorteventid = self.get_or_create_shorteventid(&new_pdu.event_id, globals)?; - - let previous_shortstatehash = self.current_shortstatehash(&new_pdu.room_id)?; - - if let Some(p) = previous_shortstatehash { - self.shorteventid_shortstatehash - .insert(&shorteventid.to_be_bytes(), &p.to_be_bytes())?; - } - - if let Some(state_key) = &new_pdu.state_key { - let states_parents = previous_shortstatehash - .map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; - - let shortstatekey = self.get_or_create_shortstatekey( - &new_pdu.kind.to_string().into(), - state_key, - globals, - )?; - - let new = self.compress_state_event(shortstatekey, &new_pdu.event_id, globals)?; - - let replaces = states_parents - .last() - .map(|info| { - info.1 - .iter() - .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) - }) - .unwrap_or_default(); - - if Some(&new) == replaces { - return Ok(previous_shortstatehash.expect("must exist")); - } - - // TODO: statehash with deterministic inputs - let shortstatehash = globals.next_count()?; - - let mut statediffnew = HashSet::new(); - statediffnew.insert(new); + alias: &RoomAliasId, + ) -> Result<()> { + if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? { + let mut prefix = room_id.to_vec(); + prefix.push(0xff); - let mut statediffremoved = HashSet::new(); - if let Some(replaces) = replaces { - statediffremoved.insert(*replaces); + for (key, _) in self.aliasid_alias.scan_prefix(prefix) { + self.aliasid_alias.remove(&key)?; } - - self.save_state_from_diff( - shortstatehash, - statediffnew, - statediffremoved, - 2, - states_parents, - )?; - - Ok(shortstatehash) + self.alias_roomid.remove(alias.alias().as_bytes())?; } else { - Ok(previous_shortstatehash.expect("first event in room must be a state event")) - } - } - - #[tracing::instrument(skip(self, invite_event))] - pub fn calculate_invite_state( - &self, - invite_event: &PduEvent, - ) -> Result>> { - let mut state = Vec::new(); - // Add recommended events - if let Some(e) = - self.room_state_get(&invite_event.room_id, &StateEventType::RoomCreate, "")? - { - state.push(e.to_stripped_state_event()); - } - if let Some(e) = - self.room_state_get(&invite_event.room_id, &StateEventType::RoomJoinRules, "")? - { - state.push(e.to_stripped_state_event()); + return Err(Error::BadRequest( + ErrorKind::NotFound, + "Alias does not exist.", + )); } - if let Some(e) = self.room_state_get( - &invite_event.room_id, - &StateEventType::RoomCanonicalAlias, - "", - )? { - state.push(e.to_stripped_state_event()); - } - if let Some(e) = - self.room_state_get(&invite_event.room_id, &StateEventType::RoomAvatar, "")? - { - state.push(e.to_stripped_state_event()); - } - if let Some(e) = - self.room_state_get(&invite_event.room_id, &StateEventType::RoomName, "")? - { - state.push(e.to_stripped_state_event()); - } - if let Some(e) = self.room_state_get( - &invite_event.room_id, - &StateEventType::RoomMember, - invite_event.sender.as_str(), - )? { - state.push(e.to_stripped_state_event()); - } - - state.push(invite_event.to_stripped_state_event()); - Ok(state) - } - - #[tracing::instrument(skip(self))] - pub fn set_room_state(&self, room_id: &RoomId, shortstatehash: u64) -> Result<()> { - self.roomid_shortstatehash - .insert(room_id.as_bytes(), &shortstatehash.to_be_bytes())?; - Ok(()) } -} - #[tracing::instrument(skip(self, globals))] - pub fn set_alias( + fn resolve_local_alias( &self, - alias: &RoomAliasId, - room_id: Option<&RoomId>, - globals: &super::globals::Globals, + alias: &RoomAliasId ) -> Result<()> { - if let Some(room_id) = room_id { - // New alias - self.alias_roomid - .insert(alias.alias().as_bytes(), room_id.as_bytes())?; - let mut aliasid = room_id.as_bytes().to_vec(); - aliasid.push(0xff); - aliasid.extend_from_slice(&globals.next_count()?.to_be_bytes()); - self.aliasid_alias.insert(&aliasid, &*alias.as_bytes())?; - } else { - // room_id=None means remove alias - if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? { - let mut prefix = room_id.to_vec(); - prefix.push(0xff); - - for (key, _) in self.aliasid_alias.scan_prefix(prefix) { - self.aliasid_alias.remove(&key)?; - } - self.alias_roomid.remove(alias.alias().as_bytes())?; - } else { - return Err(Error::BadRequest( - ErrorKind::NotFound, - "Alias does not exist.", - )); - } - } - - Ok(()) - } - - #[tracing::instrument(skip(self))] - pub fn id_from_alias(&self, alias: &RoomAliasId) -> Result>> { self.alias_roomid .get(alias.alias().as_bytes())? .map(|bytes| { @@ -360,11 +115,10 @@ impl Service { .transpose() } - #[tracing::instrument(skip(self))] - pub fn room_aliases<'a>( - &'a self, + fn local_aliases_for_room( + &self, room_id: &RoomId, - ) -> impl Iterator>> + 'a { + ) -> Result<()> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); @@ -375,26 +129,22 @@ impl Service { .map_err(|_| Error::bad_database("Invalid alias in aliasid_alias.")) }) } +} +impl service::room::directory::Data for KeyValueDatabase { + fn set_public(&self, room_id: &RoomId) -> Result<()> { + self.publicroomids.insert(room_id.as_bytes(), &[])?; + } - #[tracing::instrument(skip(self))] - pub fn set_public(&self, room_id: &RoomId, public: bool) -> Result<()> { - if public { - self.publicroomids.insert(room_id.as_bytes(), &[])?; - } else { - self.publicroomids.remove(room_id.as_bytes())?; - } - - Ok(()) + fn set_not_public(&self, room_id: &RoomId) -> Result<()> { + self.publicroomids.remove(room_id.as_bytes())?; } - #[tracing::instrument(skip(self))] - pub fn is_public_room(&self, room_id: &RoomId) -> Result { + fn is_public_room(&self, room_id: &RoomId) -> Result { Ok(self.publicroomids.get(room_id.as_bytes())?.is_some()) } - #[tracing::instrument(skip(self))] - pub fn public_rooms(&self) -> impl Iterator>> + '_ { + fn public_rooms(&self) -> impl Iterator>> + '_ { self.publicroomids.iter().map(|(bytes, _)| { RoomId::parse( utils::string_from_bytes(&bytes).map_err(|_| { @@ -404,43 +154,14 @@ impl Service { .map_err(|_| Error::bad_database("Room ID in publicroomids is invalid.")) }) } - -use crate::{database::abstraction::Tree, utils, Error, Result}; -use ruma::{ - events::{ - presence::{PresenceEvent, PresenceEventContent}, - receipt::ReceiptEvent, - SyncEphemeralRoomEvent, - }, - presence::PresenceState, - serde::Raw, - signatures::CanonicalJsonObject, - RoomId, UInt, UserId, -}; -use std::{ - collections::{HashMap, HashSet}, - mem, - sync::Arc, -}; - -pub struct RoomEdus { - pub(in super::super) readreceiptid_readreceipt: Arc, // ReadReceiptId = RoomId + Count + UserId - pub(in super::super) roomuserid_privateread: Arc, // RoomUserId = Room + User, PrivateRead = Count - pub(in super::super) roomuserid_lastprivatereadupdate: Arc, // LastPrivateReadUpdate = Count - pub(in super::super) typingid_userid: Arc, // TypingId = RoomId + TimeoutTime + Count - pub(in super::super) roomid_lasttypingupdate: Arc, // LastRoomTypingUpdate = Count - pub(in super::super) presenceid_presence: Arc, // PresenceId = RoomId + Count + UserId - pub(in super::super) userid_lastpresenceupdate: Arc, // LastPresenceUpdate = Count } -impl RoomEdus { - /// Adds an event which will be saved until a new event replaces it (e.g. read receipt). - pub fn readreceipt_update( +impl service::room::edus::Data for KeyValueDatabase { + fn readreceipt_update( &self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent, - globals: &super::super::globals::Globals, ) -> Result<()> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); @@ -477,8 +198,6 @@ impl RoomEdus { Ok(()) } - /// Returns an iterator over the most recent read_receipts in a room that happened after the event with id `since`. - #[tracing::instrument(skip(self))] pub fn readreceipts_since<'a>( &'a self, room_id: &RoomId, @@ -527,14 +246,11 @@ impl RoomEdus { }) } - /// Sets a private read marker at `count`. - #[tracing::instrument(skip(self, globals))] - pub fn private_read_set( + fn private_read_set( &self, room_id: &RoomId, user_id: &UserId, count: u64, - globals: &super::super::globals::Globals, ) -> Result<()> { let mut key = room_id.as_bytes().to_vec(); key.push(0xff); @@ -545,13 +261,9 @@ impl RoomEdus { self.roomuserid_lastprivatereadupdate .insert(&key, &globals.next_count()?.to_be_bytes())?; - - Ok(()) } - /// Returns the private read marker. - #[tracing::instrument(skip(self))] - pub fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result> { let mut key = room_id.as_bytes().to_vec(); key.push(0xff); key.extend_from_slice(user_id.as_bytes()); @@ -565,8 +277,7 @@ impl RoomEdus { }) } - /// Returns the count of the last typing update in this room. - pub fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result { + fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result { let mut key = room_id.as_bytes().to_vec(); key.push(0xff); key.extend_from_slice(user_id.as_bytes()); @@ -583,9 +294,7 @@ impl RoomEdus { .unwrap_or(0)) } - /// Sets a user as typing until the timeout timestamp is reached or roomtyping_remove is - /// called. - pub fn typing_add( + fn typing_add( &self, user_id: &UserId, room_id: &RoomId, @@ -611,12 +320,10 @@ impl RoomEdus { Ok(()) } - /// Removes a user from typing before the timeout is reached. - pub fn typing_remove( + fn typing_remove( &self, user_id: &UserId, room_id: &RoomId, - globals: &super::super::globals::Globals, ) -> Result<()> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); @@ -643,59 +350,10 @@ impl RoomEdus { Ok(()) } - /// Makes sure that typing events with old timestamps get removed. - fn typings_maintain( - &self, - room_id: &RoomId, - globals: &super::super::globals::Globals, - ) -> Result<()> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); - - let current_timestamp = utils::millis_since_unix_epoch(); - - let mut found_outdated = false; - - // Find all outdated edus before inserting a new one - for outdated_edu in self - .typingid_userid - .scan_prefix(prefix) - .map(|(key, _)| { - Ok::<_, Error>(( - key.clone(), - utils::u64_from_bytes( - &key.splitn(2, |&b| b == 0xff).nth(1).ok_or_else(|| { - Error::bad_database("RoomTyping has invalid timestamp or delimiters.") - })?[0..mem::size_of::()], - ) - .map_err(|_| Error::bad_database("RoomTyping has invalid timestamp bytes."))?, - )) - }) - .filter_map(|r| r.ok()) - .take_while(|&(_, timestamp)| timestamp < current_timestamp) - { - // This is an outdated edu (time > timestamp) - self.typingid_userid.remove(&outdated_edu.0)?; - found_outdated = true; - } - - if found_outdated { - self.roomid_lasttypingupdate - .insert(room_id.as_bytes(), &globals.next_count()?.to_be_bytes())?; - } - - Ok(()) - } - - /// Returns the count of the last typing update in this room. - #[tracing::instrument(skip(self, globals))] - pub fn last_typing_update( + fn last_typing_update( &self, room_id: &RoomId, - globals: &super::super::globals::Globals, ) -> Result { - self.typings_maintain(room_id, globals)?; - Ok(self .roomid_lasttypingupdate .get(room_id.as_bytes())? @@ -708,10 +366,10 @@ impl RoomEdus { .unwrap_or(0)) } - pub fn typings_all( + fn typings_all( &self, room_id: &RoomId, - ) -> Result> { + ) -> Result> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); @@ -726,23 +384,14 @@ impl RoomEdus { user_ids.insert(user_id); } - Ok(SyncEphemeralRoomEvent { - content: ruma::events::typing::TypingEventContent { - user_ids: user_ids.into_iter().collect(), - }, - }) + Ok(user_ids) } - /// Adds a presence event which will be saved until a new event replaces it. - /// - /// Note: This method takes a RoomId because presence updates are always bound to rooms to - /// make sure users outside these rooms can't see them. - pub fn update_presence( + fn update_presence( &self, user_id: &UserId, room_id: &RoomId, presence: PresenceEvent, - globals: &super::super::globals::Globals, ) -> Result<()> { // TODO: Remove old entry? Or maybe just wipe completely from time to time? @@ -767,8 +416,6 @@ impl RoomEdus { Ok(()) } - /// Resets the presence timeout, so the user will stay in their current presence state. - #[tracing::instrument(skip(self))] pub fn ping_presence(&self, user_id: &UserId) -> Result<()> { self.userid_lastpresenceupdate.insert( user_id.as_bytes(), @@ -778,8 +425,7 @@ impl RoomEdus { Ok(()) } - /// Returns the timestamp of the last presence update of this user in millis since the unix epoch. - pub fn last_presence_update(&self, user_id: &UserId) -> Result> { + fn last_presence_update(&self, user_id: &UserId) -> Result> { self.userid_lastpresenceupdate .get(user_id.as_bytes())? .map(|bytes| { @@ -790,125 +436,29 @@ impl RoomEdus { .transpose() } - pub fn get_last_presence_event( + fn get_presence_event( &self, user_id: &UserId, room_id: &RoomId, + count: u64, ) -> Result> { - let last_update = match self.last_presence_update(user_id)? { - Some(last) => last, - None => return Ok(None), - }; - let mut presence_id = room_id.as_bytes().to_vec(); presence_id.push(0xff); - presence_id.extend_from_slice(&last_update.to_be_bytes()); + presence_id.extend_from_slice(&count.to_be_bytes()); presence_id.push(0xff); presence_id.extend_from_slice(user_id.as_bytes()); self.presenceid_presence .get(&presence_id)? - .map(|value| { - let mut presence: PresenceEvent = serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("Invalid presence event in db."))?; - let current_timestamp: UInt = utils::millis_since_unix_epoch() - .try_into() - .expect("time is valid"); - - if presence.content.presence == PresenceState::Online { - // Don't set last_active_ago when the user is online - presence.content.last_active_ago = None; - } else { - // Convert from timestamp to duration - presence.content.last_active_ago = presence - .content - .last_active_ago - .map(|timestamp| current_timestamp - timestamp); - } - - Ok(presence) - }) + .map(|value| parse_presence_event(&value)) .transpose() } - /// Sets all users to offline who have been quiet for too long. - fn _presence_maintain( - &self, - rooms: &super::Rooms, - globals: &super::super::globals::Globals, - ) -> Result<()> { - let current_timestamp = utils::millis_since_unix_epoch(); - - for (user_id_bytes, last_timestamp) in self - .userid_lastpresenceupdate - .iter() - .filter_map(|(k, bytes)| { - Some(( - k, - utils::u64_from_bytes(&bytes) - .map_err(|_| { - Error::bad_database("Invalid timestamp in userid_lastpresenceupdate.") - }) - .ok()?, - )) - }) - .take_while(|(_, timestamp)| current_timestamp.saturating_sub(*timestamp) > 5 * 60_000) - // 5 Minutes - { - // Send new presence events to set the user offline - let count = globals.next_count()?.to_be_bytes(); - let user_id: Box<_> = utils::string_from_bytes(&user_id_bytes) - .map_err(|_| { - Error::bad_database("Invalid UserId bytes in userid_lastpresenceupdate.") - })? - .try_into() - .map_err(|_| Error::bad_database("Invalid UserId in userid_lastpresenceupdate."))?; - for room_id in rooms.rooms_joined(&user_id).filter_map(|r| r.ok()) { - let mut presence_id = room_id.as_bytes().to_vec(); - presence_id.push(0xff); - presence_id.extend_from_slice(&count); - presence_id.push(0xff); - presence_id.extend_from_slice(&user_id_bytes); - - self.presenceid_presence.insert( - &presence_id, - &serde_json::to_vec(&PresenceEvent { - content: PresenceEventContent { - avatar_url: None, - currently_active: None, - displayname: None, - last_active_ago: Some( - last_timestamp.try_into().expect("time is valid"), - ), - presence: PresenceState::Offline, - status_msg: None, - }, - sender: user_id.to_owned(), - }) - .expect("PresenceEvent can be serialized"), - )?; - } - - self.userid_lastpresenceupdate.insert( - user_id.as_bytes(), - &utils::millis_since_unix_epoch().to_be_bytes(), - )?; - } - - Ok(()) - } - - /// Returns an iterator over the most recent presence updates that happened after the event with id `since`. - #[tracing::instrument(skip(self, since, _rooms, _globals))] - pub fn presence_since( + fn presence_since( &self, room_id: &RoomId, since: u64, - _rooms: &super::Rooms, - _globals: &super::super::globals::Globals, ) -> Result, PresenceEvent>> { - //self.presence_maintain(rooms, globals)?; - let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); @@ -931,23 +481,7 @@ impl RoomEdus { ) .map_err(|_| Error::bad_database("Invalid UserId in presenceid_presence."))?; - let mut presence: PresenceEvent = serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("Invalid presence event in db."))?; - - let current_timestamp: UInt = utils::millis_since_unix_epoch() - .try_into() - .expect("time is valid"); - - if presence.content.presence == PresenceState::Online { - // Don't set last_active_ago when the user is online - presence.content.last_active_ago = None; - } else { - // Convert from timestamp to duration - presence.content.last_active_ago = presence - .content - .last_active_ago - .map(|timestamp| current_timestamp - timestamp); - } + let presence = parse_presence_event(&value)?; hashmap.insert(user_id, presence); } @@ -956,8 +490,28 @@ impl RoomEdus { } } - #[tracing::instrument(skip(self))] - pub fn lazy_load_was_sent_before( +fn parse_presence_event(bytes: &[u8]) -> Result { + let mut presence: PresenceEvent = serde_json::from_slice(bytes) + .map_err(|_| Error::bad_database("Invalid presence event in db."))?; + + let current_timestamp: UInt = utils::millis_since_unix_epoch() + .try_into() + .expect("time is valid"); + + if presence.content.presence == PresenceState::Online { + // Don't set last_active_ago when the user is online + presence.content.last_active_ago = None; + } else { + // Convert from timestamp to duration + presence.content.last_active_ago = presence + .content + .last_active_ago + .map(|timestamp| current_timestamp - timestamp); + } +} + +impl service::room::lazy_load::Data for KeyValueDatabase { + fn lazy_load_was_sent_before( &self, user_id: &UserId, device_id: &DeviceId, @@ -974,28 +528,7 @@ impl RoomEdus { Ok(self.lazyloadedids.get(&key)?.is_some()) } - #[tracing::instrument(skip(self))] - pub fn lazy_load_mark_sent( - &self, - user_id: &UserId, - device_id: &DeviceId, - room_id: &RoomId, - lazy_load: HashSet>, - count: u64, - ) { - self.lazy_load_waiting.lock().unwrap().insert( - ( - user_id.to_owned(), - device_id.to_owned(), - room_id.to_owned(), - count, - ), - lazy_load, - ); - } - - #[tracing::instrument(skip(self))] - pub fn lazy_load_confirm_delivery( + fn lazy_load_confirm_delivery( &self, user_id: &UserId, device_id: &DeviceId, @@ -1025,8 +558,7 @@ impl RoomEdus { Ok(()) } - #[tracing::instrument(skip(self))] - pub fn lazy_load_reset( + fn lazy_load_reset( &self, user_id: &UserId, device_id: &DeviceId, @@ -1045,10 +577,10 @@ impl RoomEdus { Ok(()) } +} - /// Checks if a room exists. - #[tracing::instrument(skip(self))] - pub fn exists(&self, room_id: &RoomId) -> Result { +impl service::room::metadata::Data for KeyValueDatabase { + fn exists(&self, room_id: &RoomId) -> Result { let prefix = match self.get_shortroomid(room_id)? { Some(b) => b.to_be_bytes().to_vec(), None => return Ok(false), @@ -1062,36 +594,10 @@ impl RoomEdus { .filter(|(k, _)| k.starts_with(&prefix)) .is_some()) } +} - pub fn get_shortroomid(&self, room_id: &RoomId) -> Result> { - self.roomid_shortroomid - .get(room_id.as_bytes())? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid shortroomid in db.")) - }) - .transpose() - } - - pub fn get_or_create_shortroomid( - &self, - room_id: &RoomId, - globals: &super::globals::Globals, - ) -> Result { - Ok(match self.roomid_shortroomid.get(room_id.as_bytes())? { - Some(short) => utils::u64_from_bytes(&short) - .map_err(|_| Error::bad_database("Invalid shortroomid in db."))?, - None => { - let short = globals.next_count()?; - self.roomid_shortroomid - .insert(room_id.as_bytes(), &short.to_be_bytes())?; - short - } - }) - } - - /// Returns the pdu from the outlier tree. - pub fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { +impl service::room::outlier::Data for KeyValueDatabase { + fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { self.eventid_outlierpdu .get(event_id.as_bytes())? .map_or(Ok(None), |pdu| { @@ -1099,8 +605,7 @@ impl RoomEdus { }) } - /// Returns the pdu from the outlier tree. - pub fn get_pdu_outlier(&self, event_id: &EventId) -> Result> { + fn get_outlier_pdu(&self, event_id: &EventId) -> Result> { self.eventid_outlierpdu .get(event_id.as_bytes())? .map_or(Ok(None), |pdu| { @@ -1108,18 +613,16 @@ impl RoomEdus { }) } - /// Append the PDU as an outlier. - #[tracing::instrument(skip(self, pdu))] - pub fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> { + fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> { self.eventid_outlierpdu.insert( event_id.as_bytes(), &serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"), ) } +} - - #[tracing::instrument(skip(self, room_id, event_ids))] - pub fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()> { +impl service::room::pdu_metadata::Data for KeyValueDatabase { + fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()> { for prev in event_ids { let mut key = room_id.as_bytes().to_vec(); key.extend_from_slice(prev.as_bytes()); @@ -1129,22 +632,19 @@ impl RoomEdus { Ok(()) } - #[tracing::instrument(skip(self))] - pub fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result { + fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result { let mut key = room_id.as_bytes().to_vec(); key.extend_from_slice(event_id.as_bytes()); Ok(self.referencedevents.get(&key)?.is_some()) } - #[tracing::instrument(skip(self))] - pub fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { + fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { self.softfailedeventids.insert(event_id.as_bytes(), &[]) } - #[tracing::instrument(skip(self))] - pub fn is_event_soft_failed(&self, event_id: &EventId) -> Result { + fn is_event_soft_failed(&self, event_id: &EventId) -> Result { self.softfailedeventids .get(event_id.as_bytes()) .map(|o| o.is_some()) } - +} diff --git a/src/database/mod.rs b/src/database/mod.rs index a0937c29..a35228aa 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -15,7 +15,7 @@ pub mod users; use self::admin::create_admin_room; use crate::{utils, Config, Error, Result}; -use abstraction::DatabaseEngine; +use abstraction::KeyValueDatabaseEngine; use directories::ProjectDirs; use futures_util::{stream::FuturesUnordered, StreamExt}; use lru_cache::LruCache; @@ -39,8 +39,8 @@ use std::{ use tokio::sync::{mpsc, OwnedRwLockReadGuard, RwLock as TokioRwLock, Semaphore}; use tracing::{debug, error, info, warn}; -pub struct Database { - _db: Arc, +pub struct KeyValueDatabase { + _db: Arc, pub globals: globals::Globals, pub users: users::Users, pub uiaa: uiaa::Uiaa, @@ -55,7 +55,7 @@ pub struct Database { pub pusher: pusher::PushData, } -impl Database { +impl KeyValueDatabase { /// Tries to remove the old database but ignores all errors. pub fn try_remove(server_name: &str) -> Result<()> { let mut path = ProjectDirs::from("xyz", "koesters", "conduit") @@ -124,7 +124,7 @@ impl Database { .map_err(|_| Error::BadConfig("Database folder doesn't exists and couldn't be created (e.g. due to missing permissions). Please create the database folder yourself."))?; } - let builder: Arc = match &*config.database_backend { + let builder: Arc = match &*config.database_backend { "sqlite" => { #[cfg(not(feature = "sqlite"))] return Err(Error::BadConfig("Database backend not found.")); @@ -955,7 +955,7 @@ impl Database { } /// Sets the emergency password and push rules for the @conduit account in case emergency password is set -fn set_emergency_access(db: &Database) -> Result { +fn set_emergency_access(db: &KeyValueDatabase) -> Result { let conduit_user = UserId::parse_with_server_name("conduit", db.globals.server_name()) .expect("@conduit:server_name is a valid UserId"); @@ -979,39 +979,3 @@ fn set_emergency_access(db: &Database) -> Result { res } - -pub struct DatabaseGuard(OwnedRwLockReadGuard); - -impl Deref for DatabaseGuard { - type Target = OwnedRwLockReadGuard; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -#[cfg(feature = "conduit_bin")] -#[axum::async_trait] -impl axum::extract::FromRequest for DatabaseGuard -where - B: Send, -{ - type Rejection = axum::extract::rejection::ExtensionRejection; - - async fn from_request( - req: &mut axum::extract::RequestParts, - ) -> Result { - use axum::extract::Extension; - - let Extension(db): Extension>> = - Extension::from_request(req).await?; - - Ok(DatabaseGuard(db.read_owned().await)) - } -} - -impl From> for DatabaseGuard { - fn from(val: OwnedRwLockReadGuard) -> Self { - Self(val) - } -} diff --git a/src/main.rs b/src/main.rs index 9a0928a0..a1af9761 100644 --- a/src/main.rs +++ b/src/main.rs @@ -46,27 +46,26 @@ use tikv_jemallocator::Jemalloc; #[global_allocator] static GLOBAL: Jemalloc = Jemalloc; -#[tokio::main] -async fn main() { - let raw_config = - Figment::new() - .merge( - Toml::file(Env::var("CONDUIT_CONFIG").expect( - "The CONDUIT_CONFIG env var needs to be set. Example: /etc/conduit.toml", - )) - .nested(), - ) - .merge(Env::prefixed("CONDUIT_").global()); - - let config = match raw_config.extract::() { - Ok(s) => s, - Err(e) => { - eprintln!("It looks like your config is invalid. The following error occured while parsing it: {}", e); - std::process::exit(1); - } - }; +lazy_static! { + static ref DB: Database = { + let raw_config = + Figment::new() + .merge( + Toml::file(Env::var("CONDUIT_CONFIG").expect( + "The CONDUIT_CONFIG env var needs to be set. Example: /etc/conduit.toml", + )) + .nested(), + ) + .merge(Env::prefixed("CONDUIT_").global()); + + let config = match raw_config.extract::() { + Ok(s) => s, + Err(e) => { + eprintln!("It looks like your config is invalid. The following error occured while parsing it: {}", e); + std::process::exit(1); + } + }; - let start = async { config.warn_deprecated(); let db = match Database::load_or_create(&config).await { @@ -79,8 +78,15 @@ async fn main() { std::process::exit(1); } }; + }; +} - run_server(&config, db).await.unwrap(); +#[tokio::main] +async fn main() { + lazy_static::initialize(&DB); + + let start = async { + run_server(&config).await.unwrap(); }; if config.allow_jaeger { @@ -120,7 +126,8 @@ async fn main() { } } -async fn run_server(config: &Config, db: Arc>) -> io::Result<()> { +async fn run_server() -> io::Result<()> { + let config = DB.globals.config; let addr = SocketAddr::from((config.address, config.port)); let x_requested_with = HeaderName::from_static("x-requested-with"); diff --git a/src/service/rooms/alias/data.rs b/src/service/rooms/alias/data.rs new file mode 100644 index 00000000..9dbfc7b5 --- /dev/null +++ b/src/service/rooms/alias/data.rs @@ -0,0 +1,22 @@ +pub trait Data { + /// Creates or updates the alias to the given room id. + pub fn set_alias( + alias: &RoomAliasId, + room_id: &RoomId + ) -> Result<()>; + + /// Forgets about an alias. Returns an error if the alias did not exist. + pub fn remove_alias( + alias: &RoomAliasId, + ) -> Result<()>; + + /// Looks up the roomid for the given alias. + pub fn resolve_local_alias( + alias: &RoomAliasId, + ) -> Result<()>; + + /// Returns all local aliases that point to the given room + pub fn local_aliases_for_room( + alias: &RoomAliasId, + ) -> Result<()>; +} diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs index 393ad671..cfe05396 100644 --- a/src/service/rooms/alias/mod.rs +++ b/src/service/rooms/alias/mod.rs @@ -1,66 +1,40 @@ +mod data; +pub use data::Data; +use crate::service::*; + +pub struct Service { + db: D, +} + +impl Service<_> { #[tracing::instrument(skip(self, globals))] pub fn set_alias( &self, alias: &RoomAliasId, - room_id: Option<&RoomId>, - globals: &super::globals::Globals, + room_id: &RoomId, ) -> Result<()> { - if let Some(room_id) = room_id { - // New alias - self.alias_roomid - .insert(alias.alias().as_bytes(), room_id.as_bytes())?; - let mut aliasid = room_id.as_bytes().to_vec(); - aliasid.push(0xff); - aliasid.extend_from_slice(&globals.next_count()?.to_be_bytes()); - self.aliasid_alias.insert(&aliasid, &*alias.as_bytes())?; - } else { - // room_id=None means remove alias - if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? { - let mut prefix = room_id.to_vec(); - prefix.push(0xff); - - for (key, _) in self.aliasid_alias.scan_prefix(prefix) { - self.aliasid_alias.remove(&key)?; - } - self.alias_roomid.remove(alias.alias().as_bytes())?; - } else { - return Err(Error::BadRequest( - ErrorKind::NotFound, - "Alias does not exist.", - )); - } - } + self.db.set_alias(alias, room_id) + } - Ok(()) + #[tracing::instrument(skip(self, globals))] + pub fn remove_alias( + &self, + alias: &RoomAliasId, + ) -> Result<()> { + self.db.remove_alias(alias) } #[tracing::instrument(skip(self))] - pub fn id_from_alias(&self, alias: &RoomAliasId) -> Result>> { - self.alias_roomid - .get(alias.alias().as_bytes())? - .map(|bytes| { - RoomId::parse(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Room ID in alias_roomid is invalid unicode.") - })?) - .map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid.")) - }) - .transpose() + pub fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result>> { + self.db.resolve_local_alias(alias: &RoomAliasId) } #[tracing::instrument(skip(self))] - pub fn room_aliases<'a>( + pub fn local_aliases_for_room<'a>( &'a self, room_id: &RoomId, ) -> impl Iterator>> + 'a { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); - - self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| { - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))? - .try_into() - .map_err(|_| Error::bad_database("Invalid alias in aliasid_alias.")) - }) + self.db.local_aliases_for_room(room_id) } - +} diff --git a/src/service/rooms/directory/data.rs b/src/service/rooms/directory/data.rs new file mode 100644 index 00000000..83d78853 --- /dev/null +++ b/src/service/rooms/directory/data.rs @@ -0,0 +1,13 @@ +pub trait Data { + /// Adds the room to the public room directory + fn set_public(room_id: &RoomId) -> Result<()>; + + /// Removes the room from the public room directory. + fn set_not_public(room_id: &RoomId) -> Result<()>; + + /// Returns true if the room is in the public room directory. + fn is_public_room(room_id: &RoomId) -> Result; + + /// Returns the unsorted public room directory + fn public_rooms() -> impl Iterator>> + '_; +} diff --git a/src/service/rooms/directory/mod.rs b/src/service/rooms/directory/mod.rs index 8be7bd57..b92933f4 100644 --- a/src/service/rooms/directory/mod.rs +++ b/src/service/rooms/directory/mod.rs @@ -1,29 +1,30 @@ +mod data; +pub use data::Data; +use crate::service::*; + +pub struct Service { + db: D, +} + +impl Service<_> { #[tracing::instrument(skip(self))] - pub fn set_public(&self, room_id: &RoomId, public: bool) -> Result<()> { - if public { - self.publicroomids.insert(room_id.as_bytes(), &[])?; - } else { - self.publicroomids.remove(room_id.as_bytes())?; - } + pub fn set_public(&self, room_id: &RoomId) -> Result<()> { + self.db.set_public(&self, room_id) + } - Ok(()) + #[tracing::instrument(skip(self))] + pub fn set_not_public(&self, room_id: &RoomId) -> Result<()> { + self.db.set_not_public(&self, room_id) } #[tracing::instrument(skip(self))] pub fn is_public_room(&self, room_id: &RoomId) -> Result { - Ok(self.publicroomids.get(room_id.as_bytes())?.is_some()) + self.db.is_public_room(&self, room_id) } #[tracing::instrument(skip(self))] pub fn public_rooms(&self) -> impl Iterator>> + '_ { - self.publicroomids.iter().map(|(bytes, _)| { - RoomId::parse( - utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Room ID in publicroomids is invalid unicode.") - })?, - ) - .map_err(|_| Error::bad_database("Room ID in publicroomids is invalid.")) - }) + self.db.public_rooms(&self, room_id) } - +} diff --git a/src/service/rooms/edus/data.rs b/src/service/rooms/edus/data.rs new file mode 100644 index 00000000..16c14cf3 --- /dev/null +++ b/src/service/rooms/edus/data.rs @@ -0,0 +1,91 @@ +pub trait Data { + /// Replaces the previous read receipt. + fn readreceipt_update( + &self, + user_id: &UserId, + room_id: &RoomId, + event: ReceiptEvent, + ) -> Result<()>; + + /// Returns an iterator over the most recent read_receipts in a room that happened after the event with id `since`. + fn readreceipts_since( + &self, + room_id: &RoomId, + since: u64, + ) -> impl Iterator< + Item = Result<( + Box, + u64, + Raw, + )>, + >; + + /// Sets a private read marker at `count`. + fn private_read_set( + &self, + room_id: &RoomId, + user_id: &UserId, + count: u64, + ) -> Result<()>; + + /// Returns the private read marker. + fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result>; + + /// Returns the count of the last typing update in this room. + fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result; + + /// Sets a user as typing until the timeout timestamp is reached or roomtyping_remove is + /// called. + fn typing_add( + &self, + user_id: &UserId, + room_id: &RoomId, + timeout: u64, + ) -> Result<()>; + + /// Removes a user from typing before the timeout is reached. + fn typing_remove( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result<()>; + + /// Returns the count of the last typing update in this room. + fn last_typing_update( + &self, + room_id: &RoomId, + ) -> Result; + + /// Returns all user ids currently typing. + fn typings_all( + &self, + room_id: &RoomId, + ) -> Result>; + + /// Adds a presence event which will be saved until a new event replaces it. + /// + /// Note: This method takes a RoomId because presence updates are always bound to rooms to + /// make sure users outside these rooms can't see them. + fn update_presence( + &self, + user_id: &UserId, + room_id: &RoomId, + presence: PresenceEvent, + ) -> Result<()>; + + /// Resets the presence timeout, so the user will stay in their current presence state. + fn ping_presence(&self, user_id: &UserId) -> Result<()>; + + /// Returns the timestamp of the last presence update of this user in millis since the unix epoch. + fn last_presence_update(&self, user_id: &UserId) -> Result>; + + /// Returns the presence event with correct last_active_ago. + fn get_presence_event(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result>; + + /// Returns the most recent presence updates that happened after the event with id `since`. + fn presence_since( + &self, + room_id: &RoomId, + since: u64, + ) -> Result, PresenceEvent>>; +} diff --git a/src/service/rooms/edus/mod.rs b/src/service/rooms/edus/mod.rs index 118efd4c..06adf57e 100644 --- a/src/service/rooms/edus/mod.rs +++ b/src/service/rooms/edus/mod.rs @@ -1,73 +1,21 @@ -use crate::{database::abstraction::Tree, utils, Error, Result}; -use ruma::{ - events::{ - presence::{PresenceEvent, PresenceEventContent}, - receipt::ReceiptEvent, - SyncEphemeralRoomEvent, - }, - presence::PresenceState, - serde::Raw, - signatures::CanonicalJsonObject, - RoomId, UInt, UserId, -}; -use std::{ - collections::{HashMap, HashSet}, - mem, - sync::Arc, -}; +mod data; +pub use data::Data; -pub struct RoomEdus { - pub(in super::super) readreceiptid_readreceipt: Arc, // ReadReceiptId = RoomId + Count + UserId - pub(in super::super) roomuserid_privateread: Arc, // RoomUserId = Room + User, PrivateRead = Count - pub(in super::super) roomuserid_lastprivatereadupdate: Arc, // LastPrivateReadUpdate = Count - pub(in super::super) typingid_userid: Arc, // TypingId = RoomId + TimeoutTime + Count - pub(in super::super) roomid_lasttypingupdate: Arc, // LastRoomTypingUpdate = Count - pub(in super::super) presenceid_presence: Arc, // PresenceId = RoomId + Count + UserId - pub(in super::super) userid_lastpresenceupdate: Arc, // LastPresenceUpdate = Count +use crate::service::*; + +pub struct Service { + db: D, } -impl RoomEdus { - /// Adds an event which will be saved until a new event replaces it (e.g. read receipt). +impl Service<_> { + /// Replaces the previous read receipt. pub fn readreceipt_update( &self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent, - globals: &super::super::globals::Globals, ) -> Result<()> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); - - let mut last_possible_key = prefix.clone(); - last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); - - // Remove old entry - if let Some((old, _)) = self - .readreceiptid_readreceipt - .iter_from(&last_possible_key, true) - .take_while(|(key, _)| key.starts_with(&prefix)) - .find(|(key, _)| { - key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element") - == user_id.as_bytes() - }) - { - // This is the old room_latest - self.readreceiptid_readreceipt.remove(&old)?; - } - - let mut room_latest_id = prefix; - room_latest_id.extend_from_slice(&globals.next_count()?.to_be_bytes()); - room_latest_id.push(0xff); - room_latest_id.extend_from_slice(user_id.as_bytes()); - - self.readreceiptid_readreceipt.insert( - &room_latest_id, - &serde_json::to_vec(&event).expect("EduEvent::to_string always works"), - )?; - - Ok(()) + self.db.readreceipt_update(user_id, room_id, event); } /// Returns an iterator over the most recent read_receipts in a room that happened after the event with id `since`. @@ -83,41 +31,7 @@ impl RoomEdus { Raw, )>, > + 'a { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); - let prefix2 = prefix.clone(); - - let mut first_possible_edu = prefix.clone(); - first_possible_edu.extend_from_slice(&(since + 1).to_be_bytes()); // +1 so we don't send the event at since - - self.readreceiptid_readreceipt - .iter_from(&first_possible_edu, false) - .take_while(move |(k, _)| k.starts_with(&prefix2)) - .map(move |(k, v)| { - let count = - utils::u64_from_bytes(&k[prefix.len()..prefix.len() + mem::size_of::()]) - .map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?; - let user_id = UserId::parse( - utils::string_from_bytes(&k[prefix.len() + mem::size_of::() + 1..]) - .map_err(|_| { - Error::bad_database("Invalid readreceiptid userid bytes in db.") - })?, - ) - .map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?; - - let mut json = serde_json::from_slice::(&v).map_err(|_| { - Error::bad_database("Read receipt in roomlatestid_roomlatest is invalid json.") - })?; - json.remove("room_id"); - - Ok(( - user_id, - count, - Raw::from_json( - serde_json::value::to_raw_value(&json).expect("json is valid raw value"), - ), - )) - }) + self.db.readreceipts_since(room_id, since) } /// Sets a private read marker at `count`. @@ -127,53 +41,19 @@ impl RoomEdus { room_id: &RoomId, user_id: &UserId, count: u64, - globals: &super::super::globals::Globals, ) -> Result<()> { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(user_id.as_bytes()); - - self.roomuserid_privateread - .insert(&key, &count.to_be_bytes())?; - - self.roomuserid_lastprivatereadupdate - .insert(&key, &globals.next_count()?.to_be_bytes())?; - - Ok(()) + self.db.private_read_set(room_id, user_id, count) } /// Returns the private read marker. #[tracing::instrument(skip(self))] pub fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result> { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(user_id.as_bytes()); - - self.roomuserid_privateread - .get(&key)? - .map_or(Ok(None), |v| { - Ok(Some(utils::u64_from_bytes(&v).map_err(|_| { - Error::bad_database("Invalid private read marker bytes") - })?)) - }) + self.db.private_read_get(room_id, user_id) } /// Returns the count of the last typing update in this room. pub fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(user_id.as_bytes()); - - Ok(self - .roomuserid_lastprivatereadupdate - .get(&key)? - .map(|bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.") - }) - }) - .transpose()? - .unwrap_or(0)) + self.db.last_privateread_update(user_id, room_id) } /// Sets a user as typing until the timeout timestamp is reached or roomtyping_remove is @@ -183,25 +63,8 @@ impl RoomEdus { user_id: &UserId, room_id: &RoomId, timeout: u64, - globals: &super::super::globals::Globals, ) -> Result<()> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); - - let count = globals.next_count()?.to_be_bytes(); - - let mut room_typing_id = prefix; - room_typing_id.extend_from_slice(&timeout.to_be_bytes()); - room_typing_id.push(0xff); - room_typing_id.extend_from_slice(&count); - - self.typingid_userid - .insert(&room_typing_id, &*user_id.as_bytes())?; - - self.roomid_lasttypingupdate - .insert(room_id.as_bytes(), &count)?; - - Ok(()) + self.db.typing_add(user_id, room_id, timeout) } /// Removes a user from typing before the timeout is reached. @@ -209,33 +72,11 @@ impl RoomEdus { &self, user_id: &UserId, room_id: &RoomId, - globals: &super::super::globals::Globals, ) -> Result<()> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); - - let user_id = user_id.to_string(); - - let mut found_outdated = false; - - // Maybe there are multiple ones from calling roomtyping_add multiple times - for outdated_edu in self - .typingid_userid - .scan_prefix(prefix) - .filter(|(_, v)| &**v == user_id.as_bytes()) - { - self.typingid_userid.remove(&outdated_edu.0)?; - found_outdated = true; - } - - if found_outdated { - self.roomid_lasttypingupdate - .insert(room_id.as_bytes(), &globals.next_count()?.to_be_bytes())?; - } - - Ok(()) + self.db.typing_remove(user_id, room_id) } + /* TODO: Do this in background thread? /// Makes sure that typing events with old timestamps get removed. fn typings_maintain( &self, @@ -279,45 +120,23 @@ impl RoomEdus { Ok(()) } + */ /// Returns the count of the last typing update in this room. #[tracing::instrument(skip(self, globals))] pub fn last_typing_update( &self, room_id: &RoomId, - globals: &super::super::globals::Globals, ) -> Result { - self.typings_maintain(room_id, globals)?; - - Ok(self - .roomid_lasttypingupdate - .get(room_id.as_bytes())? - .map(|bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.") - }) - }) - .transpose()? - .unwrap_or(0)) + self.db.last_typing_update(room_id) } + /// Returns a new typing EDU. pub fn typings_all( &self, room_id: &RoomId, ) -> Result> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); - - let mut user_ids = HashSet::new(); - - for (_, user_id) in self.typingid_userid.scan_prefix(prefix) { - let user_id = UserId::parse(utils::string_from_bytes(&user_id).map_err(|_| { - Error::bad_database("User ID in typingid_userid is invalid unicode.") - })?) - .map_err(|_| Error::bad_database("User ID in typingid_userid is invalid."))?; - - user_ids.insert(user_id); - } + let user_ids = self.db.typings_all(room_id)?; Ok(SyncEphemeralRoomEvent { content: ruma::events::typing::TypingEventContent { @@ -335,52 +154,13 @@ impl RoomEdus { user_id: &UserId, room_id: &RoomId, presence: PresenceEvent, - globals: &super::super::globals::Globals, ) -> Result<()> { - // TODO: Remove old entry? Or maybe just wipe completely from time to time? - - let count = globals.next_count()?.to_be_bytes(); - - let mut presence_id = room_id.as_bytes().to_vec(); - presence_id.push(0xff); - presence_id.extend_from_slice(&count); - presence_id.push(0xff); - presence_id.extend_from_slice(presence.sender.as_bytes()); - - self.presenceid_presence.insert( - &presence_id, - &serde_json::to_vec(&presence).expect("PresenceEvent can be serialized"), - )?; - - self.userid_lastpresenceupdate.insert( - user_id.as_bytes(), - &utils::millis_since_unix_epoch().to_be_bytes(), - )?; - - Ok(()) + self.db.update_presence(user_id, room_id, presence) } /// Resets the presence timeout, so the user will stay in their current presence state. - #[tracing::instrument(skip(self))] pub fn ping_presence(&self, user_id: &UserId) -> Result<()> { - self.userid_lastpresenceupdate.insert( - user_id.as_bytes(), - &utils::millis_since_unix_epoch().to_be_bytes(), - )?; - - Ok(()) - } - - /// Returns the timestamp of the last presence update of this user in millis since the unix epoch. - pub fn last_presence_update(&self, user_id: &UserId) -> Result> { - self.userid_lastpresenceupdate - .get(user_id.as_bytes())? - .map(|bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Invalid timestamp in userid_lastpresenceupdate.") - }) - }) - .transpose() + self.db.ping_presence(user_id) } pub fn get_last_presence_event( @@ -388,42 +168,15 @@ impl RoomEdus { user_id: &UserId, room_id: &RoomId, ) -> Result> { - let last_update = match self.last_presence_update(user_id)? { + let last_update = match self.db.last_presence_update(user_id)? { Some(last) => last, None => return Ok(None), }; - let mut presence_id = room_id.as_bytes().to_vec(); - presence_id.push(0xff); - presence_id.extend_from_slice(&last_update.to_be_bytes()); - presence_id.push(0xff); - presence_id.extend_from_slice(user_id.as_bytes()); - - self.presenceid_presence - .get(&presence_id)? - .map(|value| { - let mut presence: PresenceEvent = serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("Invalid presence event in db."))?; - let current_timestamp: UInt = utils::millis_since_unix_epoch() - .try_into() - .expect("time is valid"); - - if presence.content.presence == PresenceState::Online { - // Don't set last_active_ago when the user is online - presence.content.last_active_ago = None; - } else { - // Convert from timestamp to duration - presence.content.last_active_ago = presence - .content - .last_active_ago - .map(|timestamp| current_timestamp - timestamp); - } - - Ok(presence) - }) - .transpose() + self.db.get_presence_event(room_id, user_id, last_update) } + /* TODO /// Sets all users to offline who have been quiet for too long. fn _presence_maintain( &self, @@ -489,62 +242,15 @@ impl RoomEdus { } Ok(()) - } + }*/ - /// Returns an iterator over the most recent presence updates that happened after the event with id `since`. + /// Returns the most recent presence updates that happened after the event with id `since`. #[tracing::instrument(skip(self, since, _rooms, _globals))] pub fn presence_since( &self, room_id: &RoomId, since: u64, - _rooms: &super::Rooms, - _globals: &super::super::globals::Globals, ) -> Result, PresenceEvent>> { - //self.presence_maintain(rooms, globals)?; - - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); - - let mut first_possible_edu = prefix.clone(); - first_possible_edu.extend_from_slice(&(since + 1).to_be_bytes()); // +1 so we don't send the event at since - let mut hashmap = HashMap::new(); - - for (key, value) in self - .presenceid_presence - .iter_from(&*first_possible_edu, false) - .take_while(|(key, _)| key.starts_with(&prefix)) - { - let user_id = UserId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("Invalid UserId bytes in presenceid_presence."))?, - ) - .map_err(|_| Error::bad_database("Invalid UserId in presenceid_presence."))?; - - let mut presence: PresenceEvent = serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("Invalid presence event in db."))?; - - let current_timestamp: UInt = utils::millis_since_unix_epoch() - .try_into() - .expect("time is valid"); - - if presence.content.presence == PresenceState::Online { - // Don't set last_active_ago when the user is online - presence.content.last_active_ago = None; - } else { - // Convert from timestamp to duration - presence.content.last_active_ago = presence - .content - .last_active_ago - .map(|timestamp| current_timestamp - timestamp); - } - - hashmap.insert(user_id, presence); - } - - Ok(hashmap) + self.db.presence_since(room_id, since) } } diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index e59219b2..5b77586a 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -2,1151 +2,1157 @@ /// An async function that can recursively call itself. type AsyncRecursiveType<'a, T> = Pin + 'a + Send>>; -/// When receiving an event one needs to: -/// 0. Check the server is in the room -/// 1. Skip the PDU if we already know about it -/// 2. Check signatures, otherwise drop -/// 3. Check content hash, redact if doesn't match -/// 4. Fetch any missing auth events doing all checks listed here starting at 1. These are not -/// timeline events -/// 5. Reject "due to auth events" if can't get all the auth events or some of the auth events are -/// also rejected "due to auth events" -/// 6. Reject "due to auth events" if the event doesn't pass auth based on the auth events -/// 7. Persist this event as an outlier -/// 8. If not timeline event: stop -/// 9. Fetch any missing prev events doing all checks listed here starting at 1. These are timeline -/// events -/// 10. Fetch missing state and auth chain events by calling /state_ids at backwards extremities -/// doing all the checks in this list starting at 1. These are not timeline events -/// 11. Check the auth of the event passes based on the state of the event -/// 12. Ensure that the state is derived from the previous current state (i.e. we calculated by -/// doing state res where one of the inputs was a previously trusted set of state, don't just -/// trust a set of state we got from a remote) -/// 13. Check if the event passes auth based on the "current state" of the room, if not "soft fail" -/// it -/// 14. Use state resolution to find new room state -// We use some AsyncRecursiveType hacks here so we can call this async funtion recursively -#[tracing::instrument(skip(value, is_timeline_event, db, pub_key_map))] -pub(crate) async fn handle_incoming_pdu<'a>( - origin: &'a ServerName, - event_id: &'a EventId, - room_id: &'a RoomId, - value: BTreeMap, - is_timeline_event: bool, - db: &'a Database, - pub_key_map: &'a RwLock>>, -) -> Result>> { - db.rooms.exists(room_id)?.ok_or(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server"))?; - - db.rooms.is_disabled(room_id)?.ok_or(Error::BadRequest(ErrorKind::Forbidden, "Federation of this room is currently disabled on this server."))?; - - // 1. Skip the PDU if we already have it as a timeline event - if let Some(pdu_id) = db.rooms.get_pdu_id(event_id)? { - return Some(pdu_id.to_vec()); - } - - let create_event = db - .rooms - .room_state_get(room_id, &StateEventType::RoomCreate, "")? - .ok_or_else(|| Error::bad_database("Failed to find create event in db."))?; - - let first_pdu_in_room = db - .rooms - .first_pdu_in_room(room_id)? - .ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?; - - let (incoming_pdu, val) = handle_outlier_pdu( - origin, - &create_event, - event_id, - room_id, - value, - db, - pub_key_map, - ) - .await?; - - // 8. if not timeline event: stop - if !is_timeline_event { - return Ok(None); - } - - // Skip old events - if incoming_pdu.origin_server_ts < first_pdu_in_room.origin_server_ts { - return Ok(None); - } +use crate::service::*; + +pub struct Service; + +impl Service { + /// When receiving an event one needs to: + /// 0. Check the server is in the room + /// 1. Skip the PDU if we already know about it + /// 2. Check signatures, otherwise drop + /// 3. Check content hash, redact if doesn't match + /// 4. Fetch any missing auth events doing all checks listed here starting at 1. These are not + /// timeline events + /// 5. Reject "due to auth events" if can't get all the auth events or some of the auth events are + /// also rejected "due to auth events" + /// 6. Reject "due to auth events" if the event doesn't pass auth based on the auth events + /// 7. Persist this event as an outlier + /// 8. If not timeline event: stop + /// 9. Fetch any missing prev events doing all checks listed here starting at 1. These are timeline + /// events + /// 10. Fetch missing state and auth chain events by calling /state_ids at backwards extremities + /// doing all the checks in this list starting at 1. These are not timeline events + /// 11. Check the auth of the event passes based on the state of the event + /// 12. Ensure that the state is derived from the previous current state (i.e. we calculated by + /// doing state res where one of the inputs was a previously trusted set of state, don't just + /// trust a set of state we got from a remote) + /// 13. Check if the event passes auth based on the "current state" of the room, if not "soft fail" + /// it + /// 14. Use state resolution to find new room state + // We use some AsyncRecursiveType hacks here so we can call this async funtion recursively + #[tracing::instrument(skip(value, is_timeline_event, db, pub_key_map))] + pub(crate) async fn handle_incoming_pdu<'a>( + origin: &'a ServerName, + event_id: &'a EventId, + room_id: &'a RoomId, + value: BTreeMap, + is_timeline_event: bool, + db: &'a Database, + pub_key_map: &'a RwLock>>, + ) -> Result>> { + db.rooms.exists(room_id)?.ok_or(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server"))?; + + db.rooms.is_disabled(room_id)?.ok_or(Error::BadRequest(ErrorKind::Forbidden, "Federation of this room is currently disabled on this server."))?; + + // 1. Skip the PDU if we already have it as a timeline event + if let Some(pdu_id) = db.rooms.get_pdu_id(event_id)? { + return Some(pdu_id.to_vec()); + } - // 9. Fetch any missing prev events doing all checks listed here starting at 1. These are timeline events - let sorted_prev_events = fetch_unknown_prev_events(incoming_pdu.prev_events.clone()); + let create_event = db + .rooms + .room_state_get(room_id, &StateEventType::RoomCreate, "")? + .ok_or_else(|| Error::bad_database("Failed to find create event in db."))?; - let mut errors = 0; - for prev_id in dbg!(sorted) { - // Check for disabled again because it might have changed - db.rooms.is_disabled(room_id)?.ok_or(Error::BadRequest(ErrorKind::Forbidden, "Federation of - this room is currently disabled on this server."))?; + let first_pdu_in_room = db + .rooms + .first_pdu_in_room(room_id)? + .ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?; - if let Some((time, tries)) = db - .globals - .bad_event_ratelimiter - .read() - .unwrap() - .get(&*prev_id) - { - // Exponential backoff - let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); - if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { - min_elapsed_duration = Duration::from_secs(60 * 60 * 24); - } + let (incoming_pdu, val) = handle_outlier_pdu( + origin, + &create_event, + event_id, + room_id, + value, + db, + pub_key_map, + ) + .await?; - if time.elapsed() < min_elapsed_duration { - info!("Backing off from {}", prev_id); - continue; - } + // 8. if not timeline event: stop + if !is_timeline_event { + return Ok(None); } - if errors >= 5 { - break; + // Skip old events + if incoming_pdu.origin_server_ts < first_pdu_in_room.origin_server_ts { + return Ok(None); } - if let Some((pdu, json)) = eventid_info.remove(&*prev_id) { - // Skip old events - if pdu.origin_server_ts < first_pdu_in_room.origin_server_ts { - continue; - } + // 9. Fetch any missing prev events doing all checks listed here starting at 1. These are timeline events + let sorted_prev_events = fetch_unknown_prev_events(incoming_pdu.prev_events.clone()); - let start_time = Instant::now(); - db.globals - .roomid_federationhandletime - .write() - .unwrap() - .insert(room_id.to_owned(), ((*prev_id).to_owned(), start_time)); + let mut errors = 0; + for prev_id in dbg!(sorted) { + // Check for disabled again because it might have changed + db.rooms.is_disabled(room_id)?.ok_or(Error::BadRequest(ErrorKind::Forbidden, "Federation of + this room is currently disabled on this server."))?; - if let Err(e) = upgrade_outlier_to_timeline_pdu( - pdu, - json, - &create_event, - origin, - db, - room_id, - pub_key_map, - ) - .await + if let Some((time, tries)) = db + .globals + .bad_event_ratelimiter + .read() + .unwrap() + .get(&*prev_id) { - errors += 1; - warn!("Prev event {} failed: {}", prev_id, e); - match db - .globals - .bad_event_ratelimiter - .write() - .unwrap() - .entry((*prev_id).to_owned()) - { - hash_map::Entry::Vacant(e) => { - e.insert((Instant::now(), 1)); - } - hash_map::Entry::Occupied(mut e) => { - *e.get_mut() = (Instant::now(), e.get().1 + 1) - } + // Exponential backoff + let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); + if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { + min_elapsed_duration = Duration::from_secs(60 * 60 * 24); + } + + if time.elapsed() < min_elapsed_duration { + info!("Backing off from {}", prev_id); + continue; } } - let elapsed = start_time.elapsed(); - db.globals - .roomid_federationhandletime - .write() - .unwrap() - .remove(&room_id.to_owned()); - warn!( - "Handling prev event {} took {}m{}s", - prev_id, - elapsed.as_secs() / 60, - elapsed.as_secs() % 60 - ); - } - } - // Done with prev events, now handling the incoming event - - let start_time = Instant::now(); - db.globals - .roomid_federationhandletime - .write() - .unwrap() - .insert(room_id.to_owned(), (event_id.to_owned(), start_time)); - let r = upgrade_outlier_to_timeline_pdu( - incoming_pdu, - val, - &create_event, - origin, - db, - room_id, - pub_key_map, - ) - .await; - db.globals - .roomid_federationhandletime - .write() - .unwrap() - .remove(&room_id.to_owned()); - - r -} + if errors >= 5 { + break; + } -#[tracing::instrument(skip(create_event, value, db, pub_key_map))] -fn handle_outlier_pdu<'a>( - origin: &'a ServerName, - create_event: &'a PduEvent, - event_id: &'a EventId, - room_id: &'a RoomId, - value: BTreeMap, - db: &'a Database, - pub_key_map: &'a RwLock>>, -) -> AsyncRecursiveType<'a, Result<(Arc, BTreeMap), String>> { - Box::pin(async move { - // TODO: For RoomVersion6 we must check that Raw<..> is canonical do we anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json - - // We go through all the signatures we see on the value and fetch the corresponding signing - // keys - fetch_required_signing_keys(&value, pub_key_map, db) - .await?; - - // 2. Check signatures, otherwise drop - // 3. check content hash, redact if doesn't match - let create_event_content: RoomCreateEventContent = - serde_json::from_str(create_event.content.get()).map_err(|e| { - error!("Invalid create event: {}", e); - Error::BadDatabase("Invalid create event in db") - })?; + if let Some((pdu, json)) = eventid_info.remove(&*prev_id) { + // Skip old events + if pdu.origin_server_ts < first_pdu_in_room.origin_server_ts { + continue; + } - let room_version_id = &create_event_content.room_version; - let room_version = RoomVersion::new(room_version_id).expect("room version is supported"); + let start_time = Instant::now(); + db.globals + .roomid_federationhandletime + .write() + .unwrap() + .insert(room_id.to_owned(), ((*prev_id).to_owned(), start_time)); - let mut val = match ruma::signatures::verify_event( - &*pub_key_map.read().map_err(|_| "RwLock is poisoned.")?, - &value, - room_version_id, - ) { - Err(e) => { - // Drop - warn!("Dropping bad event {}: {}", event_id, e); - return Err("Signature verification failed".to_owned()); - } - Ok(ruma::signatures::Verified::Signatures) => { - // Redact - warn!("Calculated hash does not match: {}", event_id); - match ruma::signatures::redact(&value, room_version_id) { - Ok(obj) => obj, - Err(_) => return Err("Redaction failed".to_owned()), + if let Err(e) = upgrade_outlier_to_timeline_pdu( + pdu, + json, + &create_event, + origin, + db, + room_id, + pub_key_map, + ) + .await + { + errors += 1; + warn!("Prev event {} failed: {}", prev_id, e); + match db + .globals + .bad_event_ratelimiter + .write() + .unwrap() + .entry((*prev_id).to_owned()) + { + hash_map::Entry::Vacant(e) => { + e.insert((Instant::now(), 1)); + } + hash_map::Entry::Occupied(mut e) => { + *e.get_mut() = (Instant::now(), e.get().1 + 1) + } + } } + let elapsed = start_time.elapsed(); + db.globals + .roomid_federationhandletime + .write() + .unwrap() + .remove(&room_id.to_owned()); + warn!( + "Handling prev event {} took {}m{}s", + prev_id, + elapsed.as_secs() / 60, + elapsed.as_secs() % 60 + ); } - Ok(ruma::signatures::Verified::All) => value, - }; - - // Now that we have checked the signature and hashes we can add the eventID and convert - // to our PduEvent type - val.insert( - "event_id".to_owned(), - CanonicalJsonValue::String(event_id.as_str().to_owned()), - ); - let incoming_pdu = serde_json::from_value::( - serde_json::to_value(&val).expect("CanonicalJsonObj is a valid JsonValue"), - ) - .map_err(|_| "Event is not a valid PDU.".to_owned())?; + } - // 4. fetch any missing auth events doing all checks listed here starting at 1. These are not timeline events - // 5. Reject "due to auth events" if can't get all the auth events or some of the auth events are also rejected "due to auth events" - // NOTE: Step 5 is not applied anymore because it failed too often - warn!("Fetching auth events for {}", incoming_pdu.event_id); - fetch_and_handle_outliers( - db, + // Done with prev events, now handling the incoming event + + let start_time = Instant::now(); + db.globals + .roomid_federationhandletime + .write() + .unwrap() + .insert(room_id.to_owned(), (event_id.to_owned(), start_time)); + let r = upgrade_outlier_to_timeline_pdu( + incoming_pdu, + val, + &create_event, origin, - &incoming_pdu - .auth_events - .iter() - .map(|x| Arc::from(&**x)) - .collect::>(), - create_event, + db, room_id, pub_key_map, ) .await; + db.globals + .roomid_federationhandletime + .write() + .unwrap() + .remove(&room_id.to_owned()); - // 6. Reject "due to auth events" if the event doesn't pass auth based on the auth events - info!( - "Auth check for {} based on auth events", - incoming_pdu.event_id - ); + r + } - // Build map of auth events - let mut auth_events = HashMap::new(); - for id in &incoming_pdu.auth_events { - let auth_event = match db.rooms.get_pdu(id)? { - Some(e) => e, - None => { - warn!("Could not find auth event {}", id); - continue; - } - }; + #[tracing::instrument(skip(create_event, value, db, pub_key_map))] + fn handle_outlier_pdu<'a>( + origin: &'a ServerName, + create_event: &'a PduEvent, + event_id: &'a EventId, + room_id: &'a RoomId, + value: BTreeMap, + db: &'a Database, + pub_key_map: &'a RwLock>>, + ) -> AsyncRecursiveType<'a, Result<(Arc, BTreeMap), String>> { + Box::pin(async move { + // TODO: For RoomVersion6 we must check that Raw<..> is canonical do we anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json + + // We go through all the signatures we see on the value and fetch the corresponding signing + // keys + fetch_required_signing_keys(&value, pub_key_map, db) + .await?; + + // 2. Check signatures, otherwise drop + // 3. check content hash, redact if doesn't match + let create_event_content: RoomCreateEventContent = + serde_json::from_str(create_event.content.get()).map_err(|e| { + error!("Invalid create event: {}", e); + Error::BadDatabase("Invalid create event in db") + })?; - match auth_events.entry(( - auth_event.kind.to_string().into(), - auth_event - .state_key - .clone() - .expect("all auth events have state keys"), - )) { - hash_map::Entry::Vacant(v) => { - v.insert(auth_event); + let room_version_id = &create_event_content.room_version; + let room_version = RoomVersion::new(room_version_id).expect("room version is supported"); + + let mut val = match ruma::signatures::verify_event( + &*pub_key_map.read().map_err(|_| "RwLock is poisoned.")?, + &value, + room_version_id, + ) { + Err(e) => { + // Drop + warn!("Dropping bad event {}: {}", event_id, e); + return Err("Signature verification failed".to_owned()); } - hash_map::Entry::Occupied(_) => { - return Err(Error::BadRequest(ErrorKind::InvalidParam, - "Auth event's type and state_key combination exists multiple times." - )); + Ok(ruma::signatures::Verified::Signatures) => { + // Redact + warn!("Calculated hash does not match: {}", event_id); + match ruma::signatures::redact(&value, room_version_id) { + Ok(obj) => obj, + Err(_) => return Err("Redaction failed".to_owned()), + } } - } - } + Ok(ruma::signatures::Verified::All) => value, + }; - // The original create event must be in the auth events - if auth_events - .get(&(StateEventType::RoomCreate, "".to_owned())) - .map(|a| a.as_ref()) - != Some(create_event) - { - return Err(Error::BadRequest(ErrorKind::InvalidParam("Incoming event refers to wrong create event."))); - } + // Now that we have checked the signature and hashes we can add the eventID and convert + // to our PduEvent type + val.insert( + "event_id".to_owned(), + CanonicalJsonValue::String(event_id.as_str().to_owned()), + ); + let incoming_pdu = serde_json::from_value::( + serde_json::to_value(&val).expect("CanonicalJsonObj is a valid JsonValue"), + ) + .map_err(|_| "Event is not a valid PDU.".to_owned())?; - if !state_res::event_auth::auth_check( - &room_version, - &incoming_pdu, - None::, // TODO: third party invite - |k, s| auth_events.get(&(k.to_string().into(), s.to_owned())), - ) - .map_err(|e| {error!(e); Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed")})? - { - return Err(Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed")); - } + // 4. fetch any missing auth events doing all checks listed here starting at 1. These are not timeline events + // 5. Reject "due to auth events" if can't get all the auth events or some of the auth events are also rejected "due to auth events" + // NOTE: Step 5 is not applied anymore because it failed too often + warn!("Fetching auth events for {}", incoming_pdu.event_id); + fetch_and_handle_outliers( + db, + origin, + &incoming_pdu + .auth_events + .iter() + .map(|x| Arc::from(&**x)) + .collect::>(), + create_event, + room_id, + pub_key_map, + ) + .await; - info!("Validation successful."); + // 6. Reject "due to auth events" if the event doesn't pass auth based on the auth events + info!( + "Auth check for {} based on auth events", + incoming_pdu.event_id + ); - // 7. Persist the event as an outlier. - db.rooms - .add_pdu_outlier(&incoming_pdu.event_id, &val)?; + // Build map of auth events + let mut auth_events = HashMap::new(); + for id in &incoming_pdu.auth_events { + let auth_event = match db.rooms.get_pdu(id)? { + Some(e) => e, + None => { + warn!("Could not find auth event {}", id); + continue; + } + }; - info!("Added pdu as outlier."); + match auth_events.entry(( + auth_event.kind.to_string().into(), + auth_event + .state_key + .clone() + .expect("all auth events have state keys"), + )) { + hash_map::Entry::Vacant(v) => { + v.insert(auth_event); + } + hash_map::Entry::Occupied(_) => { + return Err(Error::BadRequest(ErrorKind::InvalidParam, + "Auth event's type and state_key combination exists multiple times." + )); + } + } + } - Ok((Arc::new(incoming_pdu), val)) - }) -} + // The original create event must be in the auth events + if auth_events + .get(&(StateEventType::RoomCreate, "".to_owned())) + .map(|a| a.as_ref()) + != Some(create_event) + { + return Err(Error::BadRequest(ErrorKind::InvalidParam("Incoming event refers to wrong create event."))); + } -#[tracing::instrument(skip(incoming_pdu, val, create_event, db, pub_key_map))] -async fn upgrade_outlier_to_timeline_pdu( - incoming_pdu: Arc, - val: BTreeMap, - create_event: &PduEvent, - origin: &ServerName, - db: &Database, - room_id: &RoomId, - pub_key_map: &RwLock>>, -) -> Result>, String> { - // Skip the PDU if we already have it as a timeline event - if let Ok(Some(pduid)) = db.rooms.get_pdu_id(&incoming_pdu.event_id) { - return Ok(Some(pduid)); - } + if !state_res::event_auth::auth_check( + &room_version, + &incoming_pdu, + None::, // TODO: third party invite + |k, s| auth_events.get(&(k.to_string().into(), s.to_owned())), + ) + .map_err(|e| {error!(e); Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed")})? + { + return Err(Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed")); + } - if db - .rooms - .is_event_soft_failed(&incoming_pdu.event_id) - .map_err(|_| "Failed to ask db for soft fail".to_owned())? - { - return Err("Event has been soft failed".into()); - } + info!("Validation successful."); - info!("Upgrading {} to timeline pdu", incoming_pdu.event_id); + // 7. Persist the event as an outlier. + db.rooms + .add_pdu_outlier(&incoming_pdu.event_id, &val)?; - let create_event_content: RoomCreateEventContent = - serde_json::from_str(create_event.content.get()).map_err(|e| { - warn!("Invalid create event: {}", e); - Error::BadDatabase("Invalid create event in db") - })?; + info!("Added pdu as outlier."); - let room_version_id = &create_event_content.room_version; - let room_version = RoomVersion::new(room_version_id).expect("room version is supported"); + Ok((Arc::new(incoming_pdu), val)) + }) + } - // 10. Fetch missing state and auth chain events by calling /state_ids at backwards extremities - // doing all the checks in this list starting at 1. These are not timeline events. + #[tracing::instrument(skip(incoming_pdu, val, create_event, db, pub_key_map))] + async fn upgrade_outlier_to_timeline_pdu( + incoming_pdu: Arc, + val: BTreeMap, + create_event: &PduEvent, + origin: &ServerName, + db: &Database, + room_id: &RoomId, + pub_key_map: &RwLock>>, + ) -> Result>, String> { + // Skip the PDU if we already have it as a timeline event + if let Ok(Some(pduid)) = db.rooms.get_pdu_id(&incoming_pdu.event_id) { + return Ok(Some(pduid)); + } - // TODO: if we know the prev_events of the incoming event we can avoid the request and build - // the state from a known point and resolve if > 1 prev_event + if db + .rooms + .is_event_soft_failed(&incoming_pdu.event_id) + .map_err(|_| "Failed to ask db for soft fail".to_owned())? + { + return Err("Event has been soft failed".into()); + } - info!("Requesting state at event"); - let mut state_at_incoming_event = None; + info!("Upgrading {} to timeline pdu", incoming_pdu.event_id); - if incoming_pdu.prev_events.len() == 1 { - let prev_event = &*incoming_pdu.prev_events[0]; - let prev_event_sstatehash = db - .rooms - .pdu_shortstatehash(prev_event) - .map_err(|_| "Failed talking to db".to_owned())?; + let create_event_content: RoomCreateEventContent = + serde_json::from_str(create_event.content.get()).map_err(|e| { + warn!("Invalid create event: {}", e); + Error::BadDatabase("Invalid create event in db") + })?; - let state = if let Some(shortstatehash) = prev_event_sstatehash { - Some(db.rooms.state_full_ids(shortstatehash).await) - } else { - None - }; - - if let Some(Ok(mut state)) = state { - info!("Using cached state"); - let prev_pdu = - db.rooms.get_pdu(prev_event).ok().flatten().ok_or_else(|| { - "Could not find prev event, but we know the state.".to_owned() - })?; + let room_version_id = &create_event_content.room_version; + let room_version = RoomVersion::new(room_version_id).expect("room version is supported"); - if let Some(state_key) = &prev_pdu.state_key { - let shortstatekey = db - .rooms - .get_or_create_shortstatekey( - &prev_pdu.kind.to_string().into(), - state_key, - &db.globals, - ) - .map_err(|_| "Failed to create shortstatekey.".to_owned())?; + // 10. Fetch missing state and auth chain events by calling /state_ids at backwards extremities + // doing all the checks in this list starting at 1. These are not timeline events. - state.insert(shortstatekey, Arc::from(prev_event)); - // Now it's the state after the pdu - } + // TODO: if we know the prev_events of the incoming event we can avoid the request and build + // the state from a known point and resolve if > 1 prev_event - state_at_incoming_event = Some(state); - } - } else { - info!("Calculating state at event using state res"); - let mut extremity_sstatehashes = HashMap::new(); - - let mut okay = true; - for prev_eventid in &incoming_pdu.prev_events { - let prev_event = if let Ok(Some(pdu)) = db.rooms.get_pdu(prev_eventid) { - pdu - } else { - okay = false; - break; - }; + info!("Requesting state at event"); + let mut state_at_incoming_event = None; + + if incoming_pdu.prev_events.len() == 1 { + let prev_event = &*incoming_pdu.prev_events[0]; + let prev_event_sstatehash = db + .rooms + .pdu_shortstatehash(prev_event) + .map_err(|_| "Failed talking to db".to_owned())?; - let sstatehash = if let Ok(Some(s)) = db.rooms.pdu_shortstatehash(prev_eventid) { - s + let state = if let Some(shortstatehash) = prev_event_sstatehash { + Some(db.rooms.state_full_ids(shortstatehash).await) } else { - okay = false; - break; + None }; - extremity_sstatehashes.insert(sstatehash, prev_event); - } - - if okay { - let mut fork_states = Vec::with_capacity(extremity_sstatehashes.len()); - let mut auth_chain_sets = Vec::with_capacity(extremity_sstatehashes.len()); + if let Some(Ok(mut state)) = state { + info!("Using cached state"); + let prev_pdu = + db.rooms.get_pdu(prev_event).ok().flatten().ok_or_else(|| { + "Could not find prev event, but we know the state.".to_owned() + })?; - for (sstatehash, prev_event) in extremity_sstatehashes { - let mut leaf_state: BTreeMap<_, _> = db - .rooms - .state_full_ids(sstatehash) - .await - .map_err(|_| "Failed to ask db for room state.".to_owned())?; - - if let Some(state_key) = &prev_event.state_key { + if let Some(state_key) = &prev_pdu.state_key { let shortstatekey = db .rooms .get_or_create_shortstatekey( - &prev_event.kind.to_string().into(), + &prev_pdu.kind.to_string().into(), state_key, &db.globals, ) .map_err(|_| "Failed to create shortstatekey.".to_owned())?; - leaf_state.insert(shortstatekey, Arc::from(&*prev_event.event_id)); + + state.insert(shortstatekey, Arc::from(prev_event)); // Now it's the state after the pdu } - let mut state = StateMap::with_capacity(leaf_state.len()); - let mut starting_events = Vec::with_capacity(leaf_state.len()); + state_at_incoming_event = Some(state); + } + } else { + info!("Calculating state at event using state res"); + let mut extremity_sstatehashes = HashMap::new(); - for (k, id) in leaf_state { - if let Ok((ty, st_key)) = db.rooms.get_statekey_from_short(k) { - // FIXME: Undo .to_string().into() when StateMap - // is updated to use StateEventType - state.insert((ty.to_string().into(), st_key), id.clone()); - } else { - warn!("Failed to get_statekey_from_short."); - } - starting_events.push(id); - } + let mut okay = true; + for prev_eventid in &incoming_pdu.prev_events { + let prev_event = if let Ok(Some(pdu)) = db.rooms.get_pdu(prev_eventid) { + pdu + } else { + okay = false; + break; + }; - auth_chain_sets.push( - get_auth_chain(room_id, starting_events, db) - .await - .map_err(|_| "Failed to load auth chain.".to_owned())? - .collect(), - ); + let sstatehash = if let Ok(Some(s)) = db.rooms.pdu_shortstatehash(prev_eventid) { + s + } else { + okay = false; + break; + }; - fork_states.push(state); + extremity_sstatehashes.insert(sstatehash, prev_event); } - let lock = db.globals.stateres_mutex.lock(); + if okay { + let mut fork_states = Vec::with_capacity(extremity_sstatehashes.len()); + let mut auth_chain_sets = Vec::with_capacity(extremity_sstatehashes.len()); - let result = state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| { - let res = db.rooms.get_pdu(id); - if let Err(e) = &res { - error!("LOOK AT ME Failed to fetch event: {}", e); + for (sstatehash, prev_event) in extremity_sstatehashes { + let mut leaf_state: BTreeMap<_, _> = db + .rooms + .state_full_ids(sstatehash) + .await + .map_err(|_| "Failed to ask db for room state.".to_owned())?; + + if let Some(state_key) = &prev_event.state_key { + let shortstatekey = db + .rooms + .get_or_create_shortstatekey( + &prev_event.kind.to_string().into(), + state_key, + &db.globals, + ) + .map_err(|_| "Failed to create shortstatekey.".to_owned())?; + leaf_state.insert(shortstatekey, Arc::from(&*prev_event.event_id)); + // Now it's the state after the pdu + } + + let mut state = StateMap::with_capacity(leaf_state.len()); + let mut starting_events = Vec::with_capacity(leaf_state.len()); + + for (k, id) in leaf_state { + if let Ok((ty, st_key)) = db.rooms.get_statekey_from_short(k) { + // FIXME: Undo .to_string().into() when StateMap + // is updated to use StateEventType + state.insert((ty.to_string().into(), st_key), id.clone()); + } else { + warn!("Failed to get_statekey_from_short."); + } + starting_events.push(id); + } + + auth_chain_sets.push( + get_auth_chain(room_id, starting_events, db) + .await + .map_err(|_| "Failed to load auth chain.".to_owned())? + .collect(), + ); + + fork_states.push(state); } - res.ok().flatten() - }); - drop(lock); - - state_at_incoming_event = match result { - Ok(new_state) => Some( - new_state - .into_iter() - .map(|((event_type, state_key), event_id)| { - let shortstatekey = db - .rooms - .get_or_create_shortstatekey( - &event_type.to_string().into(), - &state_key, - &db.globals, - ) - .map_err(|_| "Failed to get_or_create_shortstatekey".to_owned())?; - Ok((shortstatekey, event_id)) - }) - .collect::>()?, - ), - Err(e) => { - warn!("State resolution on prev events failed, either an event could not be found or deserialization: {}", e); - None + + let lock = db.globals.stateres_mutex.lock(); + + let result = state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| { + let res = db.rooms.get_pdu(id); + if let Err(e) = &res { + error!("LOOK AT ME Failed to fetch event: {}", e); + } + res.ok().flatten() + }); + drop(lock); + + state_at_incoming_event = match result { + Ok(new_state) => Some( + new_state + .into_iter() + .map(|((event_type, state_key), event_id)| { + let shortstatekey = db + .rooms + .get_or_create_shortstatekey( + &event_type.to_string().into(), + &state_key, + &db.globals, + ) + .map_err(|_| "Failed to get_or_create_shortstatekey".to_owned())?; + Ok((shortstatekey, event_id)) + }) + .collect::>()?, + ), + Err(e) => { + warn!("State resolution on prev events failed, either an event could not be found or deserialization: {}", e); + None + } } } } - } - if state_at_incoming_event.is_none() { - info!("Calling /state_ids"); - // Call /state_ids to find out what the state at this pdu is. We trust the server's - // response to some extend, but we still do a lot of checks on the events - match db - .sending - .send_federation_request( - &db.globals, - origin, - get_room_state_ids::v1::Request { - room_id, - event_id: &incoming_pdu.event_id, - }, - ) - .await - { - Ok(res) => { - info!("Fetching state events at event."); - let state_vec = fetch_and_handle_outliers( - db, + if state_at_incoming_event.is_none() { + info!("Calling /state_ids"); + // Call /state_ids to find out what the state at this pdu is. We trust the server's + // response to some extend, but we still do a lot of checks on the events + match db + .sending + .send_federation_request( + &db.globals, origin, - &res.pdu_ids - .iter() - .map(|x| Arc::from(&**x)) - .collect::>(), - create_event, - room_id, - pub_key_map, + get_room_state_ids::v1::Request { + room_id, + event_id: &incoming_pdu.event_id, + }, ) - .await; - - let mut state: BTreeMap<_, Arc> = BTreeMap::new(); - for (pdu, _) in state_vec { - let state_key = pdu - .state_key - .clone() - .ok_or_else(|| "Found non-state pdu in state events.".to_owned())?; + .await + { + Ok(res) => { + info!("Fetching state events at event."); + let state_vec = fetch_and_handle_outliers( + db, + origin, + &res.pdu_ids + .iter() + .map(|x| Arc::from(&**x)) + .collect::>(), + create_event, + room_id, + pub_key_map, + ) + .await; + + let mut state: BTreeMap<_, Arc> = BTreeMap::new(); + for (pdu, _) in state_vec { + let state_key = pdu + .state_key + .clone() + .ok_or_else(|| "Found non-state pdu in state events.".to_owned())?; + + let shortstatekey = db + .rooms + .get_or_create_shortstatekey( + &pdu.kind.to_string().into(), + &state_key, + &db.globals, + ) + .map_err(|_| "Failed to create shortstatekey.".to_owned())?; + + match state.entry(shortstatekey) { + btree_map::Entry::Vacant(v) => { + v.insert(Arc::from(&*pdu.event_id)); + } + btree_map::Entry::Occupied(_) => return Err( + "State event's type and state_key combination exists multiple times." + .to_owned(), + ), + } + } - let shortstatekey = db + // The original create event must still be in the state + let create_shortstatekey = db .rooms - .get_or_create_shortstatekey( - &pdu.kind.to_string().into(), - &state_key, - &db.globals, - ) - .map_err(|_| "Failed to create shortstatekey.".to_owned())?; - - match state.entry(shortstatekey) { - btree_map::Entry::Vacant(v) => { - v.insert(Arc::from(&*pdu.event_id)); - } - btree_map::Entry::Occupied(_) => return Err( - "State event's type and state_key combination exists multiple times." - .to_owned(), - ), + .get_shortstatekey(&StateEventType::RoomCreate, "") + .map_err(|_| "Failed to talk to db.")? + .expect("Room exists"); + + if state.get(&create_shortstatekey).map(|id| id.as_ref()) + != Some(&create_event.event_id) + { + return Err("Incoming event refers to wrong create event.".to_owned()); } + + state_at_incoming_event = Some(state); + } + Err(e) => { + warn!("Fetching state for event failed: {}", e); + return Err("Fetching state for event failed".into()); } + }; + } - // The original create event must still be in the state - let create_shortstatekey = db - .rooms - .get_shortstatekey(&StateEventType::RoomCreate, "") - .map_err(|_| "Failed to talk to db.")? - .expect("Room exists"); + let state_at_incoming_event = + state_at_incoming_event.expect("we always set this to some above"); - if state.get(&create_shortstatekey).map(|id| id.as_ref()) - != Some(&create_event.event_id) - { - return Err("Incoming event refers to wrong create event.".to_owned()); - } + info!("Starting auth check"); + // 11. Check the auth of the event passes based on the state of the event + let check_result = state_res::event_auth::auth_check( + &room_version, + &incoming_pdu, + None::, // TODO: third party invite + |k, s| { + db.rooms + .get_shortstatekey(&k.to_string().into(), s) + .ok() + .flatten() + .and_then(|shortstatekey| state_at_incoming_event.get(&shortstatekey)) + .and_then(|event_id| db.rooms.get_pdu(event_id).ok().flatten()) + }, + ) + .map_err(|_e| "Auth check failed.".to_owned())?; - state_at_incoming_event = Some(state); - } - Err(e) => { - warn!("Fetching state for event failed: {}", e); - return Err("Fetching state for event failed".into()); - } - }; - } + if !check_result { + return Err("Event has failed auth check with state at the event.".into()); + } + info!("Auth check succeeded"); - let state_at_incoming_event = - state_at_incoming_event.expect("we always set this to some above"); + // We start looking at current room state now, so lets lock the room - info!("Starting auth check"); - // 11. Check the auth of the event passes based on the state of the event - let check_result = state_res::event_auth::auth_check( - &room_version, - &incoming_pdu, - None::, // TODO: third party invite - |k, s| { - db.rooms - .get_shortstatekey(&k.to_string().into(), s) - .ok() - .flatten() - .and_then(|shortstatekey| state_at_incoming_event.get(&shortstatekey)) - .and_then(|event_id| db.rooms.get_pdu(event_id).ok().flatten()) - }, - ) - .map_err(|_e| "Auth check failed.".to_owned())?; - - if !check_result { - return Err("Event has failed auth check with state at the event.".into()); - } - info!("Auth check succeeded"); + let mutex_state = Arc::clone( + db.globals + .roomid_mutex_state + .write() + .unwrap() + .entry(room_id.to_owned()) + .or_default(), + ); + let state_lock = mutex_state.lock().await; - // We start looking at current room state now, so lets lock the room + // Now we calculate the set of extremities this room has after the incoming event has been + // applied. We start with the previous extremities (aka leaves) + info!("Calculating extremities"); + let mut extremities = db + .rooms + .get_pdu_leaves(room_id) + .map_err(|_| "Failed to load room leaves".to_owned())?; - let mutex_state = Arc::clone( - db.globals - .roomid_mutex_state - .write() - .unwrap() - .entry(room_id.to_owned()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; - - // Now we calculate the set of extremities this room has after the incoming event has been - // applied. We start with the previous extremities (aka leaves) - info!("Calculating extremities"); - let mut extremities = db - .rooms - .get_pdu_leaves(room_id) - .map_err(|_| "Failed to load room leaves".to_owned())?; - - // Remove any forward extremities that are referenced by this incoming event's prev_events - for prev_event in &incoming_pdu.prev_events { - if extremities.contains(prev_event) { - extremities.remove(prev_event); + // Remove any forward extremities that are referenced by this incoming event's prev_events + for prev_event in &incoming_pdu.prev_events { + if extremities.contains(prev_event) { + extremities.remove(prev_event); + } } - } - // Only keep those extremities were not referenced yet - extremities.retain(|id| !matches!(db.rooms.is_event_referenced(room_id, id), Ok(true))); + // Only keep those extremities were not referenced yet + extremities.retain(|id| !matches!(db.rooms.is_event_referenced(room_id, id), Ok(true))); - info!("Compressing state at event"); - let state_ids_compressed = state_at_incoming_event - .iter() - .map(|(shortstatekey, id)| { - db.rooms - .compress_state_event(*shortstatekey, id, &db.globals) - .map_err(|_| "Failed to compress_state_event".to_owned()) - }) - .collect::>()?; + info!("Compressing state at event"); + let state_ids_compressed = state_at_incoming_event + .iter() + .map(|(shortstatekey, id)| { + db.rooms + .compress_state_event(*shortstatekey, id, &db.globals) + .map_err(|_| "Failed to compress_state_event".to_owned()) + }) + .collect::>()?; - // 13. Check if the event passes auth based on the "current state" of the room, if not "soft fail" it - info!("Starting soft fail auth check"); + // 13. Check if the event passes auth based on the "current state" of the room, if not "soft fail" it + info!("Starting soft fail auth check"); - let auth_events = db - .rooms - .get_auth_events( - room_id, - &incoming_pdu.kind, - &incoming_pdu.sender, - incoming_pdu.state_key.as_deref(), - &incoming_pdu.content, - ) - .map_err(|_| "Failed to get_auth_events.".to_owned())?; - - let soft_fail = !state_res::event_auth::auth_check( - &room_version, - &incoming_pdu, - None::, - |k, s| auth_events.get(&(k.clone(), s.to_owned())), - ) - .map_err(|_e| "Auth check failed.".to_owned())?; - - if soft_fail { - append_incoming_pdu( - db, + let auth_events = db + .rooms + .get_auth_events( + room_id, + &incoming_pdu.kind, + &incoming_pdu.sender, + incoming_pdu.state_key.as_deref(), + &incoming_pdu.content, + ) + .map_err(|_| "Failed to get_auth_events.".to_owned())?; + + let soft_fail = !state_res::event_auth::auth_check( + &room_version, &incoming_pdu, - val, - extremities.iter().map(Deref::deref), - state_ids_compressed, - soft_fail, - &state_lock, + None::, + |k, s| auth_events.get(&(k.clone(), s.to_owned())), ) - .map_err(|e| { - warn!("Failed to add pdu to db: {}", e); - "Failed to add pdu to db.".to_owned() - })?; + .map_err(|_e| "Auth check failed.".to_owned())?; - // Soft fail, we keep the event as an outlier but don't add it to the timeline - warn!("Event was soft failed: {:?}", incoming_pdu); - db.rooms - .mark_event_soft_failed(&incoming_pdu.event_id) - .map_err(|_| "Failed to set soft failed flag".to_owned())?; - return Err("Event has been soft failed".into()); - } - - if incoming_pdu.state_key.is_some() { - info!("Loading current room state ids"); - let current_sstatehash = db - .rooms - .current_shortstatehash(room_id) - .map_err(|_| "Failed to load current state hash.".to_owned())? - .expect("every room has state"); + if soft_fail { + append_incoming_pdu( + db, + &incoming_pdu, + val, + extremities.iter().map(Deref::deref), + state_ids_compressed, + soft_fail, + &state_lock, + ) + .map_err(|e| { + warn!("Failed to add pdu to db: {}", e); + "Failed to add pdu to db.".to_owned() + })?; - let current_state_ids = db - .rooms - .state_full_ids(current_sstatehash) - .await - .map_err(|_| "Failed to load room state.")?; + // Soft fail, we keep the event as an outlier but don't add it to the timeline + warn!("Event was soft failed: {:?}", incoming_pdu); + db.rooms + .mark_event_soft_failed(&incoming_pdu.event_id) + .map_err(|_| "Failed to set soft failed flag".to_owned())?; + return Err("Event has been soft failed".into()); + } - info!("Preparing for stateres to derive new room state"); - let mut extremity_sstatehashes = HashMap::new(); + if incoming_pdu.state_key.is_some() { + info!("Loading current room state ids"); + let current_sstatehash = db + .rooms + .current_shortstatehash(room_id) + .map_err(|_| "Failed to load current state hash.".to_owned())? + .expect("every room has state"); - info!("Loading extremities"); - for id in dbg!(&extremities) { - match db + let current_state_ids = db .rooms - .get_pdu(id) - .map_err(|_| "Failed to ask db for pdu.".to_owned())? - { - Some(leaf_pdu) => { - extremity_sstatehashes.insert( - db.rooms - .pdu_shortstatehash(&leaf_pdu.event_id) - .map_err(|_| "Failed to ask db for pdu state hash.".to_owned())? - .ok_or_else(|| { - error!( - "Found extremity pdu with no statehash in db: {:?}", - leaf_pdu - ); - "Found pdu with no statehash in db.".to_owned() - })?, - leaf_pdu, - ); - } - _ => { - error!("Missing state snapshot for {:?}", id); - return Err("Missing state snapshot.".to_owned()); + .state_full_ids(current_sstatehash) + .await + .map_err(|_| "Failed to load room state.")?; + + info!("Preparing for stateres to derive new room state"); + let mut extremity_sstatehashes = HashMap::new(); + + info!("Loading extremities"); + for id in dbg!(&extremities) { + match db + .rooms + .get_pdu(id) + .map_err(|_| "Failed to ask db for pdu.".to_owned())? + { + Some(leaf_pdu) => { + extremity_sstatehashes.insert( + db.rooms + .pdu_shortstatehash(&leaf_pdu.event_id) + .map_err(|_| "Failed to ask db for pdu state hash.".to_owned())? + .ok_or_else(|| { + error!( + "Found extremity pdu with no statehash in db: {:?}", + leaf_pdu + ); + "Found pdu with no statehash in db.".to_owned() + })?, + leaf_pdu, + ); + } + _ => { + error!("Missing state snapshot for {:?}", id); + return Err("Missing state snapshot.".to_owned()); + } } } - } - let mut fork_states = Vec::new(); + let mut fork_states = Vec::new(); - // 12. Ensure that the state is derived from the previous current state (i.e. we calculated - // by doing state res where one of the inputs was a previously trusted set of state, - // don't just trust a set of state we got from a remote). + // 12. Ensure that the state is derived from the previous current state (i.e. we calculated + // by doing state res where one of the inputs was a previously trusted set of state, + // don't just trust a set of state we got from a remote). - // We do this by adding the current state to the list of fork states - extremity_sstatehashes.remove(¤t_sstatehash); - fork_states.push(current_state_ids); - - // We also add state after incoming event to the fork states - let mut state_after = state_at_incoming_event.clone(); - if let Some(state_key) = &incoming_pdu.state_key { - let shortstatekey = db - .rooms - .get_or_create_shortstatekey( - &incoming_pdu.kind.to_string().into(), - state_key, - &db.globals, - ) - .map_err(|_| "Failed to create shortstatekey.".to_owned())?; + // We do this by adding the current state to the list of fork states + extremity_sstatehashes.remove(¤t_sstatehash); + fork_states.push(current_state_ids); - state_after.insert(shortstatekey, Arc::from(&*incoming_pdu.event_id)); - } - fork_states.push(state_after); - - let mut update_state = false; - // 14. Use state resolution to find new room state - let new_room_state = if fork_states.is_empty() { - return Err("State is empty.".to_owned()); - } else if fork_states.iter().skip(1).all(|f| &fork_states[0] == f) { - info!("State resolution trivial"); - // There was only one state, so it has to be the room's current state (because that is - // always included) - fork_states[0] - .iter() - .map(|(k, id)| { - db.rooms - .compress_state_event(*k, id, &db.globals) - .map_err(|_| "Failed to compress_state_event.".to_owned()) - }) - .collect::>()? - } else { - info!("Loading auth chains"); - // We do need to force an update to this room's state - update_state = true; - - let mut auth_chain_sets = Vec::new(); - for state in &fork_states { - auth_chain_sets.push( - get_auth_chain( - room_id, - state.iter().map(|(_, id)| id.clone()).collect(), - db, + // We also add state after incoming event to the fork states + let mut state_after = state_at_incoming_event.clone(); + if let Some(state_key) = &incoming_pdu.state_key { + let shortstatekey = db + .rooms + .get_or_create_shortstatekey( + &incoming_pdu.kind.to_string().into(), + state_key, + &db.globals, ) - .await - .map_err(|_| "Failed to load auth chain.".to_owned())? - .collect(), - ); - } + .map_err(|_| "Failed to create shortstatekey.".to_owned())?; - info!("Loading fork states"); + state_after.insert(shortstatekey, Arc::from(&*incoming_pdu.event_id)); + } + fork_states.push(state_after); + + let mut update_state = false; + // 14. Use state resolution to find new room state + let new_room_state = if fork_states.is_empty() { + return Err("State is empty.".to_owned()); + } else if fork_states.iter().skip(1).all(|f| &fork_states[0] == f) { + info!("State resolution trivial"); + // There was only one state, so it has to be the room's current state (because that is + // always included) + fork_states[0] + .iter() + .map(|(k, id)| { + db.rooms + .compress_state_event(*k, id, &db.globals) + .map_err(|_| "Failed to compress_state_event.".to_owned()) + }) + .collect::>()? + } else { + info!("Loading auth chains"); + // We do need to force an update to this room's state + update_state = true; + + let mut auth_chain_sets = Vec::new(); + for state in &fork_states { + auth_chain_sets.push( + get_auth_chain( + room_id, + state.iter().map(|(_, id)| id.clone()).collect(), + db, + ) + .await + .map_err(|_| "Failed to load auth chain.".to_owned())? + .collect(), + ); + } - let fork_states: Vec<_> = fork_states - .into_iter() - .map(|map| { - map.into_iter() - .filter_map(|(k, id)| { - db.rooms - .get_statekey_from_short(k) - // FIXME: Undo .to_string().into() when StateMap - // is updated to use StateEventType - .map(|(ty, st_key)| ((ty.to_string().into(), st_key), id)) - .map_err(|e| warn!("Failed to get_statekey_from_short: {}", e)) - .ok() - }) - .collect::>() - }) - .collect(); - - info!("Resolving state"); - - let lock = db.globals.stateres_mutex.lock(); - let state = match state_res::resolve( - room_version_id, - &fork_states, - auth_chain_sets, - |id| { - let res = db.rooms.get_pdu(id); - if let Err(e) = &res { - error!("LOOK AT ME Failed to fetch event: {}", e); + info!("Loading fork states"); + + let fork_states: Vec<_> = fork_states + .into_iter() + .map(|map| { + map.into_iter() + .filter_map(|(k, id)| { + db.rooms + .get_statekey_from_short(k) + // FIXME: Undo .to_string().into() when StateMap + // is updated to use StateEventType + .map(|(ty, st_key)| ((ty.to_string().into(), st_key), id)) + .map_err(|e| warn!("Failed to get_statekey_from_short: {}", e)) + .ok() + }) + .collect::>() + }) + .collect(); + + info!("Resolving state"); + + let lock = db.globals.stateres_mutex.lock(); + let state = match state_res::resolve( + room_version_id, + &fork_states, + auth_chain_sets, + |id| { + let res = db.rooms.get_pdu(id); + if let Err(e) = &res { + error!("LOOK AT ME Failed to fetch event: {}", e); + } + res.ok().flatten() + }, + ) { + Ok(new_state) => new_state, + Err(_) => { + return Err("State resolution failed, either an event could not be found or deserialization".into()); } - res.ok().flatten() - }, - ) { - Ok(new_state) => new_state, - Err(_) => { - return Err("State resolution failed, either an event could not be found or deserialization".into()); - } + }; + + drop(lock); + + info!("State resolution done. Compressing state"); + + state + .into_iter() + .map(|((event_type, state_key), event_id)| { + let shortstatekey = db + .rooms + .get_or_create_shortstatekey( + &event_type.to_string().into(), + &state_key, + &db.globals, + ) + .map_err(|_| "Failed to get_or_create_shortstatekey".to_owned())?; + db.rooms + .compress_state_event(shortstatekey, &event_id, &db.globals) + .map_err(|_| "Failed to compress state event".to_owned()) + }) + .collect::>()? }; - drop(lock); - - info!("State resolution done. Compressing state"); - - state - .into_iter() - .map(|((event_type, state_key), event_id)| { - let shortstatekey = db - .rooms - .get_or_create_shortstatekey( - &event_type.to_string().into(), - &state_key, - &db.globals, - ) - .map_err(|_| "Failed to get_or_create_shortstatekey".to_owned())?; - db.rooms - .compress_state_event(shortstatekey, &event_id, &db.globals) - .map_err(|_| "Failed to compress state event".to_owned()) - }) - .collect::>()? - }; - - // Set the new room state to the resolved state - if update_state { - info!("Forcing new room state"); - db.rooms - .force_state(room_id, new_room_state, db) - .map_err(|_| "Failed to set new room state.".to_owned())?; + // Set the new room state to the resolved state + if update_state { + info!("Forcing new room state"); + db.rooms + .force_state(room_id, new_room_state, db) + .map_err(|_| "Failed to set new room state.".to_owned())?; + } } - } - info!("Appending pdu to timeline"); - extremities.insert(incoming_pdu.event_id.clone()); - - // Now that the event has passed all auth it is added into the timeline. - // We use the `state_at_event` instead of `state_after` so we accurately - // represent the state for this event. - - let pdu_id = append_incoming_pdu( - db, - &incoming_pdu, - val, - extremities.iter().map(Deref::deref), - state_ids_compressed, - soft_fail, - &state_lock, - ) - .map_err(|e| { - warn!("Failed to add pdu to db: {}", e); - "Failed to add pdu to db.".to_owned() - })?; - - info!("Appended incoming pdu"); - - // Event has passed all auth/stateres checks - drop(state_lock); - Ok(pdu_id) -} + info!("Appending pdu to timeline"); + extremities.insert(incoming_pdu.event_id.clone()); -/// Find the event and auth it. Once the event is validated (steps 1 - 8) -/// it is appended to the outliers Tree. -/// -/// Returns pdu and if we fetched it over federation the raw json. -/// -/// a. Look in the main timeline (pduid_pdu tree) -/// b. Look at outlier pdu tree -/// c. Ask origin server over federation -/// d. TODO: Ask other servers over federation? -#[tracing::instrument(skip_all)] -pub(crate) fn fetch_and_handle_outliers<'a>( - db: &'a Database, - origin: &'a ServerName, - events: &'a [Arc], - create_event: &'a PduEvent, - room_id: &'a RoomId, - pub_key_map: &'a RwLock>>, -) -> AsyncRecursiveType<'a, Vec<(Arc, Option>)>> { - Box::pin(async move { - let back_off = |id| match db.globals.bad_event_ratelimiter.write().unwrap().entry(id) { - hash_map::Entry::Vacant(e) => { - e.insert((Instant::now(), 1)); - } - hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), - }; + // Now that the event has passed all auth it is added into the timeline. + // We use the `state_at_event` instead of `state_after` so we accurately + // represent the state for this event. - let mut pdus = vec![]; - for id in events { - if let Some((time, tries)) = db.globals.bad_event_ratelimiter.read().unwrap().get(&**id) - { - // Exponential backoff - let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); - if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { - min_elapsed_duration = Duration::from_secs(60 * 60 * 24); - } + let pdu_id = append_incoming_pdu( + db, + &incoming_pdu, + val, + extremities.iter().map(Deref::deref), + state_ids_compressed, + soft_fail, + &state_lock, + ) + .map_err(|e| { + warn!("Failed to add pdu to db: {}", e); + "Failed to add pdu to db.".to_owned() + })?; - if time.elapsed() < min_elapsed_duration { - info!("Backing off from {}", id); - continue; - } - } + info!("Appended incoming pdu"); - // a. Look in the main timeline (pduid_pdu tree) - // b. Look at outlier pdu tree - // (get_pdu_json checks both) - if let Ok(Some(local_pdu)) = db.rooms.get_pdu(id) { - trace!("Found {} in db", id); - pdus.push((local_pdu, None)); - continue; - } + // Event has passed all auth/stateres checks + drop(state_lock); + Ok(pdu_id) + } - // c. Ask origin server over federation - // We also handle its auth chain here so we don't get a stack overflow in - // handle_outlier_pdu. - let mut todo_auth_events = vec![Arc::clone(id)]; - let mut events_in_reverse_order = Vec::new(); - let mut events_all = HashSet::new(); - let mut i = 0; - while let Some(next_id) = todo_auth_events.pop() { - if events_all.contains(&next_id) { - continue; + /// Find the event and auth it. Once the event is validated (steps 1 - 8) + /// it is appended to the outliers Tree. + /// + /// Returns pdu and if we fetched it over federation the raw json. + /// + /// a. Look in the main timeline (pduid_pdu tree) + /// b. Look at outlier pdu tree + /// c. Ask origin server over federation + /// d. TODO: Ask other servers over federation? + #[tracing::instrument(skip_all)] + pub(crate) fn fetch_and_handle_outliers<'a>( + db: &'a Database, + origin: &'a ServerName, + events: &'a [Arc], + create_event: &'a PduEvent, + room_id: &'a RoomId, + pub_key_map: &'a RwLock>>, + ) -> AsyncRecursiveType<'a, Vec<(Arc, Option>)>> { + Box::pin(async move { + let back_off = |id| match db.globals.bad_event_ratelimiter.write().unwrap().entry(id) { + hash_map::Entry::Vacant(e) => { + e.insert((Instant::now(), 1)); } + hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), + }; - i += 1; - if i % 100 == 0 { - tokio::task::yield_now().await; + let mut pdus = vec![]; + for id in events { + if let Some((time, tries)) = db.globals.bad_event_ratelimiter.read().unwrap().get(&**id) + { + // Exponential backoff + let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); + if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { + min_elapsed_duration = Duration::from_secs(60 * 60 * 24); + } + + if time.elapsed() < min_elapsed_duration { + info!("Backing off from {}", id); + continue; + } } - if let Ok(Some(_)) = db.rooms.get_pdu(&next_id) { + // a. Look in the main timeline (pduid_pdu tree) + // b. Look at outlier pdu tree + // (get_pdu_json checks both) + if let Ok(Some(local_pdu)) = db.rooms.get_pdu(id) { trace!("Found {} in db", id); + pdus.push((local_pdu, None)); continue; } - info!("Fetching {} over federation.", next_id); - match db - .sending - .send_federation_request( - &db.globals, - origin, - get_event::v1::Request { event_id: &next_id }, - ) - .await - { - Ok(res) => { - info!("Got {} over federation", next_id); - let (calculated_event_id, value) = - match crate::pdu::gen_event_id_canonical_json(&res.pdu, &db) { - Ok(t) => t, - Err(_) => { - back_off((*next_id).to_owned()); - continue; - } - }; + // c. Ask origin server over federation + // We also handle its auth chain here so we don't get a stack overflow in + // handle_outlier_pdu. + let mut todo_auth_events = vec![Arc::clone(id)]; + let mut events_in_reverse_order = Vec::new(); + let mut events_all = HashSet::new(); + let mut i = 0; + while let Some(next_id) = todo_auth_events.pop() { + if events_all.contains(&next_id) { + continue; + } - if calculated_event_id != *next_id { - warn!("Server didn't return event id we requested: requested: {}, we got {}. Event: {:?}", - next_id, calculated_event_id, &res.pdu); - } + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + + if let Ok(Some(_)) = db.rooms.get_pdu(&next_id) { + trace!("Found {} in db", id); + continue; + } - if let Some(auth_events) = - value.get("auth_events").and_then(|c| c.as_array()) - { - for auth_event in auth_events { - if let Ok(auth_event) = - serde_json::from_value(auth_event.clone().into()) - { - let a: Arc = auth_event; - todo_auth_events.push(a); - } else { - warn!("Auth event id is not valid"); + info!("Fetching {} over federation.", next_id); + match db + .sending + .send_federation_request( + &db.globals, + origin, + get_event::v1::Request { event_id: &next_id }, + ) + .await + { + Ok(res) => { + info!("Got {} over federation", next_id); + let (calculated_event_id, value) = + match crate::pdu::gen_event_id_canonical_json(&res.pdu, &db) { + Ok(t) => t, + Err(_) => { + back_off((*next_id).to_owned()); + continue; + } + }; + + if calculated_event_id != *next_id { + warn!("Server didn't return event id we requested: requested: {}, we got {}. Event: {:?}", + next_id, calculated_event_id, &res.pdu); + } + + if let Some(auth_events) = + value.get("auth_events").and_then(|c| c.as_array()) + { + for auth_event in auth_events { + if let Ok(auth_event) = + serde_json::from_value(auth_event.clone().into()) + { + let a: Arc = auth_event; + todo_auth_events.push(a); + } else { + warn!("Auth event id is not valid"); + } } + } else { + warn!("Auth event list invalid"); } - } else { - warn!("Auth event list invalid"); - } - events_in_reverse_order.push((next_id.clone(), value)); - events_all.insert(next_id); - } - Err(_) => { - warn!("Failed to fetch event: {}", next_id); - back_off((*next_id).to_owned()); + events_in_reverse_order.push((next_id.clone(), value)); + events_all.insert(next_id); + } + Err(_) => { + warn!("Failed to fetch event: {}", next_id); + back_off((*next_id).to_owned()); + } } } - } - for (next_id, value) in events_in_reverse_order.iter().rev() { - match handle_outlier_pdu( - origin, - create_event, - next_id, - room_id, - value.clone(), - db, - pub_key_map, - ) - .await - { - Ok((pdu, json)) => { - if next_id == id { - pdus.push((pdu, Some(json))); + for (next_id, value) in events_in_reverse_order.iter().rev() { + match handle_outlier_pdu( + origin, + create_event, + next_id, + room_id, + value.clone(), + db, + pub_key_map, + ) + .await + { + Ok((pdu, json)) => { + if next_id == id { + pdus.push((pdu, Some(json))); + } + } + Err(e) => { + warn!("Authentication of event {} failed: {:?}", next_id, e); + back_off((**next_id).to_owned()); } - } - Err(e) => { - warn!("Authentication of event {} failed: {:?}", next_id, e); - back_off((**next_id).to_owned()); } } } - } - pdus - }) -} - + pdus + }) + } -fn fetch_unknown_prev_events(initial_set: Vec>) -> Vec> { - let mut graph: HashMap, _> = HashMap::new(); - let mut eventid_info = HashMap::new(); - let mut todo_outlier_stack: Vec> = initial_set; - let mut amount = 0; + fn fetch_unknown_prev_events(initial_set: Vec>) -> Vec> { + let mut graph: HashMap, _> = HashMap::new(); + let mut eventid_info = HashMap::new(); + let mut todo_outlier_stack: Vec> = initial_set; - while let Some(prev_event_id) = todo_outlier_stack.pop() { - if let Some((pdu, json_opt)) = fetch_and_handle_outliers( - db, - origin, - &[prev_event_id.clone()], - &create_event, - room_id, - pub_key_map, - ) - .await - .pop() - { - if amount > 100 { - // Max limit reached - warn!("Max prev event limit reached!"); - graph.insert(prev_event_id.clone(), HashSet::new()); - continue; - } + let mut amount = 0; - if let Some(json) = - json_opt.or_else(|| db.rooms.get_outlier_pdu_json(&prev_event_id).ok().flatten()) + while let Some(prev_event_id) = todo_outlier_stack.pop() { + if let Some((pdu, json_opt)) = fetch_and_handle_outliers( + db, + origin, + &[prev_event_id.clone()], + &create_event, + room_id, + pub_key_map, + ) + .await + .pop() { - if pdu.origin_server_ts > first_pdu_in_room.origin_server_ts { - amount += 1; - for prev_prev in &pdu.prev_events { - if !graph.contains_key(prev_prev) { - todo_outlier_stack.push(dbg!(prev_prev.clone())); + if amount > 100 { + // Max limit reached + warn!("Max prev event limit reached!"); + graph.insert(prev_event_id.clone(), HashSet::new()); + continue; + } + + if let Some(json) = + json_opt.or_else(|| db.rooms.get_outlier_pdu_json(&prev_event_id).ok().flatten()) + { + if pdu.origin_server_ts > first_pdu_in_room.origin_server_ts { + amount += 1; + for prev_prev in &pdu.prev_events { + if !graph.contains_key(prev_prev) { + todo_outlier_stack.push(dbg!(prev_prev.clone())); + } } + + graph.insert( + prev_event_id.clone(), + pdu.prev_events.iter().cloned().collect(), + ); + } else { + // Time based check failed + graph.insert(prev_event_id.clone(), HashSet::new()); } - graph.insert( - prev_event_id.clone(), - pdu.prev_events.iter().cloned().collect(), - ); + eventid_info.insert(prev_event_id.clone(), (pdu, json)); } else { - // Time based check failed + // Get json failed, so this was not fetched over federation graph.insert(prev_event_id.clone(), HashSet::new()); } - - eventid_info.insert(prev_event_id.clone(), (pdu, json)); } else { - // Get json failed, so this was not fetched over federation + // Fetch and handle failed graph.insert(prev_event_id.clone(), HashSet::new()); } - } else { - // Fetch and handle failed - graph.insert(prev_event_id.clone(), HashSet::new()); } - } - let sorted = state_res::lexicographical_topological_sort(dbg!(&graph), |event_id| { - // This return value is the key used for sorting events, - // events are then sorted by power level, time, - // and lexically by event_id. - println!("{}", event_id); - Ok(( - int!(0), - MilliSecondsSinceUnixEpoch( - eventid_info - .get(event_id) - .map_or_else(|| uint!(0), |info| info.0.origin_server_ts), - ), - )) - }) - .map_err(|_| "Error sorting prev events".to_owned())?; - - sorted + let sorted = state_res::lexicographical_topological_sort(dbg!(&graph), |event_id| { + // This return value is the key used for sorting events, + // events are then sorted by power level, time, + // and lexically by event_id. + println!("{}", event_id); + Ok(( + int!(0), + MilliSecondsSinceUnixEpoch( + eventid_info + .get(event_id) + .map_or_else(|| uint!(0), |info| info.0.origin_server_ts), + ), + )) + }) + .map_err(|_| "Error sorting prev events".to_owned())?; + + sorted + } } diff --git a/src/service/rooms/lazy_loading/data.rs b/src/service/rooms/lazy_loading/data.rs new file mode 100644 index 00000000..9cf2d8bc --- /dev/null +++ b/src/service/rooms/lazy_loading/data.rs @@ -0,0 +1,24 @@ +pub trait Data { + fn lazy_load_was_sent_before( + &self, + user_id: &UserId, + device_id: &DeviceId, + room_id: &RoomId, + ll_user: &UserId, + ) -> Result; + + fn lazy_load_confirm_delivery( + &self, + user_id: &UserId, + device_id: &DeviceId, + room_id: &RoomId, + since: u64, + ) -> Result<()>; + + fn lazy_load_reset( + &self, + user_id: &UserId, + device_id: &DeviceId, + room_id: &RoomId, + ) -> Result<()>; +} diff --git a/src/service/rooms/lazy_loading/mod.rs b/src/service/rooms/lazy_loading/mod.rs index a402702a..cf00174b 100644 --- a/src/service/rooms/lazy_loading/mod.rs +++ b/src/service/rooms/lazy_loading/mod.rs @@ -1,4 +1,13 @@ +mod data; +pub use data::Data; +use crate::service::*; + +pub struct Service { + db: D, +} + +impl Service<_> { #[tracing::instrument(skip(self))] pub fn lazy_load_was_sent_before( &self, @@ -7,14 +16,7 @@ room_id: &RoomId, ll_user: &UserId, ) -> Result { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(device_id.as_bytes()); - key.push(0xff); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xff); - key.extend_from_slice(ll_user.as_bytes()); - Ok(self.lazyloadedids.get(&key)?.is_some()) + self.db.lazy_load_was_sent_before(user_id, device_id, room_id, ll_user) } #[tracing::instrument(skip(self))] @@ -45,27 +47,7 @@ room_id: &RoomId, since: u64, ) -> Result<()> { - if let Some(user_ids) = self.lazy_load_waiting.lock().unwrap().remove(&( - user_id.to_owned(), - device_id.to_owned(), - room_id.to_owned(), - since, - )) { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xff); - prefix.extend_from_slice(room_id.as_bytes()); - prefix.push(0xff); - - for ll_id in user_ids { - let mut key = prefix.clone(); - key.extend_from_slice(ll_id.as_bytes()); - self.lazyloadedids.insert(&key, &[])?; - } - } - - Ok(()) + self.db.lazy_load_confirm_delivery(user_d, device_id, room_id, since) } #[tracing::instrument(skip(self))] @@ -75,17 +57,6 @@ device_id: &DeviceId, room_id: &RoomId, ) -> Result<()> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xff); - prefix.extend_from_slice(room_id.as_bytes()); - prefix.push(0xff); - - for (key, _) in self.lazyloadedids.scan_prefix(prefix) { - self.lazyloadedids.remove(&key)?; - } - - Ok(()) + self.db.lazy_load_reset(user_id, device_id, room_id); } - +} diff --git a/src/service/rooms/metadata/data.rs b/src/service/rooms/metadata/data.rs new file mode 100644 index 00000000..58bd3510 --- /dev/null +++ b/src/service/rooms/metadata/data.rs @@ -0,0 +1,3 @@ +pub trait Data { + fn exists(&self, room_id: &RoomId) -> Result; +} diff --git a/src/service/rooms/metadata/mod.rs b/src/service/rooms/metadata/mod.rs index 5d703451..644cd18f 100644 --- a/src/service/rooms/metadata/mod.rs +++ b/src/service/rooms/metadata/mod.rs @@ -1,44 +1,16 @@ - /// Checks if a room exists. - #[tracing::instrument(skip(self))] - pub fn exists(&self, room_id: &RoomId) -> Result { - let prefix = match self.get_shortroomid(room_id)? { - Some(b) => b.to_be_bytes().to_vec(), - None => return Ok(false), - }; +mod data; +pub use data::Data; - // Look for PDUs in that room. - Ok(self - .pduid_pdu - .iter_from(&prefix, false) - .next() - .filter(|(k, _)| k.starts_with(&prefix)) - .is_some()) - } +use crate::service::*; - pub fn get_shortroomid(&self, room_id: &RoomId) -> Result> { - self.roomid_shortroomid - .get(room_id.as_bytes())? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid shortroomid in db.")) - }) - .transpose() - } +pub struct Service { + db: D, +} - pub fn get_or_create_shortroomid( - &self, - room_id: &RoomId, - globals: &super::globals::Globals, - ) -> Result { - Ok(match self.roomid_shortroomid.get(room_id.as_bytes())? { - Some(short) => utils::u64_from_bytes(&short) - .map_err(|_| Error::bad_database("Invalid shortroomid in db."))?, - None => { - let short = globals.next_count()?; - self.roomid_shortroomid - .insert(room_id.as_bytes(), &short.to_be_bytes())?; - short - } - }) +impl Service<_> { + /// Checks if a room exists. + #[tracing::instrument(skip(self))] + pub fn exists(&self, room_id: &RoomId) -> Result { + self.db.exists(room_id) } - +} diff --git a/src/service/rooms/outlier/data.rs b/src/service/rooms/outlier/data.rs new file mode 100644 index 00000000..6b534b95 --- /dev/null +++ b/src/service/rooms/outlier/data.rs @@ -0,0 +1,5 @@ +pub trait Data { + fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result>; + fn get_outlier_pdu(&self, event_id: &EventId) -> Result>; + fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()>; +} diff --git a/src/service/rooms/outlier/mod.rs b/src/service/rooms/outlier/mod.rs index 340e93e4..c82cb628 100644 --- a/src/service/rooms/outlier/mod.rs +++ b/src/service/rooms/outlier/mod.rs @@ -1,27 +1,26 @@ +mod data; +pub use data::Data; + +use crate::service::*; + +pub struct Service { + db: D, +} + +impl Service<_> { /// Returns the pdu from the outlier tree. pub fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { - self.eventid_outlierpdu - .get(event_id.as_bytes())? - .map_or(Ok(None), |pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) + self.db.get_outlier_pdu_json(event_id) } /// Returns the pdu from the outlier tree. pub fn get_pdu_outlier(&self, event_id: &EventId) -> Result> { - self.eventid_outlierpdu - .get(event_id.as_bytes())? - .map_or(Ok(None), |pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) + self.db.get_outlier_pdu(event_id) } /// Append the PDU as an outlier. #[tracing::instrument(skip(self, pdu))] pub fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> { - self.eventid_outlierpdu.insert( - event_id.as_bytes(), - &serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"), - ) + self.db.add_pdu_outlier(event_id, pdu) } - +} diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs new file mode 100644 index 00000000..67787958 --- /dev/null +++ b/src/service/rooms/pdu_metadata/data.rs @@ -0,0 +1,6 @@ +pub trait Data { + fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()>; + fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result; + fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()>; + fn is_event_soft_failed(&self, event_id: &EventId) -> Result; +} diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs index f8ffcee1..6d6df223 100644 --- a/src/service/rooms/pdu_metadata/mod.rs +++ b/src/service/rooms/pdu_metadata/mod.rs @@ -1,31 +1,30 @@ +mod data; +pub use data::Data; +use crate::service::*; + +pub struct Service { + db: D, +} + +impl Service<_> { #[tracing::instrument(skip(self, room_id, event_ids))] pub fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()> { - for prev in event_ids { - let mut key = room_id.as_bytes().to_vec(); - key.extend_from_slice(prev.as_bytes()); - self.referencedevents.insert(&key, &[])?; - } - - Ok(()) + self.db.mark_as_referenced(room_id, event_ids) } #[tracing::instrument(skip(self))] pub fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result { - let mut key = room_id.as_bytes().to_vec(); - key.extend_from_slice(event_id.as_bytes()); - Ok(self.referencedevents.get(&key)?.is_some()) + self.db.is_event_referenced(room_id, event_id) } #[tracing::instrument(skip(self))] pub fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { - self.softfailedeventids.insert(event_id.as_bytes(), &[]) + self.db.mark_event_soft_failed(event_id) } #[tracing::instrument(skip(self))] pub fn is_event_soft_failed(&self, event_id: &EventId) -> Result { - self.softfailedeventids - .get(event_id.as_bytes()) - .map(|o| o.is_some()) + self.db.is_event_soft_failed(event_id) } - +} diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index 63e8b713..c44d357c 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -196,3 +196,30 @@ }) } + pub fn get_shortroomid(&self, room_id: &RoomId) -> Result> { + self.roomid_shortroomid + .get(room_id.as_bytes())? + .map(|bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Invalid shortroomid in db.")) + }) + .transpose() + } + + pub fn get_or_create_shortroomid( + &self, + room_id: &RoomId, + globals: &super::globals::Globals, + ) -> Result { + Ok(match self.roomid_shortroomid.get(room_id.as_bytes())? { + Some(short) => utils::u64_from_bytes(&short) + .map_err(|_| Error::bad_database("Invalid shortroomid in db."))?, + None => { + let short = globals.next_count()?; + self.roomid_shortroomid + .insert(room_id.as_bytes(), &short.to_be_bytes())?; + short + } + }) + } + diff --git a/src/service/rooms/state/data.rs b/src/service/rooms/state/data.rs index 4b42ca8e..8aa76380 100644 --- a/src/service/rooms/state/data.rs +++ b/src/service/rooms/state/data.rs @@ -1,16 +1,24 @@ pub trait Data { + /// Returns the last state hash key added to the db for the given room. fn get_room_shortstatehash(room_id: &RoomId); -} - /// Returns the last state hash key added to the db for the given room. - #[tracing::instrument(skip(self))] - pub fn current_shortstatehash(&self, room_id: &RoomId) -> Result> { - self.roomid_shortstatehash - .get(room_id.as_bytes())? - .map_or(Ok(None), |bytes| { - Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Invalid shortstatehash in roomid_shortstatehash") - })?)) - }) - } + /// Update the current state of the room. + fn set_room_state(room_id: &RoomId, new_shortstatehash: u64 + _mutex_lock: &MutexGuard<'_, StateLock>, // Take mutex guard to make sure users get the room state mutex + ); + + /// Associates a state with an event. + fn set_event_state(shorteventid: u64, shortstatehash: u64) -> Result<()> { + + /// Returns all events we would send as the prev_events of the next event. + fn get_forward_extremities(room_id: &RoomId) -> Result>>; + + /// Replace the forward extremities of the room. + fn set_forward_extremities( + room_id: &RoomId, + event_ids: impl IntoIterator + Debug, + _mutex_lock: &MutexGuard<'_, StateLock>, // Take mutex guard to make sure users get the room state mutex + ) -> Result<()> { +} +pub struct StateLock; diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index da03ad4c..b513ab53 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -1,25 +1,30 @@ +mod data; +pub use data::Data; + +use crate::service::*; + pub struct Service { db: D, } -impl Service { +impl Service<_> { /// Set the room to the given statehash and update caches. #[tracing::instrument(skip(self, new_state_ids_compressed, db))] pub fn force_state( &self, room_id: &RoomId, shortstatehash: u64, - statediffnew :HashSet, - statediffremoved :HashSet, + statediffnew: HashSet, + statediffremoved: HashSet, db: &Database, ) -> Result<()> { for event_id in statediffnew.into_iter().filter_map(|new| { - self.parse_compressed_state_event(new) + state_compressor::parse_compressed_state_event(new) .ok() .map(|(_, id)| id) }) { - let pdu = match self.get_pdu_json(&event_id)? { + let pdu = match timeline::get_pdu_json(&event_id)? { Some(pdu) => pdu, None => continue, }; @@ -55,56 +60,12 @@ impl Service { Err(_) => continue, }; - self.update_membership(room_id, &user_id, membership, &pdu.sender, None, db, false)?; + room::state_cache::update_membership(room_id, &user_id, membership, &pdu.sender, None, db, false)?; } - self.update_joined_count(room_id, db)?; + room::state_cache::update_joined_count(room_id, db)?; - self.roomid_shortstatehash - .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?; - - Ok(()) - } - - /// Returns the leaf pdus of a room. - #[tracing::instrument(skip(self))] - pub fn get_pdu_leaves(&self, room_id: &RoomId) -> Result>> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); - - self.roomid_pduleaves - .scan_prefix(prefix) - .map(|(_, bytes)| { - EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("EventID in roomid_pduleaves is invalid unicode.") - })?) - .map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid.")) - }) - .collect() - } - - /// Replace the leaves of a room. - /// - /// The provided `event_ids` become the new leaves, this allows a room to have multiple - /// `prev_events`. - #[tracing::instrument(skip(self))] - pub fn replace_pdu_leaves<'a>( - &self, - room_id: &RoomId, - event_ids: impl IntoIterator + Debug, - ) -> Result<()> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); - - for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) { - self.roomid_pduleaves.remove(&key)?; - } - - for event_id in event_ids { - let mut key = prefix.to_owned(); - key.extend_from_slice(event_id.as_bytes()); - self.roomid_pduleaves.insert(&key, event_id.as_bytes())?; - } + db.set_room_state(room_id, new_shortstatehash); Ok(()) } @@ -121,11 +82,11 @@ impl Service { state_ids_compressed: HashSet, globals: &super::globals::Globals, ) -> Result<()> { - let shorteventid = self.get_or_create_shorteventid(event_id, globals)?; + let shorteventid = short::get_or_create_shorteventid(event_id, globals)?; - let previous_shortstatehash = self.current_shortstatehash(room_id)?; + let previous_shortstatehash = db.get_room_shortstatehash(room_id)?; - let state_hash = self.calculate_hash( + let state_hash = super::calculate_hash( &state_ids_compressed .iter() .map(|s| &s[..]) @@ -133,11 +94,11 @@ impl Service { ); let (shortstatehash, already_existed) = - self.get_or_create_shortstatehash(&state_hash, globals)?; + short::get_or_create_shortstatehash(&state_hash, globals)?; if !already_existed { let states_parents = previous_shortstatehash - .map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; + .map_or_else(|| Ok(Vec::new()), |p| room::state_compressor.load_shortstatehash_info(p))?; let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { @@ -156,7 +117,7 @@ impl Service { } else { (state_ids_compressed, HashSet::new()) }; - self.save_state_from_diff( + state_compressor::save_state_from_diff( shortstatehash, statediffnew, statediffremoved, @@ -165,8 +126,7 @@ impl Service { )?; } - self.shorteventid_shortstatehash - .insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; + db.set_event_state(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; Ok(()) } @@ -183,7 +143,7 @@ impl Service { ) -> Result { let shorteventid = self.get_or_create_shorteventid(&new_pdu.event_id, globals)?; - let previous_shortstatehash = self.current_shortstatehash(&new_pdu.room_id)?; + let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id)?; if let Some(p) = previous_shortstatehash { self.shorteventid_shortstatehash @@ -293,4 +253,8 @@ impl Service { Ok(()) } + + pub fn db(&self) -> D { + &self.db + } }