Skip to main content

tuwunel_service/media/
mod.rs

1pub mod blurhash;
2mod data;
3pub(super) mod migrations;
4mod preview;
5mod remote;
6mod tests;
7mod thumbnail;
8use std::{
9	collections::HashMap,
10	path::PathBuf,
11	sync::{Arc, Mutex},
12	time::{Duration, Instant, SystemTime},
13};
14
15use async_trait::async_trait;
16use base64::{Engine as _, engine::general_purpose};
17use futures::{FutureExt, StreamExt, TryFutureExt, TryStreamExt, pin_mut};
18use http::StatusCode;
19use ruma::{
20	Mxc, OwnedMxcUri, OwnedUserId, UserId,
21	api::error::{ErrorKind, RetryAfter},
22	http_headers::ContentDisposition,
23};
24use tokio::{fs, sync::Notify};
25use tuwunel_core::{
26	Err, Error, Result, debug, debug_error, debug_info, debug_warn, err, trace,
27	utils::{
28		self, BoolExt, MutexMap,
29		result::LogDebugErr,
30		stream::{IterStream, TryReadyExt},
31		time::now_millis,
32	},
33	warn,
34};
35
36use self::data::{Data, Metadata};
37pub use self::thumbnail::Dim;
38use crate::storage::Provider;
39
40#[derive(Debug)]
41pub struct Media {
42	pub content: Vec<u8>,
43	pub content_type: Option<String>,
44	pub content_disposition: Option<ContentDisposition>,
45}
46
47/// For MSC2246
48struct MXCState {
49	/// Save the notifier for each pending media upload
50	notifiers: Mutex<HashMap<OwnedMxcUri, Arc<Notify>>>,
51	/// Save the ratelimiter for each user
52	ratelimiter: Mutex<HashMap<OwnedUserId, (Instant, f64)>>,
53}
54
55pub struct Service {
56	pub(super) db: Data,
57	services: Arc<crate::services::OnceServices>,
58	url_preview_mutex: MutexMap<String, ()>,
59	federation_mutex: MutexMap<String, ()>,
60	mxc_state: MXCState,
61}
62
63/// generated MXC ID (`media-id`) length
64pub const MXC_LENGTH: usize = 32;
65
66/// Cache control for immutable objects.
67pub const CACHE_CONTROL_IMMUTABLE: &str = "private,max-age=31536000,immutable";
68
69/// Default cross-origin resource policy.
70pub const CORP_CROSS_ORIGIN: &str = "cross-origin";
71
72#[async_trait]
73impl crate::Service for Service {
74	fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
75		Ok(Arc::new(Self {
76			db: Data::new(args.db),
77			services: args.services.clone(),
78			url_preview_mutex: MutexMap::new(),
79			federation_mutex: MutexMap::new(),
80			mxc_state: MXCState {
81				notifiers: Mutex::new(HashMap::new()),
82				ratelimiter: Mutex::new(HashMap::new()),
83			},
84		}))
85	}
86
87	fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
88}
89
90impl Service {
91	/// Create a pending media upload ID.
92	#[tracing::instrument(level = "debug", skip(self))]
93	pub async fn create_pending(
94		&self,
95		mxc: &Mxc<'_>,
96		user: &UserId,
97		unused_expires_at: u64,
98	) -> Result {
99		let config = &self.services.server.config;
100
101		// Rate limiting (rc_media_create)
102		let rate = f64::from(config.media_rc_create_per_second);
103		let burst = f64::from(config.media_rc_create_burst_count);
104
105		// Check rate limiting
106		if rate > 0.0 && burst > 0.0 {
107			let now = Instant::now();
108			let mut ratelimiter = self.mxc_state.ratelimiter.lock()?;
109
110			let (last_time, tokens) = ratelimiter
111				.entry(user.to_owned())
112				.or_insert_with(|| (now, burst));
113
114			let elapsed = now.duration_since(*last_time).as_secs_f64();
115			let new_tokens = elapsed.mul_add(rate, *tokens).min(burst);
116
117			if new_tokens >= 1.0 {
118				*last_time = now;
119				*tokens = new_tokens - 1.0;
120			} else {
121				return Err(Error::Request(
122					ErrorKind::LimitExceeded(ruma::api::error::LimitExceededErrorData {
123						retry_after: None,
124					}),
125					"Too many pending media creation requests.".into(),
126					StatusCode::TOO_MANY_REQUESTS,
127				));
128			}
129		}
130
131		let max_uploads = config.max_pending_media_uploads;
132		let (current_uploads, earliest_expiration) =
133			self.db.count_pending_mxc_for_user(user).await;
134
135		// Check if the user has reached the maximum number of pending media uploads
136		if current_uploads >= max_uploads {
137			let retry_after = earliest_expiration.saturating_sub(now_millis());
138			return Err(Error::Request(
139				ErrorKind::LimitExceeded(ruma::api::error::LimitExceededErrorData {
140					retry_after: Some(RetryAfter::Delay(Duration::from_millis(retry_after))),
141				}),
142				"Maximum number of pending media uploads reached.".into(),
143				StatusCode::TOO_MANY_REQUESTS,
144			));
145		}
146
147		self.db
148			.insert_pending_mxc(mxc, user, unused_expires_at);
149
150		Ok(())
151	}
152
153	/// Uploads content to a pending media ID.
154	#[tracing::instrument(level = "debug", skip(self))]
155	pub async fn upload_pending(
156		&self,
157		mxc: &Mxc<'_>,
158		user: &UserId,
159		content_disposition: Option<&ContentDisposition>,
160		content_type: Option<&str>,
161		file: &[u8],
162	) -> Result {
163		let Ok((owner_id, expires_at)) = self.db.search_pending_mxc(mxc).await else {
164			if self.get_metadata(mxc).await.is_some() {
165				return Err!(Request(CannotOverwriteMedia("Media ID already has content")));
166			}
167
168			return Err!(Request(NotFound("Media not found")));
169		};
170
171		if owner_id != user {
172			return Err!(Request(Forbidden("You did not create this media ID")));
173		}
174
175		let current_time = now_millis();
176		if expires_at < current_time {
177			return Err!(Request(NotFound("Pending media ID expired")));
178		}
179
180		self.create(mxc, Some(user), content_disposition, content_type, file)
181			.await?;
182
183		self.db.remove_pending_mxc(mxc);
184
185		let mxc_uri: OwnedMxcUri = mxc.to_string().into();
186		if let Some(notifier) = self.mxc_state.notifiers.lock()?.remove(&mxc_uri) {
187			notifier.notify_waiters();
188		}
189
190		Ok(())
191	}
192
193	/// Uploads a file.
194	pub async fn create(
195		&self,
196		mxc: &Mxc<'_>,
197		user: Option<&UserId>,
198		content_disposition: Option<&ContentDisposition>,
199		content_type: Option<&str>,
200		file: &[u8],
201	) -> Result {
202		// Width, Height = 0 if it's not a thumbnail
203		let key = self.db.create_file_metadata(
204			mxc,
205			user,
206			&Dim::default(),
207			content_disposition,
208			content_type,
209		)?;
210
211		//TODO: Dangling metadata in database if creation fails
212		self.create_media_file(&key, file).await
213	}
214
215	/// Deletes a file in the database and from the media directory via an MXC
216	#[tracing::instrument(level = "trace", skip(self))]
217	pub async fn delete(&self, mxc: &Mxc<'_>) -> Result {
218		match self.db.search_mxc_metadata_prefix(mxc).await {
219			| Ok(keys) => {
220				for key in keys {
221					trace!(?mxc, "MXC Key: {key:?}");
222					debug_info!(?mxc, "Deleting from storage provider");
223
224					if let Err(e) = self.remove_media_file(&key).await {
225						debug_error!(?mxc, "Failed to remove media file: {e}");
226					}
227
228					debug_info!(?mxc, "Deleting from database");
229					self.db.delete_file_mxc(mxc).await;
230				}
231
232				Ok(())
233			},
234			| _ => {
235				Err!(Database(error!(
236					"Failed to find any media keys for MXC {mxc} in our database."
237				)))
238			},
239		}
240	}
241
242	/// Deletes all media by the specified user
243	///
244	/// currently, this is only practical for local users
245	#[tracing::instrument(level = "trace", skip(self))]
246	pub async fn delete_from_user(&self, user: &UserId) -> Result<usize> {
247		let mxcs = self.db.get_all_user_mxcs(user).await;
248		let mut deletion_count: usize = 0;
249
250		for mxc in mxcs {
251			let Ok(mxc) = mxc.as_str().try_into().inspect_err(|e| {
252				debug_error!(?mxc, "Failed to parse MXC URI from database: {e}");
253			}) else {
254				continue;
255			};
256
257			debug_info!(
258				%deletion_count,
259				"Deleting MXC {mxc} by user {user} from database and filesystem",
260			);
261			match self.delete(&mxc).await {
262				| Ok(()) => {
263					deletion_count = deletion_count.saturating_add(1);
264				},
265				| Err(e) => {
266					debug_error!(
267						%deletion_count,
268						"Failed to delete {mxc} from user {user}, ignoring error: {e}"
269					);
270				},
271			}
272		}
273
274		Ok(deletion_count)
275	}
276
277	/// Get file from local storage or make a federation request if it
278	/// originates remotely.
279	#[tracing::instrument(
280		level = "debug",
281		err(level = "debug")
282		skip(self),
283	)]
284	pub async fn get_or_fetch(&self, mxc: &Mxc<'_>, timeout_ms: Duration) -> Result<Media> {
285		if let Ok(media) = self.get(mxc, Some(timeout_ms)).await {
286			return Ok(media);
287		}
288
289		if self
290			.services
291			.globals
292			.server_is_ours(mxc.server_name)
293		{
294			return Err!(Request(NotFound("Local media not found.")));
295		}
296
297		let lock = self.federation_mutex.lock(&mxc.to_string()).await;
298
299		if self
300			.db
301			.file_metadata_exists(mxc, &Dim::default())
302			.await
303		{
304			drop(lock);
305			return self.get(mxc, None).await;
306		}
307
308		self.fetch_remote_content(mxc, None, timeout_ms)
309			.await
310	}
311
312	/// Get file from local storage while waiting up to a timeout_ms if it is
313	/// pending.
314	#[tracing::instrument(
315		level = "debug",
316		err(level = "trace")
317		skip(self),
318	)]
319	pub async fn get(&self, mxc: &Mxc<'_>, timeout: Option<Duration>) -> Result<Media> {
320		if let Ok(meta) = self.get_stored(mxc).await {
321			return Ok(meta);
322		}
323
324		let Some(timeout) = timeout else {
325			return Err!(Request(NotFound("Media not found.")));
326		};
327
328		let Ok(_pending) = self.db.search_pending_mxc(mxc).await else {
329			return Err!(Request(NotFound("Media not found.")));
330		};
331
332		let notifier = self
333			.mxc_state
334			.notifiers
335			.lock()?
336			.entry(mxc.to_string().into())
337			.or_insert_with(|| Arc::new(Notify::new()))
338			.clone();
339
340		if tokio::time::timeout(timeout, notifier.notified())
341			.await
342			.is_err()
343		{
344			return Err!(Request(NotYetUploaded("Media has not been uploaded yet")));
345		}
346
347		self.get_stored(mxc).await
348	}
349
350	/// Get file from local storage.
351	#[tracing::instrument(level = "debug", skip(self))]
352	pub async fn get_stored(&self, mxc: &Mxc<'_>) -> Result<Media> {
353		let meta = self
354			.db
355			.search_file_metadata(mxc, &Dim::default())
356			.await;
357
358		let Ok(Metadata { content_type, content_disposition, key }) = meta else {
359			return Err!(Request(NotFound("Media not found.")));
360		};
361
362		let path = self.get_media_name_sha256(&key);
363		let fetch = self
364			.storage_providers()
365			.stream()
366			.filter_map(async |provider| {
367				provider
368					.get(path.as_str())
369					.await
370					.log_debug_err()
371					.ok()
372			});
373
374		pin_mut!(fetch);
375		let Some(bytes) = fetch.next().await else {
376			return Err!(Request(NotFound("Media not found.")));
377		};
378
379		Ok(Media {
380			content: bytes.to_vec(),
381			content_type,
382			content_disposition,
383		})
384	}
385
386	/// Gets all the MXC URIs in our media database
387	pub async fn get_all_mxcs(&self) -> Result<Vec<OwnedMxcUri>> {
388		let all_keys = self.db.get_all_media_keys().await;
389
390		let mut mxcs = Vec::with_capacity(all_keys.len());
391
392		for key in all_keys {
393			trace!("Full MXC key from database: {key:?}");
394
395			let mut parts = key.split(|&b| b == 0xFF);
396			let mxc = parts
397				.next()
398				.map(|bytes| {
399					utils::string_from_bytes(bytes).map_err(|e| {
400						err!(Database(error!(
401							"Failed to parse MXC unicode bytes from our database: {e}"
402						)))
403					})
404				})
405				.transpose()?;
406
407			let Some(mxc_s) = mxc else {
408				debug_warn!(
409					?mxc,
410					"Parsed MXC URL unicode bytes from database but is still invalid"
411				);
412				continue;
413			};
414
415			trace!("Parsed MXC key to URL: {mxc_s}");
416			let mxc = OwnedMxcUri::from(mxc_s);
417
418			if mxc.is_valid() {
419				mxcs.push(mxc);
420			} else {
421				debug_warn!("{mxc:?} from database was found to not be valid");
422			}
423		}
424
425		Ok(mxcs)
426	}
427
428	/// Deletes all media files before or after the given time. Returns a usize
429	/// with the number of media files deleted.
430	pub async fn delete_range(
431		&self,
432		time: SystemTime,
433		older_than: bool,
434		newer_than: bool,
435		yes_i_want_to_delete_local_media: bool,
436	) -> Result<usize> {
437		let all_keys = self.db.get_all_media_keys().await;
438		let mut remote_mxcs = Vec::with_capacity(all_keys.len());
439
440		for key in all_keys {
441			trace!("Full MXC key from database: {key:?}");
442			let mut parts = key.split(|&b| b == 0xFF);
443			let mxc = parts
444				.next()
445				.map(|bytes| {
446					utils::string_from_bytes(bytes).map_err(|e| {
447						err!(Database(error!(
448							"Failed to parse MXC unicode bytes from our database: {e}"
449						)))
450					})
451				})
452				.transpose()?;
453
454			let Some(mxc_s) = mxc else {
455				debug_warn!(
456					?mxc,
457					"Parsed MXC URL unicode bytes from database but is still invalid"
458				);
459				continue;
460			};
461
462			trace!("Parsed MXC key to URL: {mxc_s}");
463			let mxc = OwnedMxcUri::from(mxc_s);
464			if (mxc.server_name() == Ok(self.services.globals.server_name())
465				&& !yes_i_want_to_delete_local_media)
466				|| !mxc.is_valid()
467			{
468				debug!("Ignoring local or broken media MXC: {mxc}");
469				continue;
470			}
471
472			let file_created_at = if let Some(file_metadata) = self
473				.storage_providers()
474				.stream()
475				.filter_map(async |provider| {
476					let path = self.get_media_name_sha256(&key);
477					match provider.head(&path).await {
478						| Ok(file_metadata) => {
479							trace!(%mxc, ?path, "Provider file metadata: {file_metadata:?}");
480							Some(file_metadata)
481						},
482						| Err(e) => {
483							debug_warn!(
484								"Failed to obtain {:?} file metadata for MXC {mxc} at file path \
485								 {path:?}\", skipping: {e}",
486								provider.name,
487							);
488							None
489						},
490					}
491				})
492				.boxed()
493				.next()
494				.await
495			{
496				SystemTime::from(file_metadata.last_modified)
497			} else {
498				continue;
499			};
500
501			debug!("File created at: {file_created_at:?}");
502
503			if file_created_at <= time && older_than {
504				debug!(
505					"File is older than user duration, pushing to list of file paths and keys \
506					 to delete."
507				);
508				remote_mxcs.push(mxc.to_string());
509			} else if file_created_at >= time && newer_than {
510				debug!(
511					"File is newer than user duration, pushing to list of file paths and keys \
512					 to delete."
513				);
514				remote_mxcs.push(mxc.to_string());
515			}
516		}
517
518		if remote_mxcs.is_empty() {
519			return Err!(Database("Did not found any eligible MXCs to delete."));
520		}
521
522		debug_info!("Deleting media now in the past {time:?}");
523
524		let mut deletion_count: usize = 0;
525
526		for mxc in remote_mxcs {
527			let Ok(mxc) = mxc.as_str().try_into() else {
528				debug_warn!("Invalid MXC in database, skipping");
529				continue;
530			};
531
532			debug_info!("Deleting MXC {mxc} from database and filesystem");
533
534			match self.delete(&mxc).await {
535				| Ok(()) => {
536					deletion_count = deletion_count.saturating_add(1);
537				},
538				| Err(e) => {
539					warn!("Failed to delete {mxc}, ignoring error and skipping: {e}");
540					continue;
541				},
542			}
543		}
544
545		Ok(deletion_count)
546	}
547
548	pub async fn create_media_dir(&self) -> Result {
549		let dir = self.get_media_dir();
550		Ok(fs::create_dir_all(dir).await?)
551	}
552
553	async fn remove_media_file(&self, key: &[u8]) -> Result {
554		let path = self.get_media_name_sha256(key);
555		self.storage_providers()
556			.stream()
557			.filter_map(async |provider| {
558				debug!(
559					?key, ?path, provider = ?provider.name,
560					"Deleting media file from provider",
561				);
562
563				provider
564					.delete_one(&path)
565					.await
566					.log_debug_err()
567					.ok()
568			})
569			.count()
570			.map(|count| {
571				count
572					.ge(&0)
573					.ok_or_else(|| err!(Request(NotFound("Failed to remove on any provider."))))
574			})
575			.await
576	}
577
578	async fn create_media_file(&self, key: &[u8], file: &[u8]) -> Result {
579		self.storage_providers()
580			.try_stream()
581			.ready_try_filter(|provider| {
582				let store_media_on_providers = &self.services.config.store_media_on_providers;
583
584				store_media_on_providers.is_empty()
585					|| store_media_on_providers.contains(&provider.name)
586			})
587			.and_then(async |provider| {
588				let path = self.get_media_name_sha256(key);
589				debug!(
590					?key, ?path,
591					len = ?file.len(),
592					provider = ?provider.name,
593					"Creating media file on storage provider."
594				);
595
596				if let Err(e) = provider
597					.put_one(path.as_str(), file.to_vec())
598					.await
599				{
600					return Err!(Database(error!(
601						?path,
602						?provider,
603						"Failed to store media on provider: {e:?}"
604					)));
605				}
606
607				Ok(1)
608			})
609			.ready_try_fold(0_usize, |a, c| Ok(a.saturating_add(c)))
610			.inspect_ok(|&uploads| assert!(uploads > 0, "Successfully saved to nowhere."))
611			.map_ok(|_| ())
612			.await
613	}
614
615	fn storage_providers(&self) -> impl Iterator<Item = &Arc<Provider>> + Send + '_ {
616		let explicit_providers = &self.services.config.media_storage_providers;
617
618		let or_all_providers = explicit_providers
619			.is_empty()
620			.then(|| self.services.storage.providers())
621			.into_iter()
622			.flatten();
623
624		explicit_providers
625			.iter()
626			.filter_map(|id| self.services.storage.provider(id).ok())
627			.chain(or_all_providers)
628	}
629
630	#[inline]
631	pub async fn get_metadata(&self, mxc: &Mxc<'_>) -> Option<Metadata> {
632		self.db
633			.search_file_metadata(mxc, &Dim::default())
634			.await
635			.ok()
636	}
637
638	#[inline]
639	#[must_use]
640	pub fn get_media_path_sha256(&self, key: &[u8]) -> PathBuf {
641		let mut r = self.get_media_dir();
642		r.push(self.get_media_name_sha256(key));
643		r
644	}
645
646	/// new SHA256 file name media function. requires database migrated. uses
647	/// SHA256 hash of the base64 key as the file name
648	#[inline]
649	#[must_use]
650	pub fn get_media_name_sha256(&self, key: &[u8]) -> String {
651		// Using the hash of the base64 key as the filename prevents the total
652		// length of the path from exceeding the maximum length in most
653		// filesystems
654		let digest = <sha2::Sha256 as sha2::Digest>::digest(key);
655		encode_key(&digest)
656	}
657
658	/// old base64 file name media function
659	/// This is the old version of `get_media_path_sha256` that uses the full
660	/// base64 key as the filename.
661	#[must_use]
662	pub fn get_media_path_b64(&self, key: &[u8]) -> PathBuf {
663		let mut r = self.get_media_dir();
664		let encoded = encode_key(key);
665		r.push(encoded);
666		r
667	}
668
669	#[must_use]
670	pub fn get_media_dir(&self) -> PathBuf {
671		let mut r = PathBuf::new();
672		r.push(self.services.server.config.database_path.clone());
673		r.push("media");
674		r
675	}
676}
677
678#[inline]
679#[must_use]
680pub fn encode_key(key: &[u8]) -> String { general_purpose::URL_SAFE_NO_PAD.encode(key) }