From 475a68cbb93da208fed175a2868c97d4c44a8839 Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Fri, 12 Apr 2024 20:52:14 +0100 Subject: [PATCH] refactor: disable federation at the router level --- src/api/ruma_wrapper/axum.rs | 4 -- src/api/server_server.rs | 8 ---- src/main.rs | 76 +++++++++++++++++++++--------------- 3 files changed, 44 insertions(+), 44 deletions(-) diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index 895b601d..906904bf 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -149,10 +149,6 @@ where Token::User((user_id, device_id)), ) => (Some(user_id), Some(device_id), None, false), (AuthScheme::ServerSignatures, Token::None) => { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - let TypedHeader(Authorization(x_matrix)) = parts .extract::>>() .await diff --git a/src/api/server_server.rs b/src/api/server_server.rs index fa7f1315..b25b1313 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -527,10 +527,6 @@ async fn request_well_known(destination: &str) -> Option { pub async fn get_server_version_route( _body: Ruma, ) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - Ok(get_server_version::v1::Response { server: Some(get_server_version::v1::Server { name: Some("Conduit".to_owned()), @@ -547,10 +543,6 @@ pub async fn get_server_version_route( /// forever. // Response type for this endpoint is Json because we need to calculate a signature for the response pub async fn get_server_keys_route() -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - let mut verify_keys: BTreeMap = BTreeMap::new(); verify_keys.insert( format!("ed25519:{}", services().globals.keypair().version()) diff --git a/src/main.rs b/src/main.rs index b5bf742d..7beeb8ba 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,7 +3,7 @@ use std::{future::Future, io, net::SocketAddr, sync::atomic, time::Duration}; use axum::{ extract::{DefaultBodyLimit, FromRequestParts, MatchedPath}, response::IntoResponse, - routing::{get, on, MethodFilter}, + routing::{any, get, on, MethodFilter}, Router, }; use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle}; @@ -188,7 +188,7 @@ async fn run_server() -> io::Result<()> { .expect("failed to convert max request size"), )); - let app = routes().layer(middlewares).into_make_service(); + let app = routes(config).layer(middlewares).into_make_service(); let handle = ServerHandle::new(); tokio::spawn(shutdown_signal(handle.clone())); @@ -249,8 +249,8 @@ async fn unrecognized_method( Ok(inner) } -fn routes() -> Router { - Router::new() +fn routes(config: &Config) -> Router { + let router = Router::new() .ruma_route(client_server::get_supported_versions_route) .ruma_route(client_server::get_register_available_route) .ruma_route(client_server::register_route) @@ -390,33 +390,6 @@ fn routes() -> Router { .ruma_route(client_server::get_relating_events_with_rel_type_route) .ruma_route(client_server::get_relating_events_route) .ruma_route(client_server::get_hierarchy_route) - .ruma_route(server_server::get_server_version_route) - .route( - "/_matrix/key/v2/server", - get(server_server::get_server_keys_route), - ) - .route( - "/_matrix/key/v2/server/:key_id", - get(server_server::get_server_keys_deprecated_route), - ) - .ruma_route(server_server::get_public_rooms_route) - .ruma_route(server_server::get_public_rooms_filtered_route) - .ruma_route(server_server::send_transaction_message_route) - .ruma_route(server_server::get_event_route) - .ruma_route(server_server::get_backfill_route) - .ruma_route(server_server::get_missing_events_route) - .ruma_route(server_server::get_event_authorization_route) - .ruma_route(server_server::get_room_state_route) - .ruma_route(server_server::get_room_state_ids_route) - .ruma_route(server_server::create_join_event_template_route) - .ruma_route(server_server::create_join_event_v1_route) - .ruma_route(server_server::create_join_event_v2_route) - .ruma_route(server_server::create_invite_route) - .ruma_route(server_server::get_devices_route) - .ruma_route(server_server::get_room_information_route) - .ruma_route(server_server::get_profile_information_route) - .ruma_route(server_server::get_keys_route) - .ruma_route(server_server::claim_keys_route) .route( "/_matrix/client/r0/rooms/:room_id/initialSync", get(initial_sync), @@ -426,7 +399,42 @@ fn routes() -> Router { get(initial_sync), ) .route("/", get(it_works)) - .fallback(not_found) + .fallback(not_found); + + if config.allow_federation { + router + .ruma_route(server_server::get_server_version_route) + .route( + "/_matrix/key/v2/server", + get(server_server::get_server_keys_route), + ) + .route( + "/_matrix/key/v2/server/:key_id", + get(server_server::get_server_keys_deprecated_route), + ) + .ruma_route(server_server::get_public_rooms_route) + .ruma_route(server_server::get_public_rooms_filtered_route) + .ruma_route(server_server::send_transaction_message_route) + .ruma_route(server_server::get_event_route) + .ruma_route(server_server::get_backfill_route) + .ruma_route(server_server::get_missing_events_route) + .ruma_route(server_server::get_event_authorization_route) + .ruma_route(server_server::get_room_state_route) + .ruma_route(server_server::get_room_state_ids_route) + .ruma_route(server_server::create_join_event_template_route) + .ruma_route(server_server::create_join_event_v1_route) + .ruma_route(server_server::create_join_event_v2_route) + .ruma_route(server_server::create_invite_route) + .ruma_route(server_server::get_devices_route) + .ruma_route(server_server::get_room_information_route) + .ruma_route(server_server::get_profile_information_route) + .ruma_route(server_server::get_keys_route) + .ruma_route(server_server::claim_keys_route) + } else { + router + .route("/_matrix/federation/*path", any(federation_disabled)) + .route("/_matrix/key/*path", any(federation_disabled)) + } } async fn shutdown_signal(handle: ServerHandle) { @@ -463,6 +471,10 @@ async fn shutdown_signal(handle: ServerHandle) { let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Stopping]); } +async fn federation_disabled(_: Uri) -> impl IntoResponse { + Error::bad_config("Federation is disabled.") +} + async fn not_found(uri: Uri) -> impl IntoResponse { warn!("Not found: {uri}"); Error::BadRequest(ErrorKind::Unrecognized, "Unrecognized request")