diff --git a/src/client_server/context.rs b/src/client_server/context.rs index de7aae93..e93f5a5b 100644 --- a/src/client_server/context.rs +++ b/src/client_server/context.rs @@ -137,7 +137,7 @@ pub async fn get_context_route( .expect("All rooms have state"), }; - let state_ids = db.rooms.state_full_ids(shortstatehash)?; + let state_ids = db.rooms.state_full_ids(shortstatehash).await?; let end_token = events_after .last() diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs index 65107a3c..a1b616be 100644 --- a/src/client_server/membership.rs +++ b/src/client_server/membership.rs @@ -29,7 +29,7 @@ use ruma::{ }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use std::{ - collections::{hash_map::Entry, BTreeMap, HashMap, HashSet}, + collections::{hash_map::Entry, BTreeMap, HashMap}, iter, sync::{Arc, RwLock}, time::{Duration, Instant}, @@ -48,19 +48,20 @@ pub async fn join_room_by_id_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let mut servers: HashSet<_> = db - .rooms - .invite_state(sender_user, &body.room_id)? - .unwrap_or_default() - .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(|s| s.to_owned())) - .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()) - .collect(); - - servers.insert(body.room_id.server_name().to_owned()); + let mut servers = Vec::new(); // There is no body.server_name for /roomId/join + servers.extend( + db.rooms + .invite_state(sender_user, &body.room_id)? + .unwrap_or_default() + .iter() + .filter_map(|event| serde_json::from_str(event.json().get()).ok()) + .filter_map(|event: serde_json::Value| event.get("sender").cloned()) + .filter_map(|sender| sender.as_str().map(|s| s.to_owned())) + .filter_map(|sender| UserId::parse(sender).ok()) + .map(|user| user.server_name().to_owned()), + ); + + servers.push(body.room_id.server_name().to_owned()); let ret = join_room_by_id_helper( &db, @@ -91,19 +92,20 @@ pub async fn join_room_by_id_or_alias_route( let (servers, room_id) = match Box::::try_from(body.room_id_or_alias) { Ok(room_id) => { - let mut servers: HashSet<_> = db - .rooms - .invite_state(sender_user, &room_id)? - .unwrap_or_default() - .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(|s| s.to_owned())) - .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()) - .collect(); + let mut servers = body.server_name.clone(); + servers.extend( + db.rooms + .invite_state(sender_user, &room_id)? + .unwrap_or_default() + .iter() + .filter_map(|event| serde_json::from_str(event.json().get()).ok()) + .filter_map(|event: serde_json::Value| event.get("sender").cloned()) + .filter_map(|sender| sender.as_str().map(|s| s.to_owned())) + .filter_map(|sender| UserId::parse(sender).ok()) + .map(|user| user.server_name().to_owned()), + ); - servers.insert(room_id.server_name().to_owned()); + servers.push(room_id.server_name().to_owned()); (servers, room_id) } Err(room_alias) => { @@ -413,7 +415,8 @@ pub async fn get_member_events_route( Ok(get_member_events::v3::Response { chunk: db .rooms - .room_state_full(&body.room_id)? + .room_state_full(&body.room_id) + .await? .iter() .filter(|(key, _)| key.0 == StateEventType::RoomMember) .map(|(_, pdu)| pdu.to_member_event().into()) @@ -462,7 +465,7 @@ async fn join_room_by_id_helper( db: &Database, sender_user: Option<&UserId>, room_id: &RoomId, - servers: &HashSet>, + servers: &[Box], _third_party_signed: Option<&IncomingThirdPartySigned>, ) -> Result { let sender_user = sender_user.expect("user is authenticated"); @@ -478,7 +481,7 @@ async fn join_room_by_id_helper( let state_lock = mutex_state.lock().await; // Ask a remote server if we don't have this room - if !db.rooms.exists(room_id)? && room_id.server_name() != db.globals.server_name() { + if !db.rooms.exists(room_id)? { let mut make_join_response_and_server = Err(Error::BadServerResponse( "No server available to assist in joining.", )); @@ -1032,6 +1035,13 @@ pub(crate) async fn invite_helper<'a>( return Ok(()); } + if !db.rooms.is_joined(sender_user, &room_id)? { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "You don't have permission to view this room.", + )); + } + let mutex_state = Arc::clone( db.globals .roomid_mutex_state diff --git a/src/client_server/state.rs b/src/client_server/state.rs index 50fe9b4f..4df953cf 100644 --- a/src/client_server/state.rs +++ b/src/client_server/state.rs @@ -124,7 +124,8 @@ pub async fn get_state_events_route( Ok(get_state_events::v3::Response { room_state: db .rooms - .room_state_full(&body.room_id)? + .room_state_full(&body.room_id) + .await? .values() .map(|pdu| pdu.to_state_event()) .collect(), diff --git a/src/client_server/sync.rs b/src/client_server/sync.rs index d61e6894..0c294b7e 100644 --- a/src/client_server/sync.rs +++ b/src/client_server/sync.rs @@ -230,18 +230,20 @@ async fn sync_helper( for room_id in all_joined_rooms { let room_id = room_id?; - // Get and drop the lock to wait for remaining operations to finish - // This will make sure the we have all events until next_batch - let mutex_insert = Arc::clone( - db.globals - .roomid_mutex_insert - .write() - .unwrap() - .entry(room_id.clone()) - .or_default(), - ); - let insert_lock = mutex_insert.lock().unwrap(); - drop(insert_lock); + { + // Get and drop the lock to wait for remaining operations to finish + // This will make sure the we have all events until next_batch + let mutex_insert = Arc::clone( + db.globals + .roomid_mutex_insert + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), + ); + let insert_lock = mutex_insert.lock().unwrap(); + drop(insert_lock); + } let timeline_pdus; let limited; @@ -296,10 +298,12 @@ async fn sync_helper( // Database queries: - let current_shortstatehash = db - .rooms - .current_shortstatehash(&room_id)? - .expect("All rooms have state"); + let current_shortstatehash = if let Some(s) = db.rooms.current_shortstatehash(&room_id)? { + s + } else { + error!("Room {} has no state", room_id); + continue; + }; let since_shortstatehash = db.rooms.get_token_shortstatehash(&room_id, since)?; @@ -377,11 +381,12 @@ async fn sync_helper( let (joined_member_count, invited_member_count, heroes) = calculate_counts()?; - let current_state_ids = db.rooms.state_full_ids(current_shortstatehash)?; + let current_state_ids = db.rooms.state_full_ids(current_shortstatehash).await?; let mut state_events = Vec::new(); let mut lazy_loaded = HashSet::new(); + let mut i = 0; for (shortstatekey, id) in current_state_ids { let (event_type, state_key) = db.rooms.get_statekey_from_short(shortstatekey)?; @@ -394,6 +399,11 @@ async fn sync_helper( } }; state_events.push(pdu); + + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } } else if !lazy_load_enabled || body.full_state || timeline_users.contains(&state_key) @@ -411,6 +421,11 @@ async fn sync_helper( lazy_loaded.insert(uid); } state_events.push(pdu); + + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } } } @@ -462,8 +477,8 @@ async fn sync_helper( let mut lazy_loaded = HashSet::new(); if since_shortstatehash != current_shortstatehash { - let current_state_ids = db.rooms.state_full_ids(current_shortstatehash)?; - let since_state_ids = db.rooms.state_full_ids(since_shortstatehash)?; + let current_state_ids = db.rooms.state_full_ids(current_shortstatehash).await?; + let since_state_ids = db.rooms.state_full_ids(since_shortstatehash).await?; for (key, id) in current_state_ids { if body.full_state || since_state_ids.get(&key) != Some(&id) { @@ -490,6 +505,7 @@ async fn sync_helper( } state_events.push(pdu); + tokio::task::yield_now().await; } } } @@ -753,17 +769,19 @@ async fn sync_helper( for result in all_left_rooms { let (room_id, left_state_events) = result?; - // Get and drop the lock to wait for remaining operations to finish - let mutex_insert = Arc::clone( - db.globals - .roomid_mutex_insert - .write() - .unwrap() - .entry(room_id.clone()) - .or_default(), - ); - let insert_lock = mutex_insert.lock().unwrap(); - drop(insert_lock); + { + // Get and drop the lock to wait for remaining operations to finish + let mutex_insert = Arc::clone( + db.globals + .roomid_mutex_insert + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), + ); + let insert_lock = mutex_insert.lock().unwrap(); + drop(insert_lock); + } let left_count = db.rooms.get_left_count(&room_id, &sender_user)?; @@ -793,17 +811,19 @@ async fn sync_helper( for result in all_invited_rooms { let (room_id, invite_state_events) = result?; - // Get and drop the lock to wait for remaining operations to finish - let mutex_insert = Arc::clone( - db.globals - .roomid_mutex_insert - .write() - .unwrap() - .entry(room_id.clone()) - .or_default(), - ); - let insert_lock = mutex_insert.lock().unwrap(); - drop(insert_lock); + { + // Get and drop the lock to wait for remaining operations to finish + let mutex_insert = Arc::clone( + db.globals + .roomid_mutex_insert + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), + ); + let insert_lock = mutex_insert.lock().unwrap(); + drop(insert_lock); + } let invite_count = db.rooms.get_invite_count(&room_id, &sender_user)?; diff --git a/src/database/admin.rs b/src/database/admin.rs index c6ef9a64..3ed1a8a9 100644 --- a/src/database/admin.rs +++ b/src/database/admin.rs @@ -116,7 +116,7 @@ impl Admin { send_message(content, guard, &state_lock); } AdminRoomEvent::ProcessMessage(room_message) => { - let reply_message = process_admin_message(&*guard, room_message); + let reply_message = process_admin_message(&*guard, room_message).await; send_message(reply_message, guard, &state_lock); } @@ -143,7 +143,7 @@ impl Admin { } // Parse and process a message from the admin room -fn process_admin_message(db: &Database, room_message: String) -> RoomMessageEventContent { +async fn process_admin_message(db: &Database, room_message: String) -> RoomMessageEventContent { let mut lines = room_message.lines(); let command_line = lines.next().expect("each string has at least one line"); let body: Vec<_> = lines.collect(); @@ -161,7 +161,7 @@ fn process_admin_message(db: &Database, room_message: String) -> RoomMessageEven } }; - match process_admin_command(db, admin_command, body) { + match process_admin_command(db, admin_command, body).await { Ok(reply_message) => reply_message, Err(error) => { let markdown_message = format!( @@ -290,7 +290,7 @@ enum AdminCommand { EnableRoom { room_id: Box }, } -fn process_admin_command( +async fn process_admin_command( db: &Database, command: AdminCommand, body: Vec<&str>, @@ -404,7 +404,9 @@ fn process_admin_command( Error::bad_database("Invalid room id field in event in database") })?; let start = Instant::now(); - let count = server_server::get_auth_chain(room_id, vec![event_id], db)?.count(); + let count = server_server::get_auth_chain(room_id, vec![event_id], db) + .await? + .count(); let elapsed = start.elapsed(); RoomMessageEventContent::text_plain(format!( "Loaded auth chain with length {} in {:?}", diff --git a/src/database/rooms.rs b/src/database/rooms.rs index 2c1b8f44..7b3b7506 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -144,20 +144,28 @@ impl Rooms { /// Builds a StateMap by iterating over all keys that start /// with state_hash, this gives the full state for the given state_hash. #[tracing::instrument(skip(self))] - pub fn state_full_ids(&self, shortstatehash: u64) -> Result>> { + pub async fn state_full_ids(&self, shortstatehash: u64) -> Result>> { let full_state = self .load_shortstatehash_info(shortstatehash)? .pop() .expect("there is always one layer") .1; - full_state - .into_iter() - .map(|compressed| self.parse_compressed_state_event(compressed)) - .collect() + let mut result = BTreeMap::new(); + let mut i = 0; + for compressed in full_state.into_iter() { + let parsed = self.parse_compressed_state_event(compressed)?; + result.insert(parsed.0, parsed.1); + + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + } + Ok(result) } #[tracing::instrument(skip(self))] - pub fn state_full( + pub async fn state_full( &self, shortstatehash: u64, ) -> Result>> { @@ -166,14 +174,13 @@ impl Rooms { .pop() .expect("there is always one layer") .1; - Ok(full_state - .into_iter() - .map(|compressed| self.parse_compressed_state_event(compressed)) - .filter_map(|r| r.ok()) - .map(|(_, eventid)| self.get_pdu(&eventid)) - .filter_map(|r| r.ok().flatten()) - .map(|pdu| { - Ok::<_, Error>(( + + let mut result = HashMap::new(); + let mut i = 0; + for compressed in full_state { + let (_, eventid) = self.parse_compressed_state_event(compressed)?; + if let Some(pdu) = self.get_pdu(&eventid)? { + result.insert( ( pdu.kind.to_string().into(), pdu.state_key @@ -182,10 +189,16 @@ impl Rooms { .clone(), ), pdu, - )) - }) - .filter_map(|r| r.ok()) - .collect()) + ); + } + + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + } + + Ok(result) } /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). @@ -228,7 +241,6 @@ impl Rooms { } /// Returns the state hash for this pdu. - #[tracing::instrument(skip(self))] pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result> { self.eventid_shorteventid .get(event_id.as_bytes())? @@ -531,7 +543,6 @@ impl Rooms { } } - #[tracing::instrument(skip(self, globals))] pub fn compress_state_event( &self, shortstatekey: u64, @@ -548,7 +559,6 @@ impl Rooms { } /// Returns shortstatekey, event id - #[tracing::instrument(skip(self, compressed_event))] pub fn parse_compressed_state_event( &self, compressed_event: CompressedStateEvent, @@ -707,7 +717,6 @@ impl Rooms { } /// Returns (shortstatehash, already_existed) - #[tracing::instrument(skip(self, globals))] fn get_or_create_shortstatehash( &self, state_hash: &StateHashId, @@ -728,7 +737,6 @@ impl Rooms { }) } - #[tracing::instrument(skip(self, globals))] pub fn get_or_create_shorteventid( &self, event_id: &EventId, @@ -759,7 +767,6 @@ impl Rooms { Ok(short) } - #[tracing::instrument(skip(self))] pub fn get_shortroomid(&self, room_id: &RoomId) -> Result> { self.roomid_shortroomid .get(room_id.as_bytes())? @@ -770,7 +777,6 @@ impl Rooms { .transpose() } - #[tracing::instrument(skip(self))] pub fn get_shortstatekey( &self, event_type: &StateEventType, @@ -808,7 +814,6 @@ impl Rooms { Ok(short) } - #[tracing::instrument(skip(self, globals))] pub fn get_or_create_shortroomid( &self, room_id: &RoomId, @@ -826,7 +831,6 @@ impl Rooms { }) } - #[tracing::instrument(skip(self, globals))] pub fn get_or_create_shortstatekey( &self, event_type: &StateEventType, @@ -867,7 +871,6 @@ impl Rooms { Ok(short) } - #[tracing::instrument(skip(self))] pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { if let Some(id) = self .shorteventid_cache @@ -896,7 +899,6 @@ impl Rooms { Ok(event_id) } - #[tracing::instrument(skip(self))] pub fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { if let Some(id) = self .shortstatekey_cache @@ -940,12 +942,12 @@ impl Rooms { /// Returns the full room state. #[tracing::instrument(skip(self))] - pub fn room_state_full( + pub async fn room_state_full( &self, room_id: &RoomId, ) -> Result>> { if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { - self.state_full(current_shortstatehash) + self.state_full(current_shortstatehash).await } else { Ok(HashMap::new()) } @@ -982,14 +984,12 @@ impl Rooms { } /// Returns the `count` of this pdu's id. - #[tracing::instrument(skip(self))] pub fn pdu_count(&self, pdu_id: &[u8]) -> Result { utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::()..]) .map_err(|_| Error::bad_database("PDU has invalid count bytes.")) } /// Returns the `count` of this pdu's id. - #[tracing::instrument(skip(self))] pub fn get_pdu_count(&self, event_id: &EventId) -> Result> { self.eventid_pduid .get(event_id.as_bytes())? @@ -1018,7 +1018,6 @@ impl Rooms { } /// Returns the json of a pdu. - #[tracing::instrument(skip(self))] pub fn get_pdu_json(&self, event_id: &EventId) -> Result> { self.eventid_pduid .get(event_id.as_bytes())? @@ -1037,7 +1036,6 @@ impl Rooms { } /// Returns the json of a pdu. - #[tracing::instrument(skip(self))] pub fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { self.eventid_outlierpdu .get(event_id.as_bytes())? @@ -1048,7 +1046,6 @@ impl Rooms { } /// Returns the json of a pdu. - #[tracing::instrument(skip(self))] pub fn get_non_outlier_pdu_json( &self, event_id: &EventId, @@ -1068,7 +1065,6 @@ impl Rooms { } /// Returns the pdu's id. - #[tracing::instrument(skip(self))] pub fn get_pdu_id(&self, event_id: &EventId) -> Result>> { self.eventid_pduid.get(event_id.as_bytes()) } @@ -1076,7 +1072,6 @@ impl Rooms { /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - #[tracing::instrument(skip(self))] pub fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result> { self.eventid_pduid .get(event_id.as_bytes())? @@ -1095,7 +1090,6 @@ impl Rooms { /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - #[tracing::instrument(skip(self))] pub fn get_pdu(&self, event_id: &EventId) -> Result>> { if let Some(p) = self.pdu_cache.lock().unwrap().get_mut(event_id) { return Ok(Some(Arc::clone(p))); @@ -1132,7 +1126,6 @@ impl Rooms { /// Returns the pdu. /// /// This does __NOT__ check the outliers `Tree`. - #[tracing::instrument(skip(self))] pub fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result> { self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { Ok(Some( @@ -1143,7 +1136,6 @@ impl Rooms { } /// Returns the pdu as a `BTreeMap`. - #[tracing::instrument(skip(self))] pub fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result> { self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { Ok(Some( @@ -1232,7 +1224,6 @@ impl Rooms { } /// Returns the pdu from the outlier tree. - #[tracing::instrument(skip(self))] pub fn get_pdu_outlier(&self, event_id: &EventId) -> Result> { self.eventid_outlierpdu .get(event_id.as_bytes())? diff --git a/src/server_server.rs b/src/server_server.rs index 7b08cf9b..6fa83e4c 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -691,7 +691,7 @@ pub async fn send_transaction_message_route( .roomid_mutex_federation .write() .unwrap() - .entry(room_id.clone()) + .entry(room_id.to_owned()) .or_default(), ); let mutex_lock = mutex.lock().await; @@ -1054,6 +1054,25 @@ pub(crate) async fn handle_incoming_pdu<'a>( } } + if let Some((time, tries)) = db + .globals + .bad_event_ratelimiter + .read() + .unwrap() + .get(&*prev_id) + { + // Exponential backoff + let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); + if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { + min_elapsed_duration = Duration::from_secs(60 * 60 * 24); + } + + if time.elapsed() < min_elapsed_duration { + info!("Backing off from {}", prev_id); + continue; + } + } + if errors >= 5 { break; } @@ -1068,7 +1087,6 @@ pub(crate) async fn handle_incoming_pdu<'a>( .write() .unwrap() .insert(room_id.to_owned(), ((*prev_id).to_owned(), start_time)); - let event_id = pdu.event_id.clone(); if let Err(e) = upgrade_outlier_to_timeline_pdu( pdu, json, @@ -1081,7 +1099,21 @@ pub(crate) async fn handle_incoming_pdu<'a>( .await { errors += 1; - warn!("Prev event {} failed: {}", event_id, e); + warn!("Prev event {} failed: {}", prev_id, e); + match db + .globals + .bad_event_ratelimiter + .write() + .unwrap() + .entry((*prev_id).to_owned()) + { + hash_map::Entry::Vacant(e) => { + e.insert((Instant::now(), 1)); + } + hash_map::Entry::Occupied(mut e) => { + *e.get_mut() = (Instant::now(), e.get().1 + 1) + } + } } let elapsed = start_time.elapsed(); db.globals @@ -1091,7 +1123,7 @@ pub(crate) async fn handle_incoming_pdu<'a>( .remove(&room_id.to_owned()); warn!( "Handling prev event {} took {}m{}s", - event_id, + prev_id, elapsed.as_secs() / 60, elapsed.as_secs() % 60 ); @@ -1321,8 +1353,11 @@ async fn upgrade_outlier_to_timeline_pdu( .pdu_shortstatehash(prev_event) .map_err(|_| "Failed talking to db".to_owned())?; - let state = - prev_event_sstatehash.map(|shortstatehash| db.rooms.state_full_ids(shortstatehash)); + let state = if let Some(shortstatehash) = prev_event_sstatehash { + Some(db.rooms.state_full_ids(shortstatehash).await) + } else { + None + }; if let Some(Ok(mut state)) = state { info!("Using cached state"); @@ -1378,6 +1413,7 @@ async fn upgrade_outlier_to_timeline_pdu( let mut leaf_state: BTreeMap<_, _> = db .rooms .state_full_ids(sstatehash) + .await .map_err(|_| "Failed to ask db for room state.".to_owned())?; if let Some(state_key) = &prev_event.state_key { @@ -1409,6 +1445,7 @@ async fn upgrade_outlier_to_timeline_pdu( auth_chain_sets.push( get_auth_chain(room_id, starting_events, db) + .await .map_err(|_| "Failed to load auth chain.".to_owned())? .collect(), ); @@ -1535,6 +1572,7 @@ async fn upgrade_outlier_to_timeline_pdu( let state_at_incoming_event = state_at_incoming_event.expect("we always set this to some above"); + info!("Starting auth check"); // 11. Check the auth of the event passes based on the state of the event let check_result = state_res::event_auth::auth_check( &room_version, @@ -1554,7 +1592,7 @@ async fn upgrade_outlier_to_timeline_pdu( if !check_result { return Err("Event has failed auth check with state at the event.".into()); } - info!("Auth check succeeded."); + info!("Auth check succeeded"); // We start looking at current room state now, so lets lock the room @@ -1570,6 +1608,7 @@ async fn upgrade_outlier_to_timeline_pdu( // Now we calculate the set of extremities this room has after the incoming event has been // applied. We start with the previous extremities (aka leaves) + info!("Calculating extremities"); let mut extremities = db .rooms .get_pdu_leaves(room_id) @@ -1585,16 +1624,18 @@ async fn upgrade_outlier_to_timeline_pdu( // Only keep those extremities were not referenced yet extremities.retain(|id| !matches!(db.rooms.is_event_referenced(room_id, id), Ok(true))); - let current_sstatehash = db - .rooms - .current_shortstatehash(room_id) - .map_err(|_| "Failed to load current state hash.".to_owned())? - .expect("every room has state"); + info!("Compressing state at event"); + let state_ids_compressed = state_at_incoming_event + .iter() + .map(|(shortstatekey, id)| { + db.rooms + .compress_state_event(*shortstatekey, id, &db.globals) + .map_err(|_| "Failed to compress_state_event".to_owned()) + }) + .collect::>()?; - let current_state_ids = db - .rooms - .state_full_ids(current_sstatehash) - .map_err(|_| "Failed to load room state.")?; + // 13. Check if the event passes auth based on the "current state" of the room, if not "soft fail" it + info!("Starting soft fail auth check"); let auth_events = db .rooms @@ -1607,18 +1648,6 @@ async fn upgrade_outlier_to_timeline_pdu( ) .map_err(|_| "Failed to get_auth_events.".to_owned())?; - let state_ids_compressed = state_at_incoming_event - .iter() - .map(|(shortstatekey, id)| { - db.rooms - .compress_state_event(*shortstatekey, id, &db.globals) - .map_err(|_| "Failed to compress_state_event".to_owned()) - }) - .collect::>()?; - - // 13. Check if the event passes auth based on the "current state" of the room, if not "soft fail" it - info!("Starting soft fail auth check"); - let soft_fail = !state_res::event_auth::auth_check( &room_version, &incoming_pdu, @@ -1651,6 +1680,19 @@ async fn upgrade_outlier_to_timeline_pdu( } if incoming_pdu.state_key.is_some() { + info!("Loading current room state ids"); + let current_sstatehash = db + .rooms + .current_shortstatehash(room_id) + .map_err(|_| "Failed to load current state hash.".to_owned())? + .expect("every room has state"); + + let current_state_ids = db + .rooms + .state_full_ids(current_sstatehash) + .await + .map_err(|_| "Failed to load room state.")?; + info!("Preparing for stateres to derive new room state"); let mut extremity_sstatehashes = HashMap::new(); @@ -1738,6 +1780,7 @@ async fn upgrade_outlier_to_timeline_pdu( state.iter().map(|(_, id)| id.clone()).collect(), db, ) + .await .map_err(|_| "Failed to load auth chain.".to_owned())? .collect(), ); @@ -1899,11 +1942,17 @@ pub(crate) fn fetch_and_handle_outliers<'a>( let mut todo_auth_events = vec![Arc::clone(id)]; let mut events_in_reverse_order = Vec::new(); let mut events_all = HashSet::new(); + let mut i = 0; while let Some(next_id) = todo_auth_events.pop() { if events_all.contains(&next_id) { continue; } + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + if let Ok(Some(_)) = db.rooms.get_pdu(&next_id) { trace!("Found {} in db", id); continue; @@ -2242,7 +2291,7 @@ fn append_incoming_pdu<'a>( } #[tracing::instrument(skip(starting_events, db))] -pub(crate) fn get_auth_chain<'a>( +pub(crate) async fn get_auth_chain<'a>( room_id: &RoomId, starting_events: Vec>, db: &'a Database, @@ -2251,10 +2300,15 @@ pub(crate) fn get_auth_chain<'a>( let mut buckets = vec![BTreeSet::new(); NUM_BUCKETS]; + let mut i = 0; for id in starting_events { let short = db.rooms.get_or_create_shorteventid(&id, &db.globals)?; let bucket_id = (short % NUM_BUCKETS as u64) as usize; buckets[bucket_id].insert((short, id.clone())); + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } } let mut full_auth_chain = HashSet::new(); @@ -2277,6 +2331,7 @@ pub(crate) fn get_auth_chain<'a>( let mut chunk_cache = HashSet::new(); let mut hits2 = 0; let mut misses2 = 0; + let mut i = 0; for (sevent_id, event_id) in chunk { if let Some(cached) = db.rooms.get_auth_chain_from_cache(&[sevent_id])? { hits2 += 1; @@ -2292,6 +2347,11 @@ pub(crate) fn get_auth_chain<'a>( auth_chain.len() ); chunk_cache.extend(auth_chain.iter()); + + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } }; } println!( @@ -2512,7 +2572,7 @@ pub async fn get_event_authorization_route( let room_id = <&RoomId>::try_from(room_id_str) .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; - let auth_chain_ids = get_auth_chain(room_id, vec![Arc::from(&*body.event_id)], &db)?; + let auth_chain_ids = get_auth_chain(room_id, vec![Arc::from(&*body.event_id)], &db).await?; Ok(get_event_authorization::v1::Response { auth_chain: auth_chain_ids @@ -2557,7 +2617,8 @@ pub async fn get_room_state_route( let pdus = db .rooms - .state_full_ids(shortstatehash)? + .state_full_ids(shortstatehash) + .await? .into_iter() .map(|(_, id)| { PduEvent::convert_to_outgoing_federation_event( @@ -2566,7 +2627,8 @@ pub async fn get_room_state_route( }) .collect(); - let auth_chain_ids = get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)], &db)?; + let auth_chain_ids = + get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)], &db).await?; Ok(get_room_state::v1::Response { auth_chain: auth_chain_ids @@ -2616,12 +2678,14 @@ pub async fn get_room_state_ids_route( let pdu_ids = db .rooms - .state_full_ids(shortstatehash)? + .state_full_ids(shortstatehash) + .await? .into_iter() .map(|(_, id)| (*id).to_owned()) .collect(); - let auth_chain_ids = get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)], &db)?; + let auth_chain_ids = + get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)], &db).await?; Ok(get_room_state_ids::v1::Response { auth_chain_ids: auth_chain_ids.map(|id| (*id).to_owned()).collect(), @@ -2927,12 +2991,13 @@ async fn create_join_event( ))?; drop(mutex_lock); - let state_ids = db.rooms.state_full_ids(shortstatehash)?; + let state_ids = db.rooms.state_full_ids(shortstatehash).await?; let auth_chain_ids = get_auth_chain( room_id, state_ids.iter().map(|(_, id)| id.clone()).collect(), db, - )?; + ) + .await?; let servers = db .rooms