Merge pull request 'refactor: better error handling' (#116) from error-handling into master

merge-requests/22/head
Timo Kösters 4 years ago
commit 4c98079c4c

File diff suppressed because it is too large Load Diff

@ -6,6 +6,7 @@ pub(self) mod rooms;
pub(self) mod uiaa; pub(self) mod uiaa;
pub(self) mod users; pub(self) mod users;
use crate::{Error, Result};
use directories::ProjectDirs; use directories::ProjectDirs;
use log::info; use log::info;
use std::fs::remove_dir_all; use std::fs::remove_dir_all;
@ -25,84 +26,92 @@ pub struct Database {
impl Database { impl Database {
/// Tries to remove the old database but ignores all errors. /// Tries to remove the old database but ignores all errors.
pub fn try_remove(server_name: &str) { pub fn try_remove(server_name: &str) -> Result<()> {
let mut path = ProjectDirs::from("xyz", "koesters", "conduit") let mut path = ProjectDirs::from("xyz", "koesters", "conduit")
.unwrap() .ok_or(Error::BadConfig(
"The OS didn't return a valid home directory path.",
))?
.data_dir() .data_dir()
.to_path_buf(); .to_path_buf();
path.push(server_name); path.push(server_name);
let _ = remove_dir_all(path); let _ = remove_dir_all(path);
Ok(())
} }
/// Load an existing database or create a new one. /// Load an existing database or create a new one.
pub fn load_or_create(config: &Config) -> Self { pub fn load_or_create(config: &Config) -> Result<Self> {
let server_name = config.get_str("server_name").unwrap_or("localhost"); let server_name = config.get_str("server_name").unwrap_or("localhost");
let path = config let path = config
.get_str("database_path") .get_str("database_path")
.map(|x| x.to_owned()) .map(|x| Ok::<_, Error>(x.to_owned()))
.unwrap_or_else(|_| { .unwrap_or_else(|_| {
let path = ProjectDirs::from("xyz", "koesters", "conduit") let path = ProjectDirs::from("xyz", "koesters", "conduit")
.unwrap() .ok_or(Error::BadConfig(
"The OS didn't return a valid home directory path.",
))?
.data_dir() .data_dir()
.join(server_name); .join(server_name);
path.to_str().unwrap().to_owned()
});
let db = sled::open(&path).unwrap(); Ok(path
.to_str()
.ok_or(Error::BadConfig("Database path contains invalid unicode."))?
.to_owned())
})?;
let db = sled::open(&path)?;
info!("Opened sled database at {}", path); info!("Opened sled database at {}", path);
Self { Ok(Self {
globals: globals::Globals::load(db.open_tree("global").unwrap(), config), globals: globals::Globals::load(db.open_tree("global")?, config)?,
users: users::Users { users: users::Users {
userid_password: db.open_tree("userid_password").unwrap(), userid_password: db.open_tree("userid_password")?,
userid_displayname: db.open_tree("userid_displayname").unwrap(), userid_displayname: db.open_tree("userid_displayname")?,
userid_avatarurl: db.open_tree("userid_avatarurl").unwrap(), userid_avatarurl: db.open_tree("userid_avatarurl")?,
userdeviceid_token: db.open_tree("userdeviceid_token").unwrap(), userdeviceid_token: db.open_tree("userdeviceid_token")?,
userdeviceid_metadata: db.open_tree("userdeviceid_metadata").unwrap(), userdeviceid_metadata: db.open_tree("userdeviceid_metadata")?,
token_userdeviceid: db.open_tree("token_userdeviceid").unwrap(), token_userdeviceid: db.open_tree("token_userdeviceid")?,
onetimekeyid_onetimekeys: db.open_tree("onetimekeyid_onetimekeys").unwrap(), onetimekeyid_onetimekeys: db.open_tree("onetimekeyid_onetimekeys")?,
userdeviceid_devicekeys: db.open_tree("userdeviceid_devicekeys").unwrap(), userdeviceid_devicekeys: db.open_tree("userdeviceid_devicekeys")?,
devicekeychangeid_userid: db.open_tree("devicekeychangeid_userid").unwrap(), devicekeychangeid_userid: db.open_tree("devicekeychangeid_userid")?,
todeviceid_events: db.open_tree("todeviceid_events").unwrap(), todeviceid_events: db.open_tree("todeviceid_events")?,
}, },
uiaa: uiaa::Uiaa { uiaa: uiaa::Uiaa {
userdeviceid_uiaainfo: db.open_tree("userdeviceid_uiaainfo").unwrap(), userdeviceid_uiaainfo: db.open_tree("userdeviceid_uiaainfo")?,
}, },
rooms: rooms::Rooms { rooms: rooms::Rooms {
edus: rooms::RoomEdus { edus: rooms::RoomEdus {
roomuserid_lastread: db.open_tree("roomuserid_lastread").unwrap(), // "Private" read receipt roomuserid_lastread: db.open_tree("roomuserid_lastread")?, // "Private" read receipt
roomlatestid_roomlatest: db.open_tree("roomlatestid_roomlatest").unwrap(), // Read receipts roomlatestid_roomlatest: db.open_tree("roomlatestid_roomlatest")?, // Read receipts
roomactiveid_userid: db.open_tree("roomactiveid_userid").unwrap(), // Typing notifs roomactiveid_userid: db.open_tree("roomactiveid_userid")?, // Typing notifs
roomid_lastroomactiveupdate: db roomid_lastroomactiveupdate: db.open_tree("roomid_lastroomactiveupdate")?,
.open_tree("roomid_lastroomactiveupdate")
.unwrap(),
}, },
pduid_pdu: db.open_tree("pduid_pdu").unwrap(), pduid_pdu: db.open_tree("pduid_pdu")?,
eventid_pduid: db.open_tree("eventid_pduid").unwrap(), eventid_pduid: db.open_tree("eventid_pduid")?,
roomid_pduleaves: db.open_tree("roomid_pduleaves").unwrap(), roomid_pduleaves: db.open_tree("roomid_pduleaves")?,
roomstateid_pdu: db.open_tree("roomstateid_pdu").unwrap(), roomstateid_pdu: db.open_tree("roomstateid_pdu")?,
alias_roomid: db.open_tree("alias_roomid").unwrap(), alias_roomid: db.open_tree("alias_roomid")?,
aliasid_alias: db.open_tree("alias_roomid").unwrap(), aliasid_alias: db.open_tree("alias_roomid")?,
publicroomids: db.open_tree("publicroomids").unwrap(), publicroomids: db.open_tree("publicroomids")?,
userroomid_joined: db.open_tree("userroomid_joined").unwrap(), userroomid_joined: db.open_tree("userroomid_joined")?,
roomuserid_joined: db.open_tree("roomuserid_joined").unwrap(), roomuserid_joined: db.open_tree("roomuserid_joined")?,
userroomid_invited: db.open_tree("userroomid_invited").unwrap(), userroomid_invited: db.open_tree("userroomid_invited")?,
roomuserid_invited: db.open_tree("roomuserid_invited").unwrap(), roomuserid_invited: db.open_tree("roomuserid_invited")?,
userroomid_left: db.open_tree("userroomid_left").unwrap(), userroomid_left: db.open_tree("userroomid_left")?,
}, },
account_data: account_data::AccountData { account_data: account_data::AccountData {
roomuserdataid_accountdata: db.open_tree("roomuserdataid_accountdata").unwrap(), roomuserdataid_accountdata: db.open_tree("roomuserdataid_accountdata")?,
}, },
global_edus: global_edus::GlobalEdus { global_edus: global_edus::GlobalEdus {
presenceid_presence: db.open_tree("presenceid_presence").unwrap(), // Presence presenceid_presence: db.open_tree("presenceid_presence")?, // Presence
}, },
media: media::Media { media: media::Media {
mediaid_file: db.open_tree("mediaid_file").unwrap(), mediaid_file: db.open_tree("mediaid_file")?,
}, },
_db: db, _db: db,
} })
} }
} }

@ -1,5 +1,6 @@
use crate::{utils, Error, Result}; use crate::{utils, Error, Result};
use ruma::{ use ruma::{
api::client::error::ErrorKind,
events::{collections::only::Event as EduEvent, EventJson, EventType}, events::{collections::only::Event as EduEvent, EventJson, EventType},
identifiers::{RoomId, UserId}, identifiers::{RoomId, UserId},
}; };
@ -20,7 +21,10 @@ impl AccountData {
globals: &super::globals::Globals, globals: &super::globals::Globals,
) -> Result<()> { ) -> Result<()> {
if json.get("content").is_none() { if json.get("content").is_none() {
return Err(Error::BadRequest("json needs to have a content field")); return Err(Error::BadRequest(
ErrorKind::BadJson,
"Json needs to have a content field.",
));
} }
json.insert("type".to_owned(), kind.to_string().into()); json.insert("type".to_owned(), kind.to_string().into());
@ -62,9 +66,10 @@ impl AccountData {
key.push(0xff); key.push(0xff);
key.extend_from_slice(kind.to_string().as_bytes()); key.extend_from_slice(kind.to_string().as_bytes());
self.roomuserdataid_accountdata self.roomuserdataid_accountdata.insert(
.insert(key, &*serde_json::to_string(&json)?) key,
.unwrap(); &*serde_json::to_string(&json).expect("Map::to_string always works"),
)?;
Ok(()) Ok(())
} }
@ -109,17 +114,20 @@ impl AccountData {
.take_while(move |(k, _)| k.starts_with(&prefix)) .take_while(move |(k, _)| k.starts_with(&prefix))
.map(|(k, v)| { .map(|(k, v)| {
Ok::<_, Error>(( Ok::<_, Error>((
EventType::try_from(utils::string_from_bytes( EventType::try_from(
k.rsplit(|&b| b == 0xff) utils::string_from_bytes(k.rsplit(|&b| b == 0xff).next().ok_or_else(
.next() || Error::bad_database("RoomUserData ID in db is invalid."),
.ok_or(Error::BadDatabase("roomuserdataid is invalid"))?, )?)
)?) .map_err(|_| Error::bad_database("RoomUserData ID in db is invalid."))?,
.map_err(|_| Error::BadDatabase("roomuserdataid is invalid"))?, )
serde_json::from_slice::<EventJson<EduEvent>>(&v).unwrap(), .map_err(|_| Error::bad_database("RoomUserData ID in db is invalid."))?,
serde_json::from_slice::<EventJson<EduEvent>>(&v).map_err(|_| {
Error::bad_database("Database contains invalid account data.")
})?,
)) ))
}) })
{ {
let (kind, data) = r.unwrap(); let (kind, data) = r?;
userdata.insert(kind, data); userdata.insert(kind, data);
} }

@ -1,4 +1,4 @@
use crate::Result; use crate::{Error, Result};
use ruma::events::EventJson; use ruma::events::EventJson;
pub struct GlobalEdus { pub struct GlobalEdus {
@ -21,7 +21,10 @@ impl GlobalEdus {
.rev() .rev()
.filter_map(|r| r.ok()) .filter_map(|r| r.ok())
.find(|key| { .find(|key| {
key.rsplit(|&b| b == 0xff).next().unwrap() == presence.sender.to_string().as_bytes() key.rsplit(|&b| b == 0xff)
.next()
.expect("rsplit always returns an element")
== presence.sender.to_string().as_bytes()
}) })
{ {
// This is the old global_latest // This is the old global_latest
@ -32,8 +35,10 @@ impl GlobalEdus {
presence_id.push(0xff); presence_id.push(0xff);
presence_id.extend_from_slice(&presence.sender.to_string().as_bytes()); presence_id.extend_from_slice(&presence.sender.to_string().as_bytes());
self.presenceid_presence self.presenceid_presence.insert(
.insert(presence_id, &*serde_json::to_string(&presence)?)?; presence_id,
&*serde_json::to_string(&presence).expect("PresenceEvent can be serialized"),
)?;
Ok(()) Ok(())
} }
@ -50,6 +55,9 @@ impl GlobalEdus {
.presenceid_presence .presenceid_presence
.range(&*first_possible_edu..) .range(&*first_possible_edu..)
.filter_map(|r| r.ok()) .filter_map(|r| r.ok())
.map(|(_, v)| Ok(serde_json::from_slice(&v)?))) .map(|(_, v)| {
Ok(serde_json::from_slice(&v)
.map_err(|_| Error::bad_database("Invalid presence event in db."))?)
}))
} }
} }

@ -1,4 +1,4 @@
use crate::{utils, Result}; use crate::{utils, Error, Result};
pub const COUNTER: &str = "c"; pub const COUNTER: &str = "c";
@ -11,17 +11,16 @@ pub struct Globals {
} }
impl Globals { impl Globals {
pub fn load(globals: sled::Tree, config: &rocket::Config) -> Self { pub fn load(globals: sled::Tree, config: &rocket::Config) -> Result<Self> {
let keypair = ruma::signatures::Ed25519KeyPair::new( let keypair = ruma::signatures::Ed25519KeyPair::new(
&*globals &*globals
.update_and_fetch("keypair", utils::generate_keypair) .update_and_fetch("keypair", utils::generate_keypair)?
.unwrap() .expect("utils::generate_keypair always returns Some"),
.unwrap(),
"key1".to_owned(), "key1".to_owned(),
) )
.unwrap(); .map_err(|_| Error::bad_database("Private or public keys are invalid."))?;
Self { Ok(Self {
globals, globals,
keypair, keypair,
reqwest_client: reqwest::Client::new(), reqwest_client: reqwest::Client::new(),
@ -30,7 +29,7 @@ impl Globals {
.unwrap_or("localhost") .unwrap_or("localhost")
.to_owned(), .to_owned(),
registration_disabled: config.get_bool("registration_disabled").unwrap_or(false), registration_disabled: config.get_bool("registration_disabled").unwrap_or(false),
} })
} }
/// Returns this server's keypair. /// Returns this server's keypair.
@ -49,14 +48,15 @@ impl Globals {
.globals .globals
.update_and_fetch(COUNTER, utils::increment)? .update_and_fetch(COUNTER, utils::increment)?
.expect("utils::increment will always put in a value"), .expect("utils::increment will always put in a value"),
)) )
.map_err(|_| Error::bad_database("Count has invalid bytes."))?)
} }
pub fn current_count(&self) -> Result<u64> { pub fn current_count(&self) -> Result<u64> {
Ok(self self.globals.get(COUNTER)?.map_or(Ok(0_u64), |bytes| {
.globals Ok(utils::u64_from_bytes(&bytes)
.get(COUNTER)? .map_err(|_| Error::bad_database("Count has invalid bytes."))?)
.map_or(0_u64, |bytes| utils::u64_from_bytes(&bytes))) })
} }
pub fn server_name(&self) -> &str { pub fn server_name(&self) -> &str {

@ -43,16 +43,20 @@ impl Media {
let content_type = utils::string_from_bytes( let content_type = utils::string_from_bytes(
parts parts
.next() .next()
.ok_or(Error::BadDatabase("mediaid is invalid"))?, .ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?,
)?; )
.map_err(|_| Error::bad_database("Content type in mediaid_file is invalid unicode."))?;
let filename_bytes = parts let filename_bytes = parts
.next() .next()
.ok_or(Error::BadDatabase("mediaid is invalid"))?; .ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?;
let filename = if filename_bytes.is_empty() { let filename = if filename_bytes.is_empty() {
None None
} else { } else {
Some(utils::string_from_bytes(filename_bytes)?) Some(utils::string_from_bytes(filename_bytes).map_err(|_| {
Error::bad_database("Filename in mediaid_file is invalid unicode.")
})?)
}; };
Ok(Some((filename, content_type, file.to_vec()))) Ok(Some((filename, content_type, file.to_vec())))
@ -89,16 +93,21 @@ impl Media {
let content_type = utils::string_from_bytes( let content_type = utils::string_from_bytes(
parts parts
.next() .next()
.ok_or(Error::BadDatabase("mediaid is invalid"))?, .ok_or_else(|| Error::bad_database("Invalid Media ID in db"))?,
)?; )
.map_err(|_| Error::bad_database("Content type in mediaid_file is invalid unicode."))?;
let filename_bytes = parts let filename_bytes = parts
.next() .next()
.ok_or(Error::BadDatabase("mediaid is invalid"))?; .ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?;
let filename = if filename_bytes.is_empty() { let filename = if filename_bytes.is_empty() {
None None
} else { } else {
Some(utils::string_from_bytes(filename_bytes)?) Some(
utils::string_from_bytes(filename_bytes)
.map_err(|_| Error::bad_database("Filename in db is invalid."))?,
)
}; };
Ok(Some((filename, content_type, file.to_vec()))) Ok(Some((filename, content_type, file.to_vec())))
@ -110,16 +119,20 @@ impl Media {
let content_type = utils::string_from_bytes( let content_type = utils::string_from_bytes(
parts parts
.next() .next()
.ok_or(Error::BadDatabase("mediaid is invalid"))?, .ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?,
)?; )
.map_err(|_| Error::bad_database("Content type in mediaid_file is invalid unicode."))?;
let filename_bytes = parts let filename_bytes = parts
.next() .next()
.ok_or(Error::BadDatabase("mediaid is invalid"))?; .ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?;
let filename = if filename_bytes.is_empty() { let filename = if filename_bytes.is_empty() {
None None
} else { } else {
Some(utils::string_from_bytes(filename_bytes)?) Some(utils::string_from_bytes(filename_bytes).map_err(|_| {
Error::bad_database("Filename in mediaid_file is invalid unicode.")
})?)
}; };
if let Ok(image) = image::load_from_memory(&file) { if let Ok(image) = image::load_from_memory(&file) {
@ -132,7 +145,7 @@ impl Media {
let width_index = thumbnail_key let width_index = thumbnail_key
.iter() .iter()
.position(|&b| b == 0xff) .position(|&b| b == 0xff)
.ok_or(Error::BadDatabase("mediaid is invalid"))? .ok_or_else(|| Error::bad_database("Media in db is invalid."))?
+ 1; + 1;
let mut widthheight = width.to_be_bytes().to_vec(); let mut widthheight = width.to_be_bytes().to_vec();
widthheight.extend_from_slice(&height.to_be_bytes()); widthheight.extend_from_slice(&height.to_be_bytes());

@ -5,6 +5,7 @@ pub use edus::RoomEdus;
use crate::{utils, Error, PduEvent, Result}; use crate::{utils, Error, PduEvent, Result};
use log::error; use log::error;
use ruma::{ use ruma::{
api::client::error::ErrorKind,
events::{ events::{
room::{ room::{
join_rules, member, join_rules, member,
@ -61,30 +62,34 @@ impl Rooms {
.roomstateid_pdu .roomstateid_pdu
.scan_prefix(&room_id.to_string().as_bytes()) .scan_prefix(&room_id.to_string().as_bytes())
.values() .values()
.map(|value| Ok::<_, Error>(serde_json::from_slice::<PduEvent>(&value?)?)) .map(|value| {
Ok::<_, Error>(
serde_json::from_slice::<PduEvent>(&value?)
.map_err(|_| Error::bad_database("Invalid PDU in db."))?,
)
})
{ {
let pdu = pdu?; let pdu = pdu?;
hashmap.insert( let state_key = pdu.state_key.clone().ok_or_else(|| {
( Error::bad_database("Room state contains event without state_key.")
pdu.kind.clone(), })?;
pdu.state_key hashmap.insert((pdu.kind.clone(), state_key), pdu);
.clone()
.expect("state events have a state key"),
),
pdu,
);
} }
Ok(hashmap) Ok(hashmap)
} }
/// Returns the `count` of this pdu's id. /// Returns the `count` of this pdu's id.
pub fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<u64>> { pub fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<u64>> {
Ok(self self.eventid_pduid
.eventid_pduid
.get(event_id.to_string().as_bytes())? .get(event_id.to_string().as_bytes())?
.map(|pdu_id| { .map_or(Ok(None), |pdu_id| {
utils::u64_from_bytes(&pdu_id[pdu_id.len() - mem::size_of::<u64>()..pdu_id.len()]) Ok(Some(
})) utils::u64_from_bytes(
&pdu_id[pdu_id.len() - mem::size_of::<u64>()..pdu_id.len()],
)
.map_err(|_| Error::bad_database("PDU has invalid count bytes."))?,
))
})
} }
/// Returns the json of a pdu. /// Returns the json of a pdu.
@ -92,11 +97,12 @@ impl Rooms {
self.eventid_pduid self.eventid_pduid
.get(event_id.to_string().as_bytes())? .get(event_id.to_string().as_bytes())?
.map_or(Ok(None), |pdu_id| { .map_or(Ok(None), |pdu_id| {
Ok(Some(serde_json::from_slice( Ok(Some(
&self.pduid_pdu.get(pdu_id)?.ok_or(Error::BadDatabase( serde_json::from_slice(&self.pduid_pdu.get(pdu_id)?.ok_or_else(|| {
"eventid_pduid points to nonexistent pdu", Error::bad_database("eventid_pduid points to nonexistent pdu.")
))?, })?)
)?)) .map_err(|_| Error::bad_database("Invalid PDU in db."))?,
))
}) })
} }
@ -112,28 +118,37 @@ impl Rooms {
self.eventid_pduid self.eventid_pduid
.get(event_id.to_string().as_bytes())? .get(event_id.to_string().as_bytes())?
.map_or(Ok(None), |pdu_id| { .map_or(Ok(None), |pdu_id| {
Ok(Some(serde_json::from_slice( Ok(Some(
&self.pduid_pdu.get(pdu_id)?.ok_or(Error::BadDatabase( serde_json::from_slice(&self.pduid_pdu.get(pdu_id)?.ok_or_else(|| {
"eventid_pduid points to nonexistent pdu", Error::bad_database("eventid_pduid points to nonexistent pdu.")
))?, })?)
)?)) .map_err(|_| Error::bad_database("Invalid PDU in db."))?,
))
}) })
} }
/// Returns the pdu. /// Returns the pdu.
pub fn get_pdu_from_id(&self, pdu_id: &IVec) -> Result<Option<PduEvent>> { pub fn get_pdu_from_id(&self, pdu_id: &IVec) -> Result<Option<PduEvent>> {
self.pduid_pdu self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
.get(pdu_id)? Ok(Some(
.map_or(Ok(None), |pdu| Ok(Some(serde_json::from_slice(&pdu)?))) serde_json::from_slice(&pdu)
.map_err(|_| Error::bad_database("Invalid PDU in db."))?,
))
})
} }
/// Returns the pdu. /// Removes a pdu and creates a new one with the same id.
pub fn replace_pdu(&self, pdu_id: &IVec, pdu: &PduEvent) -> Result<()> { fn replace_pdu(&self, pdu_id: &IVec, pdu: &PduEvent) -> Result<()> {
if self.pduid_pdu.get(&pdu_id)?.is_some() { if self.pduid_pdu.get(&pdu_id)?.is_some() {
self.pduid_pdu self.pduid_pdu.insert(
.insert(&pdu_id, &*serde_json::to_string(pdu)?)?; &pdu_id,
&*serde_json::to_string(pdu).expect("PduEvent::to_string always works"),
)?;
Ok(()) Ok(())
} else { } else {
Err(Error::BadRequest("pdu does not exist")) Err(Error::BadRequest(
ErrorKind::NotFound,
"PDU does not exist.",
))
} }
} }
@ -148,7 +163,14 @@ impl Rooms {
.roomid_pduleaves .roomid_pduleaves
.scan_prefix(prefix) .scan_prefix(prefix)
.values() .values()
.map(|bytes| Ok::<_, Error>(EventId::try_from(&*utils::string_from_bytes(&bytes?)?)?)) .map(|bytes| {
Ok::<_, Error>(
EventId::try_from(utils::string_from_bytes(&bytes?).map_err(|_| {
Error::bad_database("EventID in roomid_pduleaves is invalid unicode.")
})?)
.map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid."))?,
)
})
{ {
events.push(event?); events.push(event?);
} }
@ -214,174 +236,205 @@ impl Rooms {
Ok( Ok(
serde_json::from_value::<EventJson<PowerLevelsEventContent>>( serde_json::from_value::<EventJson<PowerLevelsEventContent>>(
power_levels.content.clone(), power_levels.content.clone(),
)? )
.deserialize()?, .expect("EventJson::from_value always works.")
.deserialize()
.map_err(|_| Error::bad_database("Invalid PowerLevels event in db."))?,
) )
}, },
)?; )?;
{ let sender_membership = self
let sender_membership = self .room_state(&room_id)?
.room_state(&room_id)? .get(&(EventType::RoomMember, sender.to_string()))
.get(&(EventType::RoomMember, sender.to_string())) .map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| {
.map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| { Ok(
Ok( serde_json::from_value::<EventJson<member::MemberEventContent>>(
serde_json::from_value::<EventJson<member::MemberEventContent>>( pdu.content.clone(),
pdu.content.clone(), )
)? .expect("EventJson::from_value always works.")
.deserialize()? .deserialize()
.membership, .map_err(|_| Error::bad_database("Invalid Member event in db."))?
.membership,
)
})?;
let sender_power = power_levels.users.get(&sender).map_or_else(
|| {
if sender_membership != member::MembershipState::Join {
None
} else {
Some(&power_levels.users_default)
}
},
// If it's okay, wrap with Some(_)
Some,
);
if !match event_type {
EventType::RoomMember => {
let target_user_id = UserId::try_from(&**state_key).map_err(|_| {
Error::BadRequest(
ErrorKind::InvalidParam,
"State key of member event does not contain user id.",
) )
})?; })?;
let sender_power = power_levels.users.get(&sender).map_or_else( let current_membership = self
|| { .room_state(&room_id)?
if sender_membership != member::MembershipState::Join { .get(&(EventType::RoomMember, target_user_id.to_string()))
None .map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| {
} else { Ok(
Some(&power_levels.users_default) serde_json::from_value::<EventJson<member::MemberEventContent>>(
} pdu.content.clone(),
}, )
// If it's okay, wrap with Some(_) .expect("EventJson::from_value always works.")
Some, .deserialize()
); .map_err(|_| Error::bad_database("Invalid Member event in db."))?
.membership,
if !match event_type { )
EventType::RoomMember => { })?;
let target_user_id = UserId::try_from(&**state_key)?;
let target_membership = serde_json::from_value::<
let current_membership = self EventJson<member::MemberEventContent>,
.room_state(&room_id)? >(content.clone())
.get(&(EventType::RoomMember, target_user_id.to_string())) .expect("EventJson::from_value always works.")
.map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| { .deserialize()
Ok(serde_json::from_value::< .map_err(|_| Error::bad_database("Invalid Member event in db."))?
EventJson<member::MemberEventContent>, .membership;
>(pdu.content.clone())?
.deserialize()? let target_power = power_levels.users.get(&target_user_id).map_or_else(
.membership) || {
})?; if target_membership != member::MembershipState::Join {
None
} else {
Some(&power_levels.users_default)
}
},
// If it's okay, wrap with Some(_)
Some,
);
let target_membership = serde_json::from_value::< let join_rules =
EventJson<member::MemberEventContent>, self.room_state(&room_id)?
>(content.clone())?
.deserialize()?
.membership;
let target_power = power_levels.users.get(&target_user_id).map_or_else(
|| {
if target_membership != member::MembershipState::Join {
None
} else {
Some(&power_levels.users_default)
}
},
// If it's okay, wrap with Some(_)
Some,
);
let join_rules = self
.room_state(&room_id)?
.get(&(EventType::RoomJoinRules, "".to_owned())) .get(&(EventType::RoomJoinRules, "".to_owned()))
.map_or(join_rules::JoinRule::Public, |pdu| { .map_or(Ok::<_, Error>(join_rules::JoinRule::Public), |pdu| {
serde_json::from_value::< Ok(serde_json::from_value::<
EventJson<join_rules::JoinRulesEventContent>, EventJson<join_rules::JoinRulesEventContent>,
>(pdu.content.clone()) >(pdu.content.clone())
.unwrap() .expect("EventJson::from_value always works.")
.deserialize() .deserialize()
.unwrap() .map_err(|_| {
.join_rule Error::bad_database("Database contains invalid JoinRules event")
}); })?
.join_rule)
let authorized = if target_membership == member::MembershipState::Join { })?;
let mut prev_events = prev_events.iter();
let prev_event = self let authorized = if target_membership == member::MembershipState::Join {
.get_pdu(prev_events.next().ok_or(Error::BadRequest( let mut prev_events = prev_events.iter();
"membership can't be the first event", let prev_event = self
))?)? .get_pdu(prev_events.next().ok_or(Error::BadRequest(
.ok_or(Error::BadDatabase("pdu leave points to valid event"))?; ErrorKind::Unknown,
if prev_event.kind == EventType::RoomCreate "Membership can't be the first event",
&& prev_event.prev_events.is_empty() ))?)?
{ .ok_or_else(|| {
true Error::bad_database("PDU leaf points to invalid event!")
} else if sender != target_user_id { })?;
false if prev_event.kind == EventType::RoomCreate
} else if let member::MembershipState::Ban = current_membership { && prev_event.prev_events.is_empty()
false {
} else { true
join_rules == join_rules::JoinRule::Invite } else if sender != target_user_id {
&& (current_membership == member::MembershipState::Join false
|| current_membership == member::MembershipState::Invite) } else if let member::MembershipState::Ban = current_membership {
|| join_rules == join_rules::JoinRule::Public false
} } else {
} else if target_membership == member::MembershipState::Invite { join_rules == join_rules::JoinRule::Invite
if let Some(third_party_invite_json) = content.get("third_party_invite") && (current_membership == member::MembershipState::Join
{ || current_membership == member::MembershipState::Invite)
if current_membership == member::MembershipState::Ban { || join_rules == join_rules::JoinRule::Public
false }
} else { } else if target_membership == member::MembershipState::Invite {
let _third_party_invite = if let Some(third_party_invite_json) = content.get("third_party_invite") {
serde_json::from_value::<member::ThirdPartyInvite>( if current_membership == member::MembershipState::Ban {
third_party_invite_json.clone(),
)?;
todo!("handle third party invites");
}
} else if sender_membership != member::MembershipState::Join
|| current_membership == member::MembershipState::Join
|| current_membership == member::MembershipState::Ban
{
false
} else {
sender_power
.filter(|&p| p >= &power_levels.invite)
.is_some()
}
} else if target_membership == member::MembershipState::Leave {
if sender == target_user_id {
current_membership == member::MembershipState::Join
|| current_membership == member::MembershipState::Invite
} else if sender_membership != member::MembershipState::Join
|| current_membership == member::MembershipState::Ban
&& sender_power.filter(|&p| p < &power_levels.ban).is_some()
{
false
} else {
sender_power.filter(|&p| p >= &power_levels.kick).is_some()
&& target_power < sender_power
}
} else if target_membership == member::MembershipState::Ban {
if sender_membership != member::MembershipState::Join {
false false
} else { } else {
sender_power.filter(|&p| p >= &power_levels.ban).is_some() let _third_party_invite =
&& target_power < sender_power serde_json::from_value::<member::ThirdPartyInvite>(
third_party_invite_json.clone(),
)
.map_err(|_| {
Error::BadRequest(
ErrorKind::InvalidParam,
"ThirdPartyInvite is invalid",
)
})?;
todo!("handle third party invites");
} }
} else if sender_membership != member::MembershipState::Join
|| current_membership == member::MembershipState::Join
|| current_membership == member::MembershipState::Ban
{
false
} else { } else {
sender_power
.filter(|&p| p >= &power_levels.invite)
.is_some()
}
} else if target_membership == member::MembershipState::Leave {
if sender == target_user_id {
current_membership == member::MembershipState::Join
|| current_membership == member::MembershipState::Invite
} else if sender_membership != member::MembershipState::Join
|| current_membership == member::MembershipState::Ban
&& sender_power.filter(|&p| p < &power_levels.ban).is_some()
{
false false
}; } else {
sender_power.filter(|&p| p >= &power_levels.kick).is_some()
if authorized { && target_power < sender_power
// Update our membership info
self.update_membership(&room_id, &target_user_id, &target_membership)?;
} }
} else if target_membership == member::MembershipState::Ban {
if sender_membership != member::MembershipState::Join {
false
} else {
sender_power.filter(|&p| p >= &power_levels.ban).is_some()
&& target_power < sender_power
}
} else {
false
};
authorized if authorized {
// Update our membership info
self.update_membership(&room_id, &target_user_id, &target_membership)?;
} }
EventType::RoomCreate => prev_events.is_empty(),
// Not allow any of the following events if the sender is not joined. authorized
_ if sender_membership != member::MembershipState::Join => false,
_ => {
// TODO
sender_power.unwrap_or(&power_levels.users_default)
>= &power_levels.state_default
}
} {
error!("Unauthorized");
// Not authorized
return Err(Error::BadRequest("event not authorized"));
} }
EventType::RoomCreate => prev_events.is_empty(),
// Not allow any of the following events if the sender is not joined.
_ if sender_membership != member::MembershipState::Join => false,
_ => {
// TODO
sender_power.unwrap_or(&power_levels.users_default)
>= &power_levels.state_default
}
} {
error!("Unauthorized");
// Not authorized
return Err(Error::BadRequest(
ErrorKind::Forbidden,
"Event is not authorized",
));
} }
} else if !self.is_joined(&sender, &room_id)? { } else if !self.is_joined(&sender, &room_id)? {
return Err(Error::BadRequest("event not authorized")); // TODO: auth rules apply to all events, not only those with a state key
error!("Unauthorized");
return Err(Error::BadRequest(
ErrorKind::Forbidden,
"Event is not authorized",
));
} }
// Our depth is the maximum depth of prev_events + 1 // Our depth is the maximum depth of prev_events + 1
@ -410,14 +463,14 @@ impl Rooms {
origin: globals.server_name().to_owned(), origin: globals.server_name().to_owned(),
origin_server_ts: utils::millis_since_unix_epoch() origin_server_ts: utils::millis_since_unix_epoch()
.try_into() .try_into()
.expect("this only fails many years in the future"), .expect("time is valid"),
kind: event_type.clone(), kind: event_type.clone(),
content: content.clone(), content: content.clone(),
state_key, state_key,
prev_events, prev_events,
depth: depth depth: depth
.try_into() .try_into()
.expect("depth can overflow and should be deprecated..."), .map_err(|_| Error::bad_database("Depth is invalid"))?,
auth_events: Vec::new(), auth_events: Vec::new(),
redacts: redacts.clone(), redacts: redacts.clone(),
unsigned, unsigned,
@ -430,18 +483,20 @@ impl Rooms {
// Generate event id // Generate event id
pdu.event_id = EventId::try_from(&*format!( pdu.event_id = EventId::try_from(&*format!(
"${}", "${}",
ruma::signatures::reference_hash(&serde_json::to_value(&pdu)?) ruma::signatures::reference_hash(
.expect("ruma can calculate reference hashes") &serde_json::to_value(&pdu).expect("event is valid, we just created it")
)
.expect("ruma can calculate reference hashes")
)) ))
.expect("ruma's reference hashes are correct"); .expect("ruma's reference hashes are valid event ids");
let mut pdu_json = serde_json::to_value(&pdu)?; let mut pdu_json = serde_json::to_value(&pdu).expect("event is valid, we just created it");
ruma::signatures::hash_and_sign_event( ruma::signatures::hash_and_sign_event(
globals.server_name(), globals.server_name(),
globals.keypair(), globals.keypair(),
&mut pdu_json, &mut pdu_json,
) )
.expect("our new event can be hashed and signed"); .expect("event is valid, we just created it");
self.replace_pdu_leaves(&room_id, &pdu.event_id)?; self.replace_pdu_leaves(&room_id, &pdu.event_id)?;
@ -473,8 +528,15 @@ impl Rooms {
// TODO: Reason // TODO: Reason
let _reason = serde_json::from_value::< let _reason = serde_json::from_value::<
EventJson<redaction::RedactionEventContent>, EventJson<redaction::RedactionEventContent>,
>(content)? >(content)
.deserialize()? .expect("EventJson::from_value always works.")
.deserialize()
.map_err(|_| {
Error::BadRequest(
ErrorKind::InvalidParam,
"Invalid redaction event content.",
)
})?
.reason; .reason;
self.redact_pdu(&redact_id)?; self.redact_pdu(&redact_id)?;
@ -528,7 +590,10 @@ impl Rooms {
}) })
.filter_map(|r| r.ok()) .filter_map(|r| r.ok())
.take_while(move |(k, _)| k.starts_with(&prefix)) .take_while(move |(k, _)| k.starts_with(&prefix))
.map(|(_, v)| Ok(serde_json::from_slice(&v)?))) .map(|(_, v)| {
Ok(serde_json::from_slice(&v)
.map_err(|_| Error::bad_database("PDU in db is invalid."))?)
}))
} }
/// Returns an iterator over all events in a room that happened before the event with id /// Returns an iterator over all events in a room that happened before the event with id
@ -552,7 +617,10 @@ impl Rooms {
.rev() .rev()
.filter_map(|r| r.ok()) .filter_map(|r| r.ok())
.take_while(move |(k, _)| k.starts_with(&prefix)) .take_while(move |(k, _)| k.starts_with(&prefix))
.map(|(_, v)| Ok(serde_json::from_slice(&v)?)) .map(|(_, v)| {
Ok(serde_json::from_slice(&v)
.map_err(|_| Error::bad_database("PDU in db is invalid."))?)
})
} }
/// Returns an iterator over all events in a room that happened after the event with id /// Returns an iterator over all events in a room that happened after the event with id
@ -575,7 +643,10 @@ impl Rooms {
.range(current..) .range(current..)
.filter_map(|r| r.ok()) .filter_map(|r| r.ok())
.take_while(move |(k, _)| k.starts_with(&prefix)) .take_while(move |(k, _)| k.starts_with(&prefix))
.map(|(_, v)| Ok(serde_json::from_slice(&v)?)) .map(|(_, v)| {
Ok(serde_json::from_slice(&v)
.map_err(|_| Error::bad_database("PDU in db is invalid."))?)
})
} }
/// Replace a PDU with the redacted form. /// Replace a PDU with the redacted form.
@ -583,12 +654,15 @@ impl Rooms {
if let Some(pdu_id) = self.get_pdu_id(event_id)? { if let Some(pdu_id) = self.get_pdu_id(event_id)? {
let mut pdu = self let mut pdu = self
.get_pdu_from_id(&pdu_id)? .get_pdu_from_id(&pdu_id)?
.ok_or(Error::BadDatabase("pduid points to invalid pdu"))?; .ok_or_else(|| Error::bad_database("PDU ID points to invalid PDU."))?;
pdu.redact(); pdu.redact()?;
self.replace_pdu(&pdu_id, &pdu)?; self.replace_pdu(&pdu_id, &pdu)?;
Ok(()) Ok(())
} else { } else {
Err(Error::BadRequest("eventid does not exist")) Err(Error::BadRequest(
ErrorKind::NotFound,
"Event ID does not exist.",
))
} }
} }
@ -664,7 +738,10 @@ impl Rooms {
let room_id = self let room_id = self
.alias_roomid .alias_roomid
.remove(alias.alias())? .remove(alias.alias())?
.ok_or(Error::BadRequest("Alias does not exist"))?; .ok_or(Error::BadRequest(
ErrorKind::NotFound,
"Alias does not exist.",
))?;
for key in self.aliasid_alias.scan_prefix(room_id).keys() { for key in self.aliasid_alias.scan_prefix(room_id).keys() {
self.aliasid_alias.remove(key?)?; self.aliasid_alias.remove(key?)?;
@ -678,7 +755,12 @@ impl Rooms {
self.alias_roomid self.alias_roomid
.get(alias.alias())? .get(alias.alias())?
.map_or(Ok(None), |bytes| { .map_or(Ok(None), |bytes| {
Ok(Some(RoomId::try_from(utils::string_from_bytes(&bytes)?)?)) Ok(Some(
RoomId::try_from(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Room ID in alias_roomid is invalid unicode.")
})?)
.map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid."))?,
))
}) })
} }
@ -689,7 +771,10 @@ impl Rooms {
self.aliasid_alias self.aliasid_alias
.scan_prefix(prefix) .scan_prefix(prefix)
.values() .values()
.map(|bytes| Ok(RoomAliasId::try_from(utils::string_from_bytes(&bytes?)?)?)) .map(|bytes| {
Ok(serde_json::from_slice(&bytes?)
.map_err(|_| Error::bad_database("Alias in aliasid_alias is invalid."))?)
})
} }
pub fn set_public(&self, room_id: &RoomId, public: bool) -> Result<()> { pub fn set_public(&self, room_id: &RoomId, public: bool) -> Result<()> {
@ -707,54 +792,76 @@ impl Rooms {
} }
pub fn public_rooms(&self) -> impl Iterator<Item = Result<RoomId>> { pub fn public_rooms(&self) -> impl Iterator<Item = Result<RoomId>> {
self.publicroomids self.publicroomids.iter().keys().map(|bytes| {
.iter() Ok(
.keys() RoomId::try_from(utils::string_from_bytes(&bytes?).map_err(|_| {
.map(|bytes| Ok(RoomId::try_from(utils::string_from_bytes(&bytes?)?)?)) Error::bad_database("Room ID in publicroomids is invalid unicode.")
})?)
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid."))?,
)
})
} }
/// Returns an iterator over all rooms a user joined. /// Returns an iterator over all joined members of a room.
pub fn room_members(&self, room_id: &RoomId) -> impl Iterator<Item = Result<UserId>> { pub fn room_members(&self, room_id: &RoomId) -> impl Iterator<Item = Result<UserId>> {
self.roomuserid_joined self.roomuserid_joined
.scan_prefix(room_id.to_string()) .scan_prefix(room_id.to_string())
.values() .keys()
.map(|key| { .map(|key| {
Ok(UserId::try_from(&*utils::string_from_bytes( Ok(UserId::try_from(
&key? utils::string_from_bytes(
.rsplit(|&b| b == 0xff) &key?
.next() .rsplit(|&b| b == 0xff)
.ok_or(Error::BadDatabase("userroomid is invalid"))?, .next()
)?)?) .expect("rsplit always returns an element"),
)
.map_err(|_| {
Error::bad_database("User ID in roomuserid_joined is invalid unicode.")
})?,
)
.map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid."))?)
}) })
} }
/// Returns an iterator over all rooms a user joined. /// Returns an iterator over all invited members of a room.
pub fn room_members_invited(&self, room_id: &RoomId) -> impl Iterator<Item = Result<UserId>> { pub fn room_members_invited(&self, room_id: &RoomId) -> impl Iterator<Item = Result<UserId>> {
self.roomuserid_invited self.roomuserid_invited
.scan_prefix(room_id.to_string()) .scan_prefix(room_id.to_string())
.keys() .keys()
.map(|key| { .map(|key| {
Ok(UserId::try_from(&*utils::string_from_bytes( Ok(UserId::try_from(
&key? utils::string_from_bytes(
.rsplit(|&b| b == 0xff) &key?
.next() .rsplit(|&b| b == 0xff)
.ok_or(Error::BadDatabase("userroomid is invalid"))?, .next()
)?)?) .expect("rsplit always returns an element"),
)
.map_err(|_| {
Error::bad_database("User ID in roomuserid_invited is invalid unicode.")
})?,
)
.map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid."))?)
}) })
} }
/// Returns an iterator over all rooms a user joined. /// Returns an iterator over all left members of a room.
pub fn rooms_joined(&self, user_id: &UserId) -> impl Iterator<Item = Result<RoomId>> { pub fn rooms_joined(&self, user_id: &UserId) -> impl Iterator<Item = Result<RoomId>> {
self.userroomid_joined self.userroomid_joined
.scan_prefix(user_id.to_string()) .scan_prefix(user_id.to_string())
.keys() .keys()
.map(|key| { .map(|key| {
Ok(RoomId::try_from(&*utils::string_from_bytes( Ok(RoomId::try_from(
&key? utils::string_from_bytes(
.rsplit(|&b| b == 0xff) &key?
.next() .rsplit(|&b| b == 0xff)
.ok_or(Error::BadDatabase("userroomid is invalid"))?, .next()
)?)?) .expect("rsplit always returns an element"),
)
.map_err(|_| {
Error::bad_database("Room ID in userroomid_joined is invalid unicode.")
})?,
)
.map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid."))?)
}) })
} }
@ -764,12 +871,18 @@ impl Rooms {
.scan_prefix(&user_id.to_string()) .scan_prefix(&user_id.to_string())
.keys() .keys()
.map(|key| { .map(|key| {
Ok(RoomId::try_from(&*utils::string_from_bytes( Ok(RoomId::try_from(
&key? utils::string_from_bytes(
.rsplit(|&b| b == 0xff) &key?
.next() .rsplit(|&b| b == 0xff)
.ok_or(Error::BadDatabase("userroomid is invalid"))?, .next()
)?)?) .expect("rsplit always returns an element"),
)
.map_err(|_| {
Error::bad_database("Room ID in userroomid_invited is invalid unicode.")
})?,
)
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?)
}) })
} }
@ -779,12 +892,18 @@ impl Rooms {
.scan_prefix(&user_id.to_string()) .scan_prefix(&user_id.to_string())
.keys() .keys()
.map(|key| { .map(|key| {
Ok(RoomId::try_from(&*utils::string_from_bytes( Ok(RoomId::try_from(
&key? utils::string_from_bytes(
.rsplit(|&b| b == 0xff) &key?
.next() .rsplit(|&b| b == 0xff)
.ok_or(Error::BadDatabase("userroomid is invalid"))?, .next()
)?)?) .expect("rsplit always returns an element"),
)
.map_err(|_| {
Error::bad_database("Room ID in userroomid_left is invalid unicode.")
})?,
)
.map_err(|_| Error::bad_database("Room ID in userroomid_left is invalid."))?)
}) })
} }

@ -33,7 +33,10 @@ impl RoomEdus {
.filter_map(|r| r.ok()) .filter_map(|r| r.ok())
.take_while(|key| key.starts_with(&prefix)) .take_while(|key| key.starts_with(&prefix))
.find(|key| { .find(|key| {
key.rsplit(|&b| b == 0xff).next().unwrap() == user_id.to_string().as_bytes() key.rsplit(|&b| b == 0xff)
.next()
.expect("rsplit always returns an element")
== user_id.to_string().as_bytes()
}) })
{ {
// This is the old room_latest // This is the old room_latest
@ -45,8 +48,10 @@ impl RoomEdus {
room_latest_id.push(0xff); room_latest_id.push(0xff);
room_latest_id.extend_from_slice(&user_id.to_string().as_bytes()); room_latest_id.extend_from_slice(&user_id.to_string().as_bytes());
self.roomlatestid_roomlatest self.roomlatestid_roomlatest.insert(
.insert(room_latest_id, &*serde_json::to_string(&event)?)?; room_latest_id,
&*serde_json::to_string(&event).expect("EduEvent::to_string always works"),
)?;
Ok(()) Ok(())
} }
@ -68,7 +73,11 @@ impl RoomEdus {
.range(&*first_possible_edu..) .range(&*first_possible_edu..)
.filter_map(|r| r.ok()) .filter_map(|r| r.ok())
.take_while(move |(k, _)| k.starts_with(&prefix)) .take_while(move |(k, _)| k.starts_with(&prefix))
.map(|(_, v)| Ok(serde_json::from_slice(&v)?))) .map(|(_, v)| {
Ok(serde_json::from_slice(&v).map_err(|_| {
Error::bad_database("Read receipt in roomlatestid_roomlatest is invalid.")
})?)
}))
} }
/// Sets a user as typing until the timeout timestamp is reached or roomactive_remove is /// Sets a user as typing until the timeout timestamp is reached or roomactive_remove is
@ -152,17 +161,21 @@ impl RoomEdus {
.roomactiveid_userid .roomactiveid_userid
.scan_prefix(&prefix) .scan_prefix(&prefix)
.keys() .keys()
.filter_map(|r| r.ok()) .map(|key| {
.take_while(|k| { let key = key?;
utils::u64_from_bytes( Ok::<_, Error>((
k.split(|&c| c == 0xff) key.clone(),
.nth(1) utils::u64_from_bytes(key.split(|&b| b == 0xff).nth(1).ok_or_else(|| {
.expect("roomactive has valid timestamp and delimiters"), Error::bad_database("RoomActive has invalid timestamp or delimiters.")
) < current_timestamp })?)
.map_err(|_| Error::bad_database("RoomActive has invalid timestamp bytes."))?,
))
}) })
.filter_map(|r| r.ok())
.take_while(|&(_, timestamp)| timestamp < current_timestamp)
{ {
// This is an outdated edu (time > timestamp) // This is an outdated edu (time > timestamp)
self.roomactiveid_userid.remove(outdated_edu)?; self.roomactiveid_userid.remove(outdated_edu.0)?;
found_outdated = true; found_outdated = true;
} }
@ -187,7 +200,11 @@ impl RoomEdus {
Ok(self Ok(self
.roomid_lastroomactiveupdate .roomid_lastroomactiveupdate
.get(&room_id.to_string().as_bytes())? .get(&room_id.to_string().as_bytes())?
.map(|bytes| utils::u64_from_bytes(&bytes)) .map_or(Ok::<_, Error>(None), |bytes| {
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.")
})?))
})?
.unwrap_or(0)) .unwrap_or(0))
} }
@ -202,7 +219,16 @@ impl RoomEdus {
.roomactiveid_userid .roomactiveid_userid
.scan_prefix(prefix) .scan_prefix(prefix)
.values() .values()
.map(|user_id| Ok::<_, Error>(UserId::try_from(utils::string_from_bytes(&user_id?)?)?)) .map(|user_id| {
Ok::<_, Error>(
UserId::try_from(utils::string_from_bytes(&user_id?).map_err(|_| {
Error::bad_database("User ID in roomactiveid_userid is invalid unicode.")
})?)
.map_err(|_| {
Error::bad_database("User ID in roomactiveid_userid is invalid.")
})?,
)
})
{ {
user_ids.push(user_id?); user_ids.push(user_id?);
} }
@ -230,9 +256,10 @@ impl RoomEdus {
key.push(0xff); key.push(0xff);
key.extend_from_slice(&user_id.to_string().as_bytes()); key.extend_from_slice(&user_id.to_string().as_bytes());
Ok(self self.roomuserid_lastread.get(key)?.map_or(Ok(None), |v| {
.roomuserid_lastread Ok(Some(utils::u64_from_bytes(&v).map_err(|_| {
.get(key)? Error::bad_database("Invalid private read marker bytes")
.map(|v| utils::u64_from_bytes(&v))) })?))
})
} }
} }

@ -43,15 +43,51 @@ impl Uiaa {
// Find out what the user completed // Find out what the user completed
match &**kind { match &**kind {
"m.login.password" => { "m.login.password" => {
if auth_parameters["identifier"]["type"] != "m.id.user" { let identifier = auth_parameters.get("identifier").ok_or(Error::BadRequest(
panic!("identifier not supported"); ErrorKind::MissingParam,
"m.login.password needs identifier.",
))?;
let identifier_type = identifier.get("type").ok_or(Error::BadRequest(
ErrorKind::MissingParam,
"Identifier needs a type.",
))?;
if identifier_type != "m.id.user" {
return Err(Error::BadRequest(
ErrorKind::Unrecognized,
"Identifier type not recognized.",
));
} }
let user_id = UserId::parse_with_server_name( let username = identifier
auth_parameters["identifier"]["user"].as_str().unwrap(), .get("user")
globals.server_name(), .ok_or(Error::BadRequest(
)?; ErrorKind::MissingParam,
let password = auth_parameters["password"].as_str().unwrap(); "Identifier needs user field.",
))?
.as_str()
.ok_or(Error::BadRequest(
ErrorKind::BadJson,
"User is not a string.",
))?;
let user_id = UserId::parse_with_server_name(username, globals.server_name())
.map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid.")
})?;
let password = auth_parameters
.get("password")
.ok_or(Error::BadRequest(
ErrorKind::MissingParam,
"Password is missing.",
))?
.as_str()
.ok_or(Error::BadRequest(
ErrorKind::BadJson,
"Password is not a string.",
))?;
// Check if password is correct // Check if password is correct
if let Some(hash) = users.password_hash(&user_id)? { if let Some(hash) = users.password_hash(&user_id)? {
@ -59,7 +95,6 @@ impl Uiaa {
argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false); argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false);
if !hash_matches { if !hash_matches {
debug!("Invalid password.");
uiaainfo.auth_error = Some(ruma::api::client::error::ErrorBody { uiaainfo.auth_error = Some(ruma::api::client::error::ErrorBody {
kind: ErrorKind::Forbidden, kind: ErrorKind::Forbidden,
message: "Invalid username or password.".to_owned(), message: "Invalid username or password.".to_owned(),
@ -113,8 +148,10 @@ impl Uiaa {
userdeviceid.extend_from_slice(device_id.as_bytes()); userdeviceid.extend_from_slice(device_id.as_bytes());
if let Some(uiaainfo) = uiaainfo { if let Some(uiaainfo) = uiaainfo {
self.userdeviceid_uiaainfo self.userdeviceid_uiaainfo.insert(
.insert(&userdeviceid, &*serde_json::to_string(&uiaainfo)?)?; &userdeviceid,
&*serde_json::to_string(&uiaainfo).expect("UiaaInfo::to_string always works"),
)?;
} else { } else {
self.userdeviceid_uiaainfo.remove(&userdeviceid)?; self.userdeviceid_uiaainfo.remove(&userdeviceid)?;
} }
@ -136,8 +173,12 @@ impl Uiaa {
&self &self
.userdeviceid_uiaainfo .userdeviceid_uiaainfo
.get(&userdeviceid)? .get(&userdeviceid)?
.ok_or(Error::BadRequest("session does not exist"))?, .ok_or(Error::BadRequest(
)?; ErrorKind::Forbidden,
"UIAA session does not exist.",
))?,
)
.map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid."))?;
if uiaainfo if uiaainfo
.session .session
@ -145,7 +186,10 @@ impl Uiaa {
.filter(|&s| s == session) .filter(|&s| s == session)
.is_none() .is_none()
{ {
return Err(Error::BadRequest("wrong session token")); return Err(Error::BadRequest(
ErrorKind::Forbidden,
"UIAA session token invalid.",
));
} }
Ok(uiaainfo) Ok(uiaainfo)

@ -43,24 +43,36 @@ impl Users {
.get(token)? .get(token)?
.map_or(Ok(None), |bytes| { .map_or(Ok(None), |bytes| {
let mut parts = bytes.split(|&b| b == 0xff); let mut parts = bytes.split(|&b| b == 0xff);
let user_bytes = parts let user_bytes = parts.next().ok_or_else(|| {
.next() Error::bad_database("User ID in token_userdeviceid is invalid.")
.ok_or(Error::BadDatabase("token_userdeviceid value invalid"))?; })?;
let device_bytes = parts let device_bytes = parts.next().ok_or_else(|| {
.next() Error::bad_database("Device ID in token_userdeviceid is invalid.")
.ok_or(Error::BadDatabase("token_userdeviceid value invalid"))?; })?;
Ok(Some(( Ok(Some((
UserId::try_from(utils::string_from_bytes(&user_bytes)?)?, UserId::try_from(utils::string_from_bytes(&user_bytes).map_err(|_| {
utils::string_from_bytes(&device_bytes)?, Error::bad_database("User ID in token_userdeviceid is invalid unicode.")
})?)
.map_err(|_| {
Error::bad_database("User ID in token_userdeviceid is invalid.")
})?,
utils::string_from_bytes(&device_bytes).map_err(|_| {
Error::bad_database("Device ID in token_userdeviceid is invalid.")
})?,
))) )))
}) })
} }
/// Returns an iterator over all users on this homeserver. /// Returns an iterator over all users on this homeserver.
pub fn iter(&self) -> impl Iterator<Item = Result<UserId>> { pub fn iter(&self) -> impl Iterator<Item = Result<UserId>> {
self.userid_password.iter().keys().map(|r| { self.userid_password.iter().keys().map(|bytes| {
utils::string_from_bytes(&r?).and_then(|string| Ok(UserId::try_from(&*string)?)) Ok(
UserId::try_from(utils::string_from_bytes(&bytes?).map_err(|_| {
Error::bad_database("User ID in userid_password is invalid unicode.")
})?)
.map_err(|_| Error::bad_database("User ID in userid_password is invalid."))?,
)
}) })
} }
@ -68,14 +80,22 @@ impl Users {
pub fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> { pub fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_password self.userid_password
.get(user_id.to_string())? .get(user_id.to_string())?
.map_or(Ok(None), |bytes| utils::string_from_bytes(&bytes).map(Some)) .map_or(Ok(None), |bytes| {
Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Password hash in db is not valid string.")
})?))
})
} }
/// Returns the displayname of a user on this homeserver. /// Returns the displayname of a user on this homeserver.
pub fn displayname(&self, user_id: &UserId) -> Result<Option<String>> { pub fn displayname(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_displayname self.userid_displayname
.get(user_id.to_string())? .get(user_id.to_string())?
.map_or(Ok(None), |bytes| utils::string_from_bytes(&bytes).map(Some)) .map_or(Ok(None), |bytes| {
Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Displayname in db is invalid.")
})?))
})
} }
/// Sets a new displayname or removes it if displayname is None. You still need to nofify all rooms of this change. /// Sets a new displayname or removes it if displayname is None. You still need to nofify all rooms of this change.
@ -94,7 +114,11 @@ impl Users {
pub fn avatar_url(&self, user_id: &UserId) -> Result<Option<String>> { pub fn avatar_url(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_avatarurl self.userid_avatarurl
.get(user_id.to_string())? .get(user_id.to_string())?
.map_or(Ok(None), |bytes| utils::string_from_bytes(&bytes).map(Some)) .map_or(Ok(None), |bytes| {
Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Avatar URL in db is invalid.")
})?))
})
} }
/// Sets a new avatar_url or removes it if avatar_url is None. /// Sets a new avatar_url or removes it if avatar_url is None.
@ -117,11 +141,8 @@ impl Users {
token: &str, token: &str,
initial_device_display_name: Option<String>, initial_device_display_name: Option<String>,
) -> Result<()> { ) -> Result<()> {
if !self.exists(user_id)? { // This method should never be called for nonexistent users.
return Err(Error::BadRequest( assert!(self.exists(user_id)?);
"tried to create device for nonexistent user",
));
}
let mut userdeviceid = user_id.to_string().as_bytes().to_vec(); let mut userdeviceid = user_id.to_string().as_bytes().to_vec();
userdeviceid.push(0xff); userdeviceid.push(0xff);
@ -134,7 +155,8 @@ impl Users {
display_name: initial_device_display_name, display_name: initial_device_display_name,
last_seen_ip: None, // TODO last_seen_ip: None, // TODO
last_seen_ts: Some(SystemTime::now()), last_seen_ts: Some(SystemTime::now()),
})? })
.expect("Device::to_string never fails.")
.as_bytes(), .as_bytes(),
)?; )?;
@ -185,23 +207,22 @@ impl Users {
&*bytes? &*bytes?
.rsplit(|&b| b == 0xff) .rsplit(|&b| b == 0xff)
.next() .next()
.ok_or(Error::BadDatabase("userdeviceid is invalid"))?, .ok_or_else(|| Error::bad_database("UserDevice ID in db is invalid."))?,
)?) )
.map_err(|_| {
Error::bad_database("Device ID in userdeviceid_metadata is invalid.")
})?)
}) })
} }
/// Replaces the access token of one device. /// Replaces the access token of one device.
pub fn set_token(&self, user_id: &UserId, device_id: &str, token: &str) -> Result<()> { fn set_token(&self, user_id: &UserId, device_id: &str, token: &str) -> Result<()> {
let mut userdeviceid = user_id.to_string().as_bytes().to_vec(); let mut userdeviceid = user_id.to_string().as_bytes().to_vec();
userdeviceid.push(0xff); userdeviceid.push(0xff);
userdeviceid.extend_from_slice(device_id.as_bytes()); userdeviceid.extend_from_slice(device_id.as_bytes());
// All devices have metadata // All devices have metadata
if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() { assert!(self.userdeviceid_metadata.get(&userdeviceid)?.is_some());
return Err(Error::BadRequest(
"Tried to set token for nonexistent device",
));
}
// Remove old token // Remove old token
if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? { if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? {
@ -228,19 +249,23 @@ impl Users {
key.extend_from_slice(device_id.as_bytes()); key.extend_from_slice(device_id.as_bytes());
// All devices have metadata // All devices have metadata
if self.userdeviceid_metadata.get(&key)?.is_none() { // Only existing devices should be able to call this.
return Err(Error::BadRequest( assert!(self.userdeviceid_metadata.get(&key)?.is_some());
"Tried to set token for nonexistent device",
));
}
key.push(0xff); key.push(0xff);
// TODO: Use AlgorithmAndDeviceId::to_string when it's available (and update everything, // TODO: Use AlgorithmAndDeviceId::to_string when it's available (and update everything,
// because there are no wrapping quotation marks anymore) // because there are no wrapping quotation marks anymore)
key.extend_from_slice(&serde_json::to_string(one_time_key_key)?.as_bytes()); key.extend_from_slice(
&serde_json::to_string(one_time_key_key)
self.onetimekeyid_onetimekeys .expect("AlgorithmAndDeviceId::to_string always works")
.insert(&key, &*serde_json::to_string(&one_time_key_value)?)?; .as_bytes(),
);
self.onetimekeyid_onetimekeys.insert(
&key,
&*serde_json::to_string(&one_time_key_value)
.expect("OneTimeKey::to_string always works"),
)?;
Ok(()) Ok(())
} }
@ -271,9 +296,11 @@ impl Users {
&*key &*key
.rsplit(|&b| b == 0xff) .rsplit(|&b| b == 0xff)
.next() .next()
.ok_or(Error::BadDatabase("onetimekeyid is invalid"))?, .ok_or_else(|| Error::bad_database("OneTimeKeyId in db is invalid."))?,
)?, )
serde_json::from_slice(&*value)?, .map_err(|_| Error::bad_database("OneTimeKeyId in db is invalid."))?,
serde_json::from_slice(&*value)
.map_err(|_| Error::bad_database("OneTimeKeys in db are invalid."))?,
)) ))
}) })
.transpose() .transpose()
@ -297,11 +324,11 @@ impl Users {
.map(|bytes| { .map(|bytes| {
Ok::<_, Error>( Ok::<_, Error>(
serde_json::from_slice::<AlgorithmAndDeviceId>( serde_json::from_slice::<AlgorithmAndDeviceId>(
&*bytes? &*bytes?.rsplit(|&b| b == 0xff).next().ok_or_else(|| {
.rsplit(|&b| b == 0xff) Error::bad_database("OneTimeKey ID in db is invalid.")
.next() })?,
.ok_or(Error::BadDatabase("onetimekeyid is invalid"))?, )
)? .map_err(|_| Error::bad_database("AlgorithmAndDeviceID in db is invalid."))?
.0, .0,
) )
}) })
@ -323,8 +350,10 @@ impl Users {
userdeviceid.push(0xff); userdeviceid.push(0xff);
userdeviceid.extend_from_slice(device_id.as_bytes()); userdeviceid.extend_from_slice(device_id.as_bytes());
self.userdeviceid_devicekeys self.userdeviceid_devicekeys.insert(
.insert(&userdeviceid, &*serde_json::to_string(&device_keys)?)?; &userdeviceid,
&*serde_json::to_string(&device_keys).expect("DeviceKeys::to_string always works"),
)?;
self.devicekeychangeid_userid self.devicekeychangeid_userid
.insert(globals.next_count()?.to_be_bytes(), &*user_id.to_string())?; .insert(globals.next_count()?.to_be_bytes(), &*user_id.to_string())?;
@ -344,14 +373,28 @@ impl Users {
self.userdeviceid_devicekeys self.userdeviceid_devicekeys
.scan_prefix(key) .scan_prefix(key)
.values() .values()
.map(|bytes| Ok(serde_json::from_slice(&bytes?)?)) .map(|bytes| {
Ok(serde_json::from_slice(&bytes?)
.map_err(|_| Error::bad_database("DeviceKeys in db are invalid."))?)
})
} }
pub fn device_keys_changed(&self, since: u64) -> impl Iterator<Item = Result<UserId>> { pub fn device_keys_changed(&self, since: u64) -> impl Iterator<Item = Result<UserId>> {
self.devicekeychangeid_userid self.devicekeychangeid_userid
.range(since.to_be_bytes()..) .range(since.to_be_bytes()..)
.values() .values()
.map(|bytes| Ok(UserId::try_from(utils::string_from_bytes(&bytes?)?)?)) .map(|bytes| {
Ok(
UserId::try_from(utils::string_from_bytes(&bytes?).map_err(|_| {
Error::bad_database(
"User ID in devicekeychangeid_userid is invalid unicode.",
)
})?)
.map_err(|_| {
Error::bad_database("User ID in devicekeychangeid_userid is invalid.")
})?,
)
})
} }
pub fn all_device_keys( pub fn all_device_keys(
@ -366,9 +409,14 @@ impl Users {
let userdeviceid = utils::string_from_bytes( let userdeviceid = utils::string_from_bytes(
key.rsplit(|&b| b == 0xff) key.rsplit(|&b| b == 0xff)
.next() .next()
.ok_or(Error::BadDatabase("userdeviceid is invalid"))?, .ok_or_else(|| Error::bad_database("UserDeviceID in db is invalid."))?,
)?; )
Ok((userdeviceid, serde_json::from_slice(&*value)?)) .map_err(|_| Error::bad_database("UserDeviceId in db is invalid."))?;
Ok((
userdeviceid,
serde_json::from_slice(&*value)
.map_err(|_| Error::bad_database("DeviceKeys in db are invalid."))?,
))
}) })
} }
@ -392,8 +440,10 @@ impl Users {
json.insert("sender".to_owned(), sender.to_string().into()); json.insert("sender".to_owned(), sender.to_string().into());
json.insert("content".to_owned(), content); json.insert("content".to_owned(), content);
self.todeviceid_events self.todeviceid_events.insert(
.insert(&key, &*serde_json::to_string(&json)?)?; &key,
&*serde_json::to_string(&json).expect("Map::to_string always works"),
)?;
Ok(()) Ok(())
} }
@ -413,7 +463,10 @@ impl Users {
for result in self.todeviceid_events.scan_prefix(&prefix).take(max) { for result in self.todeviceid_events.scan_prefix(&prefix).take(max) {
let (key, value) = result?; let (key, value) = result?;
events.push(serde_json::from_slice(&*value)?); events.push(
serde_json::from_slice(&*value)
.map_err(|_| Error::bad_database("Event in todeviceid_events is invalid."))?,
);
self.todeviceid_events.remove(key)?; self.todeviceid_events.remove(key)?;
} }
@ -430,12 +483,15 @@ impl Users {
userdeviceid.push(0xff); userdeviceid.push(0xff);
userdeviceid.extend_from_slice(device_id.as_bytes()); userdeviceid.extend_from_slice(device_id.as_bytes());
if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() { // Only existing devices should be able to call this.
return Err(Error::BadRequest("device does not exist")); assert!(self.userdeviceid_metadata.get(&userdeviceid)?.is_some());
}
self.userdeviceid_metadata self.userdeviceid_metadata.insert(
.insert(userdeviceid, serde_json::to_string(device)?.as_bytes())?; userdeviceid,
serde_json::to_string(device)
.expect("Device::to_string always works")
.as_bytes(),
)?;
Ok(()) Ok(())
} }
@ -448,7 +504,11 @@ impl Users {
self.userdeviceid_metadata self.userdeviceid_metadata
.get(&userdeviceid)? .get(&userdeviceid)?
.map_or(Ok(None), |bytes| Ok(Some(serde_json::from_slice(&bytes)?))) .map_or(Ok(None), |bytes| {
Ok(Some(serde_json::from_slice(&bytes).map_err(|_| {
Error::bad_database("Metadata in userdeviceid_metadata is invalid.")
})?))
})
} }
pub fn all_devices_metadata(&self, user_id: &UserId) -> impl Iterator<Item = Result<Device>> { pub fn all_devices_metadata(&self, user_id: &UserId) -> impl Iterator<Item = Result<Device>> {
@ -458,6 +518,10 @@ impl Users {
self.userdeviceid_metadata self.userdeviceid_metadata
.scan_prefix(key) .scan_prefix(key)
.values() .values()
.map(|bytes| Ok(serde_json::from_slice::<Device>(&bytes?)?)) .map(|bytes| {
Ok(serde_json::from_slice::<Device>(&bytes?).map_err(|_| {
Error::bad_database("Device in userdeviceid_metadata is invalid.")
})?)
})
} }
} }

@ -1,41 +1,88 @@
use crate::RumaResponse;
use http::StatusCode;
use log::error;
use rocket::{
response::{self, Responder},
Request,
};
use ruma::api::client::{
error::{Error as RumaError, ErrorKind},
r0::uiaa::{UiaaInfo, UiaaResponse},
};
use thiserror::Error; use thiserror::Error;
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum Error { pub enum Error {
#[error("problem with the database")] #[error("There was a problem with the connection to the database.")]
SledError { SledError {
#[from] #[from]
source: sled::Error, source: sled::Error,
}, },
#[error("tried to parse invalid string")] #[error("Could not generate an image.")]
StringFromBytesError {
#[from]
source: std::string::FromUtf8Error,
},
#[error("tried to parse invalid identifier")]
SerdeJsonError {
#[from]
source: serde_json::Error,
},
#[error("tried to parse invalid identifier")]
RumaIdentifierError {
#[from]
source: ruma::identifiers::Error,
},
#[error("tried to parse invalid event")]
RumaEventError {
#[from]
source: ruma::events::InvalidEvent,
},
#[error("could not generate image")]
ImageError { ImageError {
#[from] #[from]
source: image::error::ImageError, source: image::error::ImageError,
}, },
#[error("bad request")] #[error("{0}")]
BadRequest(&'static str), BadConfig(&'static str),
#[error("problem in that database")] #[error("{0}")]
/// Don't create this directly. Use Error::bad_database instead.
BadDatabase(&'static str), BadDatabase(&'static str),
#[error("uiaa")]
Uiaa(UiaaInfo),
#[error("{0}: {1}")]
BadRequest(ErrorKind, &'static str),
#[error("{0}")]
Conflict(&'static str), // This is only needed for when a room alias already exists
}
impl Error {
pub fn bad_database(message: &'static str) -> Self {
error!("BadDatabase: {}", message);
Self::BadDatabase(message)
}
}
#[rocket::async_trait]
impl<'r> Responder<'r> for Error {
async fn respond_to(self, r: &'r Request<'_>) -> response::Result<'r> {
if let Self::Uiaa(uiaainfo) = &self {
return RumaResponse::from(UiaaResponse::AuthResponse(uiaainfo.clone()))
.respond_to(r)
.await;
}
let message = format!("{}", self);
use ErrorKind::*;
let (kind, status_code) = match self {
Self::BadRequest(kind, _) => (
kind,
match kind {
Forbidden | GuestAccessForbidden | ThreepidAuthFailed | ThreepidDenied => {
StatusCode::FORBIDDEN
}
Unauthorized | UnknownToken | MissingToken => StatusCode::UNAUTHORIZED,
NotFound => StatusCode::NOT_FOUND,
LimitExceeded => StatusCode::TOO_MANY_REQUESTS,
UserDeactivated => StatusCode::FORBIDDEN,
TooLarge => StatusCode::PAYLOAD_TOO_LARGE,
_ => StatusCode::BAD_REQUEST,
},
),
Self::Conflict(_) => (Unknown, StatusCode::CONFLICT),
_ => (Unknown, StatusCode::INTERNAL_SERVER_ERROR),
};
RumaResponse::from(RumaError {
kind,
message,
status_code,
})
.respond_to(r)
.await
}
} }

@ -12,7 +12,7 @@ mod utils;
pub use database::Database; pub use database::Database;
pub use error::{Error, Result}; pub use error::{Error, Result};
pub use pdu::PduEvent; pub use pdu::PduEvent;
pub use ruma_wrapper::{MatrixResult, Ruma}; pub use ruma_wrapper::{ConduitResult, Ruma, RumaResponse};
use rocket::{fairing::AdHoc, routes}; use rocket::{fairing::AdHoc, routes};
@ -95,7 +95,7 @@ fn setup_rocket() -> rocket::Rocket {
], ],
) )
.attach(AdHoc::on_attach("Config", |rocket| { .attach(AdHoc::on_attach("Config", |rocket| {
let data = Database::load_or_create(&rocket.config()); let data = Database::load_or_create(&rocket.config()).expect("valid config");
Ok(rocket.manage(data)) Ok(rocket.manage(data))
})) }))

@ -1,3 +1,4 @@
use crate::{Error, Result};
use js_int::UInt; use js_int::UInt;
use ruma::{ use ruma::{
api::federation::pdu::EventHash, api::federation::pdu::EventHash,
@ -36,7 +37,7 @@ pub struct PduEvent {
} }
impl PduEvent { impl PduEvent {
pub fn redact(&mut self) { pub fn redact(&mut self) -> Result<()> {
self.unsigned.clear(); self.unsigned.clear();
let allowed = match self.kind { let allowed = match self.kind {
EventType::RoomMember => vec!["membership"], EventType::RoomMember => vec!["membership"],
@ -56,7 +57,11 @@ impl PduEvent {
_ => vec![], _ => vec![],
}; };
let old_content = self.content.as_object_mut().unwrap(); // TODO error let old_content = self
.content
.as_object_mut()
.ok_or_else(|| Error::bad_database("PDU in db has invalid content."))?;
let mut new_content = serde_json::Map::new(); let mut new_content = serde_json::Map::new();
for key in allowed { for key in allowed {
@ -71,21 +76,23 @@ impl PduEvent {
); );
self.content = new_content.into(); self.content = new_content.into();
Ok(())
} }
pub fn to_room_event(&self) -> EventJson<RoomEvent> { pub fn to_room_event(&self) -> EventJson<RoomEvent> {
// Can only fail in rare circumstances that won't ever happen here, see let json = serde_json::to_string(&self).expect("PDUs are always valid");
// https://docs.rs/serde_json/1.0.50/serde_json/fn.to_string.html serde_json::from_str::<EventJson<RoomEvent>>(&json)
let json = serde_json::to_string(&self).unwrap(); .expect("EventJson::from_str always works")
// EventJson's deserialize implementation always returns `Ok(...)`
serde_json::from_str::<EventJson<RoomEvent>>(&json).unwrap()
} }
pub fn to_state_event(&self) -> EventJson<StateEvent> { pub fn to_state_event(&self) -> EventJson<StateEvent> {
let json = serde_json::to_string(&self).unwrap(); let json = serde_json::to_string(&self).expect("PDUs are always valid");
serde_json::from_str::<EventJson<StateEvent>>(&json).unwrap() serde_json::from_str::<EventJson<StateEvent>>(&json)
.expect("EventJson::from_str always works")
} }
pub fn to_stripped_state_event(&self) -> EventJson<AnyStrippedStateEvent> { pub fn to_stripped_state_event(&self) -> EventJson<AnyStrippedStateEvent> {
let json = serde_json::to_string(&self).unwrap(); let json = serde_json::to_string(&self).expect("PDUs are always valid");
serde_json::from_str::<EventJson<AnyStrippedStateEvent>>(&json).unwrap() serde_json::from_str::<EventJson<AnyStrippedStateEvent>>(&json)
.expect("EventJson::from_str always works")
} }
} }

@ -1,4 +1,4 @@
use crate::utils; use crate::{utils, Error};
use log::warn; use log::warn;
use rocket::{ use rocket::{
data::{Data, FromData, FromDataFuture, Transform, TransformFuture, Transformed}, data::{Data, FromData, FromDataFuture, Transform, TransformFuture, Transformed},
@ -42,7 +42,10 @@ impl<'a, T: Endpoint> FromData<'a> for Ruma<T> {
let data = rocket::try_outcome!(outcome.owned()); let data = rocket::try_outcome!(outcome.owned());
let (user_id, device_id) = if T::METADATA.requires_authentication { let (user_id, device_id) = if T::METADATA.requires_authentication {
let db = request.guard::<State<'_, crate::Database>>().await.unwrap(); let db = request
.guard::<State<'_, crate::Database>>()
.await
.expect("database was loaded");
// Get token from header or query value // Get token from header or query value
let token = match request let token = match request
@ -108,32 +111,24 @@ impl<T> Deref for Ruma<T> {
} }
/// This struct converts ruma responses into rocket http responses. /// This struct converts ruma responses into rocket http responses.
pub struct MatrixResult<T, E = ruma::api::client::Error>(pub std::result::Result<T, E>); pub type ConduitResult<T> = std::result::Result<RumaResponse<T>, Error>;
impl<T, E> TryInto<http::Response<Vec<u8>>> for MatrixResult<T, E> pub struct RumaResponse<T: TryInto<http::Response<Vec<u8>>>>(pub T);
where
T: TryInto<http::Response<Vec<u8>>>,
E: Into<http::Response<Vec<u8>>>,
{
type Error = T::Error;
fn try_into(self) -> Result<http::Response<Vec<u8>>, T::Error> { impl<T: TryInto<http::Response<Vec<u8>>>> From<T> for RumaResponse<T> {
match self.0 { fn from(t: T) -> Self {
Ok(t) => t.try_into(), Self(t)
Err(e) => Ok(e.into()),
}
} }
} }
#[rocket::async_trait] #[rocket::async_trait]
impl<'r, T, E> Responder<'r> for MatrixResult<T, E> impl<'r, T> Responder<'r> for RumaResponse<T>
where where
T: Send + TryInto<http::Response<Vec<u8>>>, T: Send + TryInto<http::Response<Vec<u8>>>,
T::Error: Send, T::Error: Send,
E: Into<http::Response<Vec<u8>>> + Send,
{ {
async fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r> { async fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r> {
let http_response: Result<http::Response<_>, _> = self.try_into(); let http_response: Result<http::Response<_>, _> = self.0.try_into();
match http_response { match http_response {
Ok(http_response) => { Ok(http_response) => {
let mut response = rocket::response::Response::build(); let mut response = rocket::response::Response::build();
@ -165,11 +160,3 @@ where
} }
} }
} }
impl<T, E> Deref for MatrixResult<T, E> {
type Target = Result<T, E>;
fn deref(&self) -> &Self::Target {
&self.0
}
}

@ -1,4 +1,3 @@
use crate::Result;
use argon2::{Config, Variant}; use argon2::{Config, Variant};
use rand::prelude::*; use rand::prelude::*;
use std::{ use std::{
@ -9,39 +8,38 @@ use std::{
pub fn millis_since_unix_epoch() -> u64 { pub fn millis_since_unix_epoch() -> u64 {
SystemTime::now() SystemTime::now()
.duration_since(UNIX_EPOCH) .duration_since(UNIX_EPOCH)
.unwrap() .expect("time is valid")
.as_millis() as u64 .as_millis() as u64
} }
pub fn increment(old: Option<&[u8]>) -> Option<Vec<u8>> { pub fn increment(old: Option<&[u8]>) -> Option<Vec<u8>> {
let number = match old { let number = match old.map(|bytes| bytes.try_into()) {
Some(bytes) => { Some(Ok(bytes)) => {
let array: [u8; 8] = bytes.try_into().unwrap(); let number = u64::from_be_bytes(bytes);
let number = u64::from_be_bytes(array);
number + 1 number + 1
} }
None => 1, // Start at one. since 0 should return the first event in the db _ => 1, // Start at one. since 0 should return the first event in the db
}; };
Some(number.to_be_bytes().to_vec()) Some(number.to_be_bytes().to_vec())
} }
pub fn generate_keypair(old: Option<&[u8]>) -> Option<Vec<u8>> { pub fn generate_keypair(old: Option<&[u8]>) -> Option<Vec<u8>> {
Some( Some(old.map(|s| s.to_vec()).unwrap_or_else(|| {
old.map(|s| s.to_vec()) ruma::signatures::Ed25519KeyPair::generate()
.unwrap_or_else(|| ruma::signatures::Ed25519KeyPair::generate().unwrap()), .expect("Ed25519KeyPair generation always works (?)")
) }))
} }
/// Parses the bytes into an u64. /// Parses the bytes into an u64.
pub fn u64_from_bytes(bytes: &[u8]) -> u64 { pub fn u64_from_bytes(bytes: &[u8]) -> Result<u64, std::array::TryFromSliceError> {
let array: [u8; 8] = bytes.try_into().expect("bytes are valid u64"); let array: [u8; 8] = bytes.try_into()?;
u64::from_be_bytes(array) Ok(u64::from_be_bytes(array))
} }
/// Parses the bytes into a string. /// Parses the bytes into a string.
pub fn string_from_bytes(bytes: &[u8]) -> Result<String> { pub fn string_from_bytes(bytes: &[u8]) -> Result<String, std::string::FromUtf8Error> {
Ok(String::from_utf8(bytes.to_vec())?) String::from_utf8(bytes.to_vec())
} }
pub fn random_string(length: usize) -> String { pub fn random_string(length: usize) -> String {
@ -52,7 +50,7 @@ pub fn random_string(length: usize) -> String {
} }
/// Calculate a new hash for the given password /// Calculate a new hash for the given password
pub fn calculate_hash(password: &str) -> std::result::Result<String, argon2::Error> { pub fn calculate_hash(password: &str) -> Result<String, argon2::Error> {
let hashing_config = Config { let hashing_config = Config {
variant: Variant::Argon2id, variant: Variant::Argon2id,
..Default::default() ..Default::default()

@ -35,8 +35,6 @@ POST /rooms/:room_id/invite can send an invite
PUT /rooms/:room_id/state/m.room.power_levels can set levels PUT /rooms/:room_id/state/m.room.power_levels can set levels
PUT power_levels should not explode if the old power levels were empty PUT power_levels should not explode if the old power levels were empty
Both GET and PUT work Both GET and PUT work
Room creation reports m.room.create to myself
Room creation reports m.room.member to myself
Version responds 200 OK with valid structure Version responds 200 OK with valid structure
PUT /profile/:user_id/displayname sets my name PUT /profile/:user_id/displayname sets my name
GET /profile/:user_id/displayname publicly accessible GET /profile/:user_id/displayname publicly accessible
@ -78,3 +76,4 @@ User directory correctly update on display name change
User in shared private room does appear in user directory User in shared private room does appear in user directory
User in dir while user still shares private rooms User in dir while user still shares private rooms
POST /rooms/:room_id/ban can ban a user POST /rooms/:room_id/ban can ban a user
Alternative server names do not cause a routing loop

Loading…
Cancel
Save