1use std::{collections::BTreeMap, sync::Arc};
2
3use futures::{Stream, StreamExt, future::join};
4use ruma::{
5 MxcUri, OwnedMxcUri, OwnedRoomId, RoomId, UserId,
6 api::federation::query::get_profile_information,
7 events::room::member::{MembershipState, RoomMemberEventContent},
8 profile::{ProfileFieldName, ProfileFieldValue},
9};
10use serde::Deserialize;
11use serde_json::Value;
12use tuwunel_core::{
13 Err, Result, err, extract_variant, implement,
14 matrix::PduBuilder,
15 utils::{
16 TryReadyExt,
17 future::TryExtExt,
18 stream::{IterStream, TryIgnore, automatic_width},
19 },
20 warn,
21};
22use tuwunel_database::{Deserialized, Ignore, Interfix, Json, Map};
23
24pub struct Service {
25 services: Arc<crate::services::OnceServices>,
26 useridprofilekey_value: Arc<Map>,
27}
28
29impl crate::Service for Service {
30 fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
31 Ok(Arc::new(Self {
32 services: args.services.clone(),
33 useridprofilekey_value: args.db["useridprofilekey_value"].clone(),
34 }))
35 }
36
37 fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
38}
39
40#[derive(Copy, Clone, Debug, Eq, PartialEq)]
44pub enum Propagation {
45 All,
47
48 Unchanged,
52
53 None,
55}
56
57#[implement(Service)]
58pub async fn update_all_rooms(
59 &self,
60 user_id: &UserId,
61 profile_values: &[(ProfileFieldName, Option<Value>)],
62 propagation: Propagation,
63) {
64 if matches!(propagation, Propagation::None) {
65 return;
66 }
67
68 if !profile_values.iter().any(|(name, _)| {
69 matches!(name, ProfileFieldName::DisplayName | ProfileFieldName::AvatarUrl)
70 }) {
71 return;
72 }
73
74 let (current_displayname, current_avatar_url) =
75 if matches!(propagation, Propagation::Unchanged) {
76 join(self.displayname(user_id).ok(), self.avatar_url(user_id).ok()).await
77 } else {
78 (None, None)
79 };
80
81 let rooms: Vec<OwnedRoomId> = self
82 .services
83 .state_cache
84 .rooms_joined(user_id)
85 .map(Into::into)
86 .collect()
87 .await;
88
89 rooms
90 .iter()
91 .stream()
92 .for_each_concurrent(automatic_width(), async |room_id| {
93 if let Err(e) = self
94 .update_room(
95 user_id,
96 room_id,
97 profile_values,
98 propagation,
99 current_displayname.as_deref(),
100 current_avatar_url.as_deref(),
101 )
102 .await
103 {
104 warn!(
105 %user_id,
106 %room_id,
107 %e,
108 "Failed to update room profile",
109 );
110 }
111 })
112 .await;
113}
114
115#[implement(Service)]
116async fn update_room(
117 &self,
118 user_id: &UserId,
119 room_id: &RoomId,
120 profile_values: &[(ProfileFieldName, Option<Value>)],
121 propagation: Propagation,
122 current_displayname: Option<&str>,
123 current_avatar_url: Option<&MxcUri>,
124) -> Result {
125 let unchanged = match propagation {
126 | Propagation::All => false,
127 | Propagation::Unchanged => true,
128 | Propagation::None => return Ok(()),
129 };
130
131 let mut content = self
132 .services
133 .state_accessor
134 .get_member(room_id, user_id)
135 .await?;
136
137 if !matches!(content.membership, MembershipState::Join) {
138 return Ok(());
139 }
140
141 let mut changed = false;
142
143 for (name, value) in profile_values {
144 match name {
145 | ProfileFieldName::DisplayName => {
146 if unchanged && content.displayname.as_deref() != current_displayname {
147 continue;
148 }
149
150 let displayname = value.clone().map(|value| {
151 extract_variant!(value, Value::String).expect("invalid profile value type")
152 });
153
154 content.displayname = displayname;
155
156 changed = true;
157 },
158 | ProfileFieldName::AvatarUrl => {
159 if unchanged && content.avatar_url.as_deref() != current_avatar_url {
160 continue;
161 }
162
163 let avatar_url = value.clone().map(|value| {
164 serde_json::from_value(value).expect("invalid profile value type")
165 });
166
167 content.avatar_url = avatar_url;
168
169 changed = true;
170 },
171 | _ => {},
172 }
173 }
174
175 if !changed {
176 return Ok(());
177 }
178
179 content.reason = None;
180
181 let state_lock = self.services.state.mutex.lock(room_id).await;
182
183 self.services
184 .timeline
185 .build_and_append_pdu(
186 PduBuilder::state(user_id.as_str(), &content),
187 user_id,
188 room_id,
189 &state_lock,
190 )
191 .await?;
192
193 Ok(())
194}
195
196#[implement(Service)]
199pub async fn set_displayname(
200 &self,
201 user_id: &UserId,
202 displayname: Option<&str>,
203 propagation: Option<Propagation>,
204) -> Result {
205 self.set_profile_keys(
206 user_id,
207 &[(
208 ProfileFieldName::DisplayName,
209 displayname.map(|displayname| {
210 serde_json::to_value(displayname).expect("displayname serialization cannot fail")
211 }),
212 )],
213 propagation,
214 )
215 .await
216}
217
218#[implement(Service)]
220pub async fn displayname(&self, user_id: &UserId) -> Result<String> {
221 self.profile_key(user_id, &ProfileFieldName::DisplayName)
222 .await
223}
224
225#[implement(Service)]
227pub async fn set_avatar_url(
228 &self,
229 user_id: &UserId,
230 avatar_url: Option<&MxcUri>,
231 propagation: Option<Propagation>,
232) -> Result {
233 self.set_profile_keys(
234 user_id,
235 &[(
236 ProfileFieldName::AvatarUrl,
237 avatar_url.map(|avatar_url| {
238 serde_json::to_value(avatar_url).expect("avatar url serialization cannot fail")
239 }),
240 )],
241 propagation,
242 )
243 .await
244}
245
246#[implement(Service)]
248pub async fn avatar_url(&self, user_id: &UserId) -> Result<OwnedMxcUri> {
249 self.profile_key(user_id, &ProfileFieldName::AvatarUrl)
250 .await
251}
252
253#[implement(Service)]
255pub async fn set_timezone(
256 &self,
257 user_id: &UserId,
258 timezone: Option<&str>,
259 propagation: Option<Propagation>,
260) -> Result {
261 self.set_profile_keys(
262 user_id,
263 &[(
264 ProfileFieldName::TimeZone,
265 timezone.map(|timezone| {
266 serde_json::to_value(timezone).expect("timezone serialization cannot fail")
267 }),
268 )],
269 propagation,
270 )
271 .await
272}
273
274#[implement(Service)]
276pub async fn timezone(&self, user_id: &UserId) -> Result<String> {
277 self.profile_key(user_id, &ProfileFieldName::TimeZone)
278 .await
279}
280
281#[implement(Service)]
283pub fn all_profile_keys(&self, user_id: &UserId) -> impl Stream<Item = ProfileFieldValue> + Send {
284 let prefix = (user_id, Interfix);
285 self.useridprofilekey_value
286 .stream_prefix(&prefix)
287 .ignore_err()
288 .map(move |((_, key), Json(val)): ((Ignore, _), _)| {
289 ProfileFieldValue::new(key, val).map_err(|_| {
290 err!(Database(
291 error!(%user_id, %key, "Invalid json in database profile value while iterating")
292 ))
293 })
294 })
295 .ignore_err()
296}
297
298#[implement(Service)]
299pub async fn clear_profile_keys(&self, user_id: &UserId) {
300 let prefix = (user_id, Interfix);
301
302 self.useridprofilekey_value
303 .keys_prefix_raw(&prefix)
304 .ready_try_for_each(|key| {
305 self.useridprofilekey_value.remove(key);
306 Ok(())
307 })
308 .await
309 .ok();
310}
311
312#[implement(Service)]
314pub async fn set_profile_keys(
315 &self,
316 user_id: &UserId,
317 profile_values: &[(ProfileFieldName, Option<Value>)],
318 propagation: Option<Propagation>,
319) -> Result {
320 if self.services.globals.user_is_local(user_id) {
321 for (name, value) in profile_values {
322 check_profile_key(name.as_str())?;
323
324 if let Some(value) = value {
325 self.enforce_profile_size(user_id, name.as_str(), value)
326 .await?;
327 }
328 }
329 }
330
331 let propagation = propagation.unwrap_or(
332 if self
333 .services
334 .config
335 .preserve_room_profile_overrides
336 {
337 Propagation::Unchanged
338 } else {
339 Propagation::All
340 },
341 );
342
343 if !matches!(propagation, Propagation::None) && self.services.globals.user_is_local(user_id) {
344 self.update_all_rooms(user_id, profile_values, propagation)
345 .await;
346 }
347
348 for (name, value) in profile_values {
349 let key = (user_id, name.as_str());
350
351 if let Some(value) = value {
352 self.useridprofilekey_value.put(key, Json(value));
353 } else {
354 self.useridprofilekey_value.del(key);
355 }
356 }
357
358 Ok(())
359}
360
361#[implement(Service)]
363pub async fn profile_key<T>(&self, user_id: &UserId, profile_key: &ProfileFieldName) -> Result<T>
364where
365 T: for<'de> Deserialize<'de> + Send,
366{
367 let key = (user_id, profile_key);
368 let Json(value) = self
369 .useridprofilekey_value
370 .qry(&key)
371 .await
372 .map_err(|_| err!(Request(NotFound("The requested profile key does not exist."))))?
373 .deserialized()
374 .map_err(|_| err!(Database("Cannot deserialize database profile value")))?;
375
376 Ok(value)
377}
378
379#[implement(Service)]
380pub async fn fill_profile_data(&self, user_id: &UserId, content: &mut RoomMemberEventContent) {
381 let displayname = self.displayname(user_id).ok();
382 let avatar_url = self.avatar_url(user_id).ok();
383
384 let (displayname, avatar_url) = join(displayname, avatar_url).await;
385
386 content.displayname = displayname;
387 content.avatar_url = avatar_url;
388}
389
390#[implement(Service)]
391pub async fn fetch_remote_profile(&self, user_id: &UserId) -> Result {
392 assert!(
393 !self.services.globals.user_is_local(user_id),
394 "fetch remote profile called with a local user"
395 );
396
397 if let Ok(response) = self
398 .services
399 .federation
400 .execute(user_id.server_name(), get_profile_information::v1::Request {
401 user_id: user_id.to_owned(),
402 field: None,
403 })
404 .await
405 {
406 if !self.services.users.exists(user_id).await {
407 self.services
408 .users
409 .create(user_id, None, None)
410 .await?;
411 }
412
413 for (key, value) in response.iter() {
414 self.set_profile_keys(
415 user_id,
416 &[(key.as_str().into(), Some(value.clone()))],
417 Some(Propagation::None),
418 )
419 .await?;
420 }
421 }
422
423 Ok(())
424}
425
426pub(super) const MAX_PROFILE_SIZE: usize = 65_536;
429
430#[implement(Service)]
434async fn enforce_profile_size(&self, user_id: &UserId, key: &str, value: &Value) -> Result {
435 let mut profile: BTreeMap<_, _> = self
436 .all_profile_keys(user_id)
437 .map(|profile_value| {
438 (
439 profile_value.field_name().as_str().to_owned(),
440 profile_value.value().into_owned(),
441 )
442 })
443 .collect()
444 .await;
445 profile.insert(key.to_owned(), value.clone());
446
447 let profile_size = serde_json::to_vec(&profile).map_or(0, |buf| buf.len());
448
449 if profile_size > MAX_PROFILE_SIZE {
450 return Err!(Request(ProfileTooLarge(
451 "Profile would exceed the maximum size of 64 KiB."
452 )));
453 }
454
455 Ok(())
456}
457
458const MAX_KEY_LENGTH: usize = 255;
460
461fn check_profile_key(name: &str) -> Result {
465 if name.len() > MAX_KEY_LENGTH {
466 return Err!(Request(KeyTooLarge("Profile key names cannot be longer than 255 bytes.")));
467 }
468
469 let ok = name
470 .bytes()
471 .next()
472 .is_some_and(|b| b.is_ascii_lowercase())
473 && name.bytes().all(|b| {
474 b.is_ascii_lowercase() || b.is_ascii_digit() || matches!(b, b'_' | b'.' | b'-')
475 });
476
477 if !ok {
478 return Err!(Request(BadJson(
479 "Profile key names must follow the Common Namespaced Identifier Grammar."
480 )));
481 }
482
483 Ok(())
484}