From cc801528899dd37afcf7669ae5ebfeb050fc1eb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Mon, 20 Jun 2022 12:08:58 +0200 Subject: [PATCH] refactor: split up force_state --- src/service/rooms/state/mod.rs | 54 ++------------------- src/service/rooms/state_compressor/mod.rs | 59 ++++++++++++++++++++++- 2 files changed, 62 insertions(+), 51 deletions(-) diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index eddfe9e0..da03ad4c 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -3,62 +3,16 @@ pub struct Service { } impl Service { - /// Force the creation of a new StateHash and insert it into the db. - /// - /// Whatever `state` is supplied to `force_state` becomes the new current room state snapshot. + /// Set the room to the given statehash and update caches. #[tracing::instrument(skip(self, new_state_ids_compressed, db))] pub fn force_state( &self, room_id: &RoomId, - new_state_ids_compressed: HashSet, + shortstatehash: u64, + statediffnew :HashSet, + statediffremoved :HashSet, db: &Database, ) -> Result<()> { - let previous_shortstatehash = self.d.current_shortstatehash(room_id)?; - - let state_hash = self.calculate_hash( - &new_state_ids_compressed - .iter() - .map(|bytes| &bytes[..]) - .collect::>(), - ); - - let (new_shortstatehash, already_existed) = - self.get_or_create_shortstatehash(&state_hash, &db.globals)?; - - if Some(new_shortstatehash) == previous_shortstatehash { - return Ok(()); - } - - let states_parents = previous_shortstatehash - .map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; - - let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() - { - let statediffnew: HashSet<_> = new_state_ids_compressed - .difference(&parent_stateinfo.1) - .copied() - .collect(); - - let statediffremoved: HashSet<_> = parent_stateinfo - .1 - .difference(&new_state_ids_compressed) - .copied() - .collect(); - - (statediffnew, statediffremoved) - } else { - (new_state_ids_compressed, HashSet::new()) - }; - - if !already_existed { - self.save_state_from_diff( - new_shortstatehash, - statediffnew.clone(), - statediffremoved, - 2, // every state change is 2 event changes on average - states_parents, - )?; - }; for event_id in statediffnew.into_iter().filter_map(|new| { self.parse_compressed_state_event(new) diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index a56c0f5f..197ce844 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -241,6 +241,64 @@ Ok(()) } + /// Returns the new shortstatehash + pub fn save_state( + room_id: &RoomId, + new_state_ids_compressed: HashSet, + ) -> Result<(u64, + HashSet, // added + HashSet)> // removed + { + let previous_shortstatehash = self.d.current_shortstatehash(room_id)?; + + let state_hash = self.calculate_hash( + &new_state_ids_compressed + .iter() + .map(|bytes| &bytes[..]) + .collect::>(), + ); + + let (new_shortstatehash, already_existed) = + self.get_or_create_shortstatehash(&state_hash, &db.globals)?; + + if Some(new_shortstatehash) == previous_shortstatehash { + return Ok(()); + } + + let states_parents = previous_shortstatehash + .map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; + + let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() + { + let statediffnew: HashSet<_> = new_state_ids_compressed + .difference(&parent_stateinfo.1) + .copied() + .collect(); + + let statediffremoved: HashSet<_> = parent_stateinfo + .1 + .difference(&new_state_ids_compressed) + .copied() + .collect(); + + (statediffnew, statediffremoved) + } else { + (new_state_ids_compressed, HashSet::new()) + }; + + if !already_existed { + self.save_state_from_diff( + new_shortstatehash, + statediffnew.clone(), + statediffremoved, + 2, // every state change is 2 event changes on average + states_parents, + )?; + }; + + Ok((new_shortstatehash, statediffnew, statediffremoved)) + } + #[tracing::instrument(skip(self))] pub fn get_auth_chain_from_cache<'a>( &'a self, @@ -298,4 +356,3 @@ Ok(()) } -