1pub mod local;
2pub mod s3;
3
4#[cfg(test)]
5mod tests;
6
7use std::{
8 iter::{from_fn, once},
9 ops::Range,
10 sync::Arc,
11 time::Duration,
12};
13
14use bytes::Bytes;
15use derive_more::Debug;
16use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
17use http::Method;
18use object_store::{
19 Attributes, CopyMode, DynObjectStore, GetResult, MultipartUpload, ObjectMeta, ObjectStore,
20 ObjectStoreExt, PutPayload, PutResult, path::Path, signer::Signer,
21};
22use tuwunel_core::{
23 Error, Result,
24 config::StorageProvider,
25 debug, err, error,
26 error::error_chain,
27 extract_variant, implement, info, trace,
28 utils::{
29 BoolExt,
30 result::FlatOk,
31 stream::{IterStream, TryReadyExt},
32 },
33};
34use url::Url;
35
36#[derive(Debug)]
37pub struct Provider {
38 pub name: String,
39
40 pub config: StorageProvider,
41
42 pub(crate) provider: Box<DynObjectStore>,
43
44 #[debug(skip)]
45 pub(crate) signer: Option<Arc<dyn Signer>>,
46
47 pub(crate) base_path: Option<Path>,
48
49 startup_check: bool,
50
51 #[expect(unused)]
52 #[debug(skip)]
53 services: Arc<crate::services::OnceServices>,
54}
55
56pub type FetchItem = (Bytes, (Range<u64>, u64));
57pub type FetchMetaItem = (Bytes, Arc<(Range<u64>, ObjectMeta, Attributes)>);
58
59#[implement(Provider)]
60#[tracing::instrument(skip_all, err)]
61pub(super) async fn start(self: &Arc<Self>) -> Result {
62 if self.startup_check {
63 self.startup_check().await?;
64 }
65
66 Ok(())
67}
68
69#[implement(Provider)]
70#[tracing::instrument(name = "check", skip_all, err)]
71async fn startup_check(self: &Arc<Self>) -> Result {
72 debug!(
73 name = ?self.name,
74 "Checking storage provider client connection...",
75 );
76 self.ping()
77 .inspect_ok(|()| {
78 info!(
79 name = %self.name,
80 "Connected to storage provider"
81 );
82 })
83 .await
84}
85
86#[implement(Provider)]
92#[tracing::instrument(
93 level = "debug",
94 err(level = "debug"),
95 skip_all,
96 fields(
97 provider = %self.name,
98 ?path,
99 ?size,
100 )
101)]
102pub async fn put<S, T>(&self, path: &str, size: Option<usize>, input: S) -> Result<PutResult>
103where
104 S: Stream<Item = Result<T>> + Send,
105 PutPayload: From<T> + From<PutPayload>,
106{
107 if size.is_none_or(|size| size >= self.multipart_threshold()) {
108 return self.put_multi(path, input).await;
109 }
110
111 debug!(
112 ?size,
113 threshold = ?self.multipart_threshold(),
114 "Selecting single-part upload..."
115 );
116
117 let payload: PutPayload = input
118 .map_ok(PutPayload::from)
119 .try_collect::<Vec<_>>()
120 .await?
121 .into_iter()
122 .map(Bytes::from)
123 .collect();
124
125 self.put_single(path, payload).await
126}
127
128#[implement(Provider)]
133#[tracing::instrument(
134 level = "debug",
135 err(level = "debug"),
136 skip_all,
137 fields(
138 provider = %self.name,
139 ?path,
140 )
141)]
142pub async fn put_one<T>(&self, path: &str, input: T) -> Result<PutResult>
143where
144 PutPayload: From<T> + From<PutPayload>,
145{
146 let payload: PutPayload = input.into();
147
148 if payload.content_length() < self.multipart_threshold() {
149 return self.put_single(path, payload).await;
150 }
151
152 let part_size = self.multipart_part_size();
153
154 debug!(
155 len = ?payload.content_length(),
156 threshold = ?self.multipart_threshold(),
157 ?part_size,
158 "Selecting multi-part upload..."
159 );
160
161 self.put_multi(path, chunked(payload, part_size).try_stream())
162 .await
163}
164
165#[implement(Provider)]
167#[tracing::instrument(
168 level = "debug",
169 err(level = "debug"),
170 skip_all,
171 fields(
172 provider = %self.name,
173 ?path,
174 )
175)]
176async fn put_multi<S, T>(&self, path: &str, input: S) -> Result<PutResult>
177where
178 S: Stream<Item = Result<T>> + Send,
179 PutPayload: From<T> + From<PutPayload>,
180{
181 let path = self.to_abs_path(path)?;
182 let mut handle = self
183 .provider
184 .put_multipart(&path)
185 .map_err(Error::from)
186 .await?;
187
188 match input
189 .try_for_each(|t| handle.put_part(t.into()).map_err(Error::from))
190 .inspect_err(|e| error!(?path, chain = %error_chain(e), "Failed to store object"))
191 .await
192 {
193 | Ok(()) =>
194 handle
195 .complete()
196 .map_err(Error::from)
197 .inspect_err(|e| {
198 error!(
199 ?path,
200 chain = %error_chain(e),
201 "Failed to store object during completion",
202 );
203 })
204 .await,
205
206 | Err(e) =>
207 handle
208 .abort()
209 .map_ok(|()| Err(e))
210 .map_err(Error::from)
211 .inspect_err(|e| {
212 error!(
213 ?path,
214 chain = %error_chain(e),
215 "Additional errors during error handling",
216 );
217 })
218 .await?,
219 }
220}
221
222#[implement(Provider)]
224#[tracing::instrument(
225 level = "debug",
226 err(level = "debug"),
227 skip_all,
228 fields(
229 provider = %self.name,
230 ?path,
231 )
232)]
233async fn put_single(&self, path: &str, input: PutPayload) -> Result<PutResult> {
234 let path = self.to_abs_path(path)?;
235
236 self.provider
237 .put(&path, input)
238 .map_err(Error::from)
239 .await
240}
241
242#[implement(Provider)]
243#[tracing::instrument(
244 level = "debug",
245 skip_all,
246 fields(
247 provider = %self.name,
248 ?path,
249 )
250)]
251pub fn fetch_with_metadata(
252 &self,
253 path: &str,
254) -> impl Stream<Item = Result<FetchMetaItem>> + Send {
255 self.load(path)
256 .map_ok(|result| {
257 let meta = (result.range.clone(), result.meta.clone(), result.attributes.clone());
258 let data = Arc::new(meta);
259
260 result
261 .into_stream()
262 .map_err(Error::from)
263 .map_ok(move |bytes| (bytes, data.clone()))
264 })
265 .map_err(Error::from)
266 .try_flatten_stream()
267}
268
269#[implement(Provider)]
270#[tracing::instrument(
271 level = "debug",
272 skip_all,
273 fields(
274 provider = %self.name,
275 ?path,
276 )
277)]
278pub fn fetch(&self, path: &str) -> impl Stream<Item = Result<FetchItem>> + Send {
279 self.load(path)
280 .map_ok(|result| {
281 let size = result.meta.size;
282 let range = result.range.clone();
283
284 result
285 .into_stream()
286 .map_err(Error::from)
287 .map_ok(move |bytes| (bytes, (range.clone(), size)))
288 })
289 .map_err(Error::from)
290 .try_flatten_stream()
291}
292
293#[implement(Provider)]
294#[tracing::instrument(
295 level = "debug",
296 err(level = "debug"),
297 skip_all,
298 fields(
299 provider = %self.name,
300 ?path,
301 )
302)]
303pub async fn get(&self, path: &str) -> Result<Bytes> {
304 self.load(path)
305 .map_ok(GetResult::bytes)
306 .await?
307 .map_err(Error::from)
308 .await
309}
310
311#[implement(Provider)]
312#[tracing::instrument(
313 level = "debug",
314 err(level = "debug"),
315 skip_all,
316 fields(
317 provider = %self.name,
318 ?path,
319 )
320)]
321pub async fn load(&self, path: &str) -> Result<GetResult> {
322 let path = self.to_abs_path(path)?;
323
324 self.provider
325 .get(&path)
326 .map_err(Error::from)
327 .await
328}
329
330#[implement(Provider)]
333#[tracing::instrument(
334 level = "debug",
335 err(level = "debug"),
336 skip_all,
337 fields(
338 provider = %self.name,
339 ?path,
340 ?ttl,
341 )
342)]
343pub async fn signed_get_url(&self, path: &str, ttl: Duration) -> Result<Option<Url>> {
344 let Some(signer) = self.signer.as_ref() else {
345 return Ok(None);
346 };
347
348 let path = self.to_abs_path(path)?;
349
350 signer
351 .signed_url(Method::GET, &path, ttl)
352 .map_err(Error::from)
353 .map_ok(Some)
354 .await
355}
356
357#[implement(Provider)]
358#[tracing::instrument(
359 level = "debug",
360 err(level = "debug"),
361 skip_all,
362 fields(
363 provider = %self.name,
364 ?path,
365 )
366)]
367pub async fn delete_one(self: &Arc<Self>, path: &str) -> Result {
368 self.delete(once(path.to_owned()).stream())
369 .map_ok(|_| ())
370 .try_collect()
371 .await
372}
373
374#[implement(Provider)]
375#[tracing::instrument(
376 level = "debug",
377 skip_all,
378 fields(
379 provider = %self.name,
380 )
381)]
382pub fn delete<S>(self: &Arc<Self>, paths: S) -> impl Stream<Item = Result<Path>> + Send
383where
384 S: Stream<Item = String> + Send + 'static,
385{
386 let this = self.clone();
387 let paths = paths
388 .map(Ok)
389 .ready_and_then(move |path| {
390 use object_store::{Error, path};
391
392 this.to_abs_path(&path)
393 .map_err(|_| Error::InvalidPath {
394 source: path::Error::InvalidPath { path: path.into() },
395 })
396 })
397 .boxed();
398
399 self.provider
400 .delete_stream(paths)
401 .map_err(Error::from)
402}
403
404#[implement(Provider)]
405#[tracing::instrument(
406 level = "debug",
407 err(level = "debug"),
408 skip_all,
409 fields(
410 provider = %self.name,
411 ?src,
412 ?dst,
413 ?overwrite,
414 )
415)]
416pub async fn rename(&self, src: &str, dst: &str, overwrite: CopyMode) -> Result {
417 let src = self.to_abs_path(src)?;
418 let dst = self.to_abs_path(dst)?;
419
420 match overwrite {
421 | CopyMode::Overwrite => self.provider.rename(&src, &dst).left_future(),
422 | CopyMode::Create => self
423 .provider
424 .rename_if_not_exists(&src, &dst)
425 .right_future(),
426 }
427 .map_err(Error::from)
428 .await
429}
430
431#[implement(Provider)]
432#[tracing::instrument(
433 level = "debug",
434 err(level = "debug"),
435 skip_all,
436 fields(
437 provider = %self.name,
438 ?src,
439 ?dst,
440 ?overwrite,
441 )
442)]
443pub async fn copy(&self, src: &str, dst: &str, overwrite: CopyMode) -> Result {
444 let src = self.to_abs_path(src)?;
445 let dst = self.to_abs_path(dst)?;
446
447 match overwrite {
448 | CopyMode::Overwrite => self.provider.copy(&src, &dst).left_future(),
449 | CopyMode::Create => self
450 .provider
451 .copy_if_not_exists(&src, &dst)
452 .right_future(),
453 }
454 .map_err(Error::from)
455 .await
456}
457
458#[implement(Provider)]
459#[tracing::instrument(
460 level = "debug",
461 skip_all,
462 fields(
463 provider = %self.name,
464 ?prefix,
465 )
466)]
467pub fn list(&self, prefix: Option<&str>) -> impl Stream<Item = Result<ObjectMeta>> + Send {
468 let abs_prefix = prefix
469 .map(Path::from)
470 .map(|p| self.prepend_base_path(p))
471 .or_else(|| self.base_path.clone());
472
473 self.provider
474 .list(abs_prefix.as_ref())
475 .map_err(Error::from)
476 .map_ok(|meta| ObjectMeta {
477 location: self.strip_base_path(meta.location),
478 ..meta
479 })
480}
481
482#[implement(Provider)]
483#[tracing::instrument(
484 level = "debug",
485 err(level = "debug"),
486 skip_all,
487 fields(
488 provider = %self.name,
489 ?path,
490 )
491)]
492pub async fn head(&self, path: &str) -> Result<ObjectMeta> {
493 self.provider
494 .head(&self.to_abs_path(path)?)
495 .map_err(Error::from)
496 .await
497}
498
499#[implement(Provider)]
500#[tracing::instrument(
501 level = "debug",
502 err(level = "error"),
503 skip_all,
504 fields(
505 provider = %self.name,
506 )
507)]
508pub async fn ping(&self) -> Result {
509 self.list(None)
510 .try_next()
511 .inspect_err(|e| {
512 error!(chain = %error_chain(e), "Failed to connect to storage provider");
513 })
514 .boxed()
515 .await
516 .map(|_| ())
517}
518
519#[implement(Provider)]
520fn to_abs_path(&self, location: &str) -> Result<Path> {
521 let location = Path::parse(location)
522 .map_err(|e| err!("Failed to parse location into canonical PathPart: {e}"))?;
523
524 let path = self.prepend_base_path(location);
525
526 trace!(
527 provider = ?self.name,
528 base_path = ?self.base_path,
529 ?path,
530 "Computed absolute path for object on provider.",
531 );
532
533 Ok(path)
534}
535
536#[implement(Provider)]
537fn prepend_base_path(&self, location: Path) -> Path {
538 match self.base_path.as_ref() {
539 | Some(base_path) if !location.prefix_matches(base_path) => base_path
540 .parts()
541 .chain(location.parts())
542 .collect(),
543
544 | _ => location,
545 }
546}
547
548#[implement(Provider)]
549fn strip_base_path(&self, location: Path) -> Path {
550 self.base_path
551 .as_ref()
552 .and_then(|base_path| location.prefix_match(base_path))
553 .map(Iterator::collect)
554 .unwrap_or(location)
555}
556
557#[implement(Provider)]
558fn multipart_threshold(&self) -> usize {
559 extract_variant!(&self.config, StorageProvider::s3)
560 .map(|config| config.multipart_threshold.as_u64())
561 .map(TryInto::try_into)
562 .flat_ok()
563 .unwrap_or(usize::MAX)
564}
565
566#[implement(Provider)]
567fn multipart_part_size(&self) -> usize {
568 extract_variant!(&self.config, StorageProvider::s3)
569 .map(|config| config.multipart_part_size.as_u64())
570 .map(TryInto::try_into)
571 .flat_ok()
572 .unwrap_or(usize::MAX)
573}
574
575fn chunked(payload: PutPayload, part_size: usize) -> impl Iterator<Item = PutPayload> {
576 let mut buf: Bytes = payload.into();
577 from_fn(move || {
578 buf.is_empty()
579 .is_false()
580 .then(|| buf.split_to(part_size.min(buf.len())).into())
581 })
582}