diff --git a/Cargo.lock b/Cargo.lock index c31894ae..76e727e3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -243,6 +243,7 @@ dependencies = [ "image", "jsonwebtoken", "log", + "lru-cache", "opentelemetry", "opentelemetry-jaeger", "pretty_env_logger", @@ -271,9 +272,9 @@ dependencies = [ [[package]] name = "const-oid" -version = "0.5.2" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "279bc8fc53f788a75c7804af68237d1fce02cde1e275a886a4b320604dc2aeda" +checksum = "44c32f031ea41b4291d695026c023b95d59db2d8a2c7640800ed56bc8f510f22" [[package]] name = "const_fn" @@ -393,9 +394,9 @@ dependencies = [ [[package]] name = "der" -version = "0.3.5" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2eeb9d92785d1facb50567852ce75d0858630630e7eabea59cf7eb7474051087" +checksum = "49f215f706081a44cb702c71c39a52c05da637822e9c1645a50b7202689e982d" dependencies = [ "const-oid", ] @@ -1474,9 +1475,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkcs8" -version = "0.6.1" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9c2f795bc591cb3384cb64082a578b89207ac92bb89c9d98c1ea2ace7cd8110" +checksum = "09d156817ae0125e8aa5067710b0db24f0984830614f99875a70aa5e3b74db69" dependencies = [ "der", "spki", @@ -1882,8 +1883,8 @@ dependencies = [ [[package]] name = "ruma" -version = "0.1.2" -source = "git+https://github.com/ruma/ruma?rev=5a7e2cddcf257e367465cced51442c91e8f557c9#5a7e2cddcf257e367465cced51442c91e8f557c9" +version = "0.2.0" +source = "git+https://github.com/ruma/ruma?rev=174555857ef90d49e4b9a672be9e2fe0acdc2687#174555857ef90d49e4b9a672be9e2fe0acdc2687" dependencies = [ "assign", "js_int", @@ -1903,8 +1904,8 @@ dependencies = [ [[package]] name = "ruma-api" -version = "0.17.0" -source = "git+https://github.com/ruma/ruma?rev=5a7e2cddcf257e367465cced51442c91e8f557c9#5a7e2cddcf257e367465cced51442c91e8f557c9" +version = "0.17.1" +source = "git+https://github.com/ruma/ruma?rev=174555857ef90d49e4b9a672be9e2fe0acdc2687#174555857ef90d49e4b9a672be9e2fe0acdc2687" dependencies = [ "bytes", "http", @@ -1919,8 +1920,8 @@ dependencies = [ [[package]] name = "ruma-api-macros" -version = "0.17.0" -source = "git+https://github.com/ruma/ruma?rev=5a7e2cddcf257e367465cced51442c91e8f557c9#5a7e2cddcf257e367465cced51442c91e8f557c9" +version = "0.17.1" +source = "git+https://github.com/ruma/ruma?rev=174555857ef90d49e4b9a672be9e2fe0acdc2687#174555857ef90d49e4b9a672be9e2fe0acdc2687" dependencies = [ "proc-macro-crate", "proc-macro2", @@ -1930,8 +1931,8 @@ dependencies = [ [[package]] name = "ruma-appservice-api" -version = "0.2.0" -source = "git+https://github.com/ruma/ruma?rev=5a7e2cddcf257e367465cced51442c91e8f557c9#5a7e2cddcf257e367465cced51442c91e8f557c9" +version = "0.3.0" +source = "git+https://github.com/ruma/ruma?rev=174555857ef90d49e4b9a672be9e2fe0acdc2687#174555857ef90d49e4b9a672be9e2fe0acdc2687" dependencies = [ "ruma-api", "ruma-common", @@ -1944,8 +1945,8 @@ dependencies = [ [[package]] name = "ruma-client-api" -version = "0.10.2" -source = "git+https://github.com/ruma/ruma?rev=5a7e2cddcf257e367465cced51442c91e8f557c9#5a7e2cddcf257e367465cced51442c91e8f557c9" +version = "0.11.0" +source = "git+https://github.com/ruma/ruma?rev=174555857ef90d49e4b9a672be9e2fe0acdc2687#174555857ef90d49e4b9a672be9e2fe0acdc2687" dependencies = [ "assign", "bytes", @@ -1964,8 +1965,8 @@ dependencies = [ [[package]] name = "ruma-common" -version = "0.5.3" -source = "git+https://github.com/ruma/ruma?rev=5a7e2cddcf257e367465cced51442c91e8f557c9#5a7e2cddcf257e367465cced51442c91e8f557c9" +version = "0.5.4" +source = "git+https://github.com/ruma/ruma?rev=174555857ef90d49e4b9a672be9e2fe0acdc2687#174555857ef90d49e4b9a672be9e2fe0acdc2687" dependencies = [ "indexmap", "js_int", @@ -1979,8 +1980,8 @@ dependencies = [ [[package]] name = "ruma-events" -version = "0.22.2" -source = "git+https://github.com/ruma/ruma?rev=5a7e2cddcf257e367465cced51442c91e8f557c9#5a7e2cddcf257e367465cced51442c91e8f557c9" +version = "0.23.1" +source = "git+https://github.com/ruma/ruma?rev=174555857ef90d49e4b9a672be9e2fe0acdc2687#174555857ef90d49e4b9a672be9e2fe0acdc2687" dependencies = [ "indoc", "js_int", @@ -1994,8 +1995,8 @@ dependencies = [ [[package]] name = "ruma-events-macros" -version = "0.22.2" -source = "git+https://github.com/ruma/ruma?rev=5a7e2cddcf257e367465cced51442c91e8f557c9#5a7e2cddcf257e367465cced51442c91e8f557c9" +version = "0.23.1" +source = "git+https://github.com/ruma/ruma?rev=174555857ef90d49e4b9a672be9e2fe0acdc2687#174555857ef90d49e4b9a672be9e2fe0acdc2687" dependencies = [ "proc-macro-crate", "proc-macro2", @@ -2005,8 +2006,8 @@ dependencies = [ [[package]] name = "ruma-federation-api" -version = "0.1.0" -source = "git+https://github.com/ruma/ruma?rev=5a7e2cddcf257e367465cced51442c91e8f557c9#5a7e2cddcf257e367465cced51442c91e8f557c9" +version = "0.2.0" +source = "git+https://github.com/ruma/ruma?rev=174555857ef90d49e4b9a672be9e2fe0acdc2687#174555857ef90d49e4b9a672be9e2fe0acdc2687" dependencies = [ "js_int", "ruma-api", @@ -2020,8 +2021,8 @@ dependencies = [ [[package]] name = "ruma-identifiers" -version = "0.19.2" -source = "git+https://github.com/ruma/ruma?rev=5a7e2cddcf257e367465cced51442c91e8f557c9#5a7e2cddcf257e367465cced51442c91e8f557c9" +version = "0.19.4" +source = "git+https://github.com/ruma/ruma?rev=174555857ef90d49e4b9a672be9e2fe0acdc2687#174555857ef90d49e4b9a672be9e2fe0acdc2687" dependencies = [ "paste", "rand 0.8.3", @@ -2034,8 +2035,8 @@ dependencies = [ [[package]] name = "ruma-identifiers-macros" -version = "0.19.2" -source = "git+https://github.com/ruma/ruma?rev=5a7e2cddcf257e367465cced51442c91e8f557c9#5a7e2cddcf257e367465cced51442c91e8f557c9" +version = "0.19.4" +source = "git+https://github.com/ruma/ruma?rev=174555857ef90d49e4b9a672be9e2fe0acdc2687#174555857ef90d49e4b9a672be9e2fe0acdc2687" dependencies = [ "quote", "ruma-identifiers-validation", @@ -2045,12 +2046,12 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" version = "0.4.0" -source = "git+https://github.com/ruma/ruma?rev=5a7e2cddcf257e367465cced51442c91e8f557c9#5a7e2cddcf257e367465cced51442c91e8f557c9" +source = "git+https://github.com/ruma/ruma?rev=174555857ef90d49e4b9a672be9e2fe0acdc2687#174555857ef90d49e4b9a672be9e2fe0acdc2687" [[package]] name = "ruma-identity-service-api" -version = "0.1.0" -source = "git+https://github.com/ruma/ruma?rev=5a7e2cddcf257e367465cced51442c91e8f557c9#5a7e2cddcf257e367465cced51442c91e8f557c9" +version = "0.2.0" +source = "git+https://github.com/ruma/ruma?rev=174555857ef90d49e4b9a672be9e2fe0acdc2687#174555857ef90d49e4b9a672be9e2fe0acdc2687" dependencies = [ "js_int", "ruma-api", @@ -2062,8 +2063,8 @@ dependencies = [ [[package]] name = "ruma-push-gateway-api" -version = "0.1.0" -source = "git+https://github.com/ruma/ruma?rev=5a7e2cddcf257e367465cced51442c91e8f557c9#5a7e2cddcf257e367465cced51442c91e8f557c9" +version = "0.2.0" +source = "git+https://github.com/ruma/ruma?rev=174555857ef90d49e4b9a672be9e2fe0acdc2687#174555857ef90d49e4b9a672be9e2fe0acdc2687" dependencies = [ "js_int", "ruma-api", @@ -2077,8 +2078,8 @@ dependencies = [ [[package]] name = "ruma-serde" -version = "0.4.0" -source = "git+https://github.com/ruma/ruma?rev=5a7e2cddcf257e367465cced51442c91e8f557c9#5a7e2cddcf257e367465cced51442c91e8f557c9" +version = "0.4.1" +source = "git+https://github.com/ruma/ruma?rev=174555857ef90d49e4b9a672be9e2fe0acdc2687#174555857ef90d49e4b9a672be9e2fe0acdc2687" dependencies = [ "bytes", "form_urlencoded", @@ -2091,8 +2092,8 @@ dependencies = [ [[package]] name = "ruma-serde-macros" -version = "0.4.0" -source = "git+https://github.com/ruma/ruma?rev=5a7e2cddcf257e367465cced51442c91e8f557c9#5a7e2cddcf257e367465cced51442c91e8f557c9" +version = "0.4.1" +source = "git+https://github.com/ruma/ruma?rev=174555857ef90d49e4b9a672be9e2fe0acdc2687#174555857ef90d49e4b9a672be9e2fe0acdc2687" dependencies = [ "proc-macro-crate", "proc-macro2", @@ -2102,8 +2103,8 @@ dependencies = [ [[package]] name = "ruma-signatures" -version = "0.7.2" -source = "git+https://github.com/ruma/ruma?rev=5a7e2cddcf257e367465cced51442c91e8f557c9#5a7e2cddcf257e367465cced51442c91e8f557c9" +version = "0.8.0" +source = "git+https://github.com/ruma/ruma?rev=174555857ef90d49e4b9a672be9e2fe0acdc2687#174555857ef90d49e4b9a672be9e2fe0acdc2687" dependencies = [ "base64 0.13.0", "ed25519-dalek", @@ -2115,13 +2116,12 @@ dependencies = [ "sha2", "thiserror", "tracing", - "untrusted", ] [[package]] name = "ruma-state-res" -version = "0.1.0" -source = "git+https://github.com/ruma/ruma?rev=5a7e2cddcf257e367465cced51442c91e8f557c9#5a7e2cddcf257e367465cced51442c91e8f557c9" +version = "0.2.0" +source = "git+https://github.com/ruma/ruma?rev=174555857ef90d49e4b9a672be9e2fe0acdc2687#174555857ef90d49e4b9a672be9e2fe0acdc2687" dependencies = [ "itertools 0.10.0", "js_int", @@ -2130,7 +2130,6 @@ dependencies = [ "ruma-events", "ruma-identifiers", "ruma-serde", - "ruma-signatures", "serde", "serde_json", "thiserror", @@ -2444,9 +2443,9 @@ checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" [[package]] name = "spki" -version = "0.3.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9dae7e047abc519c96350e9484a96c6bf1492348af912fd3446dd2dc323f6268" +checksum = "987637c5ae6b3121aba9d513f869bd2bff11c4cc086c22473befd6649c0bd521" dependencies = [ "der", ] diff --git a/Cargo.toml b/Cargo.toml index 4f7095de..426d242c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,6 @@ repository = "https://gitlab.com/famedly/conduit" readme = "README.md" version = "0.1.0" edition = "2018" -rust = "1.50" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -18,7 +17,7 @@ rust = "1.50" rocket = { git = "https://github.com/SergioBenitez/Rocket.git", rev = "801e04bd5369eb39e126c75f6d11e1e9597304d8", features = ["tls"] } # Used to handle requests # Used for matrix spec type definitions and helpers -ruma = { git = "https://github.com/ruma/ruma", rev = "5a7e2cddcf257e367465cced51442c91e8f557c9", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } +ruma = { git = "https://github.com/ruma/ruma", rev = "174555857ef90d49e4b9a672be9e2fe0acdc2687", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } #ruma = { path = "../ruma/crates/ruma", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } # Used for long polling and federation sender, should be the same as rocket::tokio @@ -73,6 +72,7 @@ tracing-subscriber = "0.2.16" tracing-opentelemetry = "0.11.0" opentelemetry-jaeger = "0.11.0" pretty_env_logger = "0.4.0" +lru-cache = "0.1.2" [features] default = ["conduit_bin", "backend_sled"] diff --git a/rust-toolchain b/rust-toolchain index 5a5c7211..ba0a7191 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -1.50.0 +1.51.0 diff --git a/src/client_server/directory.rs b/src/client_server/directory.rs index be5501ab..1b6b1d7b 100644 --- a/src/client_server/directory.rs +++ b/src/client_server/directory.rs @@ -200,84 +200,84 @@ pub async fn get_public_rooms_filtered_helper( } } - let mut all_rooms = db - .rooms - .public_rooms() - .map(|room_id| { - let room_id = room_id?; + let mut all_rooms = + db.rooms + .public_rooms() + .map(|room_id| { + let room_id = room_id?; - let chunk = PublicRoomsChunk { - aliases: Vec::new(), - canonical_alias: db - .rooms - .room_state_get(&room_id, &EventType::RoomCanonicalAlias, "")? - .map_or(Ok::<_, Error>(None), |s| { - Ok( - serde_json::from_value::< + let chunk = PublicRoomsChunk { + aliases: Vec::new(), + canonical_alias: db + .rooms + .room_state_get(&room_id, &EventType::RoomCanonicalAlias, "")? + .map_or(Ok::<_, Error>(None), |s| { + Ok(serde_json::from_value::< Raw, - >(s.content) + >(s.content.clone()) .expect("from_value::> can never fail") .deserialize() .map_err(|_| { Error::bad_database("Invalid canonical alias event in database.") })? - .alias, - ) - })?, - name: db - .rooms - .room_state_get(&room_id, &EventType::RoomName, "")? - .map_or(Ok::<_, Error>(None), |s| { - Ok( - serde_json::from_value::>(s.content) - .expect("from_value::> can never fail") - .deserialize() - .map_err(|_| { - Error::bad_database("Invalid room name event in database.") - })? - .name() - .map(|n| n.to_owned()), - ) - })?, - num_joined_members: (db.rooms.room_members(&room_id).count() as u32).into(), - topic: db - .rooms - .room_state_get(&room_id, &EventType::RoomTopic, "")? - .map_or(Ok::<_, Error>(None), |s| { - Ok(Some( - serde_json::from_value::>(s.content) + .alias) + })?, + name: db + .rooms + .room_state_get(&room_id, &EventType::RoomName, "")? + .map_or(Ok::<_, Error>(None), |s| { + Ok(serde_json::from_value::>( + s.content.clone(), + ) + .expect("from_value::> can never fail") + .deserialize() + .map_err(|_| { + Error::bad_database("Invalid room name event in database.") + })? + .name() + .map(|n| n.to_owned())) + })?, + num_joined_members: (db.rooms.room_members(&room_id).count() as u32).into(), + topic: db + .rooms + .room_state_get(&room_id, &EventType::RoomTopic, "")? + .map_or(Ok::<_, Error>(None), |s| { + Ok(Some( + serde_json::from_value::>( + s.content.clone(), + ) .expect("from_value::> can never fail") .deserialize() .map_err(|_| { Error::bad_database("Invalid room topic event in database.") })? .topic, - )) - })?, - world_readable: db - .rooms - .room_state_get(&room_id, &EventType::RoomHistoryVisibility, "")? - .map_or(Ok::<_, Error>(false), |s| { - Ok(serde_json::from_value::< - Raw, - >(s.content) - .expect("from_value::> can never fail") - .deserialize() - .map_err(|_| { - Error::bad_database( - "Invalid room history visibility event in database.", - ) - })? - .history_visibility - == history_visibility::HistoryVisibility::WorldReadable) - })?, - guest_can_join: db - .rooms - .room_state_get(&room_id, &EventType::RoomGuestAccess, "")? - .map_or(Ok::<_, Error>(false), |s| { - Ok( + )) + })?, + world_readable: db + .rooms + .room_state_get(&room_id, &EventType::RoomHistoryVisibility, "")? + .map_or(Ok::<_, Error>(false), |s| { + Ok(serde_json::from_value::< + Raw, + >(s.content.clone()) + .expect("from_value::> can never fail") + .deserialize() + .map_err(|_| { + Error::bad_database( + "Invalid room history visibility event in database.", + ) + })? + .history_visibility + == history_visibility::HistoryVisibility::WorldReadable) + })?, + guest_can_join: db + .rooms + .room_state_get(&room_id, &EventType::RoomGuestAccess, "")? + .map_or(Ok::<_, Error>(false), |s| { + Ok( serde_json::from_value::>( - s.content, + s.content.clone(), ) .expect("from_value::> can never fail") .deserialize() @@ -287,61 +287,63 @@ pub async fn get_public_rooms_filtered_helper( .guest_access == guest_access::GuestAccess::CanJoin, ) - })?, - avatar_url: db - .rooms - .room_state_get(&room_id, &EventType::RoomAvatar, "")? - .map(|s| { - Ok::<_, Error>( - serde_json::from_value::>(s.content) + })?, + avatar_url: db + .rooms + .room_state_get(&room_id, &EventType::RoomAvatar, "")? + .map(|s| { + Ok::<_, Error>( + serde_json::from_value::>( + s.content.clone(), + ) .expect("from_value::> can never fail") .deserialize() .map_err(|_| { Error::bad_database("Invalid room avatar event in database.") })? .url, - ) - }) - .transpose()? - // url is now an Option so we must flatten - .flatten(), - room_id, - }; - Ok(chunk) - }) - .filter_map(|r: Result<_>| r.ok()) // Filter out buggy rooms - .filter(|chunk| { - if let Some(query) = filter - .generic_search_term - .as_ref() - .map(|q| q.to_lowercase()) - { - if let Some(name) = &chunk.name { - if name.to_lowercase().contains(&query) { - return true; + ) + }) + .transpose()? + // url is now an Option so we must flatten + .flatten(), + room_id, + }; + Ok(chunk) + }) + .filter_map(|r: Result<_>| r.ok()) // Filter out buggy rooms + .filter(|chunk| { + if let Some(query) = filter + .generic_search_term + .as_ref() + .map(|q| q.to_lowercase()) + { + if let Some(name) = &chunk.name { + if name.to_lowercase().contains(&query) { + return true; + } } - } - if let Some(topic) = &chunk.topic { - if topic.to_lowercase().contains(&query) { - return true; + if let Some(topic) = &chunk.topic { + if topic.to_lowercase().contains(&query) { + return true; + } } - } - if let Some(canonical_alias) = &chunk.canonical_alias { - if canonical_alias.as_str().to_lowercase().contains(&query) { - return true; + if let Some(canonical_alias) = &chunk.canonical_alias { + if canonical_alias.as_str().to_lowercase().contains(&query) { + return true; + } } - } - false - } else { - // No search term - true - } - }) - // We need to collect all, so we can sort by member count - .collect::>(); + false + } else { + // No search term + true + } + }) + // We need to collect all, so we can sort by member count + .collect::>(); all_rooms.sort_by(|l, r| r.num_joined_members.cmp(&l.num_joined_members)); diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs index 2dfa0776..5c57b68a 100644 --- a/src/client_server/membership.rs +++ b/src/client_server/membership.rs @@ -25,7 +25,7 @@ use ruma::{ EventType, }, serde::{to_canonical_value, CanonicalJsonObject, CanonicalJsonValue, Raw}, - state_res::{self, EventMap, RoomVersion}, + state_res::{self, RoomVersion}, uint, EventId, RoomId, RoomVersionId, ServerName, UserId, }; use std::{ @@ -189,7 +189,8 @@ pub async fn kick_user_route( ErrorKind::BadState, "Cannot kick member that's not in the room.", ))? - .content, + .content + .clone(), ) .expect("Raw::from_value always works") .deserialize() @@ -245,11 +246,12 @@ pub async fn ban_user_route( third_party_invite: None, }), |event| { - let mut event = - serde_json::from_value::>(event.content) - .expect("Raw::from_value always works") - .deserialize() - .map_err(|_| Error::bad_database("Invalid member event in database."))?; + let mut event = serde_json::from_value::>( + event.content.clone(), + ) + .expect("Raw::from_value always works") + .deserialize() + .map_err(|_| Error::bad_database("Invalid member event in database."))?; event.membership = ruma::events::room::member::MembershipState::Ban; Ok(event) }, @@ -295,7 +297,8 @@ pub async fn unban_user_route( ErrorKind::BadState, "Cannot unban a user who is not banned.", ))? - .content, + .content + .clone(), ) .expect("from_value::> can never fail") .deserialize() @@ -753,7 +756,7 @@ pub async fn invite_helper( let create_prev_event = if prev_events.len() == 1 && Some(&prev_events[0]) == create_event.as_ref().map(|c| &c.event_id) { - create_event.map(Arc::new) + create_event } else { None }; @@ -792,10 +795,10 @@ pub async fn invite_helper( let mut unsigned = BTreeMap::new(); if let Some(prev_pdu) = db.rooms.room_state_get(room_id, &kind, &state_key)? { - unsigned.insert("prev_content".to_owned(), prev_pdu.content); + unsigned.insert("prev_content".to_owned(), prev_pdu.content.clone()); unsigned.insert( "prev_sender".to_owned(), - serde_json::to_value(prev_pdu.sender).expect("UserId::to_value always works"), + serde_json::to_value(&prev_pdu.sender).expect("UserId::to_value always works"), ); } @@ -880,7 +883,6 @@ pub async fn invite_helper( .await?; let pub_key_map = RwLock::new(BTreeMap::new()); - let mut auth_cache = EventMap::new(); // We do not add the event_id field to the pdu here because of signature and hashes checks let (event_id, value) = match crate::pdu::gen_event_id_canonical_json(&response.event) { @@ -903,26 +905,19 @@ pub async fn invite_helper( ) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?; - let pdu_id = server_server::handle_incoming_pdu( - &origin, - &event_id, - value, - true, - &db, - &pub_key_map, - &mut auth_cache, - ) - .await - .map_err(|_| { - Error::BadRequest( - ErrorKind::InvalidParam, - "Error while handling incoming PDU.", - ) - })? - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Could not accept incoming PDU as timeline event.", - ))?; + let pdu_id = + server_server::handle_incoming_pdu(&origin, &event_id, value, true, &db, &pub_key_map) + .await + .map_err(|_| { + Error::BadRequest( + ErrorKind::InvalidParam, + "Error while handling incoming PDU.", + ) + })? + .ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Could not accept incoming PDU as timeline event.", + ))?; for server in db .rooms diff --git a/src/client_server/profile.rs b/src/client_server/profile.rs index 32bb6083..4e9a37b6 100644 --- a/src/client_server/profile.rs +++ b/src/client_server/profile.rs @@ -53,7 +53,8 @@ pub async fn set_displayname_route( room.", ) })? - .content, + .content + .clone(), ) .expect("from_value::> can never fail") .deserialize() @@ -154,7 +155,8 @@ pub async fn set_avatar_url_route( room.", ) })? - .content, + .content + .clone(), ) .expect("from_value::> can never fail") .deserialize() diff --git a/src/client_server/room.rs b/src/client_server/room.rs index 3f913249..b33b5500 100644 --- a/src/client_server/room.rs +++ b/src/client_server/room.rs @@ -362,7 +362,8 @@ pub async fn upgrade_room_route( db.rooms .room_state_get(&body.room_id, &EventType::RoomCreate, "")? .ok_or_else(|| Error::bad_database("Found room without m.room.create event."))? - .content, + .content + .clone(), ) .expect("Raw::from_value always works") .deserialize() @@ -463,7 +464,8 @@ pub async fn upgrade_room_route( db.rooms .room_state_get(&body.room_id, &EventType::RoomPowerLevels, "")? .ok_or_else(|| Error::bad_database("Found room without m.room.create event."))? - .content, + .content + .clone(), ) .expect("database contains invalid PDU") .deserialize() diff --git a/src/client_server/state.rs b/src/client_server/state.rs index c431ac0d..be52834a 100644 --- a/src/client_server/state.rs +++ b/src/client_server/state.rs @@ -92,7 +92,7 @@ pub async fn get_state_events_route( db.rooms .room_state_get(&body.room_id, &EventType::RoomHistoryVisibility, "")? .map(|event| { - serde_json::from_value::(event.content) + serde_json::from_value::(event.content.clone()) .map_err(|_| { Error::bad_database( "Invalid room history visibility event in database.", @@ -139,7 +139,7 @@ pub async fn get_state_events_for_key_route( db.rooms .room_state_get(&body.room_id, &EventType::RoomHistoryVisibility, "")? .map(|event| { - serde_json::from_value::(event.content) + serde_json::from_value::(event.content.clone()) .map_err(|_| { Error::bad_database( "Invalid room history visibility event in database.", @@ -165,7 +165,7 @@ pub async fn get_state_events_for_key_route( ))?; Ok(get_state_events_for_key::Response { - content: serde_json::from_value(event.content) + content: serde_json::from_value(event.content.clone()) .map_err(|_| Error::bad_database("Invalid event content in database"))?, } .into()) @@ -190,7 +190,7 @@ pub async fn get_state_events_for_empty_key_route( db.rooms .room_state_get(&body.room_id, &EventType::RoomHistoryVisibility, "")? .map(|event| { - serde_json::from_value::(event.content) + serde_json::from_value::(event.content.clone()) .map_err(|_| { Error::bad_database( "Invalid room history visibility event in database.", @@ -216,7 +216,7 @@ pub async fn get_state_events_for_empty_key_route( ))?; Ok(get_state_events_for_key::Response { - content: serde_json::from_value(event.content) + content: serde_json::from_value(event.content.clone()) .map_err(|_| Error::bad_database("Invalid event content in database"))?, } .into()) diff --git a/src/client_server/sync.rs b/src/client_server/sync.rs index 1c078e91..69511fa1 100644 --- a/src/client_server/sync.rs +++ b/src/client_server/sync.rs @@ -1,21 +1,22 @@ use super::State; -use crate::{ConduitResult, Database, Error, Result, Ruma}; +use crate::{ConduitResult, Database, Error, Result, Ruma, RumaResponse}; use log::error; use ruma::{ - api::client::r0::sync::sync_events, + api::client::r0::{sync::sync_events, uiaa::UiaaResponse}, events::{room::member::MembershipState, AnySyncEphemeralRoomEvent, EventType}, serde::Raw, - RoomId, UserId, + DeviceId, RoomId, UserId, }; - -#[cfg(feature = "conduit_bin")] -use rocket::{get, tokio}; use std::{ - collections::{hash_map, BTreeMap, HashMap, HashSet}, + collections::{btree_map::Entry, hash_map, BTreeMap, HashMap, HashSet}, convert::{TryFrom, TryInto}, sync::Arc, time::Duration, }; +use tokio::sync::watch::Sender; + +#[cfg(feature = "conduit_bin")] +use rocket::{get, tokio}; /// # `GET /_matrix/client/r0/sync` /// @@ -36,21 +37,134 @@ use std::{ pub async fn sync_events_route( db: State<'_, Arc>, body: Ruma>, -) -> ConduitResult { +) -> std::result::Result, RumaResponse> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); + let mut rx = match db + .globals + .sync_receivers + .write() + .unwrap() + .entry((sender_user.clone(), sender_device.clone())) + { + Entry::Vacant(v) => { + let (tx, rx) = tokio::sync::watch::channel(None); + + tokio::spawn(sync_helper_wrapper( + Arc::clone(&db), + sender_user.clone(), + sender_device.clone(), + body.since.clone(), + body.full_state, + body.timeout, + tx, + )); + + v.insert((body.since.clone(), rx)).1.clone() + } + Entry::Occupied(mut o) => { + if o.get().0 != body.since { + let (tx, rx) = tokio::sync::watch::channel(None); + + tokio::spawn(sync_helper_wrapper( + Arc::clone(&db), + sender_user.clone(), + sender_device.clone(), + body.since.clone(), + body.full_state, + body.timeout, + tx, + )); + + o.insert((body.since.clone(), rx.clone())); + + rx + } else { + o.get().1.clone() + } + } + }; + + let we_have_to_wait = rx.borrow().is_none(); + if we_have_to_wait { + let _ = rx.changed().await; + } + + let result = match rx + .borrow() + .as_ref() + .expect("When sync channel changes it's always set to some") + { + Ok(response) => Ok(response.clone()), + Err(error) => Err(error.to_response()), + }; + + result +} + +pub async fn sync_helper_wrapper( + db: Arc, + sender_user: UserId, + sender_device: Box, + since: Option, + full_state: bool, + timeout: Option, + tx: Sender>>, +) { + let r = sync_helper( + Arc::clone(&db), + sender_user.clone(), + sender_device.clone(), + since.clone(), + full_state, + timeout, + ) + .await; + + if let Ok((_, caching_allowed)) = r { + if !caching_allowed { + match db + .globals + .sync_receivers + .write() + .unwrap() + .entry((sender_user, sender_device)) + { + Entry::Occupied(o) => { + // Only remove if the device didn't start a different /sync already + if o.get().0 == since { + o.remove(); + } + } + Entry::Vacant(_) => {} + } + } + } + + let _ = tx.send(Some(r.map(|(r, _)| r.into()))); +} + +async fn sync_helper( + db: Arc, + sender_user: UserId, + sender_device: Box, + since: Option, + full_state: bool, + timeout: Option, + // bool = caching allowed +) -> std::result::Result<(sync_events::Response, bool), Error> { // TODO: match body.set_presence { db.rooms.edus.ping_presence(&sender_user)?; // Setup watchers, so if there's no response, we can wait for them - let watcher = db.watch(sender_user, sender_device); + let watcher = db.watch(&sender_user, &sender_device); - let next_batch = db.globals.current_count()?.to_string(); + let next_batch = db.globals.current_count()?; + let next_batch_string = next_batch.to_string(); let mut joined_rooms = BTreeMap::new(); - let since = body - .since + let since = since .clone() .and_then(|string| string.parse().ok()) .unwrap_or(0); @@ -114,10 +228,11 @@ pub async fn sync_events_route( // since and the current room state, meaning there should be no updates. // The inner Option is None when there is an event, but there is no state hash associated // with it. This can happen for the RoomCreate event, so all updates should arrive. - let first_pdu_before_since = db.rooms.pdus_until(sender_user, &room_id, since).next(); + let first_pdu_before_since = db.rooms.pdus_until(&sender_user, &room_id, since).next(); + let pdus_after_since = db .rooms - .pdus_after(sender_user, &room_id, since) + .pdus_after(&sender_user, &room_id, since) .next() .is_some(); @@ -256,11 +371,11 @@ pub async fn sync_events_route( .flatten() .filter(|user_id| { // Don't send key updates from the sender to the sender - sender_user != user_id + &sender_user != user_id }) .filter(|user_id| { // Only send keys if the sender doesn't share an encrypted room with the target already - !share_encrypted_room(&db, sender_user, user_id, &room_id) + !share_encrypted_room(&db, &sender_user, user_id, &room_id) .unwrap_or(false) }), ); @@ -335,7 +450,7 @@ pub async fn sync_events_route( let state_events = if joined_since_last_sync { current_state - .into_iter() + .iter() .map(|(_, pdu)| pdu.to_sync_state_event()) .collect() } else { @@ -520,7 +635,7 @@ pub async fn sync_events_route( account_data: sync_events::RoomAccountData { events: Vec::new() }, timeline: sync_events::Timeline { limited: false, - prev_batch: Some(next_batch.clone()), + prev_batch: Some(next_batch_string.clone()), events: Vec::new(), }, state: sync_events::State { @@ -573,10 +688,10 @@ pub async fn sync_events_route( // Remove all to-device events the device received *last time* db.users - .remove_to_device_events(sender_user, sender_device, since)?; + .remove_to_device_events(&sender_user, &sender_device, since)?; let response = sync_events::Response { - next_batch, + next_batch: next_batch_string, rooms: sync_events::Rooms { leave: left_rooms, join: joined_rooms, @@ -604,20 +719,22 @@ pub async fn sync_events_route( changed: device_list_updates.into_iter().collect(), left: device_list_left.into_iter().collect(), }, - device_one_time_keys_count: if db.users.last_one_time_keys_update(sender_user)? > since + device_one_time_keys_count: if db.users.last_one_time_keys_update(&sender_user)? > since || since == 0 { - db.users.count_one_time_keys(sender_user, sender_device)? + db.users.count_one_time_keys(&sender_user, &sender_device)? } else { BTreeMap::new() }, to_device: sync_events::ToDevice { - events: db.users.get_to_device_events(sender_user, sender_device)?, + events: db + .users + .get_to_device_events(&sender_user, &sender_device)?, }, }; // TODO: Retry the endpoint instead of returning (waiting for #118) - if !body.full_state + if !full_state && response.rooms.is_empty() && response.presence.is_empty() && response.account_data.is_empty() @@ -627,14 +744,15 @@ pub async fn sync_events_route( { // Hang a few seconds so requests are not spammed // Stop hanging if new info arrives - let mut duration = body.timeout.unwrap_or_default(); + let mut duration = timeout.unwrap_or_default(); if duration.as_secs() > 30 { duration = Duration::from_secs(30); } let _ = tokio::time::timeout(duration, watcher).await; + Ok((response, false)) + } else { + Ok((response, since != next_batch)) // Only cache if we made progress } - - Ok(response.into()) } #[tracing::instrument(skip(db))] diff --git a/src/database.rs b/src/database.rs index 0ea4d784..ec4052cb 100644 --- a/src/database.rs +++ b/src/database.rs @@ -18,6 +18,7 @@ use crate::{utils, Error, Result}; use abstraction::DatabaseEngine; use directories::ProjectDirs; use log::error; +use lru_cache::LruCache; use rocket::futures::{channel::mpsc, stream::FuturesUnordered, StreamExt}; use ruma::{DeviceId, ServerName, UserId}; use serde::Deserialize; @@ -194,6 +195,7 @@ impl Database { eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?, prevevent_parent: builder.open_tree("prevevent_parent")?, + pdu_cache: RwLock::new(LruCache::new(1_000_000)), }, account_data: account_data::AccountData { roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, diff --git a/src/database/globals.rs b/src/database/globals.rs index db166e98..eef478a1 100644 --- a/src/database/globals.rs +++ b/src/database/globals.rs @@ -1,8 +1,11 @@ -use crate::{database::Config, utils, Error, Result}; +use crate::{database::Config, utils, ConduitResult, Error, Result}; use log::{error, info}; use ruma::{ - api::federation::discovery::{ServerSigningKeys, VerifyKey}, - EventId, MilliSecondsSinceUnixEpoch, ServerName, ServerSigningKeyId, + api::{ + client::r0::sync::sync_events, + federation::discovery::{ServerSigningKeys, VerifyKey}, + }, + DeviceId, EventId, MilliSecondsSinceUnixEpoch, ServerName, ServerSigningKeyId, UserId, }; use rustls::{ServerCertVerifier, WebPKIVerifier}; use std::{ @@ -35,6 +38,15 @@ pub struct Globals { pub bad_event_ratelimiter: Arc>>, pub bad_signature_ratelimiter: Arc, RateLimitState>>>, pub servername_ratelimiter: Arc, Arc>>>, + pub sync_receivers: RwLock< + BTreeMap< + (UserId, Box), + ( + Option, + tokio::sync::watch::Receiver>>, + ), // since, rx + >, + >, } struct MatrixServerVerifier { @@ -155,6 +167,7 @@ impl Globals { bad_event_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), bad_signature_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), servername_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), + sync_receivers: RwLock::new(BTreeMap::new()), }; fs::create_dir_all(s.get_media_folder())?; diff --git a/src/database/pusher.rs b/src/database/pusher.rs index 358c3c98..a27bf2ce 100644 --- a/src/database/pusher.rs +++ b/src/database/pusher.rs @@ -203,7 +203,7 @@ pub fn get_actions<'a>( .rooms .room_state_get(&pdu.room_id, &EventType::RoomPowerLevels, "")? .map(|ev| { - serde_json::from_value(ev.content) + serde_json::from_value(ev.content.clone()) .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) }) .transpose()? diff --git a/src/database/rooms.rs b/src/database/rooms.rs index f19d4b99..e23b8046 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -5,6 +5,7 @@ use member::MembershipState; use crate::{pdu::PduBuilder, utils, Database, Error, PduEvent, Result}; use log::{debug, error, warn}; +use lru_cache::LruCache; use regex::Regex; use ring::digest; use ruma::{ @@ -23,7 +24,7 @@ use std::{ collections::{BTreeMap, HashMap, HashSet}, convert::{TryFrom, TryInto}, mem, - sync::Arc, + sync::{Arc, RwLock}, }; use super::{abstraction::Tree, admin::AdminCommand, pusher}; @@ -81,6 +82,8 @@ pub struct Rooms { /// RoomId + EventId -> Parent PDU EventId. pub(super) prevevent_parent: Arc, + + pub(super) pdu_cache: RwLock>>, } impl Rooms { @@ -105,8 +108,8 @@ impl Rooms { pub fn state_full( &self, shortstatehash: u64, - ) -> Result> { - Ok(self + ) -> Result>> { + let state = self .stateid_shorteventid .scan_prefix(shortstatehash.to_be_bytes().to_vec()) .map(|(_, bytes)| self.shorteventid_eventid.get(&bytes).ok().flatten()) @@ -133,7 +136,9 @@ impl Rooms { )) }) .filter_map(|r| r.ok()) - .collect()) + .collect(); + + Ok(state) } /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). @@ -179,7 +184,7 @@ impl Rooms { shortstatehash: u64, event_type: &EventType, state_key: &str, - ) -> Result> { + ) -> Result>> { self.state_get_id(shortstatehash, event_type, state_key)? .map_or(Ok(None), |event_id| self.get_pdu(&event_id)) } @@ -234,7 +239,7 @@ impl Rooms { let mut events = StateMap::new(); for (event_type, state_key) in auth_events { if let Some(pdu) = self.room_state_get(room_id, &event_type, &state_key)? { - events.insert((event_type, state_key), Arc::new(pdu)); + events.insert((event_type, state_key), pdu); } else { // This is okay because when creating a new room some events were not created yet debug!( @@ -396,7 +401,7 @@ impl Rooms { pub fn room_state_full( &self, room_id: &RoomId, - ) -> Result> { + ) -> Result>> { if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { self.state_full(current_shortstatehash) } else { @@ -426,7 +431,7 @@ impl Rooms { room_id: &RoomId, event_type: &EventType, state_key: &str, - ) -> Result> { + ) -> Result>> { if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { self.state_get(current_shortstatehash, event_type, state_key) } else { @@ -514,21 +519,42 @@ impl Rooms { /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - pub fn get_pdu(&self, event_id: &EventId) -> Result> { - self.eventid_pduid + pub fn get_pdu(&self, event_id: &EventId) -> Result>> { + if let Some(p) = self.pdu_cache.write().unwrap().get_mut(&event_id) { + return Ok(Some(Arc::clone(p))); + } + + if let Some(pdu) = self + .eventid_pduid .get(event_id.as_bytes())? .map_or_else::, _, _>( - || self.eventid_outlierpdu.get(event_id.as_bytes()), + || { + let r = self.eventid_outlierpdu.get(event_id.as_bytes()); + r + }, |pduid| { - Ok(Some(self.pduid_pdu.get(&pduid)?.ok_or_else(|| { + let r = Ok(Some(self.pduid_pdu.get(&pduid)?.ok_or_else(|| { Error::bad_database("Invalid pduid in eventid_pduid.") - })?)) + })?)); + r }, )? .map(|pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) + let r = serde_json::from_slice(&pdu) + .map_err(|_| Error::bad_database("Invalid PDU in db.")) + .map(Arc::new); + r }) - .transpose() + .transpose()? + { + self.pdu_cache + .write() + .unwrap() + .insert(event_id.clone(), Arc::clone(&pdu)); + Ok(Some(pdu)) + } else { + Ok(None) + } } /// Returns the pdu. @@ -663,7 +689,7 @@ impl Rooms { unsigned.insert( "prev_content".to_owned(), CanonicalJsonValue::Object( - utils::to_canonical_object(prev_state.content) + utils::to_canonical_object(prev_state.content.clone()) .expect("event is valid, we just created it"), ), ); @@ -1204,7 +1230,7 @@ impl Rooms { let create_prev_event = if prev_events.len() == 1 && Some(&prev_events[0]) == create_event.as_ref().map(|c| &c.event_id) { - create_event.map(Arc::new) + create_event } else { None }; @@ -1235,10 +1261,10 @@ impl Rooms { let mut unsigned = unsigned.unwrap_or_default(); if let Some(state_key) = &state_key { if let Some(prev_pdu) = self.room_state_get(&room_id, &event_type, &state_key)? { - unsigned.insert("prev_content".to_owned(), prev_pdu.content); + unsigned.insert("prev_content".to_owned(), prev_pdu.content.clone()); unsigned.insert( "prev_sender".to_owned(), - serde_json::to_value(prev_pdu.sender).expect("UserId::to_value always works"), + serde_json::to_value(&prev_pdu.sender).expect("UserId::to_value always works"), ); } } @@ -1583,7 +1609,7 @@ impl Rooms { .and_then(|create| { serde_json::from_value::< Raw, - >(create.content) + >(create.content.clone()) .expect("Raw::from_value always works") .deserialize() .ok() @@ -1764,7 +1790,8 @@ impl Rooms { ErrorKind::BadState, "Cannot leave a room you are not a member of.", ))? - .content, + .content + .clone(), ) .expect("from_value::> can never fail") .deserialize() diff --git a/src/error.rs b/src/error.rs index 4f363fff..501c77d1 100644 --- a/src/error.rs +++ b/src/error.rs @@ -61,7 +61,6 @@ pub enum Error { BadDatabase(&'static str), #[error("uiaa")] Uiaa(UiaaInfo), - #[error("{0}: {1}")] BadRequest(ErrorKind, &'static str), #[error("{0}")] @@ -80,19 +79,16 @@ impl Error { } } -#[cfg(feature = "conduit_bin")] -impl<'r, 'o> Responder<'r, 'o> for Error -where - 'o: 'r, -{ - fn respond_to(self, r: &'r Request<'_>) -> response::Result<'o> { +impl Error { + pub fn to_response(&self) -> RumaResponse { if let Self::Uiaa(uiaainfo) = self { - return RumaResponse::from(UiaaResponse::AuthResponse(uiaainfo)).respond_to(r); + return RumaResponse(UiaaResponse::AuthResponse(uiaainfo.clone())); } - if let Self::FederationError(origin, mut error) = self { + if let Self::FederationError(origin, error) = self { + let mut error = error.clone(); error.message = format!("Answer from {}: {}", origin, error.message); - return RumaResponse::from(error).respond_to(r); + return RumaResponse(UiaaResponse::MatrixError(error)); } let message = format!("{}", self); @@ -119,11 +115,20 @@ where warn!("{}: {}", status_code, message); - RumaResponse::from(RumaError { + RumaResponse(UiaaResponse::MatrixError(RumaError { kind, message, status_code, - }) - .respond_to(r) + })) + } +} + +#[cfg(feature = "conduit_bin")] +impl<'r, 'o> Responder<'r, 'o> for Error +where + 'o: 'r, +{ + fn respond_to(self, r: &'r Request<'_>) -> response::Result<'o> { + self.to_response().respond_to(r) } } diff --git a/src/ruma_wrapper.rs b/src/ruma_wrapper.rs index 2912a578..8c22f79b 100644 --- a/src/ruma_wrapper.rs +++ b/src/ruma_wrapper.rs @@ -1,6 +1,6 @@ use crate::Error; use ruma::{ - api::OutgoingResponse, + api::{client::r0::uiaa::UiaaResponse, OutgoingResponse}, identifiers::{DeviceId, UserId}, signatures::CanonicalJsonValue, Outgoing, ServerName, @@ -335,49 +335,60 @@ impl Deref for Ruma { /// This struct converts ruma responses into rocket http responses. pub type ConduitResult = std::result::Result, Error>; -pub struct RumaResponse(pub T); +pub fn response(response: RumaResponse) -> response::Result<'static> { + let http_response = response + .0 + .try_into_http_response::>() + .map_err(|_| Status::InternalServerError)?; -impl From for RumaResponse { + let mut response = rocket::response::Response::build(); + + let status = http_response.status(); + response.raw_status(status.into(), ""); + + for header in http_response.headers() { + response.raw_header(header.0.to_string(), header.1.to_str().unwrap().to_owned()); + } + + let http_body = http_response.into_body(); + + response.sized_body(http_body.len(), Cursor::new(http_body)); + + response.raw_header("Access-Control-Allow-Origin", "*"); + response.raw_header( + "Access-Control-Allow-Methods", + "GET, POST, PUT, DELETE, OPTIONS", + ); + response.raw_header( + "Access-Control-Allow-Headers", + "Origin, X-Requested-With, Content-Type, Accept, Authorization", + ); + response.raw_header("Access-Control-Max-Age", "86400"); + response.ok() +} + +#[derive(Clone)] +pub struct RumaResponse(pub T); + +impl From for RumaResponse { fn from(t: T) -> Self { Self(t) } } +impl From for RumaResponse { + fn from(t: Error) -> Self { + t.to_response() + } +} + #[cfg(feature = "conduit_bin")] impl<'r, 'o, T> Responder<'r, 'o> for RumaResponse where - T: Send + OutgoingResponse, 'o: 'r, + T: OutgoingResponse, { fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> { - let http_response = self - .0 - .try_into_http_response::>() - .map_err(|_| Status::InternalServerError)?; - - let mut response = rocket::response::Response::build(); - - let status = http_response.status(); - response.raw_status(status.into(), ""); - - for header in http_response.headers() { - response.raw_header(header.0.to_string(), header.1.to_str().unwrap().to_owned()); - } - - let http_body = http_response.into_body(); - - response.sized_body(http_body.len(), Cursor::new(http_body)); - - response.raw_header("Access-Control-Allow-Origin", "*"); - response.raw_header( - "Access-Control-Allow-Methods", - "GET, POST, PUT, DELETE, OPTIONS", - ); - response.raw_header( - "Access-Control-Allow-Headers", - "Origin, X-Requested-With, Content-Type, Accept, Authorization", - ); - response.raw_header("Access-Control-Max-Age", "86400"); - response.ok() + response(self) } } diff --git a/src/server_server.rs b/src/server_server.rs index 961cc9d8..2bcfd2b3 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -45,7 +45,7 @@ use ruma::{ receipt::ReceiptType, serde::Raw, signatures::{CanonicalJsonObject, CanonicalJsonValue}, - state_res::{self, Event, EventMap, RoomVersion, StateMap}, + state_res::{self, Event, RoomVersion, StateMap}, to_device::DeviceIdOrAllDevices, uint, EventId, MilliSecondsSinceUnixEpoch, RoomId, RoomVersionId, ServerName, ServerSigningKeyId, UserId, @@ -612,7 +612,7 @@ pub async fn send_transaction_message_route( // TODO: This could potentially also be some sort of trie (suffix tree) like structure so // that once an auth event is known it would know (using indexes maybe) all of the auth // events that it references. - let mut auth_cache = EventMap::new(); + // let mut auth_cache = EventMap::new(); for pdu in &body.pdus { // We do not add the event_id field to the pdu here because of signature and hashes checks @@ -627,17 +627,9 @@ pub async fn send_transaction_message_route( let start_time = Instant::now(); resolved_map.insert( event_id.clone(), - handle_incoming_pdu( - &body.origin, - &event_id, - value, - true, - &db, - &pub_key_map, - &mut auth_cache, - ) - .await - .map(|_| ()), + handle_incoming_pdu(&body.origin, &event_id, value, true, &db, &pub_key_map) + .await + .map(|_| ()), ); let elapsed = start_time.elapsed(); @@ -820,7 +812,6 @@ pub fn handle_incoming_pdu<'a>( is_timeline_event: bool, db: &'a Database, pub_key_map: &'a RwLock>>, - auth_cache: &'a mut EventMap>, ) -> AsyncRecursiveResult<'a, Option>, String> { Box::pin(async move { // TODO: For RoomVersion6 we must check that Raw<..> is canonical do we anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json @@ -908,15 +899,9 @@ pub fn handle_incoming_pdu<'a>( // 5. Reject "due to auth events" if can't get all the auth events or some of the auth events are also rejected "due to auth events" // EDIT: Step 5 is not applied anymore because it failed too often debug!("Fetching auth events for {}", incoming_pdu.event_id); - fetch_and_handle_events( - db, - origin, - &incoming_pdu.auth_events, - pub_key_map, - auth_cache, - ) - .await - .map_err(|e| e.to_string())?; + fetch_and_handle_events(db, origin, &incoming_pdu.auth_events, pub_key_map) + .await + .map_err(|e| e.to_string())?; // 6. Reject "due to auth events" if the event doesn't pass auth based on the auth events debug!( @@ -927,9 +912,13 @@ pub fn handle_incoming_pdu<'a>( // Build map of auth events let mut auth_events = BTreeMap::new(); for id in &incoming_pdu.auth_events { - let auth_event = auth_cache.get(id).ok_or_else(|| { - "Auth event not found, event failed recursive auth checks.".to_string() - })?; + let auth_event = db + .rooms + .get_pdu(id) + .map_err(|e| e.to_string())? + .ok_or_else(|| { + "Auth event not found, event failed recursive auth checks.".to_string() + })?; match auth_events.entry(( auth_event.kind.clone(), @@ -963,10 +952,10 @@ pub fn handle_incoming_pdu<'a>( let previous_create = if incoming_pdu.auth_events.len() == 1 && incoming_pdu.prev_events == incoming_pdu.auth_events { - auth_cache - .get(&incoming_pdu.auth_events[0]) - .cloned() - .filter(|maybe_create| **maybe_create == create_event) + db.rooms + .get_pdu(&incoming_pdu.auth_events[0]) + .map_err(|e| e.to_string())? + .filter(|maybe_create| **maybe_create == *create_event) } else { None }; @@ -1008,7 +997,6 @@ pub fn handle_incoming_pdu<'a>( debug!("Requesting state at event."); let mut state_at_incoming_event = None; - let mut incoming_auth_events = Vec::new(); if incoming_pdu.prev_events.len() == 1 { let prev_event = &incoming_pdu.prev_events[0]; @@ -1031,7 +1019,7 @@ pub fn handle_incoming_pdu<'a>( state_vec.push(prev_event.clone()); } state_at_incoming_event = Some( - fetch_and_handle_events(db, origin, &state_vec, pub_key_map, auth_cache) + fetch_and_handle_events(db, origin, &state_vec, pub_key_map) .await .map_err(|_| "Failed to fetch state events locally".to_owned())? .into_iter() @@ -1069,18 +1057,12 @@ pub fn handle_incoming_pdu<'a>( { Ok(res) => { debug!("Fetching state events at event."); - let state_vec = match fetch_and_handle_events( - &db, - origin, - &res.pdu_ids, - pub_key_map, - auth_cache, - ) - .await - { - Ok(state) => state, - Err(_) => return Err("Failed to fetch state events.".to_owned()), - }; + let state_vec = + match fetch_and_handle_events(&db, origin, &res.pdu_ids, pub_key_map).await + { + Ok(state) => state, + Err(_) => return Err("Failed to fetch state events.".to_owned()), + }; let mut state = BTreeMap::new(); for pdu in state_vec { @@ -1106,14 +1088,8 @@ pub fn handle_incoming_pdu<'a>( } debug!("Fetching auth chain events at event."); - incoming_auth_events = match fetch_and_handle_events( - &db, - origin, - &res.auth_chain_ids, - pub_key_map, - auth_cache, - ) - .await + match fetch_and_handle_events(&db, origin, &res.auth_chain_ids, pub_key_map) + .await { Ok(state) => state, Err(_) => return Err("Failed to fetch auth chain.".to_owned()), @@ -1181,15 +1157,12 @@ pub fn handle_incoming_pdu<'a>( let mut leaf_state = db .rooms .state_full(pdu_shortstatehash) - .map_err(|_| "Failed to ask db for room state.".to_owned())? - .into_iter() - .map(|(k, v)| (k, Arc::new(v))) - .collect::>(); + .map_err(|_| "Failed to ask db for room state.".to_owned())?; if let Some(state_key) = &leaf_pdu.state_key { // Now it's the state after let key = (leaf_pdu.kind.clone(), state_key.clone()); - leaf_state.insert(key, Arc::new(leaf_pdu)); + leaf_state.insert(key, leaf_pdu); } fork_states.insert(leaf_state); @@ -1209,10 +1182,7 @@ pub fn handle_incoming_pdu<'a>( let current_state = db .rooms .room_state_full(&room_id) - .map_err(|_| "Failed to load room state.".to_owned())? - .into_iter() - .map(|(k, v)| (k, Arc::new(v))) - .collect::>(); + .map_err(|_| "Failed to load room state.".to_owned())?; fork_states.insert(current_state.clone()); @@ -1249,14 +1219,8 @@ pub fn handle_incoming_pdu<'a>( for map in &fork_states { let mut state_auth = vec![]; for auth_id in map.values().flat_map(|pdu| &pdu.auth_events) { - match fetch_and_handle_events( - &db, - origin, - &[auth_id.clone()], - pub_key_map, - auth_cache, - ) - .await + match fetch_and_handle_events(&db, origin, &[auth_id.clone()], pub_key_map) + .await { // This should always contain exactly one element when Ok Ok(events) => state_auth.extend_from_slice(&events), @@ -1265,31 +1229,9 @@ pub fn handle_incoming_pdu<'a>( } } } - auth_cache.extend( - map.iter() - .map(|pdu| (pdu.1.event_id.clone(), pdu.1.clone())), - ); auth_events.push(state_auth); } - // Add everything we will need to event_map - auth_cache.extend( - auth_events - .iter() - .map(|pdus| pdus.iter().map(|pdu| (pdu.event_id.clone(), pdu.clone()))) - .flatten(), - ); - auth_cache.extend( - incoming_auth_events - .into_iter() - .map(|pdu| (pdu.event_id().clone(), pdu)), - ); - auth_cache.extend( - state_after - .into_iter() - .map(|(_, pdu)| (pdu.event_id().clone(), pdu)), - ); - match state_res::StateResolution::resolve( &room_id, room_version_id, @@ -1305,7 +1247,13 @@ pub fn handle_incoming_pdu<'a>( .into_iter() .map(|pdus| pdus.into_iter().map(|pdu| pdu.event_id().clone()).collect()) .collect(), - auth_cache, + &|id| { + let res = db.rooms.get_pdu(id); + if let Err(e) = &res { + error!("LOOK AT ME Failed to fetch event: {}", e); + } + res.ok().flatten() + }, ) { Ok(new_state) => new_state, Err(_) => { @@ -1365,21 +1313,16 @@ pub fn handle_incoming_pdu<'a>( /// Find the event and auth it. Once the event is validated (steps 1 - 8) /// it is appended to the outliers Tree. /// -/// a. Look in the auth_cache -/// b. Look in the main timeline (pduid_pdu tree) -/// c. Look at outlier pdu tree -/// d. Ask origin server over federation -/// e. TODO: Ask other servers over federation? -/// -/// If the event is unknown to the `auth_cache` it is added. This guarantees that any -/// event we need to know of will be present. +/// a. Look in the main timeline (pduid_pdu tree) +/// b. Look at outlier pdu tree +/// c. Ask origin server over federation +/// d. TODO: Ask other servers over federation? //#[tracing::instrument(skip(db, key_map, auth_cache))] pub(crate) fn fetch_and_handle_events<'a>( db: &'a Database, origin: &'a ServerName, events: &'a [EventId], pub_key_map: &'a RwLock>>, - auth_cache: &'a mut EventMap>, ) -> AsyncRecursiveResult<'a, Vec>, Error> { Box::pin(async move { let back_off = |id| match db.globals.bad_event_ratelimiter.write().unwrap().entry(id) { @@ -1403,84 +1346,73 @@ pub(crate) fn fetch_and_handle_events<'a>( continue; } } - // a. Look at auth cache - let pdu = match auth_cache.get(id) { + + // a. Look in the main timeline (pduid_pdu tree) + // b. Look at outlier pdu tree + // (get_pdu checks both) + let pdu = match db.rooms.get_pdu(&id)? { Some(pdu) => { - // We already have the auth chain for events in cache - pdu.clone() + trace!("Found {} in db", id); + pdu } - // b. Look in the main timeline (pduid_pdu tree) - // c. Look at outlier pdu tree - // (get_pdu checks both) - None => match db.rooms.get_pdu(&id)? { - Some(pdu) => { - trace!("Found {} in db", id); - // We need to fetch the auth chain - let _ = fetch_and_handle_events( - db, + None => { + // c. Ask origin server over federation + debug!("Fetching {} over federation.", id); + match db + .sending + .send_federation_request( + &db.globals, origin, - &pdu.auth_events, - pub_key_map, - auth_cache, + get_event::v1::Request { event_id: &id }, ) - .await?; - Arc::new(pdu) - } - None => { - // d. Ask origin server over federation - debug!("Fetching {} over federation.", id); - match db - .sending - .send_federation_request( - &db.globals, + .await + { + Ok(res) => { + debug!("Got {} over federation", id); + let (event_id, mut value) = + crate::pdu::gen_event_id_canonical_json(&res.pdu)?; + // This will also fetch the auth chain + match handle_incoming_pdu( origin, - get_event::v1::Request { event_id: &id }, + &event_id, + value.clone(), + false, + db, + pub_key_map, ) .await - { - Ok(res) => { - debug!("Got {} over federation", id); - let (event_id, mut value) = - crate::pdu::gen_event_id_canonical_json(&res.pdu)?; - // This will also fetch the auth chain - match handle_incoming_pdu( - origin, - &event_id, - value.clone(), - false, - db, - pub_key_map, - auth_cache, - ) - .await - { - Ok(_) => { - value.insert( - "event_id".to_owned(), - CanonicalJsonValue::String(event_id.into()), - ); - - Arc::new(serde_json::from_value( - serde_json::to_value(value).expect("canonicaljsonobject is valid value"), - ).expect("This is possible because handle_incoming_pdu worked")) - } - Err(e) => { - warn!("Authentication of event {} failed: {:?}", id, e); - back_off(id.clone()); - continue; - } + { + Ok(_) => { + value.insert( + "event_id".to_owned(), + CanonicalJsonValue::String(event_id.into()), + ); + + Arc::new( + serde_json::from_value( + serde_json::to_value(value) + .expect("canonicaljsonobject is valid value"), + ) + .expect( + "This is possible because handle_incoming_pdu worked", + ), + ) + } + Err(e) => { + warn!("Authentication of event {} failed: {:?}", id, e); + back_off(id.clone()); + continue; } } - Err(_) => { - warn!("Failed to fetch event: {}", id); - back_off(id.clone()); - continue; - } + } + Err(_) => { + warn!("Failed to fetch event: {}", id); + back_off(id.clone()); + continue; } } - }, + } }; - auth_cache.entry(id.clone()).or_insert_with(|| pdu.clone()); pdus.push(pdu); } Ok(pdus) @@ -1838,7 +1770,7 @@ pub fn get_event_authorization_route( .difference(&auth_chain_ids) .cloned(), ); - auth_chain_ids.extend(pdu.auth_events.into_iter()); + auth_chain_ids.extend(pdu.auth_events.clone().into_iter()); let pdu_json = PduEvent::convert_to_outgoing_federation_event( db.rooms.get_pdu_json(&event_id)?.unwrap(), @@ -1901,7 +1833,7 @@ pub fn get_room_state_route( .difference(&auth_chain_ids) .cloned(), ); - auth_chain_ids.extend(pdu.auth_events.into_iter()); + auth_chain_ids.extend(pdu.auth_events.clone().into_iter()); let pdu_json = PduEvent::convert_to_outgoing_federation_event( db.rooms.get_pdu_json(&event_id)?.unwrap(), @@ -1954,7 +1886,7 @@ pub fn get_room_state_ids_route( .difference(&auth_chain_ids) .cloned(), ); - auth_chain_ids.extend(pdu.auth_events.into_iter()); + auth_chain_ids.extend(pdu.auth_events.clone().into_iter()); } else { warn!("Could not find pdu mentioned in auth events."); } @@ -2022,7 +1954,7 @@ pub fn create_join_event_template_route( let create_prev_event = if prev_events.len() == 1 && Some(&prev_events[0]) == create_event.as_ref().map(|c| &c.event_id) { - create_event.map(Arc::new) + create_event } else { None }; @@ -2066,10 +1998,10 @@ pub fn create_join_event_template_route( let mut unsigned = BTreeMap::new(); if let Some(prev_pdu) = db.rooms.room_state_get(&body.room_id, &kind, &state_key)? { - unsigned.insert("prev_content".to_owned(), prev_pdu.content); + unsigned.insert("prev_content".to_owned(), prev_pdu.content.clone()); unsigned.insert( "prev_sender".to_owned(), - serde_json::to_value(prev_pdu.sender).expect("UserId::to_value always works"), + serde_json::to_value(&prev_pdu.sender).expect("UserId::to_value always works"), ); } @@ -2161,7 +2093,7 @@ pub async fn create_join_event_route( ))?; let pub_key_map = RwLock::new(BTreeMap::new()); - let mut auth_cache = EventMap::new(); + // let mut auth_cache = EventMap::new(); // We do not add the event_id field to the pdu here because of signature and hashes checks let (event_id, value) = match crate::pdu::gen_event_id_canonical_json(&body.pdu) { @@ -2184,26 +2116,18 @@ pub async fn create_join_event_route( ) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?; - let pdu_id = handle_incoming_pdu( - &origin, - &event_id, - value, - true, - &db, - &pub_key_map, - &mut auth_cache, - ) - .await - .map_err(|_| { - Error::BadRequest( + let pdu_id = handle_incoming_pdu(&origin, &event_id, value, true, &db, &pub_key_map) + .await + .map_err(|_| { + Error::BadRequest( + ErrorKind::InvalidParam, + "Error while handling incoming PDU.", + ) + })? + .ok_or(Error::BadRequest( ErrorKind::InvalidParam, - "Error while handling incoming PDU.", - ) - })? - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Could not accept incoming PDU as timeline event.", - ))?; + "Could not accept incoming PDU as timeline event.", + ))?; let state_ids = db.rooms.state_full_ids(shortstatehash)?; @@ -2220,7 +2144,7 @@ pub async fn create_join_event_route( .difference(&auth_chain_ids) .cloned(), ); - auth_chain_ids.extend(pdu.auth_events.into_iter()); + auth_chain_ids.extend(pdu.auth_events.clone().into_iter()); } else { warn!("Could not find pdu mentioned in auth events."); }