diff --git a/src/database.rs b/src/database.rs index 84ca68dc..8b29b221 100644 --- a/src/database.rs +++ b/src/database.rs @@ -250,8 +250,7 @@ impl Database { }, uiaa: uiaa::Uiaa { userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?, - userdevicesessionid_uiaarequest: builder - .open_tree("userdevicesessionid_uiaarequest")?, + userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()), }, rooms: rooms::Rooms { edus: rooms::RoomEdus { @@ -755,6 +754,15 @@ impl Database { println!("Migration: 9 -> 10 finished"); } + + if db.globals.database_version()? < 11 { + db._db + .open_tree("userdevicesessionid_uiaarequest")? + .clear()?; + db.globals.bump_database_version(11)?; + + println!("Migration: 10 -> 11 finished"); + } } let guard = db.read().await; diff --git a/src/database/uiaa.rs b/src/database/uiaa.rs index 1c0fb566..772dab9e 100644 --- a/src/database/uiaa.rs +++ b/src/database/uiaa.rs @@ -1,4 +1,6 @@ +use std::collections::BTreeMap; use std::sync::Arc; +use std::sync::RwLock; use crate::{client_server::SESSION_ID_LENGTH, utils, Error, Result}; use ruma::{ @@ -18,7 +20,8 @@ use super::abstraction::Tree; pub struct Uiaa { pub(super) userdevicesessionid_uiaainfo: Arc, // User-interactive authentication - pub(super) userdevicesessionid_uiaarequest: Arc, // UiaaRequest = canonical json value + pub(super) userdevicesessionid_uiaarequest: + RwLock, Box, String), CanonicalJsonValue>>, } impl Uiaa { @@ -147,16 +150,17 @@ impl Uiaa { session: &str, request: &CanonicalJsonValue, ) -> Result<()> { - let mut userdevicesessionid = user_id.as_bytes().to_vec(); - userdevicesessionid.push(0xff); - userdevicesessionid.extend_from_slice(device_id.as_bytes()); - userdevicesessionid.push(0xff); - userdevicesessionid.extend_from_slice(session.as_bytes()); - - self.userdevicesessionid_uiaarequest.insert( - &userdevicesessionid, - &serde_json::to_vec(request).expect("json value to vec always works"), - )?; + self.userdevicesessionid_uiaarequest + .write() + .unwrap() + .insert( + ( + user_id.to_owned(), + device_id.to_owned(), + session.to_string(), + ), + request.to_owned(), + ); Ok(()) } @@ -167,22 +171,16 @@ impl Uiaa { device_id: &DeviceId, session: &str, ) -> Result> { - let mut userdevicesessionid = user_id.as_bytes().to_vec(); - userdevicesessionid.push(0xff); - userdevicesessionid.extend_from_slice(device_id.as_bytes()); - userdevicesessionid.push(0xff); - userdevicesessionid.extend_from_slice(session.as_bytes()); - - self.userdevicesessionid_uiaarequest - .get(&userdevicesessionid)? - .map(|bytes| { - serde_json::from_str::( - &utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid uiaa request bytes in db."))?, - ) - .map_err(|_| Error::bad_database("Invalid uiaa request in db.")) - }) - .transpose() + Ok(self + .userdevicesessionid_uiaarequest + .read() + .unwrap() + .get(&( + user_id.to_owned(), + device_id.to_owned(), + session.to_string(), + )) + .map(|j| j.to_owned())) } fn update_uiaa_session(