messing with trait objects

Nyaaori/refactor-next
Timo Kösters 2 years ago committed by Nyaaori
parent 8708cd3b63
commit face766e0f
No known key found for this signature in database
GPG Key ID: E7819C3ED4D1F82E

@ -481,7 +481,7 @@ async fn join_room_by_id_helper(
let (make_join_response, remote_server) = make_join_response_and_server?;
let room_version = match make_join_response.room_version {
Some(room_version) if services().rooms.metadata.is_supported_version(&room_version) => room_version,
Some(room_version) if services().globals.supported_room_versions().contains(&room_version) => room_version,
_ => return Err(Error::BadServerResponse("Room version is not supported")),
};
@ -568,7 +568,7 @@ async fn join_room_by_id_helper(
let mut state = HashMap::new();
let pub_key_map = RwLock::new(BTreeMap::new());
server_server::fetch_join_signing_keys(
services().rooms.event_handler.fetch_join_signing_keys(
&send_join_response,
&room_version,
&pub_key_map,
@ -1048,7 +1048,7 @@ async fn remote_leave_room(
let (make_leave_response, remote_server) = make_leave_response_and_server?;
let room_version_id = match make_leave_response.room_version {
Some(version) if services().rooms.is_supported_version(&version) => version,
Some(version) if services().globals.supported_room_versions().contains(&version) => version,
_ => return Err(Error::BadServerResponse("Room version is not supported")),
};

@ -99,7 +99,7 @@ pub async fn create_room_route(
let room_version = match body.room_version.clone() {
Some(room_version) => {
if services().rooms.is_supported_version(&services(), &room_version) {
if services().globals.supported_room_versions().contains(&room_version) {
room_version
} else {
return Err(Error::BadRequest(
@ -470,7 +470,7 @@ pub async fn upgrade_room_route(
) -> Result<upgrade_room::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !services().rooms.is_supported_version(&body.new_version) {
if !services().globals.supported_room_versions().contains(&body.new_version) {
return Err(Error::BadRequest(
ErrorKind::UnsupportedRoomVersion,
"This server does not support that room version.",

@ -175,7 +175,7 @@ async fn sync_helper(
services().rooms.edus.presence.ping_presence(&sender_user)?;
// Setup watchers, so if there's no response, we can wait for them
let watcher = services().watch(&sender_user, &sender_device);
let watcher = services().globals.db.watch(&sender_user, &sender_device);
let next_batch = services().globals.current_count()?;
let next_batch_string = next_batch.to_string();

@ -197,7 +197,7 @@ where
request_map.insert("content".to_owned(), json_body.clone());
};
let keys_result = server_server::fetch_signing_keys(
let keys_result = services().rooms.event_handler.fetch_signing_keys(
&x_matrix.origin,
vec![x_matrix.key.to_owned()],
)

@ -664,7 +664,7 @@ pub async fn send_transaction_message_route(
Some(id) => id,
None => {
// Event is invalid
resolved_map.insert(event_id, Err("Event needs a valid RoomId.".to_owned()));
resolved_map.insert(event_id, Err(Error::bad_database("Event needs a valid RoomId.")));
continue;
}
};
@ -707,7 +707,7 @@ pub async fn send_transaction_message_route(
for pdu in &resolved_map {
if let Err(e) = pdu.1 {
if e != "Room is unknown to this server." {
if matches!(e, Error::BadRequest(ErrorKind::NotFound, _)) {
warn!("Incoming PDU failed {:?}", pdu);
}
}
@ -854,170 +854,7 @@ pub async fn send_transaction_message_route(
}
}
Ok(send_transaction_message::v1::Response { pdus: resolved_map })
}
/// Search the DB for the signing keys of the given server, if we don't have them
/// fetch them from the server and save to our DB.
#[tracing::instrument(skip_all)]
pub(crate) async fn fetch_signing_keys(
origin: &ServerName,
signature_ids: Vec<String>,
) -> Result<BTreeMap<String, Base64>> {
let contains_all_ids =
|keys: &BTreeMap<String, Base64>| signature_ids.iter().all(|id| keys.contains_key(id));
let permit = services()
.globals
.servername_ratelimiter
.read()
.unwrap()
.get(origin)
.map(|s| Arc::clone(s).acquire_owned());
let permit = match permit {
Some(p) => p,
None => {
let mut write = services().globals.servername_ratelimiter.write().unwrap();
let s = Arc::clone(
write
.entry(origin.to_owned())
.or_insert_with(|| Arc::new(Semaphore::new(1))),
);
s.acquire_owned()
}
}
.await;
let back_off = |id| match services()
.globals
.bad_signature_ratelimiter
.write()
.unwrap()
.entry(id)
{
hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
}
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
};
if let Some((time, tries)) = services()
.globals
.bad_signature_ratelimiter
.read()
.unwrap()
.get(&signature_ids)
{
// Exponential backoff
let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
min_elapsed_duration = Duration::from_secs(60 * 60 * 24);
}
if time.elapsed() < min_elapsed_duration {
debug!("Backing off from {:?}", signature_ids);
return Err(Error::BadServerResponse("bad signature, still backing off"));
}
}
trace!("Loading signing keys for {}", origin);
let mut result: BTreeMap<_, _> = services()
.globals
.signing_keys_for(origin)?
.into_iter()
.map(|(k, v)| (k.to_string(), v.key))
.collect();
if contains_all_ids(&result) {
return Ok(result);
}
debug!("Fetching signing keys for {} over federation", origin);
if let Some(server_key) = services()
.sending
.send_federation_request(origin, get_server_keys::v2::Request::new())
.await
.ok()
.and_then(|resp| resp.server_key.deserialize().ok())
{
services().globals.add_signing_key(origin, server_key.clone())?;
result.extend(
server_key
.verify_keys
.into_iter()
.map(|(k, v)| (k.to_string(), v.key)),
);
result.extend(
server_key
.old_verify_keys
.into_iter()
.map(|(k, v)| (k.to_string(), v.key)),
);
if contains_all_ids(&result) {
return Ok(result);
}
}
for server in services().globals.trusted_servers() {
debug!("Asking {} for {}'s signing key", server, origin);
if let Some(server_keys) = services()
.sending
.send_federation_request(
server,
get_remote_server_keys::v2::Request::new(
origin,
MilliSecondsSinceUnixEpoch::from_system_time(
SystemTime::now()
.checked_add(Duration::from_secs(3600))
.expect("SystemTime to large"),
)
.expect("time is valid"),
),
)
.await
.ok()
.map(|resp| {
resp.server_keys
.into_iter()
.filter_map(|e| e.deserialize().ok())
.collect::<Vec<_>>()
})
{
trace!("Got signing keys: {:?}", server_keys);
for k in server_keys {
services().globals.add_signing_key(origin, k.clone())?;
result.extend(
k.verify_keys
.into_iter()
.map(|(k, v)| (k.to_string(), v.key)),
);
result.extend(
k.old_verify_keys
.into_iter()
.map(|(k, v)| (k.to_string(), v.key)),
);
}
if contains_all_ids(&result) {
return Ok(result);
}
}
}
drop(permit);
back_off(signature_ids);
warn!("Failed to find public key for server: {}", origin);
Err(Error::BadServerResponse(
"Failed to find public key for server",
))
Ok(send_transaction_message::v1::Response { pdus: resolved_map.into_iter().map(|(e, r)| (e, r.map_err(|e| e.to_string()))).collect() })
}
#[tracing::instrument(skip(starting_events))]
@ -1050,7 +887,7 @@ pub(crate) async fn get_auth_chain<'a>(
}
let chunk_key: Vec<u64> = chunk.iter().map(|(short, _)| short).copied().collect();
if let Some(cached) = services().rooms.auth_chain.get_auth_chain_from_cache(&chunk_key)? {
if let Some(cached) = services().rooms.auth_chain.get_cached_eventid_authchain(&chunk_key)? {
hits += 1;
full_auth_chain.extend(cached.iter().copied());
continue;
@ -1062,7 +899,7 @@ pub(crate) async fn get_auth_chain<'a>(
let mut misses2 = 0;
let mut i = 0;
for (sevent_id, event_id) in chunk {
if let Some(cached) = services().rooms.auth_chain.get_auth_chain_from_cache(&[sevent_id])? {
if let Some(cached) = services().rooms.auth_chain.get_cached_eventid_authchain(&[sevent_id])? {
hits2 += 1;
chunk_cache.extend(cached.iter().copied());
} else {
@ -1689,7 +1526,7 @@ pub async fn create_invite_route(
services().rooms.event_handler.acl_check(&sender_servername, &body.room_id)?;
if !services().rooms.is_supported_version(&body.room_version) {
if !services().globals.supported_room_versions().contains(&body.room_version) {
return Err(Error::BadRequest(
ErrorKind::IncompatibleRoomVersion {
room_version: body.room_version.clone(),

@ -4,10 +4,10 @@ use crate::{Result, service, database::KeyValueDatabase, Error, utils};
impl service::globals::Data for KeyValueDatabase {
fn load_keypair(&self) -> Result<Ed25519KeyPair> {
let keypair_bytes = self.globals.get(b"keypair")?.map_or_else(
let keypair_bytes = self.global.get(b"keypair")?.map_or_else(
|| {
let keypair = utils::generate_keypair();
self.globals.insert(b"keypair", &keypair)?;
self.global.insert(b"keypair", &keypair)?;
Ok::<_, Error>(keypair)
},
|s| Ok(s.to_vec()),
@ -33,8 +33,10 @@ impl service::globals::Data for KeyValueDatabase {
Ed25519KeyPair::from_der(key, version)
.map_err(|_| Error::bad_database("Private or public keys are invalid."))
});
keypair
}
fn remove_keypair(&self) -> Result<()> {
self.globals.remove(b"keypair")?
self.global.remove(b"keypair")
}
}

@ -1,3 +1,5 @@
use ruma::api::client::error::ErrorKind;
use crate::{database::KeyValueDatabase, service, Error, utils, Result};
impl service::media::Data for KeyValueDatabase {
@ -33,7 +35,7 @@ impl service::media::Data for KeyValueDatabase {
prefix.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail
prefix.push(0xff);
let (key, _) = self.mediaid_file.scan_prefix(prefix).next().ok_or(Error::NotFound)?;
let (key, _) = self.mediaid_file.scan_prefix(prefix).next().ok_or(Error::BadRequest(ErrorKind::NotFound, "Media not found"))?;
let mut parts = key.rsplit(|&b| b == 0xff);

@ -55,6 +55,6 @@ impl service::pusher::Data for KeyValueDatabase {
let mut prefix = sender.as_bytes().to_vec();
prefix.push(0xff);
self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| k)
Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| k))
}
}

@ -56,15 +56,15 @@ impl service::rooms::alias::Data for KeyValueDatabase {
fn local_aliases_for_room(
&self,
room_id: &RoomId,
) -> Result<Box<dyn Iterator<Item=String>>> {
) -> Box<dyn Iterator<Item = Result<Box<RoomAliasId>>>> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff);
self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| {
Box::new(self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| {
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?
.try_into()
.map_err(|_| Error::bad_database("Invalid alias in aliasid_alias."))
})
}))
}
}

@ -3,8 +3,8 @@ use std::{collections::HashSet, mem::size_of};
use crate::{service, database::KeyValueDatabase, Result, utils};
impl service::rooms::auth_chain::Data for KeyValueDatabase {
fn get_cached_eventid_authchain(&self, shorteventid: u64) -> Result<HashSet<u64>> {
self.shorteventid_authchain
fn get_cached_eventid_authchain(&self, shorteventid: u64) -> Result<Option<HashSet<u64>>> {
Ok(self.shorteventid_authchain
.get(&shorteventid.to_be_bytes())?
.map(|chain| {
chain
@ -13,7 +13,7 @@ impl service::rooms::auth_chain::Data for KeyValueDatabase {
utils::u64_from_bytes(chunk).expect("byte length is correct")
})
.collect()
})
}))
}
fn cache_eventid_authchain(&self, shorteventid: u64, auth_chain: &HashSet<u64>) -> Result<()> {

@ -145,4 +145,6 @@ fn parse_presence_event(bytes: &[u8]) -> Result<PresenceEvent> {
.last_active_ago
.map(|timestamp| current_timestamp - timestamp);
}
Ok(presence)
}

@ -64,7 +64,7 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
let mut first_possible_edu = prefix.clone();
first_possible_edu.extend_from_slice(&(since + 1).to_be_bytes()); // +1 so we don't send the event at since
self.readreceiptid_readreceipt
Box::new(self.readreceiptid_readreceipt
.iter_from(&first_possible_edu, false)
.take_while(move |(k, _)| k.starts_with(&prefix2))
.map(move |(k, v)| {
@ -91,7 +91,7 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
serde_json::value::to_raw_value(&json).expect("json is valid raw value"),
),
))
})
}))
}
fn private_read_set(

@ -25,26 +25,19 @@ impl service::rooms::lazy_loading::Data for KeyValueDatabase {
user_id: &UserId,
device_id: &DeviceId,
room_id: &RoomId,
since: u64,
confirmed_user_ids: &mut dyn Iterator<Item = &UserId>,
) -> Result<()> {
if let Some(user_ids) = self.lazy_load_waiting.lock().unwrap().remove(&(
user_id.to_owned(),
device_id.to_owned(),
room_id.to_owned(),
since,
)) {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xff);
prefix.extend_from_slice(room_id.as_bytes());
prefix.push(0xff);
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xff);
prefix.extend_from_slice(room_id.as_bytes());
prefix.push(0xff);
for ll_id in user_ids {
let mut key = prefix.clone();
key.extend_from_slice(ll_id.as_bytes());
self.lazyloadedids.insert(&key, &[])?;
}
for ll_id in confirmed_user_ids {
let mut key = prefix.clone();
key.extend_from_slice(ll_id.as_bytes());
self.lazyloadedids.insert(&key, &[])?;
}
Ok(())

@ -1,10 +1,10 @@
use ruma::RoomId;
use crate::{service, database::KeyValueDatabase, Result};
use crate::{service, database::KeyValueDatabase, Result, services};
impl service::rooms::metadata::Data for KeyValueDatabase {
fn exists(&self, room_id: &RoomId) -> Result<bool> {
let prefix = match self.get_shortroomid(room_id)? {
let prefix = match services().rooms.short.get_shortroomid(room_id)? {
Some(b) => b.to_be_bytes().to_vec(),
None => return Ok(false),
};

@ -2,10 +2,10 @@ use std::mem::size_of;
use ruma::RoomId;
use crate::{service, database::KeyValueDatabase, utils, Result};
use crate::{service, database::KeyValueDatabase, utils, Result, services};
impl service::rooms::search::Data for KeyValueDatabase {
fn index_pdu<'a>(&self, shortroomid: u64, pdu_id: u64, message_body: String) -> Result<()> {
fn index_pdu<'a>(&self, shortroomid: u64, pdu_id: &[u8], message_body: String) -> Result<()> {
let mut batch = message_body
.split_terminator(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty())
@ -27,7 +27,7 @@ impl service::rooms::search::Data for KeyValueDatabase {
room_id: &RoomId,
search_string: &str,
) -> Result<Option<(Box<dyn Iterator<Item = Vec<u8>>>, Vec<String>)>> {
let prefix = self
let prefix = services().rooms.short
.get_shortroomid(room_id)?
.expect("room exists")
.to_be_bytes()
@ -60,11 +60,11 @@ impl service::rooms::search::Data for KeyValueDatabase {
})
.map(|iter| {
(
iter.map(move |id| {
Box::new(iter.map(move |id| {
let mut pduid = prefix_clone.clone();
pduid.extend_from_slice(&id);
pduid
}),
})),
words,
)
}))

@ -1,13 +1,13 @@
use std::{collections::{BTreeMap, HashMap}, sync::Arc};
use crate::{database::KeyValueDatabase, service, PduEvent, Error, utils, Result};
use crate::{database::KeyValueDatabase, service, PduEvent, Error, utils, Result, services};
use async_trait::async_trait;
use ruma::{EventId, events::StateEventType, RoomId};
#[async_trait]
impl service::rooms::state_accessor::Data for KeyValueDatabase {
async fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeMap<u64, Arc<EventId>>> {
let full_state = self
let full_state = services().rooms.state_compressor
.load_shortstatehash_info(shortstatehash)?
.pop()
.expect("there is always one layer")
@ -15,7 +15,7 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
let mut result = BTreeMap::new();
let mut i = 0;
for compressed in full_state.into_iter() {
let parsed = self.parse_compressed_state_event(compressed)?;
let parsed = services().rooms.state_compressor.parse_compressed_state_event(compressed)?;
result.insert(parsed.0, parsed.1);
i += 1;
@ -30,7 +30,7 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
&self,
shortstatehash: u64,
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
let full_state = self
let full_state = services().rooms.state_compressor
.load_shortstatehash_info(shortstatehash)?
.pop()
.expect("there is always one layer")
@ -39,8 +39,8 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
let mut result = HashMap::new();
let mut i = 0;
for compressed in full_state {
let (_, eventid) = self.parse_compressed_state_event(compressed)?;
if let Some(pdu) = self.get_pdu(&eventid)? {
let (_, eventid) = services().rooms.state_compressor.parse_compressed_state_event(compressed)?;
if let Some(pdu) = services().rooms.timeline.get_pdu(&eventid)? {
result.insert(
(
pdu.kind.to_string().into(),
@ -69,11 +69,11 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
event_type: &StateEventType,
state_key: &str,
) -> Result<Option<Arc<EventId>>> {
let shortstatekey = match self.get_shortstatekey(event_type, state_key)? {
let shortstatekey = match services().rooms.short.get_shortstatekey(event_type, state_key)? {
Some(s) => s,
None => return Ok(None),
};
let full_state = self
let full_state = services().rooms.state_compressor
.load_shortstatehash_info(shortstatehash)?
.pop()
.expect("there is always one layer")
@ -82,7 +82,7 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
.into_iter()
.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
.and_then(|compressed| {
self.parse_compressed_state_event(compressed)
services().rooms.state_compressor.parse_compressed_state_event(compressed)
.ok()
.map(|(_, id)| id)
}))
@ -96,7 +96,7 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
state_key: &str,
) -> Result<Option<Arc<PduEvent>>> {
self.state_get_id(shortstatehash, event_type, state_key)?
.map_or(Ok(None), |event_id| self.get_pdu(&event_id))
.map_or(Ok(None), |event_id| services().rooms.timeline.get_pdu(&event_id))
}
/// Returns the state hash for this pdu.
@ -122,7 +122,7 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
&self,
room_id: &RoomId,
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? {
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
self.state_full(current_shortstatehash).await
} else {
Ok(HashMap::new())
@ -136,7 +136,7 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
event_type: &StateEventType,
state_key: &str,
) -> Result<Option<Arc<EventId>>> {
if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? {
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
self.state_get_id(current_shortstatehash, event_type, state_key)
} else {
Ok(None)
@ -150,7 +150,7 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
event_type: &StateEventType,
state_key: &str,
) -> Result<Option<Arc<PduEvent>>> {
if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? {
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
self.state_get(current_shortstatehash, event_type, state_key)
} else {
Ok(None)

@ -39,8 +39,8 @@ impl service::rooms::state_compressor::Data for KeyValueDatabase {
}
fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()> {
let mut value = diff.parent.to_be_bytes().to_vec();
for new in &diff.new {
let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec();
for new in &diff.added {
value.extend_from_slice(&new[..]);
}

@ -3,7 +3,7 @@ use std::{collections::hash_map, mem::size_of, sync::Arc};
use ruma::{UserId, RoomId, api::client::error::ErrorKind, EventId, signatures::CanonicalJsonObject};
use tracing::error;
use crate::{service, database::KeyValueDatabase, utils, Error, PduEvent, Result};
use crate::{service, database::KeyValueDatabase, utils, Error, PduEvent, Result, services};
impl service::rooms::timeline::Data for KeyValueDatabase {
fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<u64> {
@ -191,7 +191,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
room_id: &RoomId,
since: u64,
) -> Result<Box<dyn Iterator<Item = Result<(Vec<u8>, PduEvent)>>>> {
let prefix = self
let prefix = services().rooms.short
.get_shortroomid(room_id)?
.expect("room exists")
.to_be_bytes()
@ -203,7 +203,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
let user_id = user_id.to_owned();
Ok(self
Ok(Box::new(self
.pduid_pdu
.iter_from(&first_pdu_id, false)
.take_while(move |(k, _)| k.starts_with(&prefix))
@ -214,7 +214,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
pdu.remove_transaction_id()?;
}
Ok((pdu_id, pdu))
}))
})))
}
/// Returns an iterator over all events and their tokens in a room that happened before the
@ -226,7 +226,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
until: u64,
) -> Result<Box<dyn Iterator<Item = Result<(Vec<u8>, PduEvent)>>>> {
// Create the first part of the full pdu id
let prefix = self
let prefix = services().rooms.short
.get_shortroomid(room_id)?
.expect("room exists")
.to_be_bytes()
@ -239,7 +239,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
let user_id = user_id.to_owned();
Ok(self
Ok(Box::new(self
.pduid_pdu
.iter_from(current, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
@ -250,7 +250,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
pdu.remove_transaction_id()?;
}
Ok((pdu_id, pdu))
}))
})))
}
fn pdus_after<'a>(
@ -260,7 +260,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
from: u64,
) -> Result<Box<dyn Iterator<Item = Result<(Vec<u8>, PduEvent)>>>> {
// Create the first part of the full pdu id
let prefix = self
let prefix = services().rooms.short
.get_shortroomid(room_id)?
.expect("room exists")
.to_be_bytes()
@ -273,7 +273,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
let user_id = user_id.to_owned();
Ok(self
Ok(Box::new(self
.pduid_pdu
.iter_from(current, false)
.take_while(move |(k, _)| k.starts_with(&prefix))
@ -284,6 +284,6 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
pdu.remove_transaction_id()?;
}
Ok((pdu_id, pdu))
}))
})))
}
}

@ -1,6 +1,6 @@
use ruma::{UserId, RoomId};
use crate::{service, database::KeyValueDatabase, utils, Error, Result};
use crate::{service, database::KeyValueDatabase, utils, Error, Result, services};
impl service::rooms::user::Data for KeyValueDatabase {
fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
@ -50,7 +50,7 @@ impl service::rooms::user::Data for KeyValueDatabase {
token: u64,
shortstatehash: u64,
) -> Result<()> {
let shortroomid = self.get_shortroomid(room_id)?.expect("room exists");
let shortroomid = services().rooms.short.get_shortroomid(room_id)?.expect("room exists");
let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(&token.to_be_bytes());
@ -60,7 +60,7 @@ impl service::rooms::user::Data for KeyValueDatabase {
}
fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> {
let shortroomid = self.get_shortroomid(room_id)?.expect("room exists");
let shortroomid = services().rooms.short.get_shortroomid(room_id)?.expect("room exists");
let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(&token.to_be_bytes());

@ -57,12 +57,12 @@ impl service::users::Data for KeyValueDatabase {
/// Returns an iterator over all users on this homeserver.
fn iter(&self) -> Box<dyn Iterator<Item = Result<Box<UserId>>>> {
self.userid_password.iter().map(|(bytes, _)| {
Box::new(self.userid_password.iter().map(|(bytes, _)| {
UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("User ID in userid_password is invalid unicode.")
})?)
.map_err(|_| Error::bad_database("User ID in userid_password is invalid."))
})
}))
}
/// Returns a list of local users as list of usernames.
@ -274,7 +274,7 @@ impl service::users::Data for KeyValueDatabase {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff);
// All devices have metadata
self.userdeviceid_metadata
Box::new(self.userdeviceid_metadata
.scan_prefix(prefix)
.map(|(bytes, _)| {
Ok(utils::string_from_bytes(
@ -285,7 +285,7 @@ impl service::users::Data for KeyValueDatabase {
)
.map_err(|_| Error::bad_database("Device ID in userdeviceid_metadata is invalid."))?
.into())
})
}))
}
/// Replaces the access token of one device.
@ -617,7 +617,7 @@ impl service::users::Data for KeyValueDatabase {
let to = to.unwrap_or(u64::MAX);
self.keychangeid_userid
Box::new(self.keychangeid_userid
.iter_from(&start, false)
.take_while(move |(k, _)| {
k.starts_with(&prefix)
@ -638,7 +638,7 @@ impl service::users::Data for KeyValueDatabase {
Error::bad_database("User ID in devicekeychangeid_userid is invalid unicode.")
})?)
.map_err(|_| Error::bad_database("User ID in devicekeychangeid_userid is invalid."))
})
}))
}
fn mark_device_key_update(
@ -646,9 +646,10 @@ impl service::users::Data for KeyValueDatabase {
user_id: &UserId,
) -> Result<()> {
let count = services().globals.next_count()?.to_be_bytes();
for room_id in services().rooms.rooms_joined(user_id).filter_map(|r| r.ok()) {
for room_id in services().rooms.state_cache.rooms_joined(user_id).filter_map(|r| r.ok()) {
// Don't send key updates to unencrypted rooms
if services().rooms
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomEncryption, "")?
.is_none()
{
@ -882,12 +883,12 @@ impl service::users::Data for KeyValueDatabase {
let mut key = user_id.as_bytes().to_vec();
key.push(0xff);
self.userdeviceid_metadata
Box::new(self.userdeviceid_metadata
.scan_prefix(key)
.map(|(_, bytes)| {
serde_json::from_slice::<Device>(&bytes)
.map_err(|_| Error::bad_database("Device in userdeviceid_metadata is invalid."))
})
}))
}
/// Creates a new sync filter. Returns the filter id.

@ -1,7 +1,7 @@
pub mod abstraction;
pub mod key_value;
use crate::{utils, Config, Error, Result, service::{users, globals, uiaa, rooms, account_data, media, key_backups, transaction_ids, sending, appservice, pusher}};
use crate::{utils, Config, Error, Result, service::{users, globals, uiaa, rooms::{self, state_compressor::CompressedStateEvent}, account_data, media, key_backups, transaction_ids, sending, appservice, pusher}, services, PduEvent, Services, SERVICES};
use abstraction::KeyValueDatabaseEngine;
use directories::ProjectDirs;
use futures_util::{stream::FuturesUnordered, StreamExt};
@ -9,7 +9,7 @@ use lru_cache::LruCache;
use ruma::{
events::{
push_rules::PushRulesEventContent, room::message::RoomMessageEventContent,
GlobalAccountDataEvent, GlobalAccountDataEventType,
GlobalAccountDataEvent, GlobalAccountDataEventType, StateEventType,
},
push::Ruleset,
DeviceId, EventId, RoomId, UserId, signatures::CanonicalJsonValue,
@ -151,6 +151,30 @@ pub struct KeyValueDatabase {
//pub pusher: pusher::PushData,
pub(super) senderkey_pusher: Arc<dyn KvTree>,
pub(super) cached_registrations: Arc<RwLock<HashMap<String, serde_yaml::Value>>>,
pub(super) pdu_cache: Mutex<LruCache<Box<EventId>, Arc<PduEvent>>>,
pub(super) shorteventid_cache: Mutex<LruCache<u64, Arc<EventId>>>,
pub(super) auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<HashSet<u64>>>>,
pub(super) eventidshort_cache: Mutex<LruCache<Box<EventId>, u64>>,
pub(super) statekeyshort_cache: Mutex<LruCache<(StateEventType, String), u64>>,
pub(super) shortstatekey_cache: Mutex<LruCache<u64, (StateEventType, String)>>,
pub(super) our_real_users_cache: RwLock<HashMap<Box<RoomId>, Arc<HashSet<Box<UserId>>>>>,
pub(super) appservice_in_room_cache: RwLock<HashMap<Box<RoomId>, HashMap<String, bool>>>,
pub(super) lazy_load_waiting:
Mutex<HashMap<(Box<UserId>, Box<DeviceId>, Box<RoomId>, u64), HashSet<Box<UserId>>>>,
pub(super) stateinfo_cache: Mutex<
LruCache<
u64,
Vec<(
u64, // sstatehash
HashSet<CompressedStateEvent>, // full state
HashSet<CompressedStateEvent>, // added
HashSet<CompressedStateEvent>, // removed
)>,
>,
>,
pub(super) lasttimelinecount_cache: Mutex<HashMap<Box<RoomId>, u64>>,
}
impl KeyValueDatabase {
@ -214,7 +238,7 @@ impl KeyValueDatabase {
}
/// Load an existing database or create a new one.
pub async fn load_or_create(config: &Config) -> Result<Arc<TokioRwLock<Self>>> {
pub async fn load_or_create(config: &Config) -> Result<()> {
Self::check_db_setup(config)?;
if !Path::new(&config.database_path).exists() {
@ -253,7 +277,7 @@ impl KeyValueDatabase {
let (admin_sender, admin_receiver) = mpsc::unbounded_channel();
let (sending_sender, sending_receiver) = mpsc::unbounded_channel();
let db = Self {
let db = Arc::new(Self {
_db: builder.clone(),
userid_password: builder.open_tree("userid_password")?,
userid_displayname: builder.open_tree("userid_displayname")?,
@ -345,18 +369,53 @@ impl KeyValueDatabase {
senderkey_pusher: builder.open_tree("senderkey_pusher")?,
global: builder.open_tree("global")?,
server_signingkeys: builder.open_tree("server_signingkeys")?,
};
// TODO: do this after constructing the db
cached_registrations: Arc::new(RwLock::new(HashMap::new())),
pdu_cache: Mutex::new(LruCache::new(
config
.pdu_cache_capacity
.try_into()
.expect("pdu cache capacity fits into usize"),
)),
auth_chain_cache: Mutex::new(LruCache::new(
(100_000.0 * config.conduit_cache_capacity_modifier) as usize,
)),
shorteventid_cache: Mutex::new(LruCache::new(
(100_000.0 * config.conduit_cache_capacity_modifier) as usize,
)),
eventidshort_cache: Mutex::new(LruCache::new(
(100_000.0 * config.conduit_cache_capacity_modifier) as usize,
)),
shortstatekey_cache: Mutex::new(LruCache::new(
(100_000.0 * config.conduit_cache_capacity_modifier) as usize,
)),
statekeyshort_cache: Mutex::new(LruCache::new(
(100_000.0 * config.conduit_cache_capacity_modifier) as usize,
)),
our_real_users_cache: RwLock::new(HashMap::new()),
appservice_in_room_cache: RwLock::new(HashMap::new()),
lazy_load_waiting: Mutex::new(HashMap::new()),
stateinfo_cache: Mutex::new(LruCache::new(
(100.0 * config.conduit_cache_capacity_modifier) as usize,
)),
lasttimelinecount_cache: Mutex::new(HashMap::new()),
});
let services_raw = Services::build(Arc::clone(&db));
// This is the first and only time we initialize the SERVICE static
*SERVICES.write().unwrap() = Some(services_raw);
// Matrix resource ownership is based on the server name; changing it
// requires recreating the database from scratch.
if guard.users.count()? > 0 {
if services().users.count()? > 0 {
let conduit_user =
UserId::parse_with_server_name("conduit", guard.globals.server_name())
UserId::parse_with_server_name("conduit", services().globals.server_name())
.expect("@conduit:server_name is valid");
if !guard.users.exists(&conduit_user)? {
if !services().users.exists(&conduit_user)? {
error!(
"The {} server user does not exist, and the database is not new.",
conduit_user
@ -370,11 +429,10 @@ impl KeyValueDatabase {
// If the database has any data, perform data migrations before starting
let latest_database_version = 11;
if guard.users.count()? > 0 {
let db = &*guard;
if services().users.count()? > 0 {
// MIGRATIONS
if db.globals.database_version()? < 1 {
for (roomserverid, _) in db.rooms.roomserverids.iter() {
if services().globals.database_version()? < 1 {
for (roomserverid, _) in db.roomserverids.iter() {
let mut parts = roomserverid.split(|&b| b == 0xff);
let room_id = parts.next().expect("split always returns one element");
let servername = match parts.next() {
@ -388,17 +446,17 @@ impl KeyValueDatabase {
serverroomid.push(0xff);
serverroomid.extend_from_slice(room_id);
db.rooms.serverroomids.insert(&serverroomid, &[])?;
db.serverroomids.insert(&serverroomid, &[])?;
}
db.globals.bump_database_version(1)?;
services().globals.bump_database_version(1)?;
warn!("Migration: 0 -> 1 finished");
}
if db.globals.database_version()? < 2 {
if services().globals.database_version()? < 2 {
// We accidentally inserted hashed versions of "" into the db instead of just ""
for (userid, password) in db.users.userid_password.iter() {
for (userid, password) in db.userid_password.iter() {
let password = utils::string_from_bytes(&password);
let empty_hashed_password = password.map_or(false, |password| {
@ -406,59 +464,59 @@ impl KeyValueDatabase {
});
if empty_hashed_password {
db.users.userid_password.insert(&userid, b"")?;
db.userid_password.insert(&userid, b"")?;
}
}
db.globals.bump_database_version(2)?;
services().globals.bump_database_version(2)?;
warn!("Migration: 1 -> 2 finished");
}
if db.globals.database_version()? < 3 {
if services().globals.database_version()? < 3 {
// Move media to filesystem
for (key, content) in db.media.mediaid_file.iter() {
for (key, content) in db.mediaid_file.iter() {
if content.is_empty() {
continue;
}
let path = db.globals.get_media_file(&key);
let path = services().globals.get_media_file(&key);
let mut file = fs::File::create(path)?;
file.write_all(&content)?;
db.media.mediaid_file.insert(&key, &[])?;
db.mediaid_file.insert(&key, &[])?;
}
db.globals.bump_database_version(3)?;
services().globals.bump_database_version(3)?;
warn!("Migration: 2 -> 3 finished");
}
if db.globals.database_version()? < 4 {
// Add federated users to db as deactivated
for our_user in db.users.iter() {
if services().globals.database_version()? < 4 {
// Add federated users to services() as deactivated
for our_user in services().users.iter() {
let our_user = our_user?;
if db.users.is_deactivated(&our_user)? {
if services().users.is_deactivated(&our_user)? {
continue;
}
for room in db.rooms.rooms_joined(&our_user) {
for user in db.rooms.room_members(&room?) {
for room in services().rooms.state_cache.rooms_joined(&our_user) {
for user in services().rooms.state_cache.room_members(&room?) {
let user = user?;
if user.server_name() != db.globals.server_name() {
if user.server_name() != services().globals.server_name() {
println!("Migration: Creating user {}", user);
db.users.create(&user, None)?;
services().users.create(&user, None)?;
}
}
}
}
db.globals.bump_database_version(4)?;
services().globals.bump_database_version(4)?;
warn!("Migration: 3 -> 4 finished");
}
if db.globals.database_version()? < 5 {
if services().globals.database_version()? < 5 {
// Upgrade user data store
for (roomuserdataid, _) in db.account_data.roomuserdataid_accountdata.iter() {
for (roomuserdataid, _) in db.roomuserdataid_accountdata.iter() {
let mut parts = roomuserdataid.split(|&b| b == 0xff);
let room_id = parts.next().unwrap();
let user_id = parts.next().unwrap();
@ -470,30 +528,29 @@ impl KeyValueDatabase {
key.push(0xff);
key.extend_from_slice(event_type);
db.account_data
.roomusertype_roomuserdataid
db.roomusertype_roomuserdataid
.insert(&key, &roomuserdataid)?;
}
db.globals.bump_database_version(5)?;
services().globals.bump_database_version(5)?;
warn!("Migration: 4 -> 5 finished");
}
if db.globals.database_version()? < 6 {
if services().globals.database_version()? < 6 {
// Set room member count
for (roomid, _) in db.rooms.roomid_shortstatehash.iter() {
for (roomid, _) in db.roomid_shortstatehash.iter() {
let string = utils::string_from_bytes(&roomid).unwrap();
let room_id = <&RoomId>::try_from(string.as_str()).unwrap();
db.rooms.update_joined_count(room_id, &db)?;
services().rooms.state_cache.update_joined_count(room_id)?;
}
db.globals.bump_database_version(6)?;
services().globals.bump_database_version(6)?;
warn!("Migration: 5 -> 6 finished");
}
if db.globals.database_version()? < 7 {
if services().globals.database_version()? < 7 {
// Upgrade state store
let mut last_roomstates: HashMap<Box<RoomId>, u64> = HashMap::new();
let mut current_sstatehash: Option<u64> = None;
@ -513,7 +570,7 @@ impl KeyValueDatabase {
let states_parents = last_roomsstatehash.map_or_else(
|| Ok(Vec::new()),
|&last_roomsstatehash| {
db.rooms.state_accessor.load_shortstatehash_info(dbg!(last_roomsstatehash))
services().rooms.state_compressor.load_shortstatehash_info(dbg!(last_roomsstatehash))
},
)?;
@ -535,7 +592,7 @@ impl KeyValueDatabase {
(current_state, HashSet::new())
};
db.rooms.save_state_from_diff(
services().rooms.state_compressor.save_state_from_diff(
dbg!(current_sstatehash),
statediffnew,
statediffremoved,
@ -544,7 +601,7 @@ impl KeyValueDatabase {
)?;
/*
let mut tmp = db.rooms.load_shortstatehash_info(&current_sstatehash, &db)?;
let mut tmp = services().rooms.load_shortstatehash_info(&current_sstatehash)?;
let state = tmp.pop().unwrap();
println!(
"{}\t{}{:?}: {:?} + {:?} - {:?}",
@ -587,14 +644,13 @@ impl KeyValueDatabase {
current_sstatehash = Some(sstatehash);
let event_id = db
.rooms
.shorteventid_eventid
.get(&seventid)
.unwrap()
.unwrap();
let string = utils::string_from_bytes(&event_id).unwrap();
let event_id = <&EventId>::try_from(string.as_str()).unwrap();
let pdu = db.rooms.get_pdu(event_id).unwrap().unwrap();
let pdu = services().rooms.timeline.get_pdu(event_id).unwrap().unwrap();
if Some(&pdu.room_id) != current_room.as_ref() {
current_room = Some(pdu.room_id.clone());
@ -615,20 +671,20 @@ impl KeyValueDatabase {
)?;
}
db.globals.bump_database_version(7)?;
services().globals.bump_database_version(7)?;
warn!("Migration: 6 -> 7 finished");
}
if db.globals.database_version()? < 8 {
if services().globals.database_version()? < 8 {
// Generate short room ids for all rooms
for (room_id, _) in db.rooms.roomid_shortstatehash.iter() {
let shortroomid = db.globals.next_count()?.to_be_bytes();
db.rooms.roomid_shortroomid.insert(&room_id, &shortroomid)?;
for (room_id, _) in db.roomid_shortstatehash.iter() {
let shortroomid = services().globals.next_count()?.to_be_bytes();
db.roomid_shortroomid.insert(&room_id, &shortroomid)?;
info!("Migration: 8");
}
// Update pduids db layout
let mut batch = db.rooms.pduid_pdu.iter().filter_map(|(key, v)| {
let mut batch = db.pduid_pdu.iter().filter_map(|(key, v)| {
if !key.starts_with(b"!") {
return None;
}
@ -637,7 +693,6 @@ impl KeyValueDatabase {
let count = parts.next().unwrap();
let short_room_id = db
.rooms
.roomid_shortroomid
.get(room_id)
.unwrap()
@ -649,9 +704,9 @@ impl KeyValueDatabase {
Some((new_key, v))
});
db.rooms.pduid_pdu.insert_batch(&mut batch)?;
db.pduid_pdu.insert_batch(&mut batch)?;
let mut batch2 = db.rooms.eventid_pduid.iter().filter_map(|(k, value)| {
let mut batch2 = db.eventid_pduid.iter().filter_map(|(k, value)| {
if !value.starts_with(b"!") {
return None;
}
@ -660,7 +715,6 @@ impl KeyValueDatabase {
let count = parts.next().unwrap();
let short_room_id = db
.rooms
.roomid_shortroomid
.get(room_id)
.unwrap()
@ -672,17 +726,16 @@ impl KeyValueDatabase {
Some((k, new_value))
});
db.rooms.eventid_pduid.insert_batch(&mut batch2)?;
db.eventid_pduid.insert_batch(&mut batch2)?;
db.globals.bump_database_version(8)?;
services().globals.bump_database_version(8)?;
warn!("Migration: 7 -> 8 finished");
}
if db.globals.database_version()? < 9 {
if services().globals.database_version()? < 9 {
// Update tokenids db layout
let mut iter = db
.rooms
.tokenids
.iter()
.filter_map(|(key, _)| {
@ -696,7 +749,6 @@ impl KeyValueDatabase {
let pdu_id_count = parts.next().unwrap();
let short_room_id = db
.rooms
.roomid_shortroomid
.get(room_id)
.unwrap()
@ -712,8 +764,7 @@ impl KeyValueDatabase {
.peekable();
while iter.peek().is_some() {
db.rooms
.tokenids
db.tokenids
.insert_batch(&mut iter.by_ref().take(1000))?;
println!("smaller batch done");
}
@ -721,7 +772,6 @@ impl KeyValueDatabase {
info!("Deleting starts");
let batch2: Vec<_> = db
.rooms
.tokenids
.iter()
.filter_map(|(key, _)| {
@ -736,38 +786,37 @@ impl KeyValueDatabase {
for key in batch2 {
println!("del");
db.rooms.tokenids.remove(&key)?;
db.tokenids.remove(&key)?;
}
db.globals.bump_database_version(9)?;
services().globals.bump_database_version(9)?;
warn!("Migration: 8 -> 9 finished");
}
if db.globals.database_version()? < 10 {
if services().globals.database_version()? < 10 {
// Add other direction for shortstatekeys
for (statekey, shortstatekey) in db.rooms.statekey_shortstatekey.iter() {
db.rooms
.shortstatekey_statekey
for (statekey, shortstatekey) in db.statekey_shortstatekey.iter() {
db.shortstatekey_statekey
.insert(&shortstatekey, &statekey)?;
}
// Force E2EE device list updates so we can send them over federation
for user_id in db.users.iter().filter_map(|r| r.ok()) {
db.users
.mark_device_key_update(&user_id, &db.rooms, &db.globals)?;
for user_id in services().users.iter().filter_map(|r| r.ok()) {
services().users
.mark_device_key_update(&user_id)?;
}
db.globals.bump_database_version(10)?;
services().globals.bump_database_version(10)?;
warn!("Migration: 9 -> 10 finished");
}
if db.globals.database_version()? < 11 {
if services().globals.database_version()? < 11 {
db._db
.open_tree("userdevicesessionid_uiaarequest")?
.clear()?;
db.globals.bump_database_version(11)?;
services().globals.bump_database_version(11)?;
warn!("Migration: 10 -> 11 finished");
}
@ -779,12 +828,12 @@ impl KeyValueDatabase {
config.database_backend, latest_database_version
);
} else {
guard
services()
.globals
.bump_database_version(latest_database_version)?;
// Create the admin room and server user on first run
create_admin_room().await?;
services().admin.create_admin_room().await?;
warn!(
"Created new {} database with version {}",
@ -793,16 +842,16 @@ impl KeyValueDatabase {
}
// This data is probably outdated
guard.rooms.edus.presenceid_presence.clear()?;
db.presenceid_presence.clear()?;
guard.admin.start_handler(Arc::clone(&db), admin_receiver);
services().admin.start_handler(admin_receiver);
// Set emergency access for the conduit user
match set_emergency_access(&guard) {
match set_emergency_access() {
Ok(pwd_set) => {
if pwd_set {
warn!("The Conduit account emergency password is set! Please unset it as soon as you finish admin account recovery!");
guard.admin.send_message(RoomMessageEventContent::text_plain("The Conduit account emergency password is set! Please unset it as soon as you finish admin account recovery!"));
services().admin.send_message(RoomMessageEventContent::text_plain("The Conduit account emergency password is set! Please unset it as soon as you finish admin account recovery!"));
}
}
Err(e) => {
@ -813,21 +862,19 @@ impl KeyValueDatabase {
}
};
guard
services()
.sending
.start_handler(Arc::clone(&db), sending_receiver);
.start_handler(sending_receiver);
drop(guard);
Self::start_cleanup_task(config).await;
Self::start_cleanup_task(Arc::clone(&db), config).await;
Ok(db)
Ok(())
}
#[cfg(feature = "conduit_bin")]
pub async fn on_shutdown(db: Arc<TokioRwLock<Self>>) {
pub async fn on_shutdown() {
info!(target: "shutdown-sync", "Received shutdown notification, notifying sync helpers...");
db.read().await.globals.rotate.fire();
services().globals.rotate.fire();
}
pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) {
@ -844,33 +891,30 @@ impl KeyValueDatabase {
// Return when *any* user changed his key
// TODO: only send for user they share a room with
futures.push(
self.users
.todeviceid_events
self.todeviceid_events
.watch_prefix(&userdeviceid_prefix),
);
futures.push(self.rooms.userroomid_joined.watch_prefix(&userid_prefix));
futures.push(self.userroomid_joined.watch_prefix(&userid_prefix));
futures.push(
self.rooms
.userroomid_invitestate
self.userroomid_invitestate
.watch_prefix(&userid_prefix),
);
futures.push(self.rooms.userroomid_leftstate.watch_prefix(&userid_prefix));
futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix));
futures.push(
self.rooms
.userroomid_notificationcount
self.userroomid_notificationcount
.watch_prefix(&userid_prefix),
);
futures.push(
self.rooms
.userroomid_highlightcount
self.userroomid_highlightcount
.watch_prefix(&userid_prefix),
);
// Events for rooms we are in
for room_id in self.rooms.rooms_joined(user_id).filter_map(|r| r.ok()) {
let short_roomid = self
for room_id in services().rooms.state_cache.rooms_joined(user_id).filter_map(|r| r.ok()) {
let short_roomid = services()
.rooms
.short
.get_shortroomid(&room_id)
.ok()
.flatten()
@ -883,33 +927,28 @@ impl KeyValueDatabase {
roomid_prefix.push(0xff);
// PDUs
futures.push(self.rooms.pduid_pdu.watch_prefix(&short_roomid));
futures.push(self.pduid_pdu.watch_prefix(&short_roomid));
// EDUs
futures.push(
self.rooms
.edus
.roomid_lasttypingupdate
self.roomid_lasttypingupdate
.watch_prefix(&roomid_bytes),
);
futures.push(
self.rooms
.edus
.readreceiptid_readreceipt
self.readreceiptid_readreceipt
.watch_prefix(&roomid_prefix),
);
// Key changes
futures.push(self.users.keychangeid_userid.watch_prefix(&roomid_prefix));
futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix));
// Room account data
let mut roomuser_prefix = roomid_prefix.clone();
roomuser_prefix.extend_from_slice(&userid_prefix);
futures.push(
self.account_data
.roomusertype_roomuserdataid
self.roomusertype_roomuserdataid
.watch_prefix(&roomuser_prefix),
);
}
@ -918,22 +957,20 @@ impl KeyValueDatabase {
globaluserdata_prefix.extend_from_slice(&userid_prefix);
futures.push(
self.account_data
.roomusertype_roomuserdataid
self.roomusertype_roomuserdataid
.watch_prefix(&globaluserdata_prefix),
);
// More key changes (used when user is not joined to any rooms)
futures.push(self.users.keychangeid_userid.watch_prefix(&userid_prefix));
futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix));
// One time keys
futures.push(
self.users
.userid_lastonetimekeyupdate
self.userid_lastonetimekeyupdate
.watch_prefix(&userid_bytes),
);
futures.push(Box::pin(self.globals.rotate.watch()));
futures.push(Box::pin(services().globals.rotate.watch()));
// Wait until one of them finds something
futures.next().await;
@ -950,8 +987,8 @@ impl KeyValueDatabase {
res
}
#[tracing::instrument(skip(db, config))]
pub async fn start_cleanup_task(db: Arc<TokioRwLock<Self>>, config: &Config) {
#[tracing::instrument(skip(config))]
pub async fn start_cleanup_task(config: &Config) {
use tokio::time::interval;
#[cfg(unix)]
@ -984,7 +1021,7 @@ impl KeyValueDatabase {
}
let start = Instant::now();
if let Err(e) = db.read().await._db.cleanup() {
if let Err(e) = services().globals.db._db.cleanup() {
error!("cleanup: Errored: {}", e);
} else {
info!("cleanup: Finished in {:?}", start.elapsed());
@ -995,26 +1032,25 @@ impl KeyValueDatabase {
}
/// Sets the emergency password and push rules for the @conduit account in case emergency password is set
fn set_emergency_access(db: &KeyValueDatabase) -> Result<bool> {
let conduit_user = UserId::parse_with_server_name("conduit", db.globals.server_name())
fn set_emergency_access() -> Result<bool> {
let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name())
.expect("@conduit:server_name is a valid UserId");
db.users
.set_password(&conduit_user, db.globals.emergency_password().as_deref())?;
services().users
.set_password(&conduit_user, services().globals.emergency_password().as_deref())?;
let (ruleset, res) = match db.globals.emergency_password() {
let (ruleset, res) = match services().globals.emergency_password() {
Some(_) => (Ruleset::server_default(&conduit_user), Ok(true)),
None => (Ruleset::new(), Ok(false)),
};
db.account_data.update(
services().account_data.update(
None,
&conduit_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&GlobalAccountDataEvent {
content: PushRulesEventContent { global: ruleset },
},
&db.globals,
)?;
res

@ -13,22 +13,16 @@ mod service;
pub mod api;
mod utils;
use std::{cell::Cell, sync::RwLock};
use std::{cell::Cell, sync::{RwLock, Arc}};
pub use config::Config;
pub use utils::error::{Error, Result};
pub use service::{Services, pdu::PduEvent};
pub use api::ruma_wrapper::{Ruma, RumaResponse};
use crate::database::KeyValueDatabase;
pub static SERVICES: RwLock<Option<Arc<Services>>> = RwLock::new(None);
pub static SERVICES: RwLock<Option<ServicesEnum>> = RwLock::new(None);
enum ServicesEnum {
Rocksdb(Services<KeyValueDatabase>)
}
pub fn services<'a>() -> &'a Services<KeyValueDatabase> {
&SERVICES.read().unwrap()
pub fn services<'a>() -> Arc<Services> {
Arc::clone(&SERVICES.read().unwrap())
}

@ -69,19 +69,14 @@ async fn main() {
config.warn_deprecated();
let db = match KeyValueDatabase::load_or_create(&config).await {
Ok(db) => db,
Err(e) => {
eprintln!(
"The database couldn't be loaded or created. The following error occured: {}",
e
);
std::process::exit(1);
}
if let Err(e) = KeyValueDatabase::load_or_create(&config).await {
eprintln!(
"The database couldn't be loaded or created. The following error occured: {}",
e
);
std::process::exit(1);
};
SERVICES.set(db).expect("this is the first and only time we initialize the SERVICE static");
let start = async {
run_server().await.unwrap();
};

@ -17,11 +17,11 @@ use tracing::error;
use crate::{service::*, services, utils, Error, Result};
pub struct Service<D: Data> {
db: D,
pub struct Service {
db: Box<dyn Data>,
}
impl<D: Data> Service<D> {
impl Service {
/// Places one event in the account data of the user and removes the previous entry.
#[tracing::instrument(skip(self, room_id, user_id, event_type, data))]
pub fn update<T: Serialize>(

@ -3,11 +3,11 @@ pub use data::Data;
use crate::Result;
pub struct Service<D: Data> {
db: D,
pub struct Service {
db: Box<dyn Data>,
}
impl<D: Data> Service<D> {
impl Service {
/// Registers an appservice and returns the ID to the caller
pub fn register_appservice(&self, yaml: serde_yaml::Value) -> Result<String> {
self.db.register_appservice(yaml)

@ -36,8 +36,8 @@ type SyncHandle = (
Receiver<Option<Result<sync_events::v3::Response>>>, // rx
);
pub struct Service<D: Data> {
pub db: D,
pub struct Service {
pub db: Box<dyn Data>,
pub actual_destination_cache: Arc<RwLock<WellKnownMap>>, // actual_destination, host
pub tls_name_override: Arc<RwLock<TlsNameMap>>,
@ -92,9 +92,9 @@ impl Default for RotationHandler {
}
impl<D: Data> Service<D> {
impl Service {
pub fn load(
db: D,
db: Box<dyn Data>,
config: Config,
) -> Result<Self> {
let keypair = db.load_keypair();

@ -12,11 +12,11 @@ use ruma::{
};
use std::{collections::BTreeMap, sync::Arc};
pub struct Service<D: Data> {
db: D,
pub struct Service {
db: Box<dyn Data>,
}
impl<D: Data> Service<D> {
impl Service {
pub fn create_backup(
&self,
user_id: &UserId,

@ -15,11 +15,11 @@ pub struct FileMeta {
pub file: Vec<u8>,
}
pub struct Service<D: Data> {
db: D,
pub struct Service {
db: Box<dyn Data>,
}
impl<D: Data> Service<D> {
impl Service {
/// Uploads a file.
pub async fn create(
&self,

@ -1,3 +1,5 @@
use std::sync::Arc;
pub mod account_data;
pub mod admin;
pub mod appservice;
@ -12,18 +14,36 @@ pub mod transaction_ids;
pub mod uiaa;
pub mod users;
pub struct Services<D: appservice::Data + pusher::Data + rooms::Data + transaction_ids::Data + uiaa::Data + users::Data + account_data::Data + globals::Data + key_backups::Data + media::Data>
{
pub appservice: appservice::Service<D>,
pub pusher: pusher::Service<D>,
pub rooms: rooms::Service<D>,
pub transaction_ids: transaction_ids::Service<D>,
pub uiaa: uiaa::Service<D>,
pub users: users::Service<D>,
pub account_data: account_data::Service<D>,
pub struct Services {
pub appservice: appservice::Service,
pub pusher: pusher::Service,
pub rooms: rooms::Service,
pub transaction_ids: transaction_ids::Service,
pub uiaa: uiaa::Service,
pub users: users::Service,
pub account_data: account_data::Service,
pub admin: admin::Service,
pub globals: globals::Service<D>,
pub key_backups: key_backups::Service<D>,
pub media: media::Service<D>,
pub globals: globals::Service,
pub key_backups: key_backups::Service,
pub media: media::Service,
pub sending: sending::Service,
}
impl Services {
pub fn build<D: appservice::Data + pusher::Data + rooms::Data + transaction_ids::Data + uiaa::Data + users::Data + account_data::Data + globals::Data + key_backups::Data + media::Data>(db: Arc<D>) {
Self {
appservice: appservice::Service { db: Arc::clone(&db) },
pusher: appservice::Service { db: Arc::clone(&db) },
rooms: appservice::Service { db: Arc::clone(&db) },
transaction_ids: appservice::Service { db: Arc::clone(&db) },
uiaa: appservice::Service { db: Arc::clone(&db) },
users: appservice::Service { db: Arc::clone(&db) },
account_data: appservice::Service { db: Arc::clone(&db) },
admin: appservice::Service { db: Arc::clone(&db) },
globals: appservice::Service { db: Arc::clone(&db) },
key_backups: appservice::Service { db: Arc::clone(&db) },
media: appservice::Service { db: Arc::clone(&db) },
sending: appservice::Service { db: Arc::clone(&db) },
}
}
}

@ -23,11 +23,11 @@ use ruma::{
use std::{fmt::Debug, mem};
use tracing::{error, info, warn};
pub struct Service<D: Data> {
db: D,
pub struct Service {
db: Box<dyn Data>,
}
impl<D: Data> Service<D> {
impl Service {
pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()> {
self.db.set_pusher(sender, pusher)
}

@ -25,5 +25,5 @@ pub trait Data {
fn local_aliases_for_room(
&self,
room_id: &RoomId,
) -> Result<Box<dyn Iterator<Item=String>>>;
) -> Box<dyn Iterator<Item = Result<Box<RoomAliasId>>>>;
}

@ -4,11 +4,11 @@ pub use data::Data;
use ruma::{RoomAliasId, RoomId};
use crate::Result;
pub struct Service<D: Data> {
db: D,
pub struct Service {
db: Box<dyn Data>,
}
impl<D: Data> Service<D> {
impl Service {
#[tracing::instrument(skip(self))]
pub fn set_alias(
&self,

@ -2,6 +2,6 @@ use std::collections::HashSet;
use crate::Result;
pub trait Data {
fn get_cached_eventid_authchain(&self, shorteventid: u64) -> Result<HashSet<u64>>;
fn get_cached_eventid_authchain(&self, shorteventid: u64) -> Result<Option<HashSet<u64>>>;
fn cache_eventid_authchain(&self, shorteventid: u64, auth_chain: &HashSet<u64>) -> Result<()>;
}

@ -5,11 +5,11 @@ pub use data::Data;
use crate::Result;
pub struct Service<D: Data> {
db: D,
pub struct Service {
db: Box<dyn Data>,
}
impl<D: Data> Service<D> {
impl Service {
#[tracing::instrument(skip(self))]
pub fn get_cached_eventid_authchain<'a>(
&'a self,

@ -4,11 +4,11 @@ use ruma::RoomId;
use crate::Result;
pub struct Service<D: Data> {
db: D,
pub struct Service {
db: Box<dyn Data>,
}
impl<D: Data> Service<D> {
impl Service {
#[tracing::instrument(skip(self))]
pub fn set_public(&self, room_id: &RoomId) -> Result<()> {
self.db.set_public(room_id)

@ -4,8 +4,8 @@ pub mod typing;
pub trait Data: presence::Data + read_receipt::Data + typing::Data {}
pub struct Service<D: Data> {
pub presence: presence::Service<D>,
pub read_receipt: read_receipt::Service<D>,
pub typing: typing::Service<D>,
pub struct Service {
pub presence: presence::Service,
pub read_receipt: read_receipt::Service,
pub typing: typing::Service,
}

@ -6,11 +6,11 @@ use ruma::{RoomId, UserId, events::presence::PresenceEvent};
use crate::Result;
pub struct Service<D: Data> {
db: D,
pub struct Service {
db: Box<dyn Data>,
}
impl<D: Data> Service<D> {
impl Service {
/// Adds a presence event which will be saved until a new event replaces it.
///
/// Note: This method takes a RoomId because presence updates are always bound to rooms to

@ -4,11 +4,11 @@ pub use data::Data;
use ruma::{RoomId, UserId, events::receipt::ReceiptEvent, serde::Raw};
use crate::Result;
pub struct Service<D: Data> {
db: D,
pub struct Service {
db: Box<dyn Data>,
}
impl<D: Data> Service<D> {
impl Service {
/// Replaces the previous read receipt.
pub fn readreceipt_update(
&self,

@ -4,11 +4,11 @@ use ruma::{UserId, RoomId, events::SyncEphemeralRoomEvent};
use crate::Result;
pub struct Service<D: Data> {
db: D,
pub struct Service {
db: Box<dyn Data>,
}
impl<D: Data> Service<D> {
impl Service {
/// Sets a user as typing until the timeout timestamp is reached or roomtyping_remove is
/// called.
pub fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> {

@ -1,14 +1,16 @@
/// An async function that can recursively call itself.
type AsyncRecursiveType<'a, T> = Pin<Box<dyn Future<Output = T> + 'a + Send>>;
use ruma::{RoomVersionId, signatures::CanonicalJsonObject, api::federation::discovery::{get_server_keys, get_remote_server_keys}};
use tokio::sync::Semaphore;
use std::{
collections::{btree_map, hash_map, BTreeMap, HashMap, HashSet},
pin::Pin,
sync::{Arc, RwLock},
time::{Duration, Instant},
sync::{Arc, RwLock, RwLockWriteGuard},
time::{Duration, Instant, SystemTime},
};
use futures_util::{Future, stream::FuturesUnordered};
use futures_util::{Future, stream::FuturesUnordered, StreamExt};
use ruma::{
api::{
client::error::ErrorKind,
@ -22,7 +24,7 @@ use ruma::{
uint, EventId, MilliSecondsSinceUnixEpoch, RoomId, ServerName, ServerSigningKeyId,
};
use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
use tracing::{error, info, trace, warn};
use tracing::{error, info, trace, warn, debug};
use crate::{service::*, services, Result, Error, PduEvent};
@ -53,7 +55,7 @@ impl Service {
/// it
/// 14. Use state resolution to find new room state
// We use some AsyncRecursiveType hacks here so we can call this async funtion recursively
#[tracing::instrument(skip(value, is_timeline_event, pub_key_map))]
#[tracing::instrument(skip(self, value, is_timeline_event, pub_key_map))]
pub(crate) async fn handle_incoming_pdu<'a>(
&self,
origin: &'a ServerName,
@ -64,10 +66,11 @@ impl Service {
pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> Result<Option<Vec<u8>>> {
if !services().rooms.metadata.exists(room_id)? {
return Error::BadRequest(
return Err(Error::BadRequest(
ErrorKind::NotFound,
"Room is unknown to this server",
)};
));
}
services()
.rooms
@ -732,7 +735,7 @@ impl Service {
&incoming_pdu.sender,
incoming_pdu.state_key.as_deref(),
&incoming_pdu.content,
)?
)?;
let soft_fail = !state_res::event_auth::auth_check(
&room_version,
@ -821,7 +824,7 @@ impl Service {
let shortstatekey = services()
.rooms
.short
.get_or_create_shortstatekey(&incoming_pdu.kind.to_string().into(), state_key)?
.get_or_create_shortstatekey(&incoming_pdu.kind.to_string().into(), state_key)?;
state_after.insert(shortstatekey, Arc::from(&*incoming_pdu.event_id));
}
@ -1236,7 +1239,7 @@ impl Service {
let signature_ids = signature_object.keys().cloned().collect::<Vec<_>>();
let fetch_res = fetch_signing_keys(
let fetch_res = self.fetch_signing_keys(
signature_server.as_str().try_into().map_err(|_| {
Error::BadServerResponse("Invalid servername in signatures of server response pdu.")
})?,
@ -1481,4 +1484,168 @@ impl Service {
))
}
}
/// Search the DB for the signing keys of the given server, if we don't have them
/// fetch them from the server and save to our DB.
#[tracing::instrument(skip_all)]
pub async fn fetch_signing_keys(
&self,
origin: &ServerName,
signature_ids: Vec<String>,
) -> Result<BTreeMap<String, Base64>> {
let contains_all_ids =
|keys: &BTreeMap<String, Base64>| signature_ids.iter().all(|id| keys.contains_key(id));
let permit = services()
.globals
.servername_ratelimiter
.read()
.unwrap()
.get(origin)
.map(|s| Arc::clone(s).acquire_owned());
let permit = match permit {
Some(p) => p,
None => {
let mut write = services().globals.servername_ratelimiter.write().unwrap();
let s = Arc::clone(
write
.entry(origin.to_owned())
.or_insert_with(|| Arc::new(Semaphore::new(1))),
);
s.acquire_owned()
}
}
.await;
let back_off = |id| match services()
.globals
.bad_signature_ratelimiter
.write()
.unwrap()
.entry(id)
{
hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
}
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
};
if let Some((time, tries)) = services()
.globals
.bad_signature_ratelimiter
.read()
.unwrap()
.get(&signature_ids)
{
// Exponential backoff
let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
min_elapsed_duration = Duration::from_secs(60 * 60 * 24);
}
if time.elapsed() < min_elapsed_duration {
debug!("Backing off from {:?}", signature_ids);
return Err(Error::BadServerResponse("bad signature, still backing off"));
}
}
trace!("Loading signing keys for {}", origin);
let mut result: BTreeMap<_, _> = services()
.globals
.signing_keys_for(origin)?
.into_iter()
.map(|(k, v)| (k.to_string(), v.key))
.collect();
if contains_all_ids(&result) {
return Ok(result);
}
debug!("Fetching signing keys for {} over federation", origin);
if let Some(server_key) = services()
.sending
.send_federation_request(origin, get_server_keys::v2::Request::new())
.await
.ok()
.and_then(|resp| resp.server_key.deserialize().ok())
{
services().globals.add_signing_key(origin, server_key.clone())?;
result.extend(
server_key
.verify_keys
.into_iter()
.map(|(k, v)| (k.to_string(), v.key)),
);
result.extend(
server_key
.old_verify_keys
.into_iter()
.map(|(k, v)| (k.to_string(), v.key)),
);
if contains_all_ids(&result) {
return Ok(result);
}
}
for server in services().globals.trusted_servers() {
debug!("Asking {} for {}'s signing key", server, origin);
if let Some(server_keys) = services()
.sending
.send_federation_request(
server,
get_remote_server_keys::v2::Request::new(
origin,
MilliSecondsSinceUnixEpoch::from_system_time(
SystemTime::now()
.checked_add(Duration::from_secs(3600))
.expect("SystemTime to large"),
)
.expect("time is valid"),
),
)
.await
.ok()
.map(|resp| {
resp.server_keys
.into_iter()
.filter_map(|e| e.deserialize().ok())
.collect::<Vec<_>>()
})
{
trace!("Got signing keys: {:?}", server_keys);
for k in server_keys {
services().globals.add_signing_key(origin, k.clone())?;
result.extend(
k.verify_keys
.into_iter()
.map(|(k, v)| (k.to_string(), v.key)),
);
result.extend(
k.old_verify_keys
.into_iter()
.map(|(k, v)| (k.to_string(), v.key)),
);
}
if contains_all_ids(&result) {
return Ok(result);
}
}
}
drop(permit);
back_off(signature_ids);
warn!("Failed to find public key for server: {}", origin);
Err(Error::BadServerResponse(
"Failed to find public key for server",
))
}
}

@ -15,7 +15,7 @@ pub trait Data {
user_id: &UserId,
device_id: &DeviceId,
room_id: &RoomId,
since: u64,
confirmed_user_ids: &mut dyn Iterator<Item=&UserId>,
) -> Result<()>;
fn lazy_load_reset(

@ -1,16 +1,18 @@
mod data;
use std::collections::HashSet;
use std::{collections::{HashSet, HashMap}, sync::Mutex};
pub use data::Data;
use ruma::{DeviceId, UserId, RoomId};
use crate::Result;
pub struct Service<D: Data> {
db: D,
pub struct Service {
db: Box<dyn Data>,
lazy_load_waiting: Mutex<HashMap<(Box<UserId>, Box<DeviceId>, Box<RoomId>, u64), HashSet<Box<UserId>>>>,
}
impl<D: Data> Service<D> {
impl Service {
#[tracing::instrument(skip(self))]
pub fn lazy_load_was_sent_before(
&self,
@ -50,7 +52,18 @@ impl<D: Data> Service<D> {
room_id: &RoomId,
since: u64,
) -> Result<()> {
self.db.lazy_load_confirm_delivery(user_id, device_id, room_id, since)
if let Some(user_ids) = self.lazy_load_waiting.lock().unwrap().remove(&(
user_id.to_owned(),
device_id.to_owned(),
room_id.to_owned(),
since,
)) {
self.db.lazy_load_confirm_delivery(user_id, device_id, room_id, &mut user_ids.iter().map(|&u| &*u))?;
} else {
// Ignore
}
Ok(())
}
#[tracing::instrument(skip(self))]

@ -4,11 +4,11 @@ use ruma::RoomId;
use crate::Result;
pub struct Service<D: Data> {
db: D,
pub struct Service {
db: Box<dyn Data>,
}
impl<D: Data> Service<D> {
impl Service {
/// Checks if a room exists.
#[tracing::instrument(skip(self))]
pub fn exists(&self, room_id: &RoomId) -> Result<bool> {

@ -18,22 +18,22 @@ pub mod user;
pub trait Data: alias::Data + auth_chain::Data + directory::Data + edus::Data + lazy_loading::Data + metadata::Data + outlier::Data + pdu_metadata::Data + search::Data + short::Data + state::Data + state_accessor::Data + state_cache::Data + state_compressor::Data + timeline::Data + user::Data {}
pub struct Service<D: Data> {
pub alias: alias::Service<D>,
pub auth_chain: auth_chain::Service<D>,
pub directory: directory::Service<D>,
pub edus: edus::Service<D>,
pub struct Service {
pub alias: alias::Service,
pub auth_chain: auth_chain::Service,
pub directory: directory::Service,
pub edus: edus::Service,
pub event_handler: event_handler::Service,
pub lazy_loading: lazy_loading::Service<D>,
pub metadata: metadata::Service<D>,
pub outlier: outlier::Service<D>,
pub pdu_metadata: pdu_metadata::Service<D>,
pub search: search::Service<D>,
pub short: short::Service<D>,
pub state: state::Service<D>,
pub state_accessor: state_accessor::Service<D>,
pub state_cache: state_cache::Service<D>,
pub state_compressor: state_compressor::Service<D>,
pub timeline: timeline::Service<D>,
pub user: user::Service<D>,
pub lazy_loading: lazy_loading::Service,
pub metadata: metadata::Service,
pub outlier: outlier::Service,
pub pdu_metadata: pdu_metadata::Service,
pub search: search::Service,
pub short: short::Service,
pub state: state::Service,
pub state_accessor: state_accessor::Service,
pub state_cache: state_cache::Service,
pub state_compressor: state_compressor::Service,
pub timeline: timeline::Service,
pub user: user::Service,
}

@ -4,11 +4,11 @@ use ruma::{EventId, signatures::CanonicalJsonObject};
use crate::{Result, PduEvent};
pub struct Service<D: Data> {
db: D,
pub struct Service {
db: Box<dyn Data>,
}
impl<D: Data> Service<D> {
impl Service {
/// Returns the pdu from the outlier tree.
pub fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.db.get_outlier_pdu_json(event_id)

@ -6,11 +6,11 @@ use ruma::{RoomId, EventId};
use crate::Result;
pub struct Service<D: Data> {
db: D,
pub struct Service {
db: Box<dyn Data>,
}
impl<D: Data> Service<D> {
impl Service {
#[tracing::instrument(skip(self, room_id, event_ids))]
pub fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> {
self.db.mark_as_referenced(room_id, event_ids)

@ -2,7 +2,7 @@ use ruma::RoomId;
use crate::Result;
pub trait Data {
fn index_pdu<'a>(&self, shortroomid: u64, pdu_id: u64, message_body: String) -> Result<()>;
fn index_pdu<'a>(&self, shortroomid: u64, pdu_id: &[u8], message_body: String) -> Result<()>;
fn search_pdus<'a>(
&'a self,

@ -4,11 +4,16 @@ pub use data::Data;
use crate::Result;
use ruma::RoomId;
pub struct Service<D: Data> {
db: D,
pub struct Service {
db: Box<dyn Data>,
}
impl<D: Data> Service<D> {
impl Service {
#[tracing::instrument(skip(self))]
pub fn index_pdu<'a>(&self, shortroomid: u64, pdu_id: &[u8], message_body: String) -> Result<()> {
self.db.index_pdu(shortroomid, pdu_id, message_body)
}
#[tracing::instrument(skip(self))]
pub fn search_pdus<'a>(
&'a self,

@ -6,11 +6,11 @@ use ruma::{EventId, events::StateEventType, RoomId};
use crate::{Result, Error, utils, services};
pub struct Service<D: Data> {
db: D,
pub struct Service {
db: Box<dyn Data>,
}
impl<D: Data> Service<D> {
impl Service {
pub fn get_or_create_shorteventid(
&self,
event_id: &EventId,

@ -1,6 +1,5 @@
use std::sync::Arc;
use std::{sync::MutexGuard, collections::HashSet};
use std::fmt::Debug;
use crate::Result;
use ruma::{EventId, RoomId};
@ -22,7 +21,7 @@ pub trait Data {
/// Replace the forward extremities of the room.
fn set_forward_extremities<'a>(&self,
room_id: &RoomId,
event_ids: impl IntoIterator<Item = &'a EventId> + Debug,
event_ids: &dyn Iterator<Item = &'a EventId>,
_mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
) -> Result<()>;
}

@ -10,11 +10,11 @@ use crate::{Result, services, PduEvent, Error, utils::calculate_hash};
use super::state_compressor::CompressedStateEvent;
pub struct Service<D: Data> {
db: D,
pub struct Service {
db: Box<dyn Data>,
}
impl<D: Data> Service<D> {
impl Service {
/// Set the room to the given statehash and update caches.
pub fn force_state(
&self,
@ -23,6 +23,15 @@ impl<D: Data> Service<D> {
statediffnew: HashSet<CompressedStateEvent>,
statediffremoved: HashSet<CompressedStateEvent>,
) -> Result<()> {
let mutex_state = Arc::clone(
services().globals
.roomid_mutex_state
.write()
.unwrap()
.entry(body.room_id.to_owned())
.or_default(),
);
let state_lock = mutex_state.lock().await;
for event_id in statediffnew.into_iter().filter_map(|new| {
services().rooms.state_compressor.parse_compressed_state_event(new)
@ -70,7 +79,9 @@ impl<D: Data> Service<D> {
services().room.state_cache.update_joined_count(room_id)?;
self.db.set_room_state(room_id, shortstatehash);
self.db.set_room_state(room_id, shortstatehash, &state_lock);
drop(state_lock);
Ok(())
}

@ -6,11 +6,11 @@ use ruma::{events::StateEventType, RoomId, EventId};
use crate::{Result, PduEvent};
pub struct Service<D: Data> {
db: D,
pub struct Service {
db: Box<dyn Data>,
}
impl<D: Data> Service<D> {
impl Service {
/// Builds a StateMap by iterating over all keys that start
/// with state_hash, this gives the full state for the given state_hash.
#[tracing::instrument(skip(self))]

@ -7,11 +7,11 @@ use ruma::{RoomId, UserId, events::{room::{member::MembershipState, create::Room
use crate::{Result, services, utils, Error};
pub struct Service<D: Data> {
db: D,
pub struct Service {
db: Box<dyn Data>,
}
impl<D: Data> Service<D> {
impl Service {
/// Update current membership data.
#[tracing::instrument(skip(self, last_state))]
pub fn update_membership(

@ -1,10 +1,12 @@
use std::collections::HashSet;
use super::CompressedStateEvent;
use crate::Result;
pub struct StateDiff {
parent: Option<u64>,
added: Vec<CompressedStateEvent>,
removed: Vec<CompressedStateEvent>,
pub parent: Option<u64>,
pub added: HashSet<CompressedStateEvent>,
pub removed: HashSet<CompressedStateEvent>,
}
pub trait Data {

@ -8,13 +8,13 @@ use crate::{Result, utils, services};
use self::data::StateDiff;
pub struct Service<D: Data> {
db: D,
pub struct Service {
db: Box<dyn Data>,
}
pub type CompressedStateEvent = [u8; 2 * size_of::<u64>()];
impl<D: Data> Service<D> {
impl Service {
/// Returns a stack with info on shortstatehash, full state, added diff and removed diff for the selected shortstatehash and each parent layer.
#[tracing::instrument(skip(self))]
pub fn load_shortstatehash_info(

@ -20,11 +20,11 @@ use crate::{services, Result, service::pdu::{PduBuilder, EventHash}, Error, PduE
use super::state_compressor::CompressedStateEvent;
pub struct Service<D: Data> {
db: D,
pub struct Service {
db: Box<dyn Data>,
}
impl<D: Data> Service<D> {
impl Service {
/*
/// Checks if a room exists.
#[tracing::instrument(skip(self))]

@ -4,11 +4,11 @@ use ruma::{RoomId, UserId};
use crate::Result;
pub struct Service<D: Data> {
db: D,
pub struct Service {
db: Box<dyn Data>,
}
impl<D: Data> Service<D> {
impl Service {
pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
self.db.reset_notification_counts(user_id, room_id)
}

@ -4,11 +4,11 @@ pub use data::Data;
use ruma::{UserId, DeviceId, TransactionId};
use crate::Result;
pub struct Service<D: Data> {
db: D,
pub struct Service {
db: Box<dyn Data>,
}
impl<D: Data> Service<D> {
impl Service {
pub fn add_txnid(
&self,
user_id: &UserId,

@ -6,11 +6,11 @@ use tracing::error;
use crate::{Result, utils, Error, services, api::client_server::SESSION_ID_LENGTH};
pub struct Service<D: Data> {
db: D,
pub struct Service {
db: Box<dyn Data>,
}
impl<D: Data> Service<D> {
impl Service {
/// Creates a new Uiaa session. Make sure the session token is unique.
pub fn create(
&self,

@ -2,7 +2,7 @@ use std::collections::BTreeMap;
use crate::Result;
use ruma::{UserId, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, serde::Raw, encryption::{OneTimeKey, DeviceKeys, CrossSigningKey}, UInt, events::AnyToDeviceEvent, api::client::{device::Device, filter::IncomingFilterDefinition}, MxcUri};
pub trait Data {
pub trait Data: Send + Sync {
/// Check if a user has an account on this homeserver.
fn exists(&self, user_id: &UserId) -> Result<bool>;
@ -138,16 +138,16 @@ pub trait Data {
device_id: &DeviceId,
) -> Result<Option<Raw<DeviceKeys>>>;
fn get_master_key<F: Fn(&UserId) -> bool>(
fn get_master_key(
&self,
user_id: &UserId,
allowed_signatures: F,
allowed_signatures: &dyn Fn(&UserId) -> bool,
) -> Result<Option<Raw<CrossSigningKey>>>;
fn get_self_signing_key<F: Fn(&UserId) -> bool>(
fn get_self_signing_key(
&self,
user_id: &UserId,
allowed_signatures: F,
allowed_signatures: &dyn Fn(&UserId) -> bool,
) -> Result<Option<Raw<CrossSigningKey>>>;
fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>>;

@ -6,11 +6,11 @@ use ruma::{UserId, MxcUri, DeviceId, DeviceKeyId, serde::Raw, encryption::{OneTi
use crate::{Result, Error, services};
pub struct Service<D: Data> {
db: D,
pub struct Service {
db: Box<dyn Data>,
}
impl<D: Data> Service<D> {
impl Service {
/// Check if a user has an account on this homeserver.
pub fn exists(&self, user_id: &UserId) -> Result<bool> {
self.db.exists(user_id)

Loading…
Cancel
Save