diff --git a/src/database/key_value/rooms/state_cache.rs b/src/database/key_value/rooms/state_cache.rs index 49e3842b..1db3a0aa 100644 --- a/src/database/key_value/rooms/state_cache.rs +++ b/src/database/key_value/rooms/state_cache.rs @@ -66,6 +66,8 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { self.roomuserid_joined.remove(&roomuser_id)?; self.userroomid_leftstate.remove(&userroom_id)?; self.roomuserid_leftcount.remove(&roomuser_id)?; + self.userroomid_knockedstate.remove(&userroom_id)?; + self.roomuserid_knockedcount.remove(&roomuser_id)?; Ok(()) } @@ -91,6 +93,36 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { self.roomuserid_joined.remove(&roomuser_id)?; self.userroomid_invitestate.remove(&userroom_id)?; self.roomuserid_invitecount.remove(&roomuser_id)?; + self.userroomid_knockedstate.remove(&userroom_id)?; + self.roomuserid_knockedcount.remove(&roomuser_id)?; + + Ok(()) + } + + fn mark_as_knocked(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + let mut roomuser_id = room_id.as_bytes().to_vec(); + roomuser_id.push(0xff); + roomuser_id.extend_from_slice(user_id.as_bytes()); + + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.as_bytes()); + + self.userroomid_knockedstate.insert( + &userroom_id, + &serde_json::to_vec(&Vec::>::new()) + .expect("state to bytes always works"), + )?; + self.roomuserid_knockedcount.insert( + &roomuser_id, + &services().globals.next_count()?.to_be_bytes(), + )?; + self.userroomid_joined.remove(&userroom_id)?; + self.roomuserid_joined.remove(&roomuser_id)?; + self.userroomid_invitestate.remove(&userroom_id)?; + self.roomuserid_invitecount.remove(&roomuser_id)?; + self.userroomid_invitestate.remove(&userroom_id)?; + self.roomuserid_invitecount.remove(&roomuser_id)?; Ok(()) } @@ -604,4 +636,13 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some()) } + + #[tracing::instrument(skip(self))] + fn is_knocked(&self, user_id: &UserId, room_id: &RoomId) -> Result { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.as_bytes()); + + Ok(self.userroomid_knockedstate.get(&userroom_id)?.is_some()) + } } diff --git a/src/database/mod.rs b/src/database/mod.rs index 5171d4bb..16b2f556 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -12,7 +12,10 @@ use lru_cache::LruCache; use ruma::{ events::{ push_rules::{PushRulesEvent, PushRulesEventContent}, - room::message::RoomMessageEventContent, + room::{ + member::{MembershipState, RoomMemberEventContent}, + message::RoomMessageEventContent, + }, GlobalAccountDataEvent, GlobalAccountDataEventType, StateEventType, }, push::Ruleset, @@ -100,6 +103,8 @@ pub struct KeyValueDatabase { pub(super) roomuserid_invitecount: Arc, // InviteCount = Count pub(super) userroomid_leftstate: Arc, pub(super) roomuserid_leftcount: Arc, + pub(super) userroomid_knockedstate: Arc, + pub(super) roomuserid_knockedcount: Arc, pub(super) alias_userid: Arc, // User who created the alias @@ -328,6 +333,8 @@ impl KeyValueDatabase { roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?, userroomid_leftstate: builder.open_tree("userroomid_leftstate")?, roomuserid_leftcount: builder.open_tree("roomuserid_leftcount")?, + userroomid_knockedstate: builder.open_tree("userroomid_knockedstate")?, + roomuserid_knockedcount: builder.open_tree("roomuserid_knockedcount")?, alias_userid: builder.open_tree("alias_userid")?, @@ -941,6 +948,60 @@ impl KeyValueDatabase { warn!("Migration: 12 -> 13 finished"); } + if services().globals.database_version()? < 14 { + for username in services().users.list_local_users()? { + let user = match UserId::parse_with_server_name( + username.clone(), + services().globals.server_name(), + ) { + Ok(u) => u, + Err(e) => { + warn!("Invalid username {username}: {e}"); + continue; + } + }; + + for room in + services() + .rooms + .metadata + .iter_ids() + .filter_map(|room_id| match room_id { + Ok(room_id) => Some(room_id), + Err(e) => { + warn!("Invalid room id: {e}"); + None + } + }) + { + if services() + .rooms + .state_accessor + .room_state_get(&room, &StateEventType::RoomMember, user.as_str())? + .map(|pdu| { + serde_json::from_str(pdu.content.get()) + .map_err(|_| Error::bad_database("Invalid PDU in database.")) + }) + .transpose()? + .map(|content: RoomMemberEventContent| content.membership) + == Some(MembershipState::Knock) + { + services().rooms.state_cache.update_membership( + &room, + &user, + MembershipState::Knock, + &user, + None, + false, + )?; + } + } + } + services().globals.bump_database_version(14)?; + + warn!("Migration: 13 -> 14 finished"); + } + assert_eq!( services().globals.database_version().unwrap(), latest_database_version diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index b511919a..6d44b0df 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -18,6 +18,9 @@ pub trait Data: Send + Sync { ) -> Result<()>; fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; + /// Marks a user as knocking on a room + fn mark_as_knocked(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; + fn update_joined_count(&self, room_id: &RoomId) -> Result<()>; fn get_our_real_users(&self, room_id: &RoomId) -> Result>>; @@ -106,4 +109,6 @@ pub trait Data: Send + Sync { fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result; fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result; + + fn is_knocked(&self, user_id: &UserId, room_id: &RoomId) -> Result; } diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index c108695d..cef24098 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -181,6 +181,9 @@ impl Service { MembershipState::Leave | MembershipState::Ban => { self.db.mark_as_left(user_id, room_id)?; } + MembershipState::Knock => { + self.db.mark_as_knocked(user_id, room_id)?; + } _ => {} } @@ -350,4 +353,9 @@ impl Service { pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result { self.db.is_left(user_id, room_id) } + + #[tracing::instrument(skip(self))] + pub fn is_knocked(&self, user_id: &UserId, room_id: &RoomId) -> Result { + self.db.is_knocked(user_id, room_id) + } }