diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index 380f86cf..36fa1fcd 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -5,11 +5,10 @@ use ruma::{ api::{ client::{ error::ErrorKind, - space::{get_hierarchy, SpaceHierarchyRoomsChunk, SpaceRoomJoinRule}, + space::{get_hierarchy, SpaceHierarchyRoomsChunk}, }, federation, }, - directory::PublicRoomJoinRule, events::{ room::{ avatar::RoomAvatarEventContent, @@ -18,11 +17,11 @@ use ruma::{ guest_access::{GuestAccess, RoomGuestAccessEventContent}, history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, join_rules::{JoinRule, RoomJoinRulesEventContent}, - name::RoomNameEventContent, topic::RoomTopicEventContent, }, StateEventType, }, + space::SpaceRoomJoinRule, OwnedRoomId, RoomId, UserId, }; @@ -30,10 +29,15 @@ use tracing::{debug, error, warn}; use crate::{services, Error, PduEvent, Result}; +pub enum CachedJoinRule { + Simplified(SpaceRoomJoinRule), + Full(JoinRule), +} + pub struct CachedSpaceChunk { chunk: SpaceHierarchyRoomsChunk, children: Vec, - join_rule: JoinRule, + join_rule: CachedJoinRule, } pub struct Service { @@ -79,9 +83,15 @@ impl Service { .as_ref() { if let Some(cached) = cached { - if let Some(_join_rule) = - self.handle_join_rule(&cached.join_rule, sender_user, ¤t_room)? - { + let allowed = match &cached.join_rule { + CachedJoinRule::Simplified(s) => { + self.handle_simplified_join_rule(s, sender_user, ¤t_room)? + } + CachedJoinRule::Full(f) => { + self.handle_join_rule(f, sender_user, ¤t_room)? + } + }; + if allowed { if left_to_skip > 0 { left_to_skip -= 1; } else { @@ -152,7 +162,7 @@ impl Service { Some(CachedSpaceChunk { chunk, children: children_ids.clone(), - join_rule, + join_rule: CachedJoinRule::Full(join_rule), }), ); } @@ -182,7 +192,6 @@ impl Service { .await { warn!("Got response from {server} for /hierarchy\n{response:?}"); - let join_rule = self.translate_pjoinrule(&response.room.join_rule)?; let chunk = SpaceHierarchyRoomsChunk { canonical_alias: response.room.canonical_alias, name: response.room.name, @@ -192,7 +201,7 @@ impl Service { world_readable: response.room.world_readable, guest_can_join: response.room.guest_can_join, avatar_url: response.room.avatar_url, - join_rule: self.translate_sjoinrule(&response.room.join_rule)?, + join_rule: response.room.join_rule.clone(), room_type: response.room.room_type, children_state: response.room.children_state, }; @@ -202,9 +211,11 @@ impl Service { .map(|c| c.room_id.clone()) .collect::>(); - if let Some(_join_rule) = - self.handle_join_rule(&join_rule, sender_user, ¤t_room)? - { + if self.handle_simplified_join_rule( + &response.room.join_rule, + sender_user, + ¤t_room, + )? { if left_to_skip > 0 { left_to_skip -= 1; } else { @@ -220,7 +231,7 @@ impl Service { Some(CachedSpaceChunk { chunk, children, - join_rule, + join_rule: CachedJoinRule::Simplified(response.room.join_rule), }), ); @@ -349,15 +360,17 @@ impl Service { }) .transpose()? .unwrap_or(JoinRule::Invite); - self.handle_join_rule(&join_rule, sender_user, room_id)? - .ok_or_else(|| { - debug!("User is not allowed to see room {room_id}"); - // This error will be caught later - Error::BadRequest( - ErrorKind::Forbidden, - "User is not allowed to see the room", - ) - })? + + if !self.handle_join_rule(&join_rule, sender_user, room_id)? { + debug!("User is not allowed to see room {room_id}"); + // This error will be caught later + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "User is not allowed to see the room", + )); + } + + self.translate_joinrule(&join_rule)? }, room_type: services() .rooms @@ -378,20 +391,35 @@ impl Service { }) } - fn translate_pjoinrule(&self, join_rule: &PublicRoomJoinRule) -> Result { + fn translate_joinrule(&self, join_rule: &JoinRule) -> Result { match join_rule { - PublicRoomJoinRule::Knock => Ok(JoinRule::Knock), - PublicRoomJoinRule::Public => Ok(JoinRule::Public), + JoinRule::Invite => Ok(SpaceRoomJoinRule::Invite), + JoinRule::Knock => Ok(SpaceRoomJoinRule::Knock), + JoinRule::Private => Ok(SpaceRoomJoinRule::Private), + JoinRule::Restricted(_) => Ok(SpaceRoomJoinRule::Restricted), + JoinRule::KnockRestricted(_) => Ok(SpaceRoomJoinRule::KnockRestricted), + JoinRule::Public => Ok(SpaceRoomJoinRule::Public), _ => Err(Error::BadServerResponse("Unknown join rule")), } } - fn translate_sjoinrule(&self, join_rule: &PublicRoomJoinRule) -> Result { - match join_rule { - PublicRoomJoinRule::Knock => Ok(SpaceRoomJoinRule::Knock), - PublicRoomJoinRule::Public => Ok(SpaceRoomJoinRule::Public), - _ => Err(Error::BadServerResponse("Unknown join rule")), - } + fn handle_simplified_join_rule( + &self, + join_rule: &SpaceRoomJoinRule, + sender_user: &UserId, + room_id: &RoomId, + ) -> Result { + let allowed = match join_rule { + SpaceRoomJoinRule::Public => true, + SpaceRoomJoinRule::Knock => true, + SpaceRoomJoinRule::Invite => services() + .rooms + .state_cache + .is_joined(sender_user, &room_id)?, + _ => false, + }; + + Ok(allowed) } fn handle_join_rule( @@ -399,30 +427,25 @@ impl Service { join_rule: &JoinRule, sender_user: &UserId, room_id: &RoomId, - ) -> Result> { + ) -> Result { + if self.handle_simplified_join_rule( + &self.translate_joinrule(join_rule)?, + sender_user, + room_id, + )? { + return Ok(true); + } + match join_rule { - JoinRule::Public => Ok::<_, Error>(Some(SpaceRoomJoinRule::Public)), - JoinRule::Knock => Ok(Some(SpaceRoomJoinRule::Knock)), - JoinRule::Invite => { - if services() - .rooms - .state_cache - .is_joined(sender_user, &room_id)? - { - Ok(Some(SpaceRoomJoinRule::Invite)) - } else { - Ok(None) - } - } - JoinRule::Restricted(_r) => { + JoinRule::Restricted(_) => { // TODO: Check rules - Ok(None) + Ok(false) } - JoinRule::KnockRestricted(_r) => { + JoinRule::KnockRestricted(_) => { // TODO: Check rules - Ok(None) + Ok(false) } - _ => Ok(None), + _ => Ok(false), } } }