diff --git a/src/client_server.rs b/src/client_server.rs index de76eef7..cd61746b 100644 --- a/src/client_server.rs +++ b/src/client_server.rs @@ -1,5 +1,5 @@ use std::{ - collections::{hash_map, BTreeMap, HashMap}, + collections::{hash_map, BTreeMap, HashMap, HashSet}, convert::{TryFrom, TryInto}, time::{Duration, SystemTime}, }; @@ -898,7 +898,7 @@ pub fn upload_keys_route( // This check is needed to assure that signatures are kept if db.users.get_device_keys(sender_id, device_id)?.is_none() { db.users - .add_device_keys(sender_id, device_id, device_keys, &db.globals)?; + .add_device_keys(sender_id, device_id, device_keys, &db.rooms, &db.globals)?; } } @@ -2518,20 +2518,41 @@ pub async fn sync_events_route( .unwrap_or(0); let mut presence_updates = HashMap::new(); + let mut device_list_updates = HashSet::new(); for room_id in db.rooms.rooms_joined(&sender_id) { let room_id = room_id?; - let mut pdus = db + let mut non_timeline_pdus = db .rooms .pdus_since(&sender_id, &room_id, since)? - .filter_map(|r| r.ok()) // Filter out buggy events + .filter_map(|r| r.ok()); // Filter out buggy events + + // Take the last 10 events for the timeline + let timeline_pdus = non_timeline_pdus + .by_ref() + .rev() + .take(10) + .collect::>() + .into_iter() + .rev() .collect::>(); + // They /sync response doesn't always return all messages, so we say the output is + // limited unless there are events in non_timeline_pdus + //let mut limited = false; + + let mut state_pdus = Vec::new(); + for pdu in non_timeline_pdus { + if pdu.state_key.is_some() { + state_pdus.push(pdu); + } + } + let mut send_member_count = false; let mut joined_since_last_sync = false; let mut send_notification_counts = false; - for pdu in &pdus { + for pdu in db.rooms.pdus_since(&sender_id, &room_id, since)?.filter_map(|r| r.ok()) { send_notification_counts = true; if pdu.kind == EventType::RoomMember { send_member_count = true; @@ -2544,8 +2565,8 @@ pub async fn sync_events_route( .map_err(|_| Error::bad_database("Invalid PDU in database."))?; if content.membership == ruma::events::room::member::MembershipState::Join { joined_since_last_sync = true; - // Both send_member_count and joined_since_last_sync are set. There's nothing more - // to do + // Both send_member_count and joined_since_last_sync are set. There's + // nothing more to do break; } } @@ -2574,7 +2595,7 @@ pub async fn sync_events_route( let content = serde_json::from_value::< Raw, >(pdu.content.clone()) - .map_err(|_| Error::bad_database("Invalid member event in database."))? + .expect("Raw::from_value always works") .deserialize() .map_err(|_| Error::bad_database("Invalid member event in database."))?; @@ -2592,7 +2613,7 @@ pub async fn sync_events_route( .content .clone(), ) - .map_err(|_| Error::bad_database("Invalid member event in database."))? + .expect("Raw::from_value always works") .deserialize() .map_err(|_| { Error::bad_database("Invalid member event in database.") @@ -2659,15 +2680,7 @@ pub async fn sync_events_route( None }; - // They /sync response doesn't always return all messages, so we say the output is - // limited unless there are enough events - let mut limited = true; - pdus = pdus.split_off(pdus.len().checked_sub(10).unwrap_or_else(|| { - limited = false; - 0 - })); - - let prev_batch = pdus.first().map_or(Ok::<_, Error>(None), |e| { + let prev_batch = timeline_pdus.first().map_or(Ok::<_, Error>(None), |e| { Ok(Some( db.rooms .get_pdu_count(&e.event_id)? @@ -2676,7 +2689,7 @@ pub async fn sync_events_route( )) })?; - let room_events = pdus + let room_events = timeline_pdus .into_iter() .map(|pdu| pdu.to_sync_room_event()) .collect::>(); @@ -2728,7 +2741,7 @@ pub async fn sync_events_route( notification_count, }, timeline: sync_events::Timeline { - limited: limited || joined_since_last_sync, + limited: false || joined_since_last_sync, prev_batch, events: room_events, }, @@ -2751,6 +2764,13 @@ pub async fn sync_events_route( joined_rooms.insert(room_id.clone(), joined_room); } + // Look for device list updates in this room + device_list_updates.extend( + db.users + .keys_changed(&room_id, since) + .filter_map(|r| r.ok()), + ); + // Take presence updates from this room for (user_id, presence) in db.rooms @@ -2885,14 +2905,7 @@ pub async fn sync_events_route( .collect::>(), }, device_lists: sync_events::DeviceLists { - changed: if since != 0 { - db.users - .keys_changed(since) - .filter_map(|u| u.ok()) - .collect() // Filter out buggy events - } else { - Vec::new() - }, + changed: device_list_updates.into_iter().collect(), left: Vec::new(), // TODO }, device_one_time_keys_count: Default::default(), // TODO @@ -3450,6 +3463,7 @@ pub fn upload_signing_keys_route( &master_key, &body.self_signing_key, &body.user_signing_key, + &db.rooms, &db.globals, )?; } @@ -3500,8 +3514,14 @@ pub fn upload_signatures_route( ))? .to_owned(), ); - db.users - .sign_key(&user_id, &key_id, signature, &sender_id, &db.globals)?; + db.users.sign_key( + &user_id, + &key_id, + signature, + &sender_id, + &db.rooms, + &db.globals, + )?; } } } diff --git a/src/database/rooms.rs b/src/database/rooms.rs index fe5721c7..4cd47a17 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -611,44 +611,29 @@ impl Rooms { self.pdus_since(user_id, room_id, 0) } - /// Returns an iterator over all events in a room that happened after the event with id `since`. + /// Returns an iterator over all events in a room that happened after the event with id `since` + /// in reverse-chronological order. pub fn pdus_since( &self, user_id: &UserId, room_id: &RoomId, since: u64, - ) -> Result>> { - // Create the first part of the full pdu id - let mut pdu_id = room_id.to_string().as_bytes().to_vec(); - pdu_id.push(0xff); - pdu_id.extend_from_slice(&(since).to_be_bytes()); - - self.pdus_since_pduid(user_id, room_id, &pdu_id) - } - - /// Returns an iterator over all events in a room that happened after the event with id `since`. - pub fn pdus_since_pduid( - &self, - user_id: &UserId, - room_id: &RoomId, - pdu_id: &[u8], - ) -> Result>> { - // Create the first part of the full pdu id + ) -> Result>> { let mut prefix = room_id.to_string().as_bytes().to_vec(); prefix.push(0xff); + // Skip the first pdu if it's exactly at since, because we sent that last time + let mut first_pdu_id = prefix.clone(); + first_pdu_id.extend_from_slice(&(since+1).to_be_bytes()); + + let mut last_pdu_id = prefix.clone(); + last_pdu_id.extend_from_slice(&u64::MAX.to_be_bytes()); + let user_id = user_id.clone(); Ok(self .pduid_pdu - .range(pdu_id..) - // Skip the first pdu if it's exactly at since, because we sent that last time - .skip(if self.pduid_pdu.get(pdu_id)?.is_some() { - 1 - } else { - 0 - }) + .range(first_pdu_id..last_pdu_id) .filter_map(|r| r.ok()) - .take_while(move |(k, _)| k.starts_with(&prefix)) .map(move |(_, v)| { let mut pdu = serde_json::from_slice::(&v) .map_err(|_| Error::bad_database("PDU in db is invalid."))?; diff --git a/src/database/users.rs b/src/database/users.rs index 5030f32e..7fbdd806 100644 --- a/src/database/users.rs +++ b/src/database/users.rs @@ -9,7 +9,7 @@ use ruma::{ }, }, events::{AnyToDeviceEvent, EventType}, - DeviceId, Raw, UserId, + DeviceId, Raw, UserId, RoomId, }; use std::{collections::BTreeMap, convert::TryFrom, mem, time::SystemTime}; @@ -22,7 +22,7 @@ pub struct Users { pub(super) token_userdeviceid: sled::Tree, pub(super) onetimekeyid_onetimekeys: sled::Tree, // OneTimeKeyId = UserId + AlgorithmAndDeviceId - pub(super) keychangeid_userid: sled::Tree, // KeyChangeId = Count + pub(super) keychangeid_userid: sled::Tree, // KeyChangeId = RoomId + Count pub(super) keyid_key: sled::Tree, // KeyId = UserId + KeyId (depends on key type) pub(super) userid_masterkeyid: sled::Tree, pub(super) userid_selfsigningkeyid: sled::Tree, @@ -371,6 +371,7 @@ impl Users { user_id: &UserId, device_id: &DeviceId, device_keys: &DeviceKeys, + rooms: &super::rooms::Rooms, globals: &super::globals::Globals, ) -> Result<()> { let mut userdeviceid = user_id.to_string().as_bytes().to_vec(); @@ -382,8 +383,15 @@ impl Users { &*serde_json::to_string(&device_keys).expect("DeviceKeys::to_string always works"), )?; - self.keychangeid_userid - .insert(globals.next_count()?.to_be_bytes(), &*user_id.to_string())?; + let count = globals.next_count()?.to_be_bytes(); + for room_id in rooms.rooms_joined(&user_id) { + let mut key = room_id?.to_string().as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(&count); + + self.keychangeid_userid + .insert(key, &*user_id.to_string())?; + } Ok(()) } @@ -394,6 +402,7 @@ impl Users { master_key: &CrossSigningKey, self_signing_key: &Option, user_signing_key: &Option, + rooms: &super::rooms::Rooms, globals: &super::globals::Globals, ) -> Result<()> { // TODO: Check signatures @@ -482,8 +491,15 @@ impl Users { .insert(&*user_id.to_string(), user_signing_key_key)?; } - self.keychangeid_userid - .insert(globals.next_count()?.to_be_bytes(), &*user_id.to_string())?; + let count = globals.next_count()?.to_be_bytes(); + for room_id in rooms.rooms_joined(&user_id) { + let mut key = room_id?.to_string().as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(&count); + + self.keychangeid_userid + .insert(key, &*user_id.to_string())?; + } Ok(()) } @@ -494,6 +510,7 @@ impl Users { key_id: &str, signature: (String, String), sender_id: &UserId, + rooms: &super::rooms::Rooms, globals: &super::globals::Globals, ) -> Result<()> { let mut key = target_id.to_string().as_bytes().to_vec(); @@ -525,19 +542,33 @@ impl Users { .expect("CrossSigningKey::to_string always works"), )?; - self.keychangeid_userid - .insert(globals.next_count()?.to_be_bytes(), &*target_id.to_string())?; + // TODO: Should we notify about this change? + let count = globals.next_count()?.to_be_bytes(); + for room_id in rooms.rooms_joined(&target_id) { + let mut key = room_id?.to_string().as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(&count); + + self.keychangeid_userid + .insert(key, &*target_id.to_string())?; + } Ok(()) } - pub fn keys_changed(&self, since: u64) -> impl Iterator> { + pub fn keys_changed(&self, room_id: &RoomId, since: u64) -> impl Iterator> { + let mut prefix = room_id.to_string().as_bytes().to_vec(); + prefix.push(0xff); + let mut start = prefix.clone(); + start.extend_from_slice(&(since + 1).to_be_bytes()); + self.keychangeid_userid - .range((since + 1).to_be_bytes()..) - .values() - .map(|bytes| { + .range(start..) + .filter_map(|r| r.ok()) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(|(_, bytes)| { Ok( - UserId::try_from(utils::string_from_bytes(&bytes?).map_err(|_| { + UserId::try_from(utils::string_from_bytes(&bytes).map_err(|_| { Error::bad_database( "User ID in devicekeychangeid_userid is invalid unicode.", )