From fa3226898ccbd31475bf3f2467e729f3e19dcdfa Mon Sep 17 00:00:00 2001 From: timokoesters Date: Fri, 3 Apr 2020 17:27:08 +0200 Subject: [PATCH] feat: save pdus PDUs are saved in a pduid -> pdus map. roomid -> pduleaves keeps track of the leaves of the event graph and eventid -> pduid maps event ids to pdus. --- Cargo.lock | 26 ++++++++ Cargo.toml | 3 +- Rocket.toml | 8 ++- src/data.rs | 154 ++++++++++++++++++++++++++++++++++++++++++-- src/database.rs | 42 +++++++++--- src/main.rs | 89 ++++++++++++++++++++----- src/ruma_wrapper.rs | 36 +++++------ src/utils.rs | 5 ++ 8 files changed, 309 insertions(+), 54 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1a1da0c8..19276fe7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -496,7 +496,9 @@ dependencies = [ "ruma-api", "ruma-client-api", "ruma-events", + "ruma-federation-api", "ruma-identifiers", + "ruma-signatures", "serde_json", "sled", ] @@ -875,6 +877,19 @@ dependencies = [ "syn 1.0.17", ] +[[package]] +name = "ruma-federation-api" +version = "0.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2a73a23c4d9243be91e101e1942f4d9cd913ef5156d756bafdfe2409ee23d72" +dependencies = [ + "js_int", + "ruma-events", + "ruma-identifiers", + "serde", + "serde_json", +] + [[package]] name = "ruma-identifiers" version = "0.14.1" @@ -886,6 +901,17 @@ dependencies = [ "url 2.1.1", ] +[[package]] +name = "ruma-signatures" +version = "0.5.0" +source = "git+https://github.com/ruma/ruma-signatures.git#a08fc01c0bce63f913e1b4b1a673169d59738b63" +dependencies = [ + "base64 0.11.0", + "ring", + "serde_json", + "untrusted", +] + [[package]] name = "rust-argon2" version = "0.7.0" diff --git a/Cargo.toml b/Cargo.toml index 13b38f38..e01ca0d6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,4 +19,5 @@ ruma-api = "0.15.0" ruma-events = "0.18.0" js_int = "0.1.3" serde_json = "1.0.50" -ruma-signatures = "0.5.0" +ruma-signatures = { git = "https://github.com/ruma/ruma-signatures.git" } +ruma-federation-api = "0.0.1" diff --git a/Rocket.toml b/Rocket.toml index d18ee979..f55e1072 100644 --- a/Rocket.toml +++ b/Rocket.toml @@ -1,3 +1,7 @@ +[global] +address = "0.0.0.0" +port = 14004 + #[global.tls] -#certs = "/etc/ssl/certs/ssl-cert-snakeoil.pem" -#key = "/etc/ssl/private/ssl-cert-snakeoil.key" +#certs = "/etc/letsencrypt/live/matrixtesting.koesters.xyz/fullchain.pem" +#key = "/etc/letsencrypt/live/matrixtesting.koesters.xyz/privkey.pem" diff --git a/src/data.rs b/src/data.rs index b7b9845b..28b8d05f 100644 --- a/src/data.rs +++ b/src/data.rs @@ -1,7 +1,9 @@ use crate::{utils, Database}; +use log::debug; use ruma_events::collections::all::Event; +use ruma_federation_api::RoomV3Pdu; use ruma_identifiers::{EventId, RoomId, UserId}; -use std::convert::TryInto; +use std::convert::{TryFrom, TryInto}; pub struct Data { hostname: String, @@ -99,14 +101,152 @@ impl Data { .unwrap(); } - /// Create a new room event. - pub fn event_add(&self, room_id: &RoomId, event_id: &EventId, event: &Event) { - let mut key = room_id.to_string().as_bytes().to_vec(); - key.extend_from_slice(event_id.to_string().as_bytes()); + pub fn pdu_get(&self, event_id: &EventId) -> Option { self.db - .roomid_eventid_event - .insert(&key, &*serde_json::to_string(event).unwrap()) + .eventid_pduid + .get(event_id.to_string().as_bytes()) + .unwrap() + .map(|pdu_id| { + serde_json::from_slice( + &self + .db + .pduid_pdus + .get(pdu_id) + .unwrap() + .expect("eventid_pduid in db is valid"), + ) + .expect("pdu is valid") + }) + } + + // TODO: Make sure this isn't called twice in parallel + pub fn pdu_leaves_replace(&self, room_id: &RoomId, event_id: &EventId) -> Vec { + let event_ids = self + .db + .roomid_pduleaves + .get_iter(room_id.to_string().as_bytes()) + .values() + .map(|pdu_id| { + EventId::try_from(&*utils::string_from_bytes(&pdu_id.unwrap())) + .expect("pdu leaves are valid event ids") + }) + .collect(); + + self.db + .roomid_pduleaves + .clear(room_id.to_string().as_bytes()); + + self.db.roomid_pduleaves.add( + &room_id.to_string().as_bytes(), + (*event_id.to_string()).into(), + ); + + event_ids + } + + /// Add a persisted data unit from this homeserver + pub fn pdu_append(&self, event_id: &EventId, room_id: &RoomId, event: Event) { + // prev_events are the leaves of the current graph. This method removes all leaves from the + // room and replaces them with our event + let prev_events = self.pdu_leaves_replace(room_id, event_id); + + // Our depth is the maximum depth of prev_events + 1 + let depth = prev_events + .iter() + .map(|event_id| { + self.pdu_get(event_id) + .expect("pdu in prev_events is valid") + .depth + .into() + }) + .max() + .unwrap_or(0_u64) + + 1; + + let mut pdu_value = serde_json::to_value(&event).expect("message event can be serialized"); + let pdu = pdu_value.as_object_mut().unwrap(); + + pdu.insert( + "prev_events".to_owned(), + prev_events + .iter() + .map(|id| id.to_string()) + .collect::>() + .into(), + ); + pdu.insert("origin".to_owned(), self.hostname().into()); + pdu.insert("depth".to_owned(), depth.into()); + pdu.insert("auth_events".to_owned(), vec!["$auth_eventid"].into()); // TODO + pdu.insert( + "hashes".to_owned(), + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA".into(), + ); // TODO + pdu.insert("signatures".to_owned(), "signature".into()); // TODO + + // The new value will need a new index. We store the last used index in 'n' + id + let mut count_key: Vec = vec![b'n']; + count_key.extend_from_slice(&room_id.to_string().as_bytes()); + + // Increment the last index and use that + let index = utils::u64_from_bytes( + &self + .db + .pduid_pdus + .update_and_fetch(&count_key, utils::increment) + .unwrap() + .unwrap(), + ); + + let mut pdu_id = vec![b'd']; + pdu_id.extend_from_slice(room_id.to_string().as_bytes()); + + pdu_id.push(b'#'); // Add delimiter so we don't find rooms starting with the same id + pdu_id.extend_from_slice(index.to_string().as_bytes()); + + self.db + .pduid_pdus + .insert(&pdu_id, dbg!(&*serde_json::to_string(&pdu).unwrap())) .unwrap(); + + self.db + .eventid_pduid + .insert(event_id.to_string(), pdu_id.clone()) + .unwrap(); + } + + /// Returns a vector of all PDUs. + pub fn pdus_all(&self) -> Vec { + self.pdus_since( + self.db + .eventid_pduid + .iter() + .values() + .next() + .unwrap() + .map(|key| utils::string_from_bytes(&key)) + .expect("there should be at least one pdu"), + ) + } + + /// Returns a vector of all events that happened after the event with id `since`. + pub fn pdus_since(&self, since: String) -> Vec { + let mut pdus = Vec::new(); + + if let Some(room_id) = since.rsplitn(2, '#').nth(1) { + let mut current = since.clone(); + + while let Some((key, value)) = self.db.pduid_pdus.get_gt(current).unwrap() { + if key.starts_with(&room_id.to_string().as_bytes()) { + current = utils::string_from_bytes(&key); + } else { + break; + } + pdus.push(serde_json::from_slice(&value).expect("pdu is valid")); + } + } else { + debug!("event at `since` not found"); + } + pdus } pub fn debug(&self) { diff --git a/src/database.rs b/src/database.rs index 34ed72ba..b08dd3ce 100644 --- a/src/database.rs +++ b/src/database.rs @@ -15,11 +15,17 @@ impl MultiValue { // Data keys start with d let mut key = vec![b'd']; key.extend_from_slice(id.as_ref()); - key.push(0xff); // Add delimiter so we don't find usernames starting with the same id + key.push(0xff); // Add delimiter so we don't find keys starting with the same id self.0.scan_prefix(key) } + pub fn clear(&self, id: &[u8]) { + for key in self.get_iter(id).keys() { + self.0.remove(key.unwrap()).unwrap(); + } + } + /// Add another value to the id. pub fn add(&self, id: &[u8], value: IVec) { // The new value will need a new index. We store the last used index in 'n' + id @@ -48,7 +54,9 @@ pub struct Database { pub userid_deviceids: MultiValue, pub deviceid_token: sled::Tree, pub token_userid: sled::Tree, - pub roomid_eventid_event: sled::Tree, + pub pduid_pdus: sled::Tree, + pub roomid_pduleaves: MultiValue, + pub eventid_pduid: sled::Tree, _db: sled::Db, } @@ -67,7 +75,9 @@ impl Database { userid_deviceids: MultiValue(db.open_tree("userid_deviceids").unwrap()), deviceid_token: db.open_tree("deviceid_token").unwrap(), token_userid: db.open_tree("token_userid").unwrap(), - roomid_eventid_event: db.open_tree("roomid_eventid_event").unwrap(), + pduid_pdus: db.open_tree("pduid_pdus").unwrap(), + roomid_pduleaves: MultiValue(db.open_tree("roomid_pduleaves").unwrap()), + eventid_pduid: db.open_tree("eventid_pduid").unwrap(), _db: db, } } @@ -81,7 +91,7 @@ impl Database { String::from_utf8_lossy(&v), ); } - println!("# UserId -> DeviceIds:"); + println!("\n# UserId -> DeviceIds:"); for (k, v) in self.userid_deviceids.iter_all().map(|r| r.unwrap()) { println!( "{} -> {}", @@ -89,7 +99,7 @@ impl Database { String::from_utf8_lossy(&v), ); } - println!("# DeviceId -> Token:"); + println!("\n# DeviceId -> Token:"); for (k, v) in self.deviceid_token.iter().map(|r| r.unwrap()) { println!( "{} -> {}", @@ -97,7 +107,7 @@ impl Database { String::from_utf8_lossy(&v), ); } - println!("# Token -> UserId:"); + println!("\n# Token -> UserId:"); for (k, v) in self.token_userid.iter().map(|r| r.unwrap()) { println!( "{} -> {}", @@ -105,8 +115,24 @@ impl Database { String::from_utf8_lossy(&v), ); } - println!("# RoomId + EventId -> Event:"); - for (k, v) in self.roomid_eventid_event.iter().map(|r| r.unwrap()) { + println!("\n# RoomId -> PDU leaves:"); + for (k, v) in self.roomid_pduleaves.iter_all().map(|r| r.unwrap()) { + println!( + "{} -> {}", + String::from_utf8_lossy(&k), + String::from_utf8_lossy(&v), + ); + } + println!("\n# PDU Id -> PDUs:"); + for (k, v) in self.pduid_pdus.iter().map(|r| r.unwrap()) { + println!( + "{} -> {}", + String::from_utf8_lossy(&k), + String::from_utf8_lossy(&v), + ); + } + println!("\n# EventId -> PDU Id:"); + for (k, v) in self.eventid_pduid.iter().map(|r| r.unwrap()) { println!( "{} -> {}", String::from_utf8_lossy(&k), diff --git a/src/main.rs b/src/main.rs index 6cf5477b..c8b1cc8b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,12 +8,12 @@ pub use data::Data; pub use database::Database; use log::debug; -use rocket::{get, post, put, routes, State}; +use rocket::{get, options, post, put, routes, State}; use ruma_client_api::{ error::{Error, ErrorKind}, r0::{ account::register, alias::get_alias, membership::join_room_by_id, - message::create_message_event, session::login, + message::create_message_event, session::login, sync::sync_events, }, unversioned::get_supported_versions, }; @@ -24,20 +24,13 @@ use serde_json::map::Map; use std::{ collections::HashMap, convert::{TryFrom, TryInto}, + path::PathBuf, }; #[get("/_matrix/client/versions")] fn get_supported_versions_route() -> MatrixResult { MatrixResult(Ok(get_supported_versions::Response { - versions: vec![ - "r0.0.1".to_owned(), - "r0.1.0".to_owned(), - "r0.2.0".to_owned(), - "r0.3.0".to_owned(), - "r0.4.0".to_owned(), - "r0.5.0".to_owned(), - "r0.6.0".to_owned(), - ], + versions: vec!["r0.6.0".to_owned()], unstable_features: HashMap::new(), })) } @@ -219,9 +212,9 @@ fn create_message_event_route( body: Ruma, ) -> MatrixResult { // Construct event - let event = Event::RoomMessage(MessageEvent { + let mut event = Event::RoomMessage(MessageEvent { content: body.data.clone().into_result().unwrap(), - event_id: event_id.clone(), + event_id: EventId::try_from("$thiswillbefilledinlater").unwrap(), origin_server_ts: utils::millis_since_unix_epoch(), room_id: Some(body.room_id.clone()), sender: body.user_id.clone().expect("user is authenticated"), @@ -229,18 +222,78 @@ fn create_message_event_route( }); // Generate event id - dbg!(ruma_signatures::reference_hash(event)); + let event_id = EventId::try_from(&*format!( + "${}", + ruma_signatures::reference_hash(&serde_json::to_value(&event).unwrap()) + .expect("ruma can calculate reference hashes") + )) + .expect("ruma's reference hashes are correct"); + + // Insert event id + if let Event::RoomMessage(message) = &mut event { + message.event_id = event_id.clone(); + } - let event_id = EventId::try_from("$TODOrandomeventid:localhost").unwrap(); - data.event_add(&body.room_id, &event_id, &event); + // Add PDU to the graph + data.pdu_append(&event_id, &body.room_id, event); MatrixResult(Ok(create_message_event::Response { event_id })) } +#[get("/_matrix/client/r0/sync")] +fn sync_route(data: State) -> MatrixResult { + let pdus = data.pdus_all(); + let mut joined_rooms = HashMap::new(); + joined_rooms.insert( + "!roomid:localhost".try_into().unwrap(), + sync_events::JoinedRoom { + account_data: sync_events::AccountData { events: Vec::new() }, + summary: sync_events::RoomSummary { + heroes: Vec::new(), + joined_member_count: None, + invited_member_count: None, + }, + unread_notifications: sync_events::UnreadNotificationsCount { + highlight_count: None, + notification_count: None, + }, + timeline: sync_events::Timeline { + limited: None, + prev_batch: None, + events: todo!(), + }, + state: sync_events::State { events: Vec::new() }, + ephemeral: sync_events::Ephemeral { events: Vec::new() }, + }, + ); + + MatrixResult(Ok(sync_events::Response { + next_batch: String::new(), + rooms: sync_events::Rooms { + leave: Default::default(), + join: joined_rooms, + invite: Default::default(), + }, + presence: sync_events::Presence { events: Vec::new() }, + device_lists: Default::default(), + device_one_time_keys_count: Default::default(), + to_device: sync_events::ToDevice { events: Vec::new() }, + })) +} + +#[options("/<_segments..>")] +fn options_route(_segments: PathBuf) -> MatrixResult { + MatrixResult(Err(Error { + kind: ErrorKind::NotFound, + message: "Room not found.".to_owned(), + status_code: http::StatusCode::NOT_FOUND, + })) +} + fn main() { // Log info by default if let Err(_) = std::env::var("RUST_LOG") { - std::env::set_var("RUST_LOG", "info"); + std::env::set_var("RUST_LOG", "matrixserver=debug,info"); } pretty_env_logger::init(); @@ -257,6 +310,8 @@ fn main() { get_alias_route, join_room_by_id_route, create_message_event_route, + sync_route, + options_route, ], ) .manage(data) diff --git a/src/ruma_wrapper.rs b/src/ruma_wrapper.rs index e898137b..eda648ed 100644 --- a/src/ruma_wrapper.rs +++ b/src/ruma_wrapper.rs @@ -1,28 +1,26 @@ -use { - rocket::data::{FromDataSimple, Outcome}, - rocket::http::Status, - rocket::response::Responder, - rocket::Outcome::*, - rocket::Request, - rocket::State, - ruma_api::{ - error::{FromHttpRequestError, FromHttpResponseError}, - Endpoint, Outgoing, - }, - ruma_client_api::error::Error, - ruma_identifiers::UserId, - std::ops::Deref, - std::{ - convert::{TryFrom, TryInto}, - io::{Cursor, Read}, - }, +use rocket::{ + data::{FromDataSimple, Outcome}, + http::Status, + response::Responder, + Outcome::*, + Request, State, +}; +use ruma_api::{ + error::{FromHttpRequestError, FromHttpResponseError}, + Endpoint, Outgoing, +}; +use ruma_client_api::error::Error; +use ruma_identifiers::UserId; +use std::{ + convert::{TryFrom, TryInto}, + io::{Cursor, Read}, + ops::Deref, }; const MESSAGE_LIMIT: u64 = 65535; /// This struct converts rocket requests into ruma structs by converting them into http requests /// first. -#[derive(Debug)] pub struct Ruma { body: T::Incoming, pub user_id: Option, diff --git a/src/utils.rs b/src/utils.rs index f2ef6c4c..19f3f022 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -24,6 +24,11 @@ pub fn increment(old: Option<&[u8]>) -> Option> { Some(number.to_be_bytes().to_vec()) } +pub fn u64_from_bytes(bytes: &[u8]) -> u64 { + let array: [u8; 8] = bytes.try_into().expect("bytes are valid u64"); + u64::from_be_bytes(array) +} + pub fn string_from_bytes(bytes: &[u8]) -> String { String::from_utf8(bytes.to_vec()).expect("bytes are valid utf8") }