diff --git a/Cargo.lock b/Cargo.lock index 8527dba3..fe80fb6c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1263,7 +1263,7 @@ dependencies = [ [[package]] name = "ruma" version = "0.1.0" -source = "git+https://github.com/ruma/ruma?rev=f6fb971329a4a5a7faeebf7ea47a86cd19e580f4#f6fb971329a4a5a7faeebf7ea47a86cd19e580f4" +source = "git+https://github.com/ruma/ruma?rev=12388c3fbc8ba2a685cbf0fe810c633c827f5b2c#12388c3fbc8ba2a685cbf0fe810c633c827f5b2c" dependencies = [ "ruma-api", "ruma-client-api", @@ -1277,7 +1277,7 @@ dependencies = [ [[package]] name = "ruma-api" version = "0.16.1" -source = "git+https://github.com/ruma/ruma?rev=f6fb971329a4a5a7faeebf7ea47a86cd19e580f4#f6fb971329a4a5a7faeebf7ea47a86cd19e580f4" +source = "git+https://github.com/ruma/ruma?rev=12388c3fbc8ba2a685cbf0fe810c633c827f5b2c#12388c3fbc8ba2a685cbf0fe810c633c827f5b2c" dependencies = [ "http", "percent-encoding 2.1.0", @@ -1292,7 +1292,7 @@ dependencies = [ [[package]] name = "ruma-api-macros" version = "0.16.1" -source = "git+https://github.com/ruma/ruma?rev=f6fb971329a4a5a7faeebf7ea47a86cd19e580f4#f6fb971329a4a5a7faeebf7ea47a86cd19e580f4" +source = "git+https://github.com/ruma/ruma?rev=12388c3fbc8ba2a685cbf0fe810c633c827f5b2c#12388c3fbc8ba2a685cbf0fe810c633c827f5b2c" dependencies = [ "proc-macro2 1.0.18", "quote 1.0.6", @@ -1302,7 +1302,7 @@ dependencies = [ [[package]] name = "ruma-client-api" version = "0.9.0" -source = "git+https://github.com/ruma/ruma?rev=f6fb971329a4a5a7faeebf7ea47a86cd19e580f4#f6fb971329a4a5a7faeebf7ea47a86cd19e580f4" +source = "git+https://github.com/ruma/ruma?rev=12388c3fbc8ba2a685cbf0fe810c633c827f5b2c#12388c3fbc8ba2a685cbf0fe810c633c827f5b2c" dependencies = [ "http", "js_int", @@ -1319,7 +1319,7 @@ dependencies = [ [[package]] name = "ruma-common" version = "0.1.3" -source = "git+https://github.com/ruma/ruma?rev=f6fb971329a4a5a7faeebf7ea47a86cd19e580f4#f6fb971329a4a5a7faeebf7ea47a86cd19e580f4" +source = "git+https://github.com/ruma/ruma?rev=12388c3fbc8ba2a685cbf0fe810c633c827f5b2c#12388c3fbc8ba2a685cbf0fe810c633c827f5b2c" dependencies = [ "matches", "ruma-serde", @@ -1356,8 +1356,7 @@ dependencies = [ [[package]] name = "ruma-federation-api" version = "0.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff655a4cb7d43b60b18e07a601889836c1c12854bb16f4c083826b664fdc55aa" +source = "git+https://github.com/ruma/ruma?rev=12388c3fbc8ba2a685cbf0fe810c633c827f5b2c#12388c3fbc8ba2a685cbf0fe810c633c827f5b2c" dependencies = [ "js_int", "matches", @@ -1382,7 +1381,7 @@ dependencies = [ [[package]] name = "ruma-serde" version = "0.2.2" -source = "git+https://github.com/ruma/ruma?rev=f6fb971329a4a5a7faeebf7ea47a86cd19e580f4#f6fb971329a4a5a7faeebf7ea47a86cd19e580f4" +source = "git+https://github.com/ruma/ruma?rev=12388c3fbc8ba2a685cbf0fe810c633c827f5b2c#12388c3fbc8ba2a685cbf0fe810c633c827f5b2c" dependencies = [ "dtoa", "itoa", @@ -1395,7 +1394,7 @@ dependencies = [ [[package]] name = "ruma-signatures" version = "0.6.0-dev.1" -source = "git+https://github.com/ruma/ruma?rev=f6fb971329a4a5a7faeebf7ea47a86cd19e580f4#f6fb971329a4a5a7faeebf7ea47a86cd19e580f4" +source = "git+https://github.com/ruma/ruma?rev=12388c3fbc8ba2a685cbf0fe810c633c827f5b2c#12388c3fbc8ba2a685cbf0fe810c633c827f5b2c" dependencies = [ "base64 0.12.1", "ring", diff --git a/Cargo.toml b/Cargo.toml index 3c5c9fab..38c7530e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,12 +30,13 @@ image = { version = "0.23.4", default-features = false, features = ["jpeg", "png [dependencies.ruma] git = "https://github.com/ruma/ruma" -rev = "f6fb971329a4a5a7faeebf7ea47a86cd19e580f4" +rev = "12388c3fbc8ba2a685cbf0fe810c633c827f5b2c" +#path = "../ruma/ruma" features = ["rand", "client-api", "federation-api"] # These are required only until ruma-events and ruma-federation-api are merged into ruma/ruma [patch.crates-io] -ruma-api = { git = "https://github.com/ruma/ruma", rev = "f6fb971329a4a5a7faeebf7ea47a86cd19e580f4" } -ruma-common = { git = "https://github.com/ruma/ruma", rev = "f6fb971329a4a5a7faeebf7ea47a86cd19e580f4" } -ruma-events = { git = "https://github.com/ruma/ruma-events", rev = "c1ee72d" } -ruma-serde = { git = "https://github.com/ruma/ruma", rev = "f6fb971329a4a5a7faeebf7ea47a86cd19e580f4" } +ruma-common = { git = "https://github.com/ruma/ruma", rev = "12388c3fbc8ba2a685cbf0fe810c633c827f5b2c" } +ruma-serde = { git = "https://github.com/ruma/ruma", rev = "12388c3fbc8ba2a685cbf0fe810c633c827f5b2c" } +#ruma-common = { path = "../ruma/ruma-common" } +#ruma-serde = { path = "../ruma/ruma-serde" } diff --git a/Rocket-example.toml b/Rocket-example.toml index 924b540e..30a3c3a0 100644 --- a/Rocket-example.toml +++ b/Rocket-example.toml @@ -1,11 +1,13 @@ [global] server_name = "your.server.name" -port = 8448 -address = "0.0.0.0" +#registration_disabled = true # Default path is in this user's data #database_path = "/home/timo/MyConduitServer" +port = 14004 +address = "0.0.0.0" + #[global.tls] #certs = "/etc/letsencrypt/live/your.server.name/fullchain.pem" #key = "/etc/letsencrypt/live/your.server.name/privkey.pem" diff --git a/src/client_server.rs b/src/client_server.rs index 057b473d..6e0c40a2 100644 --- a/src/client_server.rs +++ b/src/client_server.rs @@ -56,15 +56,15 @@ use ruma::{ room::{canonical_alias, guest_access, history_visibility, join_rules, member, redaction}, EventJson, EventType, }, - identifiers::{DeviceId, RoomAliasId, RoomId, RoomVersionId, UserId}, + identifiers::{RoomAliasId, RoomId, RoomVersionId, UserId}, }; use serde_json::{json, value::RawValue}; const GUEST_NAME_LENGTH: usize = 10; const DEVICE_ID_LENGTH: usize = 10; -const SESSION_ID_LENGTH: usize = 256; const TOKEN_LENGTH: usize = 256; const MXC_LENGTH: usize = 256; +const SESSION_ID_LENGTH: usize = 256; #[get("/_matrix/client/versions")] pub fn get_supported_versions_route() -> MatrixResult { @@ -117,15 +117,11 @@ pub fn register_route( db: State<'_, Database>, body: Ruma, ) -> MatrixResult { - if body.auth.is_none() { - return MatrixResult(Err(UiaaResponse::AuthResponse(UiaaInfo { - flows: vec![AuthFlow { - stages: vec!["m.login.dummy".to_owned()], - }], - completed: vec![], - params: RawValue::from_string("{}".to_owned()).unwrap(), - session: Some(utils::random_string(SESSION_ID_LENGTH)), - auth_error: None, + if db.globals.registration_disabled() { + return MatrixResult(Err(UiaaResponse::MatrixError(Error { + kind: ErrorKind::Unknown, + message: "Registration has been disabled.".to_owned(), + status_code: http::StatusCode::FORBIDDEN, }))); } @@ -161,6 +157,31 @@ pub fn register_route( }))); } + // UIAA + let uiaainfo = UiaaInfo { + flows: vec![AuthFlow { + stages: vec!["m.login.dummy".to_owned()], + }], + completed: Vec::new(), + params: Default::default(), + session: Some(utils::random_string(SESSION_ID_LENGTH)), + auth_error: None, + }; + + if let Some(auth) = &body.auth { + let (worked, uiaainfo) = db + .uiaa + .try_auth(&user_id, "", auth, &uiaainfo, &db.users, &db.globals) + .unwrap(); + if !worked { + return MatrixResult(Err(UiaaResponse::AuthResponse(uiaainfo))); + } + // Success! + } else { + db.uiaa.create(&user_id, "", &uiaainfo).unwrap(); + return MatrixResult(Err(UiaaResponse::AuthResponse(uiaainfo))); + } + let password = body.password.clone().unwrap_or_default(); if let Ok(hash) = utils::calculate_hash(&password) { @@ -575,7 +596,7 @@ pub fn get_displayname_route( body: Ruma, _user_id: String, ) -> MatrixResult { - let user_id = (*body).user_id.clone(); + let user_id = body.body.user_id.clone(); MatrixResult(Ok(get_display_name::Response { displayname: db.users.displayname(&user_id).unwrap(), })) @@ -666,7 +687,7 @@ pub fn get_avatar_url_route( body: Ruma, _user_id: String, ) -> MatrixResult { - let user_id = (*body).user_id.clone(); + let user_id = body.body.user_id.clone(); MatrixResult(Ok(get_avatar_url::Response { avatar_url: db.users.avatar_url(&user_id).unwrap(), })) @@ -678,7 +699,7 @@ pub fn get_profile_route( body: Ruma, _user_id: String, ) -> MatrixResult { - let user_id = (*body).user_id.clone(); + let user_id = body.body.user_id.clone(); let avatar_url = db.users.avatar_url(&user_id).unwrap(); let displayname = db.users.displayname(&user_id).unwrap(); @@ -2316,50 +2337,51 @@ pub fn sync_route( ); } - joined_rooms.insert( - room_id.clone().try_into().unwrap(), - sync_events::JoinedRoom { - account_data: Some(sync_events::AccountData { - events: db - .account_data - .changes_since(Some(&room_id), &user_id, since) - .unwrap() - .into_iter() - .map(|(_, v)| v) - .collect(), - }), - summary: sync_events::RoomSummary { - heroes, - joined_member_count: joined_member_count.map(|n| (n as u32).into()), - invited_member_count: invited_member_count.map(|n| (n as u32).into()), - }, - unread_notifications: sync_events::UnreadNotificationsCount { - highlight_count: None, - notification_count, - }, - timeline: sync_events::Timeline { - limited: if limited || joined_since_last_sync { - Some(true) - } else { - None - }, - prev_batch, - events: room_events, + let joined_room = sync_events::JoinedRoom { + account_data: sync_events::AccountData { + events: db + .account_data + .changes_since(Some(&room_id), &user_id, since) + .unwrap() + .into_iter() + .map(|(_, v)| v) + .collect(), + }, + summary: sync_events::RoomSummary { + heroes, + joined_member_count: joined_member_count.map(|n| (n as u32).into()), + invited_member_count: invited_member_count.map(|n| (n as u32).into()), + }, + unread_notifications: sync_events::UnreadNotificationsCount { + highlight_count: None, + notification_count, + }, + timeline: sync_events::Timeline { + limited: if limited || joined_since_last_sync { + Some(true) + } else { + None }, - // TODO: state before timeline - state: sync_events::State { - events: if joined_since_last_sync { - state - .into_iter() - .map(|(_, pdu)| pdu.to_state_event()) - .collect() - } else { - Vec::new() - }, + prev_batch, + events: room_events, + }, + // TODO: state before timeline + state: sync_events::State { + events: if joined_since_last_sync { + state + .into_iter() + .map(|(_, pdu)| pdu.to_state_event()) + .collect() + } else { + Vec::new() }, - ephemeral: sync_events::Ephemeral { events: edus }, }, - ); + ephemeral: sync_events::Ephemeral { events: edus }, + }; + + if !joined_room.is_empty() { + joined_rooms.insert(room_id.clone().try_into().unwrap(), joined_room); + } } let mut left_rooms = BTreeMap::new(); @@ -2368,6 +2390,7 @@ pub fn sync_route( let pdus = db.rooms.pdus_since(&room_id, since).unwrap(); let room_events = pdus.map(|pdu| pdu.unwrap().to_room_event()).collect(); + // TODO: Only until leave point let mut edus = db .rooms .edus @@ -2394,38 +2417,40 @@ pub fn sync_route( ); } - left_rooms.insert( - room_id.clone().try_into().unwrap(), - sync_events::LeftRoom { - account_data: Some(sync_events::AccountData { events: Vec::new() }), - timeline: sync_events::Timeline { - limited: Some(false), - prev_batch: Some(next_batch.clone()), - events: room_events, - }, - state: sync_events::State { events: Vec::new() }, + let left_room = sync_events::LeftRoom { + account_data: sync_events::AccountData { events: Vec::new() }, + timeline: sync_events::Timeline { + limited: Some(false), + prev_batch: Some(next_batch.clone()), + events: room_events, }, - ); + state: sync_events::State { events: Vec::new() }, + }; + + if !left_room.is_empty() { + left_rooms.insert(room_id.clone().try_into().unwrap(), left_room); + } } let mut invited_rooms = BTreeMap::new(); for room_id in db.rooms.rooms_invited(&user_id) { let room_id = room_id.unwrap(); - invited_rooms.insert( - room_id.clone(), - sync_events::InvitedRoom { - invite_state: sync_events::InviteState { - events: db - .rooms - .room_state(&room_id) - .unwrap() - .into_iter() - .map(|(_, pdu)| pdu.to_stripped_state_event()) - .collect(), - }, + let invited_room = sync_events::InvitedRoom { + invite_state: sync_events::InviteState { + events: db + .rooms + .room_state(&room_id) + .unwrap() + .into_iter() + .map(|(_, pdu)| pdu.to_stripped_state_event()) + .collect(), }, - ); + }; + + if !invited_room.is_empty() { + invited_rooms.insert(room_id.clone(), invited_room); + } } MatrixResult(Ok(sync_events::Response { @@ -2460,17 +2485,16 @@ pub fn sync_route( .map(|(_, v)| v) .collect(), }, - device_lists: if since != 0 { - Some(sync_events::DeviceLists { - changed: db - .users + device_lists: sync_events::DeviceLists { + changed: if since != 0 { + db.users .device_keys_changed(since) .map(|u| u.unwrap()) - .collect(), - left: Vec::new(), // TODO - }) - } else { - None // TODO: left + .collect() + } else { + Vec::new() + }, + left: Vec::new(), // TODO }, device_one_time_keys_count: Default::default(), // TODO to_device: sync_events::ToDevice { @@ -2816,14 +2840,18 @@ pub fn get_devices_route( MatrixResult(Ok(get_devices::Response { devices })) } -#[get("/_matrix/client/r0/devices/", data = "")] +#[get("/_matrix/client/r0/devices/<_device_id>", data = "")] pub fn get_device_route( db: State<'_, Database>, body: Ruma, - device_id: DeviceId, + _device_id: String, ) -> MatrixResult { let user_id = body.user_id.as_ref().expect("user is authenticated"); - let device = db.users.get_device_metadata(&user_id, &device_id).unwrap(); + + let device = db + .users + .get_device_metadata(&user_id, &body.body.device_id) + .unwrap(); match device { None => MatrixResult(Err(Error { @@ -2835,14 +2863,18 @@ pub fn get_device_route( } } -#[put("/_matrix/client/r0/devices/", data = "")] +#[put("/_matrix/client/r0/devices/<_device_id>", data = "")] pub fn update_device_route( db: State<'_, Database>, body: Ruma, - device_id: DeviceId, + _device_id: String, ) -> MatrixResult { let user_id = body.user_id.as_ref().expect("user is authenticated"); - let device = db.users.get_device_metadata(&user_id, &device_id).unwrap(); + + let device = db + .users + .get_device_metadata(&user_id, &body.body.device_id) + .unwrap(); match device { None => MatrixResult(Err(Error { @@ -2854,7 +2886,7 @@ pub fn update_device_route( device.display_name = body.display_name.clone(); db.users - .update_device_metadata(&user_id, &device_id, &device) + .update_device_metadata(&user_id, &body.body.device_id, &device) .unwrap(); MatrixResult(Ok(update_device::Response)) @@ -2862,14 +2894,50 @@ pub fn update_device_route( } } -#[delete("/_matrix/client/r0/devices/", data = "")] +#[delete("/_matrix/client/r0/devices/<_device_id>", data = "")] pub fn delete_device_route( db: State<'_, Database>, body: Ruma, - device_id: DeviceId, -) -> MatrixResult { + _device_id: String, +) -> MatrixResult { let user_id = body.user_id.as_ref().expect("user is authenticated"); - db.users.remove_device(&user_id, &device_id).unwrap(); + let device_id = body.device_id.as_ref().expect("user is authenticated"); + + // UIAA + let uiaainfo = UiaaInfo { + flows: vec![AuthFlow { + stages: vec!["m.login.password".to_owned()], + }], + completed: Vec::new(), + params: Default::default(), + session: Some(utils::random_string(SESSION_ID_LENGTH)), + auth_error: None, + }; + + if let Some(auth) = &body.auth { + let (worked, uiaainfo) = db + .uiaa + .try_auth( + &user_id, + &device_id, + auth, + &uiaainfo, + &db.users, + &db.globals, + ) + .unwrap(); + if !worked { + return MatrixResult(Err(UiaaResponse::AuthResponse(uiaainfo))); + } + // Success! + } else { + db.uiaa.create(&user_id, &device_id, &uiaainfo).unwrap(); + return MatrixResult(Err(UiaaResponse::AuthResponse(uiaainfo))); + } + + db.users + .remove_device(&user_id, &body.body.device_id) + .unwrap(); MatrixResult(Ok(delete_device::Response)) } @@ -2878,8 +2946,42 @@ pub fn delete_device_route( pub fn delete_devices_route( db: State<'_, Database>, body: Ruma, -) -> MatrixResult { +) -> MatrixResult { let user_id = body.user_id.as_ref().expect("user is authenticated"); + let device_id = body.device_id.as_ref().expect("user is authenticated"); + + // UIAA + let uiaainfo = UiaaInfo { + flows: vec![AuthFlow { + stages: vec!["m.login.password".to_owned()], + }], + completed: Vec::new(), + params: Default::default(), + session: Some(utils::random_string(SESSION_ID_LENGTH)), + auth_error: None, + }; + + if let Some(auth) = &body.auth { + let (worked, uiaainfo) = db + .uiaa + .try_auth( + &user_id, + &device_id, + auth, + &uiaainfo, + &db.users, + &db.globals, + ) + .unwrap(); + if !worked { + return MatrixResult(Err(UiaaResponse::AuthResponse(uiaainfo))); + } + // Success! + } else { + db.uiaa.create(&user_id, &device_id, &uiaainfo).unwrap(); + return MatrixResult(Err(UiaaResponse::AuthResponse(uiaainfo))); + } + for device_id in &body.devices { db.users.remove_device(&user_id, &device_id).unwrap() } diff --git a/src/database.rs b/src/database.rs index dc78ba9a..34af8fc2 100644 --- a/src/database.rs +++ b/src/database.rs @@ -3,9 +3,11 @@ pub(self) mod global_edus; pub(self) mod globals; pub(self) mod media; pub(self) mod rooms; +pub(self) mod uiaa; pub(self) mod users; use directories::ProjectDirs; +use log::info; use std::fs::remove_dir_all; use rocket::Config; @@ -13,6 +15,7 @@ use rocket::Config; pub struct Database { pub globals: globals::Globals, pub users: users::Users, + pub uiaa: uiaa::Uiaa, pub rooms: rooms::Rooms, pub account_data: account_data::AccountData, pub global_edus: global_edus::GlobalEdus, @@ -47,13 +50,10 @@ impl Database { }); let db = sled::open(&path).unwrap(); - log::info!("Opened sled database at {}", path); + info!("Opened sled database at {}", path); Self { - globals: globals::Globals::load( - db.open_tree("global").unwrap(), - server_name.to_owned(), - ), + globals: globals::Globals::load(db.open_tree("global").unwrap(), config), users: users::Users { userid_password: db.open_tree("userid_password").unwrap(), userid_displayname: db.open_tree("userid_displayname").unwrap(), @@ -66,6 +66,9 @@ impl Database { devicekeychangeid_userid: db.open_tree("devicekeychangeid_userid").unwrap(), todeviceid_events: db.open_tree("todeviceid_events").unwrap(), }, + uiaa: uiaa::Uiaa { + userdeviceid_uiaainfo: db.open_tree("userdeviceid_uiaainfo").unwrap(), + }, rooms: rooms::Rooms { edus: rooms::RoomEdus { roomuserid_lastread: db.open_tree("roomuserid_lastread").unwrap(), // "Private" read receipt diff --git a/src/database/globals.rs b/src/database/globals.rs index 93d5794c..08ab411f 100644 --- a/src/database/globals.rs +++ b/src/database/globals.rs @@ -4,13 +4,14 @@ pub const COUNTER: &str = "c"; pub struct Globals { pub(super) globals: sled::Tree, - server_name: String, keypair: ruma::signatures::Ed25519KeyPair, reqwest_client: reqwest::Client, + server_name: String, + registration_disabled: bool, } impl Globals { - pub fn load(globals: sled::Tree, server_name: String) -> Self { + pub fn load(globals: sled::Tree, config: &rocket::Config) -> Self { let keypair = ruma::signatures::Ed25519KeyPair::new( &*globals .update_and_fetch("keypair", utils::generate_keypair) @@ -22,17 +23,16 @@ impl Globals { Self { globals, - server_name, keypair, reqwest_client: reqwest::Client::new(), + server_name: config + .get_str("server_name") + .unwrap_or("localhost") + .to_owned(), + registration_disabled: config.get_bool("registration_disabled").unwrap_or(false), } } - /// Returns the server_name of the server. - pub fn server_name(&self) -> &str { - &self.server_name - } - /// Returns this server's keypair. pub fn keypair(&self) -> &ruma::signatures::Ed25519KeyPair { &self.keypair @@ -58,4 +58,12 @@ impl Globals { .get(COUNTER)? .map_or(0_u64, |bytes| utils::u64_from_bytes(&bytes))) } + + pub fn server_name(&self) -> &str { + &self.server_name + } + + pub fn registration_disabled(&self) -> bool { + self.registration_disabled + } } diff --git a/src/database/rooms.rs b/src/database/rooms.rs index 5d9da485..fa422def 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -421,7 +421,7 @@ impl Rooms { auth_events: Vec::new(), redacts: redacts.clone(), unsigned, - hashes: ruma::api::federation::EventHash { + hashes: ruma::api::federation::pdu::EventHash { sha256: "aaa".to_owned(), }, signatures: HashMap::new(), diff --git a/src/database/uiaa.rs b/src/database/uiaa.rs new file mode 100644 index 00000000..6cd25b99 --- /dev/null +++ b/src/database/uiaa.rs @@ -0,0 +1,161 @@ +use crate::{utils, Error, Result}; +use js_int::UInt; +use log::debug; +use ruma::{ + api::client::{ + error::ErrorKind, + r0::{ + device::Device, + keys::{AlgorithmAndDeviceId, DeviceKeys, KeyAlgorithm, OneTimeKey}, + uiaa::{AuthData, AuthFlow, UiaaInfo, UiaaResponse}, + }, + }, + events::{to_device::AnyToDeviceEvent, EventJson, EventType}, + identifiers::UserId, +}; +use serde_json::value::RawValue; +use std::{collections::BTreeMap, convert::TryFrom, time::SystemTime}; + +pub struct Uiaa { + pub(super) userdeviceid_uiaainfo: sled::Tree, // User-interactive authentication +} + +impl Uiaa { + /// Creates a new Uiaa session. Make sure the session token is unique. + pub fn create(&self, user_id: &UserId, device_id: &str, uiaainfo: &UiaaInfo) -> Result<()> { + self.update_uiaa_session(user_id, device_id, Some(uiaainfo)) + } + + pub fn try_auth( + &self, + user_id: &UserId, + device_id: &str, + auth: &AuthData, + uiaainfo: &UiaaInfo, + users: &super::users::Users, + globals: &super::globals::Globals, + ) -> Result<(bool, UiaaInfo)> { + if let AuthData::DirectRequest { + kind, + session, + auth_parameters, + } = &auth + { + let mut uiaainfo = session + .as_ref() + .map(|session| { + Ok::<_, Error>(self.get_uiaa_session(&user_id, &"".to_owned(), session)?) + }) + .unwrap_or(Ok(uiaainfo.clone()))?; + + // Find out what the user completed + match &**kind { + "m.login.password" => { + if auth_parameters["identifier"]["type"] != "m.id.user" { + panic!("identifier not supported"); + } + + let user_id = UserId::parse_with_server_name( + auth_parameters["identifier"]["user"].as_str().unwrap(), + globals.server_name(), + )?; + let password = auth_parameters["password"].as_str().unwrap(); + + // Check if password is correct + if let Some(hash) = users.password_hash(&user_id)? { + let hash_matches = + argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false); + + if !hash_matches { + debug!("Invalid password."); + uiaainfo.auth_error = Some(ruma::api::client::error::ErrorBody { + kind: ErrorKind::Forbidden, + message: "Invalid username or password.".to_owned(), + }); + return Ok((false, uiaainfo)); + } + } + + // Password was correct! Let's add it to `completed` + uiaainfo.completed.push("m.login.password".to_owned()); + } + "m.login.dummy" => { + uiaainfo.completed.push("m.login.dummy".to_owned()); + } + k => panic!("type not supported: {}", k), + } + + // Check if a flow now succeeds + let mut completed = false; + 'flows: for flow in &mut uiaainfo.flows { + for stage in &flow.stages { + if !uiaainfo.completed.contains(stage) { + continue 'flows; + } + } + // We didn't break, so this flow succeeded! + completed = true; + } + + if !completed { + self.update_uiaa_session(user_id, device_id, Some(&uiaainfo))?; + return Ok((false, uiaainfo)); + } + + // UIAA was successful! Remove this session and return true + self.update_uiaa_session(user_id, device_id, None)?; + return Ok((true, uiaainfo)); + } else { + panic!("FallbackAcknowledgement is not supported yet"); + } + } + + fn update_uiaa_session( + &self, + user_id: &UserId, + device_id: &str, + uiaainfo: Option<&UiaaInfo>, + ) -> Result<()> { + let mut userdeviceid = user_id.to_string().as_bytes().to_vec(); + userdeviceid.push(0xff); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + if let Some(uiaainfo) = uiaainfo { + self.userdeviceid_uiaainfo + .insert(&userdeviceid, &*serde_json::to_string(&uiaainfo)?)?; + } else { + self.userdeviceid_uiaainfo.remove(&userdeviceid)?; + } + + Ok(()) + } + + fn get_uiaa_session( + &self, + user_id: &UserId, + device_id: &str, + session: &str, + ) -> Result { + let mut userdeviceid = user_id.to_string().as_bytes().to_vec(); + userdeviceid.push(0xff); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + let uiaainfo = serde_json::from_slice::( + &self + .userdeviceid_uiaainfo + .get(&userdeviceid)? + .ok_or(Error::BadRequest("session does not exist"))?, + )?; + + if uiaainfo + .session + .as_ref() + .filter(|&s| s == session) + .is_none() + { + return Err(Error::BadRequest("wrong session token")); + } + + Ok(uiaainfo) + } +} diff --git a/src/database/users.rs b/src/database/users.rs index 8893b102..5c474556 100644 --- a/src/database/users.rs +++ b/src/database/users.rs @@ -6,7 +6,7 @@ use ruma::{ keys::{AlgorithmAndDeviceId, DeviceKeys, KeyAlgorithm, OneTimeKey}, }, events::{to_device::AnyToDeviceEvent, EventJson, EventType}, - identifiers::{DeviceId, UserId}, + identifiers::UserId, }; use std::{collections::BTreeMap, convert::TryFrom, time::SystemTime}; @@ -113,7 +113,7 @@ impl Users { pub fn create_device( &self, user_id: &UserId, - device_id: &DeviceId, + device_id: &str, token: &str, initial_device_display_name: Option, ) -> Result<()> { @@ -130,7 +130,7 @@ impl Users { self.userdeviceid_metadata.insert( userdeviceid, serde_json::to_string(&Device { - device_id: device_id.clone(), + device_id: device_id.to_owned(), display_name: initial_device_display_name, last_seen_ip: None, // TODO last_seen_ts: Some(SystemTime::now()), @@ -144,7 +144,7 @@ impl Users { } /// Removes a device from a user. - pub fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { + pub fn remove_device(&self, user_id: &UserId, device_id: &str) -> Result<()> { let mut userdeviceid = user_id.to_string().as_bytes().to_vec(); userdeviceid.push(0xff); userdeviceid.extend_from_slice(device_id.as_bytes()); @@ -173,7 +173,7 @@ impl Users { } /// Returns an iterator over all device ids of this user. - pub fn all_device_ids(&self, user_id: &UserId) -> impl Iterator> { + pub fn all_device_ids(&self, user_id: &UserId) -> impl Iterator> { let mut prefix = user_id.to_string().as_bytes().to_vec(); prefix.push(0xff); // All devices have metadata @@ -191,7 +191,7 @@ impl Users { } /// Replaces the access token of one device. - pub fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { + pub fn set_token(&self, user_id: &UserId, device_id: &str, token: &str) -> Result<()> { let mut userdeviceid = user_id.to_string().as_bytes().to_vec(); userdeviceid.push(0xff); userdeviceid.extend_from_slice(device_id.as_bytes()); @@ -219,7 +219,7 @@ impl Users { pub fn add_one_time_key( &self, user_id: &UserId, - device_id: &DeviceId, + device_id: &str, one_time_key_key: &AlgorithmAndDeviceId, one_time_key_value: &OneTimeKey, ) -> Result<()> { @@ -248,7 +248,7 @@ impl Users { pub fn take_one_time_key( &self, user_id: &UserId, - device_id: &DeviceId, + device_id: &str, key_algorithm: &KeyAlgorithm, ) -> Result> { let mut prefix = user_id.to_string().as_bytes().to_vec(); @@ -282,7 +282,7 @@ impl Users { pub fn count_one_time_keys( &self, user_id: &UserId, - device_id: &DeviceId, + device_id: &str, ) -> Result> { let mut userdeviceid = user_id.to_string().as_bytes().to_vec(); userdeviceid.push(0xff); @@ -315,7 +315,7 @@ impl Users { pub fn add_device_keys( &self, user_id: &UserId, - device_id: &DeviceId, + device_id: &str, device_keys: &DeviceKeys, globals: &super::globals::Globals, ) -> Result<()> { @@ -335,7 +335,7 @@ impl Users { pub fn get_device_keys( &self, user_id: &UserId, - device_id: &DeviceId, + device_id: &str, ) -> impl Iterator> { let mut key = user_id.to_string().as_bytes().to_vec(); key.push(0xff); @@ -376,7 +376,7 @@ impl Users { &self, sender: &UserId, target_user_id: &UserId, - target_device_id: &DeviceId, + target_device_id: &str, event_type: &EventType, content: serde_json::Value, globals: &super::globals::Globals, @@ -401,7 +401,7 @@ impl Users { pub fn take_to_device_events( &self, user_id: &UserId, - device_id: &DeviceId, + device_id: &str, max: usize, ) -> Result>> { let mut events = Vec::new(); @@ -423,7 +423,7 @@ impl Users { pub fn update_device_metadata( &self, user_id: &UserId, - device_id: &DeviceId, + device_id: &str, device: &Device, ) -> Result<()> { let mut userdeviceid = user_id.to_string().as_bytes().to_vec(); @@ -441,11 +441,7 @@ impl Users { } /// Get device metadata. - pub fn get_device_metadata( - &self, - user_id: &UserId, - device_id: &DeviceId, - ) -> Result> { + pub fn get_device_metadata(&self, user_id: &UserId, device_id: &str) -> Result> { let mut userdeviceid = user_id.to_string().as_bytes().to_vec(); userdeviceid.push(0xff); userdeviceid.extend_from_slice(device_id.as_bytes()); diff --git a/src/pdu.rs b/src/pdu.rs index 6ee0fd52..454d27f3 100644 --- a/src/pdu.rs +++ b/src/pdu.rs @@ -1,6 +1,6 @@ use js_int::UInt; use ruma::{ - api::federation::EventHash, + api::federation::pdu::EventHash, events::{ collections::all::{RoomEvent, StateEvent}, stripped::AnyStrippedStateEvent,