diff --git a/src/database/key_value/rooms/state_accessor.rs b/src/database/key_value/rooms/state_accessor.rs new file mode 100644 index 00000000..ae26a7c4 --- /dev/null +++ b/src/database/key_value/rooms/state_accessor.rs @@ -0,0 +1,160 @@ + /// Builds a StateMap by iterating over all keys that start + /// with state_hash, this gives the full state for the given state_hash. + #[tracing::instrument(skip(self))] + pub async fn state_full_ids(&self, shortstatehash: u64) -> Result>> { + let full_state = self + .load_shortstatehash_info(shortstatehash)? + .pop() + .expect("there is always one layer") + .1; + let mut result = BTreeMap::new(); + let mut i = 0; + for compressed in full_state.into_iter() { + let parsed = self.parse_compressed_state_event(compressed)?; + result.insert(parsed.0, parsed.1); + + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + } + Ok(result) + } + + #[tracing::instrument(skip(self))] + pub async fn state_full( + &self, + shortstatehash: u64, + ) -> Result>> { + let full_state = self + .load_shortstatehash_info(shortstatehash)? + .pop() + .expect("there is always one layer") + .1; + + let mut result = HashMap::new(); + let mut i = 0; + for compressed in full_state { + let (_, eventid) = self.parse_compressed_state_event(compressed)?; + if let Some(pdu) = self.get_pdu(&eventid)? { + result.insert( + ( + pdu.kind.to_string().into(), + pdu.state_key + .as_ref() + .ok_or_else(|| Error::bad_database("State event has no state key."))? + .clone(), + ), + pdu, + ); + } + + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + } + + Ok(result) + } + + /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + #[tracing::instrument(skip(self))] + pub fn state_get_id( + &self, + shortstatehash: u64, + event_type: &StateEventType, + state_key: &str, + ) -> Result>> { + let shortstatekey = match self.get_shortstatekey(event_type, state_key)? { + Some(s) => s, + None => return Ok(None), + }; + let full_state = self + .load_shortstatehash_info(shortstatehash)? + .pop() + .expect("there is always one layer") + .1; + Ok(full_state + .into_iter() + .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) + .and_then(|compressed| { + self.parse_compressed_state_event(compressed) + .ok() + .map(|(_, id)| id) + })) + } + + /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + #[tracing::instrument(skip(self))] + pub fn state_get( + &self, + shortstatehash: u64, + event_type: &StateEventType, + state_key: &str, + ) -> Result>> { + self.state_get_id(shortstatehash, event_type, state_key)? + .map_or(Ok(None), |event_id| self.get_pdu(&event_id)) + } + + /// Returns the state hash for this pdu. + pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result> { + self.eventid_shorteventid + .get(event_id.as_bytes())? + .map_or(Ok(None), |shorteventid| { + self.shorteventid_shortstatehash + .get(&shorteventid)? + .map(|bytes| { + utils::u64_from_bytes(&bytes).map_err(|_| { + Error::bad_database( + "Invalid shortstatehash bytes in shorteventid_shortstatehash", + ) + }) + }) + .transpose() + }) + } + + /// Returns the full room state. + #[tracing::instrument(skip(self))] + pub async fn room_state_full( + &self, + room_id: &RoomId, + ) -> Result>> { + if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { + self.state_full(current_shortstatehash).await + } else { + Ok(HashMap::new()) + } + } + + /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + #[tracing::instrument(skip(self))] + pub fn room_state_get_id( + &self, + room_id: &RoomId, + event_type: &StateEventType, + state_key: &str, + ) -> Result>> { + if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { + self.state_get_id(current_shortstatehash, event_type, state_key) + } else { + Ok(None) + } + } + + /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + #[tracing::instrument(skip(self))] + pub fn room_state_get( + &self, + room_id: &RoomId, + event_type: &StateEventType, + state_key: &str, + ) -> Result>> { + if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { + self.state_get(current_shortstatehash, event_type, state_key) + } else { + Ok(None) + } + } + diff --git a/src/database/key_value/rooms/user.rs b/src/database/key_value/rooms/user.rs new file mode 100644 index 00000000..976ab5b3 --- /dev/null +++ b/src/database/key_value/rooms/user.rs @@ -0,0 +1,114 @@ + + #[tracing::instrument(skip(self))] + pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.as_bytes()); + + self.userroomid_notificationcount + .insert(&userroom_id, &0_u64.to_be_bytes())?; + self.userroomid_highlightcount + .insert(&userroom_id, &0_u64.to_be_bytes())?; + + Ok(()) + } + + #[tracing::instrument(skip(self))] + pub fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.as_bytes()); + + self.userroomid_notificationcount + .get(&userroom_id)? + .map(|bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Invalid notification count in db.")) + }) + .unwrap_or(Ok(0)) + } + + #[tracing::instrument(skip(self))] + pub fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.as_bytes()); + + self.userroomid_highlightcount + .get(&userroom_id)? + .map(|bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Invalid highlight count in db.")) + }) + .unwrap_or(Ok(0)) + } + + pub fn associate_token_shortstatehash( + &self, + room_id: &RoomId, + token: u64, + shortstatehash: u64, + ) -> Result<()> { + let shortroomid = self.get_shortroomid(room_id)?.expect("room exists"); + + let mut key = shortroomid.to_be_bytes().to_vec(); + key.extend_from_slice(&token.to_be_bytes()); + + self.roomsynctoken_shortstatehash + .insert(&key, &shortstatehash.to_be_bytes()) + } + + pub fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result> { + let shortroomid = self.get_shortroomid(room_id)?.expect("room exists"); + + let mut key = shortroomid.to_be_bytes().to_vec(); + key.extend_from_slice(&token.to_be_bytes()); + + self.roomsynctoken_shortstatehash + .get(&key)? + .map(|bytes| { + utils::u64_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash") + }) + }) + .transpose() + } + + #[tracing::instrument(skip(self))] + pub fn get_shared_rooms<'a>( + &'a self, + users: Vec>, + ) -> Result>> + 'a> { + let iterators = users.into_iter().map(move |user_id| { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xff); + + self.userroomid_joined + .scan_prefix(prefix) + .map(|(key, _)| { + let roomid_index = key + .iter() + .enumerate() + .find(|(_, &b)| b == 0xff) + .ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))? + .0 + + 1; // +1 because the room id starts AFTER the separator + + let room_id = key[roomid_index..].to_vec(); + + Ok::<_, Error>(room_id) + }) + .filter_map(|r| r.ok()) + }); + + // We use the default compare function because keys are sorted correctly (not reversed) + Ok(utils::common_elements(iterators, Ord::cmp) + .expect("users is not empty") + .map(|bytes| { + RoomId::parse(utils::string_from_bytes(&*bytes).map_err(|_| { + Error::bad_database("Invalid RoomId bytes in userroomid_joined") + })?) + .map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined.")) + })) + } + diff --git a/src/service/uiaa.rs b/src/database/key_value/uiaa.rs similarity index 100% rename from src/service/uiaa.rs rename to src/database/key_value/uiaa.rs diff --git a/src/service/rooms/state_accessor/data.rs b/src/service/rooms/state_accessor/data.rs new file mode 100644 index 00000000..ae26a7c4 --- /dev/null +++ b/src/service/rooms/state_accessor/data.rs @@ -0,0 +1,160 @@ + /// Builds a StateMap by iterating over all keys that start + /// with state_hash, this gives the full state for the given state_hash. + #[tracing::instrument(skip(self))] + pub async fn state_full_ids(&self, shortstatehash: u64) -> Result>> { + let full_state = self + .load_shortstatehash_info(shortstatehash)? + .pop() + .expect("there is always one layer") + .1; + let mut result = BTreeMap::new(); + let mut i = 0; + for compressed in full_state.into_iter() { + let parsed = self.parse_compressed_state_event(compressed)?; + result.insert(parsed.0, parsed.1); + + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + } + Ok(result) + } + + #[tracing::instrument(skip(self))] + pub async fn state_full( + &self, + shortstatehash: u64, + ) -> Result>> { + let full_state = self + .load_shortstatehash_info(shortstatehash)? + .pop() + .expect("there is always one layer") + .1; + + let mut result = HashMap::new(); + let mut i = 0; + for compressed in full_state { + let (_, eventid) = self.parse_compressed_state_event(compressed)?; + if let Some(pdu) = self.get_pdu(&eventid)? { + result.insert( + ( + pdu.kind.to_string().into(), + pdu.state_key + .as_ref() + .ok_or_else(|| Error::bad_database("State event has no state key."))? + .clone(), + ), + pdu, + ); + } + + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + } + + Ok(result) + } + + /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + #[tracing::instrument(skip(self))] + pub fn state_get_id( + &self, + shortstatehash: u64, + event_type: &StateEventType, + state_key: &str, + ) -> Result>> { + let shortstatekey = match self.get_shortstatekey(event_type, state_key)? { + Some(s) => s, + None => return Ok(None), + }; + let full_state = self + .load_shortstatehash_info(shortstatehash)? + .pop() + .expect("there is always one layer") + .1; + Ok(full_state + .into_iter() + .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) + .and_then(|compressed| { + self.parse_compressed_state_event(compressed) + .ok() + .map(|(_, id)| id) + })) + } + + /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + #[tracing::instrument(skip(self))] + pub fn state_get( + &self, + shortstatehash: u64, + event_type: &StateEventType, + state_key: &str, + ) -> Result>> { + self.state_get_id(shortstatehash, event_type, state_key)? + .map_or(Ok(None), |event_id| self.get_pdu(&event_id)) + } + + /// Returns the state hash for this pdu. + pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result> { + self.eventid_shorteventid + .get(event_id.as_bytes())? + .map_or(Ok(None), |shorteventid| { + self.shorteventid_shortstatehash + .get(&shorteventid)? + .map(|bytes| { + utils::u64_from_bytes(&bytes).map_err(|_| { + Error::bad_database( + "Invalid shortstatehash bytes in shorteventid_shortstatehash", + ) + }) + }) + .transpose() + }) + } + + /// Returns the full room state. + #[tracing::instrument(skip(self))] + pub async fn room_state_full( + &self, + room_id: &RoomId, + ) -> Result>> { + if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { + self.state_full(current_shortstatehash).await + } else { + Ok(HashMap::new()) + } + } + + /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + #[tracing::instrument(skip(self))] + pub fn room_state_get_id( + &self, + room_id: &RoomId, + event_type: &StateEventType, + state_key: &str, + ) -> Result>> { + if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { + self.state_get_id(current_shortstatehash, event_type, state_key) + } else { + Ok(None) + } + } + + /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + #[tracing::instrument(skip(self))] + pub fn room_state_get( + &self, + room_id: &RoomId, + event_type: &StateEventType, + state_key: &str, + ) -> Result>> { + if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { + self.state_get(current_shortstatehash, event_type, state_key) + } else { + Ok(None) + } + } + diff --git a/src/service/rooms/user/data.rs b/src/service/rooms/user/data.rs new file mode 100644 index 00000000..976ab5b3 --- /dev/null +++ b/src/service/rooms/user/data.rs @@ -0,0 +1,114 @@ + + #[tracing::instrument(skip(self))] + pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.as_bytes()); + + self.userroomid_notificationcount + .insert(&userroom_id, &0_u64.to_be_bytes())?; + self.userroomid_highlightcount + .insert(&userroom_id, &0_u64.to_be_bytes())?; + + Ok(()) + } + + #[tracing::instrument(skip(self))] + pub fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.as_bytes()); + + self.userroomid_notificationcount + .get(&userroom_id)? + .map(|bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Invalid notification count in db.")) + }) + .unwrap_or(Ok(0)) + } + + #[tracing::instrument(skip(self))] + pub fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.as_bytes()); + + self.userroomid_highlightcount + .get(&userroom_id)? + .map(|bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Invalid highlight count in db.")) + }) + .unwrap_or(Ok(0)) + } + + pub fn associate_token_shortstatehash( + &self, + room_id: &RoomId, + token: u64, + shortstatehash: u64, + ) -> Result<()> { + let shortroomid = self.get_shortroomid(room_id)?.expect("room exists"); + + let mut key = shortroomid.to_be_bytes().to_vec(); + key.extend_from_slice(&token.to_be_bytes()); + + self.roomsynctoken_shortstatehash + .insert(&key, &shortstatehash.to_be_bytes()) + } + + pub fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result> { + let shortroomid = self.get_shortroomid(room_id)?.expect("room exists"); + + let mut key = shortroomid.to_be_bytes().to_vec(); + key.extend_from_slice(&token.to_be_bytes()); + + self.roomsynctoken_shortstatehash + .get(&key)? + .map(|bytes| { + utils::u64_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash") + }) + }) + .transpose() + } + + #[tracing::instrument(skip(self))] + pub fn get_shared_rooms<'a>( + &'a self, + users: Vec>, + ) -> Result>> + 'a> { + let iterators = users.into_iter().map(move |user_id| { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xff); + + self.userroomid_joined + .scan_prefix(prefix) + .map(|(key, _)| { + let roomid_index = key + .iter() + .enumerate() + .find(|(_, &b)| b == 0xff) + .ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))? + .0 + + 1; // +1 because the room id starts AFTER the separator + + let room_id = key[roomid_index..].to_vec(); + + Ok::<_, Error>(room_id) + }) + .filter_map(|r| r.ok()) + }); + + // We use the default compare function because keys are sorted correctly (not reversed) + Ok(utils::common_elements(iterators, Ord::cmp) + .expect("users is not empty") + .map(|bytes| { + RoomId::parse(utils::string_from_bytes(&*bytes).map_err(|_| { + Error::bad_database("Invalid RoomId bytes in userroomid_joined") + })?) + .map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined.")) + })) + } + diff --git a/src/service/uiaa/data.rs b/src/service/uiaa/data.rs new file mode 100644 index 00000000..12373139 --- /dev/null +++ b/src/service/uiaa/data.rs @@ -0,0 +1,227 @@ +use std::{ + collections::BTreeMap, + sync::{Arc, RwLock}, +}; + +use crate::{client_server::SESSION_ID_LENGTH, utils, Error, Result}; +use ruma::{ + api::client::{ + error::ErrorKind, + uiaa::{ + AuthType, IncomingAuthData, IncomingPassword, + IncomingUserIdentifier::UserIdOrLocalpart, UiaaInfo, + }, + }, + signatures::CanonicalJsonValue, + DeviceId, UserId, +}; +use tracing::error; + +use super::abstraction::Tree; + +pub struct Uiaa { + pub(super) userdevicesessionid_uiaainfo: Arc, // User-interactive authentication + pub(super) userdevicesessionid_uiaarequest: + RwLock, Box, String), CanonicalJsonValue>>, +} + +impl Uiaa { + /// Creates a new Uiaa session. Make sure the session token is unique. + pub fn create( + &self, + user_id: &UserId, + device_id: &DeviceId, + uiaainfo: &UiaaInfo, + json_body: &CanonicalJsonValue, + ) -> Result<()> { + self.set_uiaa_request( + user_id, + device_id, + uiaainfo.session.as_ref().expect("session should be set"), // TODO: better session error handling (why is it optional in ruma?) + json_body, + )?; + self.update_uiaa_session( + user_id, + device_id, + uiaainfo.session.as_ref().expect("session should be set"), + Some(uiaainfo), + ) + } + + pub fn try_auth( + &self, + user_id: &UserId, + device_id: &DeviceId, + auth: &IncomingAuthData, + uiaainfo: &UiaaInfo, + users: &super::users::Users, + globals: &super::globals::Globals, + ) -> Result<(bool, UiaaInfo)> { + let mut uiaainfo = auth + .session() + .map(|session| self.get_uiaa_session(user_id, device_id, session)) + .unwrap_or_else(|| Ok(uiaainfo.clone()))?; + + if uiaainfo.session.is_none() { + uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); + } + + match auth { + // Find out what the user completed + IncomingAuthData::Password(IncomingPassword { + identifier, + password, + .. + }) => { + let username = match identifier { + UserIdOrLocalpart(username) => username, + _ => { + return Err(Error::BadRequest( + ErrorKind::Unrecognized, + "Identifier type not recognized.", + )) + } + }; + + let user_id = + UserId::parse_with_server_name(username.clone(), globals.server_name()) + .map_err(|_| { + Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid.") + })?; + + // Check if password is correct + if let Some(hash) = users.password_hash(&user_id)? { + let hash_matches = + argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false); + + if !hash_matches { + uiaainfo.auth_error = Some(ruma::api::client::error::ErrorBody { + kind: ErrorKind::Forbidden, + message: "Invalid username or password.".to_owned(), + }); + return Ok((false, uiaainfo)); + } + } + + // Password was correct! Let's add it to `completed` + uiaainfo.completed.push(AuthType::Password); + } + IncomingAuthData::Dummy(_) => { + uiaainfo.completed.push(AuthType::Dummy); + } + k => error!("type not supported: {:?}", k), + } + + // Check if a flow now succeeds + let mut completed = false; + 'flows: for flow in &mut uiaainfo.flows { + for stage in &flow.stages { + if !uiaainfo.completed.contains(stage) { + continue 'flows; + } + } + // We didn't break, so this flow succeeded! + completed = true; + } + + if !completed { + self.update_uiaa_session( + user_id, + device_id, + uiaainfo.session.as_ref().expect("session is always set"), + Some(&uiaainfo), + )?; + return Ok((false, uiaainfo)); + } + + // UIAA was successful! Remove this session and return true + self.update_uiaa_session( + user_id, + device_id, + uiaainfo.session.as_ref().expect("session is always set"), + None, + )?; + Ok((true, uiaainfo)) + } + + fn set_uiaa_request( + &self, + user_id: &UserId, + device_id: &DeviceId, + session: &str, + request: &CanonicalJsonValue, + ) -> Result<()> { + self.userdevicesessionid_uiaarequest + .write() + .unwrap() + .insert( + (user_id.to_owned(), device_id.to_owned(), session.to_owned()), + request.to_owned(), + ); + + Ok(()) + } + + pub fn get_uiaa_request( + &self, + user_id: &UserId, + device_id: &DeviceId, + session: &str, + ) -> Option { + self.userdevicesessionid_uiaarequest + .read() + .unwrap() + .get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned())) + .map(|j| j.to_owned()) + } + + fn update_uiaa_session( + &self, + user_id: &UserId, + device_id: &DeviceId, + session: &str, + uiaainfo: Option<&UiaaInfo>, + ) -> Result<()> { + let mut userdevicesessionid = user_id.as_bytes().to_vec(); + userdevicesessionid.push(0xff); + userdevicesessionid.extend_from_slice(device_id.as_bytes()); + userdevicesessionid.push(0xff); + userdevicesessionid.extend_from_slice(session.as_bytes()); + + if let Some(uiaainfo) = uiaainfo { + self.userdevicesessionid_uiaainfo.insert( + &userdevicesessionid, + &serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"), + )?; + } else { + self.userdevicesessionid_uiaainfo + .remove(&userdevicesessionid)?; + } + + Ok(()) + } + + fn get_uiaa_session( + &self, + user_id: &UserId, + device_id: &DeviceId, + session: &str, + ) -> Result { + let mut userdevicesessionid = user_id.as_bytes().to_vec(); + userdevicesessionid.push(0xff); + userdevicesessionid.extend_from_slice(device_id.as_bytes()); + userdevicesessionid.push(0xff); + userdevicesessionid.extend_from_slice(session.as_bytes()); + + serde_json::from_slice( + &self + .userdevicesessionid_uiaainfo + .get(&userdevicesessionid)? + .ok_or(Error::BadRequest( + ErrorKind::Forbidden, + "UIAA session does not exist.", + ))?, + ) + .map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid.")) + } +} diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs new file mode 100644 index 00000000..12373139 --- /dev/null +++ b/src/service/uiaa/mod.rs @@ -0,0 +1,227 @@ +use std::{ + collections::BTreeMap, + sync::{Arc, RwLock}, +}; + +use crate::{client_server::SESSION_ID_LENGTH, utils, Error, Result}; +use ruma::{ + api::client::{ + error::ErrorKind, + uiaa::{ + AuthType, IncomingAuthData, IncomingPassword, + IncomingUserIdentifier::UserIdOrLocalpart, UiaaInfo, + }, + }, + signatures::CanonicalJsonValue, + DeviceId, UserId, +}; +use tracing::error; + +use super::abstraction::Tree; + +pub struct Uiaa { + pub(super) userdevicesessionid_uiaainfo: Arc, // User-interactive authentication + pub(super) userdevicesessionid_uiaarequest: + RwLock, Box, String), CanonicalJsonValue>>, +} + +impl Uiaa { + /// Creates a new Uiaa session. Make sure the session token is unique. + pub fn create( + &self, + user_id: &UserId, + device_id: &DeviceId, + uiaainfo: &UiaaInfo, + json_body: &CanonicalJsonValue, + ) -> Result<()> { + self.set_uiaa_request( + user_id, + device_id, + uiaainfo.session.as_ref().expect("session should be set"), // TODO: better session error handling (why is it optional in ruma?) + json_body, + )?; + self.update_uiaa_session( + user_id, + device_id, + uiaainfo.session.as_ref().expect("session should be set"), + Some(uiaainfo), + ) + } + + pub fn try_auth( + &self, + user_id: &UserId, + device_id: &DeviceId, + auth: &IncomingAuthData, + uiaainfo: &UiaaInfo, + users: &super::users::Users, + globals: &super::globals::Globals, + ) -> Result<(bool, UiaaInfo)> { + let mut uiaainfo = auth + .session() + .map(|session| self.get_uiaa_session(user_id, device_id, session)) + .unwrap_or_else(|| Ok(uiaainfo.clone()))?; + + if uiaainfo.session.is_none() { + uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); + } + + match auth { + // Find out what the user completed + IncomingAuthData::Password(IncomingPassword { + identifier, + password, + .. + }) => { + let username = match identifier { + UserIdOrLocalpart(username) => username, + _ => { + return Err(Error::BadRequest( + ErrorKind::Unrecognized, + "Identifier type not recognized.", + )) + } + }; + + let user_id = + UserId::parse_with_server_name(username.clone(), globals.server_name()) + .map_err(|_| { + Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid.") + })?; + + // Check if password is correct + if let Some(hash) = users.password_hash(&user_id)? { + let hash_matches = + argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false); + + if !hash_matches { + uiaainfo.auth_error = Some(ruma::api::client::error::ErrorBody { + kind: ErrorKind::Forbidden, + message: "Invalid username or password.".to_owned(), + }); + return Ok((false, uiaainfo)); + } + } + + // Password was correct! Let's add it to `completed` + uiaainfo.completed.push(AuthType::Password); + } + IncomingAuthData::Dummy(_) => { + uiaainfo.completed.push(AuthType::Dummy); + } + k => error!("type not supported: {:?}", k), + } + + // Check if a flow now succeeds + let mut completed = false; + 'flows: for flow in &mut uiaainfo.flows { + for stage in &flow.stages { + if !uiaainfo.completed.contains(stage) { + continue 'flows; + } + } + // We didn't break, so this flow succeeded! + completed = true; + } + + if !completed { + self.update_uiaa_session( + user_id, + device_id, + uiaainfo.session.as_ref().expect("session is always set"), + Some(&uiaainfo), + )?; + return Ok((false, uiaainfo)); + } + + // UIAA was successful! Remove this session and return true + self.update_uiaa_session( + user_id, + device_id, + uiaainfo.session.as_ref().expect("session is always set"), + None, + )?; + Ok((true, uiaainfo)) + } + + fn set_uiaa_request( + &self, + user_id: &UserId, + device_id: &DeviceId, + session: &str, + request: &CanonicalJsonValue, + ) -> Result<()> { + self.userdevicesessionid_uiaarequest + .write() + .unwrap() + .insert( + (user_id.to_owned(), device_id.to_owned(), session.to_owned()), + request.to_owned(), + ); + + Ok(()) + } + + pub fn get_uiaa_request( + &self, + user_id: &UserId, + device_id: &DeviceId, + session: &str, + ) -> Option { + self.userdevicesessionid_uiaarequest + .read() + .unwrap() + .get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned())) + .map(|j| j.to_owned()) + } + + fn update_uiaa_session( + &self, + user_id: &UserId, + device_id: &DeviceId, + session: &str, + uiaainfo: Option<&UiaaInfo>, + ) -> Result<()> { + let mut userdevicesessionid = user_id.as_bytes().to_vec(); + userdevicesessionid.push(0xff); + userdevicesessionid.extend_from_slice(device_id.as_bytes()); + userdevicesessionid.push(0xff); + userdevicesessionid.extend_from_slice(session.as_bytes()); + + if let Some(uiaainfo) = uiaainfo { + self.userdevicesessionid_uiaainfo.insert( + &userdevicesessionid, + &serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"), + )?; + } else { + self.userdevicesessionid_uiaainfo + .remove(&userdevicesessionid)?; + } + + Ok(()) + } + + fn get_uiaa_session( + &self, + user_id: &UserId, + device_id: &DeviceId, + session: &str, + ) -> Result { + let mut userdevicesessionid = user_id.as_bytes().to_vec(); + userdevicesessionid.push(0xff); + userdevicesessionid.extend_from_slice(device_id.as_bytes()); + userdevicesessionid.push(0xff); + userdevicesessionid.extend_from_slice(session.as_bytes()); + + serde_json::from_slice( + &self + .userdevicesessionid_uiaainfo + .get(&userdevicesessionid)? + .ok_or(Error::BadRequest( + ErrorKind::Forbidden, + "UIAA session does not exist.", + ))?, + ) + .map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid.")) + } +}