feat: improved state store

merge-requests/11/head
Timo Kösters 4 years ago
parent 6e5b35ea92
commit 6606e41dde
No known key found for this signature in database
GPG Key ID: 24DA7517711A2BA4

@ -1,5 +1,5 @@
use super::State; use super::State;
use crate::{appservice_server, server_server, ConduitResult, Database, Error, Ruma}; use crate::{ConduitResult, Database, Error, Ruma};
use ruma::{ use ruma::{
api::{ api::{
appservice, appservice,
@ -66,12 +66,14 @@ pub async fn get_alias_helper(
room_alias: &RoomAliasId, room_alias: &RoomAliasId,
) -> ConduitResult<get_alias::Response> { ) -> ConduitResult<get_alias::Response> {
if room_alias.server_name() != db.globals.server_name() { if room_alias.server_name() != db.globals.server_name() {
let response = server_server::send_request( let response = db
&db.globals, .sending
room_alias.server_name().to_owned(), .send_federation_request(
federation::query::get_room_information::v1::Request { room_alias }, &db.globals,
) room_alias.server_name().to_owned(),
.await?; federation::query::get_room_information::v1::Request { room_alias },
)
.await?;
return Ok(get_alias::Response::new(response.room_id, response.servers).into()); return Ok(get_alias::Response::new(response.room_id, response.servers).into());
} }
@ -81,13 +83,15 @@ pub async fn get_alias_helper(
Some(r) => room_id = Some(r), Some(r) => room_id = Some(r),
None => { None => {
for (_id, registration) in db.appservice.iter_all().filter_map(|r| r.ok()) { for (_id, registration) in db.appservice.iter_all().filter_map(|r| r.ok()) {
if appservice_server::send_request( if db
&db.globals, .sending
registration, .send_appservice_request(
appservice::query::query_room_alias::v1::Request { room_alias }, &db.globals,
) registration,
.await appservice::query::query_room_alias::v1::Request { room_alias },
.is_ok() )
.await
.is_ok()
{ {
room_id = Some(db.rooms.id_from_alias(&room_alias)?.ok_or_else(|| { room_id = Some(db.rooms.id_from_alias(&room_alias)?.ok_or_else(|| {
Error::bad_config("Appservice lied to us. Room does not exist.") Error::bad_config("Appservice lied to us. Room does not exist.")

@ -1,5 +1,5 @@
use super::State; use super::State;
use crate::{server_server, ConduitResult, Database, Error, Result, Ruma}; use crate::{ConduitResult, Database, Error, Result, Ruma};
use log::info; use log::info;
use ruma::{ use ruma::{
api::{ api::{
@ -133,19 +133,21 @@ pub async fn get_public_rooms_filtered_helper(
.clone() .clone()
.filter(|server| *server != db.globals.server_name().as_str()) .filter(|server| *server != db.globals.server_name().as_str())
{ {
let response = server_server::send_request( let response = db
&db.globals, .sending
other_server.to_owned(), .send_federation_request(
federation::directory::get_public_rooms_filtered::v1::Request { &db.globals,
limit, other_server.to_owned(),
since: since.as_deref(), federation::directory::get_public_rooms_filtered::v1::Request {
filter: Filter { limit,
generic_search_term: filter.generic_search_term.as_deref(), since: since.as_deref(),
filter: Filter {
generic_search_term: filter.generic_search_term.as_deref(),
},
room_network: RoomNetwork::Matrix,
}, },
room_network: RoomNetwork::Matrix, )
}, .await?;
)
.await?;
return Ok(get_public_rooms_filtered::Response { return Ok(get_public_rooms_filtered::Response {
chunk: response chunk: response

@ -1,7 +1,5 @@
use super::State; use super::State;
use crate::{ use crate::{database::media::FileMeta, utils, ConduitResult, Database, Error, Ruma};
database::media::FileMeta, server_server, utils, ConduitResult, Database, Error, Ruma,
};
use ruma::api::client::{ use ruma::api::client::{
error::ErrorKind, error::ErrorKind,
r0::media::{create_content, get_content, get_content_thumbnail, get_media_config}, r0::media::{create_content, get_content, get_content_thumbnail, get_media_config},
@ -45,7 +43,11 @@ pub async fn create_content_route(
db.flush().await?; db.flush().await?;
Ok(create_content::Response { content_uri: mxc, blurhash: None }.into()) Ok(create_content::Response {
content_uri: mxc,
blurhash: None,
}
.into())
} }
#[cfg_attr( #[cfg_attr(
@ -71,16 +73,18 @@ pub async fn get_content_route(
} }
.into()) .into())
} else if &*body.server_name != db.globals.server_name() && body.allow_remote { } else if &*body.server_name != db.globals.server_name() && body.allow_remote {
let get_content_response = server_server::send_request( let get_content_response = db
&db.globals, .sending
body.server_name.clone(), .send_federation_request(
get_content::Request { &db.globals,
allow_remote: false, body.server_name.clone(),
server_name: &body.server_name, get_content::Request {
media_id: &body.media_id, allow_remote: false,
}, server_name: &body.server_name,
) media_id: &body.media_id,
.await?; },
)
.await?;
db.media.create( db.media.create(
mxc, mxc,
@ -118,19 +122,21 @@ pub async fn get_content_thumbnail_route(
)? { )? {
Ok(get_content_thumbnail::Response { file, content_type }.into()) Ok(get_content_thumbnail::Response { file, content_type }.into())
} else if &*body.server_name != db.globals.server_name() && body.allow_remote { } else if &*body.server_name != db.globals.server_name() && body.allow_remote {
let get_thumbnail_response = server_server::send_request( let get_thumbnail_response = db
&db.globals, .sending
body.server_name.clone(), .send_federation_request(
get_content_thumbnail::Request { &db.globals,
allow_remote: false, body.server_name.clone(),
height: body.height, get_content_thumbnail::Request {
width: body.width, allow_remote: false,
method: body.method, height: body.height,
server_name: &body.server_name, width: body.width,
media_id: &body.media_id, method: body.method,
}, server_name: &body.server_name,
) media_id: &body.media_id,
.await?; },
)
.await?;
db.media.upload_thumbnail( db.media.upload_thumbnail(
mxc, mxc,

@ -2,7 +2,7 @@ use super::State;
use crate::{ use crate::{
client_server, client_server,
pdu::{PduBuilder, PduEvent}, pdu::{PduBuilder, PduEvent},
server_server, utils, ConduitResult, Database, Error, Result, Ruma, utils, ConduitResult, Database, Error, Result, Ruma,
}; };
use log::warn; use log::warn;
use ruma::{ use ruma::{
@ -401,9 +401,10 @@ pub async fn get_member_events_route(
Ok(get_member_events::Response { Ok(get_member_events::Response {
chunk: db chunk: db
.rooms .rooms
.room_state_type(&body.room_id, &EventType::RoomMember)? .room_state_full(&body.room_id)?
.values() .iter()
.map(|pdu| pdu.to_member_event()) .filter(|(key, _)| key.0 == EventType::RoomMember)
.map(|(_, pdu)| pdu.to_member_event())
.collect(), .collect(),
} }
.into()) .into())
@ -463,16 +464,18 @@ async fn join_room_by_id_helper(
)); ));
for remote_server in servers { for remote_server in servers {
let make_join_response = server_server::send_request( let make_join_response = db
&db.globals, .sending
remote_server.clone(), .send_federation_request(
federation::membership::create_join_event_template::v1::Request { &db.globals,
room_id, remote_server.clone(),
user_id: sender_user, federation::membership::create_join_event_template::v1::Request {
ver: &[RoomVersionId::Version5, RoomVersionId::Version6], room_id,
}, user_id: sender_user,
) ver: &[RoomVersionId::Version5, RoomVersionId::Version6],
.await; },
)
.await;
make_join_response_and_server = make_join_response.map(|r| (r, remote_server)); make_join_response_and_server = make_join_response.map(|r| (r, remote_server));
@ -540,16 +543,18 @@ async fn join_room_by_id_helper(
// It has enough fields to be called a proper event now // It has enough fields to be called a proper event now
let join_event = join_event_stub; let join_event = join_event_stub;
let send_join_response = server_server::send_request( let send_join_response = db
&db.globals, .sending
remote_server.clone(), .send_federation_request(
federation::membership::create_join_event::v2::Request { &db.globals,
room_id, remote_server.clone(),
event_id: &event_id, federation::membership::create_join_event::v2::Request {
pdu: PduEvent::convert_to_outgoing_federation_event(join_event.clone()), room_id,
}, event_id: &event_id,
) pdu: PduEvent::convert_to_outgoing_federation_event(join_event.clone()),
.await?; },
)
.await?;
let add_event_id = |pdu: &Raw<Pdu>| -> Result<(EventId, CanonicalJsonObject)> { let add_event_id = |pdu: &Raw<Pdu>| -> Result<(EventId, CanonicalJsonObject)> {
let mut value = serde_json::from_str(pdu.json().get()) let mut value = serde_json::from_str(pdu.json().get())
@ -694,7 +699,7 @@ async fn join_room_by_id_helper(
} }
} }
db.rooms.force_state(room_id, state)?; db.rooms.force_state(room_id, state, &db.globals)?;
} else { } else {
let event = member::MemberEventContent { let event = member::MemberEventContent {
membership: member::MembershipState::Join, membership: member::MembershipState::Join,

@ -1,7 +1,9 @@
use super::State; use super::State;
use crate::{ConduitResult, Database, Error, Ruma}; use crate::{ConduitResult, Database, Error, Ruma};
use ruma::{ use ruma::{
api::client::{error::ErrorKind, r0::read_marker::set_read_marker}, api::client::{
error::ErrorKind, r0::capabilities::get_capabilities, r0::read_marker::set_read_marker,
},
events::{AnyEphemeralRoomEvent, AnyEvent, EventType}, events::{AnyEphemeralRoomEvent, AnyEvent, EventType},
}; };
@ -76,3 +78,18 @@ pub async fn set_read_marker_route(
Ok(set_read_marker::Response.into()) Ok(set_read_marker::Response.into())
} }
#[cfg_attr(
feature = "conduit_bin",
post("/_matrix/client/r0/rooms/<_>/receipt/<_>/<_>", data = "<body>")
)]
pub async fn set_receipt_route(
db: State<'_, Database>,
body: Ruma<get_capabilities::Request>,
) -> ConduitResult<set_read_marker::Response> {
let _sender_user = body.sender_user.as_ref().expect("user is authenticated");
db.flush().await?;
Ok(set_read_marker::Response.into())
}

@ -102,9 +102,15 @@ pub async fn sync_events_route(
} }
// Database queries: // Database queries:
let encrypted_room = db
.rooms let current_state = db.rooms.room_state_full(&room_id)?;
.room_state_get(&room_id, &EventType::RoomEncryption, "")? let current_members = current_state
.iter()
.filter(|(key, _)| key.0 == EventType::RoomMember)
.map(|(key, value)| (&key.1, value)) // Only keep state key
.collect::<Vec<_>>();
let encrypted_room = current_state
.get(&(EventType::RoomEncryption, "".to_owned()))
.is_some(); .is_some();
// These type is Option<Option<_>>. The outer Option is None when there is no event between // These type is Option<Option<_>>. The outer Option is None when there is no event between
@ -117,45 +123,45 @@ pub async fn sync_events_route(
.as_ref() .as_ref()
.map(|pdu| db.rooms.pdu_state_hash(&pdu.as_ref().ok()?.0).ok()?); .map(|pdu| db.rooms.pdu_state_hash(&pdu.as_ref().ok()?.0).ok()?);
let since_members = since_state_hash.as_ref().map(|state_hash| { let since_state = since_state_hash.as_ref().map(|state_hash| {
state_hash.as_ref().and_then(|state_hash| { state_hash
db.rooms .as_ref()
.state_type(&state_hash, &EventType::RoomMember) .and_then(|state_hash| db.rooms.state_full(&room_id, &state_hash).ok())
.ok()
})
}); });
let since_encryption = since_state_hash.as_ref().map(|state_hash| { let since_encryption = since_state.as_ref().map(|state| {
state_hash.as_ref().and_then(|state_hash| { state
db.rooms .as_ref()
.state_get(&state_hash, &EventType::RoomEncryption, "") .map(|state| state.get(&(EventType::RoomEncryption, "".to_owned())))
.ok()
})
}); });
let current_members = db.rooms.room_state_type(&room_id, &EventType::RoomMember)?;
// Calculations: // Calculations:
let new_encrypted_room = let new_encrypted_room =
encrypted_room && since_encryption.map_or(false, |encryption| encryption.is_none()); encrypted_room && since_encryption.map_or(false, |encryption| encryption.is_none());
let send_member_count = since_members.as_ref().map_or(false, |since_members| { let send_member_count = since_state.as_ref().map_or(false, |since_state| {
since_members.as_ref().map_or(true, |since_members| { since_state.as_ref().map_or(true, |since_state| {
current_members.len() != since_members.len() current_members.len()
!= since_state
.iter()
.filter(|(key, _)| key.0 == EventType::RoomMember)
.count()
}) })
}); });
let since_sender_member = since_members.as_ref().map(|since_members| { let since_sender_member = since_state.as_ref().map(|since_state| {
since_members.as_ref().and_then(|members| { since_state.as_ref().and_then(|state| {
members.get(sender_user.as_str()).and_then(|pdu| { state
serde_json::from_value::<Raw<ruma::events::room::member::MemberEventContent>>( .get(&(EventType::RoomMember, sender_user.as_str().to_owned()))
pdu.content.clone(), .and_then(|pdu| {
) serde_json::from_value::<
.expect("Raw::from_value always works") Raw<ruma::events::room::member::MemberEventContent>,
.deserialize() >(pdu.content.clone())
.map_err(|_| Error::bad_database("Invalid PDU in database.")) .expect("Raw::from_value always works")
.ok() .deserialize()
}) .map_err(|_| Error::bad_database("Invalid PDU in database."))
.ok()
})
}) })
}); });
@ -170,30 +176,32 @@ pub async fn sync_events_route(
.membership; .membership;
let since_membership = let since_membership =
since_members since_state
.as_ref() .as_ref()
.map_or(MembershipState::Join, |members| { .map_or(MembershipState::Join, |since_state| {
members since_state
.as_ref() .as_ref()
.and_then(|members| { .and_then(|since_state| {
members.get(&user_id).and_then(|since_member| { since_state
serde_json::from_value::< .get(&(EventType::RoomMember, user_id.clone()))
Raw<ruma::events::room::member::MemberEventContent>, .and_then(|since_member| {
>( serde_json::from_value::<
since_member.content.clone() Raw<ruma::events::room::member::MemberEventContent>,
) >(
.expect("Raw::from_value always works") since_member.content.clone()
.deserialize() )
.map_err(|_| { .expect("Raw::from_value always works")
Error::bad_database("Invalid PDU in database.") .deserialize()
.map_err(|_| {
Error::bad_database("Invalid PDU in database.")
})
.ok()
}) })
.ok()
})
}) })
.map_or(MembershipState::Leave, |member| member.membership) .map_or(MembershipState::Leave, |member| member.membership)
}); });
let user_id = UserId::try_from(user_id) let user_id = UserId::try_from(user_id.clone())
.map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?;
match (since_membership, current_membership) { match (since_membership, current_membership) {
@ -456,7 +464,12 @@ pub async fn sync_events_route(
}) })
.and_then(|state_hash| { .and_then(|state_hash| {
db.rooms db.rooms
.state_get(&state_hash, &EventType::RoomMember, sender_user.as_str()) .state_get(
&room_id,
&state_hash,
&EventType::RoomMember,
sender_user.as_str(),
)
.ok()? .ok()?
.ok_or_else(|| Error::bad_database("State hash in db doesn't have a state.")) .ok_or_else(|| Error::bad_database("State hash in db doesn't have a state."))
.ok() .ok()

@ -20,6 +20,7 @@ use serde::Deserialize;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use std::{convert::TryInto, fs::remove_dir_all}; use std::{convert::TryInto, fs::remove_dir_all};
use tokio::sync::Semaphore;
#[derive(Clone, Deserialize)] #[derive(Clone, Deserialize)]
pub struct Config { pub struct Config {
@ -30,6 +31,8 @@ pub struct Config {
cache_capacity: u64, cache_capacity: u64,
#[serde(default = "default_max_request_size")] #[serde(default = "default_max_request_size")]
max_request_size: u32, max_request_size: u32,
#[serde(default = "default_max_concurrent_requests")]
max_concurrent_requests: u16,
#[serde(default)] #[serde(default)]
registration_disabled: bool, registration_disabled: bool,
#[serde(default)] #[serde(default)]
@ -39,7 +42,9 @@ pub struct Config {
} }
fn default_server_name() -> Box<ServerName> { fn default_server_name() -> Box<ServerName> {
"localhost".try_into().expect("") "localhost"
.try_into()
.expect("localhost is valid servername")
} }
fn default_cache_capacity() -> u64 { fn default_cache_capacity() -> u64 {
@ -50,6 +55,10 @@ fn default_max_request_size() -> u32 {
20 * 1024 * 1024 // Default to 20 MB 20 * 1024 * 1024 // Default to 20 MB
} }
fn default_max_concurrent_requests() -> u16 {
4
}
#[derive(Clone)] #[derive(Clone)]
pub struct Database { pub struct Database {
pub globals: globals::Globals, pub globals: globals::Globals,
@ -159,6 +168,7 @@ impl Database {
roomuserid_invited: db.open_tree("roomuserid_invited")?, roomuserid_invited: db.open_tree("roomuserid_invited")?,
userroomid_left: db.open_tree("userroomid_left")?, userroomid_left: db.open_tree("userroomid_left")?,
statekey_short: db.open_tree("statekey_short")?,
stateid_pduid: db.open_tree("stateid_pduid")?, stateid_pduid: db.open_tree("stateid_pduid")?,
pduid_statehash: db.open_tree("pduid_statehash")?, pduid_statehash: db.open_tree("pduid_statehash")?,
roomid_statehash: db.open_tree("roomid_statehash")?, roomid_statehash: db.open_tree("roomid_statehash")?,
@ -180,6 +190,7 @@ impl Database {
sending: sending::Sending { sending: sending::Sending {
servernamepduids: db.open_tree("servernamepduids")?, servernamepduids: db.open_tree("servernamepduids")?,
servercurrentpdus: db.open_tree("servercurrentpdus")?, servercurrentpdus: db.open_tree("servercurrentpdus")?,
maximum_requests: Arc::new(Semaphore::new(10)),
}, },
admin: admin::Admin { admin: admin::Admin {
sender: admin_sender, sender: admin_sender,

@ -4,6 +4,7 @@ use ruma::ServerName;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::sync::RwLock; use std::sync::RwLock;
use std::time::Duration;
use trust_dns_resolver::TokioAsyncResolver; use trust_dns_resolver::TokioAsyncResolver;
pub const COUNTER: &str = "c"; pub const COUNTER: &str = "c";
@ -54,11 +55,18 @@ impl Globals {
} }
}; };
let reqwest_client = reqwest::Client::builder()
.connect_timeout(Duration::from_secs(30))
.timeout(Duration::from_secs(60 * 3))
.pool_max_idle_per_host(1)
.build()
.unwrap();
Ok(Self { Ok(Self {
globals, globals,
config, config,
keypair: Arc::new(keypair), keypair: Arc::new(keypair),
reqwest_client: reqwest::Client::new(), reqwest_client,
dns_resolver: TokioAsyncResolver::tokio_from_system_conf() dns_resolver: TokioAsyncResolver::tokio_from_system_conf()
.await .await
.map_err(|_| { .map_err(|_| {

@ -62,7 +62,8 @@ pub struct Rooms {
/// Remember the state hash at events in the past. /// Remember the state hash at events in the past.
pub(super) pduid_statehash: sled::Tree, pub(super) pduid_statehash: sled::Tree,
/// The state for a given state hash. /// The state for a given state hash.
pub(super) stateid_pduid: sled::Tree, // StateId = StateHash + EventType + StateKey pub(super) statekey_short: sled::Tree, // StateKey = EventType + StateKey, Short = Count
pub(super) stateid_pduid: sled::Tree, // StateId = StateHash + Short, PduId = Count (without roomid)
} }
impl StateStore for Rooms { impl StateStore for Rooms {
@ -106,21 +107,28 @@ impl StateStore for Rooms {
impl Rooms { impl Rooms {
/// Builds a StateMap by iterating over all keys that start /// Builds a StateMap by iterating over all keys that start
/// with state_hash, this gives the full state for the given state_hash. /// with state_hash, this gives the full state for the given state_hash.
pub fn state_full(&self, state_hash: &StateHashId) -> Result<StateMap<PduEvent>> { pub fn state_full(
&self,
room_id: &RoomId,
state_hash: &StateHashId,
) -> Result<StateMap<PduEvent>> {
self.stateid_pduid self.stateid_pduid
.scan_prefix(&state_hash) .scan_prefix(&state_hash)
.values() .values()
.map(|pduid| { .map(|pduid_short| {
self.pduid_pdu.get(&pduid?)?.map_or_else( let mut pduid = room_id.as_bytes().to_vec();
|| Err(Error::bad_database("Failed to find StateMap.")), pduid.push(0xff);
pduid.extend_from_slice(&pduid_short?);
self.pduid_pdu.get(&pduid)?.map_or_else(
|| Err(Error::bad_database("Failed to find PDU in state snapshot.")),
|b| { |b| {
serde_json::from_slice::<PduEvent>(&b) serde_json::from_slice::<PduEvent>(&b)
.map_err(|_| Error::bad_database("Invalid PDU in db.")) .map_err(|_| Error::bad_database("Invalid PDU in db."))
}, },
) )
}) })
.filter_map(|r| r.ok())
.map(|pdu| { .map(|pdu| {
let pdu = pdu?;
Ok(( Ok((
( (
pdu.kind.clone(), pdu.kind.clone(),
@ -135,64 +143,45 @@ impl Rooms {
.collect::<Result<StateMap<_>>>() .collect::<Result<StateMap<_>>>()
} }
/// Returns all state entries for this type.
pub fn state_type(
&self,
state_hash: &StateHashId,
event_type: &EventType,
) -> Result<HashMap<String, PduEvent>> {
let mut prefix = state_hash.to_vec();
prefix.push(0xff);
prefix.extend_from_slice(&event_type.to_string().as_bytes());
prefix.push(0xff);
let mut hashmap = HashMap::new();
for pdu in self
.stateid_pduid
.scan_prefix(&prefix)
.values()
.map(|pdu_id| {
Ok::<_, Error>(
serde_json::from_slice::<PduEvent>(&self.pduid_pdu.get(pdu_id?)?.ok_or_else(
|| Error::bad_database("PDU in state not found in database."),
)?)
.map_err(|_| Error::bad_database("Invalid PDU bytes in room state."))?,
)
})
{
let pdu = pdu?;
let state_key = pdu.state_key.clone().ok_or_else(|| {
Error::bad_database("Room state contains event without state_key.")
})?;
hashmap.insert(state_key, pdu);
}
Ok(hashmap)
}
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
pub fn state_get( pub fn state_get(
&self, &self,
room_id: &RoomId,
state_hash: &StateHashId, state_hash: &StateHashId,
event_type: &EventType, event_type: &EventType,
state_key: &str, state_key: &str,
) -> Result<Option<(IVec, PduEvent)>> { ) -> Result<Option<(IVec, PduEvent)>> {
let mut key = state_hash.to_vec(); let mut key = event_type.to_string().as_bytes().to_vec();
key.push(0xff);
key.extend_from_slice(&event_type.to_string().as_bytes());
key.push(0xff); key.push(0xff);
key.extend_from_slice(&state_key.as_bytes()); key.extend_from_slice(&state_key.as_bytes());
self.stateid_pduid.get(&key)?.map_or(Ok(None), |pdu_id| { let short = self.statekey_short.get(&key)?;
Ok::<_, Error>(Some((
pdu_id.clone(), if let Some(short) = short {
serde_json::from_slice::<PduEvent>( let mut stateid = state_hash.to_vec();
&self.pduid_pdu.get(&pdu_id)?.ok_or_else(|| { stateid.push(0xff);
Error::bad_database("PDU in state not found in database.") stateid.extend_from_slice(&short);
})?,
) self.stateid_pduid
.map_err(|_| Error::bad_database("Invalid PDU bytes in room state."))?, .get(&stateid)?
))) .map_or(Ok(None), |pdu_id_short| {
}) let mut pdu_id = room_id.as_bytes().to_vec();
pdu_id.push(0xff);
pdu_id.extend_from_slice(&pdu_id_short);
Ok::<_, Error>(Some((
pdu_id.clone().into(),
serde_json::from_slice::<PduEvent>(
&self.pduid_pdu.get(&pdu_id)?.ok_or_else(|| {
Error::bad_database("PDU in state not found in database.")
})?,
)
.map_err(|_| Error::bad_database("Invalid PDU bytes in room state."))?,
)))
})
} else {
return Ok(None);
}
} }
/// Returns the last state hash key added to the db. /// Returns the last state hash key added to the db.
@ -260,6 +249,7 @@ impl Rooms {
&self, &self,
room_id: &RoomId, room_id: &RoomId,
state: HashMap<(EventType, String), Vec<u8>>, state: HashMap<(EventType, String), Vec<u8>>,
globals: &super::globals::Globals,
) -> Result<()> { ) -> Result<()> {
let state_hash = let state_hash =
self.calculate_hash(&state.values().map(|pdu_id| &**pdu_id).collect::<Vec<_>>())?; self.calculate_hash(&state.values().map(|pdu_id| &**pdu_id).collect::<Vec<_>>())?;
@ -267,11 +257,29 @@ impl Rooms {
prefix.push(0xff); prefix.push(0xff);
for ((event_type, state_key), pdu_id) in state { for ((event_type, state_key), pdu_id) in state {
let mut statekey = event_type.as_ref().as_bytes().to_vec();
statekey.push(0xff);
statekey.extend_from_slice(&state_key.as_bytes());
let short = match self.statekey_short.get(&statekey)? {
Some(short) => utils::u64_from_bytes(&short)
.map_err(|_| Error::bad_database("Invalid short bytes in statekey_short."))?,
None => {
let short = globals.next_count()?;
self.statekey_short
.insert(&statekey, &short.to_be_bytes())?;
short
}
};
let pdu_id_short = pdu_id
.splitn(2, |&b| b == 0xff)
.nth(1)
.ok_or_else(|| Error::bad_database("Invalid pduid in state."))?;
let mut state_id = prefix.clone(); let mut state_id = prefix.clone();
state_id.extend_from_slice(&event_type.as_ref().as_bytes()); state_id.extend_from_slice(&short.to_be_bytes());
state_id.push(0xff); self.stateid_pduid.insert(state_id, pdu_id_short)?;
state_id.extend_from_slice(&state_key.as_bytes());
self.stateid_pduid.insert(state_id, pdu_id)?;
} }
self.roomid_statehash self.roomid_statehash
@ -283,25 +291,12 @@ impl Rooms {
/// Returns the full room state. /// Returns the full room state.
pub fn room_state_full(&self, room_id: &RoomId) -> Result<StateMap<PduEvent>> { pub fn room_state_full(&self, room_id: &RoomId) -> Result<StateMap<PduEvent>> {
if let Some(current_state_hash) = self.current_state_hash(room_id)? { if let Some(current_state_hash) = self.current_state_hash(room_id)? {
self.state_full(&current_state_hash) self.state_full(&room_id, &current_state_hash)
} else { } else {
Ok(BTreeMap::new()) Ok(BTreeMap::new())
} }
} }
/// Returns all state entries for this type.
pub fn room_state_type(
&self,
room_id: &RoomId,
event_type: &EventType,
) -> Result<HashMap<String, PduEvent>> {
if let Some(current_state_hash) = self.current_state_hash(room_id)? {
self.state_type(&current_state_hash, event_type)
} else {
Ok(HashMap::new())
}
}
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
pub fn room_state_get( pub fn room_state_get(
&self, &self,
@ -310,7 +305,7 @@ impl Rooms {
state_key: &str, state_key: &str,
) -> Result<Option<(IVec, PduEvent)>> { ) -> Result<Option<(IVec, PduEvent)>> {
if let Some(current_state_hash) = self.current_state_hash(room_id)? { if let Some(current_state_hash) = self.current_state_hash(room_id)? {
self.state_get(&current_state_hash, event_type, state_key) self.state_get(&room_id, &current_state_hash, event_type, state_key)
} else { } else {
Ok(None) Ok(None)
} }
@ -593,7 +588,12 @@ impl Rooms {
/// This adds all current state events (not including the incoming event) /// This adds all current state events (not including the incoming event)
/// to `stateid_pduid` and adds the incoming event to `pduid_statehash`. /// to `stateid_pduid` and adds the incoming event to `pduid_statehash`.
/// The incoming event is the `pdu_id` passed to this method. /// The incoming event is the `pdu_id` passed to this method.
pub fn append_to_state(&self, new_pdu_id: &[u8], new_pdu: &PduEvent) -> Result<StateHashId> { pub fn append_to_state(
&self,
new_pdu_id: &[u8],
new_pdu: &PduEvent,
globals: &super::globals::Globals,
) -> Result<StateHashId> {
let old_state = let old_state =
if let Some(old_state_hash) = self.roomid_statehash.get(new_pdu.room_id.as_bytes())? { if let Some(old_state_hash) = self.roomid_statehash.get(new_pdu.room_id.as_bytes())? {
// Store state for event. The state does not include the event itself. // Store state for event. The state does not include the event itself.
@ -608,7 +608,7 @@ impl Rooms {
self.stateid_pduid self.stateid_pduid
.scan_prefix(&prefix) .scan_prefix(&prefix)
.filter_map(|pdu| pdu.map_err(|e| error!("{}", e)).ok()) .filter_map(|pdu| pdu.map_err(|e| error!("{}", e)).ok())
// Chop the old state_hash out leaving behind the (EventType, StateKey) // Chop the old state_hash out leaving behind the short key (u64)
.map(|(k, v)| (k.subslice(prefix.len(), k.len() - prefix.len()), v)) .map(|(k, v)| (k.subslice(prefix.len(), k.len() - prefix.len()), v))
.collect::<HashMap<IVec, IVec>>() .collect::<HashMap<IVec, IVec>>()
} else { } else {
@ -620,7 +620,23 @@ impl Rooms {
let mut pdu_key = new_pdu.kind.as_ref().as_bytes().to_vec(); let mut pdu_key = new_pdu.kind.as_ref().as_bytes().to_vec();
pdu_key.push(0xff); pdu_key.push(0xff);
pdu_key.extend_from_slice(state_key.as_bytes()); pdu_key.extend_from_slice(state_key.as_bytes());
new_state.insert(pdu_key.into(), new_pdu_id.into());
let short = match self.statekey_short.get(&pdu_key)? {
Some(short) => utils::u64_from_bytes(&short)
.map_err(|_| Error::bad_database("Invalid short bytes in statekey_short."))?,
None => {
let short = globals.next_count()?;
self.statekey_short.insert(&pdu_key, &short.to_be_bytes())?;
short
}
};
let new_pdu_id_short = new_pdu_id
.splitn(2, |&b| b == 0xff)
.nth(1)
.ok_or_else(|| Error::bad_database("Invalid pduid in state."))?;
new_state.insert((&short.to_be_bytes()).into(), new_pdu_id_short.into());
let new_state_hash = let new_state_hash =
self.calculate_hash(&new_state.values().map(|b| &**b).collect::<Vec<_>>())?; self.calculate_hash(&new_state.values().map(|b| &**b).collect::<Vec<_>>())?;
@ -628,12 +644,10 @@ impl Rooms {
let mut key = new_state_hash.to_vec(); let mut key = new_state_hash.to_vec();
key.push(0xff); key.push(0xff);
// TODO: we could avoid writing to the DB on every state event by keeping for (short, short_pdu_id) in new_state {
// track of the delta and write that every so often
for (key_without_prefix, pdu_id) in new_state {
let mut state_id = key.clone(); let mut state_id = key.clone();
state_id.extend_from_slice(&key_without_prefix); state_id.extend_from_slice(&short);
self.stateid_pduid.insert(&state_id, &pdu_id)?; self.stateid_pduid.insert(&state_id, &short_pdu_id)?;
} }
self.roomid_statehash self.roomid_statehash
@ -887,7 +901,7 @@ impl Rooms {
// We append to state before appending the pdu, so we don't have a moment in time with the // We append to state before appending the pdu, so we don't have a moment in time with the
// pdu without it's state. This is okay because append_pdu can't fail. // pdu without it's state. This is okay because append_pdu can't fail.
self.append_to_state(&pdu_id, &pdu)?; self.append_to_state(&pdu_id, &pdu, &globals)?;
self.append_pdu( self.append_pdu(
&pdu, &pdu,

@ -1,21 +1,29 @@
use std::{collections::HashMap, convert::TryFrom, time::SystemTime}; use std::{
collections::HashMap,
convert::TryFrom,
fmt::Debug,
sync::Arc,
time::{Duration, Instant, SystemTime},
};
use crate::{appservice_server, server_server, utils, Error, PduEvent, Result}; use crate::{appservice_server, server_server, utils, Error, PduEvent, Result};
use federation::transactions::send_transaction_message; use federation::transactions::send_transaction_message;
use log::warn; use log::warn;
use rocket::futures::stream::{FuturesUnordered, StreamExt}; use rocket::futures::stream::{FuturesUnordered, StreamExt};
use ruma::{ use ruma::{
api::{appservice, federation}, api::{appservice, federation, OutgoingRequest},
ServerName, ServerName,
}; };
use sled::IVec; use sled::IVec;
use tokio::select; use tokio::select;
use tokio::sync::Semaphore;
#[derive(Clone)] #[derive(Clone)]
pub struct Sending { pub struct Sending {
/// The state for a given state hash. /// The state for a given state hash.
pub(super) servernamepduids: sled::Tree, // ServernamePduId = (+)ServerName + PduId pub(super) servernamepduids: sled::Tree, // ServernamePduId = (+)ServerName + PduId
pub(super) servercurrentpdus: sled::Tree, // ServerCurrentPdus = (+)ServerName + PduId (pduid can be empty for reservation) pub(super) servercurrentpdus: sled::Tree, // ServerCurrentPdus = (+)ServerName + PduId (pduid can be empty for reservation)
pub(super) maximum_requests: Arc<Semaphore>,
} }
impl Sending { impl Sending {
@ -40,35 +48,7 @@ impl Sending {
for (server, pdu, is_appservice) in servercurrentpdus for (server, pdu, is_appservice) in servercurrentpdus
.iter() .iter()
.filter_map(|r| r.ok()) .filter_map(|r| r.ok())
.map(|(key, _)| { .filter_map(|(key, _)| Self::parse_servercurrentpdus(key).ok())
let mut parts = key.splitn(2, |&b| b == 0xff);
let server = parts.next().expect("splitn always returns one element");
let pdu = parts.next().ok_or_else(|| {
Error::bad_database("Invalid bytes in servercurrentpdus.")
})?;
let server = utils::string_from_bytes(&server).map_err(|_| {
Error::bad_database("Invalid server bytes in server_currenttransaction")
})?;
// Appservices start with a plus
let (server, is_appservice) = if server.starts_with("+") {
(&server[1..], true)
} else {
(&*server, false)
};
Ok::<_, Error>((
Box::<ServerName>::try_from(server).map_err(|_| {
Error::bad_database(
"Invalid server string in server_currenttransaction",
)
})?,
IVec::from(pdu),
is_appservice,
))
})
.filter_map(|r| r.ok())
.filter(|(_, pdu, _)| !pdu.is_empty()) // Skip reservation key .filter(|(_, pdu, _)| !pdu.is_empty()) // Skip reservation key
.take(50) .take(50)
// This should not contain more than 50 anyway // This should not contain more than 50 anyway
@ -90,6 +70,8 @@ impl Sending {
)); ));
} }
let mut last_failed_try: HashMap<Box<ServerName>, (u32, Instant)> = HashMap::new();
let mut subscriber = servernamepduids.watch_prefix(b""); let mut subscriber = servernamepduids.watch_prefix(b"");
loop { loop {
select! { select! {
@ -140,9 +122,24 @@ impl Sending {
// servercurrentpdus with the prefix should be empty now // servercurrentpdus with the prefix should be empty now
} }
} }
Err((server, _is_appservice, e)) => { Err((server, is_appservice, e)) => {
warn!("Couldn't send transaction to {}: {}", server, e) warn!("Couldn't send transaction to {}: {}", server, e);
// TODO: exponential backoff let mut prefix = if is_appservice {
"+".as_bytes().to_vec()
} else {
Vec::new()
};
prefix.extend_from_slice(server.as_bytes());
prefix.push(0xff);
last_failed_try.insert(server.clone(), match last_failed_try.get(&server) {
Some(last_failed) => {
(last_failed.0+1, Instant::now())
},
None => {
(1, Instant::now())
}
});
servercurrentpdus.remove(&prefix).unwrap();
} }
}; };
}, },
@ -174,8 +171,19 @@ impl Sending {
.ok() .ok()
.map(|pdu_id| (server, is_appservice, pdu_id)) .map(|pdu_id| (server, is_appservice, pdu_id))
) )
// TODO: exponential backoff
.filter(|(server, is_appservice, _)| { .filter(|(server, is_appservice, _)| {
if last_failed_try.get(server).map_or(false, |(tries, instant)| {
// Fail if a request has failed recently (exponential backoff)
let mut min_elapsed_duration = Duration::from_secs(60) * *tries * *tries;
if min_elapsed_duration > Duration::from_secs(60*60*24) {
min_elapsed_duration = Duration::from_secs(60*60*24);
}
instant.elapsed() < min_elapsed_duration
}) {
return false;
}
let mut prefix = if *is_appservice { let mut prefix = if *is_appservice {
"+".as_bytes().to_vec() "+".as_bytes().to_vec()
} else { } else {
@ -308,4 +316,63 @@ impl Sending {
.map_err(|e| (server, is_appservice, e)) .map_err(|e| (server, is_appservice, e))
} }
} }
fn parse_servercurrentpdus(key: IVec) -> Result<(Box<ServerName>, IVec, bool)> {
let mut parts = key.splitn(2, |&b| b == 0xff);
let server = parts.next().expect("splitn always returns one element");
let pdu = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
let server = utils::string_from_bytes(&server).map_err(|_| {
Error::bad_database("Invalid server bytes in server_currenttransaction")
})?;
// Appservices start with a plus
let (server, is_appservice) = if server.starts_with("+") {
(&server[1..], true)
} else {
(&*server, false)
};
Ok::<_, Error>((
Box::<ServerName>::try_from(server).map_err(|_| {
Error::bad_database("Invalid server string in server_currenttransaction")
})?,
IVec::from(pdu),
is_appservice,
))
}
pub async fn send_federation_request<T: OutgoingRequest>(
&self,
globals: &crate::database::globals::Globals,
destination: Box<ServerName>,
request: T,
) -> Result<T::IncomingResponse>
where
T: Debug,
{
let permit = self.maximum_requests.acquire().await;
let response = server_server::send_request(globals, destination, request).await;
drop(permit);
response
}
pub async fn send_appservice_request<T: OutgoingRequest>(
&self,
globals: &crate::database::globals::Globals,
registration: serde_yaml::Value,
request: T,
) -> Result<T::IncomingResponse>
where
T: Debug,
{
let permit = self.maximum_requests.acquire().await;
let response = appservice_server::send_request(globals, registration, request).await;
drop(permit);
response
}
} }

@ -121,7 +121,7 @@ impl log::Log for ConduitLogger {
fn log(&self, record: &log::Record<'_>) { fn log(&self, record: &log::Record<'_>) {
let output = format!("{} - {}", record.level(), record.args()); let output = format!("{} - {}", record.level(), record.args());
println!("{}", output); eprintln!("{}", output);
if self.enabled(record.metadata()) if self.enabled(record.metadata())
&& record && record

@ -18,7 +18,7 @@ pub use pdu::PduEvent;
pub use rocket::State; pub use rocket::State;
pub use ruma_wrapper::{ConduitResult, Ruma, RumaResponse}; pub use ruma_wrapper::{ConduitResult, Ruma, RumaResponse};
use rocket::{fairing::AdHoc, routes}; use rocket::{catch, catchers, fairing::AdHoc, routes, Request};
fn setup_rocket() -> rocket::Rocket { fn setup_rocket() -> rocket::Rocket {
// Force log level off, so we can use our own logger // Force log level off, so we can use our own logger
@ -70,6 +70,7 @@ fn setup_rocket() -> rocket::Rocket {
client_server::get_backup_key_sessions_route, client_server::get_backup_key_sessions_route,
client_server::get_backup_keys_route, client_server::get_backup_keys_route,
client_server::set_read_marker_route, client_server::set_read_marker_route,
client_server::set_receipt_route,
client_server::create_typing_event_route, client_server::create_typing_event_route,
client_server::create_room_route, client_server::create_room_route,
client_server::redact_event_route, client_server::redact_event_route,
@ -134,6 +135,7 @@ fn setup_rocket() -> rocket::Rocket {
server_server::get_profile_information_route, server_server::get_profile_information_route,
], ],
) )
.register(catchers![not_found_catcher])
.attach(AdHoc::on_attach("Config", |rocket| async { .attach(AdHoc::on_attach("Config", |rocket| async {
let data = let data =
Database::load_or_create(rocket.figment().extract().expect("config is valid")) Database::load_or_create(rocket.figment().extract().expect("config is valid"))
@ -157,3 +159,8 @@ fn setup_rocket() -> rocket::Rocket {
async fn main() { async fn main() {
setup_rocket().launch().await.unwrap(); setup_rocket().launch().await.unwrap();
} }
#[catch(404)]
fn not_found_catcher(_req: &'_ Request<'_>) -> String {
"404 Not Found".to_owned()
}

@ -490,7 +490,7 @@ pub async fn send_transaction_message_route<'a>(
pdu_id.push(0xff); pdu_id.push(0xff);
pdu_id.extend_from_slice(&count.to_be_bytes()); pdu_id.extend_from_slice(&count.to_be_bytes());
db.rooms.append_to_state(&pdu_id, &pdu)?; db.rooms.append_to_state(&pdu_id, &pdu, &db.globals)?;
db.rooms.append_pdu( db.rooms.append_pdu(
&pdu, &pdu,

Loading…
Cancel
Save