From 821c608c6a9bfafd11a6f4654852f3778f713049 Mon Sep 17 00:00:00 2001 From: timokoesters Date: Mon, 18 May 2020 17:53:34 +0200 Subject: [PATCH] feat: media --- src/client_server.rs | 69 +++++++++++++++++++++++++++++++++++++++---- src/database.rs | 5 ++++ src/database/media.rs | 56 +++++++++++++++++++++++++++++++++++ src/main.rs | 3 ++ src/ruma_wrapper.rs | 12 +++----- 5 files changed, 132 insertions(+), 13 deletions(-) create mode 100644 src/database/media.rs diff --git a/src/client_server.rs b/src/client_server.rs index ab70b315..30b409c0 100644 --- a/src/client_server.rs +++ b/src/client_server.rs @@ -16,7 +16,7 @@ use ruma_client_api::{ directory::{self, get_public_rooms_filtered}, filter::{self, create_filter, get_filter}, keys::{claim_keys, get_keys, upload_keys}, - media::get_media_config, + media::{create_content, get_content_thumbnail, get_content, get_media_config}, membership::{ forget_room, get_member_events, invite_user, join_room_by_id, join_room_by_id_or_alias, leave_room, @@ -53,6 +53,7 @@ const GUEST_NAME_LENGTH: usize = 10; const DEVICE_ID_LENGTH: usize = 10; const SESSION_ID_LENGTH: usize = 256; const TOKEN_LENGTH: usize = 256; +const MXC_LENGTH: usize = 256; #[get("/_matrix/client/versions")] pub fn get_supported_versions_route() -> MatrixResult { @@ -1259,7 +1260,7 @@ pub fn create_message_event_route( body.room_id.clone(), user_id.clone(), body.event_type.clone(), - body.json_body.clone(), + body.json_body.clone().unwrap(), Some(unsigned), None, &db.globals, @@ -1291,7 +1292,7 @@ pub fn create_state_event_for_key_route( body.room_id.clone(), user_id.clone(), body.event_type.clone(), - body.json_body.clone(), + body.json_body.clone().unwrap(), None, Some(body.state_key.clone()), &db.globals, @@ -1322,7 +1323,7 @@ pub fn create_state_event_for_empty_key_route( body.room_id.clone(), user_id.clone(), body.event_type.clone(), - body.json_body.clone(), + body.json_body.clone().unwrap(), None, Some("".to_owned()), &db.globals, @@ -1766,10 +1767,68 @@ pub fn send_event_to_device_route( pub fn get_media_config_route() -> MatrixResult { warn!("TODO: get_media_config_route"); MatrixResult(Ok(get_media_config::Response { - upload_size: 0_u32.into(), + upload_size: (20_u32 * 1024 * 1024).into(), // 20 MB })) } +#[post("/_matrix/media/r0/upload", data = "")] +pub fn create_content_route( + db: State<'_, Database>, + body: Ruma, +) -> MatrixResult { + let mxc = format!("mxc://{}/{}", db.globals.server_name(), utils::random_string(MXC_LENGTH)); + db.media + .create(mxc.clone(), body.filename.as_ref(), &body.content_type, &body.file) + .unwrap(); + + MatrixResult(Ok(create_content::Response { + content_uri: mxc, + })) +} + +#[get("/_matrix/media/r0/download/<_server_name>/<_media_id>", data = "")] +pub fn get_content_route( + db: State<'_, Database>, + body: Ruma, + _server_name: String, + _media_id: String, +) -> MatrixResult { + if let Some((filename, content_type, file)) = db.media.get(format!("mxc://{}/{}", body.server_name, body.media_id)).unwrap() { + MatrixResult(Ok(get_content::Response { + file, + content_type, + content_disposition: filename.unwrap_or_default(), // TODO: Spec says this should be optional + })) + } else { + MatrixResult(Err(Error { + kind: ErrorKind::NotFound, + message: "Media not found.".to_owned(), + status_code: http::StatusCode::NOT_FOUND, + })) + } +} + +#[get("/_matrix/media/r0/thumbnail/<_server_name>/<_media_id>", data = "")] +pub fn get_content_thumbnail_route( + db: State<'_, Database>, + body: Ruma, + _server_name: String, + _media_id: String, +) -> MatrixResult { + if let Some((_, content_type, file)) = db.media.get(format!("mxc://{}/{}", body.server_name, body.media_id)).unwrap() { + MatrixResult(Ok(get_content_thumbnail::Response { + file, + content_type, + })) + } else { + MatrixResult(Err(Error { + kind: ErrorKind::NotFound, + message: "Media not found.".to_owned(), + status_code: http::StatusCode::NOT_FOUND, + })) + } +} + #[options("/<_segments..>")] pub fn options_route( _segments: rocket::http::uri::Segments<'_>, diff --git a/src/database.rs b/src/database.rs index 77ea2f9d..9c08a22e 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,6 +1,7 @@ pub(self) mod account_data; pub(self) mod global_edus; pub(self) mod globals; +pub(self) mod media; pub(self) mod rooms; pub(self) mod users; @@ -15,6 +16,7 @@ pub struct Database { pub rooms: rooms::Rooms, pub account_data: account_data::AccountData, pub global_edus: global_edus::GlobalEdus, + pub media: media::Media, pub _db: sled::Db, } @@ -88,6 +90,9 @@ impl Database { //globalallid_globalall: db.open_tree("globalallid_globalall").unwrap(), globallatestid_globallatest: db.open_tree("globallatestid_globallatest").unwrap(), // Presence }, + media: media::Media { + mediaid_file: db.open_tree("mediaid_file").unwrap(), + }, _db: db, } } diff --git a/src/database/media.rs b/src/database/media.rs new file mode 100644 index 00000000..36d94101 --- /dev/null +++ b/src/database/media.rs @@ -0,0 +1,56 @@ +use crate::{utils, Error, Result}; + +pub struct Media { + pub(super) mediaid_file: sled::Tree, // MediaId = MXC + Filename + ContentType +} + +impl Media { + /// Uploads or replaces a file. + pub fn create( + &self, + mxc: String, + filename: Option<&String>, + content_type: &str, + file: &[u8], + ) -> Result<()> { + let mut key = mxc.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(filename.map(|f| f.as_bytes()).unwrap_or_default()); + key.push(0xff); + key.extend_from_slice(content_type.as_bytes()); + + self.mediaid_file.insert(key, file)?; + + Ok(()) + } + + /// Downloads a file. + pub fn get(&self, mxc: String) -> Result, String, Vec)>> { + let mut prefix = mxc.as_bytes().to_vec(); + prefix.push(0xff); + + if let Some(r) = self.mediaid_file.scan_prefix(&prefix).next() { + let (key, file) = r?; + let mut parts = key.split(|&b| b == 0xff).skip(1); + + let filename_bytes = parts + .next() + .ok_or(Error::BadDatabase("mediaid is invalid"))?; + let filename = if filename_bytes.is_empty() { + None + } else { + Some(utils::string_from_bytes(filename_bytes)?) + }; + + let content_type = utils::string_from_bytes( + parts + .next() + .ok_or(Error::BadDatabase("mediaid is invalid"))?, + )?; + + Ok(Some((filename, content_type, file.to_vec()))) + } else { + Ok(None) + } + } +} diff --git a/src/main.rs b/src/main.rs index 7581da16..043f7571 100644 --- a/src/main.rs +++ b/src/main.rs @@ -71,6 +71,9 @@ fn setup_rocket() -> rocket::Rocket { client_server::publicised_groups_route, client_server::send_event_to_device_route, client_server::get_media_config_route, + client_server::create_content_route, + client_server::get_content_route, + client_server::get_content_thumbnail_route, client_server::options_route, server_server::well_known_server, server_server::get_server_version, diff --git a/src/ruma_wrapper.rs b/src/ruma_wrapper.rs index 21c59255..d6f6cfe8 100644 --- a/src/ruma_wrapper.rs +++ b/src/ruma_wrapper.rs @@ -11,7 +11,7 @@ use ruma_identifiers::UserId; use std::{convert::TryInto, io::Cursor, ops::Deref}; use tokio::io::AsyncReadExt; -const MESSAGE_LIMIT: u64 = 65535; +const MESSAGE_LIMIT: u64 = 20 * 1024 * 1024; // 20 MB /// This struct converts rocket requests into ruma structs by converting them into http requests /// first. @@ -19,7 +19,7 @@ pub struct Ruma { body: T, pub user_id: Option, pub device_id: Option, - pub json_body: serde_json::Value, + pub json_body: Option, // This is None if parsing failed (for raw byte bodies) } impl<'a, T: Endpoint> FromData<'a> for Ruma { @@ -85,12 +85,8 @@ impl<'a, T: Endpoint> FromData<'a> for Ruma { body: t, user_id, device_id, - // TODO: Can we avoid parsing it again? - json_body: if !body.is_empty() { - serde_json::from_slice(&body).expect("Ruma already parsed it successfully") - } else { - serde_json::Value::default() - }, + // TODO: Can we avoid parsing it again? (We only need this for append_pdu) + json_body: serde_json::from_slice(&body).ok() }), Err(e) => { warn!("{:?}", e);