diff --git a/src/api/client_server/membership.rs b/src/api/client_server/membership.rs index 58ed0401..f07f2adb 100644 --- a/src/api/client_server/membership.rs +++ b/src/api/client_server/membership.rs @@ -654,7 +654,7 @@ async fn join_room_by_id_helper( // We set the room state after inserting the pdu, so that we never have a moment in time // where events in the current room state do not exist - services().rooms.state.set_room_state(room_id, shortstatehash)?; + services().rooms.state.set_room_state(room_id, shortstatehash, &state_lock)?; let statehashid = services().rooms.state.append_to_state(&parsed_pdu)?; } else { diff --git a/src/api/server_server.rs b/src/api/server_server.rs index 647f4574..11f7ec34 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -857,131 +857,6 @@ pub async fn send_transaction_message_route( Ok(send_transaction_message::v1::Response { pdus: resolved_map.into_iter().map(|(e, r)| (e, r.map_err(|e| e.to_string()))).collect() }) } -#[tracing::instrument(skip(starting_events))] -pub(crate) async fn get_auth_chain<'a>( - room_id: &RoomId, - starting_events: Vec>, -) -> Result> + 'a> { - const NUM_BUCKETS: usize = 50; - - let mut buckets = vec![BTreeSet::new(); NUM_BUCKETS]; - - let mut i = 0; - for id in starting_events { - let short = services().rooms.short.get_or_create_shorteventid(&id)?; - let bucket_id = (short % NUM_BUCKETS as u64) as usize; - buckets[bucket_id].insert((short, id.clone())); - i += 1; - if i % 100 == 0 { - tokio::task::yield_now().await; - } - } - - let mut full_auth_chain = HashSet::new(); - - let mut hits = 0; - let mut misses = 0; - for chunk in buckets { - if chunk.is_empty() { - continue; - } - - let chunk_key: Vec = chunk.iter().map(|(short, _)| short).copied().collect(); - if let Some(cached) = services().rooms.auth_chain.get_cached_eventid_authchain(&chunk_key)? { - hits += 1; - full_auth_chain.extend(cached.iter().copied()); - continue; - } - misses += 1; - - let mut chunk_cache = HashSet::new(); - let mut hits2 = 0; - let mut misses2 = 0; - let mut i = 0; - for (sevent_id, event_id) in chunk { - if let Some(cached) = services().rooms.auth_chain.get_cached_eventid_authchain(&[sevent_id])? { - hits2 += 1; - chunk_cache.extend(cached.iter().copied()); - } else { - misses2 += 1; - let auth_chain = Arc::new(get_auth_chain_inner(room_id, &event_id)?); - services().rooms - .auth_chain - .cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?; - println!( - "cache missed event {} with auth chain len {}", - event_id, - auth_chain.len() - ); - chunk_cache.extend(auth_chain.iter()); - - i += 1; - if i % 100 == 0 { - tokio::task::yield_now().await; - } - }; - } - println!( - "chunk missed with len {}, event hits2: {}, misses2: {}", - chunk_cache.len(), - hits2, - misses2 - ); - let chunk_cache = Arc::new(chunk_cache); - services().rooms - .auth_chain.cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?; - full_auth_chain.extend(chunk_cache.iter()); - } - - println!( - "total: {}, chunk hits: {}, misses: {}", - full_auth_chain.len(), - hits, - misses - ); - - Ok(full_auth_chain - .into_iter() - .filter_map(move |sid| services().rooms.short.get_eventid_from_short(sid).ok())) -} - -#[tracing::instrument(skip(event_id))] -fn get_auth_chain_inner( - room_id: &RoomId, - event_id: &EventId, -) -> Result> { - let mut todo = vec![Arc::from(event_id)]; - let mut found = HashSet::new(); - - while let Some(event_id) = todo.pop() { - match services().rooms.timeline.get_pdu(&event_id) { - Ok(Some(pdu)) => { - if pdu.room_id != room_id { - return Err(Error::BadRequest(ErrorKind::Forbidden, "Evil event in db")); - } - for auth_event in &pdu.auth_events { - let sauthevent = services() - .rooms.short - .get_or_create_shorteventid(auth_event)?; - - if !found.contains(&sauthevent) { - found.insert(sauthevent); - todo.push(auth_event.clone()); - } - } - } - Ok(None) => { - warn!("Could not find pdu mentioned in auth events: {}", event_id); - } - Err(e) => { - warn!("Could not load event in auth chain: {} {}", event_id, e); - } - } - } - - Ok(found) -} - /// # `GET /_matrix/federation/v1/event/{eventId}` /// /// Retrieves a single event from the server. @@ -1135,7 +1010,7 @@ pub async fn get_event_authorization_route( let room_id = <&RoomId>::try_from(room_id_str) .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; - let auth_chain_ids = get_auth_chain(room_id, vec![Arc::from(&*body.event_id)]).await?; + let auth_chain_ids = services().rooms.auth_chain.get_auth_chain(room_id, vec![Arc::from(&*body.event_id)]).await?; Ok(get_event_authorization::v1::Response { auth_chain: auth_chain_ids @@ -1190,7 +1065,7 @@ pub async fn get_room_state_route( .collect(); let auth_chain_ids = - get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)]).await?; + services().rooms.auth_chain.get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)]).await?; Ok(get_room_state::v1::Response { auth_chain: auth_chain_ids @@ -1246,7 +1121,7 @@ pub async fn get_room_state_ids_route( .collect(); let auth_chain_ids = - get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)]).await?; + services().rooms.auth_chain.get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)]).await?; Ok(get_room_state_ids::v1::Response { auth_chain_ids: auth_chain_ids.map(|id| (*id).to_owned()).collect(), @@ -1449,7 +1324,7 @@ async fn create_join_event( drop(mutex_lock); let state_ids = services().rooms.state_accessor.state_full_ids(shortstatehash).await?; - let auth_chain_ids = get_auth_chain( + let auth_chain_ids = services().rooms.auth_chain.get_auth_chain( room_id, state_ids.iter().map(|(_, id)| id.clone()).collect(), ) diff --git a/src/database/key_value/account_data.rs b/src/database/key_value/account_data.rs index f0325d2b..5674ac07 100644 --- a/src/database/key_value/account_data.rs +++ b/src/database/key_value/account_data.rs @@ -1,11 +1,11 @@ -use std::{collections::HashMap, sync::Arc}; +use std::collections::HashMap; use ruma::{UserId, DeviceId, signatures::CanonicalJsonValue, api::client::{uiaa::UiaaInfo, error::ErrorKind}, events::{RoomAccountDataEventType, AnyEphemeralRoomEvent}, serde::Raw, RoomId}; use serde::{Serialize, de::DeserializeOwned}; use crate::{Result, database::KeyValueDatabase, service, Error, utils, services}; -impl service::account_data::Data for Arc { +impl service::account_data::Data for KeyValueDatabase { /// Places one event in the account data of the user and removes the previous entry. #[tracing::instrument(skip(self, room_id, user_id, event_type, data))] fn update( diff --git a/src/database/key_value/appservice.rs b/src/database/key_value/appservice.rs index ee6ae206..f427ba71 100644 --- a/src/database/key_value/appservice.rs +++ b/src/database/key_value/appservice.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use crate::{database::KeyValueDatabase, service, utils, Error, Result}; impl service::appservice::Data for KeyValueDatabase { diff --git a/src/database/key_value/globals.rs b/src/database/key_value/globals.rs index 87119207..199cbf64 100644 --- a/src/database/key_value/globals.rs +++ b/src/database/key_value/globals.rs @@ -1,4 +1,4 @@ -use std::{collections::BTreeMap, sync::Arc}; +use std::collections::BTreeMap; use async_trait::async_trait; use futures_util::{stream::FuturesUnordered, StreamExt}; @@ -9,7 +9,7 @@ use crate::{Result, service, database::KeyValueDatabase, Error, utils, services} pub const COUNTER: &[u8] = b"c"; #[async_trait] -impl service::globals::Data for Arc { +impl service::globals::Data for KeyValueDatabase { fn next_count(&self) -> Result { utils::u64_from_bytes(&self.global.increment(COUNTER)?) .map_err(|_| Error::bad_database("Count has invalid bytes.")) diff --git a/src/database/key_value/key_backups.rs b/src/database/key_value/key_backups.rs index c59ed36b..8171451c 100644 --- a/src/database/key_value/key_backups.rs +++ b/src/database/key_value/key_backups.rs @@ -1,10 +1,10 @@ -use std::{collections::BTreeMap, sync::Arc}; +use std::collections::BTreeMap; use ruma::{UserId, serde::Raw, api::client::{backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, error::ErrorKind}, RoomId}; use crate::{Result, service, database::KeyValueDatabase, services, Error, utils}; -impl service::key_backups::Data for Arc { +impl service::key_backups::Data for KeyValueDatabase { fn create_backup( &self, user_id: &UserId, diff --git a/src/database/key_value/media.rs b/src/database/key_value/media.rs index 1726755a..f0244872 100644 --- a/src/database/key_value/media.rs +++ b/src/database/key_value/media.rs @@ -1,10 +1,8 @@ -use std::sync::Arc; - use ruma::api::client::error::ErrorKind; use crate::{database::KeyValueDatabase, service, Error, utils, Result}; -impl service::media::Data for Arc { +impl service::media::Data for KeyValueDatabase { fn create_file_metadata(&self, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, content_type: Option<&str>) -> Result> { let mut key = mxc.as_bytes().to_vec(); key.push(0xff); diff --git a/src/database/key_value/pusher.rs b/src/database/key_value/pusher.rs index 85d1d864..b05e47be 100644 --- a/src/database/key_value/pusher.rs +++ b/src/database/key_value/pusher.rs @@ -1,10 +1,8 @@ -use std::sync::Arc; - use ruma::{UserId, api::client::push::{set_pusher, get_pushers}}; use crate::{service, database::KeyValueDatabase, Error, Result}; -impl service::pusher::Data for Arc { +impl service::pusher::Data for KeyValueDatabase { fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()> { let mut key = sender.as_bytes().to_vec(); key.push(0xff); diff --git a/src/database/key_value/rooms/alias.rs b/src/database/key_value/rooms/alias.rs index 437902df..0aa8dd48 100644 --- a/src/database/key_value/rooms/alias.rs +++ b/src/database/key_value/rooms/alias.rs @@ -1,10 +1,8 @@ -use std::sync::Arc; - use ruma::{RoomId, RoomAliasId, api::client::error::ErrorKind}; use crate::{service, database::KeyValueDatabase, utils, Error, services, Result}; -impl service::rooms::alias::Data for Arc { +impl service::rooms::alias::Data for KeyValueDatabase { fn set_alias( &self, alias: &RoomAliasId, diff --git a/src/database/key_value/rooms/auth_chain.rs b/src/database/key_value/rooms/auth_chain.rs index 2dffb04b..49d39560 100644 --- a/src/database/key_value/rooms/auth_chain.rs +++ b/src/database/key_value/rooms/auth_chain.rs @@ -2,7 +2,7 @@ use std::{collections::HashSet, mem::size_of, sync::Arc}; use crate::{service, database::KeyValueDatabase, Result, utils}; -impl service::rooms::auth_chain::Data for Arc { +impl service::rooms::auth_chain::Data for KeyValueDatabase { fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result>>> { // Check RAM cache if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) { diff --git a/src/database/key_value/rooms/directory.rs b/src/database/key_value/rooms/directory.rs index 864e75e9..727004e7 100644 --- a/src/database/key_value/rooms/directory.rs +++ b/src/database/key_value/rooms/directory.rs @@ -1,10 +1,8 @@ -use std::sync::Arc; - use ruma::RoomId; use crate::{service, database::KeyValueDatabase, utils, Error, Result}; -impl service::rooms::directory::Data for Arc { +impl service::rooms::directory::Data for KeyValueDatabase { fn set_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.insert(room_id.as_bytes(), &[]) } diff --git a/src/database/key_value/rooms/edus/mod.rs b/src/database/key_value/rooms/edus/mod.rs index 03e4219e..b5007f89 100644 --- a/src/database/key_value/rooms/edus/mod.rs +++ b/src/database/key_value/rooms/edus/mod.rs @@ -2,8 +2,6 @@ mod presence; mod typing; mod read_receipt; -use std::sync::Arc; - use crate::{service, database::KeyValueDatabase}; -impl service::rooms::edus::Data for Arc {} +impl service::rooms::edus::Data for KeyValueDatabase {} diff --git a/src/database/key_value/rooms/edus/presence.rs b/src/database/key_value/rooms/edus/presence.rs index 5aeb1477..1477c28b 100644 --- a/src/database/key_value/rooms/edus/presence.rs +++ b/src/database/key_value/rooms/edus/presence.rs @@ -1,10 +1,10 @@ -use std::{collections::HashMap, sync::Arc}; +use std::collections::HashMap; use ruma::{UserId, RoomId, events::presence::PresenceEvent, presence::PresenceState, UInt}; use crate::{service, database::KeyValueDatabase, utils, Error, services, Result}; -impl service::rooms::edus::presence::Data for Arc { +impl service::rooms::edus::presence::Data for KeyValueDatabase { fn update_presence( &self, user_id: &UserId, diff --git a/src/database/key_value/rooms/edus/read_receipt.rs b/src/database/key_value/rooms/edus/read_receipt.rs index 7fcb8ac8..a12e2653 100644 --- a/src/database/key_value/rooms/edus/read_receipt.rs +++ b/src/database/key_value/rooms/edus/read_receipt.rs @@ -1,10 +1,10 @@ -use std::{mem, sync::Arc}; +use std::mem; use ruma::{UserId, RoomId, events::receipt::ReceiptEvent, serde::Raw, signatures::CanonicalJsonObject}; use crate::{database::KeyValueDatabase, service, utils, Error, services, Result}; -impl service::rooms::edus::read_receipt::Data for Arc { +impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { fn readreceipt_update( &self, user_id: &UserId, diff --git a/src/database/key_value/rooms/edus/typing.rs b/src/database/key_value/rooms/edus/typing.rs index 7f3526d9..b7d35968 100644 --- a/src/database/key_value/rooms/edus/typing.rs +++ b/src/database/key_value/rooms/edus/typing.rs @@ -1,10 +1,10 @@ -use std::{collections::HashSet, sync::Arc}; +use std::collections::HashSet; use ruma::{UserId, RoomId}; use crate::{database::KeyValueDatabase, service, utils, Error, services, Result}; -impl service::rooms::edus::typing::Data for Arc { +impl service::rooms::edus::typing::Data for KeyValueDatabase { fn typing_add( &self, user_id: &UserId, diff --git a/src/database/key_value/rooms/lazy_load.rs b/src/database/key_value/rooms/lazy_load.rs index b16657aa..133e1d04 100644 --- a/src/database/key_value/rooms/lazy_load.rs +++ b/src/database/key_value/rooms/lazy_load.rs @@ -1,10 +1,8 @@ -use std::sync::Arc; - use ruma::{UserId, DeviceId, RoomId}; use crate::{service, database::KeyValueDatabase, Result}; -impl service::rooms::lazy_loading::Data for Arc { +impl service::rooms::lazy_loading::Data for KeyValueDatabase { fn lazy_load_was_sent_before( &self, user_id: &UserId, diff --git a/src/database/key_value/rooms/metadata.rs b/src/database/key_value/rooms/metadata.rs index 560beb90..72f62514 100644 --- a/src/database/key_value/rooms/metadata.rs +++ b/src/database/key_value/rooms/metadata.rs @@ -1,10 +1,8 @@ -use std::sync::Arc; - use ruma::RoomId; use crate::{service, database::KeyValueDatabase, Result, services}; -impl service::rooms::metadata::Data for Arc { +impl service::rooms::metadata::Data for KeyValueDatabase { fn exists(&self, room_id: &RoomId) -> Result { let prefix = match services().rooms.short.get_shortroomid(room_id)? { Some(b) => b.to_be_bytes().to_vec(), @@ -19,4 +17,18 @@ impl service::rooms::metadata::Data for Arc { .filter(|(k, _)| k.starts_with(&prefix)) .is_some()) } + + fn is_disabled(&self, room_id: &RoomId) -> Result { + Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some()) + } + + fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> { + if disabled { + self.disabledroomids.insert(room_id.as_bytes(), &[])?; + } else { + self.disabledroomids.remove(room_id.as_bytes())?; + } + + Ok(()) + } } diff --git a/src/database/key_value/rooms/mod.rs b/src/database/key_value/rooms/mod.rs index 97c29e5b..406943ed 100644 --- a/src/database/key_value/rooms/mod.rs +++ b/src/database/key_value/rooms/mod.rs @@ -15,8 +15,6 @@ mod state_compressor; mod timeline; mod user; -use std::sync::Arc; - use crate::{database::KeyValueDatabase, service}; -impl service::rooms::Data for Arc {} +impl service::rooms::Data for KeyValueDatabase {} diff --git a/src/database/key_value/rooms/outlier.rs b/src/database/key_value/rooms/outlier.rs index b1ae816a..aa975449 100644 --- a/src/database/key_value/rooms/outlier.rs +++ b/src/database/key_value/rooms/outlier.rs @@ -1,10 +1,8 @@ -use std::sync::Arc; - use ruma::{EventId, signatures::CanonicalJsonObject}; use crate::{service, database::KeyValueDatabase, PduEvent, Error, Result}; -impl service::rooms::outlier::Data for Arc { +impl service::rooms::outlier::Data for KeyValueDatabase { fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { self.eventid_outlierpdu .get(event_id.as_bytes())? diff --git a/src/database/key_value/rooms/pdu_metadata.rs b/src/database/key_value/rooms/pdu_metadata.rs index f5e8f766..f3ac414f 100644 --- a/src/database/key_value/rooms/pdu_metadata.rs +++ b/src/database/key_value/rooms/pdu_metadata.rs @@ -4,7 +4,7 @@ use ruma::{RoomId, EventId}; use crate::{service, database::KeyValueDatabase, Result}; -impl service::rooms::pdu_metadata::Data for Arc { +impl service::rooms::pdu_metadata::Data for KeyValueDatabase { fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()> { for prev in event_ids { let mut key = room_id.as_bytes().to_vec(); diff --git a/src/database/key_value/rooms/search.rs b/src/database/key_value/rooms/search.rs index 7b8d2783..dfbdbc64 100644 --- a/src/database/key_value/rooms/search.rs +++ b/src/database/key_value/rooms/search.rs @@ -1,10 +1,10 @@ -use std::{mem::size_of, sync::Arc}; +use std::mem::size_of; use ruma::RoomId; use crate::{service, database::KeyValueDatabase, utils, Result, services}; -impl service::rooms::search::Data for Arc { +impl service::rooms::search::Data for KeyValueDatabase { fn index_pdu<'a>(&self, shortroomid: u64, pdu_id: &[u8], message_body: String) -> Result<()> { let mut batch = message_body .split_terminator(|c: char| !c.is_alphanumeric()) diff --git a/src/database/key_value/rooms/short.rs b/src/database/key_value/rooms/short.rs index 9a302b56..ecd12dad 100644 --- a/src/database/key_value/rooms/short.rs +++ b/src/database/key_value/rooms/short.rs @@ -1,6 +1,227 @@ use std::sync::Arc; -use crate::{database::KeyValueDatabase, service}; +use ruma::{EventId, events::StateEventType, RoomId}; -impl service::rooms::short::Data for Arc { +use crate::{Result, database::KeyValueDatabase, service, utils, Error, services}; + +impl service::rooms::short::Data for KeyValueDatabase { + fn get_or_create_shorteventid( + &self, + event_id: &EventId, + ) -> Result { + if let Some(short) = self.eventidshort_cache.lock().unwrap().get_mut(event_id) { + return Ok(*short); + } + + let short = match self.eventid_shorteventid.get(event_id.as_bytes())? { + Some(shorteventid) => utils::u64_from_bytes(&shorteventid) + .map_err(|_| Error::bad_database("Invalid shorteventid in db."))?, + None => { + let shorteventid = services().globals.next_count()?; + self.eventid_shorteventid + .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; + self.shorteventid_eventid + .insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?; + shorteventid + } + }; + + self.eventidshort_cache + .lock() + .unwrap() + .insert(event_id.to_owned(), short); + + Ok(short) + } + + fn get_shortstatekey( + &self, + event_type: &StateEventType, + state_key: &str, + ) -> Result> { + if let Some(short) = self + .statekeyshort_cache + .lock() + .unwrap() + .get_mut(&(event_type.clone(), state_key.to_owned())) + { + return Ok(Some(*short)); + } + + let mut statekey = event_type.to_string().as_bytes().to_vec(); + statekey.push(0xff); + statekey.extend_from_slice(state_key.as_bytes()); + + let short = self + .statekey_shortstatekey + .get(&statekey)? + .map(|shortstatekey| { + utils::u64_from_bytes(&shortstatekey) + .map_err(|_| Error::bad_database("Invalid shortstatekey in db.")) + }) + .transpose()?; + + if let Some(s) = short { + self.statekeyshort_cache + .lock() + .unwrap() + .insert((event_type.clone(), state_key.to_owned()), s); + } + + Ok(short) + } + + fn get_or_create_shortstatekey( + &self, + event_type: &StateEventType, + state_key: &str, + ) -> Result { + if let Some(short) = self + .statekeyshort_cache + .lock() + .unwrap() + .get_mut(&(event_type.clone(), state_key.to_owned())) + { + return Ok(*short); + } + + let mut statekey = event_type.to_string().as_bytes().to_vec(); + statekey.push(0xff); + statekey.extend_from_slice(state_key.as_bytes()); + + let short = match self.statekey_shortstatekey.get(&statekey)? { + Some(shortstatekey) => utils::u64_from_bytes(&shortstatekey) + .map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?, + None => { + let shortstatekey = services().globals.next_count()?; + self.statekey_shortstatekey + .insert(&statekey, &shortstatekey.to_be_bytes())?; + self.shortstatekey_statekey + .insert(&shortstatekey.to_be_bytes(), &statekey)?; + shortstatekey + } + }; + + self.statekeyshort_cache + .lock() + .unwrap() + .insert((event_type.clone(), state_key.to_owned()), short); + + Ok(short) + } + + fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { + if let Some(id) = self + .shorteventid_cache + .lock() + .unwrap() + .get_mut(&shorteventid) + { + return Ok(Arc::clone(id)); + } + + let bytes = self + .shorteventid_eventid + .get(&shorteventid.to_be_bytes())? + .ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?; + + let event_id = EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database("EventID in shorteventid_eventid is invalid unicode.") + })?) + .map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?; + + self.shorteventid_cache + .lock() + .unwrap() + .insert(shorteventid, Arc::clone(&event_id)); + + Ok(event_id) + } + + fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { + if let Some(id) = self + .shortstatekey_cache + .lock() + .unwrap() + .get_mut(&shortstatekey) + { + return Ok(id.clone()); + } + + let bytes = self + .shortstatekey_statekey + .get(&shortstatekey.to_be_bytes())? + .ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?; + + let mut parts = bytes.splitn(2, |&b| b == 0xff); + let eventtype_bytes = parts.next().expect("split always returns one entry"); + let statekey_bytes = parts + .next() + .ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?; + + let event_type = + StateEventType::try_from(utils::string_from_bytes(eventtype_bytes).map_err(|_| { + Error::bad_database("Event type in shortstatekey_statekey is invalid unicode.") + })?) + .map_err(|_| Error::bad_database("Event type in shortstatekey_statekey is invalid."))?; + + let state_key = utils::string_from_bytes(statekey_bytes).map_err(|_| { + Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode.") + })?; + + let result = (event_type, state_key); + + self.shortstatekey_cache + .lock() + .unwrap() + .insert(shortstatekey, result.clone()); + + Ok(result) + } + + /// Returns (shortstatehash, already_existed) + fn get_or_create_shortstatehash( + &self, + state_hash: &[u8], + ) -> Result<(u64, bool)> { + Ok(match self.statehash_shortstatehash.get(state_hash)? { + Some(shortstatehash) => ( + utils::u64_from_bytes(&shortstatehash) + .map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?, + true, + ), + None => { + let shortstatehash = services().globals.next_count()?; + self.statehash_shortstatehash + .insert(state_hash, &shortstatehash.to_be_bytes())?; + (shortstatehash, false) + } + }) + } + + fn get_shortroomid(&self, room_id: &RoomId) -> Result> { + self.roomid_shortroomid + .get(room_id.as_bytes())? + .map(|bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Invalid shortroomid in db.")) + }) + .transpose() + } + + fn get_or_create_shortroomid( + &self, + room_id: &RoomId, + ) -> Result { + Ok(match self.roomid_shortroomid.get(room_id.as_bytes())? { + Some(short) => utils::u64_from_bytes(&short) + .map_err(|_| Error::bad_database("Invalid shortroomid in db."))?, + None => { + let short = services().globals.next_count()?; + self.roomid_shortroomid + .insert(room_id.as_bytes(), &short.to_be_bytes())?; + short + } + }) + } } diff --git a/src/database/key_value/rooms/state.rs b/src/database/key_value/rooms/state.rs index 527c2403..b2822b32 100644 --- a/src/database/key_value/rooms/state.rs +++ b/src/database/key_value/rooms/state.rs @@ -6,7 +6,7 @@ use std::fmt::Debug; use crate::{service, database::KeyValueDatabase, utils, Error, Result}; -impl service::rooms::state::Data for Arc { +impl service::rooms::state::Data for KeyValueDatabase { fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result> { self.roomid_shortstatehash .get(room_id.as_bytes())? diff --git a/src/database/key_value/rooms/state_accessor.rs b/src/database/key_value/rooms/state_accessor.rs index 9af45db3..4d5bd4a1 100644 --- a/src/database/key_value/rooms/state_accessor.rs +++ b/src/database/key_value/rooms/state_accessor.rs @@ -5,7 +5,7 @@ use async_trait::async_trait; use ruma::{EventId, events::StateEventType, RoomId}; #[async_trait] -impl service::rooms::state_accessor::Data for Arc { +impl service::rooms::state_accessor::Data for KeyValueDatabase { async fn state_full_ids(&self, shortstatehash: u64) -> Result>> { let full_state = services().rooms.state_compressor .load_shortstatehash_info(shortstatehash)? diff --git a/src/database/key_value/rooms/state_cache.rs b/src/database/key_value/rooms/state_cache.rs index bdb8cf81..5f054858 100644 --- a/src/database/key_value/rooms/state_cache.rs +++ b/src/database/key_value/rooms/state_cache.rs @@ -1,10 +1,8 @@ -use std::sync::Arc; - use ruma::{UserId, RoomId, events::{AnyStrippedStateEvent, AnySyncStateEvent}, serde::Raw}; use crate::{service, database::KeyValueDatabase, services, Result}; -impl service::rooms::state_cache::Data for Arc { +impl service::rooms::state_cache::Data for KeyValueDatabase { fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xff); diff --git a/src/database/key_value/rooms/state_compressor.rs b/src/database/key_value/rooms/state_compressor.rs index e1c0280b..aee1890c 100644 --- a/src/database/key_value/rooms/state_compressor.rs +++ b/src/database/key_value/rooms/state_compressor.rs @@ -1,8 +1,8 @@ -use std::{collections::HashSet, mem::size_of, sync::Arc}; +use std::{collections::HashSet, mem::size_of}; use crate::{service::{self, rooms::state_compressor::data::StateDiff}, database::KeyValueDatabase, Error, utils, Result}; -impl service::rooms::state_compressor::Data for Arc { +impl service::rooms::state_compressor::Data for KeyValueDatabase { fn get_statediff(&self, shortstatehash: u64) -> Result { let value = self .shortstatehash_statediff diff --git a/src/database/key_value/rooms/timeline.rs b/src/database/key_value/rooms/timeline.rs index 2d334b96..0b7286b2 100644 --- a/src/database/key_value/rooms/timeline.rs +++ b/src/database/key_value/rooms/timeline.rs @@ -5,7 +5,27 @@ use tracing::error; use crate::{service, database::KeyValueDatabase, utils, Error, PduEvent, Result, services}; -impl service::rooms::timeline::Data for Arc { +impl service::rooms::timeline::Data for KeyValueDatabase { + fn first_pdu_in_room(&self, room_id: &RoomId) -> Result>> { + let prefix = services().rooms.short + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); + + // Look for PDUs in that room. + self.pduid_pdu + .iter_from(&prefix, false) + .filter(|(k, _)| k.starts_with(&prefix)) + .map(|(_, pdu)| { + serde_json::from_slice(&pdu) + .map_err(|_| Error::bad_database("Invalid first PDU in db.")) + .map(Arc::new) + }) + .next() + .transpose() + } + fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { match self .lasttimelinecount_cache diff --git a/src/database/key_value/rooms/user.rs b/src/database/key_value/rooms/user.rs index 4d20b00a..3759bda7 100644 --- a/src/database/key_value/rooms/user.rs +++ b/src/database/key_value/rooms/user.rs @@ -1,10 +1,8 @@ -use std::sync::Arc; - use ruma::{UserId, RoomId}; use crate::{service, database::KeyValueDatabase, utils, Error, Result, services}; -impl service::rooms::user::Data for Arc { +impl service::rooms::user::Data for KeyValueDatabase { fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xff); @@ -104,13 +102,13 @@ impl service::rooms::user::Data for Arc { }); // We use the default compare function because keys are sorted correctly (not reversed) - Ok(utils::common_elements(iterators, Ord::cmp) + Ok(Box::new(Box::new(utils::common_elements(iterators, Ord::cmp) .expect("users is not empty") .map(|bytes| { RoomId::parse(utils::string_from_bytes(&*bytes).map_err(|_| { Error::bad_database("Invalid RoomId bytes in userroomid_joined") })?) .map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined.")) - })) + })))) } } diff --git a/src/database/key_value/transaction_ids.rs b/src/database/key_value/transaction_ids.rs index 7fa69081..a63b3c5d 100644 --- a/src/database/key_value/transaction_ids.rs +++ b/src/database/key_value/transaction_ids.rs @@ -1,10 +1,8 @@ -use std::sync::Arc; - use ruma::{UserId, DeviceId, TransactionId}; use crate::{service, database::KeyValueDatabase, Result}; -impl service::transaction_ids::Data for Arc { +impl service::transaction_ids::Data for KeyValueDatabase { fn add_txnid( &self, user_id: &UserId, diff --git a/src/database/key_value/uiaa.rs b/src/database/key_value/uiaa.rs index 8752e55a..cf242dec 100644 --- a/src/database/key_value/uiaa.rs +++ b/src/database/key_value/uiaa.rs @@ -1,10 +1,8 @@ -use std::sync::Arc; - use ruma::{UserId, DeviceId, signatures::CanonicalJsonValue, api::client::{uiaa::UiaaInfo, error::ErrorKind}}; use crate::{database::KeyValueDatabase, service, Error, Result}; -impl service::uiaa::Data for Arc { +impl service::uiaa::Data for KeyValueDatabase { fn set_uiaa_request( &self, user_id: &UserId, diff --git a/src/database/key_value/users.rs b/src/database/key_value/users.rs index 1ac85b36..55a518d4 100644 --- a/src/database/key_value/users.rs +++ b/src/database/key_value/users.rs @@ -1,11 +1,11 @@ -use std::{mem::size_of, collections::BTreeMap, sync::Arc}; +use std::{mem::size_of, collections::BTreeMap}; use ruma::{api::client::{filter::IncomingFilterDefinition, error::ErrorKind, device::Device}, UserId, RoomAliasId, MxcUri, DeviceId, MilliSecondsSinceUnixEpoch, DeviceKeyId, encryption::{OneTimeKey, CrossSigningKey, DeviceKeys}, serde::Raw, events::{AnyToDeviceEvent, StateEventType}, DeviceKeyAlgorithm, UInt}; use tracing::warn; use crate::{service::{self, users::clean_signatures}, database::KeyValueDatabase, Error, utils, services, Result}; -impl service::users::Data for Arc { +impl service::users::Data for KeyValueDatabase { /// Check if a user has an account on this homeserver. fn exists(&self, user_id: &UserId) -> Result { Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) @@ -113,7 +113,7 @@ impl service::users::Data for Arc { /// Hash and set the user's password to the Argon2 hash fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { if let Some(password) = password { - if let Ok(hash) = utils::calculate_hash(password) { + if let Ok(hash) = utils::calculate_password_hash(password) { self.userid_password .insert(user_id.as_bytes(), hash.as_bytes())?; Ok(()) diff --git a/src/database/mod.rs b/src/database/mod.rs index 35922f0b..68684677 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -238,8 +238,8 @@ impl KeyValueDatabase { } /// Load an existing database or create a new one. - pub async fn load_or_create(config: &Config) -> Result<()> { - Self::check_db_setup(config)?; + pub async fn load_or_create(config: Config) -> Result<()> { + Self::check_db_setup(&config)?; if !Path::new(&config.database_path).exists() { std::fs::create_dir_all(&config.database_path) @@ -251,19 +251,19 @@ impl KeyValueDatabase { #[cfg(not(feature = "sqlite"))] return Err(Error::BadConfig("Database backend not found.")); #[cfg(feature = "sqlite")] - Arc::new(Arc::::open(config)?) + Arc::new(Arc::::open(&config)?) } "rocksdb" => { #[cfg(not(feature = "rocksdb"))] return Err(Error::BadConfig("Database backend not found.")); #[cfg(feature = "rocksdb")] - Arc::new(Arc::::open(config)?) + Arc::new(Arc::::open(&config)?) } "persy" => { #[cfg(not(feature = "persy"))] return Err(Error::BadConfig("Database backend not found.")); #[cfg(feature = "persy")] - Arc::new(Arc::::open(config)?) + Arc::new(Arc::::open(&config)?) } _ => { return Err(Error::BadConfig("Database backend not found.")); @@ -402,7 +402,7 @@ impl KeyValueDatabase { }); - let services_raw = Box::new(Services::build(Arc::clone(&db))); + let services_raw = Box::new(Services::build(Arc::clone(&db), config)?); // This is the first and only time we initialize the SERVICE static *SERVICES.write().unwrap() = Some(Box::leak(services_raw)); @@ -825,7 +825,7 @@ impl KeyValueDatabase { info!( "Loaded {} database with version {}", - config.database_backend, latest_database_version + services().globals.config.database_backend, latest_database_version ); } else { services() @@ -837,7 +837,7 @@ impl KeyValueDatabase { warn!( "Created new {} database with version {}", - config.database_backend, latest_database_version + services().globals.config.database_backend, latest_database_version ); } @@ -866,7 +866,7 @@ impl KeyValueDatabase { .sending .start_handler(sending_receiver); - Self::start_cleanup_task(config).await; + Self::start_cleanup_task().await; Ok(()) } @@ -888,8 +888,8 @@ impl KeyValueDatabase { res } - #[tracing::instrument(skip(config))] - pub async fn start_cleanup_task(config: &Config) { + #[tracing::instrument] + pub async fn start_cleanup_task() { use tokio::time::interval; #[cfg(unix)] @@ -898,7 +898,7 @@ impl KeyValueDatabase { use std::time::{Duration, Instant}; - let timer_interval = Duration::from_secs(config.cleanup_second_interval as u64); + let timer_interval = Duration::from_secs(services().globals.config.cleanup_second_interval as u64); tokio::spawn(async move { let mut i = interval(timer_interval); diff --git a/src/service/account_data/mod.rs b/src/service/account_data/mod.rs index 9785478b..1289f7a3 100644 --- a/src/service/account_data/mod.rs +++ b/src/service/account_data/mod.rs @@ -18,7 +18,7 @@ use tracing::error; use crate::{service::*, services, utils, Error, Result}; pub struct Service { - db: Box, + db: Arc, } impl Service { diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index 32a709c1..0b14314f 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -426,7 +426,7 @@ impl Service { Error::bad_database("Invalid room id field in event in database") })?; let start = Instant::now(); - let count = server_server::get_auth_chain(room_id, vec![event_id]) + let count = services().rooms.auth_chain.get_auth_chain(room_id, vec![event_id]) .await? .count(); let elapsed = start.elapsed(); @@ -615,14 +615,12 @@ impl Service { )) } AdminCommand::DisableRoom { room_id } => { - todo!(); - //services().rooms.disabledroomids.insert(room_id.as_bytes(), &[])?; - //RoomMessageEventContent::text_plain("Room disabled.") + services().rooms.metadata.disable_room(&room_id, true); + RoomMessageEventContent::text_plain("Room disabled.") } AdminCommand::EnableRoom { room_id } => { - todo!(); - //services().rooms.disabledroomids.remove(room_id.as_bytes())?; - //RoomMessageEventContent::text_plain("Room enabled.") + services().rooms.metadata.disable_room(&room_id, false); + RoomMessageEventContent::text_plain("Room enabled.") } AdminCommand::DeactivateUser { leave_rooms, diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 8fd69dfe..de8d1aa7 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -35,7 +35,7 @@ type SyncHandle = ( ); pub struct Service { - pub db: Box, + pub db: Arc, pub actual_destination_cache: Arc>, // actual_destination, host pub tls_name_override: Arc>, @@ -92,7 +92,7 @@ impl Default for RotationHandler { impl Service { pub fn load( - db: Box, + db: Arc, config: Config, ) -> Result { let keypair = db.load_keypair(); diff --git a/src/service/key_backups/mod.rs b/src/service/key_backups/mod.rs index 4bd9efd3..a3bed714 100644 --- a/src/service/key_backups/mod.rs +++ b/src/service/key_backups/mod.rs @@ -13,7 +13,7 @@ use ruma::{ use std::{collections::BTreeMap, sync::Arc}; pub struct Service { - db: Box, + db: Arc, } impl Service { diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index f86251fa..d3dd2bdc 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -16,7 +16,7 @@ pub struct FileMeta { } pub struct Service { - db: Box, + db: Arc, } impl Service { diff --git a/src/service/mod.rs b/src/service/mod.rs index a1a728c5..a772c1db 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -1,4 +1,9 @@ -use std::sync::Arc; +use std::{ + collections::{BTreeMap, HashMap}, + sync::{Arc, Mutex}, +}; + +use crate::{Result, Config}; pub mod account_data; pub mod admin; @@ -30,20 +35,73 @@ pub struct Services { } impl Services { - pub fn build(db: Arc) -> Self { - Self { + pub fn build< + D: appservice::Data + + pusher::Data + + rooms::Data + + transaction_ids::Data + + uiaa::Data + + users::Data + + account_data::Data + + globals::Data + + key_backups::Data + + media::Data, + >( + db: Arc, config: Config + ) -> Result { + Ok(Self { appservice: appservice::Service { db: db.clone() }, pusher: pusher::Service { db: db.clone() }, - rooms: rooms::Service { db: Arc::clone(&db) }, - transaction_ids: transaction_ids::Service { db: Arc::clone(&db) }, - uiaa: uiaa::Service { db: Arc::clone(&db) }, - users: users::Service { db: Arc::clone(&db) }, - account_data: account_data::Service { db: Arc::clone(&db) }, - admin: admin::Service { db: Arc::clone(&db) }, - globals: globals::Service { db: Arc::clone(&db) }, - key_backups: key_backups::Service { db: Arc::clone(&db) }, - media: media::Service { db: Arc::clone(&db) }, - sending: sending::Service { db: Arc::clone(&db) }, - } + rooms: rooms::Service { + alias: rooms::alias::Service { db: db.clone() }, + auth_chain: rooms::auth_chain::Service { db: db.clone() }, + directory: rooms::directory::Service { db: db.clone() }, + edus: rooms::edus::Service { + presence: rooms::edus::presence::Service { db: db.clone() }, + read_receipt: rooms::edus::read_receipt::Service { db: db.clone() }, + typing: rooms::edus::typing::Service { db: db.clone() }, + }, + event_handler: rooms::event_handler::Service, + lazy_loading: rooms::lazy_loading::Service { + db: db.clone(), + lazy_load_waiting: Mutex::new(HashMap::new()), + }, + metadata: rooms::metadata::Service { db: db.clone() }, + outlier: rooms::outlier::Service { db: db.clone() }, + pdu_metadata: rooms::pdu_metadata::Service { db: db.clone() }, + search: rooms::search::Service { db: db.clone() }, + short: rooms::short::Service { db: db.clone() }, + state: rooms::state::Service { db: db.clone() }, + state_accessor: rooms::state_accessor::Service { db: db.clone() }, + state_cache: rooms::state_cache::Service { db: db.clone() }, + state_compressor: rooms::state_compressor::Service { db: db.clone() }, + timeline: rooms::timeline::Service { db: db.clone() }, + user: rooms::user::Service { db: db.clone() }, + }, + transaction_ids: transaction_ids::Service { + db: db.clone() + }, + uiaa: uiaa::Service { + db: db.clone() + }, + users: users::Service { + db: db.clone() + }, + account_data: account_data::Service { + db: db.clone() + }, + admin: admin::Service { sender: todo!() }, + globals: globals::Service::load(db.clone(), config)?, + key_backups: key_backups::Service { + db: db.clone() + }, + media: media::Service { + db: db.clone() + }, + sending: sending::Service { + maximum_requests: todo!(), + sender: todo!(), + }, + }) } } diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs index ef5888fc..65fb3677 100644 --- a/src/service/rooms/alias/mod.rs +++ b/src/service/rooms/alias/mod.rs @@ -1,11 +1,13 @@ mod data; +use std::sync::Arc; + pub use data::Data; use ruma::{RoomAliasId, RoomId}; use crate::Result; pub struct Service { - db: Box, + db: Arc, } impl Service { diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index 5fe0e3e8..e35094bb 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -1,12 +1,14 @@ mod data; -use std::{sync::Arc, collections::HashSet}; +use std::{sync::Arc, collections::{HashSet, BTreeSet}}; pub use data::Data; +use ruma::{RoomId, EventId, api::client::error::ErrorKind}; +use tracing::log::warn; -use crate::Result; +use crate::{Result, services, Error}; pub struct Service { - db: Box, + db: Arc, } impl Service { @@ -22,4 +24,131 @@ impl Service { pub fn cache_auth_chain(&self, key: Vec, auth_chain: Arc>) -> Result<()> { self.db.cache_auth_chain(key, auth_chain) } + + #[tracing::instrument(skip(self, starting_events))] + pub async fn get_auth_chain<'a>( + &self, + room_id: &RoomId, + starting_events: Vec>, + ) -> Result> + 'a> { + const NUM_BUCKETS: usize = 50; + + let mut buckets = vec![BTreeSet::new(); NUM_BUCKETS]; + + let mut i = 0; + for id in starting_events { + let short = services().rooms.short.get_or_create_shorteventid(&id)?; + let bucket_id = (short % NUM_BUCKETS as u64) as usize; + buckets[bucket_id].insert((short, id.clone())); + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + } + + let mut full_auth_chain = HashSet::new(); + + let mut hits = 0; + let mut misses = 0; + for chunk in buckets { + if chunk.is_empty() { + continue; + } + + let chunk_key: Vec = chunk.iter().map(|(short, _)| short).copied().collect(); + if let Some(cached) = services().rooms.auth_chain.get_cached_eventid_authchain(&chunk_key)? { + hits += 1; + full_auth_chain.extend(cached.iter().copied()); + continue; + } + misses += 1; + + let mut chunk_cache = HashSet::new(); + let mut hits2 = 0; + let mut misses2 = 0; + let mut i = 0; + for (sevent_id, event_id) in chunk { + if let Some(cached) = services().rooms.auth_chain.get_cached_eventid_authchain(&[sevent_id])? { + hits2 += 1; + chunk_cache.extend(cached.iter().copied()); + } else { + misses2 += 1; + let auth_chain = Arc::new(self.get_auth_chain_inner(room_id, &event_id)?); + services().rooms + .auth_chain + .cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?; + println!( + "cache missed event {} with auth chain len {}", + event_id, + auth_chain.len() + ); + chunk_cache.extend(auth_chain.iter()); + + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + }; + } + println!( + "chunk missed with len {}, event hits2: {}, misses2: {}", + chunk_cache.len(), + hits2, + misses2 + ); + let chunk_cache = Arc::new(chunk_cache); + services().rooms + .auth_chain.cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?; + full_auth_chain.extend(chunk_cache.iter()); + } + + println!( + "total: {}, chunk hits: {}, misses: {}", + full_auth_chain.len(), + hits, + misses + ); + + Ok(full_auth_chain + .into_iter() + .filter_map(move |sid| services().rooms.short.get_eventid_from_short(sid).ok())) + } + + #[tracing::instrument(skip(self, event_id))] + fn get_auth_chain_inner( + &self, + room_id: &RoomId, + event_id: &EventId, + ) -> Result> { + let mut todo = vec![Arc::from(event_id)]; + let mut found = HashSet::new(); + + while let Some(event_id) = todo.pop() { + match services().rooms.timeline.get_pdu(&event_id) { + Ok(Some(pdu)) => { + if pdu.room_id != room_id { + return Err(Error::BadRequest(ErrorKind::Forbidden, "Evil event in db")); + } + for auth_event in &pdu.auth_events { + let sauthevent = services() + .rooms.short + .get_or_create_shorteventid(auth_event)?; + + if !found.contains(&sauthevent) { + found.insert(sauthevent); + todo.push(auth_event.clone()); + } + } + } + Ok(None) => { + warn!("Could not find pdu mentioned in auth events: {}", event_id); + } + Err(e) => { + warn!("Could not load event in auth chain: {} {}", event_id, e); + } + } + } + + Ok(found) + } } diff --git a/src/service/rooms/directory/mod.rs b/src/service/rooms/directory/mod.rs index fb289941..e85afef6 100644 --- a/src/service/rooms/directory/mod.rs +++ b/src/service/rooms/directory/mod.rs @@ -1,11 +1,13 @@ mod data; +use std::sync::Arc; + pub use data::Data; use ruma::RoomId; use crate::Result; pub struct Service { - db: Box, + db: Arc, } impl Service { diff --git a/src/service/rooms/edus/presence/mod.rs b/src/service/rooms/edus/presence/mod.rs index 73b7b5a5..d6578977 100644 --- a/src/service/rooms/edus/presence/mod.rs +++ b/src/service/rooms/edus/presence/mod.rs @@ -1,5 +1,5 @@ mod data; -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; pub use data::Data; use ruma::{RoomId, UserId, events::presence::PresenceEvent}; @@ -7,7 +7,7 @@ use ruma::{RoomId, UserId, events::presence::PresenceEvent}; use crate::Result; pub struct Service { - db: Box, + db: Arc, } impl Service { diff --git a/src/service/rooms/edus/read_receipt/mod.rs b/src/service/rooms/edus/read_receipt/mod.rs index 2a4c0b7f..17708772 100644 --- a/src/service/rooms/edus/read_receipt/mod.rs +++ b/src/service/rooms/edus/read_receipt/mod.rs @@ -1,11 +1,13 @@ mod data; +use std::sync::Arc; + pub use data::Data; use ruma::{RoomId, UserId, events::receipt::ReceiptEvent, serde::Raw}; use crate::Result; pub struct Service { - db: Box, + db: Arc, } impl Service { diff --git a/src/service/rooms/edus/typing/mod.rs b/src/service/rooms/edus/typing/mod.rs index 16a135f8..37520560 100644 --- a/src/service/rooms/edus/typing/mod.rs +++ b/src/service/rooms/edus/typing/mod.rs @@ -1,11 +1,13 @@ mod data; +use std::sync::Arc; + pub use data::Data; use ruma::{UserId, RoomId, events::SyncEphemeralRoomEvent}; use crate::Result; pub struct Service { - db: Box, + db: Arc, } impl Service { diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index ac3cca6a..79f93b50 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -72,13 +72,15 @@ impl Service { )); } - services() + if services() .rooms - .is_disabled(room_id)? - .ok_or(Error::BadRequest( + .metadata + .is_disabled(room_id)? { + return Err(Error::BadRequest( ErrorKind::Forbidden, "Federation of this room is currently disabled on this server.", - ))?; + )); + } // 1. Skip the PDU if we already have it as a timeline event if let Some(pdu_id) = services().rooms.timeline.get_pdu_id(event_id)? { @@ -111,7 +113,7 @@ impl Service { } // 9. Fetch any missing prev events doing all checks listed here starting at 1. These are timeline events - let (sorted_prev_events, eventid_info) = self.fetch_unknown_prev_events( + let (sorted_prev_events, mut eventid_info) = self.fetch_unknown_prev_events( origin, &create_event, room_id, @@ -122,14 +124,15 @@ impl Service { let mut errors = 0; for prev_id in dbg!(sorted_prev_events) { // Check for disabled again because it might have changed - services() + if services() .rooms - .is_disabled(room_id)? - .ok_or(Error::BadRequest( + .metadata + .is_disabled(room_id)? { + return Err(Error::BadRequest( ErrorKind::Forbidden, - "Federation of - this room is currently disabled on this server.", - ))?; + "Federation of this room is currently disabled on this server.", + )); + } if let Some((time, tries)) = services() .globals @@ -279,14 +282,14 @@ impl Service { Err(e) => { // Drop warn!("Dropping bad event {}: {}", event_id, e); - return Err("Signature verification failed".to_owned()); + return Err(Error::BadRequest(ErrorKind::InvalidParam, "Signature verification failed")); } Ok(ruma::signatures::Verified::Signatures) => { // Redact warn!("Calculated hash does not match: {}", event_id); match ruma::signatures::redact(&value, room_version_id) { Ok(obj) => obj, - Err(_) => return Err("Redaction failed".to_owned()), + Err(_) => return Err(Error::BadRequest(ErrorKind::InvalidParam, "Redaction failed")), } } Ok(ruma::signatures::Verified::All) => value, @@ -480,7 +483,7 @@ impl Service { let mut okay = true; for prev_eventid in &incoming_pdu.prev_events { - let prev_event = if let Ok(Some(pdu)) = services().rooms.get_pdu(prev_eventid) { + let prev_event = if let Ok(Some(pdu)) = services().rooms.timeline.get_pdu(prev_eventid) { pdu } else { okay = false; @@ -488,7 +491,7 @@ impl Service { }; let sstatehash = - if let Ok(Some(s)) = services().rooms.pdu_shortstatehash(prev_eventid) { + if let Ok(Some(s)) = services().rooms.state_accessor.pdu_shortstatehash(prev_eventid) { s } else { okay = false; @@ -525,7 +528,7 @@ impl Service { let mut starting_events = Vec::with_capacity(leaf_state.len()); for (k, id) in leaf_state { - if let Ok((ty, st_key)) = services().rooms.get_statekey_from_short(k) { + if let Ok((ty, st_key)) = services().rooms.short.get_statekey_from_short(k) { // FIXME: Undo .to_string().into() when StateMap // is updated to use StateEventType state.insert((ty.to_string().into(), st_key), id.clone()); @@ -539,7 +542,7 @@ impl Service { services() .rooms .auth_chain - .get_auth_chain(room_id, starting_events, services()) + .get_auth_chain(room_id, starting_events) .await? .collect(), ); @@ -551,7 +554,7 @@ impl Service { let result = state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| { - let res = services().rooms.get_pdu(id); + let res = services().rooms.timeline.get_pdu(id); if let Err(e) = &res { error!("LOOK AT ME Failed to fetch event: {}", e); } @@ -677,7 +680,7 @@ impl Service { .and_then(|event_id| services().rooms.timeline.get_pdu(event_id).ok().flatten()) }, ) - .map_err(|_e| "Auth check failed.".to_owned())?; + .map_err(|_e| Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed."))?; if !check_result { return Err(Error::bad_database("Event has failed auth check with state at the event.")); @@ -714,7 +717,7 @@ impl Service { // Only keep those extremities were not referenced yet extremities - .retain(|id| !matches!(services().rooms.is_event_referenced(room_id, id), Ok(true))); + .retain(|id| !matches!(services().rooms.pdu_metadata.is_event_referenced(room_id, id), Ok(true))); info!("Compressing state at event"); let state_ids_compressed = state_at_incoming_event @@ -722,7 +725,8 @@ impl Service { .map(|(shortstatekey, id)| { services() .rooms - .compress_state_event(*shortstatekey, id)? + .state_compressor + .compress_state_event(*shortstatekey, id) }) .collect::>()?; @@ -731,6 +735,7 @@ impl Service { let auth_events = services() .rooms + .state .get_auth_events( room_id, &incoming_pdu.kind, @@ -744,10 +749,10 @@ impl Service { &incoming_pdu, None::, |k, s| auth_events.get(&(k.clone(), s.to_owned())), - )?; + ).map_err(|_e| Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed."))?; if soft_fail { - self.append_incoming_pdu( + services().rooms.timeline.append_incoming_pdu( &incoming_pdu, val, extremities.iter().map(std::ops::Deref::deref), @@ -760,8 +765,9 @@ impl Service { warn!("Event was soft failed: {:?}", incoming_pdu); services() .rooms + .pdu_metadata .mark_event_soft_failed(&incoming_pdu.event_id)?; - return Err("Event has been soft failed".into()); + return Err(Error::BadRequest(ErrorKind::InvalidParam, "Event has been soft failed")); } if incoming_pdu.state_key.is_some() { @@ -798,14 +804,14 @@ impl Service { "Found extremity pdu with no statehash in db: {:?}", leaf_pdu ); - "Found pdu with no statehash in db.".to_owned() + Error::bad_database("Found pdu with no statehash in db.") })?, leaf_pdu, ); } _ => { error!("Missing state snapshot for {:?}", id); - return Err("Missing state snapshot.".to_owned()); + return Err(Error::BadDatabase("Missing state snapshot.")); } } } @@ -835,7 +841,7 @@ impl Service { let mut update_state = false; // 14. Use state resolution to find new room state let new_room_state = if fork_states.is_empty() { - return Err("State is empty.".to_owned()); + panic!("State is empty"); } else if fork_states.iter().skip(1).all(|f| &fork_states[0] == f) { info!("State resolution trivial"); // There was only one state, so it has to be the room's current state (because that is @@ -845,7 +851,8 @@ impl Service { .map(|(k, id)| { services() .rooms - .compress_state_event(*k, id)? + .state_compressor + .compress_state_event(*k, id) }) .collect::>()? } else { @@ -877,9 +884,8 @@ impl Service { .filter_map(|(k, id)| { services() .rooms - .get_statekey_from_short(k)? - // FIXME: Undo .to_string().into() when StateMap - // is updated to use StateEventType + .short + .get_statekey_from_short(k) .map(|(ty, st_key)| ((ty.to_string().into(), st_key), id)) .ok() }) @@ -895,7 +901,7 @@ impl Service { &fork_states, auth_chain_sets, |id| { - let res = services().rooms.get_pdu(id); + let res = services().rooms.timeline.get_pdu(id); if let Err(e) = &res { error!("LOOK AT ME Failed to fetch event: {}", e); } @@ -904,7 +910,7 @@ impl Service { ) { Ok(new_state) => new_state, Err(_) => { - return Err("State resolution failed, either an event could not be found or deserialization".into()); + return Err(Error::bad_database("State resolution failed, either an event could not be found or deserialization")); } }; @@ -921,6 +927,7 @@ impl Service { .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?; services() .rooms + .state_compressor .compress_state_event(shortstatekey, &event_id) }) .collect::>()? @@ -929,9 +936,11 @@ impl Service { // Set the new room state to the resolved state if update_state { info!("Forcing new room state"); + let (sstatehash, _, _) = services().rooms.state_compressor.save_state(room_id, new_room_state)?; services() .rooms - .force_state(room_id, new_room_state)?; + .state + .set_room_state(room_id, sstatehash, &state_lock)?; } } @@ -942,7 +951,7 @@ impl Service { // We use the `state_at_event` instead of `state_after` so we accurately // represent the state for this event. - let pdu_id = self + let pdu_id = services().rooms.timeline .append_incoming_pdu( &incoming_pdu, val, @@ -1017,7 +1026,7 @@ impl Service { // a. Look in the main timeline (pduid_pdu tree) // b. Look at outlier pdu tree // (get_pdu_json checks both) - if let Ok(Some(local_pdu)) = services().rooms.get_pdu(id) { + if let Ok(Some(local_pdu)) = services().rooms.timeline.get_pdu(id) { trace!("Found {} in db", id); pdus.push((local_pdu, None)); continue; @@ -1040,7 +1049,7 @@ impl Service { tokio::task::yield_now().await; } - if let Ok(Some(_)) = services().rooms.get_pdu(&next_id) { + if let Ok(Some(_)) = services().rooms.timeline.get_pdu(&next_id) { trace!("Found {} in db", id); continue; } @@ -1140,6 +1149,7 @@ impl Service { let first_pdu_in_room = services() .rooms + .timeline .first_pdu_in_room(room_id)? .ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?; diff --git a/src/service/rooms/lazy_loading/mod.rs b/src/service/rooms/lazy_loading/mod.rs index 90dad21c..760fffee 100644 --- a/src/service/rooms/lazy_loading/mod.rs +++ b/src/service/rooms/lazy_loading/mod.rs @@ -1,5 +1,5 @@ mod data; -use std::{collections::{HashSet, HashMap}, sync::Mutex}; +use std::{collections::{HashSet, HashMap}, sync::{Mutex, Arc}}; pub use data::Data; use ruma::{DeviceId, UserId, RoomId}; @@ -7,7 +7,7 @@ use ruma::{DeviceId, UserId, RoomId}; use crate::Result; pub struct Service { - db: Box, + db: Arc, lazy_load_waiting: Mutex, Box, Box, u64), HashSet>>>, } diff --git a/src/service/rooms/metadata/data.rs b/src/service/rooms/metadata/data.rs index 9444db41..bc31ee88 100644 --- a/src/service/rooms/metadata/data.rs +++ b/src/service/rooms/metadata/data.rs @@ -3,4 +3,6 @@ use crate::Result; pub trait Data: Send + Sync { fn exists(&self, room_id: &RoomId) -> Result; + fn is_disabled(&self, room_id: &RoomId) -> Result; + fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()>; } diff --git a/src/service/rooms/metadata/mod.rs b/src/service/rooms/metadata/mod.rs index 3c21dd19..b6cccd15 100644 --- a/src/service/rooms/metadata/mod.rs +++ b/src/service/rooms/metadata/mod.rs @@ -1,11 +1,13 @@ mod data; +use std::sync::Arc; + pub use data::Data; use ruma::RoomId; use crate::Result; pub struct Service { - db: Box, + db: Arc, } impl Service { @@ -14,4 +16,12 @@ impl Service { pub fn exists(&self, room_id: &RoomId) -> Result { self.db.exists(room_id) } + + pub fn is_disabled(&self, room_id: &RoomId) -> Result { + self.db.is_disabled(room_id) + } + + pub fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> { + self.db.disable_room(room_id, disabled) + } } diff --git a/src/service/rooms/outlier/mod.rs b/src/service/rooms/outlier/mod.rs index 5493ce48..d36adc4c 100644 --- a/src/service/rooms/outlier/mod.rs +++ b/src/service/rooms/outlier/mod.rs @@ -1,11 +1,13 @@ mod data; +use std::sync::Arc; + pub use data::Data; use ruma::{EventId, signatures::CanonicalJsonObject}; use crate::{Result, PduEvent}; pub struct Service { - db: Box, + db: Arc, } impl Service { diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs index a81d05c1..4724f857 100644 --- a/src/service/rooms/pdu_metadata/mod.rs +++ b/src/service/rooms/pdu_metadata/mod.rs @@ -7,7 +7,7 @@ use ruma::{RoomId, EventId}; use crate::Result; pub struct Service { - db: Box, + db: Arc, } impl Service { diff --git a/src/service/rooms/search/mod.rs b/src/service/rooms/search/mod.rs index dc571910..ec1ad537 100644 --- a/src/service/rooms/search/mod.rs +++ b/src/service/rooms/search/mod.rs @@ -1,11 +1,13 @@ mod data; +use std::sync::Arc; + pub use data::Data; use crate::Result; use ruma::RoomId; pub struct Service { - db: Box, + db: Arc, } impl Service { diff --git a/src/service/rooms/short/data.rs b/src/service/rooms/short/data.rs index bc2b28f0..07a27121 100644 --- a/src/service/rooms/short/data.rs +++ b/src/service/rooms/short/data.rs @@ -1,2 +1,40 @@ +use std::sync::Arc; + +use ruma::{EventId, events::StateEventType, RoomId}; +use crate::Result; + pub trait Data: Send + Sync { + fn get_or_create_shorteventid( + &self, + event_id: &EventId, + ) -> Result; + + fn get_shortstatekey( + &self, + event_type: &StateEventType, + state_key: &str, + ) -> Result>; + + fn get_or_create_shortstatekey( + &self, + event_type: &StateEventType, + state_key: &str, + ) -> Result; + + fn get_eventid_from_short(&self, shorteventid: u64) -> Result>; + + fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)>; + + /// Returns (shortstatehash, already_existed) + fn get_or_create_shortstatehash( + &self, + state_hash: &[u8], + ) -> Result<(u64, bool)>; + + fn get_shortroomid(&self, room_id: &RoomId) -> Result>; + + fn get_or_create_shortroomid( + &self, + room_id: &RoomId, + ) -> Result; } diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index a024dc67..08ce5c5a 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -7,7 +7,7 @@ use ruma::{EventId, events::StateEventType, RoomId}; use crate::{Result, Error, utils, services}; pub struct Service { - db: Box, + db: Arc, } impl Service { @@ -15,29 +15,7 @@ impl Service { &self, event_id: &EventId, ) -> Result { - if let Some(short) = self.eventidshort_cache.lock().unwrap().get_mut(event_id) { - return Ok(*short); - } - - let short = match self.eventid_shorteventid.get(event_id.as_bytes())? { - Some(shorteventid) => utils::u64_from_bytes(&shorteventid) - .map_err(|_| Error::bad_database("Invalid shorteventid in db."))?, - None => { - let shorteventid = services().globals.next_count()?; - self.eventid_shorteventid - .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; - self.shorteventid_eventid - .insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?; - shorteventid - } - }; - - self.eventidshort_cache - .lock() - .unwrap() - .insert(event_id.to_owned(), short); - - Ok(short) + self.db.get_or_create_shorteventid(event_id) } pub fn get_shortstatekey( @@ -45,36 +23,7 @@ impl Service { event_type: &StateEventType, state_key: &str, ) -> Result> { - if let Some(short) = self - .statekeyshort_cache - .lock() - .unwrap() - .get_mut(&(event_type.clone(), state_key.to_owned())) - { - return Ok(Some(*short)); - } - - let mut statekey = event_type.to_string().as_bytes().to_vec(); - statekey.push(0xff); - statekey.extend_from_slice(state_key.as_bytes()); - - let short = self - .statekey_shortstatekey - .get(&statekey)? - .map(|shortstatekey| { - utils::u64_from_bytes(&shortstatekey) - .map_err(|_| Error::bad_database("Invalid shortstatekey in db.")) - }) - .transpose()?; - - if let Some(s) = short { - self.statekeyshort_cache - .lock() - .unwrap() - .insert((event_type.clone(), state_key.to_owned()), s); - } - - Ok(short) + self.db.get_shortstatekey(event_type, state_key) } pub fn get_or_create_shortstatekey( @@ -82,152 +31,33 @@ impl Service { event_type: &StateEventType, state_key: &str, ) -> Result { - if let Some(short) = self - .statekeyshort_cache - .lock() - .unwrap() - .get_mut(&(event_type.clone(), state_key.to_owned())) - { - return Ok(*short); - } - - let mut statekey = event_type.to_string().as_bytes().to_vec(); - statekey.push(0xff); - statekey.extend_from_slice(state_key.as_bytes()); - - let short = match self.statekey_shortstatekey.get(&statekey)? { - Some(shortstatekey) => utils::u64_from_bytes(&shortstatekey) - .map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?, - None => { - let shortstatekey = services().globals.next_count()?; - self.statekey_shortstatekey - .insert(&statekey, &shortstatekey.to_be_bytes())?; - self.shortstatekey_statekey - .insert(&shortstatekey.to_be_bytes(), &statekey)?; - shortstatekey - } - }; - - self.statekeyshort_cache - .lock() - .unwrap() - .insert((event_type.clone(), state_key.to_owned()), short); - - Ok(short) + self.db.get_or_create_shortstatekey(event_type, state_key) } pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { - if let Some(id) = self - .shorteventid_cache - .lock() - .unwrap() - .get_mut(&shorteventid) - { - return Ok(Arc::clone(id)); - } - - let bytes = self - .shorteventid_eventid - .get(&shorteventid.to_be_bytes())? - .ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?; - - let event_id = EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("EventID in shorteventid_eventid is invalid unicode.") - })?) - .map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?; - - self.shorteventid_cache - .lock() - .unwrap() - .insert(shorteventid, Arc::clone(&event_id)); - - Ok(event_id) + self.db.get_eventid_from_short(shorteventid) } pub fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { - if let Some(id) = self - .shortstatekey_cache - .lock() - .unwrap() - .get_mut(&shortstatekey) - { - return Ok(id.clone()); - } - - let bytes = self - .shortstatekey_statekey - .get(&shortstatekey.to_be_bytes())? - .ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?; - - let mut parts = bytes.splitn(2, |&b| b == 0xff); - let eventtype_bytes = parts.next().expect("split always returns one entry"); - let statekey_bytes = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?; - - let event_type = - StateEventType::try_from(utils::string_from_bytes(eventtype_bytes).map_err(|_| { - Error::bad_database("Event type in shortstatekey_statekey is invalid unicode.") - })?) - .map_err(|_| Error::bad_database("Event type in shortstatekey_statekey is invalid."))?; - - let state_key = utils::string_from_bytes(statekey_bytes).map_err(|_| { - Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode.") - })?; - - let result = (event_type, state_key); - - self.shortstatekey_cache - .lock() - .unwrap() - .insert(shortstatekey, result.clone()); - - Ok(result) + self.db.get_statekey_from_short(shortstatekey) } /// Returns (shortstatehash, already_existed) - fn get_or_create_shortstatehash( + pub fn get_or_create_shortstatehash( &self, state_hash: &[u8], ) -> Result<(u64, bool)> { - Ok(match self.statehash_shortstatehash.get(state_hash)? { - Some(shortstatehash) => ( - utils::u64_from_bytes(&shortstatehash) - .map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?, - true, - ), - None => { - let shortstatehash = services().globals.next_count()?; - self.statehash_shortstatehash - .insert(state_hash, &shortstatehash.to_be_bytes())?; - (shortstatehash, false) - } - }) + self.db.get_or_create_shortstatehash(state_hash) } pub fn get_shortroomid(&self, room_id: &RoomId) -> Result> { - self.roomid_shortroomid - .get(room_id.as_bytes())? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid shortroomid in db.")) - }) - .transpose() + self.db.get_shortroomid(room_id) } pub fn get_or_create_shortroomid( &self, room_id: &RoomId, ) -> Result { - Ok(match self.roomid_shortroomid.get(room_id.as_bytes())? { - Some(short) => utils::u64_from_bytes(&short) - .map_err(|_| Error::bad_database("Invalid shortroomid in db."))?, - None => { - let short = services().globals.next_count()?; - self.roomid_shortroomid - .insert(room_id.as_bytes(), &short.to_be_bytes())?; - short - } - }) + self.db.get_or_create_shortroomid(room_id) } } diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index 53859785..79807c55 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -1,9 +1,10 @@ mod data; -use std::{collections::HashSet, sync::Arc}; +use std::{collections::{HashSet, HashMap}, sync::Arc}; pub use data::Data; -use ruma::{RoomId, events::{room::{member::MembershipState, create::RoomCreateEventContent}, AnyStrippedStateEvent, StateEventType}, UserId, EventId, serde::Raw, RoomVersionId}; +use ruma::{RoomId, events::{room::{member::MembershipState, create::RoomCreateEventContent}, AnyStrippedStateEvent, StateEventType, RoomEventType}, UserId, EventId, serde::Raw, RoomVersionId, state_res::{StateMap, self}}; use serde::Deserialize; +use tokio::sync::MutexGuard; use tracing::warn; use crate::{Result, services, PduEvent, Error, utils::calculate_hash}; @@ -11,7 +12,7 @@ use crate::{Result, services, PduEvent, Error, utils::calculate_hash}; use super::state_compressor::CompressedStateEvent; pub struct Service { - db: Box, + db: Arc, } impl Service { @@ -97,7 +98,7 @@ impl Service { room_id: &RoomId, state_ids_compressed: HashSet, ) -> Result { - let shorteventid = services().short.get_or_create_shorteventid(event_id)?; + let shorteventid = services().rooms.short.get_or_create_shorteventid(event_id)?; let previous_shortstatehash = self.db.get_room_shortstatehash(room_id)?; @@ -109,11 +110,11 @@ impl Service { ); let (shortstatehash, already_existed) = - services().short.get_or_create_shortstatehash(&state_hash)?; + services().rooms.short.get_or_create_shortstatehash(&state_hash)?; if !already_existed { let states_parents = previous_shortstatehash - .map_or_else(|| Ok(Vec::new()), |p| services().room.state_compressor.load_shortstatehash_info(p))?; + .map_or_else(|| Ok(Vec::new()), |p| services().rooms.state_compressor.load_shortstatehash_info(p))?; let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { @@ -132,7 +133,7 @@ impl Service { } else { (state_ids_compressed, HashSet::new()) }; - services().room.state_compressor.save_state_from_diff( + services().rooms.state_compressor.save_state_from_diff( shortstatehash, statediffnew, statediffremoved, @@ -141,7 +142,7 @@ impl Service { )?; } - self.db.set_event_state(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; + self.db.set_event_state(shorteventid, shortstatehash)?; Ok(shortstatehash) } @@ -155,25 +156,24 @@ impl Service { &self, new_pdu: &PduEvent, ) -> Result { - let shorteventid = self.get_or_create_shorteventid(&new_pdu.event_id)?; + let shorteventid = services().rooms.short.get_or_create_shorteventid(&new_pdu.event_id)?; let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id)?; if let Some(p) = previous_shortstatehash { - self.shorteventid_shortstatehash - .insert(&shorteventid.to_be_bytes(), &p.to_be_bytes())?; + self.db.set_event_state(shorteventid, p)?; } if let Some(state_key) = &new_pdu.state_key { let states_parents = previous_shortstatehash - .map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; + .map_or_else(|| Ok(Vec::new()), |p| services().rooms.state_compressor.load_shortstatehash_info(p))?; - let shortstatekey = self.get_or_create_shortstatekey( + let shortstatekey = services().rooms.short.get_or_create_shortstatekey( &new_pdu.kind.to_string().into(), state_key, )?; - let new = self.compress_state_event(shortstatekey, &new_pdu.event_id)?; + let new = services().rooms.state_compressor.compress_state_event(shortstatekey, &new_pdu.event_id)?; let replaces = states_parents .last() @@ -199,7 +199,7 @@ impl Service { statediffremoved.insert(*replaces); } - self.save_state_from_diff( + services().rooms.state_compressor.save_state_from_diff( shortstatehash, statediffnew, statediffremoved, @@ -221,16 +221,16 @@ impl Service { let mut state = Vec::new(); // Add recommended events if let Some(e) = - self.room_state_get(&invite_event.room_id, &StateEventType::RoomCreate, "")? + services().rooms.state_accessor.room_state_get(&invite_event.room_id, &StateEventType::RoomCreate, "")? { state.push(e.to_stripped_state_event()); } if let Some(e) = - self.room_state_get(&invite_event.room_id, &StateEventType::RoomJoinRules, "")? + services().rooms.state_accessor.room_state_get(&invite_event.room_id, &StateEventType::RoomJoinRules, "")? { state.push(e.to_stripped_state_event()); } - if let Some(e) = self.room_state_get( + if let Some(e) = services().rooms.state_accessor.room_state_get( &invite_event.room_id, &StateEventType::RoomCanonicalAlias, "", @@ -238,16 +238,16 @@ impl Service { state.push(e.to_stripped_state_event()); } if let Some(e) = - self.room_state_get(&invite_event.room_id, &StateEventType::RoomAvatar, "")? + services().rooms.state_accessor.room_state_get(&invite_event.room_id, &StateEventType::RoomAvatar, "")? { state.push(e.to_stripped_state_event()); } if let Some(e) = - self.room_state_get(&invite_event.room_id, &StateEventType::RoomName, "")? + services().rooms.state_accessor.room_state_get(&invite_event.room_id, &StateEventType::RoomName, "")? { state.push(e.to_stripped_state_event()); } - if let Some(e) = self.room_state_get( + if let Some(e) = services().rooms.state_accessor.room_state_get( &invite_event.room_id, &StateEventType::RoomMember, invite_event.sender.as_str(), @@ -260,17 +260,16 @@ impl Service { } #[tracing::instrument(skip(self))] - pub fn set_room_state(&self, room_id: &RoomId, shortstatehash: u64) -> Result<()> { - self.roomid_shortstatehash - .insert(room_id.as_bytes(), &shortstatehash.to_be_bytes())?; - - Ok(()) + pub fn set_room_state(&self, room_id: &RoomId, shortstatehash: u64, + mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + ) -> Result<()> { + self.db.set_room_state(room_id, shortstatehash, mutex_lock) } /// Returns the room's version. #[tracing::instrument(skip(self))] pub fn get_room_version(&self, room_id: &RoomId) -> Result { - let create_event = self.room_state_get(room_id, &StateEventType::RoomCreate, "")?; + let create_event = services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomCreate, "")?; let create_event_content: Option = create_event .as_ref() @@ -294,4 +293,50 @@ impl Service { pub fn get_forward_extremities(&self, room_id: &RoomId) -> Result>> { self.db.get_forward_extremities(room_id) } + + /// This fetches auth events from the current state. + #[tracing::instrument(skip(self))] + pub fn get_auth_events( + &self, + room_id: &RoomId, + kind: &RoomEventType, + sender: &UserId, + state_key: Option<&str>, + content: &serde_json::value::RawValue, + ) -> Result>> { + let shortstatehash = + if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { + current_shortstatehash + } else { + return Ok(HashMap::new()); + }; + + let auth_events = state_res::auth_types_for_event(kind, sender, state_key, content) + .expect("content is a valid JSON object"); + + let mut sauthevents = auth_events + .into_iter() + .filter_map(|(event_type, state_key)| { + services().rooms.short.get_shortstatekey(&event_type.to_string().into(), &state_key) + .ok() + .flatten() + .map(|s| (s, (event_type, state_key))) + }) + .collect::>(); + + let full_state = services().rooms.state_compressor + .load_shortstatehash_info(shortstatehash)? + .pop() + .expect("there is always one layer") + .1; + + Ok(full_state + .into_iter() + .filter_map(|compressed| services().rooms.state_compressor.parse_compressed_state_event(compressed).ok()) + .filter_map(|(shortstatekey, event_id)| { + sauthevents.remove(&shortstatekey).map(|k| (k, event_id)) + }) + .filter_map(|(k, event_id)| services().rooms.timeline.get_pdu(&event_id).ok().flatten().map(|pdu| (k, pdu))) + .collect()) + } } diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index 1911e52f..fd299489 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -7,7 +7,7 @@ use ruma::{events::StateEventType, RoomId, EventId}; use crate::{Result, PduEvent}; pub struct Service { - db: Box, + db: Arc, } impl Service { @@ -45,7 +45,7 @@ impl Service { event_type: &StateEventType, state_key: &str, ) -> Result>> { - self.db.pdu_state_get(shortstatehash, event_type, state_key) + self.db.state_get(shortstatehash, event_type, state_key) } /// Returns the state hash for this pdu. diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 18d1123e..ab6a0d6c 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -3,12 +3,23 @@ use std::{collections::HashSet, sync::Arc}; pub use data::Data; use regex::Regex; -use ruma::{RoomId, UserId, events::{room::{member::MembershipState, create::RoomCreateEventContent}, AnyStrippedStateEvent, StateEventType, tag::TagEvent, RoomAccountDataEventType, GlobalAccountDataEventType, direct::DirectEvent, ignored_user_list::IgnoredUserListEvent, AnySyncStateEvent}, serde::Raw, ServerName}; - -use crate::{Result, services, utils, Error}; +use ruma::{ + events::{ + direct::{DirectEvent, DirectEventContent}, + ignored_user_list::IgnoredUserListEvent, + room::{create::RoomCreateEventContent, member::MembershipState}, + tag::{TagEvent, TagEventContent}, + AnyStrippedStateEvent, AnySyncStateEvent, GlobalAccountDataEventType, + RoomAccountDataEventType, StateEventType, RoomAccountDataEvent, RoomAccountDataEventContent, + }, + serde::Raw, + RoomId, ServerName, UserId, +}; + +use crate::{services, utils, Error, Result}; pub struct Service { - db: Box, + db: Arc, } impl Service { @@ -45,7 +56,9 @@ impl Service { self.db.mark_as_once_joined(user_id, room_id)?; // Check if the room has a predecessor - if let Some(predecessor) = self + if let Some(predecessor) = services() + .rooms + .state_accessor .room_state_get(room_id, &StateEventType::RoomCreate, "")? .and_then(|create| serde_json::from_str(create.content.get()).ok()) .and_then(|content: RoomCreateEventContent| content.predecessor) @@ -76,27 +89,41 @@ impl Service { // .ok(); // Copy old tags to new room - if let Some(tag_event) = services().account_data.get::( - Some(&predecessor.room_id), - user_id, - RoomAccountDataEventType::Tag, - )? { - services().account_data + if let Some(tag_event) = services() + .account_data + .get( + Some(&predecessor.room_id), + user_id, + RoomAccountDataEventType::Tag, + )? + .map(|event| { + serde_json::from_str(event.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db.")) + }) + { + services() + .account_data .update( Some(room_id), user_id, RoomAccountDataEventType::Tag, - &tag_event, + &tag_event?, ) .ok(); }; // Copy direct chat flag - if let Some(mut direct_event) = services().account_data.get::( + if let Some(mut direct_event) = services().account_data.get( None, user_id, GlobalAccountDataEventType::Direct.to_string().into(), - )? { + )? + .map(|event| { + serde_json::from_str::(event.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db.")) + }) + { + let direct_event = direct_event?; let mut room_ids_updated = false; for room_ids in direct_event.content.0.values_mut() { @@ -111,7 +138,7 @@ impl Service { None, user_id, GlobalAccountDataEventType::Direct.to_string().into(), - &direct_event, + &serde_json::to_value(&direct_event).expect("to json always works"), )?; } }; @@ -124,13 +151,17 @@ impl Service { // We want to know if the sender is ignored by the receiver let is_ignored = services() .account_data - .get::( + .get( None, // Ignored users are in global account data user_id, // Receiver GlobalAccountDataEventType::IgnoredUserList .to_string() .into(), )? + .map(|event| { + serde_json::from_str::(event.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db.")) + }).transpose()? .map_or(false, |ignored| { ignored .content @@ -200,10 +231,7 @@ impl Service { } #[tracing::instrument(skip(self, room_id))] - pub fn get_our_real_users( - &self, - room_id: &RoomId, - ) -> Result>>> { + pub fn get_our_real_users(&self, room_id: &RoomId) -> Result>>> { let maybe = self .our_real_users_cache .read() diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index ab9f4275..0c32c4bd 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -9,7 +9,7 @@ use crate::{Result, utils, services}; use self::data::StateDiff; pub struct Service { - db: Box, + db: Arc, } pub type CompressedStateEvent = [u8; 2 * size_of::()]; @@ -67,7 +67,7 @@ impl Service { ) -> Result { let mut v = shortstatekey.to_be_bytes().to_vec(); v.extend_from_slice( - &self + &services().rooms.short .get_or_create_shorteventid(event_id)? .to_be_bytes(), ); @@ -218,7 +218,7 @@ impl Service { HashSet, // added HashSet)> // removed { - let previous_shortstatehash = self.db.current_shortstatehash(room_id)?; + let previous_shortstatehash = services().rooms.state.get_room_shortstatehash(room_id)?; let state_hash = utils::calculate_hash( &new_state_ids_compressed diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index d073e865..2220b5f2 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -5,6 +5,7 @@ use ruma::{signatures::CanonicalJsonObject, EventId, UserId, RoomId}; use crate::{Result, PduEvent}; pub trait Data: Send + Sync { + fn first_pdu_in_room(&self, room_id: &RoomId) -> Result>>; fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result; /// Returns the `count` of this pdu's id. diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index e8f42053..78172255 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -21,33 +21,14 @@ use crate::{services, Result, service::pdu::{PduBuilder, EventHash}, Error, PduE use super::state_compressor::CompressedStateEvent; pub struct Service { - db: Box, + db: Arc, } impl Service { - /* - /// Checks if a room exists. #[tracing::instrument(skip(self))] pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result>> { - let prefix = self - .get_shortroomid(room_id)? - .expect("room exists") - .to_be_bytes() - .to_vec(); - - // Look for PDUs in that room. - self.pduid_pdu - .iter_from(&prefix, false) - .filter(|(k, _)| k.starts_with(&prefix)) - .map(|(_, pdu)| { - serde_json::from_slice(&pdu) - .map_err(|_| Error::bad_database("Invalid first PDU in db.")) - .map(Arc::new) - }) - .next() - .transpose() + self.db.first_pdu_in_room(room_id) } - */ #[tracing::instrument(skip(self))] pub fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { @@ -681,7 +662,8 @@ impl Service { /// Append the incoming event setting the state snapshot to the state from the /// server that sent the event. #[tracing::instrument(skip_all)] - fn append_incoming_pdu<'a>( + pub fn append_incoming_pdu<'a>( + &self, pdu: &PduEvent, pdu_json: CanonicalJsonObject, new_room_leaves: impl IntoIterator + Clone + Debug, diff --git a/src/service/rooms/user/mod.rs b/src/service/rooms/user/mod.rs index 7c7dfae6..394a550a 100644 --- a/src/service/rooms/user/mod.rs +++ b/src/service/rooms/user/mod.rs @@ -1,11 +1,13 @@ mod data; +use std::sync::Arc; + pub use data::Data; use ruma::{RoomId, UserId}; use crate::Result; pub struct Service { - db: Box, + db: Arc, } impl Service { diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index 8ab557f6..fde251b7 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -448,14 +448,6 @@ impl Service { Ok(()) } - #[tracing::instrument(skip(keys))] - fn calculate_hash(keys: &[&[u8]]) -> Vec { - // We only hash the pdu's event ids, not the whole pdu - let bytes = keys.join(&0xff); - let hash = digest::digest(&digest::SHA256, &bytes); - hash.as_ref().to_owned() - } - /// Cleanup event data /// Used for instance after we remove an appservice registration /// diff --git a/src/service/transaction_ids/mod.rs b/src/service/transaction_ids/mod.rs index a9c516cf..d7066e24 100644 --- a/src/service/transaction_ids/mod.rs +++ b/src/service/transaction_ids/mod.rs @@ -1,11 +1,13 @@ mod data; +use std::sync::Arc; + pub use data::Data; use ruma::{UserId, DeviceId, TransactionId}; use crate::Result; pub struct Service { - db: Box, + db: Arc, } impl Service { diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index 01c0d2f6..73b2273d 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -1,4 +1,6 @@ mod data; +use std::sync::Arc; + pub use data::Data; use ruma::{api::client::{uiaa::{UiaaInfo, IncomingAuthData, IncomingPassword, AuthType, IncomingUserIdentifier}, error::ErrorKind}, DeviceId, UserId, signatures::CanonicalJsonValue}; @@ -7,7 +9,7 @@ use tracing::error; use crate::{Result, utils, Error, services, api::client_server::SESSION_ID_LENGTH}; pub struct Service { - db: Box, + db: Arc, } impl Service { diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index b13ae1f2..2cf18765 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -1,5 +1,5 @@ mod data; -use std::{collections::BTreeMap, mem}; +use std::{collections::BTreeMap, mem, sync::Arc}; pub use data::Data; use ruma::{UserId, MxcUri, DeviceId, DeviceKeyId, serde::Raw, encryption::{OneTimeKey, CrossSigningKey, DeviceKeys}, DeviceKeyAlgorithm, UInt, events::AnyToDeviceEvent, api::client::{device::Device, filter::IncomingFilterDefinition, error::ErrorKind}, RoomAliasId}; @@ -7,7 +7,7 @@ use ruma::{UserId, MxcUri, DeviceId, DeviceKeyId, serde::Raw, encryption::{OneTi use crate::{Result, Error, services}; pub struct Service { - db: Box, + db: Arc, } impl Service { diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 734da2a8..0ee3ae84 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -3,6 +3,7 @@ pub mod error; use argon2::{Config, Variant}; use cmp::Ordering; use rand::prelude::*; +use ring::digest; use ruma::serde::{try_from_json_map, CanonicalJsonError, CanonicalJsonObject}; use std::{ cmp, fmt, @@ -59,7 +60,7 @@ pub fn random_string(length: usize) -> String { } /// Calculate a new hash for the given password -pub fn calculate_hash(password: &str) -> Result { +pub fn calculate_password_hash(password: &str) -> Result { let hashing_config = Config { variant: Variant::Argon2id, ..Default::default() @@ -69,6 +70,15 @@ pub fn calculate_hash(password: &str) -> Result { argon2::hash_encoded(password.as_bytes(), salt.as_bytes(), &hashing_config) } +#[tracing::instrument(skip(keys))] +pub fn calculate_hash(keys: &[&[u8]]) -> Vec { + // We only hash the pdu's event ids, not the whole pdu + let bytes = keys.join(&0xff); + let hash = digest::digest(&digest::SHA256, &bytes); + hash.as_ref().to_owned() +} + + pub fn common_elements( mut iterators: impl Iterator>>, check_order: impl Fn(&[u8], &[u8]) -> Ordering,