diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs index c88e0a86..52e074c2 100644 --- a/src/client_server/membership.rs +++ b/src/client_server/membership.rs @@ -5,7 +5,6 @@ use crate::{ server_server, utils, ConduitResult, Database, Error, Result, Ruma, }; use member::{MemberEventContent, MembershipState}; -use rocket::futures; use ruma::{ api::{ client::{ @@ -667,14 +666,19 @@ async fn join_room_by_id_helper( let mut state = HashMap::new(); let pub_key_map = RwLock::new(BTreeMap::new()); - for result in futures::future::join_all( - send_join_response - .room_state - .state - .iter() - .map(|pdu| validate_and_add_event_id(pdu, &room_version, &pub_key_map, &db)), + server_server::fetch_join_signing_keys( + &send_join_response, + &room_version, + &pub_key_map, + &db, ) - .await + .await?; + + for result in send_join_response + .room_state + .state + .iter() + .map(|pdu| validate_and_add_event_id(pdu, &room_version, &pub_key_map, &db)) { let (event_id, value) = match result { Ok(t) => t, @@ -723,14 +727,11 @@ async fn join_room_by_id_helper( &db, )?; - for result in futures::future::join_all( - send_join_response - .room_state - .auth_chain - .iter() - .map(|pdu| validate_and_add_event_id(pdu, &room_version, &pub_key_map, &db)), - ) - .await + for result in send_join_response + .room_state + .auth_chain + .iter() + .map(|pdu| validate_and_add_event_id(pdu, &room_version, &pub_key_map, &db)) { let (event_id, value) = match result { Ok(t) => t, @@ -787,7 +788,7 @@ async fn join_room_by_id_helper( Ok(join_room_by_id::Response::new(room_id.clone()).into()) } -async fn validate_and_add_event_id( +fn validate_and_add_event_id( pdu: &Raw, room_version: &RoomVersionId, pub_key_map: &RwLock>>, @@ -830,7 +831,6 @@ async fn validate_and_add_event_id( } } - server_server::fetch_required_signing_keys(&value, pub_key_map, db).await?; if let Err(e) = ruma::signatures::verify_event( &*pub_key_map .read() diff --git a/src/server_server.rs b/src/server_server.rs index b81610e3..b83eaa4a 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -6,7 +6,7 @@ use crate::{ use get_profile_information::v1::ProfileField; use http::header::{HeaderValue, AUTHORIZATION}; use regex::Regex; -use rocket::response::content::Json; +use rocket::{futures, response::content::Json}; use ruma::{ api::{ client::error::{Error as RumaError, ErrorKind}, @@ -15,8 +15,9 @@ use ruma::{ device::get_devices::{self, v1::UserDevice}, directory::{get_public_rooms, get_public_rooms_filtered}, discovery::{ - get_remote_server_keys, get_server_keys, get_server_version, ServerSigningKeys, - VerifyKey, + get_remote_server_keys, get_remote_server_keys_batch, + get_remote_server_keys_batch::v2::QueryCriteria, get_server_keys, + get_server_version, ServerSigningKeys, VerifyKey, }, event::{get_event, get_missing_events, get_room_state, get_room_state_ids}, keys::{claim_keys, get_keys}, @@ -35,6 +36,7 @@ use ruma::{ }, directory::{IncomingFilter, IncomingRoomNetwork}, events::{ + pdu::Pdu, receipt::{ReceiptEvent, ReceiptEventContent}, room::{ create::CreateEventContent, @@ -3277,6 +3279,204 @@ pub(crate) async fn fetch_required_signing_keys( Ok(()) } +pub fn get_missing_signing_keys_for_pdus( + pdus: &Vec>, + servers: &mut BTreeMap, BTreeMap>, + room_version: &RoomVersionId, + pub_key_map: &RwLock>>, + db: &Database, +) -> Result<()> { + for pdu in pdus { + let value = serde_json::from_str::(pdu.json().get()).map_err(|e| { + error!("Invalid PDU in server response: {:?}: {:?}", pdu, e); + Error::BadServerResponse("Invalid PDU in server response") + })?; + let event_id = EventId::try_from(&*format!( + "${}", + ruma::signatures::reference_hash(&value, &room_version) + .expect("ruma can calculate reference hashes") + )) + .expect("ruma's reference hashes are valid event ids"); + + if let Some((time, tries)) = db + .globals + .bad_event_ratelimiter + .read() + .unwrap() + .get(&event_id) + { + // Exponential backoff + let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries); + if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { + min_elapsed_duration = Duration::from_secs(60 * 60 * 24); + } + + if time.elapsed() < min_elapsed_duration { + debug!("Backing off from {}", event_id); + return Err(Error::BadServerResponse("bad event, still backing off")); + } + } + + let signatures = value + .get("signatures") + .ok_or(Error::BadServerResponse( + "No signatures in server response pdu.", + ))? + .as_object() + .ok_or(Error::BadServerResponse( + "Invalid signatures object in server response pdu.", + ))?; + + for (signature_server, signature) in signatures { + let signature_object = signature.as_object().ok_or(Error::BadServerResponse( + "Invalid signatures content object in server response pdu.", + ))?; + + let signature_ids = signature_object.keys().cloned().collect::>(); + + let contains_all_ids = |keys: &BTreeMap| { + signature_ids.iter().all(|id| keys.contains_key(id)) + }; + + let origin = &Box::::try_from(&**signature_server).map_err(|_| { + Error::BadServerResponse("Invalid servername in signatures of server response pdu.") + })?; + + trace!("Loading signing keys for {}", origin); + + let result = db + .globals + .signing_keys_for(origin)? + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)) + .collect::>(); + + if !contains_all_ids(&result) { + trace!("Signing key not loaded for {}", origin); + servers.insert( + origin.clone(), + BTreeMap::::new(), + ); + } + + pub_key_map + .write() + .map_err(|_| Error::bad_database("RwLock is poisoned."))? + .insert(origin.to_string(), result); + } + } + + Ok(()) +} + +pub async fn fetch_join_signing_keys( + event: &create_join_event::v2::Response, + room_version: &RoomVersionId, + pub_key_map: &RwLock>>, + db: &Database, +) -> Result<()> { + let mut servers = + BTreeMap::, BTreeMap>::new(); + + get_missing_signing_keys_for_pdus( + &event.room_state.state, + &mut servers, + &room_version, + &pub_key_map, + &db, + )?; + get_missing_signing_keys_for_pdus( + &event.room_state.auth_chain, + &mut servers, + &room_version, + &pub_key_map, + &db, + )?; + + if servers.is_empty() { + return Ok(()); + } + + for server in db.globals.trusted_servers() { + if db.globals.signing_keys_for(server)?.is_empty() { + servers.insert( + server.clone(), + BTreeMap::::new(), + ); + } + } + + for server in db.globals.trusted_servers() { + trace!("Asking batch signing keys from trusted server {}", server); + if let Ok(keys) = db + .sending + .send_federation_request( + &db.globals, + server, + get_remote_server_keys_batch::v2::Request { + server_keys: servers.clone(), + minimum_valid_until_ts: MilliSecondsSinceUnixEpoch::from_system_time( + SystemTime::now() + Duration::from_secs(60), + ) + .expect("time is valid"), + }, + ) + .await + { + trace!("Got signing keys: {:?}", keys); + for k in keys.server_keys { + // TODO: Check signature + servers.remove(&k.server_name); + + db.globals.add_signing_key(&k.server_name, k.clone())?; + + let result = db + .globals + .signing_keys_for(&k.server_name)? + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)) + .collect::>(); + + pub_key_map + .write() + .map_err(|_| Error::bad_database("RwLock is poisoned."))? + .insert(k.server_name.to_string(), result); + } + } + if servers.is_empty() { + return Ok(()); + } + } + + for result in futures::future::join_all(servers.iter().map(|(server, _)| { + db.sending + .send_federation_request(&db.globals, server, get_server_keys::v2::Request::new()) + })) + .await + { + if let Ok(get_keys_response) = result { + // TODO: We should probably not trust the server_name in the response. + let server = &get_keys_response.server_key.server_name; + db.globals + .add_signing_key(server, get_keys_response.server_key.clone())?; + + let result = db + .globals + .signing_keys_for(server)? + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)) + .collect::>(); + + pub_key_map + .write() + .map_err(|_| Error::bad_database("RwLock is poisoned."))? + .insert(server.to_string(), result); + } + } + + Ok(()) +} + #[cfg(test)] mod tests { use super::{add_port_to_hostname, get_ip_with_port, FedDest};