Skip to main content

tuwunel_service/storage/
provider.rs

1pub mod local;
2pub mod s3;
3
4#[cfg(test)]
5mod tests;
6
7use std::{
8	iter::{from_fn, once},
9	ops::Range,
10	sync::Arc,
11};
12
13use bytes::Bytes;
14use derive_more::Debug;
15use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
16use object_store::{
17	Attributes, CopyMode, DynObjectStore, GetResult, MultipartUpload, ObjectMeta, ObjectStore,
18	ObjectStoreExt, PutPayload, PutResult, path::Path,
19};
20use tuwunel_core::{
21	Error, Result,
22	config::StorageProvider,
23	debug, err, error,
24	error::error_chain,
25	extract_variant, implement, info, trace,
26	utils::{
27		BoolExt,
28		result::FlatOk,
29		stream::{IterStream, TryReadyExt},
30	},
31};
32
33#[derive(Debug)]
34pub struct Provider {
35	pub name: String,
36
37	pub config: StorageProvider,
38
39	pub(crate) provider: Box<DynObjectStore>,
40
41	pub(crate) base_path: Option<Path>,
42
43	startup_check: bool,
44
45	#[expect(unused)]
46	#[debug(skip)]
47	services: Arc<crate::services::OnceServices>,
48}
49
50pub type FetchItem = (Bytes, (Range<u64>, u64));
51pub type FetchMetaItem = (Bytes, Arc<(Range<u64>, ObjectMeta, Attributes)>);
52
53#[implement(Provider)]
54#[tracing::instrument(skip_all, err)]
55pub(super) async fn start(self: &Arc<Self>) -> Result {
56	if self.startup_check {
57		self.startup_check().await?;
58	}
59
60	Ok(())
61}
62
63#[implement(Provider)]
64#[tracing::instrument(name = "check", skip_all, err)]
65async fn startup_check(self: &Arc<Self>) -> Result {
66	debug!(
67		name = ?self.name,
68		"Checking storage provider client connection...",
69	);
70	self.ping()
71		.inspect_ok(|()| {
72			info!(
73				name = %self.name,
74				"Connected to storage provider"
75			);
76		})
77		.await
78}
79
80/// Put object into store from streaming input.
81///
82/// Recommended to know the total size of the object. If size is `None`,
83/// multi-part upload may be selected even for small uploads below the
84/// configured threshold.
85#[implement(Provider)]
86#[tracing::instrument(
87	level = "debug",
88	err(level = "debug"),
89	skip_all,
90	fields(
91		provider = %self.name,
92		?path,
93		?size,
94	)
95)]
96pub async fn put<S, T>(&self, path: &str, size: Option<usize>, input: S) -> Result<PutResult>
97where
98	S: Stream<Item = Result<T>> + Send,
99	PutPayload: From<T> + From<PutPayload>,
100{
101	if size.is_none_or(|size| size >= self.multipart_threshold()) {
102		return self.put_multi(path, input).await;
103	}
104
105	debug!(
106		?size,
107		threshold = ?self.multipart_threshold(),
108		"Selecting single-part upload..."
109	);
110
111	let payload: PutPayload = input
112		.map_ok(PutPayload::from)
113		.try_collect::<Vec<_>>()
114		.await?
115		.into_iter()
116		.map(Bytes::from)
117		.collect();
118
119	self.put_single(path, payload).await
120}
121
122/// Put object into the store from contiguous input.
123///
124/// The size of input will be determined and multipart upload will be chosen as
125/// necessary internally.
126#[implement(Provider)]
127#[tracing::instrument(
128	level = "debug",
129	err(level = "debug"),
130	skip_all,
131	fields(
132		provider = %self.name,
133		?path,
134	)
135)]
136pub async fn put_one<T>(&self, path: &str, input: T) -> Result<PutResult>
137where
138	PutPayload: From<T> + From<PutPayload>,
139{
140	let payload: PutPayload = input.into();
141
142	if payload.content_length() < self.multipart_threshold() {
143		return self.put_single(path, payload).await;
144	}
145
146	let part_size = self.multipart_part_size();
147
148	debug!(
149		len = ?payload.content_length(),
150		threshold = ?self.multipart_threshold(),
151		?part_size,
152		"Selecting multi-part upload..."
153	);
154
155	self.put_multi(path, chunked(payload, part_size).try_stream())
156		.await
157}
158
159/// Put object into the store from streaming input using multipart upload.
160#[implement(Provider)]
161#[tracing::instrument(
162	level = "debug",
163	err(level = "debug"),
164	skip_all,
165	fields(
166		provider = %self.name,
167		?path,
168	)
169)]
170async fn put_multi<S, T>(&self, path: &str, input: S) -> Result<PutResult>
171where
172	S: Stream<Item = Result<T>> + Send,
173	PutPayload: From<T> + From<PutPayload>,
174{
175	let path = self.to_abs_path(path)?;
176	let mut handle = self
177		.provider
178		.put_multipart(&path)
179		.map_err(Error::from)
180		.await?;
181
182	match input
183		.try_for_each(|t| handle.put_part(t.into()).map_err(Error::from))
184		.inspect_err(|e| error!(?path, chain = %error_chain(e), "Failed to store object"))
185		.await
186	{
187		| Ok(()) =>
188			handle
189				.complete()
190				.map_err(Error::from)
191				.inspect_err(|e| {
192					error!(
193						?path,
194						chain = %error_chain(e),
195						"Failed to store object during completion",
196					);
197				})
198				.await,
199
200		| Err(e) =>
201			handle
202				.abort()
203				.map_ok(|()| Err(e))
204				.map_err(Error::from)
205				.inspect_err(|e| {
206					error!(
207						?path,
208						chain = %error_chain(e),
209						"Additional errors during error handling",
210					);
211				})
212				.await?,
213	}
214}
215
216/// Put object into the store from contiguous input non-multipart upload.
217#[implement(Provider)]
218#[tracing::instrument(
219	level = "debug",
220	err(level = "debug"),
221	skip_all,
222	fields(
223		provider = %self.name,
224		?path,
225	)
226)]
227async fn put_single(&self, path: &str, input: PutPayload) -> Result<PutResult> {
228	let path = self.to_abs_path(path)?;
229
230	self.provider
231		.put(&path, input)
232		.map_err(Error::from)
233		.await
234}
235
236#[implement(Provider)]
237#[tracing::instrument(
238	level = "debug",
239	skip_all,
240	fields(
241		provider = %self.name,
242		?path,
243	)
244)]
245pub fn fetch_with_metadata(
246	&self,
247	path: &str,
248) -> impl Stream<Item = Result<FetchMetaItem>> + Send {
249	self.load(path)
250		.map_ok(|result| {
251			let meta = (result.range.clone(), result.meta.clone(), result.attributes.clone());
252			let data = Arc::new(meta);
253
254			result
255				.into_stream()
256				.map_err(Error::from)
257				.map_ok(move |bytes| (bytes, data.clone()))
258		})
259		.map_err(Error::from)
260		.try_flatten_stream()
261}
262
263#[implement(Provider)]
264#[tracing::instrument(
265	level = "debug",
266	skip_all,
267	fields(
268		provider = %self.name,
269		?path,
270	)
271)]
272pub fn fetch(&self, path: &str) -> impl Stream<Item = Result<FetchItem>> + Send {
273	self.load(path)
274		.map_ok(|result| {
275			let size = result.meta.size;
276			let range = result.range.clone();
277
278			result
279				.into_stream()
280				.map_err(Error::from)
281				.map_ok(move |bytes| (bytes, (range.clone(), size)))
282		})
283		.map_err(Error::from)
284		.try_flatten_stream()
285}
286
287#[implement(Provider)]
288#[tracing::instrument(
289	level = "debug",
290	err(level = "debug"),
291	skip_all,
292	fields(
293		provider = %self.name,
294		?path,
295	)
296)]
297pub async fn get(&self, path: &str) -> Result<Bytes> {
298	self.load(path)
299		.map_ok(GetResult::bytes)
300		.await?
301		.map_err(Error::from)
302		.await
303}
304
305#[implement(Provider)]
306#[tracing::instrument(
307	level = "debug",
308	err(level = "debug"),
309	skip_all,
310	fields(
311		provider = %self.name,
312		?path,
313	)
314)]
315pub async fn load(&self, path: &str) -> Result<GetResult> {
316	let path = self.to_abs_path(path)?;
317
318	self.provider
319		.get(&path)
320		.map_err(Error::from)
321		.await
322}
323
324#[implement(Provider)]
325#[tracing::instrument(
326	level = "debug",
327	err(level = "debug"),
328	skip_all,
329	fields(
330		provider = %self.name,
331		?path,
332	)
333)]
334pub async fn delete_one(self: &Arc<Self>, path: &str) -> Result {
335	self.delete(once(path.to_owned()).stream())
336		.map_ok(|_| ())
337		.try_collect()
338		.await
339}
340
341#[implement(Provider)]
342#[tracing::instrument(
343	level = "debug",
344	skip_all,
345	fields(
346		provider = %self.name,
347	)
348)]
349pub fn delete<S>(self: &Arc<Self>, paths: S) -> impl Stream<Item = Result<Path>> + Send
350where
351	S: Stream<Item = String> + Send + 'static,
352{
353	let this = self.clone();
354	let paths = paths
355		.map(Ok)
356		.ready_and_then(move |path| {
357			use object_store::{Error, path};
358
359			this.to_abs_path(&path)
360				.map_err(|_| Error::InvalidPath {
361					source: path::Error::InvalidPath { path: path.into() },
362				})
363		})
364		.boxed();
365
366	self.provider
367		.delete_stream(paths)
368		.map_err(Error::from)
369}
370
371#[implement(Provider)]
372#[tracing::instrument(
373	level = "debug",
374	err(level = "debug"),
375	skip_all,
376	fields(
377		provider = %self.name,
378		?src,
379		?dst,
380		?overwrite,
381	)
382)]
383pub async fn rename(&self, src: &str, dst: &str, overwrite: CopyMode) -> Result {
384	let src = self.to_abs_path(src)?;
385	let dst = self.to_abs_path(dst)?;
386
387	match overwrite {
388		| CopyMode::Overwrite => self.provider.rename(&src, &dst).left_future(),
389		| CopyMode::Create => self
390			.provider
391			.rename_if_not_exists(&src, &dst)
392			.right_future(),
393	}
394	.map_err(Error::from)
395	.await
396}
397
398#[implement(Provider)]
399#[tracing::instrument(
400	level = "debug",
401	err(level = "debug"),
402	skip_all,
403	fields(
404		provider = %self.name,
405		?src,
406		?dst,
407		?overwrite,
408	)
409)]
410pub async fn copy(&self, src: &str, dst: &str, overwrite: CopyMode) -> Result {
411	let src = self.to_abs_path(src)?;
412	let dst = self.to_abs_path(dst)?;
413
414	match overwrite {
415		| CopyMode::Overwrite => self.provider.copy(&src, &dst).left_future(),
416		| CopyMode::Create => self
417			.provider
418			.copy_if_not_exists(&src, &dst)
419			.right_future(),
420	}
421	.map_err(Error::from)
422	.await
423}
424
425#[implement(Provider)]
426#[tracing::instrument(
427	level = "debug",
428	skip_all,
429	fields(
430		provider = %self.name,
431		?prefix,
432	)
433)]
434pub fn list(&self, prefix: Option<&str>) -> impl Stream<Item = Result<ObjectMeta>> + Send {
435	let abs_prefix = prefix
436		.map(Path::from)
437		.map(|p| self.prepend_base_path(p))
438		.or_else(|| self.base_path.clone());
439
440	self.provider
441		.list(abs_prefix.as_ref())
442		.map_err(Error::from)
443		.map_ok(|meta| ObjectMeta {
444			location: self.strip_base_path(meta.location),
445			..meta
446		})
447}
448
449#[implement(Provider)]
450#[tracing::instrument(
451	level = "debug",
452	err(level = "debug"),
453	skip_all,
454	fields(
455		provider = %self.name,
456		?path,
457	)
458)]
459pub async fn head(&self, path: &str) -> Result<ObjectMeta> {
460	self.provider
461		.head(&self.to_abs_path(path)?)
462		.map_err(Error::from)
463		.await
464}
465
466#[implement(Provider)]
467#[tracing::instrument(
468	level = "debug",
469	err(level = "error"),
470	skip_all,
471	fields(
472		provider = %self.name,
473	)
474)]
475pub async fn ping(&self) -> Result {
476	self.list(None)
477		.try_next()
478		.inspect_err(|e| {
479			error!(chain = %error_chain(e), "Failed to connect to storage provider");
480		})
481		.boxed()
482		.await
483		.map(|_| ())
484}
485
486#[implement(Provider)]
487fn to_abs_path(&self, location: &str) -> Result<Path> {
488	let location = Path::parse(location)
489		.map_err(|e| err!("Failed to parse location into canonical PathPart: {e}"))?;
490
491	let path = self.prepend_base_path(location);
492
493	trace!(
494		provider = ?self.name,
495		base_path = ?self.base_path,
496		?path,
497		"Computed absolute path for object on provider.",
498	);
499
500	Ok(path)
501}
502
503#[implement(Provider)]
504fn prepend_base_path(&self, location: Path) -> Path {
505	match self.base_path.as_ref() {
506		| Some(base_path) if !location.prefix_matches(base_path) => base_path
507			.parts()
508			.chain(location.parts())
509			.collect(),
510
511		| _ => location,
512	}
513}
514
515#[implement(Provider)]
516fn strip_base_path(&self, location: Path) -> Path {
517	self.base_path
518		.as_ref()
519		.and_then(|base_path| location.prefix_match(base_path))
520		.map(Iterator::collect)
521		.unwrap_or(location)
522}
523
524#[implement(Provider)]
525fn multipart_threshold(&self) -> usize {
526	extract_variant!(&self.config, StorageProvider::s3)
527		.map(|config| config.multipart_threshold.as_u64())
528		.map(TryInto::try_into)
529		.flat_ok()
530		.unwrap_or(usize::MAX)
531}
532
533#[implement(Provider)]
534fn multipart_part_size(&self) -> usize {
535	extract_variant!(&self.config, StorageProvider::s3)
536		.map(|config| config.multipart_part_size.as_u64())
537		.map(TryInto::try_into)
538		.flat_ok()
539		.unwrap_or(usize::MAX)
540}
541
542fn chunked(payload: PutPayload, part_size: usize) -> impl Iterator<Item = PutPayload> {
543	let mut buf: Bytes = payload.into();
544	from_fn(move || {
545		buf.is_empty()
546			.is_false()
547			.then(|| buf.split_to(part_size.min(buf.len())).into())
548	})
549}