diff --git a/rust-toolchain b/rust-toolchain index a63cb35e..d96ae405 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -1.52.0 +1.52 diff --git a/src/client_server/account.rs b/src/client_server/account.rs index 0fc8b28a..9e16d90d 100644 --- a/src/client_server/account.rs +++ b/src/client_server/account.rs @@ -1,4 +1,4 @@ -use std::{collections::BTreeMap, convert::TryInto}; +use std::{collections::BTreeMap, convert::TryInto, sync::Arc}; use super::{DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH}; use crate::{database::DatabaseGuard, pdu::PduBuilder, utils, ConduitResult, Error, Ruma}; @@ -238,6 +238,16 @@ pub async fn register_route( let room_id = RoomId::new(db.globals.server_name()); + let mutex = Arc::clone( + db.globals + .roomid_mutex + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), + ); + let mutex_lock = mutex.lock().await; + let mut content = ruma::events::room::create::CreateEventContent::new(conduit_user.clone()); content.federate = true; content.predecessor = None; @@ -255,6 +265,7 @@ pub async fn register_route( &conduit_user, &room_id, &db, + &mutex_lock, )?; // 2. Make conduit bot join @@ -276,6 +287,7 @@ pub async fn register_route( &conduit_user, &room_id, &db, + &mutex_lock, )?; // 3. Power levels @@ -300,6 +312,7 @@ pub async fn register_route( &conduit_user, &room_id, &db, + &mutex_lock, )?; // 4.1 Join Rules @@ -317,6 +330,7 @@ pub async fn register_route( &conduit_user, &room_id, &db, + &mutex_lock, )?; // 4.2 History Visibility @@ -336,6 +350,7 @@ pub async fn register_route( &conduit_user, &room_id, &db, + &mutex_lock, )?; // 4.3 Guest Access @@ -353,6 +368,7 @@ pub async fn register_route( &conduit_user, &room_id, &db, + &mutex_lock, )?; // 6. Events implied by name and topic @@ -372,6 +388,7 @@ pub async fn register_route( &conduit_user, &room_id, &db, + &mutex_lock, )?; db.rooms.build_and_append_pdu( @@ -388,6 +405,7 @@ pub async fn register_route( &conduit_user, &room_id, &db, + &mutex_lock, )?; // Room alias @@ -410,6 +428,7 @@ pub async fn register_route( &conduit_user, &room_id, &db, + &mutex_lock, )?; db.rooms.set_alias(&alias, Some(&room_id), &db.globals)?; @@ -433,6 +452,7 @@ pub async fn register_route( &conduit_user, &room_id, &db, + &mutex_lock, )?; db.rooms.build_and_append_pdu( PduBuilder { @@ -452,6 +472,7 @@ pub async fn register_route( &user_id, &room_id, &db, + &mutex_lock, )?; // Send welcome message @@ -470,6 +491,7 @@ pub async fn register_route( &conduit_user, &room_id, &db, + &mutex_lock, )?; } @@ -641,6 +663,16 @@ pub async fn deactivate_route( third_party_invite: None, }; + let mutex = Arc::clone( + db.globals + .roomid_mutex + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), + ); + let mutex_lock = mutex.lock().await; + db.rooms.build_and_append_pdu( PduBuilder { event_type: EventType::RoomMember, @@ -652,6 +684,7 @@ pub async fn deactivate_route( &sender_user, &room_id, &db, + &mutex_lock, )?; } diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs index 4667f25d..a74950b6 100644 --- a/src/client_server/membership.rs +++ b/src/client_server/membership.rs @@ -203,6 +203,16 @@ pub async fn kick_user_route( event.membership = ruma::events::room::member::MembershipState::Leave; // TODO: reason + let mutex = Arc::clone( + db.globals + .roomid_mutex + .write() + .unwrap() + .entry(body.room_id.clone()) + .or_default(), + ); + let mutex_lock = mutex.lock().await; + db.rooms.build_and_append_pdu( PduBuilder { event_type: EventType::RoomMember, @@ -214,8 +224,11 @@ pub async fn kick_user_route( &sender_user, &body.room_id, &db, + &mutex_lock, )?; + drop(mutex_lock); + db.flush().await?; Ok(kick_user::Response::new().into()) @@ -261,6 +274,16 @@ pub async fn ban_user_route( }, )?; + let mutex = Arc::clone( + db.globals + .roomid_mutex + .write() + .unwrap() + .entry(body.room_id.clone()) + .or_default(), + ); + let mutex_lock = mutex.lock().await; + db.rooms.build_and_append_pdu( PduBuilder { event_type: EventType::RoomMember, @@ -272,8 +295,11 @@ pub async fn ban_user_route( &sender_user, &body.room_id, &db, + &mutex_lock, )?; + drop(mutex_lock); + db.flush().await?; Ok(ban_user::Response::new().into()) @@ -310,6 +336,16 @@ pub async fn unban_user_route( event.membership = ruma::events::room::member::MembershipState::Leave; + let mutex = Arc::clone( + db.globals + .roomid_mutex + .write() + .unwrap() + .entry(body.room_id.clone()) + .or_default(), + ); + let mutex_lock = mutex.lock().await; + db.rooms.build_and_append_pdu( PduBuilder { event_type: EventType::RoomMember, @@ -321,8 +357,11 @@ pub async fn unban_user_route( &sender_user, &body.room_id, &db, + &mutex_lock, )?; + drop(mutex_lock); + db.flush().await?; Ok(unban_user::Response::new().into()) @@ -446,6 +485,16 @@ async fn join_room_by_id_helper( ) -> ConduitResult { let sender_user = sender_user.expect("user is authenticated"); + let mutex = Arc::clone( + db.globals + .roomid_mutex + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), + ); + let mutex_lock = mutex.lock().await; + // Ask a remote server if we don't have this room if !db.rooms.exists(&room_id)? && room_id.server_name() != db.globals.server_name() { let mut make_join_response_and_server = Err(Error::BadServerResponse( @@ -619,16 +668,9 @@ async fn join_room_by_id_helper( // pdu without it's state. This is okay because append_pdu can't fail. let statehashid = db.rooms.append_to_state(&pdu, &db.globals)?; - let count = db.globals.next_count()?; - let mut pdu_id = room_id.as_bytes().to_vec(); - pdu_id.push(0xff); - pdu_id.extend_from_slice(&count.to_be_bytes()); - db.rooms.append_pdu( &pdu, utils::to_canonical_object(&pdu).expect("Pdu is valid canonical object"), - count, - &pdu_id, &[pdu.event_id.clone()], db, )?; @@ -656,9 +698,12 @@ async fn join_room_by_id_helper( &sender_user, &room_id, &db, + &mutex_lock, )?; } + drop(mutex_lock); + db.flush().await?; Ok(join_room_by_id::Response::new(room_id.clone()).into()) @@ -728,13 +773,23 @@ async fn validate_and_add_event_id( Ok((event_id, value)) } -pub async fn invite_helper( +pub async fn invite_helper<'a>( sender_user: &UserId, user_id: &UserId, room_id: &RoomId, db: &Database, is_direct: bool, ) -> Result<()> { + let mutex = Arc::clone( + db.globals + .roomid_mutex + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), + ); + let mutex_lock = mutex.lock().await; + if user_id.server_name() != db.globals.server_name() { let prev_events = db .rooms @@ -870,6 +925,8 @@ pub async fn invite_helper( ) .expect("event is valid, we just created it"); + drop(mutex_lock); + let invite_room_state = db.rooms.calculate_invite_state(&pdu)?; let response = db .sending @@ -909,19 +966,26 @@ pub async fn invite_helper( ) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?; - let pdu_id = - server_server::handle_incoming_pdu(&origin, &event_id, value, true, &db, &pub_key_map) - .await - .map_err(|_| { - Error::BadRequest( - ErrorKind::InvalidParam, - "Error while handling incoming PDU.", - ) - })? - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Could not accept incoming PDU as timeline event.", - ))?; + let pdu_id = server_server::handle_incoming_pdu( + &origin, + &event_id, + &room_id, + value, + true, + &db, + &pub_key_map, + ) + .await + .map_err(|_| { + Error::BadRequest( + ErrorKind::InvalidParam, + "Error while handling incoming PDU.", + ) + })? + .ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Could not accept incoming PDU as timeline event.", + ))?; for server in db .rooms @@ -953,6 +1017,7 @@ pub async fn invite_helper( &sender_user, room_id, &db, + &mutex_lock, )?; Ok(()) diff --git a/src/client_server/message.rs b/src/client_server/message.rs index 7e898b11..3d8218c6 100644 --- a/src/client_server/message.rs +++ b/src/client_server/message.rs @@ -10,6 +10,7 @@ use ruma::{ use std::{ collections::BTreeMap, convert::{TryFrom, TryInto}, + sync::Arc, }; #[cfg(feature = "conduit_bin")] @@ -27,6 +28,16 @@ pub async fn send_message_event_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_deref(); + let mutex = Arc::clone( + db.globals + .roomid_mutex + .write() + .unwrap() + .entry(body.room_id.clone()) + .or_default(), + ); + let mutex_lock = mutex.lock().await; + // Check if this is a new transaction id if let Some(response) = db.transaction_ids @@ -64,6 +75,7 @@ pub async fn send_message_event_route( &sender_user, &body.room_id, &db, + &mutex_lock, )?; db.transaction_ids.add_txnid( @@ -73,6 +85,8 @@ pub async fn send_message_event_route( event_id.as_bytes(), )?; + drop(mutex_lock); + db.flush().await?; Ok(send_message_event::Response::new(event_id).into()) diff --git a/src/client_server/profile.rs b/src/client_server/profile.rs index 5281a4a2..d947bbe1 100644 --- a/src/client_server/profile.rs +++ b/src/client_server/profile.rs @@ -9,7 +9,7 @@ use ruma::{ events::EventType, serde::Raw, }; -use std::convert::TryInto; +use std::{convert::TryInto, sync::Arc}; #[cfg(feature = "conduit_bin")] use rocket::{get, put}; @@ -69,9 +69,19 @@ pub async fn set_displayname_route( }) .filter_map(|r| r.ok()) { - let _ = db - .rooms - .build_and_append_pdu(pdu_builder, &sender_user, &room_id, &db); + let mutex = Arc::clone( + db.globals + .roomid_mutex + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), + ); + let mutex_lock = mutex.lock().await; + + let _ = + db.rooms + .build_and_append_pdu(pdu_builder, &sender_user, &room_id, &db, &mutex_lock); // Presence update db.rooms.edus.update_presence( @@ -171,9 +181,19 @@ pub async fn set_avatar_url_route( }) .filter_map(|r| r.ok()) { - let _ = db - .rooms - .build_and_append_pdu(pdu_builder, &sender_user, &room_id, &db); + let mutex = Arc::clone( + db.globals + .roomid_mutex + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), + ); + let mutex_lock = mutex.lock().await; + + let _ = + db.rooms + .build_and_append_pdu(pdu_builder, &sender_user, &room_id, &db, &mutex_lock); // Presence update db.rooms.edus.update_presence( diff --git a/src/client_server/redact.rs b/src/client_server/redact.rs index 3db27716..2e4c6519 100644 --- a/src/client_server/redact.rs +++ b/src/client_server/redact.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use crate::{database::DatabaseGuard, pdu::PduBuilder, ConduitResult, Ruma}; use ruma::{ api::client::r0::redact::redact_event, @@ -18,6 +20,16 @@ pub async fn redact_event_route( ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let mutex = Arc::clone( + db.globals + .roomid_mutex + .write() + .unwrap() + .entry(body.room_id.clone()) + .or_default(), + ); + let mutex_lock = mutex.lock().await; + let event_id = db.rooms.build_and_append_pdu( PduBuilder { event_type: EventType::RoomRedaction, @@ -32,8 +44,11 @@ pub async fn redact_event_route( &sender_user, &body.room_id, &db, + &mutex_lock, )?; + drop(mutex_lock); + db.flush().await?; Ok(redact_event::Response { event_id }.into()) diff --git a/src/client_server/room.rs b/src/client_server/room.rs index 43625fe5..f48c5e93 100644 --- a/src/client_server/room.rs +++ b/src/client_server/room.rs @@ -15,7 +15,7 @@ use ruma::{ serde::Raw, RoomAliasId, RoomId, RoomVersionId, }; -use std::{cmp::max, collections::BTreeMap, convert::TryFrom}; +use std::{cmp::max, collections::BTreeMap, convert::TryFrom, sync::Arc}; #[cfg(feature = "conduit_bin")] use rocket::{get, post}; @@ -33,6 +33,16 @@ pub async fn create_room_route( let room_id = RoomId::new(db.globals.server_name()); + let mutex = Arc::clone( + db.globals + .roomid_mutex + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), + ); + let mutex_lock = mutex.lock().await; + let alias = body .room_alias_name .as_ref() @@ -69,6 +79,7 @@ pub async fn create_room_route( &sender_user, &room_id, &db, + &mutex_lock, )?; // 2. Let the room creator join @@ -90,6 +101,7 @@ pub async fn create_room_route( &sender_user, &room_id, &db, + &mutex_lock, )?; // 3. Power levels @@ -144,6 +156,7 @@ pub async fn create_room_route( &sender_user, &room_id, &db, + &mutex_lock, )?; // 4. Events set by preset @@ -170,6 +183,7 @@ pub async fn create_room_route( &sender_user, &room_id, &db, + &mutex_lock, )?; // 4.2 History Visibility @@ -187,6 +201,7 @@ pub async fn create_room_route( &sender_user, &room_id, &db, + &mutex_lock, )?; // 4.3 Guest Access @@ -212,6 +227,7 @@ pub async fn create_room_route( &sender_user, &room_id, &db, + &mutex_lock, )?; // 5. Events listed in initial_state @@ -227,7 +243,7 @@ pub async fn create_room_route( } db.rooms - .build_and_append_pdu(pdu_builder, &sender_user, &room_id, &db)?; + .build_and_append_pdu(pdu_builder, &sender_user, &room_id, &db, &mutex_lock)?; } // 6. Events implied by name and topic @@ -248,6 +264,7 @@ pub async fn create_room_route( &sender_user, &room_id, &db, + &mutex_lock, )?; } @@ -266,10 +283,12 @@ pub async fn create_room_route( &sender_user, &room_id, &db, + &mutex_lock, )?; } // 7. Events implied by invite (and TODO: invite_3pid) + drop(mutex_lock); for user_id in &body.invite { let _ = invite_helper(sender_user, user_id, &room_id, &db, body.is_direct).await; } @@ -340,6 +359,16 @@ pub async fn upgrade_room_route( // Create a replacement room let replacement_room = RoomId::new(db.globals.server_name()); + let mutex = Arc::clone( + db.globals + .roomid_mutex + .write() + .unwrap() + .entry(body.room_id.clone()) + .or_default(), + ); + let mutex_lock = mutex.lock().await; + // Send a m.room.tombstone event to the old room to indicate that it is not intended to be used any further // Fail if the sender does not have the required permissions let tombstone_event_id = db.rooms.build_and_append_pdu( @@ -357,6 +386,7 @@ pub async fn upgrade_room_route( sender_user, &body.room_id, &db, + &mutex_lock, )?; // Get the old room federations status @@ -397,6 +427,7 @@ pub async fn upgrade_room_route( sender_user, &replacement_room, &db, + &mutex_lock, )?; // Join the new room @@ -418,6 +449,7 @@ pub async fn upgrade_room_route( sender_user, &replacement_room, &db, + &mutex_lock, )?; // Recommended transferable state events list from the specs @@ -451,6 +483,7 @@ pub async fn upgrade_room_route( sender_user, &replacement_room, &db, + &mutex_lock, )?; } @@ -494,8 +527,11 @@ pub async fn upgrade_room_route( sender_user, &body.room_id, &db, + &mutex_lock, )?; + drop(mutex_lock); + db.flush().await?; // Return the replacement room id diff --git a/src/client_server/state.rs b/src/client_server/state.rs index 68246d54..e0e5d29a 100644 --- a/src/client_server/state.rs +++ b/src/client_server/state.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use crate::{ database::DatabaseGuard, pdu::PduBuilder, ConduitResult, Database, Error, Result, Ruma, }; @@ -257,6 +259,16 @@ pub async fn send_state_event_for_key_helper( } } + let mutex = Arc::clone( + db.globals + .roomid_mutex + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), + ); + let mutex_lock = mutex.lock().await; + let event_id = db.rooms.build_and_append_pdu( PduBuilder { event_type, @@ -268,6 +280,7 @@ pub async fn send_state_event_for_key_helper( &sender_user, &room_id, &db, + &mutex_lock, )?; Ok(event_id) diff --git a/src/client_server/sync.rs b/src/client_server/sync.rs index c57f1da1..fe113048 100644 --- a/src/client_server/sync.rs +++ b/src/client_server/sync.rs @@ -89,7 +89,9 @@ pub async fn sync_events_route( let we_have_to_wait = rx.borrow().is_none(); if we_have_to_wait { - let _ = rx.changed().await; + if let Err(e) = rx.changed().await { + error!("Error waiting for sync: {}", e); + } } let result = match rx @@ -187,6 +189,18 @@ async fn sync_helper( for room_id in db.rooms.rooms_joined(&sender_user) { let room_id = room_id?; + // Get and drop the lock to wait for remaining operations to finish + let mutex = Arc::clone( + db.globals + .roomid_mutex + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), + ); + let mutex_lock = mutex.lock().await; + drop(mutex_lock); + let mut non_timeline_pdus = db .rooms .pdus_until(&sender_user, &room_id, u64::MAX) @@ -225,13 +239,16 @@ async fn sync_helper( // Database queries: - let current_shortstatehash = db.rooms.current_shortstatehash(&room_id)?; + let current_shortstatehash = db + .rooms + .current_shortstatehash(&room_id)? + .expect("All rooms have state"); - // These type is Option>. The outer Option is None when there is no event between - // since and the current room state, meaning there should be no updates. - // The inner Option is None when there is an event, but there is no state hash associated - // with it. This can happen for the RoomCreate event, so all updates should arrive. - let first_pdu_before_since = db.rooms.pdus_until(&sender_user, &room_id, since).next(); + let first_pdu_before_since = db + .rooms + .pdus_until(&sender_user, &room_id, since) + .next() + .transpose()?; let pdus_after_since = db .rooms @@ -239,11 +256,78 @@ async fn sync_helper( .next() .is_some(); - let since_shortstatehash = first_pdu_before_since.as_ref().map(|pdu| { - db.rooms - .pdu_shortstatehash(&pdu.as_ref().ok()?.1.event_id) - .ok()? - }); + let since_shortstatehash = first_pdu_before_since + .as_ref() + .map(|pdu| { + db.rooms + .pdu_shortstatehash(&pdu.1.event_id) + .transpose() + .expect("all pdus have state") + }) + .transpose()?; + + // Calculates joined_member_count, invited_member_count and heroes + let calculate_counts = || { + let joined_member_count = db.rooms.room_members(&room_id).count(); + let invited_member_count = db.rooms.room_members_invited(&room_id).count(); + + // Recalculate heroes (first 5 members) + let mut heroes = Vec::new(); + + if joined_member_count + invited_member_count <= 5 { + // Go through all PDUs and for each member event, check if the user is still joined or + // invited until we have 5 or we reach the end + + for hero in db + .rooms + .all_pdus(&sender_user, &room_id) + .filter_map(|pdu| pdu.ok()) // Ignore all broken pdus + .filter(|(_, pdu)| pdu.kind == EventType::RoomMember) + .map(|(_, pdu)| { + let content = serde_json::from_value::< + ruma::events::room::member::MemberEventContent, + >(pdu.content.clone()) + .map_err(|_| Error::bad_database("Invalid member event in database."))?; + + if let Some(state_key) = &pdu.state_key { + let user_id = UserId::try_from(state_key.clone()).map_err(|_| { + Error::bad_database("Invalid UserId in member PDU.") + })?; + + // The membership was and still is invite or join + if matches!( + content.membership, + MembershipState::Join | MembershipState::Invite + ) && (db.rooms.is_joined(&user_id, &room_id)? + || db.rooms.is_invited(&user_id, &room_id)?) + { + Ok::<_, Error>(Some(state_key.clone())) + } else { + Ok(None) + } + } else { + Ok(None) + } + }) + // Filter out buggy users + .filter_map(|u| u.ok()) + // Filter for possible heroes + .flatten() + { + if heroes.contains(&hero) || hero == sender_user.as_str() { + continue; + } + + heroes.push(hero); + } + } + + ( + Some(joined_member_count), + Some(invited_member_count), + heroes, + ) + }; let ( heroes, @@ -251,63 +335,107 @@ async fn sync_helper( invited_member_count, joined_since_last_sync, state_events, - ) = if pdus_after_since && Some(current_shortstatehash) != since_shortstatehash { - let current_state = db.rooms.room_state_full(&room_id)?; - let current_members = current_state + ) = if since_shortstatehash.is_none() { + // Probably since = 0, we will do an initial sync + let (joined_member_count, invited_member_count, heroes) = calculate_counts(); + + let current_state_ids = db.rooms.state_full_ids(current_shortstatehash)?; + let state_events = current_state_ids .iter() - .filter(|(key, _)| key.0 == EventType::RoomMember) - .map(|(key, value)| (&key.1, value)) // Only keep state key + .map(|id| db.rooms.get_pdu(id)) + .filter_map(|r| r.ok().flatten()) .collect::>(); - let encrypted_room = current_state - .get(&(EventType::RoomEncryption, "".to_owned())) - .is_some(); - let since_state = since_shortstatehash - .as_ref() - .map(|since_shortstatehash| { - since_shortstatehash - .map(|since_shortstatehash| db.rooms.state_full(since_shortstatehash)) - .transpose() - }) - .transpose()?; - let since_encryption = since_state.as_ref().map(|state| { - state - .as_ref() - .map(|state| state.get(&(EventType::RoomEncryption, "".to_owned()))) - }); + ( + heroes, + joined_member_count, + invited_member_count, + true, + state_events, + ) + } else if !pdus_after_since || since_shortstatehash == Some(current_shortstatehash) { + // No state changes + (Vec::new(), None, None, false, Vec::new()) + } else { + // Incremental /sync + let since_shortstatehash = since_shortstatehash.unwrap(); - // Calculations: - let new_encrypted_room = - encrypted_room && since_encryption.map_or(true, |encryption| encryption.is_none()); + let since_sender_member = db + .rooms + .state_get( + since_shortstatehash, + &EventType::RoomMember, + sender_user.as_str(), + )? + .and_then(|pdu| { + serde_json::from_value::>( + pdu.content.clone(), + ) + .expect("Raw::from_value always works") + .deserialize() + .map_err(|_| Error::bad_database("Invalid PDU in database.")) + .ok() + }); + + let joined_since_last_sync = since_sender_member + .map_or(true, |member| member.membership != MembershipState::Join); + + let current_state_ids = db.rooms.state_full_ids(current_shortstatehash)?; + + let since_state_ids = db.rooms.state_full_ids(since_shortstatehash)?; - let send_member_count = since_state.as_ref().map_or(true, |since_state| { - since_state.as_ref().map_or(true, |since_state| { - current_members.len() - != since_state + let state_events = if joined_since_last_sync { + current_state_ids + .iter() + .map(|id| db.rooms.get_pdu(id)) + .filter_map(|r| r.ok().flatten()) + .collect::>() + } else { + current_state_ids + .difference(&since_state_ids) + .filter(|id| { + !timeline_pdus .iter() - .filter(|(key, _)| key.0 == EventType::RoomMember) - .count() - }) - }); + .any(|(_, timeline_pdu)| timeline_pdu.event_id == **id) + }) + .map(|id| db.rooms.get_pdu(id)) + .filter_map(|r| r.ok().flatten()) + .collect() + }; - let since_sender_member = since_state.as_ref().map(|since_state| { - since_state.as_ref().and_then(|state| { - state - .get(&(EventType::RoomMember, sender_user.as_str().to_owned())) - .and_then(|pdu| { - serde_json::from_value::< - Raw, - >(pdu.content.clone()) - .expect("Raw::from_value always works") - .deserialize() - .map_err(|_| Error::bad_database("Invalid PDU in database.")) - .ok() - }) - }) - }); + let encrypted_room = db + .rooms + .state_get(current_shortstatehash, &EventType::RoomEncryption, "")? + .is_some(); + + let since_encryption = + db.rooms + .state_get(since_shortstatehash, &EventType::RoomEncryption, "")?; + + // Calculations: + let new_encrypted_room = encrypted_room && since_encryption.is_none(); + + let send_member_count = state_events + .iter() + .any(|event| event.kind == EventType::RoomMember); if encrypted_room { - for (user_id, current_member) in current_members { + for (user_id, current_member) in db + .rooms + .room_members(&room_id) + .filter_map(|r| r.ok()) + .filter_map(|user_id| { + db.rooms + .state_get( + current_shortstatehash, + &EventType::RoomMember, + user_id.as_str(), + ) + .ok() + .flatten() + .map(|current_member| (user_id, current_member)) + }) + { let current_membership = serde_json::from_value::< Raw, >(current_member.content.clone()) @@ -316,31 +444,23 @@ async fn sync_helper( .map_err(|_| Error::bad_database("Invalid PDU in database."))? .membership; - let since_membership = - since_state - .as_ref() - .map_or(MembershipState::Leave, |since_state| { - since_state - .as_ref() - .and_then(|since_state| { - since_state - .get(&(EventType::RoomMember, user_id.clone())) - .and_then(|since_member| { - serde_json::from_value::< - Raw, - >( - since_member.content.clone() - ) - .expect("Raw::from_value always works") - .deserialize() - .map_err(|_| { - Error::bad_database("Invalid PDU in database.") - }) - .ok() - }) - }) - .map_or(MembershipState::Leave, |member| member.membership) - }); + let since_membership = db + .rooms + .state_get( + since_shortstatehash, + &EventType::RoomMember, + user_id.as_str(), + )? + .and_then(|since_member| { + serde_json::from_value::< + Raw, + >(since_member.content.clone()) + .expect("Raw::from_value always works") + .deserialize() + .map_err(|_| Error::bad_database("Invalid PDU in database.")) + .ok() + }) + .map_or(MembershipState::Leave, |member| member.membership); let user_id = UserId::try_from(user_id.clone()) .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; @@ -362,10 +482,6 @@ async fn sync_helper( } } - let joined_since_last_sync = since_sender_member.map_or(true, |member| { - member.map_or(true, |member| member.membership != MembershipState::Join) - }); - if joined_since_last_sync && encrypted_room || new_encrypted_room { // If the user is in a new encrypted room, give them all joined users device_list_updates.extend( @@ -385,100 +501,11 @@ async fn sync_helper( } let (joined_member_count, invited_member_count, heroes) = if send_member_count { - let joined_member_count = db.rooms.room_members(&room_id).count(); - let invited_member_count = db.rooms.room_members_invited(&room_id).count(); - - // Recalculate heroes (first 5 members) - let mut heroes = Vec::new(); - - if joined_member_count + invited_member_count <= 5 { - // Go through all PDUs and for each member event, check if the user is still joined or - // invited until we have 5 or we reach the end - - for hero in db - .rooms - .all_pdus(&sender_user, &room_id) - .filter_map(|pdu| pdu.ok()) // Ignore all broken pdus - .filter(|(_, pdu)| pdu.kind == EventType::RoomMember) - .map(|(_, pdu)| { - let content = serde_json::from_value::< - ruma::events::room::member::MemberEventContent, - >(pdu.content.clone()) - .map_err(|_| { - Error::bad_database("Invalid member event in database.") - })?; - - if let Some(state_key) = &pdu.state_key { - let user_id = - UserId::try_from(state_key.clone()).map_err(|_| { - Error::bad_database("Invalid UserId in member PDU.") - })?; - - // The membership was and still is invite or join - if matches!( - content.membership, - MembershipState::Join | MembershipState::Invite - ) && (db.rooms.is_joined(&user_id, &room_id)? - || db.rooms.is_invited(&user_id, &room_id)?) - { - Ok::<_, Error>(Some(state_key.clone())) - } else { - Ok(None) - } - } else { - Ok(None) - } - }) - // Filter out buggy users - .filter_map(|u| u.ok()) - // Filter for possible heroes - .flatten() - { - if heroes.contains(&hero) || hero == sender_user.as_str() { - continue; - } - - heroes.push(hero); - } - } - - ( - Some(joined_member_count), - Some(invited_member_count), - heroes, - ) + calculate_counts() } else { (None, None, Vec::new()) }; - let state_events = if joined_since_last_sync { - current_state - .iter() - .map(|(_, pdu)| pdu.to_sync_state_event()) - .collect() - } else { - match since_state { - None => Vec::new(), - Some(Some(since_state)) => current_state - .iter() - .filter(|(key, value)| { - since_state.get(key).map(|e| &e.event_id) != Some(&value.event_id) - }) - .filter(|(_, value)| { - !timeline_pdus.iter().any(|(_, timeline_pdu)| { - timeline_pdu.kind == value.kind - && timeline_pdu.state_key == value.state_key - }) - }) - .map(|(_, pdu)| pdu.to_sync_state_event()) - .collect(), - Some(None) => current_state - .iter() - .map(|(_, pdu)| pdu.to_sync_state_event()) - .collect(), - } - }; - ( heroes, joined_member_count, @@ -486,8 +513,6 @@ async fn sync_helper( joined_since_last_sync, state_events, ) - } else { - (Vec::new(), None, None, false, Vec::new()) }; // Look for device list updates in this room @@ -578,7 +603,10 @@ async fn sync_helper( events: room_events, }, state: sync_events::State { - events: state_events, + events: state_events + .iter() + .map(|pdu| pdu.to_sync_state_event()) + .collect(), }, ephemeral: sync_events::Ephemeral { events: edus }, }; @@ -625,6 +653,19 @@ async fn sync_helper( let mut left_rooms = BTreeMap::new(); for result in db.rooms.rooms_left(&sender_user) { let (room_id, left_state_events) = result?; + + // Get and drop the lock to wait for remaining operations to finish + let mutex = Arc::clone( + db.globals + .roomid_mutex + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), + ); + let mutex_lock = mutex.lock().await; + drop(mutex_lock); + let left_count = db.rooms.get_left_count(&room_id, &sender_user)?; // Left before last sync @@ -651,6 +692,19 @@ async fn sync_helper( let mut invited_rooms = BTreeMap::new(); for result in db.rooms.rooms_invited(&sender_user) { let (room_id, invite_state_events) = result?; + + // Get and drop the lock to wait for remaining operations to finish + let mutex = Arc::clone( + db.globals + .roomid_mutex + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), + ); + let mutex_lock = mutex.lock().await; + drop(mutex_lock); + let invite_count = db.rooms.get_invite_count(&room_id, &sender_user)?; // Invited before last sync diff --git a/src/client_server/to_device.rs b/src/client_server/to_device.rs index 3bb135e7..9faa2555 100644 --- a/src/client_server/to_device.rs +++ b/src/client_server/to_device.rs @@ -19,7 +19,9 @@ pub async fn send_event_to_device_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_deref(); + // TODO: uncomment when https://github.com/vector-im/element-android/issues/3589 is solved // Check if this is a new transaction id + /* if db .transaction_ids .existing_txnid(sender_user, sender_device, &body.txn_id)? @@ -27,6 +29,7 @@ pub async fn send_event_to_device_route( { return Ok(send_event_to_device::Response.into()); } + */ for (target_user_id, map) in &body.messages { for (target_device_id_maybe, event) in map { diff --git a/src/database.rs b/src/database.rs index 5a896a86..c39f0fbd 100644 --- a/src/database.rs +++ b/src/database.rs @@ -279,7 +279,7 @@ impl Database { eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?, prevevent_parent: builder.open_tree("prevevent_parent")?, - pdu_cache: RwLock::new(LruCache::new(1_000_000)), + pdu_cache: RwLock::new(LruCache::new(10_000)), }, account_data: account_data::AccountData { roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, diff --git a/src/database/abstraction/sled.rs b/src/database/abstraction/sled.rs index 271be1e9..e58184df 100644 --- a/src/database/abstraction/sled.rs +++ b/src/database/abstraction/sled.rs @@ -64,7 +64,7 @@ impl Tree for SledEngineTree { backwards: bool, ) -> Box, Vec)> + Send> { let iter = if backwards { - self.0.range(..from) + self.0.range(..=from) } else { self.0.range(from..) }; diff --git a/src/database/admin.rs b/src/database/admin.rs index cd5fa847..d8b7ae5e 100644 --- a/src/database/admin.rs +++ b/src/database/admin.rs @@ -10,7 +10,7 @@ use ruma::{ events::{room::message, EventType}, UserId, }; -use tokio::sync::{RwLock, RwLockReadGuard}; +use tokio::sync::{MutexGuard, RwLock, RwLockReadGuard}; pub enum AdminCommand { RegisterAppservice(serde_yaml::Value), @@ -48,38 +48,51 @@ impl Admin { ) .unwrap(); - if conduit_room.is_none() { - warn!("Conduit instance does not have an #admins room. Logging to that room will not work. Restart Conduit after creating a user to fix this."); - } + let conduit_room = match conduit_room { + None => { + warn!("Conduit instance does not have an #admins room. Logging to that room will not work. Restart Conduit after creating a user to fix this."); + return; + } + Some(r) => r, + }; drop(guard); - let send_message = - |message: message::MessageEventContent, guard: RwLockReadGuard<'_, Database>| { - if let Some(conduit_room) = &conduit_room { - guard - .rooms - .build_and_append_pdu( - PduBuilder { - event_type: EventType::RoomMessage, - content: serde_json::to_value(message) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: None, - redacts: None, - }, - &conduit_user, - &conduit_room, - &guard, - ) - .unwrap(); - } - }; + let send_message = |message: message::MessageEventContent, + guard: RwLockReadGuard<'_, Database>, + mutex_lock: &MutexGuard<'_, ()>| { + guard + .rooms + .build_and_append_pdu( + PduBuilder { + event_type: EventType::RoomMessage, + content: serde_json::to_value(message) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: None, + redacts: None, + }, + &conduit_user, + &conduit_room, + &guard, + mutex_lock, + ) + .unwrap(); + }; loop { tokio::select! { Some(event) = receiver.next() => { let guard = db.read().await; + let mutex = Arc::clone( + guard.globals + .roomid_mutex + .write() + .unwrap() + .entry(conduit_room.clone()) + .or_default(), + ); + let mutex_lock = mutex.lock().await; match event { AdminCommand::RegisterAppservice(yaml) => { @@ -93,15 +106,17 @@ impl Admin { count, appservices.into_iter().filter_map(|r| r.ok()).collect::>().join(", ") ); - send_message(message::MessageEventContent::text_plain(output), guard); + send_message(message::MessageEventContent::text_plain(output), guard, &mutex_lock); } else { - send_message(message::MessageEventContent::text_plain("Failed to get appservices."), guard); + send_message(message::MessageEventContent::text_plain("Failed to get appservices."), guard, &mutex_lock); } } AdminCommand::SendMessage(message) => { - send_message(message, guard) + send_message(message, guard, &mutex_lock); } } + + drop(mutex_lock); } } } diff --git a/src/database/globals.rs b/src/database/globals.rs index 307ec400..0e722973 100644 --- a/src/database/globals.rs +++ b/src/database/globals.rs @@ -5,7 +5,7 @@ use ruma::{ client::r0::sync::sync_events, federation::discovery::{ServerSigningKeys, VerifyKey}, }, - DeviceId, EventId, MilliSecondsSinceUnixEpoch, ServerName, ServerSigningKeyId, UserId, + DeviceId, EventId, MilliSecondsSinceUnixEpoch, RoomId, ServerName, ServerSigningKeyId, UserId, }; use rustls::{ServerCertVerifier, WebPKIVerifier}; use std::{ @@ -16,7 +16,7 @@ use std::{ sync::{Arc, RwLock}, time::{Duration, Instant}, }; -use tokio::sync::{broadcast, watch::Receiver, Semaphore}; +use tokio::sync::{broadcast, watch::Receiver, Mutex, Semaphore}; use trust_dns_resolver::TokioAsyncResolver; use super::abstraction::Tree; @@ -45,6 +45,8 @@ pub struct Globals { pub bad_signature_ratelimiter: Arc, RateLimitState>>>, pub servername_ratelimiter: Arc, Arc>>>, pub sync_receivers: RwLock), SyncHandle>>, + pub roomid_mutex: RwLock>>>, + pub roomid_mutex_federation: RwLock>>>, // this lock will be held longer pub rotate: RotationHandler, } @@ -197,6 +199,8 @@ impl Globals { bad_event_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), bad_signature_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), servername_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), + roomid_mutex: RwLock::new(BTreeMap::new()), + roomid_mutex_federation: RwLock::new(BTreeMap::new()), sync_receivers: RwLock::new(BTreeMap::new()), rotate: RotationHandler::new(), }; diff --git a/src/database/rooms.rs b/src/database/rooms.rs index 4d66f9f1..1542db85 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -2,6 +2,7 @@ mod edus; pub use edus::RoomEdus; use member::MembershipState; +use tokio::sync::MutexGuard; use crate::{pdu::PduBuilder, utils, Database, Error, PduEvent, Result}; use log::{debug, error, warn}; @@ -21,7 +22,7 @@ use ruma::{ uint, EventId, RoomAliasId, RoomId, RoomVersionId, ServerName, UserId, }; use std::{ - collections::{BTreeMap, HashMap, HashSet}, + collections::{BTreeMap, BTreeSet, HashMap, HashSet}, convert::{TryFrom, TryInto}, mem, sync::{Arc, RwLock}, @@ -89,7 +90,7 @@ pub struct Rooms { impl Rooms { /// Builds a StateMap by iterating over all keys that start /// with state_hash, this gives the full state for the given state_hash. - pub fn state_full_ids(&self, shortstatehash: u64) -> Result> { + pub fn state_full_ids(&self, shortstatehash: u64) -> Result> { Ok(self .stateid_shorteventid .scan_prefix(shortstatehash.to_be_bytes().to_vec()) @@ -666,11 +667,10 @@ impl Rooms { &self, pdu: &PduEvent, mut pdu_json: CanonicalJsonObject, - count: u64, - pdu_id: &[u8], leaves: &[EventId], db: &Database, - ) -> Result<()> { + ) -> Result> { + // returns pdu id // Make unsigned fields correct. This is not properly documented in the spec, but state // events need to have previous content in the unsigned field, so clients can easily // interpret things like membership changes @@ -708,20 +708,30 @@ impl Rooms { self.replace_pdu_leaves(&pdu.room_id, leaves)?; + let count1 = db.globals.next_count()?; // Mark as read first so the sending client doesn't get a notification even if appending // fails self.edus - .private_read_set(&pdu.room_id, &pdu.sender, count, &db.globals)?; + .private_read_set(&pdu.room_id, &pdu.sender, count1, &db.globals)?; self.reset_notification_counts(&pdu.sender, &pdu.room_id)?; + let count2 = db.globals.next_count()?; + let mut pdu_id = pdu.room_id.as_bytes().to_vec(); + pdu_id.push(0xff); + pdu_id.extend_from_slice(&count2.to_be_bytes()); + + // There's a brief moment of time here where the count is updated but the pdu does not + // exist. This could theoretically lead to dropped pdus, but it's extremely rare + self.pduid_pdu.insert( - pdu_id, + &pdu_id, &serde_json::to_vec(&pdu_json).expect("CanonicalJsonObject is always a valid"), )?; // This also replaces the eventid of any outliers with the correct // pduid, removing the place holder. - self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?; + self.eventid_pduid + .insert(pdu.event_id.as_bytes(), &pdu_id)?; // See if the event matches any known pushers for user in db @@ -909,7 +919,7 @@ impl Rooms { _ => {} } - Ok(()) + Ok(pdu_id) } pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { @@ -1198,6 +1208,7 @@ impl Rooms { sender: &UserId, room_id: &RoomId, db: &Database, + _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room mutex ) -> Result { let PduBuilder { event_type, @@ -1206,7 +1217,7 @@ impl Rooms { state_key, redacts, } = pdu_builder; - // TODO: Make sure this isn't called twice in parallel + let prev_events = self .get_pdu_leaves(&room_id)? .into_iter() @@ -1354,11 +1365,9 @@ impl Rooms { // pdu without it's state. This is okay because append_pdu can't fail. let statehashid = self.append_to_state(&pdu, &db.globals)?; - self.append_pdu( + let pdu_id = self.append_pdu( &pdu, pdu_json, - count, - &pdu_id, // Since this PDU references all pdu_leaves we can update the leaves // of the room &[pdu.event_id.clone()], @@ -1495,7 +1504,7 @@ impl Rooms { prefix.push(0xff); let mut current = prefix.clone(); - current.extend_from_slice(&until.to_be_bytes()); + current.extend_from_slice(&(until.saturating_sub(1)).to_be_bytes()); // -1 because we don't want event at `until` let current: &[u8] = ¤t; @@ -1782,6 +1791,16 @@ impl Rooms { db, )?; } else { + let mutex = Arc::clone( + db.globals + .roomid_mutex + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), + ); + let mutex_lock = mutex.lock().await; + let mut event = serde_json::from_value::>( self.room_state_get(room_id, &EventType::RoomMember, &user_id.to_string())? .ok_or(Error::BadRequest( @@ -1809,6 +1828,7 @@ impl Rooms { user_id, room_id, db, + &mutex_lock, )?; } diff --git a/src/database/users.rs b/src/database/users.rs index f99084fa..1480d3fa 100644 --- a/src/database/users.rs +++ b/src/database/users.rs @@ -726,10 +726,9 @@ impl Users { json.insert("sender".to_owned(), sender.to_string().into()); json.insert("content".to_owned(), content); - self.todeviceid_events.insert( - &key, - &serde_json::to_vec(&json).expect("Map::to_vec always works"), - )?; + let value = serde_json::to_vec(&json).expect("Map::to_vec always works"); + + self.todeviceid_events.insert(&key, &value)?; Ok(()) } @@ -774,7 +773,7 @@ impl Users { for (key, _) in self .todeviceid_events - .iter_from(&last, true) + .iter_from(&last, true) // this includes last .take_while(move |(k, _)| k.starts_with(&prefix)) .map(|(key, _)| { Ok::<_, Error>(( diff --git a/src/ruma_wrapper.rs b/src/ruma_wrapper.rs index 4f6318a2..a4beac64 100644 --- a/src/ruma_wrapper.rs +++ b/src/ruma_wrapper.rs @@ -320,6 +320,7 @@ where }), Err(e) => { warn!("{:?}", e); + // Bad Json Failure((Status::new(583), ())) } } diff --git a/src/server_server.rs b/src/server_server.rs index 25cdd99e..fb49d0ca 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -46,7 +46,7 @@ use ruma::{ receipt::ReceiptType, serde::Raw, signatures::{CanonicalJsonObject, CanonicalJsonValue}, - state_res::{self, Event, RoomVersion, StateMap}, + state_res::{self, RoomVersion, StateMap}, to_device::DeviceIdOrAllDevices, uint, EventId, MilliSecondsSinceUnixEpoch, RoomId, RoomVersionId, ServerName, ServerSigningKeyId, UserId, @@ -625,13 +625,44 @@ pub async fn send_transaction_message_route( } }; + // 0. Check the server is in the room + let room_id = match value + .get("room_id") + .and_then(|id| RoomId::try_from(id.as_str()?).ok()) + { + Some(id) => id, + None => { + // Event is invalid + resolved_map.insert(event_id, Err("Event needs a valid RoomId.".to_string())); + continue; + } + }; + + let mutex = Arc::clone( + db.globals + .roomid_mutex_federation + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), + ); + let mutex_lock = mutex.lock().await; let start_time = Instant::now(); resolved_map.insert( event_id.clone(), - handle_incoming_pdu(&body.origin, &event_id, value, true, &db, &pub_key_map) - .await - .map(|_| ()), + handle_incoming_pdu( + &body.origin, + &event_id, + &room_id, + value, + true, + &db, + &pub_key_map, + ) + .await + .map(|_| ()), ); + drop(mutex_lock); let elapsed = start_time.elapsed(); if elapsed > Duration::from_secs(1) { @@ -784,8 +815,8 @@ pub async fn send_transaction_message_route( type AsyncRecursiveResult<'a, T, E> = Pin> + 'a + Send>>; /// When receiving an event one needs to: -/// 0. Skip the PDU if we already know about it -/// 1. Check the server is in the room +/// 0. Check the server is in the room +/// 1. Skip the PDU if we already know about it /// 2. Check signatures, otherwise drop /// 3. Check content hash, redact if doesn't match /// 4. Fetch any missing auth events doing all checks listed here starting at 1. These are not @@ -810,6 +841,7 @@ type AsyncRecursiveResult<'a, T, E> = Pin( origin: &'a ServerName, event_id: &'a EventId, + room_id: &'a RoomId, value: BTreeMap, is_timeline_event: bool, db: &'a Database, @@ -817,24 +849,6 @@ pub fn handle_incoming_pdu<'a>( ) -> AsyncRecursiveResult<'a, Option>, String> { Box::pin(async move { // TODO: For RoomVersion6 we must check that Raw<..> is canonical do we anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json - - // 0. Skip the PDU if we already have it as a timeline event - if let Ok(Some(pdu_id)) = db.rooms.get_pdu_id(&event_id) { - return Ok(Some(pdu_id.to_vec())); - } - - // 1. Check the server is in the room - let room_id = match value - .get("room_id") - .and_then(|id| RoomId::try_from(id.as_str()?).ok()) - { - Some(id) => id, - None => { - // Event is invalid - return Err("Event needs a valid RoomId.".to_string()); - } - }; - match db.rooms.exists(&room_id) { Ok(true) => {} _ => { @@ -842,6 +856,11 @@ pub fn handle_incoming_pdu<'a>( } } + // 1. Skip the PDU if we already have it as a timeline event + if let Ok(Some(pdu_id)) = db.rooms.get_pdu_id(&event_id) { + return Ok(Some(pdu_id.to_vec())); + } + // We go through all the signatures we see on the value and fetch the corresponding signing // keys fetch_required_signing_keys(&value, &pub_key_map, db) @@ -901,7 +920,7 @@ pub fn handle_incoming_pdu<'a>( // 5. Reject "due to auth events" if can't get all the auth events or some of the auth events are also rejected "due to auth events" // EDIT: Step 5 is not applied anymore because it failed too often debug!("Fetching auth events for {}", incoming_pdu.event_id); - fetch_and_handle_events(db, origin, &incoming_pdu.auth_events, pub_key_map) + fetch_and_handle_events(db, origin, &incoming_pdu.auth_events, &room_id, pub_key_map) .await .map_err(|e| e.to_string())?; @@ -1002,13 +1021,13 @@ pub fn handle_incoming_pdu<'a>( if incoming_pdu.prev_events.len() == 1 { let prev_event = &incoming_pdu.prev_events[0]; - let state_vec = db + let state = db .rooms .pdu_shortstatehash(prev_event) .map_err(|_| "Failed talking to db".to_owned())? .map(|shortstatehash| db.rooms.state_full_ids(shortstatehash).ok()) .flatten(); - if let Some(mut state_vec) = state_vec { + if let Some(mut state) = state { if db .rooms .get_pdu(prev_event) @@ -1018,25 +1037,31 @@ pub fn handle_incoming_pdu<'a>( .state_key .is_some() { - state_vec.push(prev_event.clone()); + state.insert(prev_event.clone()); } state_at_incoming_event = Some( - fetch_and_handle_events(db, origin, &state_vec, pub_key_map) - .await - .map_err(|_| "Failed to fetch state events locally".to_owned())? - .into_iter() - .map(|pdu| { + fetch_and_handle_events( + db, + origin, + &state.into_iter().collect::>(), + &room_id, + pub_key_map, + ) + .await + .map_err(|_| "Failed to fetch state events locally".to_owned())? + .into_iter() + .map(|pdu| { + ( ( - ( - pdu.kind.clone(), - pdu.state_key - .clone() - .expect("events from state_full_ids are state events"), - ), - pdu, - ) - }) - .collect(), + pdu.kind.clone(), + pdu.state_key + .clone() + .expect("events from state_full_ids are state events"), + ), + pdu, + ) + }) + .collect(), ); } // TODO: set incoming_auth_events? @@ -1059,12 +1084,18 @@ pub fn handle_incoming_pdu<'a>( { Ok(res) => { debug!("Fetching state events at event."); - let state_vec = - match fetch_and_handle_events(&db, origin, &res.pdu_ids, pub_key_map).await - { - Ok(state) => state, - Err(_) => return Err("Failed to fetch state events.".to_owned()), - }; + let state_vec = match fetch_and_handle_events( + &db, + origin, + &res.pdu_ids, + &room_id, + pub_key_map, + ) + .await + { + Ok(state) => state, + Err(_) => return Err("Failed to fetch state events.".to_owned()), + }; let mut state = BTreeMap::new(); for pdu in state_vec { @@ -1090,8 +1121,14 @@ pub fn handle_incoming_pdu<'a>( } debug!("Fetching auth chain events at event."); - match fetch_and_handle_events(&db, origin, &res.auth_chain_ids, pub_key_map) - .await + match fetch_and_handle_events( + &db, + origin, + &res.auth_chain_ids, + &room_id, + pub_key_map, + ) + .await { Ok(state) => state, Err(_) => return Err("Failed to fetch auth chain.".to_owned()), @@ -1219,18 +1256,10 @@ pub fn handle_incoming_pdu<'a>( let mut auth_events = vec![]; for map in &fork_states { - let mut state_auth = vec![]; - for auth_id in map.values().flat_map(|pdu| &pdu.auth_events) { - match fetch_and_handle_events(&db, origin, &[auth_id.clone()], pub_key_map) - .await - { - // This should always contain exactly one element when Ok - Ok(events) => state_auth.extend_from_slice(&events), - Err(e) => { - debug!("Event was not present: {}", e); - } - } - } + let state_auth = map + .values() + .flat_map(|pdu| pdu.auth_events.clone()) + .collect(); auth_events.push(state_auth); } @@ -1245,10 +1274,7 @@ pub fn handle_incoming_pdu<'a>( .collect::>() }) .collect::>(), - auth_events - .into_iter() - .map(|pdus| pdus.into_iter().map(|pdu| pdu.event_id().clone()).collect()) - .collect(), + auth_events, &|id| { let res = db.rooms.get_pdu(id); if let Err(e) = &res { @@ -1282,11 +1308,13 @@ pub fn handle_incoming_pdu<'a>( pdu_id = Some( append_incoming_pdu( &db, + &room_id, &incoming_pdu, val, extremities, &state_at_incoming_event, ) + .await .map_err(|_| "Failed to add pdu to db.".to_owned())?, ); debug!("Appended incoming pdu."); @@ -1324,6 +1352,7 @@ pub(crate) fn fetch_and_handle_events<'a>( db: &'a Database, origin: &'a ServerName, events: &'a [EventId], + room_id: &'a RoomId, pub_key_map: &'a RwLock>>, ) -> AsyncRecursiveResult<'a, Vec>, Error> { Box::pin(async move { @@ -1377,6 +1406,7 @@ pub(crate) fn fetch_and_handle_events<'a>( match handle_incoming_pdu( origin, &event_id, + &room_id, value.clone(), false, db, @@ -1583,32 +1613,38 @@ pub(crate) async fn fetch_signing_keys( /// Append the incoming event setting the state snapshot to the state from the /// server that sent the event. #[tracing::instrument(skip(db))] -pub(crate) fn append_incoming_pdu( +async fn append_incoming_pdu( db: &Database, + room_id: &RoomId, pdu: &PduEvent, pdu_json: CanonicalJsonObject, new_room_leaves: HashSet, state: &StateMap>, ) -> Result> { - let count = db.globals.next_count()?; - let mut pdu_id = pdu.room_id.as_bytes().to_vec(); - pdu_id.push(0xff); - pdu_id.extend_from_slice(&count.to_be_bytes()); + let mutex = Arc::clone( + db.globals + .roomid_mutex + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), + ); + let mutex_lock = mutex.lock().await; // We append to state before appending the pdu, so we don't have a moment in time with the // pdu without it's state. This is okay because append_pdu can't fail. db.rooms .set_event_state(&pdu.event_id, state, &db.globals)?; - db.rooms.append_pdu( + let pdu_id = db.rooms.append_pdu( pdu, pdu_json, - count, - &pdu_id, &new_room_leaves.into_iter().collect::>(), &db, )?; + drop(mutex_lock); + for appservice in db.appservice.iter_all()?.filter_map(|r| r.ok()) { if let Some(namespaces) = appservice.1.get("namespaces") { let users = namespaces @@ -1872,7 +1908,11 @@ pub fn get_room_state_ids_route( "Pdu state not found.", ))?; - let pdu_ids = db.rooms.state_full_ids(shortstatehash)?; + let pdu_ids = db + .rooms + .state_full_ids(shortstatehash)? + .into_iter() + .collect(); let mut auth_chain_ids = BTreeSet::::new(); let mut todo = BTreeSet::new(); @@ -2118,18 +2158,36 @@ pub async fn create_join_event_route( ) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?; - let pdu_id = handle_incoming_pdu(&origin, &event_id, value, true, &db, &pub_key_map) - .await - .map_err(|_| { - Error::BadRequest( - ErrorKind::InvalidParam, - "Error while handling incoming PDU.", - ) - })? - .ok_or(Error::BadRequest( + let mutex = Arc::clone( + db.globals + .roomid_mutex_federation + .write() + .unwrap() + .entry(body.room_id.clone()) + .or_default(), + ); + let mutex_lock = mutex.lock().await; + let pdu_id = handle_incoming_pdu( + &origin, + &event_id, + &body.room_id, + value, + true, + &db, + &pub_key_map, + ) + .await + .map_err(|_| { + Error::BadRequest( ErrorKind::InvalidParam, - "Could not accept incoming PDU as timeline event.", - ))?; + "Error while handling incoming PDU.", + ) + })? + .ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Could not accept incoming PDU as timeline event.", + ))?; + drop(mutex_lock); let state_ids = db.rooms.state_full_ids(shortstatehash)?;