Skip to main content

tuwunel_service/storage/
provider.rs

1pub mod local;
2pub mod s3;
3
4#[cfg(test)]
5mod tests;
6
7use std::{
8	iter::{from_fn, once},
9	ops::Range,
10	sync::Arc,
11	time::Duration,
12};
13
14use bytes::Bytes;
15use derive_more::Debug;
16use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
17use http::Method;
18use object_store::{
19	Attributes, CopyMode, DynObjectStore, GetResult, MultipartUpload, ObjectMeta, ObjectStore,
20	ObjectStoreExt, PutPayload, PutResult, path::Path, signer::Signer,
21};
22use tuwunel_core::{
23	Error, Result,
24	config::StorageProvider,
25	debug, err, error,
26	error::error_chain,
27	extract_variant, implement, info, trace,
28	utils::{
29		BoolExt,
30		result::FlatOk,
31		stream::{IterStream, TryReadyExt},
32	},
33};
34use url::Url;
35
36#[derive(Debug)]
37pub struct Provider {
38	pub name: String,
39
40	pub config: StorageProvider,
41
42	pub(crate) provider: Box<DynObjectStore>,
43
44	#[debug(skip)]
45	pub(crate) signer: Option<Arc<dyn Signer>>,
46
47	pub(crate) base_path: Option<Path>,
48
49	startup_check: bool,
50
51	#[expect(unused)]
52	#[debug(skip)]
53	services: Arc<crate::services::OnceServices>,
54}
55
56pub type FetchItem = (Bytes, (Range<u64>, u64));
57pub type FetchMetaItem = (Bytes, Arc<(Range<u64>, ObjectMeta, Attributes)>);
58
59#[implement(Provider)]
60#[tracing::instrument(skip_all, err)]
61pub(super) async fn start(self: &Arc<Self>) -> Result {
62	if self.startup_check {
63		self.startup_check().await?;
64	}
65
66	Ok(())
67}
68
69#[implement(Provider)]
70#[tracing::instrument(name = "check", skip_all, err)]
71async fn startup_check(self: &Arc<Self>) -> Result {
72	debug!(
73		name = ?self.name,
74		"Checking storage provider client connection...",
75	);
76	self.ping()
77		.inspect_ok(|()| {
78			info!(
79				name = %self.name,
80				"Connected to storage provider"
81			);
82		})
83		.await
84}
85
86/// Put object into store from streaming input.
87///
88/// Recommended to know the total size of the object. If size is `None`,
89/// multi-part upload may be selected even for small uploads below the
90/// configured threshold.
91#[implement(Provider)]
92#[tracing::instrument(
93	level = "debug",
94	err(level = "debug"),
95	skip_all,
96	fields(
97		provider = %self.name,
98		?path,
99		?size,
100	)
101)]
102pub async fn put<S, T>(&self, path: &str, size: Option<usize>, input: S) -> Result<PutResult>
103where
104	S: Stream<Item = Result<T>> + Send,
105	PutPayload: From<T> + From<PutPayload>,
106{
107	if size.is_none_or(|size| size >= self.multipart_threshold()) {
108		return self.put_multi(path, input).await;
109	}
110
111	debug!(
112		?size,
113		threshold = ?self.multipart_threshold(),
114		"Selecting single-part upload..."
115	);
116
117	let payload: PutPayload = input
118		.map_ok(PutPayload::from)
119		.try_collect::<Vec<_>>()
120		.await?
121		.into_iter()
122		.map(Bytes::from)
123		.collect();
124
125	self.put_single(path, payload).await
126}
127
128/// Put object into the store from contiguous input.
129///
130/// The size of input will be determined and multipart upload will be chosen as
131/// necessary internally.
132#[implement(Provider)]
133#[tracing::instrument(
134	level = "debug",
135	err(level = "debug"),
136	skip_all,
137	fields(
138		provider = %self.name,
139		?path,
140	)
141)]
142pub async fn put_one<T>(&self, path: &str, input: T) -> Result<PutResult>
143where
144	PutPayload: From<T> + From<PutPayload>,
145{
146	let payload: PutPayload = input.into();
147
148	if payload.content_length() < self.multipart_threshold() {
149		return self.put_single(path, payload).await;
150	}
151
152	let part_size = self.multipart_part_size();
153
154	debug!(
155		len = ?payload.content_length(),
156		threshold = ?self.multipart_threshold(),
157		?part_size,
158		"Selecting multi-part upload..."
159	);
160
161	self.put_multi(path, chunked(payload, part_size).try_stream())
162		.await
163}
164
165/// Put object into the store from streaming input using multipart upload.
166#[implement(Provider)]
167#[tracing::instrument(
168	level = "debug",
169	err(level = "debug"),
170	skip_all,
171	fields(
172		provider = %self.name,
173		?path,
174	)
175)]
176async fn put_multi<S, T>(&self, path: &str, input: S) -> Result<PutResult>
177where
178	S: Stream<Item = Result<T>> + Send,
179	PutPayload: From<T> + From<PutPayload>,
180{
181	let path = self.to_abs_path(path)?;
182	let mut handle = self
183		.provider
184		.put_multipart(&path)
185		.map_err(Error::from)
186		.await?;
187
188	match input
189		.try_for_each(|t| handle.put_part(t.into()).map_err(Error::from))
190		.inspect_err(|e| error!(?path, chain = %error_chain(e), "Failed to store object"))
191		.await
192	{
193		| Ok(()) =>
194			handle
195				.complete()
196				.map_err(Error::from)
197				.inspect_err(|e| {
198					error!(
199						?path,
200						chain = %error_chain(e),
201						"Failed to store object during completion",
202					);
203				})
204				.await,
205
206		| Err(e) =>
207			handle
208				.abort()
209				.map_ok(|()| Err(e))
210				.map_err(Error::from)
211				.inspect_err(|e| {
212					error!(
213						?path,
214						chain = %error_chain(e),
215						"Additional errors during error handling",
216					);
217				})
218				.await?,
219	}
220}
221
222/// Put object into the store from contiguous input non-multipart upload.
223#[implement(Provider)]
224#[tracing::instrument(
225	level = "debug",
226	err(level = "debug"),
227	skip_all,
228	fields(
229		provider = %self.name,
230		?path,
231	)
232)]
233async fn put_single(&self, path: &str, input: PutPayload) -> Result<PutResult> {
234	let path = self.to_abs_path(path)?;
235
236	self.provider
237		.put(&path, input)
238		.map_err(Error::from)
239		.await
240}
241
242#[implement(Provider)]
243#[tracing::instrument(
244	level = "debug",
245	skip_all,
246	fields(
247		provider = %self.name,
248		?path,
249	)
250)]
251pub fn fetch_with_metadata(
252	&self,
253	path: &str,
254) -> impl Stream<Item = Result<FetchMetaItem>> + Send {
255	self.load(path)
256		.map_ok(|result| {
257			let meta = (result.range.clone(), result.meta.clone(), result.attributes.clone());
258			let data = Arc::new(meta);
259
260			result
261				.into_stream()
262				.map_err(Error::from)
263				.map_ok(move |bytes| (bytes, data.clone()))
264		})
265		.map_err(Error::from)
266		.try_flatten_stream()
267}
268
269#[implement(Provider)]
270#[tracing::instrument(
271	level = "debug",
272	skip_all,
273	fields(
274		provider = %self.name,
275		?path,
276	)
277)]
278pub fn fetch(&self, path: &str) -> impl Stream<Item = Result<FetchItem>> + Send {
279	self.load(path)
280		.map_ok(|result| {
281			let size = result.meta.size;
282			let range = result.range.clone();
283
284			result
285				.into_stream()
286				.map_err(Error::from)
287				.map_ok(move |bytes| (bytes, (range.clone(), size)))
288		})
289		.map_err(Error::from)
290		.try_flatten_stream()
291}
292
293#[implement(Provider)]
294#[tracing::instrument(
295	level = "debug",
296	err(level = "debug"),
297	skip_all,
298	fields(
299		provider = %self.name,
300		?path,
301	)
302)]
303pub async fn get(&self, path: &str) -> Result<Bytes> {
304	self.load(path)
305		.map_ok(GetResult::bytes)
306		.await?
307		.map_err(Error::from)
308		.await
309}
310
311#[implement(Provider)]
312#[tracing::instrument(
313	level = "debug",
314	err(level = "debug"),
315	skip_all,
316	fields(
317		provider = %self.name,
318		?path,
319	)
320)]
321pub async fn load(&self, path: &str) -> Result<GetResult> {
322	let path = self.to_abs_path(path)?;
323
324	self.provider
325		.get(&path)
326		.map_err(Error::from)
327		.await
328}
329
330/// Presign a time-limited GET URL for an object, when this provider supports
331/// signing (S3).
332#[implement(Provider)]
333#[tracing::instrument(
334	level = "debug",
335	err(level = "debug"),
336	skip_all,
337	fields(
338		provider = %self.name,
339		?path,
340		?ttl,
341	)
342)]
343pub async fn signed_get_url(&self, path: &str, ttl: Duration) -> Result<Option<Url>> {
344	let Some(signer) = self.signer.as_ref() else {
345		return Ok(None);
346	};
347
348	let path = self.to_abs_path(path)?;
349
350	signer
351		.signed_url(Method::GET, &path, ttl)
352		.map_err(Error::from)
353		.map_ok(Some)
354		.await
355}
356
357#[implement(Provider)]
358#[tracing::instrument(
359	level = "debug",
360	err(level = "debug"),
361	skip_all,
362	fields(
363		provider = %self.name,
364		?path,
365	)
366)]
367pub async fn delete_one(self: &Arc<Self>, path: &str) -> Result {
368	self.delete(once(path.to_owned()).stream())
369		.map_ok(|_| ())
370		.try_collect()
371		.await
372}
373
374#[implement(Provider)]
375#[tracing::instrument(
376	level = "debug",
377	skip_all,
378	fields(
379		provider = %self.name,
380	)
381)]
382pub fn delete<S>(self: &Arc<Self>, paths: S) -> impl Stream<Item = Result<Path>> + Send
383where
384	S: Stream<Item = String> + Send + 'static,
385{
386	let this = self.clone();
387	let paths = paths
388		.map(Ok)
389		.ready_and_then(move |path| {
390			use object_store::{Error, path};
391
392			this.to_abs_path(&path)
393				.map_err(|_| Error::InvalidPath {
394					source: path::Error::InvalidPath { path: path.into() },
395				})
396		})
397		.boxed();
398
399	self.provider
400		.delete_stream(paths)
401		.map_err(Error::from)
402}
403
404#[implement(Provider)]
405#[tracing::instrument(
406	level = "debug",
407	err(level = "debug"),
408	skip_all,
409	fields(
410		provider = %self.name,
411		?src,
412		?dst,
413		?overwrite,
414	)
415)]
416pub async fn rename(&self, src: &str, dst: &str, overwrite: CopyMode) -> Result {
417	let src = self.to_abs_path(src)?;
418	let dst = self.to_abs_path(dst)?;
419
420	match overwrite {
421		| CopyMode::Overwrite => self.provider.rename(&src, &dst).left_future(),
422		| CopyMode::Create => self
423			.provider
424			.rename_if_not_exists(&src, &dst)
425			.right_future(),
426	}
427	.map_err(Error::from)
428	.await
429}
430
431#[implement(Provider)]
432#[tracing::instrument(
433	level = "debug",
434	err(level = "debug"),
435	skip_all,
436	fields(
437		provider = %self.name,
438		?src,
439		?dst,
440		?overwrite,
441	)
442)]
443pub async fn copy(&self, src: &str, dst: &str, overwrite: CopyMode) -> Result {
444	let src = self.to_abs_path(src)?;
445	let dst = self.to_abs_path(dst)?;
446
447	match overwrite {
448		| CopyMode::Overwrite => self.provider.copy(&src, &dst).left_future(),
449		| CopyMode::Create => self
450			.provider
451			.copy_if_not_exists(&src, &dst)
452			.right_future(),
453	}
454	.map_err(Error::from)
455	.await
456}
457
458#[implement(Provider)]
459#[tracing::instrument(
460	level = "debug",
461	skip_all,
462	fields(
463		provider = %self.name,
464		?prefix,
465	)
466)]
467pub fn list(&self, prefix: Option<&str>) -> impl Stream<Item = Result<ObjectMeta>> + Send {
468	let abs_prefix = prefix
469		.map(Path::from)
470		.map(|p| self.prepend_base_path(p))
471		.or_else(|| self.base_path.clone());
472
473	self.provider
474		.list(abs_prefix.as_ref())
475		.map_err(Error::from)
476		.map_ok(|meta| ObjectMeta {
477			location: self.strip_base_path(meta.location),
478			..meta
479		})
480}
481
482#[implement(Provider)]
483#[tracing::instrument(
484	level = "debug",
485	err(level = "debug"),
486	skip_all,
487	fields(
488		provider = %self.name,
489		?path,
490	)
491)]
492pub async fn head(&self, path: &str) -> Result<ObjectMeta> {
493	self.provider
494		.head(&self.to_abs_path(path)?)
495		.map_err(Error::from)
496		.await
497}
498
499#[implement(Provider)]
500#[tracing::instrument(
501	level = "debug",
502	err(level = "error"),
503	skip_all,
504	fields(
505		provider = %self.name,
506	)
507)]
508pub async fn ping(&self) -> Result {
509	self.list(None)
510		.try_next()
511		.inspect_err(|e| {
512			error!(chain = %error_chain(e), "Failed to connect to storage provider");
513		})
514		.boxed()
515		.await
516		.map(|_| ())
517}
518
519#[implement(Provider)]
520fn to_abs_path(&self, location: &str) -> Result<Path> {
521	let location = Path::parse(location)
522		.map_err(|e| err!("Failed to parse location into canonical PathPart: {e}"))?;
523
524	let path = self.prepend_base_path(location);
525
526	trace!(
527		provider = ?self.name,
528		base_path = ?self.base_path,
529		?path,
530		"Computed absolute path for object on provider.",
531	);
532
533	Ok(path)
534}
535
536#[implement(Provider)]
537fn prepend_base_path(&self, location: Path) -> Path {
538	match self.base_path.as_ref() {
539		| Some(base_path) if !location.prefix_matches(base_path) => base_path
540			.parts()
541			.chain(location.parts())
542			.collect(),
543
544		| _ => location,
545	}
546}
547
548#[implement(Provider)]
549fn strip_base_path(&self, location: Path) -> Path {
550	self.base_path
551		.as_ref()
552		.and_then(|base_path| location.prefix_match(base_path))
553		.map(Iterator::collect)
554		.unwrap_or(location)
555}
556
557#[implement(Provider)]
558fn multipart_threshold(&self) -> usize {
559	extract_variant!(&self.config, StorageProvider::s3)
560		.map(|config| config.multipart_threshold.as_u64())
561		.map(TryInto::try_into)
562		.flat_ok()
563		.unwrap_or(usize::MAX)
564}
565
566#[implement(Provider)]
567fn multipart_part_size(&self) -> usize {
568	extract_variant!(&self.config, StorageProvider::s3)
569		.map(|config| config.multipart_part_size.as_u64())
570		.map(TryInto::try_into)
571		.flat_ok()
572		.unwrap_or(usize::MAX)
573}
574
575fn chunked(payload: PutPayload, part_size: usize) -> impl Iterator<Item = PutPayload> {
576	let mut buf: Bytes = payload.into();
577	from_fn(move || {
578		buf.is_empty()
579			.is_false()
580			.then(|| buf.split_to(part_size.min(buf.len())).into())
581	})
582}