feat: partially support sync filters

merge-requests/217/head
Timo Kösters 3 years ago
parent 68e910bb77
commit 1bd9fd74b3
No known key found for this signature in database
GPG Key ID: 356E705610F626D5

@ -1,32 +1,47 @@
use crate::{utils, ConduitResult}; use crate::{database::DatabaseGuard, ConduitResult, Error, Ruma};
use ruma::api::client::r0::filter::{self, create_filter, get_filter}; use ruma::api::client::{
error::ErrorKind,
r0::filter::{create_filter, get_filter},
};
#[cfg(feature = "conduit_bin")] #[cfg(feature = "conduit_bin")]
use rocket::{get, post}; use rocket::{get, post};
/// # `GET /_matrix/client/r0/user/{userId}/filter/{filterId}` /// # `GET /_matrix/client/r0/user/{userId}/filter/{filterId}`
/// ///
/// TODO: Loads a filter that was previously created. /// Loads a filter that was previously created.
#[cfg_attr(feature = "conduit_bin", get("/_matrix/client/r0/user/<_>/filter/<_>"))] ///
#[tracing::instrument] /// - A user can only access their own filters
pub async fn get_filter_route() -> ConduitResult<get_filter::Response> { #[cfg_attr(
// TODO feature = "conduit_bin",
Ok(get_filter::Response::new(filter::IncomingFilterDefinition { get("/_matrix/client/r0/user/<_>/filter/<_>", data = "<body>")
event_fields: None, )]
event_format: filter::EventFormat::default(), #[tracing::instrument(skip(db, body))]
account_data: filter::IncomingFilter::default(), pub async fn get_filter_route(
room: filter::IncomingRoomFilter::default(), db: DatabaseGuard,
presence: filter::IncomingFilter::default(), body: Ruma<get_filter::Request<'_>>,
}) ) -> ConduitResult<get_filter::Response> {
.into()) let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let filter = match db.users.get_filter(sender_user, &body.filter_id)? {
Some(filter) => filter,
None => return Err(Error::BadRequest(ErrorKind::NotFound, "Filter not found.")),
};
Ok(get_filter::Response::new(filter).into())
} }
/// # `PUT /_matrix/client/r0/user/{userId}/filter` /// # `PUT /_matrix/client/r0/user/{userId}/filter`
/// ///
/// TODO: Creates a new filter to be used by other endpoints. /// Creates a new filter to be used by other endpoints.
#[cfg_attr(feature = "conduit_bin", post("/_matrix/client/r0/user/<_>/filter"))] #[cfg_attr(
#[tracing::instrument] feature = "conduit_bin",
pub async fn create_filter_route() -> ConduitResult<create_filter::Response> { post("/_matrix/client/r0/user/<_>/filter", data = "<body>")
// TODO )]
Ok(create_filter::Response::new(utils::random_string(10)).into()) #[tracing::instrument(skip(db, body))]
pub async fn create_filter_route(
db: DatabaseGuard,
body: Ruma<create_filter::Request<'_>>,
) -> ConduitResult<create_filter::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
Ok(create_filter::Response::new(db.users.create_filter(sender_user, &body.filter)?).into())
} }

@ -138,6 +138,9 @@ pub async fn get_message_events_route(
let to = body.to.as_ref().map(|t| t.parse()); let to = body.to.as_ref().map(|t| t.parse());
db.rooms
.lazy_load_confirm_delivery(&sender_user, &sender_device, &body.room_id, from)?;
// Use limit or else 10 // Use limit or else 10
let limit = body.limit.try_into().map_or(10_usize, |l: u32| l as usize); let limit = body.limit.try_into().map_or(10_usize, |l: u32| l as usize);
@ -224,8 +227,6 @@ pub async fn get_message_events_route(
} }
} }
db.rooms
.lazy_load_confirm_delivery(&sender_user, &sender_device, &body.room_id, from)?;
resp.state = Vec::new(); resp.state = Vec::new();
for ll_id in &lazy_loaded { for ll_id in &lazy_loaded {
if let Some(member_event) = if let Some(member_event) =

@ -1,6 +1,10 @@
use crate::{database::DatabaseGuard, ConduitResult, Database, Error, Result, Ruma, RumaResponse}; use crate::{database::DatabaseGuard, ConduitResult, Database, Error, Result, Ruma, RumaResponse};
use ruma::{ use ruma::{
api::client::r0::{sync::sync_events, uiaa::UiaaResponse}, api::client::r0::{
filter::{IncomingFilterDefinition, LazyLoadOptions},
sync::sync_events,
uiaa::UiaaResponse,
},
events::{ events::{
room::member::{MembershipState, RoomMemberEventContent}, room::member::{MembershipState, RoomMemberEventContent},
AnySyncEphemeralRoomEvent, EventType, AnySyncEphemeralRoomEvent, EventType,
@ -77,34 +81,32 @@ pub async fn sync_events_route(
Entry::Vacant(v) => { Entry::Vacant(v) => {
let (tx, rx) = tokio::sync::watch::channel(None); let (tx, rx) = tokio::sync::watch::channel(None);
v.insert((body.since.clone(), rx.clone()));
tokio::spawn(sync_helper_wrapper( tokio::spawn(sync_helper_wrapper(
Arc::clone(&arc_db), Arc::clone(&arc_db),
sender_user.clone(), sender_user.clone(),
sender_device.clone(), sender_device.clone(),
body.since.clone(), body,
body.full_state,
body.timeout,
tx, tx,
)); ));
v.insert((body.since.clone(), rx)).1.clone() rx
} }
Entry::Occupied(mut o) => { Entry::Occupied(mut o) => {
if o.get().0 != body.since { if o.get().0 != body.since {
let (tx, rx) = tokio::sync::watch::channel(None); let (tx, rx) = tokio::sync::watch::channel(None);
o.insert((body.since.clone(), rx.clone()));
tokio::spawn(sync_helper_wrapper( tokio::spawn(sync_helper_wrapper(
Arc::clone(&arc_db), Arc::clone(&arc_db),
sender_user.clone(), sender_user.clone(),
sender_device.clone(), sender_device.clone(),
body.since.clone(), body,
body.full_state,
body.timeout,
tx, tx,
)); ));
o.insert((body.since.clone(), rx.clone()));
rx rx
} else { } else {
o.get().1.clone() o.get().1.clone()
@ -135,18 +137,16 @@ async fn sync_helper_wrapper(
db: Arc<DatabaseGuard>, db: Arc<DatabaseGuard>,
sender_user: Box<UserId>, sender_user: Box<UserId>,
sender_device: Box<DeviceId>, sender_device: Box<DeviceId>,
since: Option<String>, body: sync_events::IncomingRequest,
full_state: bool,
timeout: Option<Duration>,
tx: Sender<Option<ConduitResult<sync_events::Response>>>, tx: Sender<Option<ConduitResult<sync_events::Response>>>,
) { ) {
let since = body.since.clone();
let r = sync_helper( let r = sync_helper(
Arc::clone(&db), Arc::clone(&db),
sender_user.clone(), sender_user.clone(),
sender_device.clone(), sender_device.clone(),
since.clone(), body,
full_state,
timeout,
) )
.await; .await;
@ -179,9 +179,7 @@ async fn sync_helper(
db: Arc<DatabaseGuard>, db: Arc<DatabaseGuard>,
sender_user: Box<UserId>, sender_user: Box<UserId>,
sender_device: Box<DeviceId>, sender_device: Box<DeviceId>,
since: Option<String>, body: sync_events::IncomingRequest,
full_state: bool,
timeout: Option<Duration>,
// bool = caching allowed // bool = caching allowed
) -> Result<(sync_events::Response, bool), Error> { ) -> Result<(sync_events::Response, bool), Error> {
// TODO: match body.set_presence { // TODO: match body.set_presence {
@ -193,8 +191,26 @@ async fn sync_helper(
let next_batch = db.globals.current_count()?; let next_batch = db.globals.current_count()?;
let next_batch_string = next_batch.to_string(); let next_batch_string = next_batch.to_string();
// Load filter
let filter = match body.filter {
None => IncomingFilterDefinition::default(),
Some(sync_events::IncomingFilter::FilterDefinition(filter)) => filter,
Some(sync_events::IncomingFilter::FilterId(filter_id)) => db
.users
.get_filter(&sender_user, &filter_id)?
.unwrap_or_default(),
};
let (lazy_load_enabled, lazy_load_send_redundant) = match filter.room.state.lazy_load_options {
LazyLoadOptions::Enabled {
include_redundant_members: redundant,
} => (true, redundant),
_ => (false, false),
};
let mut joined_rooms = BTreeMap::new(); let mut joined_rooms = BTreeMap::new();
let since = since let since = body
.since
.clone() .clone()
.and_then(|string| string.parse().ok()) .and_then(|string| string.parse().ok())
.unwrap_or(0); .unwrap_or(0);
@ -374,8 +390,10 @@ async fn sync_helper(
.expect("state events have state keys"); .expect("state events have state keys");
if pdu.kind != EventType::RoomMember { if pdu.kind != EventType::RoomMember {
state_events.push(pdu); state_events.push(pdu);
} else if full_state || timeline_users.contains(state_key) { } else if !lazy_load_enabled
// TODO: check filter: is ll enabled? || body.full_state
|| timeline_users.contains(state_key)
{
lazy_loaded.push( lazy_loaded.push(
UserId::parse(state_key.as_ref()) UserId::parse(state_key.as_ref())
.expect("they are in timeline_users, so they should be correct"), .expect("they are in timeline_users, so they should be correct"),
@ -432,15 +450,6 @@ async fn sync_helper(
let since_state_ids = db.rooms.state_full_ids(since_shortstatehash)?; let since_state_ids = db.rooms.state_full_ids(since_shortstatehash)?;
/*
let state_events = if joined_since_last_sync || full_state {
current_state_ids
.iter()
.map(|(_, id)| db.rooms.get_pdu(id))
.filter_map(|r| r.ok().flatten())
.collect::<Vec<_>>()
} else {
*/
let mut state_events = Vec::new(); let mut state_events = Vec::new();
let mut lazy_loaded = Vec::new(); let mut lazy_loaded = Vec::new();
@ -459,7 +468,7 @@ async fn sync_helper(
.expect("state events have state keys"); .expect("state events have state keys");
if pdu.kind != EventType::RoomMember { if pdu.kind != EventType::RoomMember {
if full_state || since_state_ids.get(&key) != Some(&id) { if body.full_state || since_state_ids.get(&key) != Some(&id) {
state_events.push(pdu); state_events.push(pdu);
} }
continue; continue;
@ -469,16 +478,16 @@ async fn sync_helper(
let state_key_userid = UserId::parse(state_key.as_ref()) let state_key_userid = UserId::parse(state_key.as_ref())
.expect("they are in timeline_users, so they should be correct"); .expect("they are in timeline_users, so they should be correct");
if full_state || since_state_ids.get(&key) != Some(&id) { if body.full_state || since_state_ids.get(&key) != Some(&id) {
lazy_loaded.push(state_key_userid); lazy_loaded.push(state_key_userid);
state_events.push(pdu); state_events.push(pdu);
} else if timeline_users.contains(state_key) } else if timeline_users.contains(state_key)
&& !db.rooms.lazy_load_was_sent_before( && (!db.rooms.lazy_load_was_sent_before(
&sender_user, &sender_user,
&sender_device, &sender_device,
&room_id, &room_id,
&state_key_userid, &state_key_userid,
)? )? || lazy_load_send_redundant)
{ {
lazy_loaded.push(state_key_userid); lazy_loaded.push(state_key_userid);
state_events.push(pdu); state_events.push(pdu);
@ -858,7 +867,7 @@ async fn sync_helper(
}; };
// TODO: Retry the endpoint instead of returning (waiting for #118) // TODO: Retry the endpoint instead of returning (waiting for #118)
if !full_state if !body.full_state
&& response.rooms.is_empty() && response.rooms.is_empty()
&& response.presence.is_empty() && response.presence.is_empty()
&& response.account_data.is_empty() && response.account_data.is_empty()
@ -867,7 +876,7 @@ async fn sync_helper(
{ {
// Hang a few seconds so requests are not spammed // Hang a few seconds so requests are not spammed
// Stop hanging if new info arrives // Stop hanging if new info arrives
let mut duration = timeout.unwrap_or_default(); let mut duration = body.timeout.unwrap_or_default();
if duration.as_secs() > 30 { if duration.as_secs() > 30 {
duration = Duration::from_secs(30); duration = Duration::from_secs(30);
} }

@ -249,6 +249,7 @@ impl Database {
userid_masterkeyid: builder.open_tree("userid_masterkeyid")?, userid_masterkeyid: builder.open_tree("userid_masterkeyid")?,
userid_selfsigningkeyid: builder.open_tree("userid_selfsigningkeyid")?, userid_selfsigningkeyid: builder.open_tree("userid_selfsigningkeyid")?,
userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid")?, userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid")?,
userfilterid_filter: builder.open_tree("userfilterid_filter")?,
todeviceid_events: builder.open_tree("todeviceid_events")?, todeviceid_events: builder.open_tree("todeviceid_events")?,
}, },
uiaa: uiaa::Uiaa { uiaa: uiaa::Uiaa {

@ -1,6 +1,9 @@
use crate::{utils, Error, Result}; use crate::{utils, Error, Result};
use ruma::{ use ruma::{
api::client::{error::ErrorKind, r0::device::Device}, api::client::{
error::ErrorKind,
r0::{device::Device, filter::IncomingFilterDefinition},
},
encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
events::{AnyToDeviceEvent, EventType}, events::{AnyToDeviceEvent, EventType},
identifiers::MxcUri, identifiers::MxcUri,
@ -36,6 +39,8 @@ pub struct Users {
pub(super) userid_selfsigningkeyid: Arc<dyn Tree>, pub(super) userid_selfsigningkeyid: Arc<dyn Tree>,
pub(super) userid_usersigningkeyid: Arc<dyn Tree>, pub(super) userid_usersigningkeyid: Arc<dyn Tree>,
pub(super) userfilterid_filter: Arc<dyn Tree>, // UserFilterId = UserId + FilterId
pub(super) todeviceid_events: Arc<dyn Tree>, // ToDeviceId = UserId + DeviceId + Count pub(super) todeviceid_events: Arc<dyn Tree>, // ToDeviceId = UserId + DeviceId + Count
} }
@ -996,6 +1001,47 @@ impl Users {
// TODO: Unhook 3PID // TODO: Unhook 3PID
Ok(()) Ok(())
} }
/// Creates a new sync filter. Returns the filter id.
#[tracing::instrument(skip(self))]
pub fn create_filter(
&self,
user_id: &UserId,
filter: &IncomingFilterDefinition,
) -> Result<String> {
let filter_id = utils::random_string(4);
let mut key = user_id.as_bytes().to_vec();
key.push(0xff);
key.extend_from_slice(filter_id.as_bytes());
self.userfilterid_filter.insert(
&key,
&serde_json::to_vec(&filter).expect("filter is valid json"),
)?;
Ok(filter_id)
}
#[tracing::instrument(skip(self))]
pub fn get_filter(
&self,
user_id: &UserId,
filter_id: &str,
) -> Result<Option<IncomingFilterDefinition>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xff);
key.extend_from_slice(filter_id.as_bytes());
let raw = self.userfilterid_filter.get(&key)?;
if let Some(raw) = raw {
serde_json::from_slice(&raw)
.map_err(|_| Error::bad_database("Invalid filter event in db."))
} else {
Ok(None)
}
}
} }
/// Ensure that a user only sees signatures from themselves and the target user /// Ensure that a user only sees signatures from themselves and the target user

Loading…
Cancel
Save