diff --git a/Cargo.lock b/Cargo.lock index 949f90d6..267f4099 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1560,7 +1560,6 @@ dependencies = [ [[package]] name = "ruma" version = "0.0.1" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fixes#7f8c78e8ba4be7fda450285e62493f6b33cb085a" dependencies = [ "ruma-api", "ruma-client-api", @@ -1574,7 +1573,6 @@ dependencies = [ [[package]] name = "ruma-api" version = "0.17.0-alpha.1" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fixes#7f8c78e8ba4be7fda450285e62493f6b33cb085a" dependencies = [ "http", "percent-encoding", @@ -1589,7 +1587,6 @@ dependencies = [ [[package]] name = "ruma-api-macros" version = "0.17.0-alpha.1" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fixes#7f8c78e8ba4be7fda450285e62493f6b33cb085a" dependencies = [ "proc-macro-crate", "proc-macro2", @@ -1600,7 +1597,6 @@ dependencies = [ [[package]] name = "ruma-client-api" version = "0.10.0-alpha.1" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fixes#7f8c78e8ba4be7fda450285e62493f6b33cb085a" dependencies = [ "assign", "http", @@ -1618,7 +1614,6 @@ dependencies = [ [[package]] name = "ruma-common" version = "0.2.0" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fixes#7f8c78e8ba4be7fda450285e62493f6b33cb085a" dependencies = [ "js_int", "ruma-identifiers", @@ -1631,7 +1626,6 @@ dependencies = [ [[package]] name = "ruma-events" version = "0.22.0-alpha.1" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fixes#7f8c78e8ba4be7fda450285e62493f6b33cb085a" dependencies = [ "js_int", "ruma-common", @@ -1646,7 +1640,6 @@ dependencies = [ [[package]] name = "ruma-events-macros" version = "0.22.0-alpha.1" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fixes#7f8c78e8ba4be7fda450285e62493f6b33cb085a" dependencies = [ "proc-macro-crate", "proc-macro2", @@ -1657,7 +1650,6 @@ dependencies = [ [[package]] name = "ruma-federation-api" version = "0.0.3" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fixes#7f8c78e8ba4be7fda450285e62493f6b33cb085a" dependencies = [ "js_int", "ruma-api", @@ -1672,7 +1664,6 @@ dependencies = [ [[package]] name = "ruma-identifiers" version = "0.17.4" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fixes#7f8c78e8ba4be7fda450285e62493f6b33cb085a" dependencies = [ "rand", "ruma-identifiers-macros", @@ -1684,7 +1675,6 @@ dependencies = [ [[package]] name = "ruma-identifiers-macros" version = "0.17.4" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fixes#7f8c78e8ba4be7fda450285e62493f6b33cb085a" dependencies = [ "proc-macro2", "quote", @@ -1695,7 +1685,6 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" version = "0.1.1" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fixes#7f8c78e8ba4be7fda450285e62493f6b33cb085a" dependencies = [ "ruma-serde", "serde", @@ -1706,7 +1695,6 @@ dependencies = [ [[package]] name = "ruma-serde" version = "0.2.3" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fixes#7f8c78e8ba4be7fda450285e62493f6b33cb085a" dependencies = [ "form_urlencoded", "itoa", @@ -1718,7 +1706,6 @@ dependencies = [ [[package]] name = "ruma-signatures" version = "0.6.0-dev.1" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fixes#7f8c78e8ba4be7fda450285e62493f6b33cb085a" dependencies = [ "base64 0.12.3", "ring", diff --git a/Cargo.toml b/Cargo.toml index 4945e3c8..ceb78839 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,8 +17,8 @@ edition = "2018" rocket = { git = "https://github.com/timokoesters/Rocket.git", branch = "empty_parameters", features = ["tls"] } #ruma = { git = "https://github.com/ruma/ruma", features = ["rand", "client-api", "federation-api", "unstable-pre-spec", "unstable-synapse-quirks"], rev = "987d48666cf166cf12100b5dbc61b5e3385c4014" } # Used for matrix spec type definitions and helpers -ruma = { git = "https://github.com/timokoesters/ruma", features = ["rand", "client-api", "federation-api", "unstable-pre-spec", "unstable-synapse-quirks"], branch = "timo-fixes" } # Used for matrix spec type definitions and helpers -#ruma = { path = "../ruma/ruma", features = ["rand", "client-api", "federation-api", "unstable-pre-spec", "unstable-synapse-quirks"] } +#ruma = { git = "https://github.com/timokoesters/ruma", features = ["rand", "client-api", "federation-api", "unstable-pre-spec", "unstable-synapse-quirks"], branch = "timo-fixes" } # Used for matrix spec type definitions and helpers +ruma = { path = "../ruma/ruma", features = ["rand", "client-api", "federation-api", "unstable-pre-spec", "unstable-synapse-quirks"] } tokio = "0.2.22" # Used for long polling sled = "0.32.0" # Used for storing data permanently log = "0.4.8" # Used for emitting log entries diff --git a/src/client_server/mod.rs b/src/client_server/mod.rs index 7703198b..e5a36f3a 100644 --- a/src/client_server/mod.rs +++ b/src/client_server/mod.rs @@ -17,6 +17,7 @@ mod push; mod read_marker; mod redact; mod room; +mod search; mod session; mod state; mod sync; @@ -47,6 +48,7 @@ pub use push::*; pub use read_marker::*; pub use redact::*; pub use room::*; +pub use search::*; pub use session::*; pub use state::*; pub use sync::*; diff --git a/src/client_server/search.rs b/src/client_server/search.rs new file mode 100644 index 00000000..9e465dd9 --- /dev/null +++ b/src/client_server/search.rs @@ -0,0 +1,93 @@ +use super::State; +use crate::{ConduitResult, Database, Error, Ruma}; +use js_int::uint; +use ruma::api::client::{error::ErrorKind, r0::search::search_events}; + +#[cfg(feature = "conduit_bin")] +use rocket::post; +use search_events::{ResultCategories, ResultRoomEvents, SearchResult}; +use std::collections::BTreeMap; + +#[cfg_attr( + feature = "conduit_bin", + post("/_matrix/client/r0/search", data = "") +)] +pub fn search_events_route( + db: State<'_, Database>, + body: Ruma, +) -> ConduitResult { + let sender_id = body.sender_id.as_ref().expect("user is authenticated"); + + let search_criteria = body.search_categories.room_events.as_ref().unwrap(); + let filter = search_criteria + .filter + .as_ref() + .unwrap(); + + let room_id = filter.rooms + .as_ref() + .unwrap() + .first() + .unwrap(); + + let limit = filter.limit.map_or(10, |l| u64::from(l) as usize); + + if !db.rooms.is_joined(sender_id, &room_id)? { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "You don't have permission to view this room.", + )); + } + + let skip = match body.next_batch.as_ref().map(|s| s.parse()) { + Some(Ok(s)) => s, + Some(Err(_)) => { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Invalid next_batch token.", + )) + } + None => 0, // Default to the start + }; + + let search = db + .rooms + .search_pdus(&room_id, &search_criteria.search_term)?; + + let results = search + .0 + .map(|result| { + Ok::<_, Error>(SearchResult { + context: None, + rank: None, + result: dbg!(db + .rooms + .get_pdu_from_id(dbg!(&result))? + .map(|pdu| pdu.to_room_event())), + }) + }) + .filter_map(|r| r.ok()) + .skip(skip) + .take(limit) + .collect::>(); + + let next_batch = if results.len() < limit as usize { + None + } else { + Some((skip + limit).to_string()) + }; + + Ok(search_events::Response { + search_categories: ResultCategories { + room_events: Some(ResultRoomEvents { + count: uint!(0), // TODO + groups: BTreeMap::new(), // TODO + next_batch, + results, + state: BTreeMap::new(), // TODO + highlights: search.1, + }), + }, + } + .into()) +} diff --git a/src/database.rs b/src/database.rs index 844a1f47..eb27325f 100644 --- a/src/database.rs +++ b/src/database.rs @@ -104,6 +104,8 @@ impl Database { aliasid_alias: db.open_tree("alias_roomid")?, publicroomids: db.open_tree("publicroomids")?, + tokenids: db.open_tree("tokenids")?, + userroomid_joined: db.open_tree("userroomid_joined")?, roomuserid_joined: db.open_tree("roomuserid_joined")?, userroomid_invited: db.open_tree("userroomid_invited")?, diff --git a/src/database/rooms.rs b/src/database/rooms.rs index fe633180..3b3c2c6a 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -35,6 +35,8 @@ pub struct Rooms { pub(super) aliasid_alias: sled::Tree, // AliasId = RoomId + Count pub(super) publicroomids: sled::Tree, + pub(super) tokenids: sled::Tree, // TokenId = RoomId + Token + PduId + pub(super) userroomid_joined: sled::Tree, pub(super) roomuserid_joined: sled::Tree, pub(super) userroomid_invited: sled::Tree, @@ -562,7 +564,7 @@ impl Rooms { self.pduid_pdu.insert(&pdu_id, &*pdu_json.to_string())?; self.eventid_pduid - .insert(pdu.event_id.to_string(), pdu_id)?; + .insert(pdu.event_id.to_string(), pdu_id.clone())?; if let Some(state_key) = pdu.state_key { let mut key = room_id.to_string().as_bytes().to_vec(); @@ -573,7 +575,7 @@ impl Rooms { self.roomstateid_pdu.insert(key, &*pdu_json.to_string())?; } - match event_type { + match dbg!(event_type) { EventType::RoomRedaction => { if let Some(redact_id) = &redacts { // TODO: Reason @@ -616,6 +618,21 @@ impl Rooms { )?; } } + EventType::RoomMessage => { + if let Some(body) = dbg!(content).get("body").and_then(|b| b.as_str()) { + for word in body + .split_terminator(|c: char| !c.is_alphanumeric()) + .map(str::to_lowercase) + { + let mut key = room_id.to_string().as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(word.as_bytes()); + key.push(0xff); + key.extend_from_slice(&pdu_id); + self.tokenids.insert(key, &[])?; + } + } + } _ => {} } self.edus.room_read_set(&room_id, &sender, index)?; @@ -928,6 +945,80 @@ impl Rooms { }) } + pub fn search_pdus( + &self, + room_id: &RoomId, + search_string: &str, + ) -> Result<(impl Iterator, Vec)> { + let mut prefix = room_id.to_string().as_bytes().to_vec(); + prefix.push(0xff); + + let words = search_string + .split_terminator(|c: char| !c.is_alphanumeric()) + .map(str::to_lowercase) + .collect::>(); + + let mut iterators = words.iter().map(|word| { + let mut prefix2 = prefix.clone(); + prefix2.extend_from_slice(word.as_bytes()); + prefix2.push(0xff); + self.tokenids + .scan_prefix(&prefix2) + .keys() + .filter_map(|r| r.ok()) + .map(|key| { + let pduid_index = key + .iter() + .enumerate() + .filter(|(_, &b)| b == 0xff) + .nth(1) + .ok_or_else(|| Error::bad_database("Invalid tokenid in db."))? + .0 + 1; // +1 because the pdu id starts AFTER the separator + + let pdu_id = + key.subslice(pduid_index, key.len() - pduid_index); + + Ok::<_, Error>(pdu_id) + }) + .filter_map(|r| r.ok()) + .peekable() + }); + + let first_iterator = match iterators.next() { + Some(i) => i, + None => { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "search_term needs to contain at least one word.", + )) + } + }; + + let mut other_iterators = iterators.collect::>(); + + Ok(( + first_iterator.filter(move |target| { + other_iterators + .iter_mut() + .map(|it| { + while let Some(element) = it.peek() { + if dbg!(element) > dbg!(target) { + return false; + } else if element == target { + return true; + } else { + it.next(); + } + } + + false + }) + .all(|b| b) + }), + words, + )) + } + /// Returns an iterator over all joined members of a room. pub fn room_members(&self, room_id: &RoomId) -> impl Iterator> { self.roomuserid_joined diff --git a/src/main.rs b/src/main.rs index d3a673f6..bbe7c962 100644 --- a/src/main.rs +++ b/src/main.rs @@ -90,6 +90,7 @@ fn setup_rocket() -> rocket::Rocket { client_server::sync_events_route, client_server::get_context_route, client_server::get_message_events_route, + client_server::search_events_route, client_server::turn_server_route, client_server::send_event_to_device_route, client_server::get_media_config_route,