From 8f67c01efd33027828573c4373b1a173f49f2007 Mon Sep 17 00:00:00 2001 From: timokoesters Date: Sun, 3 May 2020 17:25:31 +0200 Subject: [PATCH] refactor: split database into multiple files, more error handling, cleaner code --- Cargo.lock | 94 +--- Cargo.toml | 2 +- Rocket.toml | 4 +- src/client_server.rs | 631 ++++++++++++--------- src/data.rs | 1018 ---------------------------------- src/database.rs | 289 ++-------- src/database/account_data.rs | 120 ++++ src/database/globals.rs | 61 ++ src/database/rooms.rs | 547 ++++++++++++++++++ src/database/rooms/edus.rs | 190 +++++++ src/database/users.rs | 144 +++++ src/error.rs | 36 ++ src/main.rs | 8 +- src/ruma_wrapper.rs | 8 +- src/server_server.rs | 36 +- src/test.rs | 4 +- src/utils.rs | 9 +- 17 files changed, 1572 insertions(+), 1629 deletions(-) delete mode 100644 src/data.rs create mode 100644 src/database/account_data.rs create mode 100644 src/database/globals.rs create mode 100644 src/database/rooms.rs create mode 100644 src/database/rooms/edus.rs create mode 100644 src/database/users.rs create mode 100644 src/error.rs diff --git a/Cargo.lock b/Cargo.lock index 17a0ad58..642e8059 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,14 +1,5 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -[[package]] -name = "aho-corasick" -version = "0.7.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8716408b8bc624ed7f65d223ddb9ac2d044c0547b6fa4b0d554f3a9540496ada" -dependencies = [ - "memchr", -] - [[package]] name = "arc-swap" version = "0.4.6" @@ -147,7 +138,6 @@ dependencies = [ "http", "js_int", "log", - "pretty_env_logger", "rand", "reqwest", "rocket", @@ -161,6 +151,7 @@ dependencies = [ "serde", "serde_json", "sled", + "thiserror", "tokio", ] @@ -298,19 +289,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "env_logger" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44533bbbb3bb3c1fa17d9f2e4e38bbbaf8396ba82193c4cb1b6445d711445d36" -dependencies = [ - "atty", - "humantime", - "log", - "regex", - "termcolor", -] - [[package]] name = "fnv" version = "1.0.6" @@ -533,15 +511,6 @@ version = "1.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cd179ae861f0c2e53da70d892f5f3029f9594be0c41dc5269cd371691b1dc2f9" -[[package]] -name = "humantime" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df004cfca50ef23c36850aaaa59ad52cc70d0e90243c3c7737a4dd32dc7a3c4f" -dependencies = [ - "quick-error", -] - [[package]] name = "hyper" version = "0.13.5" @@ -937,16 +906,6 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b" -[[package]] -name = "pretty_env_logger" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "926d36b9553851b8b0005f1275891b392ee4d2d833852c417ed025477350fb9d" -dependencies = [ - "env_logger", - "log", -] - [[package]] name = "proc-macro-hack" version = "0.5.15" @@ -977,12 +936,6 @@ dependencies = [ "unicode-xid 0.2.0", ] -[[package]] -name = "quick-error" -version = "1.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" - [[package]] name = "quote" version = "0.6.13" @@ -1059,24 +1012,6 @@ dependencies = [ "rust-argon2 0.7.0", ] -[[package]] -name = "regex" -version = "1.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6020f034922e3194c711b82a627453881bc4682166cabb07134a10c26ba7692" -dependencies = [ - "aho-corasick", - "memchr", - "regex-syntax", - "thread_local", -] - -[[package]] -name = "regex-syntax" -version = "0.6.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fe5bd57d1d7414c6b5ed48563a2c855d995ff777729dcd91c369ec7fea395ae" - [[package]] name = "remove_dir_all" version = "0.5.2" @@ -1556,21 +1491,23 @@ dependencies = [ ] [[package]] -name = "termcolor" -version = "1.1.0" +name = "thiserror" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb6bfa289a4d7c5766392812c0a1f4c1ba45afa1ad47803c11e1f407d846d75f" +checksum = "d12a1dae4add0f0d568eebc7bf142f145ba1aa2544cafb195c76f0f409091b60" dependencies = [ - "winapi-util", + "thiserror-impl", ] [[package]] -name = "thread_local" -version = "1.0.1" +name = "thiserror-impl" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d40c6d1b69745a6ec6fb1ca717914848da4b44ae29d9b3080cbee91d72a69b14" +checksum = "3f34e0c1caaa462fd840ec6b768946ea1e7842620d94fe29d5b847138f521269" dependencies = [ - "lazy_static", + "proc-macro2 1.0.10", + "quote 1.0.4", + "syn 1.0.18", ] [[package]] @@ -1887,15 +1824,6 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" -[[package]] -name = "winapi-util" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178" -dependencies = [ - "winapi 0.3.8", -] - [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" diff --git a/Cargo.toml b/Cargo.toml index 7001ada5..4aa2f151 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,6 @@ ruma-api = "0.16.0-rc.3" ruma-events = "0.21.0-beta.1" ruma-signatures = { git = "https://github.com/ruma/ruma-signatures.git" } ruma-federation-api = { git = "https://github.com/ruma/ruma-federation-api.git" } -pretty_env_logger = "0.4.0" log = "0.4.8" sled = "0.31.0" directories = "2.0.2" @@ -31,3 +30,4 @@ rand = "0.7.3" rust-argon2 = "0.8.2" reqwest = "0.10.4" base64 = "0.12.0" +thiserror = "1.0.16" diff --git a/Rocket.toml b/Rocket.toml index 5db4a3d6..4a7d79a6 100644 --- a/Rocket.toml +++ b/Rocket.toml @@ -1,6 +1,6 @@ [global] -hostname = "conduit.rs" -port = 14004 +hostname = "matrixtesting.koesters.xyz:59003" +port = 59003 address = "0.0.0.0" [global.tls] diff --git a/src/client_server.rs b/src/client_server.rs index 58c5bda1..3d63ffd2 100644 --- a/src/client_server.rs +++ b/src/client_server.rs @@ -44,7 +44,7 @@ use ruma_events::{collections::only::Event as EduEvent, EventType}; use ruma_identifiers::{RoomId, UserId}; use serde_json::json; -use crate::{server_server, utils, Data, MatrixResult, Ruma}; +use crate::{server_server, utils, Database, MatrixResult, Ruma}; const GUEST_NAME_LENGTH: usize = 10; const DEVICE_ID_LENGTH: usize = 10; @@ -61,12 +61,12 @@ pub fn get_supported_versions_route() -> MatrixResult, + db: State<'_, Database>, body: Ruma, ) -> MatrixResult { // Validate user id let user_id: UserId = - match (*format!("@{}:{}", body.username.clone(), data.hostname())).try_into() { + match (*format!("@{}:{}", body.username.clone(), db.globals.hostname())).try_into() { Err(_) => { debug!("Username invalid"); return MatrixResult(Err(Error { @@ -79,7 +79,7 @@ pub fn get_register_available_route( }; // Check if username is creative enough - if data.user_exists(&user_id) { + if db.users.exists(&user_id).unwrap() { debug!("ID already taken"); return MatrixResult(Err(Error { kind: ErrorKind::UserInUse, @@ -96,7 +96,7 @@ pub fn get_register_available_route( #[post("/_matrix/client/r0/register", data = "")] pub fn register_route( - data: State, + db: State<'_, Database>, body: Ruma, ) -> MatrixResult { if body.auth.is_none() { @@ -117,7 +117,7 @@ pub fn register_route( body.username .clone() .unwrap_or_else(|| utils::random_string(GUEST_NAME_LENGTH)), - data.hostname() + db.globals.hostname() )) .try_into() { @@ -133,7 +133,7 @@ pub fn register_route( }; // Check if username is creative enough - if data.user_exists(&user_id) { + if db.users.exists(&user_id).unwrap() { debug!("ID already taken"); return MatrixResult(Err(UiaaResponse::MatrixError(Error { kind: ErrorKind::UserInUse, @@ -146,7 +146,7 @@ pub fn register_route( if let Ok(hash) = utils::calculate_hash(&password) { // Create user - data.user_add(&user_id, &hash); + db.users.create(&user_id, &hash).unwrap(); } else { return MatrixResult(Err(UiaaResponse::MatrixError(Error { kind: ErrorKind::InvalidParam, @@ -161,15 +161,17 @@ pub fn register_route( .clone() .unwrap_or_else(|| utils::random_string(DEVICE_ID_LENGTH)); - // Add device - data.device_add(&user_id, &device_id); - // Generate new token for the device let token = utils::random_string(TOKEN_LENGTH); - data.token_replace(&user_id, &device_id, token.clone()); + + // Add device + db + .users + .create_device(&user_id, &device_id, &token) + .unwrap(); // Initial data - data.room_userdata_update( + db.account_data.update( None, &user_id, EduEvent::PushRules(ruma_events::push_rules::PushRulesEvent { @@ -199,7 +201,9 @@ pub fn register_route( }, }, }), - ); + &db.globals, + ) + .unwrap(); MatrixResult(Ok(register::Response { access_token: Some(token), @@ -216,17 +220,17 @@ pub fn get_login_route() -> MatrixResult { } #[post("/_matrix/client/r0/login", data = "")] -pub fn login_route(data: State, body: Ruma) -> MatrixResult { +pub fn login_route(db: State<'_, Database>, body: Ruma) -> MatrixResult { // Validate login method let user_id = if let (login::UserInfo::MatrixId(mut username), login::LoginInfo::Password { password }) = (body.user.clone(), body.login_info.clone()) { if !username.contains(':') { - username = format!("@{}:{}", username, data.hostname()); + username = format!("@{}:{}", username, db.globals.hostname()); } if let Ok(user_id) = (*username).try_into() { - if let Some(hash) = data.password_hash_get(&user_id) { + if let Some(hash) = db.users.password_hash(&user_id).unwrap() { let hash_matches = argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false); @@ -272,26 +276,26 @@ pub fn login_route(data: State, body: Ruma) -> MatrixResul .clone() .unwrap_or_else(|| utils::random_string(DEVICE_ID_LENGTH)); - // Add device - data.device_add(&user_id, &device_id); - // Generate a new token for the device let token = utils::random_string(TOKEN_LENGTH); - data.token_replace(&user_id, &device_id, token.clone()); + + // Add device + db + .users + .create_device(&user_id, &device_id, &token) + .unwrap(); MatrixResult(Ok(login::Response { user_id, access_token: token, - home_server: Some(data.hostname().to_owned()), + home_server: Some(db.globals.hostname().to_owned()), device_id, well_known: None, })) } -#[get("/_matrix/client/r0/capabilities", data = "")] -pub fn get_capabilities_route( - body: Ruma, -) -> MatrixResult { +#[get("/_matrix/client/r0/capabilities")] +pub fn get_capabilities_route() -> MatrixResult { // TODO //let mut available = BTreeMap::new(); //available.insert("5".to_owned(), get_capabilities::RoomVersionStability::Unstable); @@ -337,7 +341,7 @@ pub fn get_pushrules_all_route() -> MatrixResult { data = "" )] pub fn set_pushrule_route( - data: State, + db: State<'_, Database>, body: Ruma, _scope: String, _kind: String, @@ -345,7 +349,7 @@ pub fn set_pushrule_route( ) -> MatrixResult { // TODO let user_id = body.user_id.clone().expect("user is authenticated"); - data.room_userdata_update( + db.account_data.update( None, &user_id, EduEvent::PushRules(ruma_events::push_rules::PushRulesEvent { @@ -375,7 +379,9 @@ pub fn set_pushrule_route( }, }, }), - ); + &db.globals + ) + .unwrap(); MatrixResult(Ok(set_pushrule::Response)) } @@ -392,10 +398,8 @@ pub fn set_pushrule_enabled_route( #[get( "/_matrix/client/r0/user/<_user_id>/filter/<_filter_id>", - data = "" )] pub fn get_filter_route( - body: Ruma, _user_id: String, _filter_id: String, ) -> MatrixResult { @@ -411,9 +415,8 @@ pub fn get_filter_route( })) } -#[post("/_matrix/client/r0/user/<_user_id>/filter", data = "")] +#[post("/_matrix/client/r0/user/<_user_id>/filter")] pub fn create_filter_route( - body: Ruma, _user_id: String, ) -> MatrixResult { // TODO @@ -424,10 +427,8 @@ pub fn create_filter_route( #[put( "/_matrix/client/r0/user/<_user_id>/account_data/<_type>", - data = "" )] pub fn set_global_account_data_route( - body: Ruma, _user_id: String, _type: String, ) -> MatrixResult { @@ -436,10 +437,8 @@ pub fn set_global_account_data_route( #[get( "/_matrix/client/r0/user/<_user_id>/account_data/<_type>", - data = "" )] pub fn get_global_account_data_route( - body: Ruma, _user_id: String, _type: String, ) -> MatrixResult { @@ -453,15 +452,26 @@ pub fn get_global_account_data_route( #[put("/_matrix/client/r0/profile/<_user_id>/displayname", data = "")] pub fn set_displayname_route( - data: State, + db: State<'_, Database>, body: Ruma, _user_id: String, ) -> MatrixResult { let user_id = body.user_id.clone().expect("user is authenticated"); - // Send error on None - // Synapse returns a parsing error but the spec doesn't require this - if body.displayname.is_none() { + if let Some(displayname) = &body.displayname { + // Some("") will clear the displayname + if displayname == "" { + db.users.set_displayname(&user_id, None).unwrap(); + } else { + db + .users + .set_displayname(&user_id, Some(displayname.clone())) + .unwrap(); + // TODO: send a new m.presence event with the updated displayname + } + } else { + // Send error on None + // Synapse returns a parsing error but the spec doesn't require this debug!("Request was missing the displayname payload."); return MatrixResult(Err(Error { kind: ErrorKind::MissingParam, @@ -470,30 +480,17 @@ pub fn set_displayname_route( })); } - if let Some(displayname) = &body.displayname { - // Some("") will clear the displayname - if displayname == "" { - data.displayname_remove(&user_id); - } else { - data.displayname_set(&user_id, displayname.clone()); - // TODO send a new m.presence event with the updated displayname - } - } - MatrixResult(Ok(set_display_name::Response)) } -#[get( - "/_matrix/client/r0/profile//displayname", - data = "" -)] +#[get("/_matrix/client/r0/profile/<_user_id>/displayname", data = "")] pub fn get_displayname_route( - data: State, + db: State<'_, Database>, body: Ruma, - user_id_raw: String, + _user_id: String, ) -> MatrixResult { let user_id = (*body).user_id.clone(); - if !data.user_exists(&user_id) { + if !db.users.exists(&user_id).unwrap() { // Return 404 if we don't have a profile for this id debug!("Profile was not found."); return MatrixResult(Err(Error { @@ -502,7 +499,7 @@ pub fn get_displayname_route( status_code: http::StatusCode::NOT_FOUND, })); } - if let Some(displayname) = data.displayname_get(&user_id) { + if let Some(displayname) = db.users.displayname(&user_id).unwrap() { return MatrixResult(Ok(get_display_name::Response { displayname: Some(displayname), })); @@ -514,7 +511,7 @@ pub fn get_displayname_route( #[put("/_matrix/client/r0/profile/<_user_id>/avatar_url", data = "")] pub fn set_avatar_url_route( - data: State, + db: State<'_, Database>, body: Ruma, _user_id: String, ) -> MatrixResult { @@ -533,9 +530,12 @@ pub fn set_avatar_url_route( // TODO also make sure this is valid mxc:// format (not only starting with it) if body.avatar_url == "" { - data.avatar_url_remove(&user_id); + db.users.set_avatar_url(&user_id, None).unwrap(); } else { - data.avatar_url_set(&user_id, body.avatar_url.clone()); + db + .users + .set_avatar_url(&user_id, Some(body.avatar_url.clone())) + .unwrap(); // TODO send a new m.room.member join event with the updated avatar_url // TODO send a new m.presence event with the updated avatar_url } @@ -543,14 +543,14 @@ pub fn set_avatar_url_route( MatrixResult(Ok(set_avatar_url::Response)) } -#[get("/_matrix/client/r0/profile//avatar_url", data = "")] +#[get("/_matrix/client/r0/profile/<_user_id>/avatar_url", data = "")] pub fn get_avatar_url_route( - data: State, + db: State<'_, Database>, body: Ruma, - user_id_raw: String, + _user_id: String, ) -> MatrixResult { let user_id = (*body).user_id.clone(); - if !data.user_exists(&user_id) { + if !db.users.exists(&user_id).unwrap() { // Return 404 if we don't have a profile for this id debug!("Profile was not found."); return MatrixResult(Err(Error { @@ -559,7 +559,7 @@ pub fn get_avatar_url_route( status_code: http::StatusCode::NOT_FOUND, })); } - if let Some(avatar_url) = data.avatar_url_get(&user_id) { + if let Some(avatar_url) = db.users.avatar_url(&user_id).unwrap() { return MatrixResult(Ok(get_avatar_url::Response { avatar_url: Some(avatar_url), })); @@ -569,15 +569,15 @@ pub fn get_avatar_url_route( MatrixResult(Ok(get_avatar_url::Response { avatar_url: None })) } -#[get("/_matrix/client/r0/profile/", data = "")] +#[get("/_matrix/client/r0/profile/<_user_id>", data = "")] pub fn get_profile_route( - data: State, + db: State<'_, Database>, body: Ruma, - user_id_raw: String, + _user_id: String, ) -> MatrixResult { let user_id = (*body).user_id.clone(); - let avatar_url = data.avatar_url_get(&user_id); - let displayname = data.displayname_get(&user_id); + let avatar_url = db.users.avatar_url(&user_id).unwrap(); + let displayname = db.users.displayname(&user_id).unwrap(); if avatar_url.is_some() || displayname.is_some() { return MatrixResult(Ok(get_profile::Response { @@ -595,17 +595,16 @@ pub fn get_profile_route( })) } -#[put("/_matrix/client/r0/presence/<_user_id>/status", data = "")] +#[put("/_matrix/client/r0/presence/<_user_id>/status")] pub fn set_presence_route( - body: Ruma, _user_id: String, ) -> MatrixResult { // TODO MatrixResult(Ok(set_presence::Response)) } -#[post("/_matrix/client/r0/keys/query", data = "")] -pub fn get_keys_route(body: Ruma) -> MatrixResult { +#[post("/_matrix/client/r0/keys/query")] +pub fn get_keys_route() -> MatrixResult { // TODO MatrixResult(Ok(get_keys::Response { failures: BTreeMap::new(), @@ -613,11 +612,8 @@ pub fn get_keys_route(body: Ruma) -> MatrixResult, - body: Ruma, -) -> MatrixResult { +#[post("/_matrix/client/r0/keys/upload")] +pub fn upload_keys_route() -> MatrixResult { // TODO MatrixResult(Ok(upload_keys::Response { one_time_key_counts: BTreeMap::new(), @@ -626,12 +622,12 @@ pub fn upload_keys_route( #[post("/_matrix/client/r0/rooms/<_room_id>/read_markers", data = "")] pub fn set_read_marker_route( - data: State, + db: State<'_, Database>, body: Ruma, _room_id: String, ) -> MatrixResult { let user_id = body.user_id.clone().expect("user is authenticated"); - data.room_userdata_update( + db.account_data.update( Some(&body.room_id), &user_id, EduEvent::FullyRead(ruma_events::fully_read::FullyReadEvent { @@ -640,10 +636,24 @@ pub fn set_read_marker_route( }, room_id: Some(body.room_id.clone()), }), - ); + &db.globals + ) + .unwrap(); if let Some(event) = &body.read_receipt { - data.room_read_set(&body.room_id, &user_id, event); + db + .rooms + .edus + .room_read_set( + &body.room_id, + &user_id, + db + .rooms + .get_pdu_count(event) + .unwrap() + .expect("TODO: what if a client specifies an invalid event"), + ) + .unwrap(); let mut user_receipts = BTreeMap::new(); user_receipts.insert( @@ -660,14 +670,19 @@ pub fn set_read_marker_route( }, ); - data.roomlatest_update( - &user_id, - &body.room_id, - EduEvent::Receipt(ruma_events::receipt::ReceiptEvent { - content: receipt_content, - room_id: None, // None because it can be inferred - }), - ); + db + .rooms + .edus + .roomlatest_update( + &user_id, + &body.room_id, + EduEvent::Receipt(ruma_events::receipt::ReceiptEvent { + content: receipt_content, + room_id: None, // None because it can be inferred + }), + &db.globals, + ) + .unwrap(); } MatrixResult(Ok(set_read_marker::Response)) } @@ -677,7 +692,7 @@ pub fn set_read_marker_route( data = "" )] pub fn create_typing_event_route( - data: State, + db: State<'_, Database>, body: Ruma, _room_id: String, _user_id: String, @@ -691,14 +706,23 @@ pub fn create_typing_event_route( }); if body.typing { - data.roomactive_add( - edu, - &body.room_id, - body.timeout.map(|d| d.as_millis() as u64).unwrap_or(30000) - + utils::millis_since_unix_epoch().try_into().unwrap_or(0), - ); + db + .rooms + .edus + .roomactive_add( + edu, + &body.room_id, + body.timeout.map(|d| d.as_millis() as u64).unwrap_or(30000) + + utils::millis_since_unix_epoch().try_into().unwrap_or(0), + &db.globals, + ) + .unwrap(); } else { - data.roomactive_remove(edu, &body.room_id); + db + .rooms + .edus + .roomactive_remove(edu, &body.room_id) + .unwrap(); } MatrixResult(Ok(create_typing_event::Response)) @@ -706,66 +730,93 @@ pub fn create_typing_event_route( #[post("/_matrix/client/r0/createRoom", data = "")] pub fn create_room_route( - data: State, + db: State<'_, Database>, body: Ruma, ) -> MatrixResult { // TODO: check if room is unique - let room_id = RoomId::new(data.hostname()).expect("host is valid"); + let room_id = RoomId::new(db.globals.hostname()).expect("host is valid"); let user_id = body.user_id.clone().expect("user is authenticated"); - data.pdu_append( - room_id.clone(), - user_id.clone(), - EventType::RoomCreate, - json!({ "creator": user_id }), - None, - Some("".to_owned()), - ); - - data.room_join(&room_id, &user_id); - - data.pdu_append( - room_id.clone(), - user_id.clone(), - EventType::RoomPowerLevels, - json!({ - "ban": 50, - "events_default": 0, - "invite": 50, - "kick": 50, - "redact": 50, - "state_default": 50, - "users": { user_id.to_string(): 100 }, - "users_default": 0 - }), - None, - Some("".to_owned()), - ); - - if let Some(name) = &body.name { - data.pdu_append( + db + .rooms + .append_pdu( room_id.clone(), user_id.clone(), - EventType::RoomName, - json!({ "name": name }), + EventType::RoomCreate, + json!({ "creator": user_id }), None, Some("".to_owned()), - ); - } + &db.globals, + ) + .unwrap(); - if let Some(topic) = &body.topic { - data.pdu_append( + db + .rooms + .join( + &room_id, + &user_id, + db.users.displayname(&user_id).unwrap(), + &db.globals, + ) + .unwrap(); + + db + .rooms + .append_pdu( room_id.clone(), user_id.clone(), - EventType::RoomTopic, - json!({ "topic": topic }), + EventType::RoomPowerLevels, + json!({ + "ban": 50, + "events_default": 0, + "invite": 50, + "kick": 50, + "redact": 50, + "state_default": 50, + "users": { user_id.to_string(): 100 }, + "users_default": 0 + }), None, Some("".to_owned()), - ); + &db.globals, + ) + .unwrap(); + + if let Some(name) = &body.name { + db + .rooms + .append_pdu( + room_id.clone(), + user_id.clone(), + EventType::RoomName, + json!({ "name": name }), + None, + Some("".to_owned()), + &db.globals, + ) + .unwrap(); + } + + if let Some(topic) = &body.topic { + db + .rooms + .append_pdu( + room_id.clone(), + user_id.clone(), + EventType::RoomTopic, + json!({ "topic": topic }), + None, + Some("".to_owned()), + &db.globals, + ) + .unwrap(); } for user in &body.invite { - data.room_invite(&user_id, &room_id, user); + db + .rooms + .invite(&user_id, &room_id, user, &db.globals) + .unwrap(); } MatrixResult(Ok(create_room::Response { room_id })) @@ -773,12 +824,12 @@ pub fn create_room_route( #[get("/_matrix/client/r0/directory/room/<_room_alias>", data = "")] pub fn get_alias_route( - data: State, + db: State<'_, Database>, body: Ruma, _room_alias: String, ) -> MatrixResult { // TODO - let room_id = if body.room_alias.server_name() == data.hostname() { + let room_id = if body.room_alias.server_name() == db.globals.hostname() { match body.room_alias.alias() { "conduit" => "!lgOCCXQKtXOAPlAlG5:conduit.rs", _ => { @@ -804,14 +855,22 @@ pub fn get_alias_route( #[post("/_matrix/client/r0/rooms/<_room_id>/join", data = "")] pub fn join_room_by_id_route( - data: State, + db: State<'_, Database>, body: Ruma, _room_id: String, ) -> MatrixResult { - if data.room_join( - &body.room_id, - body.user_id.as_ref().expect("user is authenticated"), - ) { + let user_id = body.user_id.clone().expect("user is authenticated"); + + if db + .rooms + .join( + &body.room_id, + &user_id, + db.users.displayname(&user_id).unwrap(), + &db.globals, + ) + .is_ok() + { MatrixResult(Ok(join_room_by_id::Response { room_id: body.room_id.clone(), })) @@ -826,14 +885,16 @@ pub fn join_room_by_id_route( #[post("/_matrix/client/r0/join/<_room_id_or_alias>", data = "")] pub fn join_room_by_id_or_alias_route( - data: State, + db: State<'_, Database>, body: Ruma, _room_id_or_alias: String, ) -> MatrixResult { + let user_id = body.user_id.clone().expect("user is authenticated"); + let room_id = match RoomId::try_from(body.room_id_or_alias.clone()) { Ok(room_id) => room_id, Err(room_alias) => { - if room_alias.server_name() == data.hostname() { + if room_alias.server_name() == db.globals.hostname() { return MatrixResult(Err(Error { kind: ErrorKind::NotFound, message: "Room alias not found.".to_owned(), @@ -847,10 +908,16 @@ pub fn join_room_by_id_or_alias_route( } }; - if data.room_join( - &room_id, - body.user_id.as_ref().expect("user is authenticated"), - ) { + if db + .rooms + .join( + &room_id, + &user_id, + db.users.displayname(&user_id).unwrap(), + &db.globals, + ) + .is_ok() + { MatrixResult(Ok(join_room_by_id_or_alias::Response { room_id })) } else { MatrixResult(Err(Error { @@ -863,38 +930,45 @@ pub fn join_room_by_id_or_alias_route( #[post("/_matrix/client/r0/rooms/<_room_id>/leave", data = "")] pub fn leave_room_route( - data: State, + db: State<'_, Database>, body: Ruma, _room_id: String, ) -> MatrixResult { let user_id = body.user_id.clone().expect("user is authenticated"); - data.room_leave(&user_id, &body.room_id, &user_id); + db + .rooms + .leave(&user_id, &body.room_id, &user_id, &db.globals) + .unwrap(); MatrixResult(Ok(leave_room::Response)) } #[post("/_matrix/client/r0/rooms/<_room_id>/forget", data = "")] pub fn forget_room_route( - data: State, + db: State<'_, Database>, body: Ruma, _room_id: String, ) -> MatrixResult { let user_id = body.user_id.clone().expect("user is authenticated"); - data.room_forget(&body.room_id, &user_id); + db.rooms.forget(&body.room_id, &user_id).unwrap(); MatrixResult(Ok(forget_room::Response)) } #[post("/_matrix/client/r0/rooms/<_room_id>/invite", data = "")] pub fn invite_user_route( - data: State, + db: State<'_, Database>, body: Ruma, _room_id: String, ) -> MatrixResult { if let invite_user::InvitationRecipient::UserId { user_id } = &body.recipient { - data.room_invite( - &body.user_id.as_ref().expect("user is authenticated"), - &body.room_id, - &user_id, - ); + db + .rooms + .invite( + &body.user_id.as_ref().expect("user is authenticated"), + &body.room_id, + &user_id, + &db.globals, + ) + .unwrap(); MatrixResult(Ok(invite_user::Response)) } else { MatrixResult(Err(Error { @@ -905,16 +979,16 @@ pub fn invite_user_route( } } -#[post("/_matrix/client/r0/publicRooms", data = "")] +#[post("/_matrix/client/r0/publicRooms")] pub async fn get_public_rooms_filtered_route( - data: State<'_, Data>, - body: Ruma, + db: State<'_, Database>, ) -> MatrixResult { - let mut chunk = data - .rooms_all() + let mut chunk = db + .rooms + .all_rooms() .into_iter() .map(|room_id| { - let state = data.room_state(&room_id); + let state = db.rooms.room_state(&room_id).unwrap(); directory::PublicRoomsChunk { aliases: Vec::new(), canonical_alias: None, @@ -923,7 +997,7 @@ pub async fn get_public_rooms_filtered_route( .and_then(|s| s.content.get("name")) .and_then(|n| n.as_str()) .map(|n| n.to_owned()), - num_joined_members: data.room_users_joined(&room_id).into(), + num_joined_members: (db.rooms.room_members(&room_id).count() as u32).into(), room_id, topic: None, world_readable: false, @@ -937,7 +1011,7 @@ pub async fn get_public_rooms_filtered_route( chunk.extend_from_slice( &server_server::send_request( - &data, + &db, "privacytools.io".to_owned(), ruma_federation_api::v1::get_public_rooms::Request { limit: Some(20_u32.into()), @@ -965,13 +1039,14 @@ pub async fn get_public_rooms_filtered_route( #[post("/_matrix/client/r0/user_directory/search", data = "")] pub fn search_users_route( - data: State, + db: State<'_, Database>, body: Ruma, ) -> MatrixResult { MatrixResult(Ok(search_users::Response { - results: data - .users_all() - .into_iter() + results: db + .users + .iter() + .map(Result::unwrap) .filter(|user_id| user_id.to_string().contains(&body.search_term)) .map(|user_id| search_users::User { user_id, @@ -983,18 +1058,16 @@ pub fn search_users_route( })) } -#[get("/_matrix/client/r0/rooms/<_room_id>/members", data = "")] +#[get("/_matrix/client/r0/rooms/<_room_id>/members")] pub fn get_member_events_route( - body: Ruma, _room_id: String, ) -> MatrixResult { // TODO MatrixResult(Ok(get_member_events::Response { chunk: Vec::new() })) } -#[get("/_matrix/client/r0/thirdparty/protocols", data = "")] +#[get("/_matrix/client/r0/thirdparty/protocols")] pub fn get_protocols_route( - body: Ruma, ) -> MatrixResult { // TODO MatrixResult(Ok(get_protocols::Response { @@ -1007,7 +1080,7 @@ pub fn get_protocols_route( data = "" )] pub fn create_message_event_route( - data: State, + db: State<'_, Database>, _room_id: String, _event_type: String, _txn_id: String, @@ -1018,14 +1091,16 @@ pub fn create_message_event_route( let mut unsigned = serde_json::Map::new(); unsigned.insert("transaction_id".to_owned(), body.txn_id.clone().into()); - let event_id = data - .pdu_append( + let event_id = db + .rooms + .append_pdu( body.room_id.clone(), user_id.clone(), body.event_type.clone(), body.json_body.clone(), Some(unsigned), None, + &db.globals, ) .expect("message events are always okay"); @@ -1037,7 +1112,7 @@ pub fn create_message_event_route( data = "" )] pub fn create_state_event_for_key_route( - data: State, + db: State<'_, Database>, _room_id: String, _event_type: String, _state_key: String, @@ -1046,18 +1121,20 @@ pub fn create_state_event_for_key_route( let user_id = body.user_id.clone().expect("user is authenticated"); // Reponse of with/without key is the same - if let Some(event_id) = data.pdu_append( - body.room_id.clone(), - body.user_id.clone().expect("user is authenticated"), - body.event_type.clone(), - body.json_body.clone(), - None, - Some(body.state_key.clone()), - ) { - MatrixResult(Ok(create_state_event_for_key::Response { event_id })) - } else { - panic!("TODO: error missing permissions"); - } + let event_id = db + .rooms + .append_pdu( + body.room_id.clone(), + user_id, + body.event_type.clone(), + body.json_body.clone(), + None, + Some(body.state_key.clone()), + &db.globals, + ) + .unwrap(); + + MatrixResult(Ok(create_state_event_for_key::Response { event_id })) } #[put( @@ -1065,7 +1142,7 @@ pub fn create_state_event_for_key_route( data = "" )] pub fn create_state_event_for_empty_key_route( - data: State, + db: State<'_, Database>, _room_id: String, _event_type: String, body: Ruma, @@ -1073,39 +1150,46 @@ pub fn create_state_event_for_empty_key_route( let user_id = body.user_id.clone().expect("user is authenticated"); // Reponse of with/without key is the same - if let Some(event_id) = data.pdu_append( - body.room_id.clone(), - body.user_id.clone().expect("user is authenticated"), - body.event_type.clone(), - body.json_body.clone(), - None, - Some("".to_owned()), - ) { - MatrixResult(Ok(create_state_event_for_empty_key::Response { event_id })) - } else { - panic!("TODO: error missing permissions"); - } + let event_id = db + .rooms + .append_pdu( + body.room_id.clone(), + user_id, + body.event_type.clone(), + body.json_body.clone(), + None, + Some("".to_owned()), + &db.globals, + ) + .unwrap(); + + MatrixResult(Ok(create_state_event_for_empty_key::Response { event_id })) } #[get("/_matrix/client/r0/sync", data = "")] pub fn sync_route( - data: State, + db: State<'_, Database>, body: Ruma, ) -> MatrixResult { - std::thread::sleep(Duration::from_millis(1500)); + std::thread::sleep(Duration::from_millis(100)); let user_id = body.user_id.clone().expect("user is authenticated"); - let next_batch = data.last_pdu_index().to_string(); + let next_batch = db.globals.current_count().unwrap().to_string(); let mut joined_rooms = BTreeMap::new(); - let joined_roomids = data.rooms_joined(&user_id); let since = body .since .clone() .and_then(|string| string.parse().ok()) .unwrap_or(0); - for room_id in joined_roomids { - let mut pdus = data.pdus_since(&room_id, since); + for room_id in db.rooms.rooms_joined(&user_id) { + let room_id = room_id.unwrap(); + + let mut pdus = db + .rooms + .pdus_since(&room_id, since).unwrap() + .map(|r| r.unwrap()) + .collect::>(); let mut send_member_count = false; let mut send_full_state = false; @@ -1119,8 +1203,13 @@ pub fn sync_route( } } - let notification_count = if let Some(last_read) = data.room_read_get(&room_id, &user_id) { - Some((data.pdus_since(&room_id, last_read).len() as u32).into()) + let notification_count = if let Some(last_read) = db + .rooms + .edus + .room_read_get(&room_id, &user_id) + .unwrap() + { + Some((db.rooms.pdus_since(&room_id, last_read).unwrap().count() as u32).into()) } else { None }; @@ -1135,7 +1224,7 @@ pub fn sync_route( let prev_batch = pdus .first() - .and_then(|e| data.pdu_get_count(&e.event_id)) + .and_then(|e| db.rooms.get_pdu_count(&e.event_id).unwrap()) .map(|c| c.to_string()); let room_events = pdus @@ -1143,15 +1232,39 @@ pub fn sync_route( .map(|pdu| pdu.to_room_event()) .collect::>(); - let mut edus = data.roomlatests_since(&room_id, since); - edus.extend_from_slice(&data.roomactives_in(&room_id)); + let mut edus = db + .rooms + .edus + .roomactives_all(&room_id) + .map(|r| r.unwrap()) + .collect::>(); + + if edus.is_empty() { + edus.push( + EduEvent::Typing(ruma_events::typing::TypingEvent { + content: ruma_events::typing::TypingEventContent { + user_ids: Vec::new(), + }, + room_id: None, // None because it can be inferred + }) + .into(), + ); + } + + edus.extend( + db + .rooms + .edus + .roomlatests_since(&room_id, since).unwrap() + .map(|r| r.unwrap()), + ); joined_rooms.insert( room_id.clone().try_into().unwrap(), sync_events::JoinedRoom { account_data: Some(sync_events::AccountData { - events: data - .room_userdata_since(Some(&room_id), &user_id, since) + events: db.account_data + .changes_since(Some(&room_id), &user_id, since).unwrap() .into_iter() .map(|(_, v)| v) .collect(), @@ -1159,12 +1272,12 @@ pub fn sync_route( summary: sync_events::RoomSummary { heroes: Vec::new(), joined_member_count: if send_member_count { - Some(data.room_users_joined(&room_id).into()) + Some((db.rooms.room_members(&room_id).count() as u32).into()) } else { None }, invited_member_count: if send_member_count { - Some(data.room_users_invited(&room_id).into()) + Some((db.rooms.room_members_invited(&room_id).count() as u32).into()) } else { None }, @@ -1181,7 +1294,10 @@ pub fn sync_route( // TODO: state before timeline state: sync_events::State { events: if send_full_state { - data.room_state(&room_id) + db + .rooms + .room_state(&room_id) + .unwrap() .into_iter() .map(|(_, pdu)| pdu.to_state_event()) .collect() @@ -1195,12 +1311,28 @@ pub fn sync_route( } let mut left_rooms = BTreeMap::new(); - let left_roomids = data.rooms_left(&user_id); - for room_id in left_roomids { - let pdus = data.pdus_since(&room_id, since); - let room_events = pdus.into_iter().map(|pdu| pdu.to_room_event()).collect(); - let mut edus = data.roomlatests_since(&room_id, since); - edus.extend_from_slice(&data.roomactives_in(&room_id)); + for room_id in db.rooms.rooms_left(&user_id) { + let room_id = room_id.unwrap(); + let pdus = db.rooms.pdus_since(&room_id, since).unwrap(); + let room_events = pdus + .into_iter() + .map(|pdu| pdu.unwrap().to_room_event()) + .collect(); + + let mut edus = db + .rooms + .edus + .roomlatests_since(&room_id, since).unwrap() + .map(|r| r.unwrap()) + .collect::>(); + + edus.extend( + db + .rooms + .edus + .roomactives_all(&room_id) + .map(|r| r.unwrap()), + ); left_rooms.insert( room_id.clone().try_into().unwrap(), @@ -1217,11 +1349,13 @@ pub fn sync_route( } let mut invited_rooms = BTreeMap::new(); - for room_id in data.rooms_invited(&user_id) { - let events = data - .pdus_since(&room_id, since) + for room_id in db.rooms.rooms_invited(&user_id) { + let room_id = room_id.unwrap(); + let events = db + .rooms + .pdus_since(&room_id, since).unwrap() .into_iter() - .map(|pdu| pdu.to_stripped_state_event()) + .map(|pdu| pdu.unwrap().to_stripped_state_event()) .collect(); invited_rooms.insert( @@ -1241,8 +1375,8 @@ pub fn sync_route( }, presence: sync_events::Presence { events: Vec::new() }, account_data: sync_events::AccountData { - events: data - .room_userdata_since(None, &user_id, since) + events: db.account_data + .changes_since(None, &user_id, since).unwrap() .into_iter() .map(|(_, v)| v) .collect(), @@ -1255,7 +1389,7 @@ pub fn sync_route( #[get("/_matrix/client/r0/rooms/<_room_id>/messages", data = "")] pub fn get_message_events_route( - data: State, + db: State<'_, Database>, body: Ruma, _room_id: String, ) -> MatrixResult { @@ -1264,14 +1398,15 @@ pub fn get_message_events_route( } if let Ok(from) = body.from.clone().parse() { - let pdus = data.pdus_until( - &body.room_id, - from, - body.limit.map(|l| l.try_into().unwrap()).unwrap_or(10), - ); + let pdus = db + .rooms + .pdus_until(&body.room_id, from) + .take(body.limit.map(|l| l.try_into().unwrap()).unwrap_or(10_u32) as usize) + .map(|r| r.unwrap()) + .collect::>(); let prev_batch = pdus .last() - .and_then(|e| data.pdu_get_count(&e.event_id)) + .and_then(|e| db.rooms.get_pdu_count(&e.event_id).unwrap()) .map(|c| c.to_string()); let room_events = pdus .into_iter() @@ -1332,7 +1467,7 @@ pub fn get_media_config_route() -> MatrixResult { #[options("/<_segments..>")] pub fn options_route( - _segments: rocket::http::uri::Segments, + _segments: rocket::http::uri::Segments<'_>, ) -> MatrixResult { MatrixResult(Err(Error { kind: ErrorKind::NotFound, diff --git a/src/data.rs b/src/data.rs deleted file mode 100644 index 9b9c5412..00000000 --- a/src/data.rs +++ /dev/null @@ -1,1018 +0,0 @@ -use crate::{database::COUNTER, utils, Database, PduEvent}; -use ruma_events::{ - collections::only::Event as EduEvent, room::power_levels::PowerLevelsEventContent, EventJson, - EventType, -}; -use ruma_federation_api::RoomV3Pdu; -use ruma_identifiers::{EventId, RoomId, UserId}; -use serde_json::json; -use std::{ - collections::HashMap, - convert::{TryFrom, TryInto}, - mem, -}; - -pub struct Data { - hostname: String, - reqwest_client: reqwest::Client, - db: Database, -} - -impl Data { - /// Load an existing database or create a new one. - pub fn load_or_create(hostname: &str) -> Self { - let db = Database::load_or_create(hostname); - Self { - hostname: hostname.to_owned(), - reqwest_client: reqwest::Client::new(), - db, - } - } - - /// Get the hostname of the server. - pub fn hostname(&self) -> &str { - &self.hostname - } - - /// Get the hostname of the server. - pub fn reqwest_client(&self) -> &reqwest::Client { - &self.reqwest_client - } - - pub fn keypair(&self) -> &ruma_signatures::Ed25519KeyPair { - &self.db.keypair - } - - /// Check if a user has an account by looking for an assigned password. - pub fn user_exists(&self, user_id: &UserId) -> bool { - self.db - .userid_password - .contains_key(user_id.to_string()) - .unwrap() - } - - /// Create a new user account by assigning them a password. - pub fn user_add(&self, user_id: &UserId, hash: &str) { - self.db - .userid_password - .insert(user_id.to_string(), hash) - .unwrap(); - } - - /// Find out which user an access token belongs to. - pub fn user_from_token(&self, token: &str) -> Option { - self.db - .token_userid - .get(token) - .unwrap() - .and_then(|bytes| (*utils::string_from_bytes(&bytes)).try_into().ok()) - } - - pub fn users_all(&self) -> Vec { - self.db - .userid_password - .iter() - .keys() - .map(|k| UserId::try_from(&*utils::string_from_bytes(&k.unwrap())).unwrap()) - .collect() - } - - /// Gets password hash for given user id. - pub fn password_hash_get(&self, user_id: &UserId) -> Option { - self.db - .userid_password - .get(user_id.to_string()) - .unwrap() - .map(|bytes| utils::string_from_bytes(&bytes)) - } - - /// Removes a displayname. - pub fn displayname_remove(&self, user_id: &UserId) { - self.db - .userid_displayname - .remove(user_id.to_string()) - .unwrap(); - } - - /// Set a new displayname. - pub fn displayname_set(&self, user_id: &UserId, displayname: String) { - self.db - .userid_displayname - .insert(user_id.to_string(), &*displayname) - .unwrap(); - for room_id in self.rooms_joined(user_id) { - self.pdu_append( - room_id.clone(), - user_id.clone(), - EventType::RoomMember, - json!({"membership": "join", "displayname": displayname}), - None, - Some(user_id.to_string()), - ); - } - } - - /// Get a the displayname of a user. - pub fn displayname_get(&self, user_id: &UserId) -> Option { - self.db - .userid_displayname - .get(user_id.to_string()) - .unwrap() - .map(|bytes| utils::string_from_bytes(&bytes)) - } - - /// Removes a avatar_url. - pub fn avatar_url_remove(&self, user_id: &UserId) { - self.db - .userid_avatarurl - .remove(user_id.to_string()) - .unwrap(); - } - - /// Set a new avatar_url. - pub fn avatar_url_set(&self, user_id: &UserId, avatar_url: String) { - self.db - .userid_avatarurl - .insert(user_id.to_string(), &*avatar_url) - .unwrap(); - } - - /// Get a the avatar_url of a user. - pub fn avatar_url_get(&self, user_id: &UserId) -> Option { - self.db - .userid_avatarurl - .get(user_id.to_string()) - .unwrap() - .map(|bytes| utils::string_from_bytes(&bytes)) - } - - /// Add a new device to a user. - pub fn device_add(&self, user_id: &UserId, device_id: &str) { - if self - .db - .userid_deviceids - .get_iter(&user_id.to_string().as_bytes()) - .filter_map(|item| item.ok()) - .map(|(_key, value)| value) - .all(|device| device != device_id) - { - self.db - .userid_deviceids - .add(user_id.to_string().as_bytes(), device_id.into()); - } - } - - /// Replace the access token of one device. - pub fn token_replace(&self, user_id: &UserId, device_id: &String, token: String) { - // Make sure the device id belongs to the user - debug_assert!(self - .db - .userid_deviceids - .get_iter(&user_id.to_string().as_bytes()) - .filter_map(|item| item.ok()) - .map(|(_key, value)| value) - .any(|device| device == device_id.as_bytes())); // Does the user have that device? - - // Remove old token - let mut key = user_id.to_string().as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(device_id.as_bytes()); - if let Some(old_token) = self.db.userdeviceid_token.get(&key).unwrap() { - self.db.token_userid.remove(old_token).unwrap(); - // It will be removed from deviceid_token by the insert later - } - - // Assign token to device_id - self.db.userdeviceid_token.insert(key, &*token).unwrap(); - - // Assign token to user - self.db - .token_userid - .insert(token, &*user_id.to_string()) - .unwrap(); - } - - pub fn room_join(&self, room_id: &RoomId, user_id: &UserId) -> bool { - if !self.room_exists(room_id) - && !self - .db - .userid_joinroomids - .get_iter(user_id.to_string().as_bytes()) - .values() - .any(|r| r.unwrap() == room_id.to_string().as_bytes()) - { - return false; - } - - self.db.userid_joinroomids.add( - user_id.to_string().as_bytes(), - room_id.to_string().as_bytes().into(), - ); - self.db.roomid_joinuserids.add( - room_id.to_string().as_bytes(), - user_id.to_string().as_bytes().into(), - ); - self.db.userid_inviteroomids.remove_value( - user_id.to_string().as_bytes(), - room_id.to_string().as_bytes(), - ); - self.db.roomid_inviteuserids.remove_value( - user_id.to_string().as_bytes(), - room_id.to_string().as_bytes(), - ); - self.db.userid_leftroomids.remove_value( - user_id.to_string().as_bytes(), - room_id.to_string().as_bytes().into(), - ); - - let mut content = json!({"membership": "join"}); - if let Some(displayname) = self.displayname_get(user_id) { - content - .as_object_mut() - .unwrap() - .insert("displayname".to_owned(), displayname.into()); - } - - self.pdu_append( - room_id.clone(), - user_id.clone(), - EventType::RoomMember, - content, - None, - Some(user_id.to_string()), - ); - - true - } - - pub fn rooms_joined(&self, user_id: &UserId) -> Vec { - self.db - .userid_joinroomids - .get_iter(user_id.to_string().as_bytes()) - .values() - .map(|room_id| { - RoomId::try_from(&*utils::string_from_bytes(&room_id.unwrap())) - .expect("user joined valid room ids") - }) - .collect() - } - - /// Check if a room exists by looking for PDUs in that room. - pub fn room_exists(&self, room_id: &RoomId) -> bool { - // Create the first part of the full pdu id - let mut prefix = room_id.to_string().as_bytes().to_vec(); - prefix.push(0xff); // Add delimiter so we don't find rooms starting with the same id - - if let Some((key, _)) = self.db.pduid_pdu.get_gt(&prefix).unwrap() { - if key.starts_with(&prefix) { - true - } else { - false - } - } else { - false - } - } - - pub fn rooms_all(&self) -> Vec { - let mut room_ids = self - .db - .roomid_pduleaves - .iter_all() - .keys() - .map(|key| { - RoomId::try_from(&*utils::string_from_bytes( - &key.unwrap() - .iter() - .skip(1) // skip "d" - .copied() - .take_while(|&x| x != 0xff) // until delimiter - .collect::>(), - )) - .unwrap() - }) - .collect::>(); - room_ids.dedup(); - room_ids - } - - pub fn room_users_joined(&self, room_id: &RoomId) -> u32 { - self.db - .roomid_joinuserids - .get_iter(room_id.to_string().as_bytes()) - .count() as u32 - } - - pub fn room_users_invited(&self, room_id: &RoomId) -> u32 { - self.db - .roomid_inviteuserids - .get_iter(room_id.to_string().as_bytes()) - .count() as u32 - } - - pub fn room_state(&self, room_id: &RoomId) -> HashMap<(EventType, String), PduEvent> { - let mut hashmap = HashMap::new(); - for pdu in self - .db - .roomstateid_pdu - .scan_prefix(&room_id.to_string().as_bytes()) - .values() - .map(|value| serde_json::from_slice::(&value.unwrap()).unwrap()) - { - hashmap.insert( - ( - pdu.kind.clone(), - pdu.state_key - .clone() - .expect("state events have a state key"), - ), - pdu, - ); - } - hashmap - } - - pub fn room_leave(&self, sender: &UserId, room_id: &RoomId, user_id: &UserId) { - self.pdu_append( - room_id.clone(), - sender.clone(), - EventType::RoomMember, - json!({"membership": "leave"}), - None, - Some(user_id.to_string()), - ); - self.db.userid_inviteroomids.remove_value( - user_id.to_string().as_bytes(), - room_id.to_string().as_bytes().into(), - ); - self.db.roomid_inviteuserids.remove_value( - user_id.to_string().as_bytes(), - room_id.to_string().as_bytes().into(), - ); - self.db.userid_joinroomids.remove_value( - user_id.to_string().as_bytes(), - room_id.to_string().as_bytes().into(), - ); - self.db.roomid_joinuserids.remove_value( - room_id.to_string().as_bytes(), - user_id.to_string().as_bytes().into(), - ); - self.db.userid_leftroomids.add( - user_id.to_string().as_bytes(), - room_id.to_string().as_bytes().into(), - ); - } - - pub fn room_forget(&self, room_id: &RoomId, user_id: &UserId) { - self.db.userid_leftroomids.remove_value( - user_id.to_string().as_bytes(), - room_id.to_string().as_bytes().into(), - ); - } - - pub fn room_invite(&self, sender: &UserId, room_id: &RoomId, user_id: &UserId) { - self.pdu_append( - room_id.clone(), - sender.clone(), - EventType::RoomMember, - json!({"membership": "invite"}), - None, - Some(user_id.to_string()), - ); - self.db.userid_inviteroomids.add( - user_id.to_string().as_bytes(), - room_id.to_string().as_bytes().into(), - ); - self.db.roomid_inviteuserids.add( - room_id.to_string().as_bytes(), - user_id.to_string().as_bytes().into(), - ); - } - - pub fn rooms_invited(&self, user_id: &UserId) -> Vec { - self.db - .userid_inviteroomids - .get_iter(&user_id.to_string().as_bytes()) - .values() - .map(|key| RoomId::try_from(&*utils::string_from_bytes(&key.unwrap())).unwrap()) - .collect() - } - - pub fn rooms_left(&self, user_id: &UserId) -> Vec { - self.db - .userid_leftroomids - .get_iter(&user_id.to_string().as_bytes()) - .values() - .map(|key| RoomId::try_from(&*utils::string_from_bytes(&key.unwrap())).unwrap()) - .collect() - } - - pub fn pdu_get_count(&self, event_id: &EventId) -> Option { - self.db - .eventid_pduid - .get(event_id.to_string().as_bytes()) - .unwrap() - .map(|pdu_id| { - utils::u64_from_bytes(&pdu_id[pdu_id.len() - mem::size_of::()..pdu_id.len()]) - }) - } - - pub fn pdu_get(&self, event_id: &EventId) -> Option { - self.db - .eventid_pduid - .get(event_id.to_string().as_bytes()) - .unwrap() - .map(|pdu_id| { - serde_json::from_slice( - &self - .db - .pduid_pdu - .get(pdu_id) - .unwrap() - .expect("eventid_pduid in db is valid"), - ) - .expect("pdu is valid") - }) - } - - pub fn pdu_leaves_get(&self, room_id: &RoomId) -> Vec { - let event_ids = self - .db - .roomid_pduleaves - .get_iter(room_id.to_string().as_bytes()) - .values() - .map(|pdu_id| { - EventId::try_from(&*utils::string_from_bytes(&pdu_id.unwrap())) - .expect("pdu leaves are valid event ids") - }) - .collect(); - - event_ids - } - - pub fn pdu_leaves_replace(&self, room_id: &RoomId, event_id: &EventId) { - self.db - .roomid_pduleaves - .clear(room_id.to_string().as_bytes()); - - self.db.roomid_pduleaves.add( - &room_id.to_string().as_bytes(), - (*event_id.to_string()).into(), - ); - } - - /// Add a persisted data unit from this homeserver - pub fn pdu_append( - &self, - room_id: RoomId, - sender: UserId, - event_type: EventType, - content: serde_json::Value, - unsigned: Option>, - state_key: Option, - ) -> Option { - // Is the event authorized? - if state_key.is_some() { - if let Some(pdu) = self - .room_state(&room_id) - .get(&(EventType::RoomPowerLevels, "".to_owned())) - { - let power_levels = serde_json::from_value::>( - pdu.content.clone(), - ) - .unwrap() - .deserialize() - .unwrap(); - - match event_type { - EventType::RoomMember => { - // Member events are okay for now (TODO) - } - _ if power_levels - .users - .get(&sender) - .unwrap_or(&power_levels.users_default) - <= &0.into() => - { - // Not authorized - return None; - } - // User has sufficient power - _ => {} - } - } - } - - // prev_events are the leaves of the current graph. This method removes all leaves from the - // room and replaces them with our event - // TODO: Make sure this isn't called twice in parallel - let prev_events = self.pdu_leaves_get(&room_id); - - // Our depth is the maximum depth of prev_events + 1 - let depth = prev_events - .iter() - .map(|event_id| { - self.pdu_get(event_id) - .expect("pdu in prev_events is valid") - .depth - .into() - }) - .max() - .unwrap_or(0_u64) - + 1; - - let mut unsigned = unsigned.unwrap_or_default(); - // TODO: Optimize this to not load the whole room state? - if let Some(state_key) = &state_key { - if let Some(prev_pdu) = self - .room_state(&room_id) - .get(&(event_type.clone(), state_key.clone())) - { - unsigned.insert("prev_content".to_owned(), prev_pdu.content.clone()); - } - } - - let mut pdu = PduEvent { - event_id: EventId::try_from("$thiswillbefilledinlater").unwrap(), - room_id: room_id.clone(), - sender: sender.clone(), - origin: self.hostname.clone(), - origin_server_ts: utils::millis_since_unix_epoch().try_into().unwrap(), - kind: event_type, - content, - state_key, - prev_events, - depth: depth.try_into().unwrap(), - auth_events: Vec::new(), - redacts: None, - unsigned, - hashes: ruma_federation_api::EventHash { - sha256: "aaa".to_owned(), - }, - signatures: HashMap::new(), - }; - - // Generate event id - pdu.event_id = EventId::try_from(&*format!( - "${}", - ruma_signatures::reference_hash(&serde_json::to_value(&pdu).unwrap()) - .expect("ruma can calculate reference hashes") - )) - .expect("ruma's reference hashes are correct"); - - let mut pdu_json = serde_json::to_value(&pdu).unwrap(); - ruma_signatures::hash_and_sign_event(self.hostname(), self.keypair(), &mut pdu_json) - .unwrap(); - - self.pdu_leaves_replace(&room_id, &pdu.event_id); - - // The count will go up regardless of the room_id - // This is also the next_batch/since value - // Increment the last index and use that - let index = utils::u64_from_bytes( - &self - .db - .global - .update_and_fetch(COUNTER, utils::increment) - .unwrap() - .unwrap(), - ); - - let mut pdu_id = room_id.to_string().as_bytes().to_vec(); - pdu_id.push(0xff); // Add delimiter so we don't find rooms starting with the same id - pdu_id.extend_from_slice(&index.to_be_bytes()); - - self.db - .pduid_pdu - .insert(&pdu_id, &*pdu_json.to_string()) - .unwrap(); - - self.db - .eventid_pduid - .insert(pdu.event_id.to_string(), pdu_id.clone()) - .unwrap(); - - if let Some(state_key) = pdu.state_key { - let mut key = room_id.to_string().as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(pdu.kind.to_string().as_bytes()); - key.push(0xff); - key.extend_from_slice(state_key.to_string().as_bytes()); - self.db - .roomstateid_pdu - .insert(key, &*pdu_json.to_string()) - .unwrap(); - } - - self.room_read_set(&room_id, &sender, &pdu.event_id); - - Some(pdu.event_id) - } - - /// Returns a vector of all PDUs in a room. - pub fn pdus_all(&self, room_id: &RoomId) -> Vec { - self.pdus_since(room_id, 0) - } - - pub fn last_pdu_index(&self) -> u64 { - utils::u64_from_bytes( - &self - .db - .global - .get(&COUNTER) - .unwrap() - .unwrap_or_else(|| (&0_u64.to_be_bytes()).into()), - ) - } - - /// Returns a vector of all events in a room that happened after the event with id `since`. - pub fn pdus_since(&self, room_id: &RoomId, since: u64) -> Vec { - // Create the first part of the full pdu id - let mut pdu_id = room_id.to_string().as_bytes().to_vec(); - pdu_id.push(0xff); // Add delimiter so we don't find rooms starting with the same id - pdu_id.extend_from_slice(&(since).to_be_bytes()); - - self.pdus_since_pduid(room_id, pdu_id) - } - - /// Returns a vector of all events in a room that happened after the event with id `since`. - pub fn pdus_since_pduid(&self, room_id: &RoomId, pdu_id: Vec) -> Vec { - let mut pdus = Vec::new(); - - // Create the first part of the full pdu id - let mut prefix = room_id.to_string().as_bytes().to_vec(); - prefix.push(0xff); // Add delimiter so we don't find rooms starting with the same id - - let mut current = pdu_id; - - while let Some((key, value)) = self.db.pduid_pdu.get_gt(¤t).unwrap() { - if key.starts_with(&prefix) { - current = key.to_vec(); - pdus.push(serde_json::from_slice(&value).expect("pdu in db is valid")); - } else { - break; - } - } - - pdus - } - - pub fn pdus_until(&self, room_id: &RoomId, until: u64, max: u32) -> Vec { - let mut pdus = Vec::new(); - - // Create the first part of the full pdu id - let mut prefix = room_id.to_string().as_bytes().to_vec(); - prefix.push(0xff); // Add delimiter so we don't find rooms starting with the same id - - let mut current = prefix.clone(); - current.extend_from_slice(&until.to_be_bytes()); - - while let Some((key, value)) = self.db.pduid_pdu.get_lt(¤t).unwrap() { - if pdus.len() < max as usize && key.starts_with(&prefix) { - current = key.to_vec(); - pdus.push(serde_json::from_slice(&value).expect("pdu in db is valid")); - } else { - break; - } - } - - pdus - } - - pub fn roomlatest_update(&self, user_id: &UserId, room_id: &RoomId, event: EduEvent) { - let mut prefix = room_id.to_string().as_bytes().to_vec(); - prefix.push(0xff); - - // Start with last - if let Some(mut current) = self - .db - .roomlatestid_roomlatest - .scan_prefix(&prefix) - .keys() - .next_back() - .map(|c| c.unwrap()) - { - // Remove old marker (There should at most one) - loop { - if !current.starts_with(&prefix) { - // We're in another room - break; - } - if current.rsplitn(2, |&b| b == 0xff).next().unwrap() - == user_id.to_string().as_bytes() - { - // This is the old room_latest - self.db.roomlatestid_roomlatest.remove(current).unwrap(); - break; - } - // Else, try the event before that - if let Some((k, _)) = self.db.roomlatestid_roomlatest.get_lt(current).unwrap() { - current = k; - } else { - break; - } - } - } - - // Increment the last index and use that - let index = utils::u64_from_bytes( - &self - .db - .global - .update_and_fetch(COUNTER, utils::increment) - .unwrap() - .unwrap(), - ); - - let mut room_latest_id = prefix; - room_latest_id.extend_from_slice(&index.to_be_bytes()); - room_latest_id.push(0xff); - room_latest_id.extend_from_slice(&user_id.to_string().as_bytes()); - - self.db - .roomlatestid_roomlatest - .insert(room_latest_id, &*serde_json::to_string(&event).unwrap()) - .unwrap(); - } - - /// Returns a vector of the most recent read_receipts in a room that happened after the event with id `since`. - pub fn roomlatests_since(&self, room_id: &RoomId, since: u64) -> Vec> { - let mut room_latests = Vec::new(); - - let mut prefix = room_id.to_string().as_bytes().to_vec(); - prefix.push(0xff); - - let mut current = prefix.clone(); - current.extend_from_slice(&(since + 1).to_be_bytes()); - - while let Some((key, value)) = self.db.roomlatestid_roomlatest.get_gt(¤t).unwrap() { - if key.starts_with(&prefix) { - current = key.to_vec(); - room_latests.push( - serde_json::from_slice::>(&value) - .expect("room_latest in db is valid"), - ); - } else { - break; - } - } - - room_latests - } - - /// Returns a vector of the most recent read_receipts in a room that happened after the event with id `since`. - pub fn roomlatests_all(&self, room_id: &RoomId) -> Vec> { - self.roomlatests_since(room_id, 0) - } - - pub fn roomactive_add(&self, event: EduEvent, room_id: &RoomId, timeout: u64) { - let mut prefix = room_id.to_string().as_bytes().to_vec(); - prefix.push(0xff); - - let mut current = prefix.clone(); - - while let Some((key, _)) = self.db.roomactiveid_roomactive.get_gt(¤t).unwrap() { - if key.starts_with(&prefix) - && utils::u64_from_bytes(key.split(|&c| c == 0xff).nth(1).unwrap()) - > utils::millis_since_unix_epoch().try_into().unwrap() - { - current = key.to_vec(); - self.db.roomactiveid_roomactive.remove(¤t).unwrap(); - } else { - break; - } - } - - // Increment the last index and use that - let index = utils::u64_from_bytes( - &self - .db - .global - .update_and_fetch(COUNTER, utils::increment) - .unwrap() - .unwrap(), - ); - - let mut room_active_id = prefix; - room_active_id.extend_from_slice(&timeout.to_be_bytes()); - room_active_id.push(0xff); - room_active_id.extend_from_slice(&index.to_be_bytes()); - - self.db - .roomactiveid_roomactive - .insert(room_active_id, &*serde_json::to_string(&event).unwrap()) - .unwrap(); - } - - pub fn roomactive_remove(&self, event: EduEvent, room_id: &RoomId) { - let mut prefix = room_id.to_string().as_bytes().to_vec(); - prefix.push(0xff); - - let mut current = prefix.clone(); - - let json = serde_json::to_string(&event).unwrap(); - - while let Some((key, value)) = self.db.roomactiveid_roomactive.get_gt(¤t).unwrap() { - if key.starts_with(&prefix) { - current = key.to_vec(); - if value == json.as_bytes() { - self.db.roomactiveid_roomactive.remove(¤t).unwrap(); - break; - } - } else { - break; - } - } - } - - /// Returns a vector of the most recent read_receipts in a room that happened after the event with id `since`. - pub fn roomactives_in(&self, room_id: &RoomId) -> Vec> { - let mut room_actives = Vec::new(); - - let mut prefix = room_id.to_string().as_bytes().to_vec(); - prefix.push(0xff); - - let mut current = prefix.clone(); - current.extend_from_slice(&utils::millis_since_unix_epoch().to_be_bytes()); - - while let Some((key, value)) = self.db.roomactiveid_roomactive.get_gt(¤t).unwrap() { - if key.starts_with(&prefix) { - current = key.to_vec(); - room_actives.push( - serde_json::from_slice::>(&value) - .expect("room_active in db is valid"), - ); - } else { - break; - } - } - - if room_actives.is_empty() { - return vec![EduEvent::Typing(ruma_events::typing::TypingEvent { - content: ruma_events::typing::TypingEventContent { - user_ids: Vec::new(), - }, - room_id: None, // None because it can be inferred - }) - .into()]; - } else { - room_actives - } - } - - pub fn room_userdata_update( - &self, - room_id: Option<&RoomId>, - user_id: &UserId, - event: EduEvent, - ) { - let mut prefix = room_id - .map(|r| r.to_string()) - .unwrap_or_default() - .as_bytes() - .to_vec(); - prefix.push(0xff); - prefix.extend_from_slice(&user_id.to_string().as_bytes()); - prefix.push(0xff); - - // Start with last - if let Some(mut current) = self - .db - .roomuserdataid_accountdata - .scan_prefix(&prefix) - .keys() - .next_back() - .map(|c| c.unwrap()) - { - // Remove old entry (there should be at most one) - loop { - if !current.starts_with(&prefix) { - // We're in another room or user - break; - } - if current.rsplit(|&b| b == 0xff).nth(2).unwrap() == user_id.to_string().as_bytes() - { - // This is the old room_latest - self.db.roomuserdataid_accountdata.remove(current).unwrap(); - break; - } - // Else, try the event before that - if let Some((k, _)) = self.db.roomuserdataid_accountdata.get_lt(current).unwrap() { - current = k; - } else { - break; - } - } - } - - // Increment the last index and use that - let index = utils::u64_from_bytes( - &self - .db - .global - .update_and_fetch(COUNTER, utils::increment) - .unwrap() - .unwrap(), - ); - - let mut key = prefix; - key.extend_from_slice(&index.to_be_bytes()); - - let json = serde_json::to_value(&event).unwrap(); - key.extend_from_slice(json["type"].as_str().unwrap().as_bytes()); - - self.db - .roomuserdataid_accountdata - .insert(key, &*json.to_string()) - .unwrap(); - } - - pub fn room_userdata_get( - &self, - room_id: Option<&RoomId>, - user_id: &UserId, - kind: &str, - ) -> Option> { - self.room_userdata_all(room_id, user_id).remove(kind) - } - - pub fn room_userdata_since( - &self, - room_id: Option<&RoomId>, - user_id: &UserId, - since: u64, - ) -> HashMap> { - let mut userdata = HashMap::new(); - - let mut prefix = room_id - .map(|r| r.to_string()) - .unwrap_or_default() - .as_bytes() - .to_vec(); - prefix.push(0xff); - prefix.extend_from_slice(&user_id.to_string().as_bytes()); - prefix.push(0xff); - - let mut current = prefix.clone(); - current.extend_from_slice(&(since + 1).to_be_bytes()); - - while let Some((key, value)) = self.db.roomuserdataid_accountdata.get_gt(¤t).unwrap() - { - if key.starts_with(&prefix) { - current = key.to_vec(); - let json = serde_json::from_slice::(&value).unwrap(); - userdata.insert( - json["type"].as_str().unwrap().to_owned(), - serde_json::from_value::>(json) - .expect("userdata in db is valid"), - ); - } else { - break; - } - } - - userdata - } - - pub fn room_userdata_all( - &self, - room_id: Option<&RoomId>, - user_id: &UserId, - ) -> HashMap> { - self.room_userdata_since(room_id, user_id, 0) - } - - pub fn room_read_set( - &self, - room_id: &RoomId, - user_id: &UserId, - event_id: &EventId, - ) -> Option<()> { - let mut key = room_id.to_string().as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(&user_id.to_string().as_bytes()); - - self.db - .roomuserid_lastread - .insert(key, &self.pdu_get_count(event_id)?.to_be_bytes()) - .unwrap(); - - Some(()) - } - - pub fn room_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Option { - let mut key = room_id.to_string().as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(&user_id.to_string().as_bytes()); - - self.db - .roomuserid_lastread - .get(key) - .unwrap() - .map(|v| utils::u64_from_bytes(&v)) - } - - pub fn debug(&self) { - self.db.debug(); - } -} diff --git a/src/database.rs b/src/database.rs index 4551bc0a..47f0a560 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,92 +1,19 @@ -use crate::utils; +pub(self) mod account_data; +pub(self) mod globals; +pub(self) mod rooms; +pub(self) mod users; + use directories::ProjectDirs; -use sled::IVec; use std::fs::remove_dir_all; -pub struct MultiValue(sled::Tree); - -pub const COUNTER: &str = "c"; - -impl MultiValue { - /// Get an iterator over all values. - pub fn iter_all(&self) -> sled::Iter { - self.0.scan_prefix(b"d") - } - - /// Get an iterator over all values of this id. - pub fn get_iter(&self, id: &[u8]) -> sled::Iter { - // Data keys start with d - let mut key = vec![b'd']; - key.extend_from_slice(id.as_ref()); - key.push(0xff); // Add delimiter so we don't find keys starting with the same id - - self.0.scan_prefix(key) - } - - pub fn clear(&self, id: &[u8]) { - for key in self.get_iter(id).keys() { - self.0.remove(key.unwrap()).unwrap(); - } - } - - pub fn remove_value(&self, id: &[u8], value: &[u8]) { - if let Some(key) = self - .get_iter(id) - .find(|t| &t.as_ref().unwrap().1 == value) - .map(|t| t.unwrap().0) - { - self.0.remove(key).unwrap(); - } - } - - /// Add another value to the id. - pub fn add(&self, id: &[u8], value: IVec) { - // The new value will need a new index. We store the last used index in 'n' + id - let mut count_key: Vec = vec![b'n']; - count_key.extend_from_slice(id.as_ref()); - - // Increment the last index and use that - let index = self - .0 - .update_and_fetch(&count_key, utils::increment) - .unwrap() - .unwrap(); - - // Data keys start with d - let mut key = vec![b'd']; - key.extend_from_slice(id.as_ref()); - key.push(0xff); - key.extend_from_slice(&index); - - self.0.insert(key, value).unwrap(); - } -} - pub struct Database { - pub userid_password: sled::Tree, - pub userid_displayname: sled::Tree, - pub userid_avatarurl: sled::Tree, - pub userid_deviceids: MultiValue, - pub userdeviceid_token: sled::Tree, - pub token_userid: sled::Tree, - pub pduid_pdu: sled::Tree, // PduId = RoomId + Count - pub eventid_pduid: sled::Tree, - pub roomid_pduleaves: MultiValue, - pub roomstateid_pdu: sled::Tree, // Room + StateType + StateKey - pub roomuserdataid_accountdata: sled::Tree, // RoomUserDataId = Room + User + Count + Type - pub roomuserid_lastread: sled::Tree, // RoomUserId = Room + User - pub roomid_joinuserids: MultiValue, - pub roomid_inviteuserids: MultiValue, - pub userid_joinroomids: MultiValue, - pub userid_inviteroomids: MultiValue, - pub userid_leftroomids: MultiValue, - // EDUs: - pub roomlatestid_roomlatest: sled::Tree, // Read Receipts, RoomLatestId = RoomId + Count + UserId TODO: Types - pub roomactiveid_roomactive: sled::Tree, // Typing, RoomActiveId = TimeoutTime + Count - pub globalallid_globalall: sled::Tree, // ToDevice, GlobalAllId = UserId + Count - pub globallatestid_globallatest: sled::Tree, // Presence, GlobalLatestId = Count + Type + UserId - pub keypair: ruma_signatures::Ed25519KeyPair, - pub global: sled::Db, + pub globals: globals::Globals, + pub users: users::Users, + pub rooms: rooms::Rooms, + pub account_data: account_data::AccountData, + //pub globalallid_globalall: sled::Tree, // ToDevice, GlobalAllId = UserId + Count + //pub globallatestid_globallatest: sled::Tree, // Presence, GlobalLatestId = Count + Type + UserId + pub _db: sled::Db, } impl Database { @@ -110,166 +37,38 @@ impl Database { let db = sled::open(&path).unwrap(); Self { - userid_password: db.open_tree("userid_password").unwrap(), - userid_deviceids: MultiValue(db.open_tree("userid_deviceids").unwrap()), - userid_displayname: db.open_tree("userid_displayname").unwrap(), - userid_avatarurl: db.open_tree("userid_avatarurl").unwrap(), - userdeviceid_token: db.open_tree("userdeviceid_token").unwrap(), - token_userid: db.open_tree("token_userid").unwrap(), - pduid_pdu: db.open_tree("pduid_pdu").unwrap(), - eventid_pduid: db.open_tree("eventid_pduid").unwrap(), - roomid_pduleaves: MultiValue(db.open_tree("roomid_pduleaves").unwrap()), - roomstateid_pdu: db.open_tree("roomstateid_pdu").unwrap(), - roomuserdataid_accountdata: db.open_tree("roomuserdataid_accountdata").unwrap(), - roomuserid_lastread: db.open_tree("roomuserid_lastread").unwrap(), - roomid_joinuserids: MultiValue(db.open_tree("roomid_joinuserids").unwrap()), - roomid_inviteuserids: MultiValue(db.open_tree("roomid_inviteuserids").unwrap()), - userid_joinroomids: MultiValue(db.open_tree("userid_joinroomids").unwrap()), - userid_inviteroomids: MultiValue(db.open_tree("userid_inviteroomids").unwrap()), - userid_leftroomids: MultiValue(db.open_tree("userid_leftroomids").unwrap()), - roomlatestid_roomlatest: db.open_tree("roomlatestid_roomlatest").unwrap(), - roomactiveid_roomactive: db.open_tree("roomactiveid_roomactive").unwrap(), - globalallid_globalall: db.open_tree("globalallid_globalall").unwrap(), - globallatestid_globallatest: db.open_tree("globallatestid_globallatest").unwrap(), - keypair: ruma_signatures::Ed25519KeyPair::new( - &*db.update_and_fetch("keypair", utils::generate_keypair) - .unwrap() - .unwrap(), - "key1".to_owned(), - ) - .unwrap(), - global: db, - } - } - - pub fn debug(&self) { - println!("# UserId -> Password:"); - for (k, v) in self.userid_password.iter().map(|r| r.unwrap()) { - println!( - "{:?} -> {:?}", - String::from_utf8_lossy(&k), - String::from_utf8_lossy(&v), - ); - } - println!("\n# UserId -> DeviceIds:"); - for (k, v) in self.userid_deviceids.iter_all().map(|r| r.unwrap()) { - println!( - "{:?} -> {:?}", - String::from_utf8_lossy(&k), - String::from_utf8_lossy(&v), - ); - } - println!("\n# UserId -> Displayname:"); - for (k, v) in self.userid_displayname.iter().map(|r| r.unwrap()) { - println!( - "{:?} -> {:?}", - String::from_utf8_lossy(&k), - String::from_utf8_lossy(&v), - ); - } - println!("\n# UserId -> AvatarURL:"); - for (k, v) in self.userid_avatarurl.iter().map(|r| r.unwrap()) { - println!( - "{:?} -> {:?}", - String::from_utf8_lossy(&k), - String::from_utf8_lossy(&v), - ); - } - println!("\n# UserId+DeviceId -> Token:"); - for (k, v) in self.userdeviceid_token.iter().map(|r| r.unwrap()) { - println!( - "{:?} -> {:?}", - String::from_utf8_lossy(&k), - String::from_utf8_lossy(&v), - ); - } - println!("\n# Token -> UserId:"); - for (k, v) in self.token_userid.iter().map(|r| r.unwrap()) { - println!( - "{:?} -> {:?}", - String::from_utf8_lossy(&k), - String::from_utf8_lossy(&v), - ); - } - println!("\n# RoomId -> PDU leaves:"); - for (k, v) in self.roomid_pduleaves.iter_all().map(|r| r.unwrap()) { - println!( - "{:?} -> {:?}", - String::from_utf8_lossy(&k), - String::from_utf8_lossy(&v), - ); - } - println!("\n# RoomStateId -> PDU:"); - for (k, v) in self.roomstateid_pdu.iter().map(|r| r.unwrap()) { - println!( - "{:?} -> {:?}", - String::from_utf8_lossy(&k), - String::from_utf8_lossy(&v), - ); - } - println!("\n# RoomId -> UserIds:"); - for (k, v) in self.roomid_joinuserids.iter_all().map(|r| r.unwrap()) { - println!( - "{:?} -> {:?}", - String::from_utf8_lossy(&k), - String::from_utf8_lossy(&v), - ); - } - println!("\n# UserId -> RoomIds:"); - for (k, v) in self.userid_joinroomids.iter_all().map(|r| r.unwrap()) { - println!( - "{:?} -> {:?}", - String::from_utf8_lossy(&k), - String::from_utf8_lossy(&v), - ); - } - println!("\n# PDU Id -> PDU:"); - for (k, v) in self.pduid_pdu.iter().map(|r| r.unwrap()) { - println!( - "{:?} -> {:?}", - String::from_utf8_lossy(&k), - String::from_utf8_lossy(&v), - ); - } - println!("\n# EventId -> PDU Id:"); - for (k, v) in self.eventid_pduid.iter().map(|r| r.unwrap()) { - println!( - "{:?} -> {:?}", - String::from_utf8_lossy(&k), - String::from_utf8_lossy(&v), - ); - } - println!("\n# RoomLatestId -> RoomLatest:"); - for (k, v) in self.roomlatestid_roomlatest.iter().map(|r| r.unwrap()) { - println!( - "{:?} -> {:?}", - String::from_utf8_lossy(&k), - String::from_utf8_lossy(&v), - ); - } - println!("\n# RoomActiveId -> RoomActives:"); - for (k, v) in self.roomactiveid_roomactive.iter().map(|r| r.unwrap()) { - println!( - "{:?} -> {:?}", - String::from_utf8_lossy(&k), - String::from_utf8_lossy(&v), - ); - } - println!("\n# GlobalAllId -> GlobalAll:"); - for (k, v) in self.globalallid_globalall.iter().map(|r| r.unwrap()) { - println!( - "{:?} -> {:?}", - String::from_utf8_lossy(&k), - String::from_utf8_lossy(&v), - ); - } - println!("\n# GlobalLatestId -> GlobalLatest:"); - for (k, v) in self.globallatestid_globallatest.iter().map(|r| r.unwrap()) { - println!( - "{:?} -> {:?}", - String::from_utf8_lossy(&k), - String::from_utf8_lossy(&v), - ); + globals: globals::Globals::load(db.open_tree("global").unwrap(), hostname.to_owned()), + users: users::Users { + userid_password: db.open_tree("userid_password").unwrap(), + userdeviceid: db.open_tree("userdeviceid").unwrap(), + userid_displayname: db.open_tree("userid_displayname").unwrap(), + userid_avatarurl: db.open_tree("userid_avatarurl").unwrap(), + userdeviceid_token: db.open_tree("userdeviceid_token").unwrap(), + token_userid: db.open_tree("token_userid").unwrap(), + }, + rooms: rooms::Rooms { + edus: rooms::RoomEdus { + roomuserid_lastread: db.open_tree("roomuserid_lastread").unwrap(), + roomlatestid_roomlatest: db.open_tree("roomlatestid_roomlatest").unwrap(), + roomactiveid_roomactive: db.open_tree("roomactiveid_roomactive").unwrap(), + }, + pduid_pdu: db.open_tree("pduid_pdu").unwrap(), + eventid_pduid: db.open_tree("eventid_pduid").unwrap(), + roomid_pduleaves: db.open_tree("roomid_pduleaves").unwrap(), + roomstateid_pdu: db.open_tree("roomstateid_pdu").unwrap(), + + userroomid_joined: db.open_tree("userroomid_joined").unwrap(), + roomuserid_joined: db.open_tree("roomuserid_joined").unwrap(), + userroomid_invited: db.open_tree("userroomid_invited").unwrap(), + roomuserid_invited: db.open_tree("roomuserid_invited").unwrap(), + userroomid_left: db.open_tree("userroomid_left").unwrap(), + }, + account_data: account_data::AccountData { + roomuserdataid_accountdata: db.open_tree("roomuserdataid_accountdata").unwrap(), + }, + //globalallid_globalall: db.open_tree("globalallid_globalall").unwrap(), + //globallatestid_globallatest: db.open_tree("globallatestid_globallatest").unwrap(), + _db: db, } } } diff --git a/src/database/account_data.rs b/src/database/account_data.rs new file mode 100644 index 00000000..1d48232d --- /dev/null +++ b/src/database/account_data.rs @@ -0,0 +1,120 @@ +use crate::Result; +use ruma_events::{collections::only::Event as EduEvent, EventJson}; +use ruma_identifiers::{RoomId, UserId}; +use std::collections::HashMap; + +pub struct AccountData { + pub(super) roomuserdataid_accountdata: sled::Tree, // RoomUserDataId = Room + User + Count + Type +} + +impl AccountData { + /// Places one event in the account data of the user and removes the previous entry. + pub fn update( + &self, + room_id: Option<&RoomId>, + user_id: &UserId, + event: EduEvent, + globals: &super::globals::Globals, + ) -> Result<()> { + let mut prefix = room_id + .map(|r| r.to_string()) + .unwrap_or_default() + .as_bytes() + .to_vec(); + prefix.push(0xff); + prefix.extend_from_slice(&user_id.to_string().as_bytes()); + prefix.push(0xff); + + // Remove old entry + if let Some(old) = self + .roomuserdataid_accountdata + .scan_prefix(&prefix) + .keys() + .rev() + .filter_map(|r| r.ok()) + .take_while(|key| key.starts_with(&prefix)) + .filter(|key| { + key.split(|&b| b == 0xff) + .nth(1) + .filter(|&user| user == user_id.to_string().as_bytes()) + .is_some() + }) + .next() + { + // This is the old room_latest + self.roomuserdataid_accountdata.remove(old)?; + println!("removed old account data"); + } + + let mut key = prefix; + key.extend_from_slice(&globals.next_count()?.to_be_bytes()); + key.push(0xff); + let json = serde_json::to_value(&event)?; + key.extend_from_slice(json["type"].as_str().unwrap().as_bytes()); + + self.roomuserdataid_accountdata + .insert(key, &*json.to_string()) + .unwrap(); + + Ok(()) + } + + // TODO: Optimize + /// Searches the account data for a specific kind. + pub fn get( + &self, + room_id: Option<&RoomId>, + user_id: &UserId, + kind: &str, + ) -> Result>> { + Ok(self.all(room_id, user_id)?.remove(kind)) + } + + /// Returns all changes to the account data that happened after `since`. + pub fn changes_since( + &self, + room_id: Option<&RoomId>, + user_id: &UserId, + since: u64, + ) -> Result>> { + let mut userdata = HashMap::new(); + + let mut prefix = room_id + .map(|r| r.to_string()) + .unwrap_or_default() + .as_bytes() + .to_vec(); + prefix.push(0xff); + prefix.extend_from_slice(&user_id.to_string().as_bytes()); + prefix.push(0xff); + + // Skip the data that's exactly at since, because we sent that last time + let mut first_possible = prefix.clone(); + first_possible.extend_from_slice(&(since + 1).to_be_bytes()); + + for json in self + .roomuserdataid_accountdata + .range(&*first_possible..) + .filter_map(|r| r.ok()) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(|(_, v)| serde_json::from_slice::(&v).unwrap()) + { + userdata.insert( + json["type"].as_str().unwrap().to_owned(), + serde_json::from_value::>(json) + .expect("userdata in db is valid"), + ); + } + + Ok(userdata) + } + + /// Returns all account data. + pub fn all( + &self, + room_id: Option<&RoomId>, + user_id: &UserId, + ) -> Result>> { + self.changes_since(room_id, user_id, 0) + } +} diff --git a/src/database/globals.rs b/src/database/globals.rs new file mode 100644 index 00000000..f9e9999b --- /dev/null +++ b/src/database/globals.rs @@ -0,0 +1,61 @@ +use crate::{utils, Result}; + +pub const COUNTER: &str = "c"; + +pub struct Globals { + pub(super) globals: sled::Tree, + hostname: String, + keypair: ruma_signatures::Ed25519KeyPair, + reqwest_client: reqwest::Client, +} + +impl Globals { + pub fn load(globals: sled::Tree, hostname: String) -> Self { + let keypair = ruma_signatures::Ed25519KeyPair::new( + &*globals + .update_and_fetch("keypair", utils::generate_keypair) + .unwrap() + .unwrap(), + "key1".to_owned(), + ) + .unwrap(); + + Self { + globals, + hostname, + keypair, + reqwest_client: reqwest::Client::new(), + } + } + + /// Returns the hostname of the server. + pub fn hostname(&self) -> &str { + &self.hostname + } + + /// Returns this server's keypair. + pub fn keypair(&self) -> &ruma_signatures::Ed25519KeyPair { + &self.keypair + } + + /// Returns a reqwest client which can be used to send requests. + pub fn reqwest_client(&self) -> &reqwest::Client { + &self.reqwest_client + } + + pub fn next_count(&self) -> Result { + Ok(utils::u64_from_bytes( + &self + .globals + .update_and_fetch(COUNTER, utils::increment)? + .expect("utils::increment will always put in a value"), + )) + } + + pub fn current_count(&self) -> Result { + Ok(self + .globals + .get(COUNTER)? + .map_or(0_u64, |bytes| utils::u64_from_bytes(&bytes))) + } +} diff --git a/src/database/rooms.rs b/src/database/rooms.rs new file mode 100644 index 00000000..f52888ca --- /dev/null +++ b/src/database/rooms.rs @@ -0,0 +1,547 @@ +mod edus; + +pub use edus::RoomEdus; + +use crate::{utils, Error, PduEvent, Result}; +use ruma_events::{room::power_levels::PowerLevelsEventContent, EventJson, EventType}; +use ruma_identifiers::{EventId, RoomId, UserId}; +use serde_json::json; +use std::{ + collections::HashMap, + convert::{TryFrom, TryInto}, + mem, +}; + +pub struct Rooms { + pub edus: edus::RoomEdus, + pub(super) pduid_pdu: sled::Tree, // PduId = RoomId + Count + pub(super) eventid_pduid: sled::Tree, + pub(super) roomid_pduleaves: sled::Tree, + pub(super) roomstateid_pdu: sled::Tree, // Room + StateType + StateKey + + pub(super) userroomid_joined: sled::Tree, + pub(super) roomuserid_joined: sled::Tree, + pub(super) userroomid_invited: sled::Tree, + pub(super) roomuserid_invited: sled::Tree, + pub(super) userroomid_left: sled::Tree, +} + +impl Rooms { + /// Checks if a room exists. + pub fn exists(&self, room_id: &RoomId) -> Result { + // Look for PDUs in that room. + + let mut prefix = room_id.to_string().as_bytes().to_vec(); + prefix.push(0xff); + + Ok(self + .pduid_pdu + .get_gt(&prefix)? + .filter(|(k, _)| k.starts_with(&prefix)) + .is_some()) + } + + // TODO: Remove and replace with public room dir + /// Returns a vector over all rooms. + pub fn all_rooms(&self) -> Vec { + let mut room_ids = self + .roomid_pduleaves + .iter() + .keys() + .map(|key| { + RoomId::try_from( + &*utils::string_from_bytes( + &key.unwrap() + .iter() + .copied() + .take_while(|&x| x != 0xff) // until delimiter + .collect::>(), + ) + .unwrap(), + ) + .unwrap() + }) + .collect::>(); + room_ids.dedup(); + room_ids + } + + /// Returns the full room state. + pub fn room_state(&self, room_id: &RoomId) -> Result> { + let mut hashmap = HashMap::new(); + for pdu in self + .roomstateid_pdu + .scan_prefix(&room_id.to_string().as_bytes()) + .values() + .map(|value| Ok::<_, Error>(serde_json::from_slice::(&value?)?)) + { + let pdu = pdu?; + hashmap.insert( + ( + pdu.kind.clone(), + pdu.state_key + .clone() + .expect("state events have a state key"), + ), + pdu, + ); + } + Ok(hashmap) + } + + /// Returns the `count` of this pdu's id. + pub fn get_pdu_count(&self, event_id: &EventId) -> Result> { + Ok(self + .eventid_pduid + .get(event_id.to_string().as_bytes())? + .map(|pdu_id| { + utils::u64_from_bytes(&pdu_id[pdu_id.len() - mem::size_of::()..pdu_id.len()]) + })) + } + + /// Returns the json of a pdu. + pub fn get_pdu_json(&self, event_id: &EventId) -> Result> { + self.eventid_pduid + .get(event_id.to_string().as_bytes())? + .map_or(Ok(None), |pdu_id| { + Ok(serde_json::from_slice( + &self.pduid_pdu.get(pdu_id)?.ok_or(Error::BadDatabase( + "eventid_pduid points to nonexistent pdu", + ))?, + )?) + .map(Some) + }) + } + + /// Returns the leaf pdus of a room. + pub fn get_pdu_leaves(&self, room_id: &RoomId) -> Result> { + let mut prefix = room_id.to_string().as_bytes().to_vec(); + prefix.push(0xff); + + let mut events = Vec::new(); + + for event in self + .roomid_pduleaves + .scan_prefix(prefix) + .values() + .map(|bytes| Ok::<_, Error>(EventId::try_from(&*utils::string_from_bytes(&bytes?)?)?)) + { + events.push(event?); + } + + Ok(events) + } + + /// Replace the leaves of a room with a new event. + pub fn replace_pdu_leaves(&self, room_id: &RoomId, event_id: &EventId) -> Result<()> { + let mut prefix = room_id.to_string().as_bytes().to_vec(); + prefix.push(0xff); + + for key in self.roomid_pduleaves.scan_prefix(&prefix).keys() { + self.roomid_pduleaves.remove(key?)?; + } + + prefix.extend_from_slice(event_id.to_string().as_bytes()); + self.roomid_pduleaves + .insert(&prefix, &*event_id.to_string())?; + + Ok(()) + } + + /// Creates a new persisted data unit and adds it to a room. + pub fn append_pdu( + &self, + room_id: RoomId, + sender: UserId, + event_type: EventType, + content: serde_json::Value, + unsigned: Option>, + state_key: Option, + globals: &super::globals::Globals, + ) -> Result { + // Is the event authorized? + if state_key.is_some() { + if let Some(pdu) = self + .room_state(&room_id)? + .get(&(EventType::RoomPowerLevels, "".to_owned())) + { + let power_levels = serde_json::from_value::>( + pdu.content.clone(), + )? + .deserialize()?; + + match event_type { + EventType::RoomMember => { + // Member events are okay for now (TODO) + } + _ if power_levels + .users + .get(&sender) + .unwrap_or(&power_levels.users_default) + <= &0.into() => + { + // Not authorized + return Err(Error::BadRequest("event not authorized")); + } + // User has sufficient power + _ => {} + } + } + } + + // prev_events are the leaves of the current graph. This method removes all leaves from the + // room and replaces them with our event + // TODO: Make sure this isn't called twice in parallel + let prev_events = self.get_pdu_leaves(&room_id)?; + + // Our depth is the maximum depth of prev_events + 1 + let depth = prev_events + .iter() + .filter_map(|event_id| Some(self.get_pdu_json(event_id).ok()??.get("depth")?.as_u64()?)) + .max() + .unwrap_or(0_u64) + + 1; + + let mut unsigned = unsigned.unwrap_or_default(); + // TODO: Optimize this to not load the whole room state? + if let Some(state_key) = &state_key { + if let Some(prev_pdu) = self + .room_state(&room_id)? + .get(&(event_type.clone(), state_key.clone())) + { + unsigned.insert("prev_content".to_owned(), prev_pdu.content.clone()); + } + } + + let mut pdu = PduEvent { + event_id: EventId::try_from("$thiswillbefilledinlater").expect("we know this is valid"), + room_id: room_id.clone(), + sender: sender.clone(), + origin: globals.hostname().to_owned(), + origin_server_ts: utils::millis_since_unix_epoch() + .try_into() + .expect("this only fails many years in the future"), + kind: event_type, + content, + state_key, + prev_events, + depth: depth + .try_into() + .expect("depth can overflow and should be deprecated..."), + auth_events: Vec::new(), + redacts: None, + unsigned, + hashes: ruma_federation_api::EventHash { + sha256: "aaa".to_owned(), + }, + signatures: HashMap::new(), + }; + + // Generate event id + pdu.event_id = EventId::try_from(&*format!( + "${}", + ruma_signatures::reference_hash(&serde_json::to_value(&pdu)?) + .expect("ruma can calculate reference hashes") + )) + .expect("ruma's reference hashes are correct"); + + let mut pdu_json = serde_json::to_value(&pdu)?; + ruma_signatures::hash_and_sign_event(globals.hostname(), globals.keypair(), &mut pdu_json) + .expect("our new event can be hashed and signed"); + + self.replace_pdu_leaves(&room_id, &pdu.event_id)?; + + // Increment the last index and use that + // This is also the next_batch/since value + let index = globals.next_count()?; + + let mut pdu_id = room_id.to_string().as_bytes().to_vec(); + pdu_id.push(0xff); + pdu_id.extend_from_slice(&index.to_be_bytes()); + + self.pduid_pdu.insert(&pdu_id, &*pdu_json.to_string())?; + + self.eventid_pduid + .insert(pdu.event_id.to_string(), pdu_id.clone())?; + + if let Some(state_key) = pdu.state_key { + let mut key = room_id.to_string().as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(pdu.kind.to_string().as_bytes()); + key.push(0xff); + key.extend_from_slice(state_key.to_string().as_bytes()); + self.roomstateid_pdu.insert(key, &*pdu_json.to_string())?; + } + + self.edus.room_read_set(&room_id, &sender, index)?; + + Ok(pdu.event_id) + } + + /// Returns an iterator over all PDUs in a room. + pub fn all_pdus(&self, room_id: &RoomId) -> Result>> { + self.pdus_since(room_id, 0) + } + + /// Returns an iterator over all events in a room that happened after the event with id `since`. + pub fn pdus_since( + &self, + room_id: &RoomId, + since: u64, + ) -> Result>> { + // Create the first part of the full pdu id + let mut pdu_id = room_id.to_string().as_bytes().to_vec(); + pdu_id.push(0xff); + pdu_id.extend_from_slice(&(since).to_be_bytes()); + + self.pdus_since_pduid(room_id, &pdu_id) + } + + /// Returns an iterator over all events in a room that happened after the event with id `since`. + pub fn pdus_since_pduid( + &self, + room_id: &RoomId, + pdu_id: &[u8], + ) -> Result>> { + // Create the first part of the full pdu id + let mut prefix = room_id.to_string().as_bytes().to_vec(); + prefix.push(0xff); + + Ok(self + .pduid_pdu + .range(pdu_id..) + // Skip the first pdu if it's exactly at since, because we sent that last time + .skip(if self.pduid_pdu.get(pdu_id)?.is_some() { + 1 + } else { + 0 + }) + .filter_map(|r| r.ok()) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(|(_, v)| Ok(serde_json::from_slice(&v)?))) + } + + /// Returns an iterator over all events in a room that happened before the event with id + /// `until` in reverse-chronological order. + pub fn pdus_until( + &self, + room_id: &RoomId, + until: u64, + ) -> impl Iterator> { + // Create the first part of the full pdu id + let mut prefix = room_id.to_string().as_bytes().to_vec(); + prefix.push(0xff); + + let mut current = prefix.clone(); + current.extend_from_slice(&until.to_be_bytes()); + + let current: &[u8] = ¤t; + + self.pduid_pdu + .range(..current) + .rev() + .filter_map(|r| r.ok()) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(|(_, v)| Ok(serde_json::from_slice(&v)?)) + } + + /// Makes a user join a room. + pub fn join( + &self, + room_id: &RoomId, + user_id: &UserId, + displayname: Option, + globals: &super::globals::Globals, + ) -> Result<()> { + if !self.exists(room_id)? { + return Err(Error::BadRequest("room does not exist")); + } + + let mut userroom_id = user_id.to_string().as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.to_string().as_bytes()); + + let mut roomuser_id = room_id.to_string().as_bytes().to_vec(); + roomuser_id.push(0xff); + roomuser_id.extend_from_slice(user_id.to_string().as_bytes()); + + self.userroomid_joined.insert(&userroom_id, &[])?; + self.roomuserid_joined.insert(&roomuser_id, &[])?; + self.userroomid_invited.remove(&userroom_id)?; + self.roomuserid_invited.remove(&roomuser_id)?; + self.userroomid_left.remove(&userroom_id)?; + + let mut content = json!({"membership": "join"}); + if let Some(displayname) = displayname { + content + .as_object_mut() + .unwrap() + .insert("displayname".to_owned(), displayname.into()); + } + + self.append_pdu( + room_id.clone(), + user_id.clone(), + EventType::RoomMember, + content, + None, + Some(user_id.to_string()), + globals, + )?; + + Ok(()) + } + + /// Makes a user leave a room. + pub fn leave( + &self, + sender: &UserId, + room_id: &RoomId, + user_id: &UserId, + globals: &super::globals::Globals, + ) -> Result<()> { + let mut userroom_id = user_id.to_string().as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.to_string().as_bytes()); + + let mut roomuser_id = room_id.to_string().as_bytes().to_vec(); + roomuser_id.push(0xff); + roomuser_id.extend_from_slice(user_id.to_string().as_bytes()); + + self.userroomid_joined.remove(&userroom_id)?; + self.roomuserid_joined.remove(&roomuser_id)?; + self.userroomid_invited.remove(&userroom_id)?; + self.roomuserid_invited.remove(&userroom_id)?; + self.userroomid_left.insert(&userroom_id, &[])?; + + self.append_pdu( + room_id.clone(), + sender.clone(), + EventType::RoomMember, + json!({"membership": "leave"}), + None, + Some(user_id.to_string()), + globals, + )?; + + Ok(()) + } + + /// Makes a user forget a room. + pub fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { + let mut userroom_id = user_id.to_string().as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.to_string().as_bytes()); + + self.userroomid_left.remove(userroom_id)?; + + Ok(()) + } + + /// Makes a user invite another user into room. + pub fn invite( + &self, + sender: &UserId, + room_id: &RoomId, + user_id: &UserId, + globals: &super::globals::Globals, + ) -> Result<()> { + let mut userroom_id = user_id.to_string().as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.to_string().as_bytes()); + + let mut roomuser_id = room_id.to_string().as_bytes().to_vec(); + roomuser_id.push(0xff); + roomuser_id.extend_from_slice(user_id.to_string().as_bytes()); + + self.userroomid_invited.insert(userroom_id, &[])?; + self.roomuserid_invited.insert(roomuser_id, &[])?; + + self.append_pdu( + room_id.clone(), + sender.clone(), + EventType::RoomMember, + json!({"membership": "invite"}), + None, + Some(user_id.to_string()), + globals, + )?; + + Ok(()) + } + + /// Returns an iterator over all rooms a user joined. + pub fn room_members(&self, room_id: &RoomId) -> impl Iterator> { + self.roomuserid_joined + .scan_prefix(room_id.to_string()) + .values() + .map(|key| { + Ok(UserId::try_from(&*utils::string_from_bytes( + &key? + .rsplit(|&b| b == 0xff) + .next() + .ok_or(Error::BadDatabase("userroomid is invalid"))?, + )?)?) + }) + } + + /// Returns an iterator over all rooms a user joined. + pub fn room_members_invited(&self, room_id: &RoomId) -> impl Iterator> { + self.roomuserid_invited + .scan_prefix(room_id.to_string()) + .keys() + .map(|key| { + Ok(UserId::try_from(&*utils::string_from_bytes( + &key? + .rsplit(|&b| b == 0xff) + .next() + .ok_or(Error::BadDatabase("userroomid is invalid"))?, + )?)?) + }) + } + + /// Returns an iterator over all rooms a user joined. + pub fn rooms_joined(&self, user_id: &UserId) -> impl Iterator> { + self.userroomid_joined + .scan_prefix(user_id.to_string()) + .keys() + .map(|key| { + Ok(RoomId::try_from(&*utils::string_from_bytes( + &key? + .rsplit(|&b| b == 0xff) + .next() + .ok_or(Error::BadDatabase("userroomid is invalid"))?, + )?)?) + }) + } + + /// Returns an iterator over all rooms a user was invited to. + pub fn rooms_invited(&self, user_id: &UserId) -> impl Iterator> { + self.userroomid_invited + .scan_prefix(&user_id.to_string()) + .keys() + .map(|key| { + Ok(RoomId::try_from(&*utils::string_from_bytes( + &key? + .rsplit(|&b| b == 0xff) + .next() + .ok_or(Error::BadDatabase("userroomid is invalid"))?, + )?)?) + }) + } + + /// Returns an iterator over all rooms a user left. + pub fn rooms_left(&self, user_id: &UserId) -> impl Iterator> { + self.userroomid_left + .scan_prefix(&user_id.to_string()) + .keys() + .map(|key| { + Ok(RoomId::try_from(&*utils::string_from_bytes( + &key? + .rsplit(|&b| b == 0xff) + .next() + .ok_or(Error::BadDatabase("userroomid is invalid"))?, + )?)?) + }) + } +} diff --git a/src/database/rooms/edus.rs b/src/database/rooms/edus.rs new file mode 100644 index 00000000..f2db5a44 --- /dev/null +++ b/src/database/rooms/edus.rs @@ -0,0 +1,190 @@ +use crate::{utils, Result}; +use ruma_events::{collections::only::Event as EduEvent, EventJson}; +use ruma_identifiers::{RoomId, UserId}; + +pub struct RoomEdus { + pub(in super::super) roomuserid_lastread: sled::Tree, // RoomUserId = Room + User + pub(in super::super) roomlatestid_roomlatest: sled::Tree, // Read Receipts, RoomLatestId = RoomId + Count + UserId + pub(in super::super) roomactiveid_roomactive: sled::Tree, // Typing, RoomActiveId = RoomId + TimeoutTime + Count +} + +impl RoomEdus { + /// Adds an event which will be saved until a new event replaces it (e.g. read receipt). + pub fn roomlatest_update( + &self, + user_id: &UserId, + room_id: &RoomId, + event: EduEvent, + globals: &super::super::globals::Globals, + ) -> Result<()> { + let mut prefix = room_id.to_string().as_bytes().to_vec(); + prefix.push(0xff); + + // Remove old entry + if let Some(old) = self + .roomlatestid_roomlatest + .scan_prefix(&prefix) + .keys() + .rev() + .filter_map(|r| r.ok()) + .take_while(|key| key.starts_with(&prefix)) + .find(|key| { + key.rsplit(|&b| b == 0xff).next().unwrap() == user_id.to_string().as_bytes() + }) + { + // This is the old room_latest + self.roomlatestid_roomlatest.remove(old)?; + } + + let mut room_latest_id = prefix; + room_latest_id.extend_from_slice(&globals.next_count()?.to_be_bytes()); + room_latest_id.push(0xff); + room_latest_id.extend_from_slice(&user_id.to_string().as_bytes()); + + self.roomlatestid_roomlatest + .insert(room_latest_id, &*serde_json::to_string(&event)?)?; + + Ok(()) + } + + /// Returns an iterator over the most recent read_receipts in a room that happened after the event with id `since`. + pub fn roomlatests_since( + &self, + room_id: &RoomId, + since: u64, + ) -> Result>>> { + let mut prefix = room_id.to_string().as_bytes().to_vec(); + prefix.push(0xff); + + let mut first_possible_edu = prefix.clone(); + first_possible_edu.extend_from_slice(&since.to_be_bytes()); + + Ok(self + .roomlatestid_roomlatest + .range(&*first_possible_edu..) + // Skip the first pdu if it's exactly at since, because we sent that last time + .skip( + if self + .roomlatestid_roomlatest + .get(first_possible_edu)? + .is_some() + { + 1 + } else { + 0 + }, + ) + .filter_map(|r| r.ok()) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(|(_, v)| Ok(serde_json::from_slice(&v)?))) + } + + /// Returns a vector of the most recent read_receipts in a room that happened after the event with id `since`. + pub fn roomlatests_all( + &self, + room_id: &RoomId, + ) -> Result>>> { + self.roomlatests_since(room_id, 0) + } + + /// Adds an event that will be saved until the `timeout` timestamp (e.g. typing notifications). + pub fn roomactive_add( + &self, + event: EduEvent, + room_id: &RoomId, + timeout: u64, + globals: &super::super::globals::Globals, + ) -> Result<()> { + let mut prefix = room_id.to_string().as_bytes().to_vec(); + prefix.push(0xff); + + // Cleanup all outdated edus before inserting a new one + for outdated_edu in self + .roomactiveid_roomactive + .scan_prefix(&prefix) + .keys() + .filter_map(|r| r.ok()) + .take_while(|k| { + utils::u64_from_bytes( + k.split(|&c| c == 0xff) + .nth(1) + .expect("roomactive has valid timestamp and delimiters"), + ) < utils::millis_since_unix_epoch() + }) + { + // This is an outdated edu (time > timestamp) + self.roomlatestid_roomlatest.remove(outdated_edu)?; + } + + let mut room_active_id = prefix; + room_active_id.extend_from_slice(&timeout.to_be_bytes()); + room_active_id.push(0xff); + room_active_id.extend_from_slice(&globals.next_count()?.to_be_bytes()); + + self.roomactiveid_roomactive + .insert(room_active_id, &*serde_json::to_string(&event)?)?; + + Ok(()) + } + + /// Removes an active event manually (before the timeout is reached). + pub fn roomactive_remove(&self, event: EduEvent, room_id: &RoomId) -> Result<()> { + let mut prefix = room_id.to_string().as_bytes().to_vec(); + prefix.push(0xff); + + let json = serde_json::to_string(&event)?; + + // Remove outdated entries + for outdated_edu in self + .roomactiveid_roomactive + .scan_prefix(&prefix) + .filter_map(|r| r.ok()) + .filter(|(_, v)| v == json.as_bytes()) + { + self.roomactiveid_roomactive.remove(outdated_edu.0)?; + } + + Ok(()) + } + + /// Returns an iterator over all active events (e.g. typing notifications). + pub fn roomactives_all( + &self, + room_id: &RoomId, + ) -> impl Iterator>> { + let mut prefix = room_id.to_string().as_bytes().to_vec(); + prefix.push(0xff); + + let mut first_active_edu = prefix.clone(); + first_active_edu.extend_from_slice(&utils::millis_since_unix_epoch().to_be_bytes()); + + self.roomactiveid_roomactive + .range(first_active_edu..) + .filter_map(|r| r.ok()) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(|(_, v)| Ok(serde_json::from_slice(&v)?)) + } + + /// Sets a private read marker at `count`. + pub fn room_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { + let mut key = room_id.to_string().as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(&user_id.to_string().as_bytes()); + + self.roomuserid_lastread.insert(key, &count.to_be_bytes())?; + + Ok(()) + } + + /// Returns the private read marker. + pub fn room_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + let mut key = room_id.to_string().as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(&user_id.to_string().as_bytes()); + + Ok(self + .roomuserid_lastread + .get(key)? + .map(|v| utils::u64_from_bytes(&v))) + } +} diff --git a/src/database/users.rs b/src/database/users.rs new file mode 100644 index 00000000..e3bf1d04 --- /dev/null +++ b/src/database/users.rs @@ -0,0 +1,144 @@ +use crate::{utils, Error, Result}; +use ruma_identifiers::UserId; +use std::convert::TryFrom; + +pub struct Users { + pub(super) userid_password: sled::Tree, + pub(super) userid_displayname: sled::Tree, + pub(super) userid_avatarurl: sled::Tree, + pub(super) userdeviceid: sled::Tree, + pub(super) userdeviceid_token: sled::Tree, + pub(super) token_userid: sled::Tree, +} + +impl Users { + /// Check if a user has an account on this homeserver. + pub fn exists(&self, user_id: &UserId) -> Result { + Ok(self.userid_password.contains_key(user_id.to_string())?) + } + + /// Create a new user account on this homeserver. + pub fn create(&self, user_id: &UserId, hash: &str) -> Result<()> { + self.userid_password.insert(user_id.to_string(), hash)?; + Ok(()) + } + + /// Find out which user an access token belongs to. + pub fn find_from_token(&self, token: &str) -> Result> { + self.token_userid.get(token)?.map_or(Ok(None), |bytes| { + utils::string_from_bytes(&bytes) + .and_then(|string| Ok(UserId::try_from(string)?)) + .map(Some) + }) + } + + /// Returns an iterator over all users on this homeserver. + pub fn iter(&self) -> impl Iterator> { + self.userid_password.iter().keys().map(|r| { + utils::string_from_bytes(&r?).and_then(|string| Ok(UserId::try_from(&*string)?)) + }) + } + + /// Returns the password hash for the given user. + pub fn password_hash(&self, user_id: &UserId) -> Result> { + self.userid_password + .get(user_id.to_string())? + .map_or(Ok(None), |bytes| utils::string_from_bytes(&bytes).map(Some)) + } + + /// Returns the displayname of a user on this homeserver. + pub fn displayname(&self, user_id: &UserId) -> Result> { + self.userid_displayname + .get(user_id.to_string())? + .map_or(Ok(None), |bytes| utils::string_from_bytes(&bytes).map(Some)) + } + + /// Sets a new displayname or removes it if displayname is None. You still need to nofify all rooms of this change. + pub fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()> { + if let Some(displayname) = displayname { + self.userid_displayname + .insert(user_id.to_string(), &*displayname)?; + } else { + self.userid_displayname.remove(user_id.to_string())?; + } + + Ok(()) + /* TODO: + for room_id in self.rooms_joined(user_id) { + self.pdu_append( + room_id.clone(), + user_id.clone(), + EventType::RoomMember, + json!({"membership": "join", "displayname": displayname}), + None, + Some(user_id.to_string()), + ); + } + */ + } + + /// Get a the avatar_url of a user. + pub fn avatar_url(&self, user_id: &UserId) -> Result> { + self.userid_avatarurl + .get(user_id.to_string())? + .map_or(Ok(None), |bytes| utils::string_from_bytes(&bytes).map(Some)) + } + + /// Sets a new avatar_url or removes it if avatar_url is None. + pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option) -> Result<()> { + if let Some(avatar_url) = avatar_url { + self.userid_avatarurl + .insert(user_id.to_string(), &*avatar_url)?; + } else { + self.userid_avatarurl.remove(user_id.to_string())?; + } + + Ok(()) + } + + /// Adds a new device to a user. + pub fn create_device(&self, user_id: &UserId, device_id: &str, token: &str) -> Result<()> { + if !self.exists(user_id)? { + return Err(Error::BadRequest( + "tried to create device for nonexistent user", + )); + } + + let mut key = user_id.to_string().as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(device_id.as_bytes()); + + self.userdeviceid.insert(key, &[])?; + + self.set_token(user_id, device_id, token)?; + + Ok(()) + } + + /// Replaces the access token of one device. + pub fn set_token(&self, user_id: &UserId, device_id: &str, token: &str) -> Result<()> { + let mut key = user_id.to_string().as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(device_id.as_bytes()); + + if self.userdeviceid.get(&key)?.is_none() { + return Err(Error::BadRequest( + "Tried to set token for nonexistent device", + )); + } + + // Remove old token + if let Some(old_token) = self.userdeviceid_token.get(&key)? { + self.token_userid.remove(old_token)?; + // It will be removed from userdeviceid_token by the insert later + } + + // Assign token to device_id + self.userdeviceid_token.insert(key, &*token)?; + + // Assign token to user + self.token_userid.insert(token, &*user_id.to_string())?; + + Ok(()) + } +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 00000000..71fd9180 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,36 @@ +use thiserror::Error; + +pub type Result = std::result::Result; + +#[derive(Error, Debug)] +pub enum Error { + #[error("problem with the database")] + SledError { + #[from] + source: sled::Error, + }, + #[error("tried to parse invalid string")] + StringFromBytesError { + #[from] + source: std::string::FromUtf8Error, + }, + #[error("tried to parse invalid identifier")] + SerdeJsonError { + #[from] + source: serde_json::Error, + }, + #[error("tried to parse invalid identifier")] + RumaIdentifierError { + #[from] + source: ruma_identifiers::Error, + }, + #[error("tried to parse invalid event")] + RumaEventError { + #[from] + source: ruma_events::InvalidEvent, + }, + #[error("bad request")] + BadRequest(&'static str), + #[error("problem in that database")] + BadDatabase(&'static str), +} diff --git a/src/main.rs b/src/main.rs index db97599b..3452423e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,9 @@ #![feature(proc_macro_hygiene, decl_macro)] +#![warn(rust_2018_idioms)] mod client_server; -mod data; mod database; +mod error; mod pdu; mod ruma_wrapper; mod server_server; @@ -11,8 +12,8 @@ mod utils; #[cfg(test)] mod test; -pub use data::Data; pub use database::Database; +pub use error::{Error, Result}; pub use pdu::PduEvent; pub use ruma_wrapper::{MatrixResult, Ruma}; @@ -75,7 +76,7 @@ fn setup_rocket() -> rocket::Rocket { ) .attach(AdHoc::on_attach("Config", |rocket| { let hostname = rocket.config().get_str("hostname").unwrap_or("localhost"); - let data = Data::load_or_create(&hostname); + let data = Database::load_or_create(&hostname); Ok(rocket.manage(data)) })) @@ -86,7 +87,6 @@ fn main() { if let Err(_) = std::env::var("RUST_LOG") { std::env::set_var("RUST_LOG", "warn"); } - pretty_env_logger::init(); setup_rocket().launch().unwrap(); } diff --git a/src/ruma_wrapper.rs b/src/ruma_wrapper.rs index 753edea4..7568573d 100644 --- a/src/ruma_wrapper.rs +++ b/src/ruma_wrapper.rs @@ -27,21 +27,21 @@ impl<'a, T: Endpoint> FromData<'a> for Ruma { type Borrowed = Self::Owned; fn transform<'r>( - _req: &'r Request, + _req: &'r Request<'_>, data: Data, ) -> TransformFuture<'r, Self::Owned, Self::Error> { Box::pin(async move { Transform::Owned(Success(data)) }) } fn from_data( - request: &'a Request, + request: &'a Request<'_>, outcome: Transformed<'a, Self>, ) -> FromDataFuture<'a, Self, Self::Error> { Box::pin(async move { let data = rocket::try_outcome!(outcome.owned()); let user_id = if T::METADATA.requires_authentication { - let data = request.guard::>().await.unwrap(); + let db = request.guard::>().await.unwrap(); // Get token from header or query value let token = match request @@ -56,7 +56,7 @@ impl<'a, T: Endpoint> FromData<'a> for Ruma { }; // Check if token is valid - match data.user_from_token(&token) { + match db.users.find_from_token(&token).unwrap() { // TODO: M_UNKNOWN_TOKEN None => return Failure((Status::Unauthorized, ())), Some(user_id) => Some(user_id), diff --git a/src/server_server.rs b/src/server_server.rs index 394757a1..bb43957f 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -1,7 +1,7 @@ -use crate::{Data, MatrixResult}; +use crate::{Database, MatrixResult}; use http::header::{HeaderValue, AUTHORIZATION}; use log::error; -use rocket::{get, post, put, response::content::Json, State}; +use rocket::{get, response::content::Json, State}; use ruma_api::Endpoint; use ruma_client_api::error::Error; use ruma_federation_api::{v1::get_server_version, v2::get_server_keys}; @@ -12,9 +12,9 @@ use std::{ time::{Duration, SystemTime}, }; -pub async fn request_well_known(data: &crate::Data, destination: &str) -> Option { +pub async fn request_well_known(db: &crate::Database, destination: &str) -> Option { let body: serde_json::Value = serde_json::from_str( - &data + &db.globals .reqwest_client() .get(&format!( "https://{}/.well-known/matrix/server", @@ -32,14 +32,14 @@ pub async fn request_well_known(data: &crate::Data, destination: &str) -> Option } pub async fn send_request( - data: &crate::Data, + db: &crate::Database, destination: String, request: T, ) -> Option { let mut http_request: http::Request<_> = request.try_into().unwrap(); let actual_destination = "https://".to_owned() - + &request_well_known(data, &destination) + + &request_well_known(db, &destination) .await .unwrap_or(destination.clone() + ":8448"); *http_request.uri_mut() = (actual_destination + T::METADATA.path).parse().unwrap(); @@ -55,11 +55,11 @@ pub async fn send_request( request_map.insert("method".to_owned(), T::METADATA.method.to_string().into()); request_map.insert("uri".to_owned(), T::METADATA.path.into()); - request_map.insert("origin".to_owned(), data.hostname().into()); + request_map.insert("origin".to_owned(), db.globals.hostname().into()); request_map.insert("destination".to_owned(), destination.into()); let mut request_json = request_map.into(); - ruma_signatures::sign_json(data.hostname(), data.keypair(), &mut request_json).unwrap(); + ruma_signatures::sign_json(db.globals.hostname(), db.globals.keypair(), &mut request_json).unwrap(); let signatures = request_json["signatures"] .as_object() @@ -77,7 +77,7 @@ pub async fn send_request( AUTHORIZATION, HeaderValue::from_str(&format!( "X-Matrix origin={},key=\"{}\",sig=\"{}\"", - data.hostname(), + db.globals.hostname(), s.0, s.1 )) @@ -85,7 +85,7 @@ pub async fn send_request( ); } - let reqwest_response = data.reqwest_client().execute(http_request.into()).await; + let reqwest_response = db.globals.reqwest_client().execute(http_request.into()).await; // Because reqwest::Response -> http::Response is complicated: match reqwest_response { @@ -120,7 +120,7 @@ pub async fn send_request( } #[get("/.well-known/matrix/server")] -pub fn well_known_server(data: State) -> Json { +pub fn well_known_server() -> Json { rocket::response::content::Json( json!({ "m.server": "matrixtesting.koesters.xyz:14004"}).to_string(), ) @@ -137,17 +137,17 @@ pub fn get_server_version() -> MatrixResult } #[get("/_matrix/key/v2/server")] -pub fn get_server_keys(data: State) -> Json { +pub fn get_server_keys(db: State<'_, Database>) -> Json { let mut verify_keys = BTreeMap::new(); verify_keys.insert( - format!("ed25519:{}", data.keypair().version()), + format!("ed25519:{}", db.globals.keypair().version()), get_server_keys::VerifyKey { - key: base64::encode_config(data.keypair().public_key(), base64::STANDARD_NO_PAD), + key: base64::encode_config(db.globals.keypair().public_key(), base64::STANDARD_NO_PAD), }, ); let mut response = serde_json::from_slice( http::Response::try_from(get_server_keys::Response { - server_name: data.hostname().to_owned(), + server_name: db.globals.hostname().to_owned(), verify_keys, old_verify_keys: BTreeMap::new(), signatures: BTreeMap::new(), @@ -157,11 +157,11 @@ pub fn get_server_keys(data: State) -> Json { .body(), ) .unwrap(); - ruma_signatures::sign_json(data.hostname(), data.keypair(), &mut response).unwrap(); + ruma_signatures::sign_json(db.globals.hostname(), db.globals.keypair(), &mut response).unwrap(); Json(response.to_string()) } #[get("/_matrix/key/v2/server/<_key_id>")] -pub fn get_server_keys_deprecated(data: State, _key_id: String) -> Json { - get_server_keys(data) +pub fn get_server_keys_deprecated(db: State<'_, Database>, _key_id: String) -> Json { + get_server_keys(db) } diff --git a/src/test.rs b/src/test.rs index 9f562149..1df3a9da 100644 --- a/src/test.rs +++ b/src/test.rs @@ -1,8 +1,6 @@ use super::*; -use rocket::{http::Status, local::Client}; -use ruma_client_api::error::ErrorKind; +use rocket::local::Client; use serde_json::{json, Value}; -use std::time::Duration; fn setup_client() -> Client { Database::try_remove("localhost"); diff --git a/src/utils.rs b/src/utils.rs index a3600360..45b9b085 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,3 +1,4 @@ +use crate::Result; use argon2::{Config, Variant}; use rand::prelude::*; use std::{ @@ -32,13 +33,15 @@ pub fn generate_keypair(old: Option<&[u8]>) -> Option> { ) } +/// Parses the bytes into an u64. pub fn u64_from_bytes(bytes: &[u8]) -> u64 { let array: [u8; 8] = bytes.try_into().expect("bytes are valid u64"); u64::from_be_bytes(array) } -pub fn string_from_bytes(bytes: &[u8]) -> String { - String::from_utf8(bytes.to_vec()).expect("bytes are valid utf8") +/// Parses the bytes into a string. +pub fn string_from_bytes(bytes: &[u8]) -> Result { + Ok(String::from_utf8(bytes.to_vec())?) } pub fn random_string(length: usize) -> String { @@ -49,7 +52,7 @@ pub fn random_string(length: usize) -> String { } /// Calculate a new hash for the given password -pub fn calculate_hash(password: &str) -> Result { +pub fn calculate_hash(password: &str) -> std::result::Result { let hashing_config = Config { variant: Variant::Argon2id, ..Default::default()