diff --git a/src/client_server/keys.rs b/src/client_server/keys.rs index 8f1afba9..1ae9f80c 100644 --- a/src/client_server/keys.rs +++ b/src/client_server/keys.rs @@ -1,19 +1,23 @@ use super::SESSION_ID_LENGTH; use crate::{database::DatabaseGuard, utils, ConduitResult, Database, Error, Result, Ruma}; use ruma::{ - api::client::{ - error::ErrorKind, - r0::{ - keys::{ - claim_keys, get_key_changes, get_keys, upload_keys, upload_signatures, - upload_signing_keys, + api::{ + client::{ + error::ErrorKind, + r0::{ + keys::{ + claim_keys, get_key_changes, get_keys, upload_keys, upload_signatures, + upload_signing_keys, + }, + uiaa::{AuthFlow, UiaaInfo}, }, - uiaa::{AuthFlow, UiaaInfo}, }, + federation, }, encryption::UnsignedDeviceInfo, DeviceId, DeviceKeyAlgorithm, UserId, }; +use serde_json::json; use std::collections::{BTreeMap, HashSet}; #[cfg(feature = "conduit_bin")] @@ -84,7 +88,8 @@ pub async fn get_keys_route( &body.device_keys, |u| u == sender_user, &db, - )?; + ) + .await?; Ok(response.into()) } @@ -98,7 +103,7 @@ pub async fn claim_keys_route( db: DatabaseGuard, body: Ruma, ) -> ConduitResult { - let response = claim_keys_helper(&body.one_time_keys, &db)?; + let response = claim_keys_helper(&body.one_time_keys, &db).await?; db.flush().await?; @@ -278,7 +283,7 @@ pub async fn get_key_changes_route( .into()) } -pub fn get_keys_helper bool>( +pub async fn get_keys_helper bool>( sender_user: Option<&UserId>, device_keys_input: &BTreeMap>>, allowed_signatures: F, @@ -289,7 +294,16 @@ pub fn get_keys_helper bool>( let mut user_signing_keys = BTreeMap::new(); let mut device_keys = BTreeMap::new(); + let mut get_over_federation = BTreeMap::new(); + for (user_id, device_ids) in device_keys_input { + if user_id.server_name() != db.globals.server_name() { + get_over_federation + .entry(user_id.server_name()) + .or_insert_with(Vec::new) + .push((user_id, device_ids)); + } + if device_ids.is_empty() { let mut container = BTreeMap::new(); for device_id in db.users.all_device_ids(user_id) { @@ -347,21 +361,51 @@ pub fn get_keys_helper bool>( } } + let mut failures = BTreeMap::new(); + + for (server, vec) in get_over_federation { + let mut device_keys = BTreeMap::new(); + for (user_id, keys) in vec { + device_keys.insert(user_id.clone(), keys.clone()); + } + if let Err(_e) = db + .sending + .send_federation_request( + &db.globals, + server, + federation::keys::get_keys::v1::Request { device_keys }, + ) + .await + { + failures.insert(server.to_string(), json!({})); + } + } + Ok(get_keys::Response { master_keys, self_signing_keys, user_signing_keys, device_keys, - failures: BTreeMap::new(), + failures, }) } -pub fn claim_keys_helper( +pub async fn claim_keys_helper( one_time_keys_input: &BTreeMap, DeviceKeyAlgorithm>>, db: &Database, ) -> Result { let mut one_time_keys = BTreeMap::new(); + + let mut get_over_federation = BTreeMap::new(); + for (user_id, map) in one_time_keys_input { + if user_id.server_name() != db.globals.server_name() { + get_over_federation + .entry(user_id.server_name()) + .or_insert_with(Vec::new) + .push((user_id, map)); + } + let mut container = BTreeMap::new(); for (device_id, key_algorithm) in map { if let Some(one_time_keys) = @@ -376,6 +420,26 @@ pub fn claim_keys_helper( one_time_keys.insert(user_id.clone(), container); } + for (server, vec) in get_over_federation { + let mut one_time_keys_input_fed = BTreeMap::new(); + for (user_id, keys) in vec { + one_time_keys_input_fed.insert(user_id.clone(), keys.clone()); + } + // Ignore failures + let keys = db + .sending + .send_federation_request( + &db.globals, + server, + federation::keys::claim_keys::v1::Request { + one_time_keys: one_time_keys_input_fed, + }, + ) + .await?; + + one_time_keys.extend(keys.one_time_keys); + } + Ok(claim_keys::Response { failures: BTreeMap::new(), one_time_keys, diff --git a/src/server_server.rs b/src/server_server.rs index d51b672e..e8c19dbc 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -2461,7 +2461,8 @@ pub async fn get_keys_route( &body.device_keys, |u| Some(u.server_name()) == body.sender_servername.as_deref(), &db, - )?; + ) + .await?; db.flush().await?; @@ -2486,7 +2487,7 @@ pub async fn claim_keys_route( return Err(Error::bad_config("Federation is disabled.")); } - let result = claim_keys_helper(&body.one_time_keys, &db)?; + let result = claim_keys_helper(&body.one_time_keys, &db).await?; db.flush().await?;