Skip to main content

tuwunel_service/media/
mod.rs

1mod data;
2pub(super) mod migrations;
3mod preview;
4mod remote;
5mod tests;
6mod thumbnail;
7use std::{
8	collections::HashMap,
9	path::PathBuf,
10	sync::{Arc, Mutex},
11	time::{Duration, Instant, SystemTime},
12};
13
14use async_trait::async_trait;
15use base64::{Engine as _, engine::general_purpose};
16use futures::{FutureExt, StreamExt, TryFutureExt, TryStreamExt, pin_mut};
17use http::StatusCode;
18use ruma::{
19	Mxc, OwnedMxcUri, OwnedUserId, UserId,
20	api::error::{ErrorKind, RetryAfter},
21	http_headers::ContentDisposition,
22};
23use tokio::{fs, sync::Notify};
24use tuwunel_core::{
25	Err, Error, Result, debug, debug_error, debug_info, debug_warn, err, trace,
26	utils::{
27		self, BoolExt, MutexMap,
28		result::LogDebugErr,
29		stream::{IterStream, TryReadyExt},
30		time::now_millis,
31	},
32	warn,
33};
34use url::Url;
35
36use self::data::{Data, Metadata};
37pub use self::thumbnail::Dim;
38use crate::storage::Provider;
39
40#[derive(Debug)]
41pub struct Media {
42	pub content: Vec<u8>,
43	pub content_type: Option<String>,
44	pub content_disposition: Option<ContentDisposition>,
45}
46
47/// For MSC2246
48struct MXCState {
49	/// Save the notifier for each pending media upload
50	notifiers: Mutex<HashMap<OwnedMxcUri, Arc<Notify>>>,
51	/// Save the ratelimiter for each user
52	ratelimiter: Mutex<HashMap<OwnedUserId, (Instant, f64)>>,
53}
54
55pub struct Service {
56	pub(super) db: Data,
57	services: Arc<crate::services::OnceServices>,
58	url_preview_mutex: MutexMap<String, ()>,
59	federation_mutex: MutexMap<String, ()>,
60	mxc_state: MXCState,
61}
62
63/// generated MXC ID (`media-id`) length
64pub const MXC_LENGTH: usize = 32;
65
66/// Cache control for immutable objects.
67pub const CACHE_CONTROL_IMMUTABLE: &str = "private,max-age=31536000,immutable";
68
69/// Default cross-origin resource policy.
70pub const CORP_CROSS_ORIGIN: &str = "cross-origin";
71
72/// Validity window for a presigned media download redirect (MSC3860).
73const REDIRECT_TTL: Duration = Duration::from_mins(5);
74
75#[async_trait]
76impl crate::Service for Service {
77	fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
78		Ok(Arc::new(Self {
79			db: Data::new(args.db),
80			services: args.services.clone(),
81			url_preview_mutex: MutexMap::new(),
82			federation_mutex: MutexMap::new(),
83			mxc_state: MXCState {
84				notifiers: Mutex::new(HashMap::new()),
85				ratelimiter: Mutex::new(HashMap::new()),
86			},
87		}))
88	}
89
90	fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
91}
92
93impl Service {
94	/// Create a pending media upload ID.
95	#[tracing::instrument(level = "debug", skip(self))]
96	pub async fn create_pending(
97		&self,
98		mxc: &Mxc<'_>,
99		user: &UserId,
100		unused_expires_at: u64,
101	) -> Result {
102		let config = &self.services.server.config;
103
104		// Rate limiting (rc_media_create)
105		let rate = f64::from(config.media_rc_create_per_second);
106		let burst = f64::from(config.media_rc_create_burst_count);
107
108		// Check rate limiting
109		if rate > 0.0 && burst > 0.0 {
110			let now = Instant::now();
111			let mut ratelimiter = self.mxc_state.ratelimiter.lock()?;
112
113			let (last_time, tokens) = ratelimiter
114				.entry(user.to_owned())
115				.or_insert_with(|| (now, burst));
116
117			let elapsed = now.duration_since(*last_time).as_secs_f64();
118			let new_tokens = elapsed.mul_add(rate, *tokens).min(burst);
119
120			if new_tokens >= 1.0 {
121				*last_time = now;
122				*tokens = new_tokens - 1.0;
123			} else {
124				return Err(Error::Request(
125					ErrorKind::LimitExceeded(ruma::api::error::LimitExceededErrorData {
126						retry_after: None,
127					}),
128					"Too many pending media creation requests.".into(),
129					StatusCode::TOO_MANY_REQUESTS,
130				));
131			}
132		}
133
134		let max_uploads = config.max_pending_media_uploads;
135		let (current_uploads, earliest_expiration) =
136			self.db.count_pending_mxc_for_user(user).await;
137
138		// Check if the user has reached the maximum number of pending media uploads
139		if current_uploads >= max_uploads {
140			let retry_after = earliest_expiration.saturating_sub(now_millis());
141			return Err(Error::Request(
142				ErrorKind::LimitExceeded(ruma::api::error::LimitExceededErrorData {
143					retry_after: Some(RetryAfter::Delay(Duration::from_millis(retry_after))),
144				}),
145				"Maximum number of pending media uploads reached.".into(),
146				StatusCode::TOO_MANY_REQUESTS,
147			));
148		}
149
150		self.db
151			.insert_pending_mxc(mxc, user, unused_expires_at);
152
153		Ok(())
154	}
155
156	/// Uploads content to a pending media ID.
157	#[tracing::instrument(level = "debug", skip(self))]
158	pub async fn upload_pending(
159		&self,
160		mxc: &Mxc<'_>,
161		user: &UserId,
162		content_disposition: Option<&ContentDisposition>,
163		content_type: Option<&str>,
164		file: &[u8],
165	) -> Result {
166		let Ok((owner_id, expires_at)) = self.db.search_pending_mxc(mxc).await else {
167			if self.get_metadata(mxc).await.is_some() {
168				return Err!(Request(CannotOverwriteMedia("Media ID already has content")));
169			}
170
171			return Err!(Request(NotFound("Media not found")));
172		};
173
174		if owner_id != user {
175			return Err!(Request(Forbidden("You did not create this media ID")));
176		}
177
178		let current_time = now_millis();
179		if expires_at < current_time {
180			return Err!(Request(NotFound("Pending media ID expired")));
181		}
182
183		self.create(mxc, Some(user), content_disposition, content_type, file)
184			.await?;
185
186		self.db.remove_pending_mxc(mxc);
187
188		let mxc_uri: OwnedMxcUri = mxc.to_string().into();
189		if let Some(notifier) = self.mxc_state.notifiers.lock()?.remove(&mxc_uri) {
190			notifier.notify_waiters();
191		}
192
193		Ok(())
194	}
195
196	/// Uploads a file.
197	pub async fn create(
198		&self,
199		mxc: &Mxc<'_>,
200		user: Option<&UserId>,
201		content_disposition: Option<&ContentDisposition>,
202		content_type: Option<&str>,
203		file: &[u8],
204	) -> Result {
205		// Width, Height = 0 if it's not a thumbnail
206		let key = self.db.create_file_metadata(
207			mxc,
208			user,
209			&Dim::default(),
210			content_disposition,
211			content_type,
212		)?;
213
214		//TODO: Dangling metadata in database if creation fails
215		self.create_media_file(&key, file).await
216	}
217
218	/// Deletes a file in the database and from the media directory via an MXC
219	#[tracing::instrument(level = "trace", skip(self))]
220	pub async fn delete(&self, mxc: &Mxc<'_>) -> Result {
221		match self.db.search_mxc_metadata_prefix(mxc).await {
222			| Ok(keys) => {
223				for key in keys {
224					trace!(?mxc, "MXC Key: {key:?}");
225					debug_info!(?mxc, "Deleting from storage provider");
226
227					if let Err(e) = self.remove_media_file(&key).await {
228						debug_error!(?mxc, "Failed to remove media file: {e}");
229					}
230
231					debug_info!(?mxc, "Deleting from database");
232					self.db.delete_file_mxc(mxc).await;
233				}
234
235				Ok(())
236			},
237			| _ => {
238				Err!(Database(error!(
239					"Failed to find any media keys for MXC {mxc} in our database."
240				)))
241			},
242		}
243	}
244
245	/// Deletes all media by the specified user
246	///
247	/// currently, this is only practical for local users
248	#[tracing::instrument(level = "trace", skip(self))]
249	pub async fn delete_from_user(&self, user: &UserId) -> Result<usize> {
250		let mxcs = self.db.get_all_user_mxcs(user).await;
251		let mut deletion_count: usize = 0;
252
253		for mxc in mxcs {
254			let Ok(mxc) = mxc.as_str().try_into().inspect_err(|e| {
255				debug_error!(?mxc, "Failed to parse MXC URI from database: {e}");
256			}) else {
257				continue;
258			};
259
260			debug_info!(
261				%deletion_count,
262				"Deleting MXC {mxc} by user {user} from database and filesystem",
263			);
264			match self.delete(&mxc).await {
265				| Ok(()) => {
266					deletion_count = deletion_count.saturating_add(1);
267				},
268				| Err(e) => {
269					debug_error!(
270						%deletion_count,
271						"Failed to delete {mxc} from user {user}, ignoring error: {e}"
272					);
273				},
274			}
275		}
276
277		Ok(deletion_count)
278	}
279
280	/// Get file from local storage or make a federation request if it
281	/// originates remotely.
282	#[tracing::instrument(
283		level = "debug",
284		err(level = "debug")
285		skip(self),
286	)]
287	pub async fn get_or_fetch(&self, mxc: &Mxc<'_>, timeout_ms: Duration) -> Result<Media> {
288		if let Ok(media) = self.get(mxc, Some(timeout_ms)).await {
289			return Ok(media);
290		}
291
292		if self
293			.services
294			.globals
295			.server_is_ours(mxc.server_name)
296		{
297			return Err!(Request(NotFound("Local media not found.")));
298		}
299
300		let lock = self.federation_mutex.lock(&mxc.to_string()).await;
301
302		if self
303			.db
304			.file_metadata_exists(mxc, &Dim::default())
305			.await
306		{
307			drop(lock);
308			return self.get(mxc, None).await;
309		}
310
311		self.fetch_remote_content(mxc, None, timeout_ms)
312			.await
313	}
314
315	/// Get file from local storage while waiting up to a timeout_ms if it is
316	/// pending.
317	#[tracing::instrument(
318		level = "debug",
319		err(level = "trace")
320		skip(self),
321	)]
322	pub async fn get(&self, mxc: &Mxc<'_>, timeout: Option<Duration>) -> Result<Media> {
323		if let Ok(meta) = self.get_stored(mxc).await {
324			return Ok(meta);
325		}
326
327		let Some(timeout) = timeout else {
328			return Err!(Request(NotFound("Media not found.")));
329		};
330
331		let Ok(_pending) = self.db.search_pending_mxc(mxc).await else {
332			return Err!(Request(NotFound("Media not found.")));
333		};
334
335		let notifier = self
336			.mxc_state
337			.notifiers
338			.lock()?
339			.entry(mxc.to_string().into())
340			.or_insert_with(|| Arc::new(Notify::new()))
341			.clone();
342
343		if tokio::time::timeout(timeout, notifier.notified())
344			.await
345			.is_err()
346		{
347			return Err!(Request(NotYetUploaded("Media has not been uploaded yet")));
348		}
349
350		self.get_stored(mxc).await
351	}
352
353	/// Get file from local storage.
354	#[tracing::instrument(level = "debug", skip(self))]
355	pub async fn get_stored(&self, mxc: &Mxc<'_>) -> Result<Media> {
356		let meta = self
357			.db
358			.search_file_metadata(mxc, &Dim::default())
359			.await;
360
361		let Ok(Metadata { content_type, content_disposition, key }) = meta else {
362			return Err!(Request(NotFound("Media not found.")));
363		};
364
365		let path = self.get_media_name_sha256(&key);
366		let fetch = self
367			.storage_providers()
368			.stream()
369			.filter_map(async |provider| {
370				provider
371					.get(path.as_str())
372					.await
373					.log_debug_err()
374					.ok()
375			});
376
377		pin_mut!(fetch);
378		let Some(bytes) = fetch.next().await else {
379			return Err!(Request(NotFound("Media not found.")));
380		};
381
382		Ok(Media {
383			content: bytes.to_vec(),
384			content_type,
385			content_disposition,
386		})
387	}
388
389	/// Presigned redirect URL for locally-stored media (MSC3860).
390	///
391	/// Returns the first configured provider's signed URL for the object, or
392	/// `None` when redirects are disabled, the media is unknown, or no provider
393	/// can presign (filesystem-only media).
394	#[tracing::instrument(level = "debug", skip(self))]
395	pub async fn redirect_url(&self, mxc: &Mxc<'_>, dim: &Dim) -> Result<Option<Url>> {
396		if !self.services.config.media_allow_redirect {
397			return Ok(None);
398		}
399
400		let Ok(Metadata { key, .. }) = self.db.search_file_metadata(mxc, dim).await else {
401			return Ok(None);
402		};
403
404		let path = self.get_media_name_sha256(&key);
405		let urls = self
406			.storage_providers()
407			.stream()
408			.filter_map(async |provider| {
409				provider
410					.signed_get_url(path.as_str(), REDIRECT_TTL)
411					.await
412					.log_debug_err()
413					.ok()
414					.flatten()
415			});
416
417		pin_mut!(urls);
418
419		Ok(urls.next().await)
420	}
421
422	/// Gets all the MXC URIs in our media database
423	pub async fn get_all_mxcs(&self) -> Result<Vec<OwnedMxcUri>> {
424		let all_keys = self.db.get_all_media_keys().await;
425
426		let mut mxcs = Vec::with_capacity(all_keys.len());
427
428		for key in all_keys {
429			trace!("Full MXC key from database: {key:?}");
430
431			let mut parts = key.split(|&b| b == 0xFF);
432			let mxc = parts
433				.next()
434				.map(|bytes| {
435					utils::string_from_bytes(bytes).map_err(|e| {
436						err!(Database(error!(
437							"Failed to parse MXC unicode bytes from our database: {e}"
438						)))
439					})
440				})
441				.transpose()?;
442
443			let Some(mxc_s) = mxc else {
444				debug_warn!(
445					?mxc,
446					"Parsed MXC URL unicode bytes from database but is still invalid"
447				);
448				continue;
449			};
450
451			trace!("Parsed MXC key to URL: {mxc_s}");
452			let mxc = OwnedMxcUri::from(mxc_s);
453
454			if mxc.is_valid() {
455				mxcs.push(mxc);
456			} else {
457				debug_warn!("{mxc:?} from database was found to not be valid");
458			}
459		}
460
461		Ok(mxcs)
462	}
463
464	/// Deletes all media files before or after the given time. Returns a usize
465	/// with the number of media files deleted.
466	pub async fn delete_range(
467		&self,
468		time: SystemTime,
469		older_than: bool,
470		newer_than: bool,
471		yes_i_want_to_delete_local_media: bool,
472	) -> Result<usize> {
473		let all_keys = self.db.get_all_media_keys().await;
474		let mut remote_mxcs = Vec::with_capacity(all_keys.len());
475
476		for key in all_keys {
477			trace!("Full MXC key from database: {key:?}");
478			let mut parts = key.split(|&b| b == 0xFF);
479			let mxc = parts
480				.next()
481				.map(|bytes| {
482					utils::string_from_bytes(bytes).map_err(|e| {
483						err!(Database(error!(
484							"Failed to parse MXC unicode bytes from our database: {e}"
485						)))
486					})
487				})
488				.transpose()?;
489
490			let Some(mxc_s) = mxc else {
491				debug_warn!(
492					?mxc,
493					"Parsed MXC URL unicode bytes from database but is still invalid"
494				);
495				continue;
496			};
497
498			trace!("Parsed MXC key to URL: {mxc_s}");
499			let mxc = OwnedMxcUri::from(mxc_s);
500			if (mxc.server_name() == Ok(self.services.globals.server_name())
501				&& !yes_i_want_to_delete_local_media)
502				|| !mxc.is_valid()
503			{
504				debug!("Ignoring local or broken media MXC: {mxc}");
505				continue;
506			}
507
508			let file_created_at = if let Some(file_metadata) = self
509				.storage_providers()
510				.stream()
511				.filter_map(async |provider| {
512					let path = self.get_media_name_sha256(&key);
513					match provider.head(&path).await {
514						| Ok(file_metadata) => {
515							trace!(%mxc, ?path, "Provider file metadata: {file_metadata:?}");
516							Some(file_metadata)
517						},
518						| Err(e) => {
519							debug_warn!(
520								"Failed to obtain {:?} file metadata for MXC {mxc} at file path \
521								 {path:?}\", skipping: {e}",
522								provider.name,
523							);
524							None
525						},
526					}
527				})
528				.boxed()
529				.next()
530				.await
531			{
532				SystemTime::from(file_metadata.last_modified)
533			} else {
534				continue;
535			};
536
537			debug!("File created at: {file_created_at:?}");
538
539			if file_created_at <= time && older_than {
540				debug!(
541					"File is older than user duration, pushing to list of file paths and keys \
542					 to delete."
543				);
544				remote_mxcs.push(mxc.to_string());
545			} else if file_created_at >= time && newer_than {
546				debug!(
547					"File is newer than user duration, pushing to list of file paths and keys \
548					 to delete."
549				);
550				remote_mxcs.push(mxc.to_string());
551			}
552		}
553
554		if remote_mxcs.is_empty() {
555			return Err!(Database("Did not found any eligible MXCs to delete."));
556		}
557
558		debug_info!("Deleting media now in the past {time:?}");
559
560		let mut deletion_count: usize = 0;
561
562		for mxc in remote_mxcs {
563			let Ok(mxc) = mxc.as_str().try_into() else {
564				debug_warn!("Invalid MXC in database, skipping");
565				continue;
566			};
567
568			debug_info!("Deleting MXC {mxc} from database and filesystem");
569
570			match self.delete(&mxc).await {
571				| Ok(()) => {
572					deletion_count = deletion_count.saturating_add(1);
573				},
574				| Err(e) => {
575					warn!("Failed to delete {mxc}, ignoring error and skipping: {e}");
576					continue;
577				},
578			}
579		}
580
581		Ok(deletion_count)
582	}
583
584	pub async fn create_media_dir(&self) -> Result {
585		let dir = self.get_media_dir();
586		Ok(fs::create_dir_all(dir).await?)
587	}
588
589	async fn remove_media_file(&self, key: &[u8]) -> Result {
590		let path = self.get_media_name_sha256(key);
591		self.storage_providers()
592			.stream()
593			.filter_map(async |provider| {
594				debug!(
595					?key, ?path, provider = ?provider.name,
596					"Deleting media file from provider",
597				);
598
599				provider
600					.delete_one(&path)
601					.await
602					.log_debug_err()
603					.ok()
604			})
605			.count()
606			.map(|count| {
607				count
608					.ge(&0)
609					.into_option()
610					.ok_or_else(|| err!(Request(NotFound("Failed to remove on any provider."))))
611			})
612			.await
613	}
614
615	async fn create_media_file(&self, key: &[u8], file: &[u8]) -> Result {
616		self.storage_providers()
617			.try_stream()
618			.ready_try_filter(|provider| {
619				let store_media_on_providers = &self.services.config.store_media_on_providers;
620
621				store_media_on_providers.is_empty()
622					|| store_media_on_providers.contains(&provider.name)
623			})
624			.and_then(async |provider| {
625				let path = self.get_media_name_sha256(key);
626				debug!(
627					?key, ?path,
628					len = ?file.len(),
629					provider = ?provider.name,
630					"Creating media file on storage provider."
631				);
632
633				if let Err(e) = provider
634					.put_one(path.as_str(), file.to_vec())
635					.await
636				{
637					return Err!(Database(error!(
638						?path,
639						?provider,
640						"Failed to store media on provider: {e:?}"
641					)));
642				}
643
644				Ok(1)
645			})
646			.ready_try_fold(0_usize, |a, c| Ok(a.saturating_add(c)))
647			.inspect_ok(|&uploads| assert!(uploads > 0, "Successfully saved to nowhere."))
648			.map_ok(|_| ())
649			.await
650	}
651
652	fn storage_providers(&self) -> impl Iterator<Item = &Arc<Provider>> + Send + '_ {
653		let explicit_providers = &self.services.config.media_storage_providers;
654
655		let or_all_providers = explicit_providers
656			.is_empty()
657			.then(|| self.services.storage.providers())
658			.into_iter()
659			.flatten();
660
661		explicit_providers
662			.iter()
663			.filter_map(|id| self.services.storage.provider(id).ok())
664			.chain(or_all_providers)
665	}
666
667	#[inline]
668	pub async fn get_metadata(&self, mxc: &Mxc<'_>) -> Option<Metadata> {
669		self.db
670			.search_file_metadata(mxc, &Dim::default())
671			.await
672			.ok()
673	}
674
675	#[inline]
676	#[must_use]
677	pub fn get_media_path_sha256(&self, key: &[u8]) -> PathBuf {
678		let mut r = self.get_media_dir();
679		r.push(self.get_media_name_sha256(key));
680		r
681	}
682
683	/// new SHA256 file name media function. requires database migrated. uses
684	/// SHA256 hash of the base64 key as the file name
685	#[inline]
686	#[must_use]
687	pub fn get_media_name_sha256(&self, key: &[u8]) -> String {
688		// Using the hash of the base64 key as the filename prevents the total
689		// length of the path from exceeding the maximum length in most
690		// filesystems
691		let digest = <sha2::Sha256 as sha2::Digest>::digest(key);
692		encode_key(&digest)
693	}
694
695	/// old base64 file name media function
696	/// This is the old version of `get_media_path_sha256` that uses the full
697	/// base64 key as the filename.
698	#[must_use]
699	pub fn get_media_path_b64(&self, key: &[u8]) -> PathBuf {
700		let mut r = self.get_media_dir();
701		let encoded = encode_key(key);
702		r.push(encoded);
703		r
704	}
705
706	#[must_use]
707	pub fn get_media_dir(&self) -> PathBuf {
708		let mut r = PathBuf::new();
709		r.push(self.services.server.config.database_path.clone());
710		r.push("media");
711		r
712	}
713}
714
715#[inline]
716#[must_use]
717pub fn encode_key(key: &[u8]) -> String { general_purpose::URL_SAFE_NO_PAD.encode(key) }