diff --git a/src/main.rs b/src/main.rs index fe6cfc01..59e82a79 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,7 +7,7 @@ #![allow(clippy::suspicious_else_formatting)] #![deny(clippy::dbg_macro)] -use std::{future::Future, io, net::SocketAddr, time::Duration}; +use std::{future::Future, io, net::SocketAddr, sync::atomic, time::Duration}; use axum::{ extract::{DefaultBodyLimit, FromRequest, MatchedPath}, @@ -147,6 +147,7 @@ async fn run_server() -> io::Result<()> { let middlewares = ServiceBuilder::new() .sensitive_headers([header::AUTHORIZATION]) + .layer(axum::middleware::from_fn(spawn_task)) .layer( TraceLayer::new_for_http().make_span_with(|request: &http::Request<_>| { let path = if let Some(path) = request.extensions().get::() { @@ -211,16 +212,21 @@ async fn run_server() -> io::Result<()> { } } - // On shutdown - info!(target: "shutdown-sync", "Received shutdown notification, notifying sync helpers..."); - services().globals.rotate.fire(); - - #[cfg(feature = "systemd")] - let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Stopping]); - Ok(()) } +async fn spawn_task( + req: axum::http::Request, + next: axum::middleware::Next, +) -> std::result::Result { + if services().globals.shutdown.load(atomic::Ordering::Relaxed) { + return Err(StatusCode::SERVICE_UNAVAILABLE); + } + tokio::spawn(next.run(req)) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) +} + async fn unrecognized_method( req: axum::http::Request, next: axum::middleware::Next, @@ -442,6 +448,11 @@ async fn shutdown_signal(handle: ServerHandle) { warn!("Received {}, shutting down...", sig); handle.graceful_shutdown(Some(Duration::from_secs(30))); + + services().globals.shutdown(); + + #[cfg(feature = "systemd")] + let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Stopping]); } async fn not_found(uri: Uri) -> impl IntoResponse { diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index cd3be081..9206d43f 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -6,7 +6,7 @@ use ruma::{ use crate::api::server_server::FedDest; -use crate::{Config, Error, Result}; +use crate::{services, Config, Error, Result}; use ruma::{ api::{ client::sync::sync_events, @@ -14,6 +14,7 @@ use ruma::{ }, DeviceId, RoomVersionId, ServerName, UserId, }; +use std::sync::atomic::{self, AtomicBool}; use std::{ collections::{BTreeMap, HashMap}, fs, @@ -24,7 +25,7 @@ use std::{ time::{Duration, Instant}, }; use tokio::sync::{broadcast, watch::Receiver, Mutex as TokioMutex, Semaphore}; -use tracing::error; +use tracing::{error, info}; use trust_dns_resolver::TokioAsyncResolver; type WellKnownMap = HashMap; @@ -58,6 +59,8 @@ pub struct Service { pub roomid_federationhandletime: RwLock>, pub stateres_mutex: Arc>, pub rotate: RotationHandler, + + pub shutdown: AtomicBool, } /// Handles "rotation" of long-polling requests. "Rotation" in this context is similar to "rotation" of log files and the like. @@ -160,6 +163,7 @@ impl Service { stateres_mutex: Arc::new(Mutex::new(())), sync_receivers: RwLock::new(HashMap::new()), rotate: RotationHandler::new(), + shutdown: AtomicBool::new(false), }; fs::create_dir_all(s.get_media_folder())?; @@ -341,6 +345,13 @@ impl Service { r.push(base64::encode_config(key, base64::URL_SAFE_NO_PAD)); r } + + pub fn shutdown(&self) { + self.shutdown.store(true, atomic::Ordering::Relaxed); + // On shutdown + info!(target: "shutdown-sync", "Received shutdown notification, notifying sync helpers..."); + services().globals.rotate.fire(); + } } fn reqwest_client_builder(config: &Config) -> Result {