Merge branch 'statediffs' into 'master'

Statediffs

See merge request famedly/conduit!145
merge-requests/160/merge
Timo Kösters 3 years ago
commit 33481ec062

112
Cargo.lock generated

@ -248,7 +248,7 @@ dependencies = [
"jsonwebtoken", "jsonwebtoken",
"lru-cache", "lru-cache",
"num_cpus", "num_cpus",
"opentelemetry", "opentelemetry 0.16.0",
"opentelemetry-jaeger", "opentelemetry-jaeger",
"parking_lot", "parking_lot",
"pretty_env_logger", "pretty_env_logger",
@ -1465,17 +1465,47 @@ dependencies = [
"thiserror", "thiserror",
] ]
[[package]]
name = "opentelemetry"
version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1cf9b1c4e9a6c4de793c632496fa490bdc0e1eea73f0c91394f7b6990935d22"
dependencies = [
"async-trait",
"crossbeam-channel",
"futures",
"js-sys",
"lazy_static",
"percent-encoding",
"pin-project",
"rand 0.8.4",
"thiserror",
"tokio",
"tokio-stream",
]
[[package]] [[package]]
name = "opentelemetry-jaeger" name = "opentelemetry-jaeger"
version = "0.14.0" version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09a9fc8192722e7daa0c56e59e2336b797122fb8598383dcb11c8852733b435c" checksum = "db22f492873ea037bc267b35a0e8e4fb846340058cb7c864efe3d0bf23684593"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"lazy_static", "lazy_static",
"opentelemetry", "opentelemetry 0.16.0",
"opentelemetry-semantic-conventions",
"thiserror", "thiserror",
"thrift", "thrift",
"tokio",
]
[[package]]
name = "opentelemetry-semantic-conventions"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ffeac823339e8b0f27b961f4385057bf9f97f2863bc745bd015fd6091f2270e9"
dependencies = [
"opentelemetry 0.16.0",
] ]
[[package]] [[package]]
@ -2014,8 +2044,8 @@ dependencies = [
[[package]] [[package]]
name = "ruma" name = "ruma"
version = "0.2.0" version = "0.3.0"
source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced"
dependencies = [ dependencies = [
"assign", "assign",
"js_int", "js_int",
@ -2035,8 +2065,8 @@ dependencies = [
[[package]] [[package]]
name = "ruma-api" name = "ruma-api"
version = "0.17.1" version = "0.18.3"
source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced"
dependencies = [ dependencies = [
"bytes", "bytes",
"http", "http",
@ -2051,8 +2081,8 @@ dependencies = [
[[package]] [[package]]
name = "ruma-api-macros" name = "ruma-api-macros"
version = "0.17.1" version = "0.18.3"
source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced"
dependencies = [ dependencies = [
"proc-macro-crate", "proc-macro-crate",
"proc-macro2", "proc-macro2",
@ -2062,8 +2092,8 @@ dependencies = [
[[package]] [[package]]
name = "ruma-appservice-api" name = "ruma-appservice-api"
version = "0.3.0" version = "0.4.0"
source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced"
dependencies = [ dependencies = [
"ruma-api", "ruma-api",
"ruma-common", "ruma-common",
@ -2076,8 +2106,8 @@ dependencies = [
[[package]] [[package]]
name = "ruma-client-api" name = "ruma-client-api"
version = "0.11.0" version = "0.12.2"
source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced"
dependencies = [ dependencies = [
"assign", "assign",
"bytes", "bytes",
@ -2096,8 +2126,8 @@ dependencies = [
[[package]] [[package]]
name = "ruma-common" name = "ruma-common"
version = "0.5.4" version = "0.6.0"
source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced"
dependencies = [ dependencies = [
"indexmap", "indexmap",
"js_int", "js_int",
@ -2111,8 +2141,8 @@ dependencies = [
[[package]] [[package]]
name = "ruma-events" name = "ruma-events"
version = "0.23.2" version = "0.24.4"
source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced"
dependencies = [ dependencies = [
"indoc", "indoc",
"js_int", "js_int",
@ -2127,8 +2157,8 @@ dependencies = [
[[package]] [[package]]
name = "ruma-events-macros" name = "ruma-events-macros"
version = "0.23.2" version = "0.24.4"
source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced"
dependencies = [ dependencies = [
"proc-macro-crate", "proc-macro-crate",
"proc-macro2", "proc-macro2",
@ -2138,8 +2168,8 @@ dependencies = [
[[package]] [[package]]
name = "ruma-federation-api" name = "ruma-federation-api"
version = "0.2.0" version = "0.3.0"
source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced"
dependencies = [ dependencies = [
"js_int", "js_int",
"ruma-api", "ruma-api",
@ -2153,8 +2183,8 @@ dependencies = [
[[package]] [[package]]
name = "ruma-identifiers" name = "ruma-identifiers"
version = "0.19.4" version = "0.20.0"
source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced"
dependencies = [ dependencies = [
"paste", "paste",
"rand 0.8.4", "rand 0.8.4",
@ -2167,8 +2197,8 @@ dependencies = [
[[package]] [[package]]
name = "ruma-identifiers-macros" name = "ruma-identifiers-macros"
version = "0.19.4" version = "0.20.0"
source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced"
dependencies = [ dependencies = [
"quote", "quote",
"ruma-identifiers-validation", "ruma-identifiers-validation",
@ -2177,13 +2207,13 @@ dependencies = [
[[package]] [[package]]
name = "ruma-identifiers-validation" name = "ruma-identifiers-validation"
version = "0.4.0" version = "0.5.0"
source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced"
[[package]] [[package]]
name = "ruma-identity-service-api" name = "ruma-identity-service-api"
version = "0.2.0" version = "0.3.0"
source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced"
dependencies = [ dependencies = [
"js_int", "js_int",
"ruma-api", "ruma-api",
@ -2195,8 +2225,8 @@ dependencies = [
[[package]] [[package]]
name = "ruma-push-gateway-api" name = "ruma-push-gateway-api"
version = "0.2.0" version = "0.3.0"
source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced"
dependencies = [ dependencies = [
"js_int", "js_int",
"ruma-api", "ruma-api",
@ -2210,8 +2240,8 @@ dependencies = [
[[package]] [[package]]
name = "ruma-serde" name = "ruma-serde"
version = "0.4.1" version = "0.5.0"
source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced"
dependencies = [ dependencies = [
"bytes", "bytes",
"form_urlencoded", "form_urlencoded",
@ -2224,8 +2254,8 @@ dependencies = [
[[package]] [[package]]
name = "ruma-serde-macros" name = "ruma-serde-macros"
version = "0.4.1" version = "0.5.0"
source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced"
dependencies = [ dependencies = [
"proc-macro-crate", "proc-macro-crate",
"proc-macro2", "proc-macro2",
@ -2235,8 +2265,8 @@ dependencies = [
[[package]] [[package]]
name = "ruma-signatures" name = "ruma-signatures"
version = "0.8.0" version = "0.9.0"
source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced"
dependencies = [ dependencies = [
"base64 0.13.0", "base64 0.13.0",
"ed25519-dalek", "ed25519-dalek",
@ -2252,8 +2282,8 @@ dependencies = [
[[package]] [[package]]
name = "ruma-state-res" name = "ruma-state-res"
version = "0.2.0" version = "0.3.0"
source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced"
dependencies = [ dependencies = [
"itertools 0.10.1", "itertools 0.10.1",
"js_int", "js_int",
@ -3022,7 +3052,7 @@ version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c47440f2979c4cd3138922840eec122e3c0ba2148bc290f756bd7fd60fc97fff" checksum = "c47440f2979c4cd3138922840eec122e3c0ba2148bc290f756bd7fd60fc97fff"
dependencies = [ dependencies = [
"opentelemetry", "opentelemetry 0.15.0",
"tracing", "tracing",
"tracing-core", "tracing-core",
"tracing-log", "tracing-log",

@ -18,8 +18,8 @@ edition = "2018"
rocket = { version = "0.5.0-rc.1", features = ["tls"] } # Used to handle requests rocket = { version = "0.5.0-rc.1", features = ["tls"] } # Used to handle requests
# Used for matrix spec type definitions and helpers # Used for matrix spec type definitions and helpers
#ruma = { git = "https://github.com/ruma/ruma", rev = "eb19b0e08a901b87d11b3be0890ec788cc760492", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } ruma = { git = "https://github.com/ruma/ruma", rev = "f5ab038e22421ed338396ece977b6b2844772ced", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] }
ruma = { git = "https://github.com/timokoesters/ruma", rev = "a2d93500e1dbc87e7032a3c74f3b2479a7f84e93", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } #ruma = { git = "https://github.com/timokoesters/ruma", rev = "995ccea20f5f6d4a8fb22041749ed4de22fa1b6a", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] }
#ruma = { path = "../ruma/crates/ruma", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } #ruma = { path = "../ruma/crates/ruma", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] }
# Used for long polling and federation sender, should be the same as rocket::tokio # Used for long polling and federation sender, should be the same as rocket::tokio
@ -66,11 +66,11 @@ regex = "1.5.4"
jsonwebtoken = "7.2.0" jsonwebtoken = "7.2.0"
# Performance measurements # Performance measurements
tracing = { version = "0.1.26", features = ["release_max_level_warn"] } tracing = { version = "0.1.26", features = ["release_max_level_warn"] }
opentelemetry = "0.15.0"
tracing-subscriber = "0.2.19" tracing-subscriber = "0.2.19"
tracing-opentelemetry = "0.14.0" tracing-opentelemetry = "0.14.0"
tracing-flame = "0.1.0" tracing-flame = "0.1.0"
opentelemetry-jaeger = "0.14.0" opentelemetry = { version = "0.16.0", features = ["rt-tokio"] }
opentelemetry-jaeger = { version = "0.15.0", features = ["rt-tokio"] }
pretty_env_logger = "0.4.0" pretty_env_logger = "0.4.0"
lru-cache = "0.1.2" lru-cache = "0.1.2"
rusqlite = { version = "0.25.3", optional = true, features = ["bundled"] } rusqlite = { version = "0.25.3", optional = true, features = ["bundled"] }

@ -249,6 +249,8 @@ pub async fn register_route(
let room_id = RoomId::new(db.globals.server_name()); let room_id = RoomId::new(db.globals.server_name());
db.rooms.get_or_create_shortroomid(&room_id, &db.globals)?;
let mutex_state = Arc::clone( let mutex_state = Arc::clone(
db.globals db.globals
.roomid_mutex_state .roomid_mutex_state
@ -290,6 +292,7 @@ pub async fn register_route(
is_direct: None, is_direct: None,
third_party_invite: None, third_party_invite: None,
blurhash: None, blurhash: None,
reason: None,
}) })
.expect("event is valid, we just created it"), .expect("event is valid, we just created it"),
unsigned: None, unsigned: None,
@ -455,6 +458,7 @@ pub async fn register_route(
is_direct: None, is_direct: None,
third_party_invite: None, third_party_invite: None,
blurhash: None, blurhash: None,
reason: None,
}) })
.expect("event is valid, we just created it"), .expect("event is valid, we just created it"),
unsigned: None, unsigned: None,
@ -476,6 +480,7 @@ pub async fn register_route(
is_direct: None, is_direct: None,
third_party_invite: None, third_party_invite: None,
blurhash: None, blurhash: None,
reason: None,
}) })
.expect("event is valid, we just created it"), .expect("event is valid, we just created it"),
unsigned: None, unsigned: None,
@ -681,6 +686,7 @@ pub async fn deactivate_route(
is_direct: None, is_direct: None,
third_party_invite: None, third_party_invite: None,
blurhash: None, blurhash: None,
reason: None,
}; };
let mutex_state = Arc::clone( let mutex_state = Arc::clone(
@ -731,7 +737,7 @@ pub async fn deactivate_route(
pub async fn third_party_route( pub async fn third_party_route(
body: Ruma<get_contacts::Request>, body: Ruma<get_contacts::Request>,
) -> ConduitResult<get_contacts::Response> { ) -> ConduitResult<get_contacts::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let _sender_user = body.sender_user.as_ref().expect("user is authenticated");
Ok(get_contacts::Response::new(Vec::new()).into()) Ok(get_contacts::Response::new(Vec::new()).into())
} }

@ -44,7 +44,7 @@ pub async fn get_context_route(
let events_before = db let events_before = db
.rooms .rooms
.pdus_until(&sender_user, &body.room_id, base_token) .pdus_until(&sender_user, &body.room_id, base_token)?
.take( .take(
u32::try_from(body.limit).map_err(|_| { u32::try_from(body.limit).map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Limit value is invalid.") Error::BadRequest(ErrorKind::InvalidParam, "Limit value is invalid.")
@ -66,7 +66,7 @@ pub async fn get_context_route(
let events_after = db let events_after = db
.rooms .rooms
.pdus_after(&sender_user, &body.room_id, base_token) .pdus_after(&sender_user, &body.room_id, base_token)?
.take( .take(
u32::try_from(body.limit).map_err(|_| { u32::try_from(body.limit).map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Limit value is invalid.") Error::BadRequest(ErrorKind::InvalidParam, "Limit value is invalid.")

@ -262,6 +262,7 @@ pub async fn ban_user_route(
is_direct: None, is_direct: None,
third_party_invite: None, third_party_invite: None,
blurhash: db.users.blurhash(&body.user_id)?, blurhash: db.users.blurhash(&body.user_id)?,
reason: None,
}), }),
|event| { |event| {
let mut event = serde_json::from_value::<Raw<member::MemberEventContent>>( let mut event = serde_json::from_value::<Raw<member::MemberEventContent>>(
@ -563,6 +564,7 @@ async fn join_room_by_id_helper(
is_direct: None, is_direct: None,
third_party_invite: None, third_party_invite: None,
blurhash: db.users.blurhash(&sender_user)?, blurhash: db.users.blurhash(&sender_user)?,
reason: None,
}) })
.expect("event is valid, we just created it"), .expect("event is valid, we just created it"),
); );
@ -609,6 +611,8 @@ async fn join_room_by_id_helper(
) )
.await?; .await?;
db.rooms.get_or_create_shortroomid(&room_id, &db.globals)?;
let pdu = PduEvent::from_id_val(&event_id, join_event.clone()) let pdu = PduEvent::from_id_val(&event_id, join_event.clone())
.map_err(|_| Error::BadServerResponse("Invalid join event PDU."))?; .map_err(|_| Error::BadServerResponse("Invalid join event PDU."))?;
@ -693,6 +697,7 @@ async fn join_room_by_id_helper(
is_direct: None, is_direct: None,
third_party_invite: None, third_party_invite: None,
blurhash: db.users.blurhash(&sender_user)?, blurhash: db.users.blurhash(&sender_user)?,
reason: None,
}; };
db.rooms.build_and_append_pdu( db.rooms.build_and_append_pdu(
@ -844,6 +849,7 @@ pub async fn invite_helper<'a>(
membership: MembershipState::Invite, membership: MembershipState::Invite,
third_party_invite: None, third_party_invite: None,
blurhash: None, blurhash: None,
reason: None,
}) })
.expect("member event is valid value"); .expect("member event is valid value");
@ -1038,6 +1044,7 @@ pub async fn invite_helper<'a>(
is_direct: Some(is_direct), is_direct: Some(is_direct),
third_party_invite: None, third_party_invite: None,
blurhash: db.users.blurhash(&user_id)?, blurhash: db.users.blurhash(&user_id)?,
reason: None,
}) })
.expect("event is valid, we just created it"), .expect("event is valid, we just created it"),
unsigned: None, unsigned: None,

@ -128,7 +128,7 @@ pub async fn get_message_events_route(
get_message_events::Direction::Forward => { get_message_events::Direction::Forward => {
let events_after = db let events_after = db
.rooms .rooms
.pdus_after(&sender_user, &body.room_id, from) .pdus_after(&sender_user, &body.room_id, from)?
.take(limit) .take(limit)
.filter_map(|r| r.ok()) // Filter out buggy events .filter_map(|r| r.ok()) // Filter out buggy events
.filter_map(|(pdu_id, pdu)| { .filter_map(|(pdu_id, pdu)| {
@ -158,7 +158,7 @@ pub async fn get_message_events_route(
get_message_events::Direction::Backward => { get_message_events::Direction::Backward => {
let events_before = db let events_before = db
.rooms .rooms
.pdus_until(&sender_user, &body.room_id, from) .pdus_until(&sender_user, &body.room_id, from)?
.take(limit) .take(limit)
.filter_map(|r| r.ok()) // Filter out buggy events .filter_map(|r| r.ok()) // Filter out buggy events
.filter_map(|(pdu_id, pdu)| { .filter_map(|(pdu_id, pdu)| {

@ -33,6 +33,8 @@ pub async fn create_room_route(
let room_id = RoomId::new(db.globals.server_name()); let room_id = RoomId::new(db.globals.server_name());
db.rooms.get_or_create_shortroomid(&room_id, &db.globals)?;
let mutex_state = Arc::clone( let mutex_state = Arc::clone(
db.globals db.globals
.roomid_mutex_state .roomid_mutex_state
@ -105,6 +107,7 @@ pub async fn create_room_route(
is_direct: Some(body.is_direct), is_direct: Some(body.is_direct),
third_party_invite: None, third_party_invite: None,
blurhash: db.users.blurhash(&sender_user)?, blurhash: db.users.blurhash(&sender_user)?,
reason: None,
}) })
.expect("event is valid, we just created it"), .expect("event is valid, we just created it"),
unsigned: None, unsigned: None,
@ -173,7 +176,6 @@ pub async fn create_room_route(
)?; )?;
// 4. Canonical room alias // 4. Canonical room alias
if let Some(room_alias_id) = &alias { if let Some(room_alias_id) = &alias {
db.rooms.build_and_append_pdu( db.rooms.build_and_append_pdu(
PduBuilder { PduBuilder {
@ -193,7 +195,7 @@ pub async fn create_room_route(
&room_id, &room_id,
&db, &db,
&state_lock, &state_lock,
); )?;
} }
// 5. Events set by preset // 5. Events set by preset
@ -516,6 +518,7 @@ pub async fn upgrade_room_route(
is_direct: None, is_direct: None,
third_party_invite: None, third_party_invite: None,
blurhash: db.users.blurhash(&sender_user)?, blurhash: db.users.blurhash(&sender_user)?,
reason: None,
}) })
.expect("event is valid, we just created it"), .expect("event is valid, we just created it"),
unsigned: None, unsigned: None,

@ -3,7 +3,10 @@ use crate::{database::DatabaseGuard, utils, ConduitResult, Error, Ruma};
use ruma::{ use ruma::{
api::client::{ api::client::{
error::ErrorKind, error::ErrorKind,
r0::session::{get_login_types, login, logout, logout_all}, r0::{
session::{get_login_types, login, logout, logout_all},
uiaa::IncomingUserIdentifier,
},
}, },
UserId, UserId,
}; };
@ -60,7 +63,7 @@ pub async fn login_route(
identifier, identifier,
password, password,
} => { } => {
let username = if let login::IncomingUserIdentifier::MatrixId(matrix_id) = identifier { let username = if let IncomingUserIdentifier::MatrixId(matrix_id) = identifier {
matrix_id matrix_id
} else { } else {
return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type.")); return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type."));

@ -205,7 +205,7 @@ async fn sync_helper(
let mut non_timeline_pdus = db let mut non_timeline_pdus = db
.rooms .rooms
.pdus_until(&sender_user, &room_id, u64::MAX) .pdus_until(&sender_user, &room_id, u64::MAX)?
.filter_map(|r| { .filter_map(|r| {
// Filter out buggy events // Filter out buggy events
if r.is_err() { if r.is_err() {
@ -248,13 +248,13 @@ async fn sync_helper(
let first_pdu_before_since = db let first_pdu_before_since = db
.rooms .rooms
.pdus_until(&sender_user, &room_id, since) .pdus_until(&sender_user, &room_id, since)?
.next() .next()
.transpose()?; .transpose()?;
let pdus_after_since = db let pdus_after_since = db
.rooms .rooms
.pdus_after(&sender_user, &room_id, since) .pdus_after(&sender_user, &room_id, since)?
.next() .next()
.is_some(); .is_some();
@ -286,7 +286,7 @@ async fn sync_helper(
for hero in db for hero in db
.rooms .rooms
.all_pdus(&sender_user, &room_id) .all_pdus(&sender_user, &room_id)?
.filter_map(|pdu| pdu.ok()) // Ignore all broken pdus .filter_map(|pdu| pdu.ok()) // Ignore all broken pdus
.filter(|(_, pdu)| pdu.kind == EventType::RoomMember) .filter(|(_, pdu)| pdu.kind == EventType::RoomMember)
.map(|(_, pdu)| { .map(|(_, pdu)| {
@ -328,11 +328,11 @@ async fn sync_helper(
} }
} }
( Ok::<_, Error>((
Some(joined_member_count), Some(joined_member_count),
Some(invited_member_count), Some(invited_member_count),
heroes, heroes,
) ))
}; };
let ( let (
@ -343,7 +343,7 @@ async fn sync_helper(
state_events, state_events,
) = if since_shortstatehash.is_none() { ) = if since_shortstatehash.is_none() {
// Probably since = 0, we will do an initial sync // Probably since = 0, we will do an initial sync
let (joined_member_count, invited_member_count, heroes) = calculate_counts(); let (joined_member_count, invited_member_count, heroes) = calculate_counts()?;
let current_state_ids = db.rooms.state_full_ids(current_shortstatehash)?; let current_state_ids = db.rooms.state_full_ids(current_shortstatehash)?;
let state_events = current_state_ids let state_events = current_state_ids
@ -510,7 +510,7 @@ async fn sync_helper(
} }
let (joined_member_count, invited_member_count, heroes) = if send_member_count { let (joined_member_count, invited_member_count, heroes) = if send_member_count {
calculate_counts() calculate_counts()?
} else { } else {
(None, None, Vec::new()) (None, None, Vec::new())
}; };

@ -24,13 +24,14 @@ use rocket::{
request::{FromRequest, Request}, request::{FromRequest, Request},
Shutdown, State, Shutdown, State,
}; };
use ruma::{DeviceId, RoomId, ServerName, UserId}; use ruma::{DeviceId, EventId, RoomId, ServerName, UserId};
use serde::{de::IgnoredAny, Deserialize}; use serde::{de::IgnoredAny, Deserialize};
use std::{ use std::{
collections::{BTreeMap, HashMap}, collections::{BTreeMap, HashMap, HashSet},
convert::TryFrom, convert::{TryFrom, TryInto},
fs::{self, remove_dir_all}, fs::{self, remove_dir_all},
io::Write, io::Write,
mem::size_of,
ops::Deref, ops::Deref,
path::Path, path::Path,
sync::{Arc, Mutex, RwLock}, sync::{Arc, Mutex, RwLock},
@ -107,7 +108,7 @@ fn default_db_cache_capacity_mb() -> f64 {
} }
fn default_sqlite_wal_clean_second_interval() -> u32 { fn default_sqlite_wal_clean_second_interval() -> u32 {
15 * 60 // every 15 minutes 1 * 60 // every minute
} }
fn default_max_request_size() -> u32 { fn default_max_request_size() -> u32 {
@ -261,7 +262,11 @@ impl Database {
userroomid_highlightcount: builder.open_tree("userroomid_highlightcount")?, userroomid_highlightcount: builder.open_tree("userroomid_highlightcount")?,
statekey_shortstatekey: builder.open_tree("statekey_shortstatekey")?, statekey_shortstatekey: builder.open_tree("statekey_shortstatekey")?,
stateid_shorteventid: builder.open_tree("stateid_shorteventid")?,
shortroomid_roomid: builder.open_tree("shortroomid_roomid")?,
roomid_shortroomid: builder.open_tree("roomid_shortroomid")?,
shortstatehash_statediff: builder.open_tree("shortstatehash_statediff")?,
eventid_shorteventid: builder.open_tree("eventid_shorteventid")?, eventid_shorteventid: builder.open_tree("eventid_shorteventid")?,
shorteventid_eventid: builder.open_tree("shorteventid_eventid")?, shorteventid_eventid: builder.open_tree("shorteventid_eventid")?,
shorteventid_shortstatehash: builder.open_tree("shorteventid_shortstatehash")?, shorteventid_shortstatehash: builder.open_tree("shorteventid_shortstatehash")?,
@ -270,8 +275,12 @@ impl Database {
eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?, eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?,
referencedevents: builder.open_tree("referencedevents")?, referencedevents: builder.open_tree("referencedevents")?,
pdu_cache: Mutex::new(LruCache::new(1_000_000)), pdu_cache: Mutex::new(LruCache::new(100_000)),
auth_chain_cache: Mutex::new(LruCache::new(1_000_000)), auth_chain_cache: Mutex::new(LruCache::new(100_000)),
shorteventid_cache: Mutex::new(LruCache::new(1_000_000)),
eventidshort_cache: Mutex::new(LruCache::new(1_000_000)),
statekeyshort_cache: Mutex::new(LruCache::new(1_000_000)),
stateinfo_cache: Mutex::new(LruCache::new(50)),
}, },
account_data: account_data::AccountData { account_data: account_data::AccountData {
roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?,
@ -424,7 +433,6 @@ impl Database {
} }
if db.globals.database_version()? < 6 { if db.globals.database_version()? < 6 {
// TODO update to 6
// Set room member count // Set room member count
for (roomid, _) in db.rooms.roomid_shortstatehash.iter() { for (roomid, _) in db.rooms.roomid_shortstatehash.iter() {
let room_id = let room_id =
@ -437,6 +445,261 @@ impl Database {
println!("Migration: 5 -> 6 finished"); println!("Migration: 5 -> 6 finished");
} }
if db.globals.database_version()? < 7 {
// Upgrade state store
let mut last_roomstates: HashMap<RoomId, u64> = HashMap::new();
let mut current_sstatehash: Option<u64> = None;
let mut current_room = None;
let mut current_state = HashSet::new();
let mut counter = 0;
let mut handle_state =
|current_sstatehash: u64,
current_room: &RoomId,
current_state: HashSet<_>,
last_roomstates: &mut HashMap<_, _>| {
counter += 1;
println!("counter: {}", counter);
let last_roomsstatehash = last_roomstates.get(current_room);
let states_parents = last_roomsstatehash.map_or_else(
|| Ok(Vec::new()),
|&last_roomsstatehash| {
db.rooms.load_shortstatehash_info(dbg!(last_roomsstatehash))
},
)?;
let (statediffnew, statediffremoved) =
if let Some(parent_stateinfo) = states_parents.last() {
let statediffnew = current_state
.difference(&parent_stateinfo.1)
.cloned()
.collect::<HashSet<_>>();
let statediffremoved = parent_stateinfo
.1
.difference(&current_state)
.cloned()
.collect::<HashSet<_>>();
(statediffnew, statediffremoved)
} else {
(current_state, HashSet::new())
};
db.rooms.save_state_from_diff(
dbg!(current_sstatehash),
statediffnew,
statediffremoved,
2, // every state change is 2 event changes on average
states_parents,
)?;
/*
let mut tmp = db.rooms.load_shortstatehash_info(&current_sstatehash, &db)?;
let state = tmp.pop().unwrap();
println!(
"{}\t{}{:?}: {:?} + {:?} - {:?}",
current_room,
" ".repeat(tmp.len()),
utils::u64_from_bytes(&current_sstatehash).unwrap(),
tmp.last().map(|b| utils::u64_from_bytes(&b.0).unwrap()),
state
.2
.iter()
.map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap())
.collect::<Vec<_>>(),
state
.3
.iter()
.map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap())
.collect::<Vec<_>>()
);
*/
Ok::<_, Error>(())
};
for (k, seventid) in db._db.open_tree("stateid_shorteventid")?.iter() {
let sstatehash = utils::u64_from_bytes(&k[0..size_of::<u64>()])
.expect("number of bytes is correct");
let sstatekey = k[size_of::<u64>()..].to_vec();
if Some(sstatehash) != current_sstatehash {
if let Some(current_sstatehash) = current_sstatehash {
handle_state(
current_sstatehash,
current_room.as_ref().unwrap(),
current_state,
&mut last_roomstates,
)?;
last_roomstates
.insert(current_room.clone().unwrap(), current_sstatehash);
}
current_state = HashSet::new();
current_sstatehash = Some(sstatehash);
let event_id = db
.rooms
.shorteventid_eventid
.get(&seventid)
.unwrap()
.unwrap();
let event_id =
EventId::try_from(utils::string_from_bytes(&event_id).unwrap())
.unwrap();
let pdu = db.rooms.get_pdu(&event_id).unwrap().unwrap();
if Some(&pdu.room_id) != current_room.as_ref() {
current_room = Some(pdu.room_id.clone());
}
}
let mut val = sstatekey;
val.extend_from_slice(&seventid);
current_state.insert(val.try_into().expect("size is correct"));
}
if let Some(current_sstatehash) = current_sstatehash {
handle_state(
current_sstatehash,
current_room.as_ref().unwrap(),
current_state,
&mut last_roomstates,
)?;
}
db.globals.bump_database_version(7)?;
println!("Migration: 6 -> 7 finished");
}
if db.globals.database_version()? < 8 {
// Generate short room ids for all rooms
for (room_id, _) in db.rooms.roomid_shortstatehash.iter() {
let shortroomid = db.globals.next_count()?.to_be_bytes();
db.rooms.roomid_shortroomid.insert(&room_id, &shortroomid)?;
db.rooms.shortroomid_roomid.insert(&shortroomid, &room_id)?;
println!("Migration: 8");
}
// Update pduids db layout
let mut batch = db.rooms.pduid_pdu.iter().filter_map(|(key, v)| {
if !key.starts_with(b"!") {
return None;
}
let mut parts = key.splitn(2, |&b| b == 0xff);
let room_id = parts.next().unwrap();
let count = parts.next().unwrap();
let short_room_id = db
.rooms
.roomid_shortroomid
.get(&room_id)
.unwrap()
.expect("shortroomid should exist");
let mut new_key = short_room_id;
new_key.extend_from_slice(count);
Some((new_key, v))
});
db.rooms.pduid_pdu.insert_batch(&mut batch)?;
let mut batch2 = db.rooms.eventid_pduid.iter().filter_map(|(k, value)| {
if !value.starts_with(b"!") {
return None;
}
let mut parts = value.splitn(2, |&b| b == 0xff);
let room_id = parts.next().unwrap();
let count = parts.next().unwrap();
let short_room_id = db
.rooms
.roomid_shortroomid
.get(&room_id)
.unwrap()
.expect("shortroomid should exist");
let mut new_value = short_room_id;
new_value.extend_from_slice(count);
Some((k, new_value))
});
db.rooms.eventid_pduid.insert_batch(&mut batch2)?;
db.globals.bump_database_version(8)?;
println!("Migration: 7 -> 8 finished");
}
if db.globals.database_version()? < 9 {
// Update tokenids db layout
let batch = db
.rooms
.tokenids
.iter()
.filter_map(|(key, _)| {
if !key.starts_with(b"!") {
return None;
}
let mut parts = key.splitn(4, |&b| b == 0xff);
let room_id = parts.next().unwrap();
let word = parts.next().unwrap();
let _pdu_id_room = parts.next().unwrap();
let pdu_id_count = parts.next().unwrap();
let short_room_id = db
.rooms
.roomid_shortroomid
.get(&room_id)
.unwrap()
.expect("shortroomid should exist");
let mut new_key = short_room_id;
new_key.extend_from_slice(word);
new_key.push(0xff);
new_key.extend_from_slice(pdu_id_count);
println!("old {:?}", key);
println!("new {:?}", new_key);
Some((new_key, Vec::new()))
})
.collect::<Vec<_>>();
let mut iter = batch.into_iter().peekable();
while iter.peek().is_some() {
db.rooms
.tokenids
.insert_batch(&mut iter.by_ref().take(1000))?;
println!("smaller batch done");
}
println!("Deleting starts");
let batch2 = db
.rooms
.tokenids
.iter()
.filter_map(|(key, _)| {
if key.starts_with(b"!") {
println!("del {:?}", key);
Some(key)
} else {
None
}
})
.collect::<Vec<_>>();
for key in batch2 {
println!("del");
db.rooms.tokenids.remove(&key)?;
}
db.globals.bump_database_version(9)?;
println!("Migration: 8 -> 9 finished");
}
} }
let guard = db.read().await; let guard = db.read().await;

@ -35,6 +35,7 @@ pub trait Tree: Send + Sync {
) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>; ) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
fn increment(&self, key: &[u8]) -> Result<Vec<u8>>; fn increment(&self, key: &[u8]) -> Result<Vec<u8>>;
fn increment_batch<'a>(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()>;
fn scan_prefix<'a>( fn scan_prefix<'a>(
&'a self, &'a self,

@ -9,15 +9,13 @@ use std::{
path::{Path, PathBuf}, path::{Path, PathBuf},
pin::Pin, pin::Pin,
sync::Arc, sync::Arc,
time::{Duration, Instant},
}; };
use tokio::sync::oneshot::Sender; use tokio::sync::oneshot::Sender;
use tracing::{debug, warn}; use tracing::debug;
pub const MILLI: Duration = Duration::from_millis(1);
thread_local! { thread_local! {
static READ_CONNECTION: RefCell<Option<&'static Connection>> = RefCell::new(None); static READ_CONNECTION: RefCell<Option<&'static Connection>> = RefCell::new(None);
static READ_CONNECTION_ITERATOR: RefCell<Option<&'static Connection>> = RefCell::new(None);
} }
struct PreparedStatementIterator<'a> { struct PreparedStatementIterator<'a> {
@ -51,11 +49,11 @@ impl Engine {
fn prepare_conn(path: &Path, cache_size_kb: u32) -> Result<Connection> { fn prepare_conn(path: &Path, cache_size_kb: u32) -> Result<Connection> {
let conn = Connection::open(&path)?; let conn = Connection::open(&path)?;
conn.pragma_update(Some(Main), "page_size", &32768)?; conn.pragma_update(Some(Main), "page_size", &2048)?;
conn.pragma_update(Some(Main), "journal_mode", &"WAL")?; conn.pragma_update(Some(Main), "journal_mode", &"WAL")?;
conn.pragma_update(Some(Main), "synchronous", &"NORMAL")?; conn.pragma_update(Some(Main), "synchronous", &"NORMAL")?;
conn.pragma_update(Some(Main), "cache_size", &(-i64::from(cache_size_kb)))?; conn.pragma_update(Some(Main), "cache_size", &(-i64::from(cache_size_kb)))?;
conn.pragma_update(Some(Main), "wal_autocheckpoint", &0)?; conn.pragma_update(Some(Main), "wal_autocheckpoint", &2000)?;
Ok(conn) Ok(conn)
} }
@ -79,9 +77,25 @@ impl Engine {
}) })
} }
fn read_lock_iterator(&self) -> &'static Connection {
READ_CONNECTION_ITERATOR.with(|cell| {
let connection = &mut cell.borrow_mut();
if (*connection).is_none() {
let c = Box::leak(Box::new(
Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap(),
));
**connection = Some(c);
}
connection.unwrap()
})
}
pub fn flush_wal(self: &Arc<Self>) -> Result<()> { pub fn flush_wal(self: &Arc<Self>) -> Result<()> {
self.write_lock() // We use autocheckpoints
.pragma_update(Some(Main), "wal_checkpoint", &"TRUNCATE")?; //self.write_lock()
//.pragma_update(Some(Main), "wal_checkpoint", &"TRUNCATE")?;
Ok(()) Ok(())
} }
} }
@ -153,6 +167,34 @@ impl SqliteTable {
)?; )?;
Ok(()) Ok(())
} }
pub fn iter_with_guard<'a>(
&'a self,
guard: &'a Connection,
) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
let statement = Box::leak(Box::new(
guard
.prepare(&format!(
"SELECT key, value FROM {} ORDER BY key ASC",
&self.name
))
.unwrap(),
));
let statement_ref = NonAliasingBox(statement);
let iterator = Box::new(
statement
.query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1))))
.unwrap()
.map(|r| r.unwrap()),
);
Box::new(PreparedStatementIterator {
iterator,
statement_ref,
})
}
} }
impl Tree for SqliteTable { impl Tree for SqliteTable {
@ -164,16 +206,7 @@ impl Tree for SqliteTable {
#[tracing::instrument(skip(self, key, value))] #[tracing::instrument(skip(self, key, value))]
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> {
let guard = self.engine.write_lock(); let guard = self.engine.write_lock();
let start = Instant::now();
self.insert_with_guard(&guard, key, value)?; self.insert_with_guard(&guard, key, value)?;
let elapsed = start.elapsed();
if elapsed > MILLI {
warn!("insert took {:?} : {}", elapsed, &self.name);
}
drop(guard); drop(guard);
let watchers = self.watchers.read(); let watchers = self.watchers.read();
@ -216,53 +249,41 @@ impl Tree for SqliteTable {
Ok(()) Ok(())
} }
#[tracing::instrument(skip(self, iter))]
fn increment_batch<'a>(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()> {
let guard = self.engine.write_lock();
guard.execute("BEGIN", [])?;
for key in iter {
let old = self.get_with_guard(&guard, &key)?;
let new = crate::utils::increment(old.as_deref())
.expect("utils::increment always returns Some");
self.insert_with_guard(&guard, &key, &new)?;
}
guard.execute("COMMIT", [])?;
drop(guard);
Ok(())
}
#[tracing::instrument(skip(self, key))] #[tracing::instrument(skip(self, key))]
fn remove(&self, key: &[u8]) -> Result<()> { fn remove(&self, key: &[u8]) -> Result<()> {
let guard = self.engine.write_lock(); let guard = self.engine.write_lock();
let start = Instant::now();
guard.execute( guard.execute(
format!("DELETE FROM {} WHERE key = ?", self.name).as_str(), format!("DELETE FROM {} WHERE key = ?", self.name).as_str(),
[key], [key],
)?; )?;
let elapsed = start.elapsed();
if elapsed > MILLI {
debug!("remove: took {:012?} : {}", elapsed, &self.name);
}
// debug!("remove key: {:?}", &key);
Ok(()) Ok(())
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> { fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
let guard = self.engine.read_lock(); let guard = self.engine.read_lock_iterator();
let statement = Box::leak(Box::new( self.iter_with_guard(&guard)
guard
.prepare(&format!(
"SELECT key, value FROM {} ORDER BY key ASC",
&self.name
))
.unwrap(),
));
let statement_ref = NonAliasingBox(statement);
let iterator = Box::new(
statement
.query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1))))
.unwrap()
.map(|r| r.unwrap()),
);
Box::new(PreparedStatementIterator {
iterator,
statement_ref,
})
} }
#[tracing::instrument(skip(self, from, backwards))] #[tracing::instrument(skip(self, from, backwards))]
@ -271,7 +292,7 @@ impl Tree for SqliteTable {
from: &[u8], from: &[u8],
backwards: bool, backwards: bool,
) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> { ) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
let guard = self.engine.read_lock(); let guard = self.engine.read_lock_iterator();
let from = from.to_vec(); // TODO change interface? let from = from.to_vec(); // TODO change interface?
if backwards { if backwards {
@ -326,8 +347,6 @@ impl Tree for SqliteTable {
fn increment(&self, key: &[u8]) -> Result<Vec<u8>> { fn increment(&self, key: &[u8]) -> Result<Vec<u8>> {
let guard = self.engine.write_lock(); let guard = self.engine.write_lock();
let start = Instant::now();
let old = self.get_with_guard(&guard, key)?; let old = self.get_with_guard(&guard, key)?;
let new = let new =
@ -335,26 +354,11 @@ impl Tree for SqliteTable {
self.insert_with_guard(&guard, key, &new)?; self.insert_with_guard(&guard, key, &new)?;
let elapsed = start.elapsed();
if elapsed > MILLI {
debug!("increment: took {:012?} : {}", elapsed, &self.name);
}
// debug!("increment key: {:?}", &key);
Ok(new) Ok(new)
} }
#[tracing::instrument(skip(self, prefix))] #[tracing::instrument(skip(self, prefix))]
fn scan_prefix<'a>(&'a self, prefix: Vec<u8>) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> { fn scan_prefix<'a>(&'a self, prefix: Vec<u8>) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
// let name = self.name.clone();
// self.iter_from_thread(
// format!(
// "SELECT key, value FROM {} WHERE key BETWEEN ?1 AND ?1 || X'FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF' ORDER BY key ASC",
// name
// )
// [prefix]
// )
Box::new( Box::new(
self.iter_from(&prefix, false) self.iter_from(&prefix, false)
.take_while(move |(key, _)| key.starts_with(&prefix)), .take_while(move |(key, _)| key.starts_with(&prefix)),

File diff suppressed because it is too large Load Diff

@ -4,11 +4,14 @@ use crate::{client_server::SESSION_ID_LENGTH, utils, Error, Result};
use ruma::{ use ruma::{
api::client::{ api::client::{
error::ErrorKind, error::ErrorKind,
r0::uiaa::{IncomingAuthData, UiaaInfo}, r0::uiaa::{
IncomingAuthData, IncomingPassword, IncomingUserIdentifier::MatrixId, UiaaInfo,
},
}, },
signatures::CanonicalJsonValue, signatures::CanonicalJsonValue,
DeviceId, UserId, DeviceId, UserId,
}; };
use tracing::error;
use super::abstraction::Tree; use super::abstraction::Tree;
@ -49,126 +52,91 @@ impl Uiaa {
users: &super::users::Users, users: &super::users::Users,
globals: &super::globals::Globals, globals: &super::globals::Globals,
) -> Result<(bool, UiaaInfo)> { ) -> Result<(bool, UiaaInfo)> {
if let IncomingAuthData::DirectRequest { let mut uiaainfo = auth
kind, .session()
session, .map(|session| self.get_uiaa_session(&user_id, &device_id, session))
auth_parameters, .unwrap_or_else(|| Ok(uiaainfo.clone()))?;
} = &auth
{
let mut uiaainfo = session
.as_ref()
.map(|session| self.get_uiaa_session(&user_id, &device_id, session))
.unwrap_or_else(|| Ok(uiaainfo.clone()))?;
if uiaainfo.session.is_none() { if uiaainfo.session.is_none() {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
} }
match auth {
// Find out what the user completed // Find out what the user completed
match &**kind { IncomingAuthData::Password(IncomingPassword {
"m.login.password" => { identifier,
let identifier = auth_parameters.get("identifier").ok_or(Error::BadRequest( password,
ErrorKind::MissingParam, ..
"m.login.password needs identifier.", }) => {
))?; let username = match identifier {
MatrixId(username) => username,
let identifier_type = identifier.get("type").ok_or(Error::BadRequest( _ => {
ErrorKind::MissingParam,
"Identifier needs a type.",
))?;
if identifier_type != "m.id.user" {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Unrecognized, ErrorKind::Unrecognized,
"Identifier type not recognized.", "Identifier type not recognized.",
)); ))
} }
};
let username = identifier let user_id =
.get("user") UserId::parse_with_server_name(username.clone(), globals.server_name())
.ok_or(Error::BadRequest(
ErrorKind::MissingParam,
"Identifier needs user field.",
))?
.as_str()
.ok_or(Error::BadRequest(
ErrorKind::BadJson,
"User is not a string.",
))?;
let user_id = UserId::parse_with_server_name(username, globals.server_name())
.map_err(|_| { .map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid.") Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid.")
})?; })?;
let password = auth_parameters // Check if password is correct
.get("password") if let Some(hash) = users.password_hash(&user_id)? {
.ok_or(Error::BadRequest( let hash_matches =
ErrorKind::MissingParam, argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false);
"Password is missing.",
))? if !hash_matches {
.as_str() uiaainfo.auth_error = Some(ruma::api::client::error::ErrorBody {
.ok_or(Error::BadRequest( kind: ErrorKind::Forbidden,
ErrorKind::BadJson, message: "Invalid username or password.".to_owned(),
"Password is not a string.", });
))?; return Ok((false, uiaainfo));
// Check if password is correct
if let Some(hash) = users.password_hash(&user_id)? {
let hash_matches =
argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false);
if !hash_matches {
uiaainfo.auth_error = Some(ruma::api::client::error::ErrorBody {
kind: ErrorKind::Forbidden,
message: "Invalid username or password.".to_owned(),
});
return Ok((false, uiaainfo));
}
} }
// Password was correct! Let's add it to `completed`
uiaainfo.completed.push("m.login.password".to_owned());
} }
"m.login.dummy" => {
uiaainfo.completed.push("m.login.dummy".to_owned());
}
k => panic!("type not supported: {}", k),
}
// Check if a flow now succeeds // Password was correct! Let's add it to `completed`
let mut completed = false; uiaainfo.completed.push("m.login.password".to_owned());
'flows: for flow in &mut uiaainfo.flows { }
for stage in &flow.stages { IncomingAuthData::Dummy(_) => {
if !uiaainfo.completed.contains(stage) { uiaainfo.completed.push("m.login.dummy".to_owned());
continue 'flows;
}
}
// We didn't break, so this flow succeeded!
completed = true;
} }
k => error!("type not supported: {:?}", k),
}
if !completed { // Check if a flow now succeeds
self.update_uiaa_session( let mut completed = false;
user_id, 'flows: for flow in &mut uiaainfo.flows {
device_id, for stage in &flow.stages {
uiaainfo.session.as_ref().expect("session is always set"), if !uiaainfo.completed.contains(stage) {
Some(&uiaainfo), continue 'flows;
)?; }
return Ok((false, uiaainfo));
} }
// We didn't break, so this flow succeeded!
completed = true;
}
// UIAA was successful! Remove this session and return true if !completed {
self.update_uiaa_session( self.update_uiaa_session(
user_id, user_id,
device_id, device_id,
uiaainfo.session.as_ref().expect("session is always set"), uiaainfo.session.as_ref().expect("session is always set"),
None, Some(&uiaainfo),
)?; )?;
Ok((true, uiaainfo)) return Ok((false, uiaainfo));
} else {
panic!("FallbackAcknowledgement is not supported yet");
} }
// UIAA was successful! Remove this session and return true
self.update_uiaa_session(
user_id,
device_id,
uiaainfo.session.as_ref().expect("session is always set"),
None,
)?;
Ok((true, uiaainfo))
} }
fn set_uiaa_request( fn set_uiaa_request(

@ -17,7 +17,7 @@ use std::sync::Arc;
use database::Config; use database::Config;
pub use database::Database; pub use database::Database;
pub use error::{Error, Result}; pub use error::{Error, Result};
use opentelemetry::trace::Tracer; use opentelemetry::trace::{FutureExt, Tracer};
pub use pdu::PduEvent; pub use pdu::PduEvent;
pub use rocket::State; pub use rocket::State;
use ruma::api::client::error::ErrorKind; use ruma::api::client::error::ErrorKind;
@ -220,14 +220,17 @@ async fn main() {
}; };
if config.allow_jaeger { if config.allow_jaeger {
opentelemetry::global::set_text_map_propagator(opentelemetry_jaeger::Propagator::new());
let tracer = opentelemetry_jaeger::new_pipeline() let tracer = opentelemetry_jaeger::new_pipeline()
.with_service_name("conduit") .install_batch(opentelemetry::runtime::Tokio)
.install_simple()
.unwrap(); .unwrap();
let span = tracer.start("conduit"); let span = tracer.start("conduit");
start.await; start.with_current_context().await;
drop(span); drop(span);
println!("exporting");
opentelemetry::global::shutdown_tracer_provider();
} else { } else {
std::env::set_var("RUST_LOG", &config.log); std::env::set_var("RUST_LOG", &config.log);

@ -12,7 +12,7 @@ use ruma::{
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::json; use serde_json::json;
use std::{cmp::Ordering, collections::BTreeMap, convert::TryFrom}; use std::{cmp::Ordering, collections::BTreeMap, convert::TryFrom};
use tracing::error; use tracing::warn;
#[derive(Clone, Deserialize, Serialize, Debug)] #[derive(Clone, Deserialize, Serialize, Debug)]
pub struct PduEvent { pub struct PduEvent {
@ -322,7 +322,7 @@ pub(crate) fn gen_event_id_canonical_json(
pdu: &Raw<ruma::events::pdu::Pdu>, pdu: &Raw<ruma::events::pdu::Pdu>,
) -> crate::Result<(EventId, CanonicalJsonObject)> { ) -> crate::Result<(EventId, CanonicalJsonObject)> {
let value = serde_json::from_str(pdu.json().get()).map_err(|e| { let value = serde_json::from_str(pdu.json().get()).map_err(|e| {
error!("{:?}: {:?}", pdu, e); warn!("Error parsing incoming event {:?}: {:?}", pdu, e);
Error::BadServerResponse("Invalid PDU in server response") Error::BadServerResponse("Invalid PDU in server response")
})?; })?;

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save