diff --git a/src/database/key_value/pusher.rs b/src/database/key_value/pusher.rs new file mode 100644 index 00000000..6b906c24 --- /dev/null +++ b/src/database/key_value/pusher.rs @@ -0,0 +1,348 @@ +use crate::{Database, Error, PduEvent, Result}; +use bytes::BytesMut; +use ruma::{ + api::{ + client::push::{get_pushers, set_pusher, PusherKind}, + push_gateway::send_event_notification::{ + self, + v1::{Device, Notification, NotificationCounts, NotificationPriority}, + }, + IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken, + }, + events::{ + room::{name::RoomNameEventContent, power_levels::RoomPowerLevelsEventContent}, + AnySyncRoomEvent, RoomEventType, StateEventType, + }, + push::{Action, PushConditionRoomCtx, PushFormat, Ruleset, Tweak}, + serde::Raw, + uint, RoomId, UInt, UserId, +}; +use tracing::{error, info, warn}; + +use std::{fmt::Debug, mem, sync::Arc}; + +use super::abstraction::Tree; + +pub struct PushData { + /// UserId + pushkey -> Pusher + pub(super) senderkey_pusher: Arc, +} + +impl PushData { + #[tracing::instrument(skip(self, sender, pusher))] + pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()> { + let mut key = sender.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(pusher.pushkey.as_bytes()); + + // There are 2 kinds of pushers but the spec says: null deletes the pusher. + if pusher.kind.is_none() { + return self + .senderkey_pusher + .remove(&key) + .map(|_| ()) + .map_err(Into::into); + } + + self.senderkey_pusher.insert( + &key, + &serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"), + )?; + + Ok(()) + } + + #[tracing::instrument(skip(self, senderkey))] + pub fn get_pusher(&self, senderkey: &[u8]) -> Result> { + self.senderkey_pusher + .get(senderkey)? + .map(|push| { + serde_json::from_slice(&*push) + .map_err(|_| Error::bad_database("Invalid Pusher in db.")) + }) + .transpose() + } + + #[tracing::instrument(skip(self, sender))] + pub fn get_pushers(&self, sender: &UserId) -> Result> { + let mut prefix = sender.as_bytes().to_vec(); + prefix.push(0xff); + + self.senderkey_pusher + .scan_prefix(prefix) + .map(|(_, push)| { + serde_json::from_slice(&*push) + .map_err(|_| Error::bad_database("Invalid Pusher in db.")) + }) + .collect() + } + + #[tracing::instrument(skip(self, sender))] + pub fn get_pusher_senderkeys<'a>( + &'a self, + sender: &UserId, + ) -> impl Iterator> + 'a { + let mut prefix = sender.as_bytes().to_vec(); + prefix.push(0xff); + + self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| k) + } +} + +#[tracing::instrument(skip(globals, destination, request))] +pub async fn send_request( + globals: &crate::database::globals::Globals, + destination: &str, + request: T, +) -> Result +where + T: Debug, +{ + let destination = destination.replace("/_matrix/push/v1/notify", ""); + + let http_request = request + .try_into_http_request::( + &destination, + SendAccessToken::IfRequired(""), + &[MatrixVersion::V1_0], + ) + .map_err(|e| { + warn!("Failed to find destination {}: {}", destination, e); + Error::BadServerResponse("Invalid destination") + })? + .map(|body| body.freeze()); + + let reqwest_request = reqwest::Request::try_from(http_request) + .expect("all http requests are valid reqwest requests"); + + // TODO: we could keep this very short and let expo backoff do it's thing... + //*reqwest_request.timeout_mut() = Some(Duration::from_secs(5)); + + let url = reqwest_request.url().clone(); + let response = globals.default_client().execute(reqwest_request).await; + + match response { + Ok(mut response) => { + // reqwest::Response -> http::Response conversion + let status = response.status(); + let mut http_response_builder = http::Response::builder() + .status(status) + .version(response.version()); + mem::swap( + response.headers_mut(), + http_response_builder + .headers_mut() + .expect("http::response::Builder is usable"), + ); + + let body = response.bytes().await.unwrap_or_else(|e| { + warn!("server error {}", e); + Vec::new().into() + }); // TODO: handle timeout + + if status != 200 { + info!( + "Push gateway returned bad response {} {}\n{}\n{:?}", + destination, + status, + url, + crate::utils::string_from_bytes(&body) + ); + } + + let response = T::IncomingResponse::try_from_http_response( + http_response_builder + .body(body) + .expect("reqwest body is valid http body"), + ); + response.map_err(|_| { + info!( + "Push gateway returned invalid response bytes {}\n{}", + destination, url + ); + Error::BadServerResponse("Push gateway returned bad response.") + }) + } + Err(e) => Err(e.into()), + } +} + +#[tracing::instrument(skip(user, unread, pusher, ruleset, pdu, db))] +pub async fn send_push_notice( + user: &UserId, + unread: UInt, + pusher: &get_pushers::v3::Pusher, + ruleset: Ruleset, + pdu: &PduEvent, + db: &Database, +) -> Result<()> { + let mut notify = None; + let mut tweaks = Vec::new(); + + let power_levels: RoomPowerLevelsEventContent = db + .rooms + .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? + .map(|ev| { + serde_json::from_str(ev.content.get()) + .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) + }) + .transpose()? + .unwrap_or_default(); + + for action in get_actions( + user, + &ruleset, + &power_levels, + &pdu.to_sync_room_event(), + &pdu.room_id, + db, + )? { + let n = match action { + Action::DontNotify => false, + // TODO: Implement proper support for coalesce + Action::Notify | Action::Coalesce => true, + Action::SetTweak(tweak) => { + tweaks.push(tweak.clone()); + continue; + } + }; + + if notify.is_some() { + return Err(Error::bad_database( + r#"Malformed pushrule contains more than one of these actions: ["dont_notify", "notify", "coalesce"]"#, + )); + } + + notify = Some(n); + } + + if notify == Some(true) { + send_notice(unread, pusher, tweaks, pdu, db).await?; + } + // Else the event triggered no actions + + Ok(()) +} + +#[tracing::instrument(skip(user, ruleset, pdu, db))] +pub fn get_actions<'a>( + user: &UserId, + ruleset: &'a Ruleset, + power_levels: &RoomPowerLevelsEventContent, + pdu: &Raw, + room_id: &RoomId, + db: &Database, +) -> Result<&'a [Action]> { + let ctx = PushConditionRoomCtx { + room_id: room_id.to_owned(), + member_count: 10_u32.into(), // TODO: get member count efficiently + user_display_name: db + .users + .displayname(user)? + .unwrap_or_else(|| user.localpart().to_owned()), + users_power_levels: power_levels.users.clone(), + default_power_level: power_levels.users_default, + notification_power_levels: power_levels.notifications.clone(), + }; + + Ok(ruleset.get_actions(pdu, &ctx)) +} + +#[tracing::instrument(skip(unread, pusher, tweaks, event, db))] +async fn send_notice( + unread: UInt, + pusher: &get_pushers::v3::Pusher, + tweaks: Vec, + event: &PduEvent, + db: &Database, +) -> Result<()> { + // TODO: email + if pusher.kind == PusherKind::Email { + return Ok(()); + } + + // TODO: + // Two problems with this + // 1. if "event_id_only" is the only format kind it seems we should never add more info + // 2. can pusher/devices have conflicting formats + let event_id_only = pusher.data.format == Some(PushFormat::EventIdOnly); + let url = if let Some(url) = &pusher.data.url { + url + } else { + error!("Http Pusher must have URL specified."); + return Ok(()); + }; + + let mut device = Device::new(pusher.app_id.clone(), pusher.pushkey.clone()); + let mut data_minus_url = pusher.data.clone(); + // The url must be stripped off according to spec + data_minus_url.url = None; + device.data = data_minus_url; + + // Tweaks are only added if the format is NOT event_id_only + if !event_id_only { + device.tweaks = tweaks.clone(); + } + + let d = &[device]; + let mut notifi = Notification::new(d); + + notifi.prio = NotificationPriority::Low; + notifi.event_id = Some(&event.event_id); + notifi.room_id = Some(&event.room_id); + // TODO: missed calls + notifi.counts = NotificationCounts::new(unread, uint!(0)); + + if event.kind == RoomEventType::RoomEncrypted + || tweaks + .iter() + .any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_))) + { + notifi.prio = NotificationPriority::High + } + + if event_id_only { + send_request( + &db.globals, + url, + send_event_notification::v1::Request::new(notifi), + ) + .await?; + } else { + notifi.sender = Some(&event.sender); + notifi.event_type = Some(&event.kind); + let content = serde_json::value::to_raw_value(&event.content).ok(); + notifi.content = content.as_deref(); + + if event.kind == RoomEventType::RoomMember { + notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str()); + } + + let user_name = db.users.displayname(&event.sender)?; + notifi.sender_display_name = user_name.as_deref(); + + let room_name = if let Some(room_name_pdu) = + db.rooms + .room_state_get(&event.room_id, &StateEventType::RoomName, "")? + { + serde_json::from_str::(room_name_pdu.content.get()) + .map_err(|_| Error::bad_database("Invalid room name event in database."))? + .name + } else { + None + }; + + notifi.room_name = room_name.as_deref(); + + send_request( + &db.globals, + url, + send_event_notification::v1::Request::new(notifi), + ) + .await?; + } + + // TODO: email + + Ok(()) +} diff --git a/src/database/key_value/rooms/timeline.rs b/src/database/key_value/rooms/timeline.rs new file mode 100644 index 00000000..5b423d2d --- /dev/null +++ b/src/database/key_value/rooms/timeline.rs @@ -0,0 +1,937 @@ + + /// Checks if a room exists. + #[tracing::instrument(skip(self))] + pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result>> { + let prefix = self + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); + + // Look for PDUs in that room. + self.pduid_pdu + .iter_from(&prefix, false) + .filter(|(k, _)| k.starts_with(&prefix)) + .map(|(_, pdu)| { + serde_json::from_slice(&pdu) + .map_err(|_| Error::bad_database("Invalid first PDU in db.")) + .map(Arc::new) + }) + .next() + .transpose() + } + + #[tracing::instrument(skip(self))] + pub fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { + match self + .lasttimelinecount_cache + .lock() + .unwrap() + .entry(room_id.to_owned()) + { + hash_map::Entry::Vacant(v) => { + if let Some(last_count) = self + .pdus_until(&sender_user, &room_id, u64::MAX)? + .filter_map(|r| { + // Filter out buggy events + if r.is_err() { + error!("Bad pdu in pdus_since: {:?}", r); + } + r.ok() + }) + .map(|(pduid, _)| self.pdu_count(&pduid)) + .next() + { + Ok(*v.insert(last_count?)) + } else { + Ok(0) + } + } + hash_map::Entry::Occupied(o) => Ok(*o.get()), + } + } + + // TODO Is this the same as the function above? + #[tracing::instrument(skip(self))] + pub fn latest_pdu_count(&self, room_id: &RoomId) -> Result { + let prefix = self + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); + + let mut last_possible_key = prefix.clone(); + last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); + + self.pduid_pdu + .iter_from(&last_possible_key, true) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .next() + .map(|b| self.pdu_count(&b.0)) + .transpose() + .map(|op| op.unwrap_or_default()) + } + + + + /// Returns the `count` of this pdu's id. + pub fn get_pdu_count(&self, event_id: &EventId) -> Result> { + self.eventid_pduid + .get(event_id.as_bytes())? + .map(|pdu_id| self.pdu_count(&pdu_id)) + .transpose() + } + + /// Returns the json of a pdu. + pub fn get_pdu_json(&self, event_id: &EventId) -> Result> { + self.eventid_pduid + .get(event_id.as_bytes())? + .map_or_else( + || self.eventid_outlierpdu.get(event_id.as_bytes()), + |pduid| { + Ok(Some(self.pduid_pdu.get(&pduid)?.ok_or_else(|| { + Error::bad_database("Invalid pduid in eventid_pduid.") + })?)) + }, + )? + .map(|pdu| { + serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) + }) + .transpose() + } + + /// Returns the json of a pdu. + pub fn get_non_outlier_pdu_json( + &self, + event_id: &EventId, + ) -> Result> { + self.eventid_pduid + .get(event_id.as_bytes())? + .map(|pduid| { + self.pduid_pdu + .get(&pduid)? + .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) + }) + .transpose()? + .map(|pdu| { + serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) + }) + .transpose() + } + + /// Returns the pdu's id. + pub fn get_pdu_id(&self, event_id: &EventId) -> Result>> { + self.eventid_pduid.get(event_id.as_bytes()) + } + + /// Returns the pdu. + /// + /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. + pub fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result> { + self.eventid_pduid + .get(event_id.as_bytes())? + .map(|pduid| { + self.pduid_pdu + .get(&pduid)? + .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) + }) + .transpose()? + .map(|pdu| { + serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) + }) + .transpose() + } + + /// Returns the pdu. + /// + /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. + pub fn get_pdu(&self, event_id: &EventId) -> Result>> { + if let Some(p) = self.pdu_cache.lock().unwrap().get_mut(event_id) { + return Ok(Some(Arc::clone(p))); + } + + if let Some(pdu) = self + .eventid_pduid + .get(event_id.as_bytes())? + .map_or_else( + || self.eventid_outlierpdu.get(event_id.as_bytes()), + |pduid| { + Ok(Some(self.pduid_pdu.get(&pduid)?.ok_or_else(|| { + Error::bad_database("Invalid pduid in eventid_pduid.") + })?)) + }, + )? + .map(|pdu| { + serde_json::from_slice(&pdu) + .map_err(|_| Error::bad_database("Invalid PDU in db.")) + .map(Arc::new) + }) + .transpose()? + { + self.pdu_cache + .lock() + .unwrap() + .insert(event_id.to_owned(), Arc::clone(&pdu)); + Ok(Some(pdu)) + } else { + Ok(None) + } + } + + /// Returns the pdu. + /// + /// This does __NOT__ check the outliers `Tree`. + pub fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result> { + self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { + Ok(Some( + serde_json::from_slice(&pdu) + .map_err(|_| Error::bad_database("Invalid PDU in db."))?, + )) + }) + } + + /// Returns the pdu as a `BTreeMap`. + pub fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result> { + self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { + Ok(Some( + serde_json::from_slice(&pdu) + .map_err(|_| Error::bad_database("Invalid PDU in db."))?, + )) + }) + } + + /// Returns the `count` of this pdu's id. + pub fn pdu_count(&self, pdu_id: &[u8]) -> Result { + utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::()..]) + .map_err(|_| Error::bad_database("PDU has invalid count bytes.")) + } + + /// Removes a pdu and creates a new one with the same id. + #[tracing::instrument(skip(self))] + fn replace_pdu(&self, pdu_id: &[u8], pdu: &PduEvent) -> Result<()> { + if self.pduid_pdu.get(pdu_id)?.is_some() { + self.pduid_pdu.insert( + pdu_id, + &serde_json::to_vec(pdu).expect("PduEvent::to_vec always works"), + )?; + Ok(()) + } else { + Err(Error::BadRequest( + ErrorKind::NotFound, + "PDU does not exist.", + )) + } + } + + /// Creates a new persisted data unit and adds it to a room. + /// + /// By this point the incoming event should be fully authenticated, no auth happens + /// in `append_pdu`. + /// + /// Returns pdu id + #[tracing::instrument(skip(self, pdu, pdu_json, leaves, db))] + pub fn append_pdu<'a>( + &self, + pdu: &PduEvent, + mut pdu_json: CanonicalJsonObject, + leaves: impl IntoIterator + Debug, + db: &Database, + ) -> Result> { + let shortroomid = self.get_shortroomid(&pdu.room_id)?.expect("room exists"); + + // Make unsigned fields correct. This is not properly documented in the spec, but state + // events need to have previous content in the unsigned field, so clients can easily + // interpret things like membership changes + if let Some(state_key) = &pdu.state_key { + if let CanonicalJsonValue::Object(unsigned) = pdu_json + .entry("unsigned".to_owned()) + .or_insert_with(|| CanonicalJsonValue::Object(Default::default())) + { + if let Some(shortstatehash) = self.pdu_shortstatehash(&pdu.event_id).unwrap() { + if let Some(prev_state) = self + .state_get(shortstatehash, &pdu.kind.to_string().into(), state_key) + .unwrap() + { + unsigned.insert( + "prev_content".to_owned(), + CanonicalJsonValue::Object( + utils::to_canonical_object(prev_state.content.clone()) + .expect("event is valid, we just created it"), + ), + ); + } + } + } else { + error!("Invalid unsigned type in pdu."); + } + } + + // We must keep track of all events that have been referenced. + self.mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; + self.replace_pdu_leaves(&pdu.room_id, leaves)?; + + let mutex_insert = Arc::clone( + db.globals + .roomid_mutex_insert + .write() + .unwrap() + .entry(pdu.room_id.clone()) + .or_default(), + ); + let insert_lock = mutex_insert.lock().unwrap(); + + let count1 = db.globals.next_count()?; + // Mark as read first so the sending client doesn't get a notification even if appending + // fails + self.edus + .private_read_set(&pdu.room_id, &pdu.sender, count1, &db.globals)?; + self.reset_notification_counts(&pdu.sender, &pdu.room_id)?; + + let count2 = db.globals.next_count()?; + let mut pdu_id = shortroomid.to_be_bytes().to_vec(); + pdu_id.extend_from_slice(&count2.to_be_bytes()); + + // There's a brief moment of time here where the count is updated but the pdu does not + // exist. This could theoretically lead to dropped pdus, but it's extremely rare + // + // Update: We fixed this using insert_lock + + self.pduid_pdu.insert( + &pdu_id, + &serde_json::to_vec(&pdu_json).expect("CanonicalJsonObject is always a valid"), + )?; + self.lasttimelinecount_cache + .lock() + .unwrap() + .insert(pdu.room_id.clone(), count2); + + self.eventid_pduid + .insert(pdu.event_id.as_bytes(), &pdu_id)?; + self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?; + + drop(insert_lock); + + // See if the event matches any known pushers + let power_levels: RoomPowerLevelsEventContent = db + .rooms + .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? + .map(|ev| { + serde_json::from_str(ev.content.get()) + .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) + }) + .transpose()? + .unwrap_or_default(); + + let sync_pdu = pdu.to_sync_room_event(); + + let mut notifies = Vec::new(); + let mut highlights = Vec::new(); + + for user in self.get_our_real_users(&pdu.room_id, db)?.iter() { + // Don't notify the user of their own events + if user == &pdu.sender { + continue; + } + + let rules_for_user = db + .account_data + .get( + None, + user, + GlobalAccountDataEventType::PushRules.to_string().into(), + )? + .map(|ev: PushRulesEvent| ev.content.global) + .unwrap_or_else(|| Ruleset::server_default(user)); + + let mut highlight = false; + let mut notify = false; + + for action in pusher::get_actions( + user, + &rules_for_user, + &power_levels, + &sync_pdu, + &pdu.room_id, + db, + )? { + match action { + Action::DontNotify => notify = false, + // TODO: Implement proper support for coalesce + Action::Notify | Action::Coalesce => notify = true, + Action::SetTweak(Tweak::Highlight(true)) => { + highlight = true; + } + _ => {} + }; + } + + let mut userroom_id = user.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(pdu.room_id.as_bytes()); + + if notify { + notifies.push(userroom_id.clone()); + } + + if highlight { + highlights.push(userroom_id); + } + + for senderkey in db.pusher.get_pusher_senderkeys(user) { + db.sending.send_push_pdu(&*pdu_id, senderkey)?; + } + } + + self.userroomid_notificationcount + .increment_batch(&mut notifies.into_iter())?; + self.userroomid_highlightcount + .increment_batch(&mut highlights.into_iter())?; + + match pdu.kind { + RoomEventType::RoomRedaction => { + if let Some(redact_id) = &pdu.redacts { + self.redact_pdu(redact_id, pdu)?; + } + } + RoomEventType::RoomMember => { + if let Some(state_key) = &pdu.state_key { + #[derive(Deserialize)] + struct ExtractMembership { + membership: MembershipState, + } + + // if the state_key fails + let target_user_id = UserId::parse(state_key.clone()) + .expect("This state_key was previously validated"); + + let content = serde_json::from_str::(pdu.content.get()) + .map_err(|_| Error::bad_database("Invalid content in pdu."))?; + + let invite_state = match content.membership { + MembershipState::Invite => { + let state = self.calculate_invite_state(pdu)?; + Some(state) + } + _ => None, + }; + + // Update our membership info, we do this here incase a user is invited + // and immediately leaves we need the DB to record the invite event for auth + self.update_membership( + &pdu.room_id, + &target_user_id, + content.membership, + &pdu.sender, + invite_state, + db, + true, + )?; + } + } + RoomEventType::RoomMessage => { + #[derive(Deserialize)] + struct ExtractBody<'a> { + #[serde(borrow)] + body: Option>, + } + + let content = serde_json::from_str::>(pdu.content.get()) + .map_err(|_| Error::bad_database("Invalid content in pdu."))?; + + if let Some(body) = content.body { + DB.rooms.search.index_pdu(room_id, pdu_id, body)?; + + let admin_room = self.id_from_alias( + <&RoomAliasId>::try_from( + format!("#admins:{}", db.globals.server_name()).as_str(), + ) + .expect("#admins:server_name is a valid room alias"), + )?; + let server_user = format!("@conduit:{}", db.globals.server_name()); + + let to_conduit = body.starts_with(&format!("{}: ", server_user)); + + // This will evaluate to false if the emergency password is set up so that + // the administrator can execute commands as conduit + let from_conduit = + pdu.sender == server_user && db.globals.emergency_password().is_none(); + + if to_conduit && !from_conduit && admin_room.as_ref() == Some(&pdu.room_id) { + db.admin.process_message(body.to_string()); + } + } + } + _ => {} + } + + for appservice in db.appservice.all()? { + if self.appservice_in_room(room_id, &appservice, db)? { + db.sending.send_pdu_appservice(&appservice.0, &pdu_id)?; + continue; + } + + // If the RoomMember event has a non-empty state_key, it is targeted at someone. + // If it is our appservice user, we send this PDU to it. + if pdu.kind == RoomEventType::RoomMember { + if let Some(state_key_uid) = &pdu + .state_key + .as_ref() + .and_then(|state_key| UserId::parse(state_key.as_str()).ok()) + { + if let Some(appservice_uid) = appservice + .1 + .get("sender_localpart") + .and_then(|string| string.as_str()) + .and_then(|string| { + UserId::parse_with_server_name(string, db.globals.server_name()).ok() + }) + { + if state_key_uid == &appservice_uid { + db.sending.send_pdu_appservice(&appservice.0, &pdu_id)?; + continue; + } + } + } + } + + if let Some(namespaces) = appservice.1.get("namespaces") { + let users = namespaces + .get("users") + .and_then(|users| users.as_sequence()) + .map_or_else(Vec::new, |users| { + users + .iter() + .filter_map(|users| Regex::new(users.get("regex")?.as_str()?).ok()) + .collect::>() + }); + let aliases = namespaces + .get("aliases") + .and_then(|aliases| aliases.as_sequence()) + .map_or_else(Vec::new, |aliases| { + aliases + .iter() + .filter_map(|aliases| Regex::new(aliases.get("regex")?.as_str()?).ok()) + .collect::>() + }); + let rooms = namespaces + .get("rooms") + .and_then(|rooms| rooms.as_sequence()); + + let matching_users = |users: &Regex| { + users.is_match(pdu.sender.as_str()) + || pdu.kind == RoomEventType::RoomMember + && pdu + .state_key + .as_ref() + .map_or(false, |state_key| users.is_match(state_key)) + }; + let matching_aliases = |aliases: &Regex| { + self.room_aliases(room_id) + .filter_map(|r| r.ok()) + .any(|room_alias| aliases.is_match(room_alias.as_str())) + }; + + if aliases.iter().any(matching_aliases) + || rooms.map_or(false, |rooms| rooms.contains(&room_id.as_str().into())) + || users.iter().any(matching_users) + { + db.sending.send_pdu_appservice(&appservice.0, &pdu_id)?; + } + } + } + + + Ok(pdu_id) + } + + pub fn create_hash_and_sign_event( + &self, + pdu_builder: PduBuilder, + sender: &UserId, + room_id: &RoomId, + db: &Database, + _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + ) -> (PduEvent, CanonicalJsonObj) { + let PduBuilder { + event_type, + content, + unsigned, + state_key, + redacts, + } = pdu_builder; + + let prev_events: Vec<_> = db + .rooms + .get_pdu_leaves(room_id)? + .into_iter() + .take(20) + .collect(); + + let create_event = db + .rooms + .room_state_get(room_id, &StateEventType::RoomCreate, "")?; + + let create_event_content: Option = create_event + .as_ref() + .map(|create_event| { + serde_json::from_str(create_event.content.get()).map_err(|e| { + warn!("Invalid create event: {}", e); + Error::bad_database("Invalid create event in db.") + }) + }) + .transpose()?; + + // If there was no create event yet, assume we are creating a room with the default + // version right now + let room_version_id = create_event_content + .map_or(db.globals.default_room_version(), |create_event| { + create_event.room_version + }); + let room_version = + RoomVersion::new(&room_version_id).expect("room version is supported"); + + let auth_events = + self.get_auth_events(room_id, &event_type, sender, state_key.as_deref(), &content)?; + + // Our depth is the maximum depth of prev_events + 1 + let depth = prev_events + .iter() + .filter_map(|event_id| Some(db.rooms.get_pdu(event_id).ok()??.depth)) + .max() + .unwrap_or_else(|| uint!(0)) + + uint!(1); + + let mut unsigned = unsigned.unwrap_or_default(); + + if let Some(state_key) = &state_key { + if let Some(prev_pdu) = + self.room_state_get(room_id, &event_type.to_string().into(), state_key)? + { + unsigned.insert( + "prev_content".to_owned(), + serde_json::from_str(prev_pdu.content.get()).expect("string is valid json"), + ); + unsigned.insert( + "prev_sender".to_owned(), + serde_json::to_value(&prev_pdu.sender).expect("UserId::to_value always works"), + ); + } + } + + let pdu = PduEvent { + event_id: ruma::event_id!("$thiswillbefilledinlater").into(), + room_id: room_id.to_owned(), + sender: sender_user.to_owned(), + origin_server_ts: utils::millis_since_unix_epoch() + .try_into() + .expect("time is valid"), + kind: event_type, + content, + state_key, + prev_events, + depth, + auth_events: auth_events + .iter() + .map(|(_, pdu)| pdu.event_id.clone()) + .collect(), + redacts, + unsigned: if unsigned.is_empty() { + None + } else { + Some(to_raw_value(&unsigned).expect("to_raw_value always works")) + }, + hashes: EventHash { + sha256: "aaa".to_owned(), + }, + signatures: None, + }; + + let auth_check = state_res::auth_check( + &room_version, + &pdu, + None::, // TODO: third_party_invite + |k, s| auth_events.get(&(k.clone(), s.to_owned())), + ) + .map_err(|e| { + error!("{:?}", e); + Error::bad_database("Auth check failed.") + })?; + + if !auth_check { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Event is not authorized.", + )); + } + + // Hash and sign + let mut pdu_json = + utils::to_canonical_object(&pdu).expect("event is valid, we just created it"); + + pdu_json.remove("event_id"); + + // Add origin because synapse likes that (and it's required in the spec) + pdu_json.insert( + "origin".to_owned(), + to_canonical_value(db.globals.server_name()) + .expect("server name is a valid CanonicalJsonValue"), + ); + + match ruma::signatures::hash_and_sign_event( + db.globals.server_name().as_str(), + db.globals.keypair(), + &mut pdu_json, + &room_version_id, + ) { + Ok(_) => {} + Err(e) => { + return match e { + ruma::signatures::Error::PduSize => Err(Error::BadRequest( + ErrorKind::TooLarge, + "Message is too long", + )), + _ => Err(Error::BadRequest( + ErrorKind::Unknown, + "Signing event failed", + )), + } + } + } + + // Generate event id + pdu.event_id = EventId::parse_arc(format!( + "${}", + ruma::signatures::reference_hash(&pdu_json, &room_version_id) + .expect("ruma can calculate reference hashes") + )) + .expect("ruma's reference hashes are valid event ids"); + + pdu_json.insert( + "event_id".to_owned(), + CanonicalJsonValue::String(pdu.event_id.as_str().to_owned()), + ); + + // Generate short event id + let _shorteventid = self.get_or_create_shorteventid(&pdu.event_id, &db.globals)?; + } + + /// Creates a new persisted data unit and adds it to a room. This function takes a + /// roomid_mutex_state, meaning that only this function is able to mutate the room state. + #[tracing::instrument(skip(self, db, _mutex_lock))] + pub fn build_and_append_pdu( + &self, + pdu_builder: PduBuilder, + sender: &UserId, + room_id: &RoomId, + db: &Database, + _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + ) -> Result> { + + let (pdu, pdu_json) = create_hash_and_sign_event()?; + + + // We append to state before appending the pdu, so we don't have a moment in time with the + // pdu without it's state. This is okay because append_pdu can't fail. + let statehashid = self.append_to_state(&pdu, &db.globals)?; + + let pdu_id = self.append_pdu( + &pdu, + pdu_json, + // Since this PDU references all pdu_leaves we can update the leaves + // of the room + iter::once(&*pdu.event_id), + db, + )?; + + // We set the room state after inserting the pdu, so that we never have a moment in time + // where events in the current room state do not exist + self.set_room_state(room_id, statehashid)?; + + let mut servers: HashSet> = + self.room_servers(room_id).filter_map(|r| r.ok()).collect(); + + // In case we are kicking or banning a user, we need to inform their server of the change + if pdu.kind == RoomEventType::RoomMember { + if let Some(state_key_uid) = &pdu + .state_key + .as_ref() + .and_then(|state_key| UserId::parse(state_key.as_str()).ok()) + { + servers.insert(Box::from(state_key_uid.server_name())); + } + } + + // Remove our server from the server list since it will be added to it by room_servers() and/or the if statement above + servers.remove(db.globals.server_name()); + + db.sending.send_pdu(servers.into_iter(), &pdu_id)?; + + Ok(pdu.event_id) + } + + /// Append the incoming event setting the state snapshot to the state from the + /// server that sent the event. + #[tracing::instrument(skip_all)] + fn append_incoming_pdu<'a>( + db: &Database, + pdu: &PduEvent, + pdu_json: CanonicalJsonObject, + new_room_leaves: impl IntoIterator + Clone + Debug, + state_ids_compressed: HashSet, + soft_fail: bool, + _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + ) -> Result>> { + // We append to state before appending the pdu, so we don't have a moment in time with the + // pdu without it's state. This is okay because append_pdu can't fail. + db.rooms.set_event_state( + &pdu.event_id, + &pdu.room_id, + state_ids_compressed, + &db.globals, + )?; + + if soft_fail { + db.rooms + .mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; + db.rooms.replace_pdu_leaves(&pdu.room_id, new_room_leaves)?; + return Ok(None); + } + + let pdu_id = db.rooms.append_pdu(pdu, pdu_json, new_room_leaves, db)?; + + Ok(Some(pdu_id)) + } + + /// Returns an iterator over all PDUs in a room. + #[tracing::instrument(skip(self))] + pub fn all_pdus<'a>( + &'a self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result, PduEvent)>> + 'a> { + self.pdus_since(user_id, room_id, 0) + } + + /// Returns an iterator over all events in a room that happened after the event with id `since` + /// in chronological order. + #[tracing::instrument(skip(self))] + pub fn pdus_since<'a>( + &'a self, + user_id: &UserId, + room_id: &RoomId, + since: u64, + ) -> Result, PduEvent)>> + 'a> { + let prefix = self + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); + + // Skip the first pdu if it's exactly at since, because we sent that last time + let mut first_pdu_id = prefix.clone(); + first_pdu_id.extend_from_slice(&(since + 1).to_be_bytes()); + + let user_id = user_id.to_owned(); + + Ok(self + .pduid_pdu + .iter_from(&first_pdu_id, false) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(move |(pdu_id, v)| { + let mut pdu = serde_json::from_slice::(&v) + .map_err(|_| Error::bad_database("PDU in db is invalid."))?; + if pdu.sender != user_id { + pdu.remove_transaction_id()?; + } + Ok((pdu_id, pdu)) + })) + } + + /// Returns an iterator over all events and their tokens in a room that happened before the + /// event with id `until` in reverse-chronological order. + #[tracing::instrument(skip(self))] + pub fn pdus_until<'a>( + &'a self, + user_id: &UserId, + room_id: &RoomId, + until: u64, + ) -> Result, PduEvent)>> + 'a> { + // Create the first part of the full pdu id + let prefix = self + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); + + let mut current = prefix.clone(); + current.extend_from_slice(&(until.saturating_sub(1)).to_be_bytes()); // -1 because we don't want event at `until` + + let current: &[u8] = ¤t; + + let user_id = user_id.to_owned(); + + Ok(self + .pduid_pdu + .iter_from(current, true) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(move |(pdu_id, v)| { + let mut pdu = serde_json::from_slice::(&v) + .map_err(|_| Error::bad_database("PDU in db is invalid."))?; + if pdu.sender != user_id { + pdu.remove_transaction_id()?; + } + Ok((pdu_id, pdu)) + })) + } + + /// Returns an iterator over all events and their token in a room that happened after the event + /// with id `from` in chronological order. + #[tracing::instrument(skip(self))] + pub fn pdus_after<'a>( + &'a self, + user_id: &UserId, + room_id: &RoomId, + from: u64, + ) -> Result, PduEvent)>> + 'a> { + // Create the first part of the full pdu id + let prefix = self + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); + + let mut current = prefix.clone(); + current.extend_from_slice(&(from + 1).to_be_bytes()); // +1 so we don't send the base event + + let current: &[u8] = ¤t; + + let user_id = user_id.to_owned(); + + Ok(self + .pduid_pdu + .iter_from(current, false) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(move |(pdu_id, v)| { + let mut pdu = serde_json::from_slice::(&v) + .map_err(|_| Error::bad_database("PDU in db is invalid."))?; + if pdu.sender != user_id { + pdu.remove_transaction_id()?; + } + Ok((pdu_id, pdu)) + })) + } + + /// Replace a PDU with the redacted form. + #[tracing::instrument(skip(self, reason))] + pub fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent) -> Result<()> { + if let Some(pdu_id) = self.get_pdu_id(event_id)? { + let mut pdu = self + .get_pdu_from_id(&pdu_id)? + .ok_or_else(|| Error::bad_database("PDU ID points to invalid PDU."))?; + pdu.redact(reason)?; + self.replace_pdu(&pdu_id, &pdu)?; + } + // If event does not exist, just noop + Ok(()) + } + diff --git a/src/service/users.rs b/src/database/key_value/users.rs similarity index 100% rename from src/service/users.rs rename to src/database/key_value/users.rs diff --git a/src/service/pusher/data.rs b/src/service/pusher/data.rs new file mode 100644 index 00000000..6b906c24 --- /dev/null +++ b/src/service/pusher/data.rs @@ -0,0 +1,348 @@ +use crate::{Database, Error, PduEvent, Result}; +use bytes::BytesMut; +use ruma::{ + api::{ + client::push::{get_pushers, set_pusher, PusherKind}, + push_gateway::send_event_notification::{ + self, + v1::{Device, Notification, NotificationCounts, NotificationPriority}, + }, + IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken, + }, + events::{ + room::{name::RoomNameEventContent, power_levels::RoomPowerLevelsEventContent}, + AnySyncRoomEvent, RoomEventType, StateEventType, + }, + push::{Action, PushConditionRoomCtx, PushFormat, Ruleset, Tweak}, + serde::Raw, + uint, RoomId, UInt, UserId, +}; +use tracing::{error, info, warn}; + +use std::{fmt::Debug, mem, sync::Arc}; + +use super::abstraction::Tree; + +pub struct PushData { + /// UserId + pushkey -> Pusher + pub(super) senderkey_pusher: Arc, +} + +impl PushData { + #[tracing::instrument(skip(self, sender, pusher))] + pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()> { + let mut key = sender.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(pusher.pushkey.as_bytes()); + + // There are 2 kinds of pushers but the spec says: null deletes the pusher. + if pusher.kind.is_none() { + return self + .senderkey_pusher + .remove(&key) + .map(|_| ()) + .map_err(Into::into); + } + + self.senderkey_pusher.insert( + &key, + &serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"), + )?; + + Ok(()) + } + + #[tracing::instrument(skip(self, senderkey))] + pub fn get_pusher(&self, senderkey: &[u8]) -> Result> { + self.senderkey_pusher + .get(senderkey)? + .map(|push| { + serde_json::from_slice(&*push) + .map_err(|_| Error::bad_database("Invalid Pusher in db.")) + }) + .transpose() + } + + #[tracing::instrument(skip(self, sender))] + pub fn get_pushers(&self, sender: &UserId) -> Result> { + let mut prefix = sender.as_bytes().to_vec(); + prefix.push(0xff); + + self.senderkey_pusher + .scan_prefix(prefix) + .map(|(_, push)| { + serde_json::from_slice(&*push) + .map_err(|_| Error::bad_database("Invalid Pusher in db.")) + }) + .collect() + } + + #[tracing::instrument(skip(self, sender))] + pub fn get_pusher_senderkeys<'a>( + &'a self, + sender: &UserId, + ) -> impl Iterator> + 'a { + let mut prefix = sender.as_bytes().to_vec(); + prefix.push(0xff); + + self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| k) + } +} + +#[tracing::instrument(skip(globals, destination, request))] +pub async fn send_request( + globals: &crate::database::globals::Globals, + destination: &str, + request: T, +) -> Result +where + T: Debug, +{ + let destination = destination.replace("/_matrix/push/v1/notify", ""); + + let http_request = request + .try_into_http_request::( + &destination, + SendAccessToken::IfRequired(""), + &[MatrixVersion::V1_0], + ) + .map_err(|e| { + warn!("Failed to find destination {}: {}", destination, e); + Error::BadServerResponse("Invalid destination") + })? + .map(|body| body.freeze()); + + let reqwest_request = reqwest::Request::try_from(http_request) + .expect("all http requests are valid reqwest requests"); + + // TODO: we could keep this very short and let expo backoff do it's thing... + //*reqwest_request.timeout_mut() = Some(Duration::from_secs(5)); + + let url = reqwest_request.url().clone(); + let response = globals.default_client().execute(reqwest_request).await; + + match response { + Ok(mut response) => { + // reqwest::Response -> http::Response conversion + let status = response.status(); + let mut http_response_builder = http::Response::builder() + .status(status) + .version(response.version()); + mem::swap( + response.headers_mut(), + http_response_builder + .headers_mut() + .expect("http::response::Builder is usable"), + ); + + let body = response.bytes().await.unwrap_or_else(|e| { + warn!("server error {}", e); + Vec::new().into() + }); // TODO: handle timeout + + if status != 200 { + info!( + "Push gateway returned bad response {} {}\n{}\n{:?}", + destination, + status, + url, + crate::utils::string_from_bytes(&body) + ); + } + + let response = T::IncomingResponse::try_from_http_response( + http_response_builder + .body(body) + .expect("reqwest body is valid http body"), + ); + response.map_err(|_| { + info!( + "Push gateway returned invalid response bytes {}\n{}", + destination, url + ); + Error::BadServerResponse("Push gateway returned bad response.") + }) + } + Err(e) => Err(e.into()), + } +} + +#[tracing::instrument(skip(user, unread, pusher, ruleset, pdu, db))] +pub async fn send_push_notice( + user: &UserId, + unread: UInt, + pusher: &get_pushers::v3::Pusher, + ruleset: Ruleset, + pdu: &PduEvent, + db: &Database, +) -> Result<()> { + let mut notify = None; + let mut tweaks = Vec::new(); + + let power_levels: RoomPowerLevelsEventContent = db + .rooms + .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? + .map(|ev| { + serde_json::from_str(ev.content.get()) + .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) + }) + .transpose()? + .unwrap_or_default(); + + for action in get_actions( + user, + &ruleset, + &power_levels, + &pdu.to_sync_room_event(), + &pdu.room_id, + db, + )? { + let n = match action { + Action::DontNotify => false, + // TODO: Implement proper support for coalesce + Action::Notify | Action::Coalesce => true, + Action::SetTweak(tweak) => { + tweaks.push(tweak.clone()); + continue; + } + }; + + if notify.is_some() { + return Err(Error::bad_database( + r#"Malformed pushrule contains more than one of these actions: ["dont_notify", "notify", "coalesce"]"#, + )); + } + + notify = Some(n); + } + + if notify == Some(true) { + send_notice(unread, pusher, tweaks, pdu, db).await?; + } + // Else the event triggered no actions + + Ok(()) +} + +#[tracing::instrument(skip(user, ruleset, pdu, db))] +pub fn get_actions<'a>( + user: &UserId, + ruleset: &'a Ruleset, + power_levels: &RoomPowerLevelsEventContent, + pdu: &Raw, + room_id: &RoomId, + db: &Database, +) -> Result<&'a [Action]> { + let ctx = PushConditionRoomCtx { + room_id: room_id.to_owned(), + member_count: 10_u32.into(), // TODO: get member count efficiently + user_display_name: db + .users + .displayname(user)? + .unwrap_or_else(|| user.localpart().to_owned()), + users_power_levels: power_levels.users.clone(), + default_power_level: power_levels.users_default, + notification_power_levels: power_levels.notifications.clone(), + }; + + Ok(ruleset.get_actions(pdu, &ctx)) +} + +#[tracing::instrument(skip(unread, pusher, tweaks, event, db))] +async fn send_notice( + unread: UInt, + pusher: &get_pushers::v3::Pusher, + tweaks: Vec, + event: &PduEvent, + db: &Database, +) -> Result<()> { + // TODO: email + if pusher.kind == PusherKind::Email { + return Ok(()); + } + + // TODO: + // Two problems with this + // 1. if "event_id_only" is the only format kind it seems we should never add more info + // 2. can pusher/devices have conflicting formats + let event_id_only = pusher.data.format == Some(PushFormat::EventIdOnly); + let url = if let Some(url) = &pusher.data.url { + url + } else { + error!("Http Pusher must have URL specified."); + return Ok(()); + }; + + let mut device = Device::new(pusher.app_id.clone(), pusher.pushkey.clone()); + let mut data_minus_url = pusher.data.clone(); + // The url must be stripped off according to spec + data_minus_url.url = None; + device.data = data_minus_url; + + // Tweaks are only added if the format is NOT event_id_only + if !event_id_only { + device.tweaks = tweaks.clone(); + } + + let d = &[device]; + let mut notifi = Notification::new(d); + + notifi.prio = NotificationPriority::Low; + notifi.event_id = Some(&event.event_id); + notifi.room_id = Some(&event.room_id); + // TODO: missed calls + notifi.counts = NotificationCounts::new(unread, uint!(0)); + + if event.kind == RoomEventType::RoomEncrypted + || tweaks + .iter() + .any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_))) + { + notifi.prio = NotificationPriority::High + } + + if event_id_only { + send_request( + &db.globals, + url, + send_event_notification::v1::Request::new(notifi), + ) + .await?; + } else { + notifi.sender = Some(&event.sender); + notifi.event_type = Some(&event.kind); + let content = serde_json::value::to_raw_value(&event.content).ok(); + notifi.content = content.as_deref(); + + if event.kind == RoomEventType::RoomMember { + notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str()); + } + + let user_name = db.users.displayname(&event.sender)?; + notifi.sender_display_name = user_name.as_deref(); + + let room_name = if let Some(room_name_pdu) = + db.rooms + .room_state_get(&event.room_id, &StateEventType::RoomName, "")? + { + serde_json::from_str::(room_name_pdu.content.get()) + .map_err(|_| Error::bad_database("Invalid room name event in database."))? + .name + } else { + None + }; + + notifi.room_name = room_name.as_deref(); + + send_request( + &db.globals, + url, + send_event_notification::v1::Request::new(notifi), + ) + .await?; + } + + // TODO: email + + Ok(()) +} diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs new file mode 100644 index 00000000..6b906c24 --- /dev/null +++ b/src/service/pusher/mod.rs @@ -0,0 +1,348 @@ +use crate::{Database, Error, PduEvent, Result}; +use bytes::BytesMut; +use ruma::{ + api::{ + client::push::{get_pushers, set_pusher, PusherKind}, + push_gateway::send_event_notification::{ + self, + v1::{Device, Notification, NotificationCounts, NotificationPriority}, + }, + IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken, + }, + events::{ + room::{name::RoomNameEventContent, power_levels::RoomPowerLevelsEventContent}, + AnySyncRoomEvent, RoomEventType, StateEventType, + }, + push::{Action, PushConditionRoomCtx, PushFormat, Ruleset, Tweak}, + serde::Raw, + uint, RoomId, UInt, UserId, +}; +use tracing::{error, info, warn}; + +use std::{fmt::Debug, mem, sync::Arc}; + +use super::abstraction::Tree; + +pub struct PushData { + /// UserId + pushkey -> Pusher + pub(super) senderkey_pusher: Arc, +} + +impl PushData { + #[tracing::instrument(skip(self, sender, pusher))] + pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()> { + let mut key = sender.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(pusher.pushkey.as_bytes()); + + // There are 2 kinds of pushers but the spec says: null deletes the pusher. + if pusher.kind.is_none() { + return self + .senderkey_pusher + .remove(&key) + .map(|_| ()) + .map_err(Into::into); + } + + self.senderkey_pusher.insert( + &key, + &serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"), + )?; + + Ok(()) + } + + #[tracing::instrument(skip(self, senderkey))] + pub fn get_pusher(&self, senderkey: &[u8]) -> Result> { + self.senderkey_pusher + .get(senderkey)? + .map(|push| { + serde_json::from_slice(&*push) + .map_err(|_| Error::bad_database("Invalid Pusher in db.")) + }) + .transpose() + } + + #[tracing::instrument(skip(self, sender))] + pub fn get_pushers(&self, sender: &UserId) -> Result> { + let mut prefix = sender.as_bytes().to_vec(); + prefix.push(0xff); + + self.senderkey_pusher + .scan_prefix(prefix) + .map(|(_, push)| { + serde_json::from_slice(&*push) + .map_err(|_| Error::bad_database("Invalid Pusher in db.")) + }) + .collect() + } + + #[tracing::instrument(skip(self, sender))] + pub fn get_pusher_senderkeys<'a>( + &'a self, + sender: &UserId, + ) -> impl Iterator> + 'a { + let mut prefix = sender.as_bytes().to_vec(); + prefix.push(0xff); + + self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| k) + } +} + +#[tracing::instrument(skip(globals, destination, request))] +pub async fn send_request( + globals: &crate::database::globals::Globals, + destination: &str, + request: T, +) -> Result +where + T: Debug, +{ + let destination = destination.replace("/_matrix/push/v1/notify", ""); + + let http_request = request + .try_into_http_request::( + &destination, + SendAccessToken::IfRequired(""), + &[MatrixVersion::V1_0], + ) + .map_err(|e| { + warn!("Failed to find destination {}: {}", destination, e); + Error::BadServerResponse("Invalid destination") + })? + .map(|body| body.freeze()); + + let reqwest_request = reqwest::Request::try_from(http_request) + .expect("all http requests are valid reqwest requests"); + + // TODO: we could keep this very short and let expo backoff do it's thing... + //*reqwest_request.timeout_mut() = Some(Duration::from_secs(5)); + + let url = reqwest_request.url().clone(); + let response = globals.default_client().execute(reqwest_request).await; + + match response { + Ok(mut response) => { + // reqwest::Response -> http::Response conversion + let status = response.status(); + let mut http_response_builder = http::Response::builder() + .status(status) + .version(response.version()); + mem::swap( + response.headers_mut(), + http_response_builder + .headers_mut() + .expect("http::response::Builder is usable"), + ); + + let body = response.bytes().await.unwrap_or_else(|e| { + warn!("server error {}", e); + Vec::new().into() + }); // TODO: handle timeout + + if status != 200 { + info!( + "Push gateway returned bad response {} {}\n{}\n{:?}", + destination, + status, + url, + crate::utils::string_from_bytes(&body) + ); + } + + let response = T::IncomingResponse::try_from_http_response( + http_response_builder + .body(body) + .expect("reqwest body is valid http body"), + ); + response.map_err(|_| { + info!( + "Push gateway returned invalid response bytes {}\n{}", + destination, url + ); + Error::BadServerResponse("Push gateway returned bad response.") + }) + } + Err(e) => Err(e.into()), + } +} + +#[tracing::instrument(skip(user, unread, pusher, ruleset, pdu, db))] +pub async fn send_push_notice( + user: &UserId, + unread: UInt, + pusher: &get_pushers::v3::Pusher, + ruleset: Ruleset, + pdu: &PduEvent, + db: &Database, +) -> Result<()> { + let mut notify = None; + let mut tweaks = Vec::new(); + + let power_levels: RoomPowerLevelsEventContent = db + .rooms + .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? + .map(|ev| { + serde_json::from_str(ev.content.get()) + .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) + }) + .transpose()? + .unwrap_or_default(); + + for action in get_actions( + user, + &ruleset, + &power_levels, + &pdu.to_sync_room_event(), + &pdu.room_id, + db, + )? { + let n = match action { + Action::DontNotify => false, + // TODO: Implement proper support for coalesce + Action::Notify | Action::Coalesce => true, + Action::SetTweak(tweak) => { + tweaks.push(tweak.clone()); + continue; + } + }; + + if notify.is_some() { + return Err(Error::bad_database( + r#"Malformed pushrule contains more than one of these actions: ["dont_notify", "notify", "coalesce"]"#, + )); + } + + notify = Some(n); + } + + if notify == Some(true) { + send_notice(unread, pusher, tweaks, pdu, db).await?; + } + // Else the event triggered no actions + + Ok(()) +} + +#[tracing::instrument(skip(user, ruleset, pdu, db))] +pub fn get_actions<'a>( + user: &UserId, + ruleset: &'a Ruleset, + power_levels: &RoomPowerLevelsEventContent, + pdu: &Raw, + room_id: &RoomId, + db: &Database, +) -> Result<&'a [Action]> { + let ctx = PushConditionRoomCtx { + room_id: room_id.to_owned(), + member_count: 10_u32.into(), // TODO: get member count efficiently + user_display_name: db + .users + .displayname(user)? + .unwrap_or_else(|| user.localpart().to_owned()), + users_power_levels: power_levels.users.clone(), + default_power_level: power_levels.users_default, + notification_power_levels: power_levels.notifications.clone(), + }; + + Ok(ruleset.get_actions(pdu, &ctx)) +} + +#[tracing::instrument(skip(unread, pusher, tweaks, event, db))] +async fn send_notice( + unread: UInt, + pusher: &get_pushers::v3::Pusher, + tweaks: Vec, + event: &PduEvent, + db: &Database, +) -> Result<()> { + // TODO: email + if pusher.kind == PusherKind::Email { + return Ok(()); + } + + // TODO: + // Two problems with this + // 1. if "event_id_only" is the only format kind it seems we should never add more info + // 2. can pusher/devices have conflicting formats + let event_id_only = pusher.data.format == Some(PushFormat::EventIdOnly); + let url = if let Some(url) = &pusher.data.url { + url + } else { + error!("Http Pusher must have URL specified."); + return Ok(()); + }; + + let mut device = Device::new(pusher.app_id.clone(), pusher.pushkey.clone()); + let mut data_minus_url = pusher.data.clone(); + // The url must be stripped off according to spec + data_minus_url.url = None; + device.data = data_minus_url; + + // Tweaks are only added if the format is NOT event_id_only + if !event_id_only { + device.tweaks = tweaks.clone(); + } + + let d = &[device]; + let mut notifi = Notification::new(d); + + notifi.prio = NotificationPriority::Low; + notifi.event_id = Some(&event.event_id); + notifi.room_id = Some(&event.room_id); + // TODO: missed calls + notifi.counts = NotificationCounts::new(unread, uint!(0)); + + if event.kind == RoomEventType::RoomEncrypted + || tweaks + .iter() + .any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_))) + { + notifi.prio = NotificationPriority::High + } + + if event_id_only { + send_request( + &db.globals, + url, + send_event_notification::v1::Request::new(notifi), + ) + .await?; + } else { + notifi.sender = Some(&event.sender); + notifi.event_type = Some(&event.kind); + let content = serde_json::value::to_raw_value(&event.content).ok(); + notifi.content = content.as_deref(); + + if event.kind == RoomEventType::RoomMember { + notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str()); + } + + let user_name = db.users.displayname(&event.sender)?; + notifi.sender_display_name = user_name.as_deref(); + + let room_name = if let Some(room_name_pdu) = + db.rooms + .room_state_get(&event.room_id, &StateEventType::RoomName, "")? + { + serde_json::from_str::(room_name_pdu.content.get()) + .map_err(|_| Error::bad_database("Invalid room name event in database."))? + .name + } else { + None + }; + + notifi.room_name = room_name.as_deref(); + + send_request( + &db.globals, + url, + send_event_notification::v1::Request::new(notifi), + ) + .await?; + } + + // TODO: email + + Ok(()) +} diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs new file mode 100644 index 00000000..5b423d2d --- /dev/null +++ b/src/service/rooms/timeline/data.rs @@ -0,0 +1,937 @@ + + /// Checks if a room exists. + #[tracing::instrument(skip(self))] + pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result>> { + let prefix = self + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); + + // Look for PDUs in that room. + self.pduid_pdu + .iter_from(&prefix, false) + .filter(|(k, _)| k.starts_with(&prefix)) + .map(|(_, pdu)| { + serde_json::from_slice(&pdu) + .map_err(|_| Error::bad_database("Invalid first PDU in db.")) + .map(Arc::new) + }) + .next() + .transpose() + } + + #[tracing::instrument(skip(self))] + pub fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { + match self + .lasttimelinecount_cache + .lock() + .unwrap() + .entry(room_id.to_owned()) + { + hash_map::Entry::Vacant(v) => { + if let Some(last_count) = self + .pdus_until(&sender_user, &room_id, u64::MAX)? + .filter_map(|r| { + // Filter out buggy events + if r.is_err() { + error!("Bad pdu in pdus_since: {:?}", r); + } + r.ok() + }) + .map(|(pduid, _)| self.pdu_count(&pduid)) + .next() + { + Ok(*v.insert(last_count?)) + } else { + Ok(0) + } + } + hash_map::Entry::Occupied(o) => Ok(*o.get()), + } + } + + // TODO Is this the same as the function above? + #[tracing::instrument(skip(self))] + pub fn latest_pdu_count(&self, room_id: &RoomId) -> Result { + let prefix = self + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); + + let mut last_possible_key = prefix.clone(); + last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); + + self.pduid_pdu + .iter_from(&last_possible_key, true) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .next() + .map(|b| self.pdu_count(&b.0)) + .transpose() + .map(|op| op.unwrap_or_default()) + } + + + + /// Returns the `count` of this pdu's id. + pub fn get_pdu_count(&self, event_id: &EventId) -> Result> { + self.eventid_pduid + .get(event_id.as_bytes())? + .map(|pdu_id| self.pdu_count(&pdu_id)) + .transpose() + } + + /// Returns the json of a pdu. + pub fn get_pdu_json(&self, event_id: &EventId) -> Result> { + self.eventid_pduid + .get(event_id.as_bytes())? + .map_or_else( + || self.eventid_outlierpdu.get(event_id.as_bytes()), + |pduid| { + Ok(Some(self.pduid_pdu.get(&pduid)?.ok_or_else(|| { + Error::bad_database("Invalid pduid in eventid_pduid.") + })?)) + }, + )? + .map(|pdu| { + serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) + }) + .transpose() + } + + /// Returns the json of a pdu. + pub fn get_non_outlier_pdu_json( + &self, + event_id: &EventId, + ) -> Result> { + self.eventid_pduid + .get(event_id.as_bytes())? + .map(|pduid| { + self.pduid_pdu + .get(&pduid)? + .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) + }) + .transpose()? + .map(|pdu| { + serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) + }) + .transpose() + } + + /// Returns the pdu's id. + pub fn get_pdu_id(&self, event_id: &EventId) -> Result>> { + self.eventid_pduid.get(event_id.as_bytes()) + } + + /// Returns the pdu. + /// + /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. + pub fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result> { + self.eventid_pduid + .get(event_id.as_bytes())? + .map(|pduid| { + self.pduid_pdu + .get(&pduid)? + .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) + }) + .transpose()? + .map(|pdu| { + serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) + }) + .transpose() + } + + /// Returns the pdu. + /// + /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. + pub fn get_pdu(&self, event_id: &EventId) -> Result>> { + if let Some(p) = self.pdu_cache.lock().unwrap().get_mut(event_id) { + return Ok(Some(Arc::clone(p))); + } + + if let Some(pdu) = self + .eventid_pduid + .get(event_id.as_bytes())? + .map_or_else( + || self.eventid_outlierpdu.get(event_id.as_bytes()), + |pduid| { + Ok(Some(self.pduid_pdu.get(&pduid)?.ok_or_else(|| { + Error::bad_database("Invalid pduid in eventid_pduid.") + })?)) + }, + )? + .map(|pdu| { + serde_json::from_slice(&pdu) + .map_err(|_| Error::bad_database("Invalid PDU in db.")) + .map(Arc::new) + }) + .transpose()? + { + self.pdu_cache + .lock() + .unwrap() + .insert(event_id.to_owned(), Arc::clone(&pdu)); + Ok(Some(pdu)) + } else { + Ok(None) + } + } + + /// Returns the pdu. + /// + /// This does __NOT__ check the outliers `Tree`. + pub fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result> { + self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { + Ok(Some( + serde_json::from_slice(&pdu) + .map_err(|_| Error::bad_database("Invalid PDU in db."))?, + )) + }) + } + + /// Returns the pdu as a `BTreeMap`. + pub fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result> { + self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { + Ok(Some( + serde_json::from_slice(&pdu) + .map_err(|_| Error::bad_database("Invalid PDU in db."))?, + )) + }) + } + + /// Returns the `count` of this pdu's id. + pub fn pdu_count(&self, pdu_id: &[u8]) -> Result { + utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::()..]) + .map_err(|_| Error::bad_database("PDU has invalid count bytes.")) + } + + /// Removes a pdu and creates a new one with the same id. + #[tracing::instrument(skip(self))] + fn replace_pdu(&self, pdu_id: &[u8], pdu: &PduEvent) -> Result<()> { + if self.pduid_pdu.get(pdu_id)?.is_some() { + self.pduid_pdu.insert( + pdu_id, + &serde_json::to_vec(pdu).expect("PduEvent::to_vec always works"), + )?; + Ok(()) + } else { + Err(Error::BadRequest( + ErrorKind::NotFound, + "PDU does not exist.", + )) + } + } + + /// Creates a new persisted data unit and adds it to a room. + /// + /// By this point the incoming event should be fully authenticated, no auth happens + /// in `append_pdu`. + /// + /// Returns pdu id + #[tracing::instrument(skip(self, pdu, pdu_json, leaves, db))] + pub fn append_pdu<'a>( + &self, + pdu: &PduEvent, + mut pdu_json: CanonicalJsonObject, + leaves: impl IntoIterator + Debug, + db: &Database, + ) -> Result> { + let shortroomid = self.get_shortroomid(&pdu.room_id)?.expect("room exists"); + + // Make unsigned fields correct. This is not properly documented in the spec, but state + // events need to have previous content in the unsigned field, so clients can easily + // interpret things like membership changes + if let Some(state_key) = &pdu.state_key { + if let CanonicalJsonValue::Object(unsigned) = pdu_json + .entry("unsigned".to_owned()) + .or_insert_with(|| CanonicalJsonValue::Object(Default::default())) + { + if let Some(shortstatehash) = self.pdu_shortstatehash(&pdu.event_id).unwrap() { + if let Some(prev_state) = self + .state_get(shortstatehash, &pdu.kind.to_string().into(), state_key) + .unwrap() + { + unsigned.insert( + "prev_content".to_owned(), + CanonicalJsonValue::Object( + utils::to_canonical_object(prev_state.content.clone()) + .expect("event is valid, we just created it"), + ), + ); + } + } + } else { + error!("Invalid unsigned type in pdu."); + } + } + + // We must keep track of all events that have been referenced. + self.mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; + self.replace_pdu_leaves(&pdu.room_id, leaves)?; + + let mutex_insert = Arc::clone( + db.globals + .roomid_mutex_insert + .write() + .unwrap() + .entry(pdu.room_id.clone()) + .or_default(), + ); + let insert_lock = mutex_insert.lock().unwrap(); + + let count1 = db.globals.next_count()?; + // Mark as read first so the sending client doesn't get a notification even if appending + // fails + self.edus + .private_read_set(&pdu.room_id, &pdu.sender, count1, &db.globals)?; + self.reset_notification_counts(&pdu.sender, &pdu.room_id)?; + + let count2 = db.globals.next_count()?; + let mut pdu_id = shortroomid.to_be_bytes().to_vec(); + pdu_id.extend_from_slice(&count2.to_be_bytes()); + + // There's a brief moment of time here where the count is updated but the pdu does not + // exist. This could theoretically lead to dropped pdus, but it's extremely rare + // + // Update: We fixed this using insert_lock + + self.pduid_pdu.insert( + &pdu_id, + &serde_json::to_vec(&pdu_json).expect("CanonicalJsonObject is always a valid"), + )?; + self.lasttimelinecount_cache + .lock() + .unwrap() + .insert(pdu.room_id.clone(), count2); + + self.eventid_pduid + .insert(pdu.event_id.as_bytes(), &pdu_id)?; + self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?; + + drop(insert_lock); + + // See if the event matches any known pushers + let power_levels: RoomPowerLevelsEventContent = db + .rooms + .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? + .map(|ev| { + serde_json::from_str(ev.content.get()) + .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) + }) + .transpose()? + .unwrap_or_default(); + + let sync_pdu = pdu.to_sync_room_event(); + + let mut notifies = Vec::new(); + let mut highlights = Vec::new(); + + for user in self.get_our_real_users(&pdu.room_id, db)?.iter() { + // Don't notify the user of their own events + if user == &pdu.sender { + continue; + } + + let rules_for_user = db + .account_data + .get( + None, + user, + GlobalAccountDataEventType::PushRules.to_string().into(), + )? + .map(|ev: PushRulesEvent| ev.content.global) + .unwrap_or_else(|| Ruleset::server_default(user)); + + let mut highlight = false; + let mut notify = false; + + for action in pusher::get_actions( + user, + &rules_for_user, + &power_levels, + &sync_pdu, + &pdu.room_id, + db, + )? { + match action { + Action::DontNotify => notify = false, + // TODO: Implement proper support for coalesce + Action::Notify | Action::Coalesce => notify = true, + Action::SetTweak(Tweak::Highlight(true)) => { + highlight = true; + } + _ => {} + }; + } + + let mut userroom_id = user.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(pdu.room_id.as_bytes()); + + if notify { + notifies.push(userroom_id.clone()); + } + + if highlight { + highlights.push(userroom_id); + } + + for senderkey in db.pusher.get_pusher_senderkeys(user) { + db.sending.send_push_pdu(&*pdu_id, senderkey)?; + } + } + + self.userroomid_notificationcount + .increment_batch(&mut notifies.into_iter())?; + self.userroomid_highlightcount + .increment_batch(&mut highlights.into_iter())?; + + match pdu.kind { + RoomEventType::RoomRedaction => { + if let Some(redact_id) = &pdu.redacts { + self.redact_pdu(redact_id, pdu)?; + } + } + RoomEventType::RoomMember => { + if let Some(state_key) = &pdu.state_key { + #[derive(Deserialize)] + struct ExtractMembership { + membership: MembershipState, + } + + // if the state_key fails + let target_user_id = UserId::parse(state_key.clone()) + .expect("This state_key was previously validated"); + + let content = serde_json::from_str::(pdu.content.get()) + .map_err(|_| Error::bad_database("Invalid content in pdu."))?; + + let invite_state = match content.membership { + MembershipState::Invite => { + let state = self.calculate_invite_state(pdu)?; + Some(state) + } + _ => None, + }; + + // Update our membership info, we do this here incase a user is invited + // and immediately leaves we need the DB to record the invite event for auth + self.update_membership( + &pdu.room_id, + &target_user_id, + content.membership, + &pdu.sender, + invite_state, + db, + true, + )?; + } + } + RoomEventType::RoomMessage => { + #[derive(Deserialize)] + struct ExtractBody<'a> { + #[serde(borrow)] + body: Option>, + } + + let content = serde_json::from_str::>(pdu.content.get()) + .map_err(|_| Error::bad_database("Invalid content in pdu."))?; + + if let Some(body) = content.body { + DB.rooms.search.index_pdu(room_id, pdu_id, body)?; + + let admin_room = self.id_from_alias( + <&RoomAliasId>::try_from( + format!("#admins:{}", db.globals.server_name()).as_str(), + ) + .expect("#admins:server_name is a valid room alias"), + )?; + let server_user = format!("@conduit:{}", db.globals.server_name()); + + let to_conduit = body.starts_with(&format!("{}: ", server_user)); + + // This will evaluate to false if the emergency password is set up so that + // the administrator can execute commands as conduit + let from_conduit = + pdu.sender == server_user && db.globals.emergency_password().is_none(); + + if to_conduit && !from_conduit && admin_room.as_ref() == Some(&pdu.room_id) { + db.admin.process_message(body.to_string()); + } + } + } + _ => {} + } + + for appservice in db.appservice.all()? { + if self.appservice_in_room(room_id, &appservice, db)? { + db.sending.send_pdu_appservice(&appservice.0, &pdu_id)?; + continue; + } + + // If the RoomMember event has a non-empty state_key, it is targeted at someone. + // If it is our appservice user, we send this PDU to it. + if pdu.kind == RoomEventType::RoomMember { + if let Some(state_key_uid) = &pdu + .state_key + .as_ref() + .and_then(|state_key| UserId::parse(state_key.as_str()).ok()) + { + if let Some(appservice_uid) = appservice + .1 + .get("sender_localpart") + .and_then(|string| string.as_str()) + .and_then(|string| { + UserId::parse_with_server_name(string, db.globals.server_name()).ok() + }) + { + if state_key_uid == &appservice_uid { + db.sending.send_pdu_appservice(&appservice.0, &pdu_id)?; + continue; + } + } + } + } + + if let Some(namespaces) = appservice.1.get("namespaces") { + let users = namespaces + .get("users") + .and_then(|users| users.as_sequence()) + .map_or_else(Vec::new, |users| { + users + .iter() + .filter_map(|users| Regex::new(users.get("regex")?.as_str()?).ok()) + .collect::>() + }); + let aliases = namespaces + .get("aliases") + .and_then(|aliases| aliases.as_sequence()) + .map_or_else(Vec::new, |aliases| { + aliases + .iter() + .filter_map(|aliases| Regex::new(aliases.get("regex")?.as_str()?).ok()) + .collect::>() + }); + let rooms = namespaces + .get("rooms") + .and_then(|rooms| rooms.as_sequence()); + + let matching_users = |users: &Regex| { + users.is_match(pdu.sender.as_str()) + || pdu.kind == RoomEventType::RoomMember + && pdu + .state_key + .as_ref() + .map_or(false, |state_key| users.is_match(state_key)) + }; + let matching_aliases = |aliases: &Regex| { + self.room_aliases(room_id) + .filter_map(|r| r.ok()) + .any(|room_alias| aliases.is_match(room_alias.as_str())) + }; + + if aliases.iter().any(matching_aliases) + || rooms.map_or(false, |rooms| rooms.contains(&room_id.as_str().into())) + || users.iter().any(matching_users) + { + db.sending.send_pdu_appservice(&appservice.0, &pdu_id)?; + } + } + } + + + Ok(pdu_id) + } + + pub fn create_hash_and_sign_event( + &self, + pdu_builder: PduBuilder, + sender: &UserId, + room_id: &RoomId, + db: &Database, + _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + ) -> (PduEvent, CanonicalJsonObj) { + let PduBuilder { + event_type, + content, + unsigned, + state_key, + redacts, + } = pdu_builder; + + let prev_events: Vec<_> = db + .rooms + .get_pdu_leaves(room_id)? + .into_iter() + .take(20) + .collect(); + + let create_event = db + .rooms + .room_state_get(room_id, &StateEventType::RoomCreate, "")?; + + let create_event_content: Option = create_event + .as_ref() + .map(|create_event| { + serde_json::from_str(create_event.content.get()).map_err(|e| { + warn!("Invalid create event: {}", e); + Error::bad_database("Invalid create event in db.") + }) + }) + .transpose()?; + + // If there was no create event yet, assume we are creating a room with the default + // version right now + let room_version_id = create_event_content + .map_or(db.globals.default_room_version(), |create_event| { + create_event.room_version + }); + let room_version = + RoomVersion::new(&room_version_id).expect("room version is supported"); + + let auth_events = + self.get_auth_events(room_id, &event_type, sender, state_key.as_deref(), &content)?; + + // Our depth is the maximum depth of prev_events + 1 + let depth = prev_events + .iter() + .filter_map(|event_id| Some(db.rooms.get_pdu(event_id).ok()??.depth)) + .max() + .unwrap_or_else(|| uint!(0)) + + uint!(1); + + let mut unsigned = unsigned.unwrap_or_default(); + + if let Some(state_key) = &state_key { + if let Some(prev_pdu) = + self.room_state_get(room_id, &event_type.to_string().into(), state_key)? + { + unsigned.insert( + "prev_content".to_owned(), + serde_json::from_str(prev_pdu.content.get()).expect("string is valid json"), + ); + unsigned.insert( + "prev_sender".to_owned(), + serde_json::to_value(&prev_pdu.sender).expect("UserId::to_value always works"), + ); + } + } + + let pdu = PduEvent { + event_id: ruma::event_id!("$thiswillbefilledinlater").into(), + room_id: room_id.to_owned(), + sender: sender_user.to_owned(), + origin_server_ts: utils::millis_since_unix_epoch() + .try_into() + .expect("time is valid"), + kind: event_type, + content, + state_key, + prev_events, + depth, + auth_events: auth_events + .iter() + .map(|(_, pdu)| pdu.event_id.clone()) + .collect(), + redacts, + unsigned: if unsigned.is_empty() { + None + } else { + Some(to_raw_value(&unsigned).expect("to_raw_value always works")) + }, + hashes: EventHash { + sha256: "aaa".to_owned(), + }, + signatures: None, + }; + + let auth_check = state_res::auth_check( + &room_version, + &pdu, + None::, // TODO: third_party_invite + |k, s| auth_events.get(&(k.clone(), s.to_owned())), + ) + .map_err(|e| { + error!("{:?}", e); + Error::bad_database("Auth check failed.") + })?; + + if !auth_check { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Event is not authorized.", + )); + } + + // Hash and sign + let mut pdu_json = + utils::to_canonical_object(&pdu).expect("event is valid, we just created it"); + + pdu_json.remove("event_id"); + + // Add origin because synapse likes that (and it's required in the spec) + pdu_json.insert( + "origin".to_owned(), + to_canonical_value(db.globals.server_name()) + .expect("server name is a valid CanonicalJsonValue"), + ); + + match ruma::signatures::hash_and_sign_event( + db.globals.server_name().as_str(), + db.globals.keypair(), + &mut pdu_json, + &room_version_id, + ) { + Ok(_) => {} + Err(e) => { + return match e { + ruma::signatures::Error::PduSize => Err(Error::BadRequest( + ErrorKind::TooLarge, + "Message is too long", + )), + _ => Err(Error::BadRequest( + ErrorKind::Unknown, + "Signing event failed", + )), + } + } + } + + // Generate event id + pdu.event_id = EventId::parse_arc(format!( + "${}", + ruma::signatures::reference_hash(&pdu_json, &room_version_id) + .expect("ruma can calculate reference hashes") + )) + .expect("ruma's reference hashes are valid event ids"); + + pdu_json.insert( + "event_id".to_owned(), + CanonicalJsonValue::String(pdu.event_id.as_str().to_owned()), + ); + + // Generate short event id + let _shorteventid = self.get_or_create_shorteventid(&pdu.event_id, &db.globals)?; + } + + /// Creates a new persisted data unit and adds it to a room. This function takes a + /// roomid_mutex_state, meaning that only this function is able to mutate the room state. + #[tracing::instrument(skip(self, db, _mutex_lock))] + pub fn build_and_append_pdu( + &self, + pdu_builder: PduBuilder, + sender: &UserId, + room_id: &RoomId, + db: &Database, + _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + ) -> Result> { + + let (pdu, pdu_json) = create_hash_and_sign_event()?; + + + // We append to state before appending the pdu, so we don't have a moment in time with the + // pdu without it's state. This is okay because append_pdu can't fail. + let statehashid = self.append_to_state(&pdu, &db.globals)?; + + let pdu_id = self.append_pdu( + &pdu, + pdu_json, + // Since this PDU references all pdu_leaves we can update the leaves + // of the room + iter::once(&*pdu.event_id), + db, + )?; + + // We set the room state after inserting the pdu, so that we never have a moment in time + // where events in the current room state do not exist + self.set_room_state(room_id, statehashid)?; + + let mut servers: HashSet> = + self.room_servers(room_id).filter_map(|r| r.ok()).collect(); + + // In case we are kicking or banning a user, we need to inform their server of the change + if pdu.kind == RoomEventType::RoomMember { + if let Some(state_key_uid) = &pdu + .state_key + .as_ref() + .and_then(|state_key| UserId::parse(state_key.as_str()).ok()) + { + servers.insert(Box::from(state_key_uid.server_name())); + } + } + + // Remove our server from the server list since it will be added to it by room_servers() and/or the if statement above + servers.remove(db.globals.server_name()); + + db.sending.send_pdu(servers.into_iter(), &pdu_id)?; + + Ok(pdu.event_id) + } + + /// Append the incoming event setting the state snapshot to the state from the + /// server that sent the event. + #[tracing::instrument(skip_all)] + fn append_incoming_pdu<'a>( + db: &Database, + pdu: &PduEvent, + pdu_json: CanonicalJsonObject, + new_room_leaves: impl IntoIterator + Clone + Debug, + state_ids_compressed: HashSet, + soft_fail: bool, + _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + ) -> Result>> { + // We append to state before appending the pdu, so we don't have a moment in time with the + // pdu without it's state. This is okay because append_pdu can't fail. + db.rooms.set_event_state( + &pdu.event_id, + &pdu.room_id, + state_ids_compressed, + &db.globals, + )?; + + if soft_fail { + db.rooms + .mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; + db.rooms.replace_pdu_leaves(&pdu.room_id, new_room_leaves)?; + return Ok(None); + } + + let pdu_id = db.rooms.append_pdu(pdu, pdu_json, new_room_leaves, db)?; + + Ok(Some(pdu_id)) + } + + /// Returns an iterator over all PDUs in a room. + #[tracing::instrument(skip(self))] + pub fn all_pdus<'a>( + &'a self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result, PduEvent)>> + 'a> { + self.pdus_since(user_id, room_id, 0) + } + + /// Returns an iterator over all events in a room that happened after the event with id `since` + /// in chronological order. + #[tracing::instrument(skip(self))] + pub fn pdus_since<'a>( + &'a self, + user_id: &UserId, + room_id: &RoomId, + since: u64, + ) -> Result, PduEvent)>> + 'a> { + let prefix = self + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); + + // Skip the first pdu if it's exactly at since, because we sent that last time + let mut first_pdu_id = prefix.clone(); + first_pdu_id.extend_from_slice(&(since + 1).to_be_bytes()); + + let user_id = user_id.to_owned(); + + Ok(self + .pduid_pdu + .iter_from(&first_pdu_id, false) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(move |(pdu_id, v)| { + let mut pdu = serde_json::from_slice::(&v) + .map_err(|_| Error::bad_database("PDU in db is invalid."))?; + if pdu.sender != user_id { + pdu.remove_transaction_id()?; + } + Ok((pdu_id, pdu)) + })) + } + + /// Returns an iterator over all events and their tokens in a room that happened before the + /// event with id `until` in reverse-chronological order. + #[tracing::instrument(skip(self))] + pub fn pdus_until<'a>( + &'a self, + user_id: &UserId, + room_id: &RoomId, + until: u64, + ) -> Result, PduEvent)>> + 'a> { + // Create the first part of the full pdu id + let prefix = self + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); + + let mut current = prefix.clone(); + current.extend_from_slice(&(until.saturating_sub(1)).to_be_bytes()); // -1 because we don't want event at `until` + + let current: &[u8] = ¤t; + + let user_id = user_id.to_owned(); + + Ok(self + .pduid_pdu + .iter_from(current, true) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(move |(pdu_id, v)| { + let mut pdu = serde_json::from_slice::(&v) + .map_err(|_| Error::bad_database("PDU in db is invalid."))?; + if pdu.sender != user_id { + pdu.remove_transaction_id()?; + } + Ok((pdu_id, pdu)) + })) + } + + /// Returns an iterator over all events and their token in a room that happened after the event + /// with id `from` in chronological order. + #[tracing::instrument(skip(self))] + pub fn pdus_after<'a>( + &'a self, + user_id: &UserId, + room_id: &RoomId, + from: u64, + ) -> Result, PduEvent)>> + 'a> { + // Create the first part of the full pdu id + let prefix = self + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); + + let mut current = prefix.clone(); + current.extend_from_slice(&(from + 1).to_be_bytes()); // +1 so we don't send the base event + + let current: &[u8] = ¤t; + + let user_id = user_id.to_owned(); + + Ok(self + .pduid_pdu + .iter_from(current, false) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(move |(pdu_id, v)| { + let mut pdu = serde_json::from_slice::(&v) + .map_err(|_| Error::bad_database("PDU in db is invalid."))?; + if pdu.sender != user_id { + pdu.remove_transaction_id()?; + } + Ok((pdu_id, pdu)) + })) + } + + /// Replace a PDU with the redacted form. + #[tracing::instrument(skip(self, reason))] + pub fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent) -> Result<()> { + if let Some(pdu_id) = self.get_pdu_id(event_id)? { + let mut pdu = self + .get_pdu_from_id(&pdu_id)? + .ok_or_else(|| Error::bad_database("PDU ID points to invalid PDU."))?; + pdu.redact(reason)?; + self.replace_pdu(&pdu_id, &pdu)?; + } + // If event does not exist, just noop + Ok(()) + } + diff --git a/src/service/users/data.rs b/src/service/users/data.rs new file mode 100644 index 00000000..7c15f1d8 --- /dev/null +++ b/src/service/users/data.rs @@ -0,0 +1,1101 @@ +use crate::{utils, Error, Result}; +use ruma::{ + api::client::{device::Device, error::ErrorKind, filter::IncomingFilterDefinition}, + encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, + events::{AnyToDeviceEvent, StateEventType}, + serde::Raw, + DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, MxcUri, RoomAliasId, + UInt, UserId, +}; +use std::{collections::BTreeMap, mem, sync::Arc}; +use tracing::warn; + +use super::abstraction::Tree; + +pub struct Users { + pub(super) userid_password: Arc, + pub(super) userid_displayname: Arc, + pub(super) userid_avatarurl: Arc, + pub(super) userid_blurhash: Arc, + pub(super) userdeviceid_token: Arc, + pub(super) userdeviceid_metadata: Arc, // This is also used to check if a device exists + pub(super) userid_devicelistversion: Arc, // DevicelistVersion = u64 + pub(super) token_userdeviceid: Arc, + + pub(super) onetimekeyid_onetimekeys: Arc, // OneTimeKeyId = UserId + DeviceKeyId + pub(super) userid_lastonetimekeyupdate: Arc, // LastOneTimeKeyUpdate = Count + pub(super) keychangeid_userid: Arc, // KeyChangeId = UserId/RoomId + Count + pub(super) keyid_key: Arc, // KeyId = UserId + KeyId (depends on key type) + pub(super) userid_masterkeyid: Arc, + pub(super) userid_selfsigningkeyid: Arc, + pub(super) userid_usersigningkeyid: Arc, + + pub(super) userfilterid_filter: Arc, // UserFilterId = UserId + FilterId + + pub(super) todeviceid_events: Arc, // ToDeviceId = UserId + DeviceId + Count +} + +impl Users { + /// Check if a user has an account on this homeserver. + #[tracing::instrument(skip(self, user_id))] + pub fn exists(&self, user_id: &UserId) -> Result { + Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) + } + + /// Check if account is deactivated + #[tracing::instrument(skip(self, user_id))] + pub fn is_deactivated(&self, user_id: &UserId) -> Result { + Ok(self + .userid_password + .get(user_id.as_bytes())? + .ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "User does not exist.", + ))? + .is_empty()) + } + + /// Check if a user is an admin + #[tracing::instrument(skip(self, user_id, rooms, globals))] + pub fn is_admin( + &self, + user_id: &UserId, + rooms: &super::rooms::Rooms, + globals: &super::globals::Globals, + ) -> Result { + let admin_room_alias_id = RoomAliasId::parse(format!("#admins:{}", globals.server_name())) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias."))?; + let admin_room_id = rooms.id_from_alias(&admin_room_alias_id)?.unwrap(); + + rooms.is_joined(user_id, &admin_room_id) + } + + /// Create a new user account on this homeserver. + #[tracing::instrument(skip(self, user_id, password))] + pub fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { + self.set_password(user_id, password)?; + Ok(()) + } + + /// Returns the number of users registered on this server. + #[tracing::instrument(skip(self))] + pub fn count(&self) -> Result { + Ok(self.userid_password.iter().count()) + } + + /// Find out which user an access token belongs to. + #[tracing::instrument(skip(self, token))] + pub fn find_from_token(&self, token: &str) -> Result, String)>> { + self.token_userdeviceid + .get(token.as_bytes())? + .map_or(Ok(None), |bytes| { + let mut parts = bytes.split(|&b| b == 0xff); + let user_bytes = parts.next().ok_or_else(|| { + Error::bad_database("User ID in token_userdeviceid is invalid.") + })?; + let device_bytes = parts.next().ok_or_else(|| { + Error::bad_database("Device ID in token_userdeviceid is invalid.") + })?; + + Ok(Some(( + UserId::parse(utils::string_from_bytes(user_bytes).map_err(|_| { + Error::bad_database("User ID in token_userdeviceid is invalid unicode.") + })?) + .map_err(|_| { + Error::bad_database("User ID in token_userdeviceid is invalid.") + })?, + utils::string_from_bytes(device_bytes).map_err(|_| { + Error::bad_database("Device ID in token_userdeviceid is invalid.") + })?, + ))) + }) + } + + /// Returns an iterator over all users on this homeserver. + #[tracing::instrument(skip(self))] + pub fn iter(&self) -> impl Iterator>> + '_ { + self.userid_password.iter().map(|(bytes, _)| { + UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database("User ID in userid_password is invalid unicode.") + })?) + .map_err(|_| Error::bad_database("User ID in userid_password is invalid.")) + }) + } + + /// Returns a list of local users as list of usernames. + /// + /// A user account is considered `local` if the length of it's password is greater then zero. + #[tracing::instrument(skip(self))] + pub fn list_local_users(&self) -> Result> { + let users: Vec = self + .userid_password + .iter() + .filter_map(|(username, pw)| self.get_username_with_valid_password(&username, &pw)) + .collect(); + Ok(users) + } + + /// Will only return with Some(username) if the password was not empty and the + /// username could be successfully parsed. + /// If utils::string_from_bytes(...) returns an error that username will be skipped + /// and the error will be logged. + #[tracing::instrument(skip(self))] + fn get_username_with_valid_password(&self, username: &[u8], password: &[u8]) -> Option { + // A valid password is not empty + if password.is_empty() { + None + } else { + match utils::string_from_bytes(username) { + Ok(u) => Some(u), + Err(e) => { + warn!( + "Failed to parse username while calling get_local_users(): {}", + e.to_string() + ); + None + } + } + } + } + + /// Returns the password hash for the given user. + #[tracing::instrument(skip(self, user_id))] + pub fn password_hash(&self, user_id: &UserId) -> Result> { + self.userid_password + .get(user_id.as_bytes())? + .map_or(Ok(None), |bytes| { + Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Password hash in db is not valid string.") + })?)) + }) + } + + /// Hash and set the user's password to the Argon2 hash + #[tracing::instrument(skip(self, user_id, password))] + pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { + if let Some(password) = password { + if let Ok(hash) = utils::calculate_hash(password) { + self.userid_password + .insert(user_id.as_bytes(), hash.as_bytes())?; + Ok(()) + } else { + Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Password does not meet the requirements.", + )) + } + } else { + self.userid_password.insert(user_id.as_bytes(), b"")?; + Ok(()) + } + } + + /// Returns the displayname of a user on this homeserver. + #[tracing::instrument(skip(self, user_id))] + pub fn displayname(&self, user_id: &UserId) -> Result> { + self.userid_displayname + .get(user_id.as_bytes())? + .map_or(Ok(None), |bytes| { + Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Displayname in db is invalid.") + })?)) + }) + } + + /// Sets a new displayname or removes it if displayname is None. You still need to nofify all rooms of this change. + #[tracing::instrument(skip(self, user_id, displayname))] + pub fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()> { + if let Some(displayname) = displayname { + self.userid_displayname + .insert(user_id.as_bytes(), displayname.as_bytes())?; + } else { + self.userid_displayname.remove(user_id.as_bytes())?; + } + + Ok(()) + } + + /// Get the avatar_url of a user. + #[tracing::instrument(skip(self, user_id))] + pub fn avatar_url(&self, user_id: &UserId) -> Result>> { + self.userid_avatarurl + .get(user_id.as_bytes())? + .map(|bytes| { + let s = utils::string_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Avatar URL in db is invalid."))?; + s.try_into() + .map_err(|_| Error::bad_database("Avatar URL in db is invalid.")) + }) + .transpose() + } + + /// Sets a new avatar_url or removes it if avatar_url is None. + #[tracing::instrument(skip(self, user_id, avatar_url))] + pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option>) -> Result<()> { + if let Some(avatar_url) = avatar_url { + self.userid_avatarurl + .insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?; + } else { + self.userid_avatarurl.remove(user_id.as_bytes())?; + } + + Ok(()) + } + + /// Get the blurhash of a user. + #[tracing::instrument(skip(self, user_id))] + pub fn blurhash(&self, user_id: &UserId) -> Result> { + self.userid_blurhash + .get(user_id.as_bytes())? + .map(|bytes| { + let s = utils::string_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Avatar URL in db is invalid."))?; + + Ok(s) + }) + .transpose() + } + + /// Sets a new avatar_url or removes it if avatar_url is None. + #[tracing::instrument(skip(self, user_id, blurhash))] + pub fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()> { + if let Some(blurhash) = blurhash { + self.userid_blurhash + .insert(user_id.as_bytes(), blurhash.as_bytes())?; + } else { + self.userid_blurhash.remove(user_id.as_bytes())?; + } + + Ok(()) + } + + /// Adds a new device to a user. + #[tracing::instrument(skip(self, user_id, device_id, token, initial_device_display_name))] + pub fn create_device( + &self, + user_id: &UserId, + device_id: &DeviceId, + token: &str, + initial_device_display_name: Option, + ) -> Result<()> { + // This method should never be called for nonexistent users. + assert!(self.exists(user_id)?); + + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xff); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + self.userid_devicelistversion + .increment(user_id.as_bytes())?; + + self.userdeviceid_metadata.insert( + &userdeviceid, + &serde_json::to_vec(&Device { + device_id: device_id.into(), + display_name: initial_device_display_name, + last_seen_ip: None, // TODO + last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()), + }) + .expect("Device::to_string never fails."), + )?; + + self.set_token(user_id, device_id, token)?; + + Ok(()) + } + + /// Removes a device from a user. + #[tracing::instrument(skip(self, user_id, device_id))] + pub fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xff); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + // Remove tokens + if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? { + self.userdeviceid_token.remove(&userdeviceid)?; + self.token_userdeviceid.remove(&old_token)?; + } + + // Remove todevice events + let mut prefix = userdeviceid.clone(); + prefix.push(0xff); + + for (key, _) in self.todeviceid_events.scan_prefix(prefix) { + self.todeviceid_events.remove(&key)?; + } + + // TODO: Remove onetimekeys + + self.userid_devicelistversion + .increment(user_id.as_bytes())?; + + self.userdeviceid_metadata.remove(&userdeviceid)?; + + Ok(()) + } + + /// Returns an iterator over all device ids of this user. + #[tracing::instrument(skip(self, user_id))] + pub fn all_device_ids<'a>( + &'a self, + user_id: &UserId, + ) -> impl Iterator>> + 'a { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xff); + // All devices have metadata + self.userdeviceid_metadata + .scan_prefix(prefix) + .map(|(bytes, _)| { + Ok(utils::string_from_bytes( + bytes + .rsplit(|&b| b == 0xff) + .next() + .ok_or_else(|| Error::bad_database("UserDevice ID in db is invalid."))?, + ) + .map_err(|_| Error::bad_database("Device ID in userdeviceid_metadata is invalid."))? + .into()) + }) + } + + /// Replaces the access token of one device. + #[tracing::instrument(skip(self, user_id, device_id, token))] + pub fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xff); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + // All devices have metadata + assert!(self.userdeviceid_metadata.get(&userdeviceid)?.is_some()); + + // Remove old token + if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? { + self.token_userdeviceid.remove(&old_token)?; + // It will be removed from userdeviceid_token by the insert later + } + + // Assign token to user device combination + self.userdeviceid_token + .insert(&userdeviceid, token.as_bytes())?; + self.token_userdeviceid + .insert(token.as_bytes(), &userdeviceid)?; + + Ok(()) + } + + #[tracing::instrument(skip( + self, + user_id, + device_id, + one_time_key_key, + one_time_key_value, + globals + ))] + pub fn add_one_time_key( + &self, + user_id: &UserId, + device_id: &DeviceId, + one_time_key_key: &DeviceKeyId, + one_time_key_value: &Raw, + globals: &super::globals::Globals, + ) -> Result<()> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(device_id.as_bytes()); + + // All devices have metadata + // Only existing devices should be able to call this. + assert!(self.userdeviceid_metadata.get(&key)?.is_some()); + + key.push(0xff); + // TODO: Use DeviceKeyId::to_string when it's available (and update everything, + // because there are no wrapping quotation marks anymore) + key.extend_from_slice( + serde_json::to_string(one_time_key_key) + .expect("DeviceKeyId::to_string always works") + .as_bytes(), + ); + + self.onetimekeyid_onetimekeys.insert( + &key, + &serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"), + )?; + + self.userid_lastonetimekeyupdate + .insert(user_id.as_bytes(), &globals.next_count()?.to_be_bytes())?; + + Ok(()) + } + + #[tracing::instrument(skip(self, user_id))] + pub fn last_one_time_keys_update(&self, user_id: &UserId) -> Result { + self.userid_lastonetimekeyupdate + .get(user_id.as_bytes())? + .map(|bytes| { + utils::u64_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.") + }) + }) + .unwrap_or(Ok(0)) + } + + #[tracing::instrument(skip(self, user_id, device_id, key_algorithm, globals))] + pub fn take_one_time_key( + &self, + user_id: &UserId, + device_id: &DeviceId, + key_algorithm: &DeviceKeyAlgorithm, + globals: &super::globals::Globals, + ) -> Result, Raw)>> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xff); + prefix.extend_from_slice(device_id.as_bytes()); + prefix.push(0xff); + prefix.push(b'"'); // Annoying quotation mark + prefix.extend_from_slice(key_algorithm.as_ref().as_bytes()); + prefix.push(b':'); + + self.userid_lastonetimekeyupdate + .insert(user_id.as_bytes(), &globals.next_count()?.to_be_bytes())?; + + self.onetimekeyid_onetimekeys + .scan_prefix(prefix) + .next() + .map(|(key, value)| { + self.onetimekeyid_onetimekeys.remove(&key)?; + + Ok(( + serde_json::from_slice( + &*key + .rsplit(|&b| b == 0xff) + .next() + .ok_or_else(|| Error::bad_database("OneTimeKeyId in db is invalid."))?, + ) + .map_err(|_| Error::bad_database("OneTimeKeyId in db is invalid."))?, + serde_json::from_slice(&*value) + .map_err(|_| Error::bad_database("OneTimeKeys in db are invalid."))?, + )) + }) + .transpose() + } + + #[tracing::instrument(skip(self, user_id, device_id))] + pub fn count_one_time_keys( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result> { + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xff); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + let mut counts = BTreeMap::new(); + + for algorithm in + self.onetimekeyid_onetimekeys + .scan_prefix(userdeviceid) + .map(|(bytes, _)| { + Ok::<_, Error>( + serde_json::from_slice::>( + &*bytes.rsplit(|&b| b == 0xff).next().ok_or_else(|| { + Error::bad_database("OneTimeKey ID in db is invalid.") + })?, + ) + .map_err(|_| Error::bad_database("DeviceKeyId in db is invalid."))? + .algorithm(), + ) + }) + { + *counts.entry(algorithm?).or_default() += UInt::from(1_u32); + } + + Ok(counts) + } + + #[tracing::instrument(skip(self, user_id, device_id, device_keys, rooms, globals))] + pub fn add_device_keys( + &self, + user_id: &UserId, + device_id: &DeviceId, + device_keys: &Raw, + rooms: &super::rooms::Rooms, + globals: &super::globals::Globals, + ) -> Result<()> { + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xff); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + self.keyid_key.insert( + &userdeviceid, + &serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"), + )?; + + self.mark_device_key_update(user_id, rooms, globals)?; + + Ok(()) + } + + #[tracing::instrument(skip( + self, + master_key, + self_signing_key, + user_signing_key, + rooms, + globals + ))] + pub fn add_cross_signing_keys( + &self, + user_id: &UserId, + master_key: &Raw, + self_signing_key: &Option>, + user_signing_key: &Option>, + rooms: &super::rooms::Rooms, + globals: &super::globals::Globals, + ) -> Result<()> { + // TODO: Check signatures + + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xff); + + // Master key + let mut master_key_ids = master_key + .deserialize() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))? + .keys + .into_values(); + + let master_key_id = master_key_ids.next().ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Master key contained no key.", + ))?; + + if master_key_ids.next().is_some() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Master key contained more than one key.", + )); + } + + let mut master_key_key = prefix.clone(); + master_key_key.extend_from_slice(master_key_id.as_bytes()); + + self.keyid_key + .insert(&master_key_key, master_key.json().get().as_bytes())?; + + self.userid_masterkeyid + .insert(user_id.as_bytes(), &master_key_key)?; + + // Self-signing key + if let Some(self_signing_key) = self_signing_key { + let mut self_signing_key_ids = self_signing_key + .deserialize() + .map_err(|_| { + Error::BadRequest(ErrorKind::InvalidParam, "Invalid self signing key") + })? + .keys + .into_values(); + + let self_signing_key_id = self_signing_key_ids.next().ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Self signing key contained no key.", + ))?; + + if self_signing_key_ids.next().is_some() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Self signing key contained more than one key.", + )); + } + + let mut self_signing_key_key = prefix.clone(); + self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes()); + + self.keyid_key.insert( + &self_signing_key_key, + self_signing_key.json().get().as_bytes(), + )?; + + self.userid_selfsigningkeyid + .insert(user_id.as_bytes(), &self_signing_key_key)?; + } + + // User-signing key + if let Some(user_signing_key) = user_signing_key { + let mut user_signing_key_ids = user_signing_key + .deserialize() + .map_err(|_| { + Error::BadRequest(ErrorKind::InvalidParam, "Invalid user signing key") + })? + .keys + .into_values(); + + let user_signing_key_id = user_signing_key_ids.next().ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "User signing key contained no key.", + ))?; + + if user_signing_key_ids.next().is_some() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "User signing key contained more than one key.", + )); + } + + let mut user_signing_key_key = prefix; + user_signing_key_key.extend_from_slice(user_signing_key_id.as_bytes()); + + self.keyid_key.insert( + &user_signing_key_key, + user_signing_key.json().get().as_bytes(), + )?; + + self.userid_usersigningkeyid + .insert(user_id.as_bytes(), &user_signing_key_key)?; + } + + self.mark_device_key_update(user_id, rooms, globals)?; + + Ok(()) + } + + #[tracing::instrument(skip(self, target_id, key_id, signature, sender_id, rooms, globals))] + pub fn sign_key( + &self, + target_id: &UserId, + key_id: &str, + signature: (String, String), + sender_id: &UserId, + rooms: &super::rooms::Rooms, + globals: &super::globals::Globals, + ) -> Result<()> { + let mut key = target_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(key_id.as_bytes()); + + let mut cross_signing_key: serde_json::Value = + serde_json::from_slice(&self.keyid_key.get(&key)?.ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Tried to sign nonexistent key.", + ))?) + .map_err(|_| Error::bad_database("key in keyid_key is invalid."))?; + + let signatures = cross_signing_key + .get_mut("signatures") + .ok_or_else(|| Error::bad_database("key in keyid_key has no signatures field."))? + .as_object_mut() + .ok_or_else(|| Error::bad_database("key in keyid_key has invalid signatures field."))? + .entry(sender_id.to_owned()) + .or_insert_with(|| serde_json::Map::new().into()); + + signatures + .as_object_mut() + .ok_or_else(|| Error::bad_database("signatures in keyid_key for a user is invalid."))? + .insert(signature.0, signature.1.into()); + + self.keyid_key.insert( + &key, + &serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"), + )?; + + // TODO: Should we notify about this change? + self.mark_device_key_update(target_id, rooms, globals)?; + + Ok(()) + } + + #[tracing::instrument(skip(self, user_or_room_id, from, to))] + pub fn keys_changed<'a>( + &'a self, + user_or_room_id: &str, + from: u64, + to: Option, + ) -> impl Iterator>> + 'a { + let mut prefix = user_or_room_id.as_bytes().to_vec(); + prefix.push(0xff); + + let mut start = prefix.clone(); + start.extend_from_slice(&(from + 1).to_be_bytes()); + + let to = to.unwrap_or(u64::MAX); + + self.keychangeid_userid + .iter_from(&start, false) + .take_while(move |(k, _)| { + k.starts_with(&prefix) + && if let Some(current) = k.splitn(2, |&b| b == 0xff).nth(1) { + if let Ok(c) = utils::u64_from_bytes(current) { + c <= to + } else { + warn!("BadDatabase: Could not parse keychangeid_userid bytes"); + false + } + } else { + warn!("BadDatabase: Could not parse keychangeid_userid"); + false + } + }) + .map(|(_, bytes)| { + UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database("User ID in devicekeychangeid_userid is invalid unicode.") + })?) + .map_err(|_| Error::bad_database("User ID in devicekeychangeid_userid is invalid.")) + }) + } + + #[tracing::instrument(skip(self, user_id, rooms, globals))] + pub fn mark_device_key_update( + &self, + user_id: &UserId, + rooms: &super::rooms::Rooms, + globals: &super::globals::Globals, + ) -> Result<()> { + let count = globals.next_count()?.to_be_bytes(); + for room_id in rooms.rooms_joined(user_id).filter_map(|r| r.ok()) { + // Don't send key updates to unencrypted rooms + if rooms + .room_state_get(&room_id, &StateEventType::RoomEncryption, "")? + .is_none() + { + continue; + } + + let mut key = room_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(&count); + + self.keychangeid_userid.insert(&key, user_id.as_bytes())?; + } + + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(&count); + self.keychangeid_userid.insert(&key, user_id.as_bytes())?; + + Ok(()) + } + + #[tracing::instrument(skip(self, user_id, device_id))] + pub fn get_device_keys( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result>> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(device_id.as_bytes()); + + self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { + Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { + Error::bad_database("DeviceKeys in db are invalid.") + })?)) + }) + } + + #[tracing::instrument(skip(self, user_id, allowed_signatures))] + pub fn get_master_key bool>( + &self, + user_id: &UserId, + allowed_signatures: F, + ) -> Result>> { + self.userid_masterkeyid + .get(user_id.as_bytes())? + .map_or(Ok(None), |key| { + self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { + let mut cross_signing_key = serde_json::from_slice::(&bytes) + .map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?; + clean_signatures(&mut cross_signing_key, user_id, allowed_signatures)?; + + Ok(Some(Raw::from_json( + serde_json::value::to_raw_value(&cross_signing_key) + .expect("Value to RawValue serialization"), + ))) + }) + }) + } + + #[tracing::instrument(skip(self, user_id, allowed_signatures))] + pub fn get_self_signing_key bool>( + &self, + user_id: &UserId, + allowed_signatures: F, + ) -> Result>> { + self.userid_selfsigningkeyid + .get(user_id.as_bytes())? + .map_or(Ok(None), |key| { + self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { + let mut cross_signing_key = serde_json::from_slice::(&bytes) + .map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?; + clean_signatures(&mut cross_signing_key, user_id, allowed_signatures)?; + + Ok(Some(Raw::from_json( + serde_json::value::to_raw_value(&cross_signing_key) + .expect("Value to RawValue serialization"), + ))) + }) + }) + } + + #[tracing::instrument(skip(self, user_id))] + pub fn get_user_signing_key(&self, user_id: &UserId) -> Result>> { + self.userid_usersigningkeyid + .get(user_id.as_bytes())? + .map_or(Ok(None), |key| { + self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { + Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { + Error::bad_database("CrossSigningKey in db is invalid.") + })?)) + }) + }) + } + + #[tracing::instrument(skip( + self, + sender, + target_user_id, + target_device_id, + event_type, + content, + globals + ))] + pub fn add_to_device_event( + &self, + sender: &UserId, + target_user_id: &UserId, + target_device_id: &DeviceId, + event_type: &str, + content: serde_json::Value, + globals: &super::globals::Globals, + ) -> Result<()> { + let mut key = target_user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(target_device_id.as_bytes()); + key.push(0xff); + key.extend_from_slice(&globals.next_count()?.to_be_bytes()); + + let mut json = serde_json::Map::new(); + json.insert("type".to_owned(), event_type.to_owned().into()); + json.insert("sender".to_owned(), sender.to_string().into()); + json.insert("content".to_owned(), content); + + let value = serde_json::to_vec(&json).expect("Map::to_vec always works"); + + self.todeviceid_events.insert(&key, &value)?; + + Ok(()) + } + + #[tracing::instrument(skip(self, user_id, device_id))] + pub fn get_to_device_events( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result>> { + let mut events = Vec::new(); + + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xff); + prefix.extend_from_slice(device_id.as_bytes()); + prefix.push(0xff); + + for (_, value) in self.todeviceid_events.scan_prefix(prefix) { + events.push( + serde_json::from_slice(&value) + .map_err(|_| Error::bad_database("Event in todeviceid_events is invalid."))?, + ); + } + + Ok(events) + } + + #[tracing::instrument(skip(self, user_id, device_id, until))] + pub fn remove_to_device_events( + &self, + user_id: &UserId, + device_id: &DeviceId, + until: u64, + ) -> Result<()> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xff); + prefix.extend_from_slice(device_id.as_bytes()); + prefix.push(0xff); + + let mut last = prefix.clone(); + last.extend_from_slice(&until.to_be_bytes()); + + for (key, _) in self + .todeviceid_events + .iter_from(&last, true) // this includes last + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(|(key, _)| { + Ok::<_, Error>(( + key.clone(), + utils::u64_from_bytes(&key[key.len() - mem::size_of::()..key.len()]) + .map_err(|_| Error::bad_database("ToDeviceId has invalid count bytes."))?, + )) + }) + .filter_map(|r| r.ok()) + .take_while(|&(_, count)| count <= until) + { + self.todeviceid_events.remove(&key)?; + } + + Ok(()) + } + + #[tracing::instrument(skip(self, user_id, device_id, device))] + pub fn update_device_metadata( + &self, + user_id: &UserId, + device_id: &DeviceId, + device: &Device, + ) -> Result<()> { + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xff); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + // Only existing devices should be able to call this. + assert!(self.userdeviceid_metadata.get(&userdeviceid)?.is_some()); + + self.userid_devicelistversion + .increment(user_id.as_bytes())?; + + self.userdeviceid_metadata.insert( + &userdeviceid, + &serde_json::to_vec(device).expect("Device::to_string always works"), + )?; + + Ok(()) + } + + /// Get device metadata. + #[tracing::instrument(skip(self, user_id, device_id))] + pub fn get_device_metadata( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result> { + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xff); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + self.userdeviceid_metadata + .get(&userdeviceid)? + .map_or(Ok(None), |bytes| { + Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { + Error::bad_database("Metadata in userdeviceid_metadata is invalid.") + })?)) + }) + } + + #[tracing::instrument(skip(self, user_id))] + pub fn get_devicelist_version(&self, user_id: &UserId) -> Result> { + self.userid_devicelistversion + .get(user_id.as_bytes())? + .map_or(Ok(None), |bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Invalid devicelistversion in db.")) + .map(Some) + }) + } + + #[tracing::instrument(skip(self, user_id))] + pub fn all_devices_metadata<'a>( + &'a self, + user_id: &UserId, + ) -> impl Iterator> + 'a { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + + self.userdeviceid_metadata + .scan_prefix(key) + .map(|(_, bytes)| { + serde_json::from_slice::(&bytes) + .map_err(|_| Error::bad_database("Device in userdeviceid_metadata is invalid.")) + }) + } + + /// Deactivate account + #[tracing::instrument(skip(self, user_id))] + pub fn deactivate_account(&self, user_id: &UserId) -> Result<()> { + // Remove all associated devices + for device_id in self.all_device_ids(user_id) { + self.remove_device(user_id, &device_id?)?; + } + + // Set the password to "" to indicate a deactivated account. Hashes will never result in an + // empty string, so the user will not be able to log in again. Systems like changing the + // password without logging in should check if the account is deactivated. + self.userid_password.insert(user_id.as_bytes(), &[])?; + + // TODO: Unhook 3PID + Ok(()) + } + + /// Creates a new sync filter. Returns the filter id. + #[tracing::instrument(skip(self))] + pub fn create_filter( + &self, + user_id: &UserId, + filter: &IncomingFilterDefinition, + ) -> Result { + let filter_id = utils::random_string(4); + + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(filter_id.as_bytes()); + + self.userfilterid_filter.insert( + &key, + &serde_json::to_vec(&filter).expect("filter is valid json"), + )?; + + Ok(filter_id) + } + + #[tracing::instrument(skip(self))] + pub fn get_filter( + &self, + user_id: &UserId, + filter_id: &str, + ) -> Result> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(filter_id.as_bytes()); + + let raw = self.userfilterid_filter.get(&key)?; + + if let Some(raw) = raw { + serde_json::from_slice(&raw) + .map_err(|_| Error::bad_database("Invalid filter event in db.")) + } else { + Ok(None) + } + } +} + +/// Ensure that a user only sees signatures from themselves and the target user +fn clean_signatures bool>( + cross_signing_key: &mut serde_json::Value, + user_id: &UserId, + allowed_signatures: F, +) -> Result<(), Error> { + if let Some(signatures) = cross_signing_key + .get_mut("signatures") + .and_then(|v| v.as_object_mut()) + { + // Don't allocate for the full size of the current signatures, but require + // at most one resize if nothing is dropped + let new_capacity = signatures.len() / 2; + for (user, signature) in + mem::replace(signatures, serde_json::Map::with_capacity(new_capacity)) + { + let id = <&UserId>::try_from(user.as_str()) + .map_err(|_| Error::bad_database("Invalid user ID in database."))?; + if id == user_id || allowed_signatures(id) { + signatures.insert(user, signature); + } + } + } + + Ok(()) +} diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs new file mode 100644 index 00000000..7c15f1d8 --- /dev/null +++ b/src/service/users/mod.rs @@ -0,0 +1,1101 @@ +use crate::{utils, Error, Result}; +use ruma::{ + api::client::{device::Device, error::ErrorKind, filter::IncomingFilterDefinition}, + encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, + events::{AnyToDeviceEvent, StateEventType}, + serde::Raw, + DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, MxcUri, RoomAliasId, + UInt, UserId, +}; +use std::{collections::BTreeMap, mem, sync::Arc}; +use tracing::warn; + +use super::abstraction::Tree; + +pub struct Users { + pub(super) userid_password: Arc, + pub(super) userid_displayname: Arc, + pub(super) userid_avatarurl: Arc, + pub(super) userid_blurhash: Arc, + pub(super) userdeviceid_token: Arc, + pub(super) userdeviceid_metadata: Arc, // This is also used to check if a device exists + pub(super) userid_devicelistversion: Arc, // DevicelistVersion = u64 + pub(super) token_userdeviceid: Arc, + + pub(super) onetimekeyid_onetimekeys: Arc, // OneTimeKeyId = UserId + DeviceKeyId + pub(super) userid_lastonetimekeyupdate: Arc, // LastOneTimeKeyUpdate = Count + pub(super) keychangeid_userid: Arc, // KeyChangeId = UserId/RoomId + Count + pub(super) keyid_key: Arc, // KeyId = UserId + KeyId (depends on key type) + pub(super) userid_masterkeyid: Arc, + pub(super) userid_selfsigningkeyid: Arc, + pub(super) userid_usersigningkeyid: Arc, + + pub(super) userfilterid_filter: Arc, // UserFilterId = UserId + FilterId + + pub(super) todeviceid_events: Arc, // ToDeviceId = UserId + DeviceId + Count +} + +impl Users { + /// Check if a user has an account on this homeserver. + #[tracing::instrument(skip(self, user_id))] + pub fn exists(&self, user_id: &UserId) -> Result { + Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) + } + + /// Check if account is deactivated + #[tracing::instrument(skip(self, user_id))] + pub fn is_deactivated(&self, user_id: &UserId) -> Result { + Ok(self + .userid_password + .get(user_id.as_bytes())? + .ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "User does not exist.", + ))? + .is_empty()) + } + + /// Check if a user is an admin + #[tracing::instrument(skip(self, user_id, rooms, globals))] + pub fn is_admin( + &self, + user_id: &UserId, + rooms: &super::rooms::Rooms, + globals: &super::globals::Globals, + ) -> Result { + let admin_room_alias_id = RoomAliasId::parse(format!("#admins:{}", globals.server_name())) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias."))?; + let admin_room_id = rooms.id_from_alias(&admin_room_alias_id)?.unwrap(); + + rooms.is_joined(user_id, &admin_room_id) + } + + /// Create a new user account on this homeserver. + #[tracing::instrument(skip(self, user_id, password))] + pub fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { + self.set_password(user_id, password)?; + Ok(()) + } + + /// Returns the number of users registered on this server. + #[tracing::instrument(skip(self))] + pub fn count(&self) -> Result { + Ok(self.userid_password.iter().count()) + } + + /// Find out which user an access token belongs to. + #[tracing::instrument(skip(self, token))] + pub fn find_from_token(&self, token: &str) -> Result, String)>> { + self.token_userdeviceid + .get(token.as_bytes())? + .map_or(Ok(None), |bytes| { + let mut parts = bytes.split(|&b| b == 0xff); + let user_bytes = parts.next().ok_or_else(|| { + Error::bad_database("User ID in token_userdeviceid is invalid.") + })?; + let device_bytes = parts.next().ok_or_else(|| { + Error::bad_database("Device ID in token_userdeviceid is invalid.") + })?; + + Ok(Some(( + UserId::parse(utils::string_from_bytes(user_bytes).map_err(|_| { + Error::bad_database("User ID in token_userdeviceid is invalid unicode.") + })?) + .map_err(|_| { + Error::bad_database("User ID in token_userdeviceid is invalid.") + })?, + utils::string_from_bytes(device_bytes).map_err(|_| { + Error::bad_database("Device ID in token_userdeviceid is invalid.") + })?, + ))) + }) + } + + /// Returns an iterator over all users on this homeserver. + #[tracing::instrument(skip(self))] + pub fn iter(&self) -> impl Iterator>> + '_ { + self.userid_password.iter().map(|(bytes, _)| { + UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database("User ID in userid_password is invalid unicode.") + })?) + .map_err(|_| Error::bad_database("User ID in userid_password is invalid.")) + }) + } + + /// Returns a list of local users as list of usernames. + /// + /// A user account is considered `local` if the length of it's password is greater then zero. + #[tracing::instrument(skip(self))] + pub fn list_local_users(&self) -> Result> { + let users: Vec = self + .userid_password + .iter() + .filter_map(|(username, pw)| self.get_username_with_valid_password(&username, &pw)) + .collect(); + Ok(users) + } + + /// Will only return with Some(username) if the password was not empty and the + /// username could be successfully parsed. + /// If utils::string_from_bytes(...) returns an error that username will be skipped + /// and the error will be logged. + #[tracing::instrument(skip(self))] + fn get_username_with_valid_password(&self, username: &[u8], password: &[u8]) -> Option { + // A valid password is not empty + if password.is_empty() { + None + } else { + match utils::string_from_bytes(username) { + Ok(u) => Some(u), + Err(e) => { + warn!( + "Failed to parse username while calling get_local_users(): {}", + e.to_string() + ); + None + } + } + } + } + + /// Returns the password hash for the given user. + #[tracing::instrument(skip(self, user_id))] + pub fn password_hash(&self, user_id: &UserId) -> Result> { + self.userid_password + .get(user_id.as_bytes())? + .map_or(Ok(None), |bytes| { + Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Password hash in db is not valid string.") + })?)) + }) + } + + /// Hash and set the user's password to the Argon2 hash + #[tracing::instrument(skip(self, user_id, password))] + pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { + if let Some(password) = password { + if let Ok(hash) = utils::calculate_hash(password) { + self.userid_password + .insert(user_id.as_bytes(), hash.as_bytes())?; + Ok(()) + } else { + Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Password does not meet the requirements.", + )) + } + } else { + self.userid_password.insert(user_id.as_bytes(), b"")?; + Ok(()) + } + } + + /// Returns the displayname of a user on this homeserver. + #[tracing::instrument(skip(self, user_id))] + pub fn displayname(&self, user_id: &UserId) -> Result> { + self.userid_displayname + .get(user_id.as_bytes())? + .map_or(Ok(None), |bytes| { + Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Displayname in db is invalid.") + })?)) + }) + } + + /// Sets a new displayname or removes it if displayname is None. You still need to nofify all rooms of this change. + #[tracing::instrument(skip(self, user_id, displayname))] + pub fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()> { + if let Some(displayname) = displayname { + self.userid_displayname + .insert(user_id.as_bytes(), displayname.as_bytes())?; + } else { + self.userid_displayname.remove(user_id.as_bytes())?; + } + + Ok(()) + } + + /// Get the avatar_url of a user. + #[tracing::instrument(skip(self, user_id))] + pub fn avatar_url(&self, user_id: &UserId) -> Result>> { + self.userid_avatarurl + .get(user_id.as_bytes())? + .map(|bytes| { + let s = utils::string_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Avatar URL in db is invalid."))?; + s.try_into() + .map_err(|_| Error::bad_database("Avatar URL in db is invalid.")) + }) + .transpose() + } + + /// Sets a new avatar_url or removes it if avatar_url is None. + #[tracing::instrument(skip(self, user_id, avatar_url))] + pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option>) -> Result<()> { + if let Some(avatar_url) = avatar_url { + self.userid_avatarurl + .insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?; + } else { + self.userid_avatarurl.remove(user_id.as_bytes())?; + } + + Ok(()) + } + + /// Get the blurhash of a user. + #[tracing::instrument(skip(self, user_id))] + pub fn blurhash(&self, user_id: &UserId) -> Result> { + self.userid_blurhash + .get(user_id.as_bytes())? + .map(|bytes| { + let s = utils::string_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Avatar URL in db is invalid."))?; + + Ok(s) + }) + .transpose() + } + + /// Sets a new avatar_url or removes it if avatar_url is None. + #[tracing::instrument(skip(self, user_id, blurhash))] + pub fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()> { + if let Some(blurhash) = blurhash { + self.userid_blurhash + .insert(user_id.as_bytes(), blurhash.as_bytes())?; + } else { + self.userid_blurhash.remove(user_id.as_bytes())?; + } + + Ok(()) + } + + /// Adds a new device to a user. + #[tracing::instrument(skip(self, user_id, device_id, token, initial_device_display_name))] + pub fn create_device( + &self, + user_id: &UserId, + device_id: &DeviceId, + token: &str, + initial_device_display_name: Option, + ) -> Result<()> { + // This method should never be called for nonexistent users. + assert!(self.exists(user_id)?); + + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xff); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + self.userid_devicelistversion + .increment(user_id.as_bytes())?; + + self.userdeviceid_metadata.insert( + &userdeviceid, + &serde_json::to_vec(&Device { + device_id: device_id.into(), + display_name: initial_device_display_name, + last_seen_ip: None, // TODO + last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()), + }) + .expect("Device::to_string never fails."), + )?; + + self.set_token(user_id, device_id, token)?; + + Ok(()) + } + + /// Removes a device from a user. + #[tracing::instrument(skip(self, user_id, device_id))] + pub fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xff); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + // Remove tokens + if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? { + self.userdeviceid_token.remove(&userdeviceid)?; + self.token_userdeviceid.remove(&old_token)?; + } + + // Remove todevice events + let mut prefix = userdeviceid.clone(); + prefix.push(0xff); + + for (key, _) in self.todeviceid_events.scan_prefix(prefix) { + self.todeviceid_events.remove(&key)?; + } + + // TODO: Remove onetimekeys + + self.userid_devicelistversion + .increment(user_id.as_bytes())?; + + self.userdeviceid_metadata.remove(&userdeviceid)?; + + Ok(()) + } + + /// Returns an iterator over all device ids of this user. + #[tracing::instrument(skip(self, user_id))] + pub fn all_device_ids<'a>( + &'a self, + user_id: &UserId, + ) -> impl Iterator>> + 'a { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xff); + // All devices have metadata + self.userdeviceid_metadata + .scan_prefix(prefix) + .map(|(bytes, _)| { + Ok(utils::string_from_bytes( + bytes + .rsplit(|&b| b == 0xff) + .next() + .ok_or_else(|| Error::bad_database("UserDevice ID in db is invalid."))?, + ) + .map_err(|_| Error::bad_database("Device ID in userdeviceid_metadata is invalid."))? + .into()) + }) + } + + /// Replaces the access token of one device. + #[tracing::instrument(skip(self, user_id, device_id, token))] + pub fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xff); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + // All devices have metadata + assert!(self.userdeviceid_metadata.get(&userdeviceid)?.is_some()); + + // Remove old token + if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? { + self.token_userdeviceid.remove(&old_token)?; + // It will be removed from userdeviceid_token by the insert later + } + + // Assign token to user device combination + self.userdeviceid_token + .insert(&userdeviceid, token.as_bytes())?; + self.token_userdeviceid + .insert(token.as_bytes(), &userdeviceid)?; + + Ok(()) + } + + #[tracing::instrument(skip( + self, + user_id, + device_id, + one_time_key_key, + one_time_key_value, + globals + ))] + pub fn add_one_time_key( + &self, + user_id: &UserId, + device_id: &DeviceId, + one_time_key_key: &DeviceKeyId, + one_time_key_value: &Raw, + globals: &super::globals::Globals, + ) -> Result<()> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(device_id.as_bytes()); + + // All devices have metadata + // Only existing devices should be able to call this. + assert!(self.userdeviceid_metadata.get(&key)?.is_some()); + + key.push(0xff); + // TODO: Use DeviceKeyId::to_string when it's available (and update everything, + // because there are no wrapping quotation marks anymore) + key.extend_from_slice( + serde_json::to_string(one_time_key_key) + .expect("DeviceKeyId::to_string always works") + .as_bytes(), + ); + + self.onetimekeyid_onetimekeys.insert( + &key, + &serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"), + )?; + + self.userid_lastonetimekeyupdate + .insert(user_id.as_bytes(), &globals.next_count()?.to_be_bytes())?; + + Ok(()) + } + + #[tracing::instrument(skip(self, user_id))] + pub fn last_one_time_keys_update(&self, user_id: &UserId) -> Result { + self.userid_lastonetimekeyupdate + .get(user_id.as_bytes())? + .map(|bytes| { + utils::u64_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.") + }) + }) + .unwrap_or(Ok(0)) + } + + #[tracing::instrument(skip(self, user_id, device_id, key_algorithm, globals))] + pub fn take_one_time_key( + &self, + user_id: &UserId, + device_id: &DeviceId, + key_algorithm: &DeviceKeyAlgorithm, + globals: &super::globals::Globals, + ) -> Result, Raw)>> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xff); + prefix.extend_from_slice(device_id.as_bytes()); + prefix.push(0xff); + prefix.push(b'"'); // Annoying quotation mark + prefix.extend_from_slice(key_algorithm.as_ref().as_bytes()); + prefix.push(b':'); + + self.userid_lastonetimekeyupdate + .insert(user_id.as_bytes(), &globals.next_count()?.to_be_bytes())?; + + self.onetimekeyid_onetimekeys + .scan_prefix(prefix) + .next() + .map(|(key, value)| { + self.onetimekeyid_onetimekeys.remove(&key)?; + + Ok(( + serde_json::from_slice( + &*key + .rsplit(|&b| b == 0xff) + .next() + .ok_or_else(|| Error::bad_database("OneTimeKeyId in db is invalid."))?, + ) + .map_err(|_| Error::bad_database("OneTimeKeyId in db is invalid."))?, + serde_json::from_slice(&*value) + .map_err(|_| Error::bad_database("OneTimeKeys in db are invalid."))?, + )) + }) + .transpose() + } + + #[tracing::instrument(skip(self, user_id, device_id))] + pub fn count_one_time_keys( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result> { + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xff); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + let mut counts = BTreeMap::new(); + + for algorithm in + self.onetimekeyid_onetimekeys + .scan_prefix(userdeviceid) + .map(|(bytes, _)| { + Ok::<_, Error>( + serde_json::from_slice::>( + &*bytes.rsplit(|&b| b == 0xff).next().ok_or_else(|| { + Error::bad_database("OneTimeKey ID in db is invalid.") + })?, + ) + .map_err(|_| Error::bad_database("DeviceKeyId in db is invalid."))? + .algorithm(), + ) + }) + { + *counts.entry(algorithm?).or_default() += UInt::from(1_u32); + } + + Ok(counts) + } + + #[tracing::instrument(skip(self, user_id, device_id, device_keys, rooms, globals))] + pub fn add_device_keys( + &self, + user_id: &UserId, + device_id: &DeviceId, + device_keys: &Raw, + rooms: &super::rooms::Rooms, + globals: &super::globals::Globals, + ) -> Result<()> { + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xff); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + self.keyid_key.insert( + &userdeviceid, + &serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"), + )?; + + self.mark_device_key_update(user_id, rooms, globals)?; + + Ok(()) + } + + #[tracing::instrument(skip( + self, + master_key, + self_signing_key, + user_signing_key, + rooms, + globals + ))] + pub fn add_cross_signing_keys( + &self, + user_id: &UserId, + master_key: &Raw, + self_signing_key: &Option>, + user_signing_key: &Option>, + rooms: &super::rooms::Rooms, + globals: &super::globals::Globals, + ) -> Result<()> { + // TODO: Check signatures + + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xff); + + // Master key + let mut master_key_ids = master_key + .deserialize() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))? + .keys + .into_values(); + + let master_key_id = master_key_ids.next().ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Master key contained no key.", + ))?; + + if master_key_ids.next().is_some() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Master key contained more than one key.", + )); + } + + let mut master_key_key = prefix.clone(); + master_key_key.extend_from_slice(master_key_id.as_bytes()); + + self.keyid_key + .insert(&master_key_key, master_key.json().get().as_bytes())?; + + self.userid_masterkeyid + .insert(user_id.as_bytes(), &master_key_key)?; + + // Self-signing key + if let Some(self_signing_key) = self_signing_key { + let mut self_signing_key_ids = self_signing_key + .deserialize() + .map_err(|_| { + Error::BadRequest(ErrorKind::InvalidParam, "Invalid self signing key") + })? + .keys + .into_values(); + + let self_signing_key_id = self_signing_key_ids.next().ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Self signing key contained no key.", + ))?; + + if self_signing_key_ids.next().is_some() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Self signing key contained more than one key.", + )); + } + + let mut self_signing_key_key = prefix.clone(); + self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes()); + + self.keyid_key.insert( + &self_signing_key_key, + self_signing_key.json().get().as_bytes(), + )?; + + self.userid_selfsigningkeyid + .insert(user_id.as_bytes(), &self_signing_key_key)?; + } + + // User-signing key + if let Some(user_signing_key) = user_signing_key { + let mut user_signing_key_ids = user_signing_key + .deserialize() + .map_err(|_| { + Error::BadRequest(ErrorKind::InvalidParam, "Invalid user signing key") + })? + .keys + .into_values(); + + let user_signing_key_id = user_signing_key_ids.next().ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "User signing key contained no key.", + ))?; + + if user_signing_key_ids.next().is_some() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "User signing key contained more than one key.", + )); + } + + let mut user_signing_key_key = prefix; + user_signing_key_key.extend_from_slice(user_signing_key_id.as_bytes()); + + self.keyid_key.insert( + &user_signing_key_key, + user_signing_key.json().get().as_bytes(), + )?; + + self.userid_usersigningkeyid + .insert(user_id.as_bytes(), &user_signing_key_key)?; + } + + self.mark_device_key_update(user_id, rooms, globals)?; + + Ok(()) + } + + #[tracing::instrument(skip(self, target_id, key_id, signature, sender_id, rooms, globals))] + pub fn sign_key( + &self, + target_id: &UserId, + key_id: &str, + signature: (String, String), + sender_id: &UserId, + rooms: &super::rooms::Rooms, + globals: &super::globals::Globals, + ) -> Result<()> { + let mut key = target_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(key_id.as_bytes()); + + let mut cross_signing_key: serde_json::Value = + serde_json::from_slice(&self.keyid_key.get(&key)?.ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Tried to sign nonexistent key.", + ))?) + .map_err(|_| Error::bad_database("key in keyid_key is invalid."))?; + + let signatures = cross_signing_key + .get_mut("signatures") + .ok_or_else(|| Error::bad_database("key in keyid_key has no signatures field."))? + .as_object_mut() + .ok_or_else(|| Error::bad_database("key in keyid_key has invalid signatures field."))? + .entry(sender_id.to_owned()) + .or_insert_with(|| serde_json::Map::new().into()); + + signatures + .as_object_mut() + .ok_or_else(|| Error::bad_database("signatures in keyid_key for a user is invalid."))? + .insert(signature.0, signature.1.into()); + + self.keyid_key.insert( + &key, + &serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"), + )?; + + // TODO: Should we notify about this change? + self.mark_device_key_update(target_id, rooms, globals)?; + + Ok(()) + } + + #[tracing::instrument(skip(self, user_or_room_id, from, to))] + pub fn keys_changed<'a>( + &'a self, + user_or_room_id: &str, + from: u64, + to: Option, + ) -> impl Iterator>> + 'a { + let mut prefix = user_or_room_id.as_bytes().to_vec(); + prefix.push(0xff); + + let mut start = prefix.clone(); + start.extend_from_slice(&(from + 1).to_be_bytes()); + + let to = to.unwrap_or(u64::MAX); + + self.keychangeid_userid + .iter_from(&start, false) + .take_while(move |(k, _)| { + k.starts_with(&prefix) + && if let Some(current) = k.splitn(2, |&b| b == 0xff).nth(1) { + if let Ok(c) = utils::u64_from_bytes(current) { + c <= to + } else { + warn!("BadDatabase: Could not parse keychangeid_userid bytes"); + false + } + } else { + warn!("BadDatabase: Could not parse keychangeid_userid"); + false + } + }) + .map(|(_, bytes)| { + UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database("User ID in devicekeychangeid_userid is invalid unicode.") + })?) + .map_err(|_| Error::bad_database("User ID in devicekeychangeid_userid is invalid.")) + }) + } + + #[tracing::instrument(skip(self, user_id, rooms, globals))] + pub fn mark_device_key_update( + &self, + user_id: &UserId, + rooms: &super::rooms::Rooms, + globals: &super::globals::Globals, + ) -> Result<()> { + let count = globals.next_count()?.to_be_bytes(); + for room_id in rooms.rooms_joined(user_id).filter_map(|r| r.ok()) { + // Don't send key updates to unencrypted rooms + if rooms + .room_state_get(&room_id, &StateEventType::RoomEncryption, "")? + .is_none() + { + continue; + } + + let mut key = room_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(&count); + + self.keychangeid_userid.insert(&key, user_id.as_bytes())?; + } + + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(&count); + self.keychangeid_userid.insert(&key, user_id.as_bytes())?; + + Ok(()) + } + + #[tracing::instrument(skip(self, user_id, device_id))] + pub fn get_device_keys( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result>> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(device_id.as_bytes()); + + self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { + Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { + Error::bad_database("DeviceKeys in db are invalid.") + })?)) + }) + } + + #[tracing::instrument(skip(self, user_id, allowed_signatures))] + pub fn get_master_key bool>( + &self, + user_id: &UserId, + allowed_signatures: F, + ) -> Result>> { + self.userid_masterkeyid + .get(user_id.as_bytes())? + .map_or(Ok(None), |key| { + self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { + let mut cross_signing_key = serde_json::from_slice::(&bytes) + .map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?; + clean_signatures(&mut cross_signing_key, user_id, allowed_signatures)?; + + Ok(Some(Raw::from_json( + serde_json::value::to_raw_value(&cross_signing_key) + .expect("Value to RawValue serialization"), + ))) + }) + }) + } + + #[tracing::instrument(skip(self, user_id, allowed_signatures))] + pub fn get_self_signing_key bool>( + &self, + user_id: &UserId, + allowed_signatures: F, + ) -> Result>> { + self.userid_selfsigningkeyid + .get(user_id.as_bytes())? + .map_or(Ok(None), |key| { + self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { + let mut cross_signing_key = serde_json::from_slice::(&bytes) + .map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?; + clean_signatures(&mut cross_signing_key, user_id, allowed_signatures)?; + + Ok(Some(Raw::from_json( + serde_json::value::to_raw_value(&cross_signing_key) + .expect("Value to RawValue serialization"), + ))) + }) + }) + } + + #[tracing::instrument(skip(self, user_id))] + pub fn get_user_signing_key(&self, user_id: &UserId) -> Result>> { + self.userid_usersigningkeyid + .get(user_id.as_bytes())? + .map_or(Ok(None), |key| { + self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { + Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { + Error::bad_database("CrossSigningKey in db is invalid.") + })?)) + }) + }) + } + + #[tracing::instrument(skip( + self, + sender, + target_user_id, + target_device_id, + event_type, + content, + globals + ))] + pub fn add_to_device_event( + &self, + sender: &UserId, + target_user_id: &UserId, + target_device_id: &DeviceId, + event_type: &str, + content: serde_json::Value, + globals: &super::globals::Globals, + ) -> Result<()> { + let mut key = target_user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(target_device_id.as_bytes()); + key.push(0xff); + key.extend_from_slice(&globals.next_count()?.to_be_bytes()); + + let mut json = serde_json::Map::new(); + json.insert("type".to_owned(), event_type.to_owned().into()); + json.insert("sender".to_owned(), sender.to_string().into()); + json.insert("content".to_owned(), content); + + let value = serde_json::to_vec(&json).expect("Map::to_vec always works"); + + self.todeviceid_events.insert(&key, &value)?; + + Ok(()) + } + + #[tracing::instrument(skip(self, user_id, device_id))] + pub fn get_to_device_events( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result>> { + let mut events = Vec::new(); + + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xff); + prefix.extend_from_slice(device_id.as_bytes()); + prefix.push(0xff); + + for (_, value) in self.todeviceid_events.scan_prefix(prefix) { + events.push( + serde_json::from_slice(&value) + .map_err(|_| Error::bad_database("Event in todeviceid_events is invalid."))?, + ); + } + + Ok(events) + } + + #[tracing::instrument(skip(self, user_id, device_id, until))] + pub fn remove_to_device_events( + &self, + user_id: &UserId, + device_id: &DeviceId, + until: u64, + ) -> Result<()> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xff); + prefix.extend_from_slice(device_id.as_bytes()); + prefix.push(0xff); + + let mut last = prefix.clone(); + last.extend_from_slice(&until.to_be_bytes()); + + for (key, _) in self + .todeviceid_events + .iter_from(&last, true) // this includes last + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(|(key, _)| { + Ok::<_, Error>(( + key.clone(), + utils::u64_from_bytes(&key[key.len() - mem::size_of::()..key.len()]) + .map_err(|_| Error::bad_database("ToDeviceId has invalid count bytes."))?, + )) + }) + .filter_map(|r| r.ok()) + .take_while(|&(_, count)| count <= until) + { + self.todeviceid_events.remove(&key)?; + } + + Ok(()) + } + + #[tracing::instrument(skip(self, user_id, device_id, device))] + pub fn update_device_metadata( + &self, + user_id: &UserId, + device_id: &DeviceId, + device: &Device, + ) -> Result<()> { + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xff); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + // Only existing devices should be able to call this. + assert!(self.userdeviceid_metadata.get(&userdeviceid)?.is_some()); + + self.userid_devicelistversion + .increment(user_id.as_bytes())?; + + self.userdeviceid_metadata.insert( + &userdeviceid, + &serde_json::to_vec(device).expect("Device::to_string always works"), + )?; + + Ok(()) + } + + /// Get device metadata. + #[tracing::instrument(skip(self, user_id, device_id))] + pub fn get_device_metadata( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result> { + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xff); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + self.userdeviceid_metadata + .get(&userdeviceid)? + .map_or(Ok(None), |bytes| { + Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { + Error::bad_database("Metadata in userdeviceid_metadata is invalid.") + })?)) + }) + } + + #[tracing::instrument(skip(self, user_id))] + pub fn get_devicelist_version(&self, user_id: &UserId) -> Result> { + self.userid_devicelistversion + .get(user_id.as_bytes())? + .map_or(Ok(None), |bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Invalid devicelistversion in db.")) + .map(Some) + }) + } + + #[tracing::instrument(skip(self, user_id))] + pub fn all_devices_metadata<'a>( + &'a self, + user_id: &UserId, + ) -> impl Iterator> + 'a { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + + self.userdeviceid_metadata + .scan_prefix(key) + .map(|(_, bytes)| { + serde_json::from_slice::(&bytes) + .map_err(|_| Error::bad_database("Device in userdeviceid_metadata is invalid.")) + }) + } + + /// Deactivate account + #[tracing::instrument(skip(self, user_id))] + pub fn deactivate_account(&self, user_id: &UserId) -> Result<()> { + // Remove all associated devices + for device_id in self.all_device_ids(user_id) { + self.remove_device(user_id, &device_id?)?; + } + + // Set the password to "" to indicate a deactivated account. Hashes will never result in an + // empty string, so the user will not be able to log in again. Systems like changing the + // password without logging in should check if the account is deactivated. + self.userid_password.insert(user_id.as_bytes(), &[])?; + + // TODO: Unhook 3PID + Ok(()) + } + + /// Creates a new sync filter. Returns the filter id. + #[tracing::instrument(skip(self))] + pub fn create_filter( + &self, + user_id: &UserId, + filter: &IncomingFilterDefinition, + ) -> Result { + let filter_id = utils::random_string(4); + + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(filter_id.as_bytes()); + + self.userfilterid_filter.insert( + &key, + &serde_json::to_vec(&filter).expect("filter is valid json"), + )?; + + Ok(filter_id) + } + + #[tracing::instrument(skip(self))] + pub fn get_filter( + &self, + user_id: &UserId, + filter_id: &str, + ) -> Result> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(filter_id.as_bytes()); + + let raw = self.userfilterid_filter.get(&key)?; + + if let Some(raw) = raw { + serde_json::from_slice(&raw) + .map_err(|_| Error::bad_database("Invalid filter event in db.")) + } else { + Ok(None) + } + } +} + +/// Ensure that a user only sees signatures from themselves and the target user +fn clean_signatures bool>( + cross_signing_key: &mut serde_json::Value, + user_id: &UserId, + allowed_signatures: F, +) -> Result<(), Error> { + if let Some(signatures) = cross_signing_key + .get_mut("signatures") + .and_then(|v| v.as_object_mut()) + { + // Don't allocate for the full size of the current signatures, but require + // at most one resize if nothing is dropped + let new_capacity = signatures.len() / 2; + for (user, signature) in + mem::replace(signatures, serde_json::Map::with_capacity(new_capacity)) + { + let id = <&UserId>::try_from(user.as_str()) + .map_err(|_| Error::bad_database("Invalid user ID in database."))?; + if id == user_id || allowed_signatures(id) { + signatures.insert(user, signature); + } + } + } + + Ok(()) +}