1pub mod local;
2pub mod s3;
3
4#[cfg(test)]
5mod tests;
6
7use std::{
8 iter::{from_fn, once},
9 ops::Range,
10 sync::Arc,
11};
12
13use bytes::Bytes;
14use derive_more::Debug;
15use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
16use object_store::{
17 Attributes, CopyMode, DynObjectStore, GetResult, MultipartUpload, ObjectMeta, ObjectStore,
18 ObjectStoreExt, PutPayload, PutResult, path::Path,
19};
20use tuwunel_core::{
21 Error, Result,
22 config::StorageProvider,
23 debug, err, error,
24 error::error_chain,
25 extract_variant, implement, info, trace,
26 utils::{
27 BoolExt,
28 result::FlatOk,
29 stream::{IterStream, TryReadyExt},
30 },
31};
32
33#[derive(Debug)]
34pub struct Provider {
35 pub name: String,
36
37 pub config: StorageProvider,
38
39 pub(crate) provider: Box<DynObjectStore>,
40
41 pub(crate) base_path: Option<Path>,
42
43 startup_check: bool,
44
45 #[expect(unused)]
46 #[debug(skip)]
47 services: Arc<crate::services::OnceServices>,
48}
49
50pub type FetchItem = (Bytes, (Range<u64>, u64));
51pub type FetchMetaItem = (Bytes, Arc<(Range<u64>, ObjectMeta, Attributes)>);
52
53#[implement(Provider)]
54#[tracing::instrument(skip_all, err)]
55pub(super) async fn start(self: &Arc<Self>) -> Result {
56 if self.startup_check {
57 self.startup_check().await?;
58 }
59
60 Ok(())
61}
62
63#[implement(Provider)]
64#[tracing::instrument(name = "check", skip_all, err)]
65async fn startup_check(self: &Arc<Self>) -> Result {
66 debug!(
67 name = ?self.name,
68 "Checking storage provider client connection...",
69 );
70 self.ping()
71 .inspect_ok(|()| {
72 info!(
73 name = %self.name,
74 "Connected to storage provider"
75 );
76 })
77 .await
78}
79
80#[implement(Provider)]
86#[tracing::instrument(
87 level = "debug",
88 err(level = "debug"),
89 skip_all,
90 fields(
91 provider = %self.name,
92 ?path,
93 ?size,
94 )
95)]
96pub async fn put<S, T>(&self, path: &str, size: Option<usize>, input: S) -> Result<PutResult>
97where
98 S: Stream<Item = Result<T>> + Send,
99 PutPayload: From<T> + From<PutPayload>,
100{
101 if size.is_none_or(|size| size >= self.multipart_threshold()) {
102 return self.put_multi(path, input).await;
103 }
104
105 debug!(
106 ?size,
107 threshold = ?self.multipart_threshold(),
108 "Selecting single-part upload..."
109 );
110
111 let payload: PutPayload = input
112 .map_ok(PutPayload::from)
113 .try_collect::<Vec<_>>()
114 .await?
115 .into_iter()
116 .map(Bytes::from)
117 .collect();
118
119 self.put_single(path, payload).await
120}
121
122#[implement(Provider)]
127#[tracing::instrument(
128 level = "debug",
129 err(level = "debug"),
130 skip_all,
131 fields(
132 provider = %self.name,
133 ?path,
134 )
135)]
136pub async fn put_one<T>(&self, path: &str, input: T) -> Result<PutResult>
137where
138 PutPayload: From<T> + From<PutPayload>,
139{
140 let payload: PutPayload = input.into();
141
142 if payload.content_length() < self.multipart_threshold() {
143 return self.put_single(path, payload).await;
144 }
145
146 let part_size = self.multipart_part_size();
147
148 debug!(
149 len = ?payload.content_length(),
150 threshold = ?self.multipart_threshold(),
151 ?part_size,
152 "Selecting multi-part upload..."
153 );
154
155 self.put_multi(path, chunked(payload, part_size).try_stream())
156 .await
157}
158
159#[implement(Provider)]
161#[tracing::instrument(
162 level = "debug",
163 err(level = "debug"),
164 skip_all,
165 fields(
166 provider = %self.name,
167 ?path,
168 )
169)]
170async fn put_multi<S, T>(&self, path: &str, input: S) -> Result<PutResult>
171where
172 S: Stream<Item = Result<T>> + Send,
173 PutPayload: From<T> + From<PutPayload>,
174{
175 let path = self.to_abs_path(path)?;
176 let mut handle = self
177 .provider
178 .put_multipart(&path)
179 .map_err(Error::from)
180 .await?;
181
182 match input
183 .try_for_each(|t| handle.put_part(t.into()).map_err(Error::from))
184 .inspect_err(|e| error!(?path, chain = %error_chain(e), "Failed to store object"))
185 .await
186 {
187 | Ok(()) =>
188 handle
189 .complete()
190 .map_err(Error::from)
191 .inspect_err(|e| {
192 error!(
193 ?path,
194 chain = %error_chain(e),
195 "Failed to store object during completion",
196 );
197 })
198 .await,
199
200 | Err(e) =>
201 handle
202 .abort()
203 .map_ok(|()| Err(e))
204 .map_err(Error::from)
205 .inspect_err(|e| {
206 error!(
207 ?path,
208 chain = %error_chain(e),
209 "Additional errors during error handling",
210 );
211 })
212 .await?,
213 }
214}
215
216#[implement(Provider)]
218#[tracing::instrument(
219 level = "debug",
220 err(level = "debug"),
221 skip_all,
222 fields(
223 provider = %self.name,
224 ?path,
225 )
226)]
227async fn put_single(&self, path: &str, input: PutPayload) -> Result<PutResult> {
228 let path = self.to_abs_path(path)?;
229
230 self.provider
231 .put(&path, input)
232 .map_err(Error::from)
233 .await
234}
235
236#[implement(Provider)]
237#[tracing::instrument(
238 level = "debug",
239 skip_all,
240 fields(
241 provider = %self.name,
242 ?path,
243 )
244)]
245pub fn fetch_with_metadata(
246 &self,
247 path: &str,
248) -> impl Stream<Item = Result<FetchMetaItem>> + Send {
249 self.load(path)
250 .map_ok(|result| {
251 let meta = (result.range.clone(), result.meta.clone(), result.attributes.clone());
252 let data = Arc::new(meta);
253
254 result
255 .into_stream()
256 .map_err(Error::from)
257 .map_ok(move |bytes| (bytes, data.clone()))
258 })
259 .map_err(Error::from)
260 .try_flatten_stream()
261}
262
263#[implement(Provider)]
264#[tracing::instrument(
265 level = "debug",
266 skip_all,
267 fields(
268 provider = %self.name,
269 ?path,
270 )
271)]
272pub fn fetch(&self, path: &str) -> impl Stream<Item = Result<FetchItem>> + Send {
273 self.load(path)
274 .map_ok(|result| {
275 let size = result.meta.size;
276 let range = result.range.clone();
277
278 result
279 .into_stream()
280 .map_err(Error::from)
281 .map_ok(move |bytes| (bytes, (range.clone(), size)))
282 })
283 .map_err(Error::from)
284 .try_flatten_stream()
285}
286
287#[implement(Provider)]
288#[tracing::instrument(
289 level = "debug",
290 err(level = "debug"),
291 skip_all,
292 fields(
293 provider = %self.name,
294 ?path,
295 )
296)]
297pub async fn get(&self, path: &str) -> Result<Bytes> {
298 self.load(path)
299 .map_ok(GetResult::bytes)
300 .await?
301 .map_err(Error::from)
302 .await
303}
304
305#[implement(Provider)]
306#[tracing::instrument(
307 level = "debug",
308 err(level = "debug"),
309 skip_all,
310 fields(
311 provider = %self.name,
312 ?path,
313 )
314)]
315pub async fn load(&self, path: &str) -> Result<GetResult> {
316 let path = self.to_abs_path(path)?;
317
318 self.provider
319 .get(&path)
320 .map_err(Error::from)
321 .await
322}
323
324#[implement(Provider)]
325#[tracing::instrument(
326 level = "debug",
327 err(level = "debug"),
328 skip_all,
329 fields(
330 provider = %self.name,
331 ?path,
332 )
333)]
334pub async fn delete_one(self: &Arc<Self>, path: &str) -> Result {
335 self.delete(once(path.to_owned()).stream())
336 .map_ok(|_| ())
337 .try_collect()
338 .await
339}
340
341#[implement(Provider)]
342#[tracing::instrument(
343 level = "debug",
344 skip_all,
345 fields(
346 provider = %self.name,
347 )
348)]
349pub fn delete<S>(self: &Arc<Self>, paths: S) -> impl Stream<Item = Result<Path>> + Send
350where
351 S: Stream<Item = String> + Send + 'static,
352{
353 let this = self.clone();
354 let paths = paths
355 .map(Ok)
356 .ready_and_then(move |path| {
357 use object_store::{Error, path};
358
359 this.to_abs_path(&path)
360 .map_err(|_| Error::InvalidPath {
361 source: path::Error::InvalidPath { path: path.into() },
362 })
363 })
364 .boxed();
365
366 self.provider
367 .delete_stream(paths)
368 .map_err(Error::from)
369}
370
371#[implement(Provider)]
372#[tracing::instrument(
373 level = "debug",
374 err(level = "debug"),
375 skip_all,
376 fields(
377 provider = %self.name,
378 ?src,
379 ?dst,
380 ?overwrite,
381 )
382)]
383pub async fn rename(&self, src: &str, dst: &str, overwrite: CopyMode) -> Result {
384 let src = self.to_abs_path(src)?;
385 let dst = self.to_abs_path(dst)?;
386
387 match overwrite {
388 | CopyMode::Overwrite => self.provider.rename(&src, &dst).left_future(),
389 | CopyMode::Create => self
390 .provider
391 .rename_if_not_exists(&src, &dst)
392 .right_future(),
393 }
394 .map_err(Error::from)
395 .await
396}
397
398#[implement(Provider)]
399#[tracing::instrument(
400 level = "debug",
401 err(level = "debug"),
402 skip_all,
403 fields(
404 provider = %self.name,
405 ?src,
406 ?dst,
407 ?overwrite,
408 )
409)]
410pub async fn copy(&self, src: &str, dst: &str, overwrite: CopyMode) -> Result {
411 let src = self.to_abs_path(src)?;
412 let dst = self.to_abs_path(dst)?;
413
414 match overwrite {
415 | CopyMode::Overwrite => self.provider.copy(&src, &dst).left_future(),
416 | CopyMode::Create => self
417 .provider
418 .copy_if_not_exists(&src, &dst)
419 .right_future(),
420 }
421 .map_err(Error::from)
422 .await
423}
424
425#[implement(Provider)]
426#[tracing::instrument(
427 level = "debug",
428 skip_all,
429 fields(
430 provider = %self.name,
431 ?prefix,
432 )
433)]
434pub fn list(&self, prefix: Option<&str>) -> impl Stream<Item = Result<ObjectMeta>> + Send {
435 let abs_prefix = prefix
436 .map(Path::from)
437 .map(|p| self.prepend_base_path(p))
438 .or_else(|| self.base_path.clone());
439
440 self.provider
441 .list(abs_prefix.as_ref())
442 .map_err(Error::from)
443 .map_ok(|meta| ObjectMeta {
444 location: self.strip_base_path(meta.location),
445 ..meta
446 })
447}
448
449#[implement(Provider)]
450#[tracing::instrument(
451 level = "debug",
452 err(level = "debug"),
453 skip_all,
454 fields(
455 provider = %self.name,
456 ?path,
457 )
458)]
459pub async fn head(&self, path: &str) -> Result<ObjectMeta> {
460 self.provider
461 .head(&self.to_abs_path(path)?)
462 .map_err(Error::from)
463 .await
464}
465
466#[implement(Provider)]
467#[tracing::instrument(
468 level = "debug",
469 err(level = "error"),
470 skip_all,
471 fields(
472 provider = %self.name,
473 )
474)]
475pub async fn ping(&self) -> Result {
476 self.list(None)
477 .try_next()
478 .inspect_err(|e| {
479 error!(chain = %error_chain(e), "Failed to connect to storage provider");
480 })
481 .boxed()
482 .await
483 .map(|_| ())
484}
485
486#[implement(Provider)]
487fn to_abs_path(&self, location: &str) -> Result<Path> {
488 let location = Path::parse(location)
489 .map_err(|e| err!("Failed to parse location into canonical PathPart: {e}"))?;
490
491 let path = self.prepend_base_path(location);
492
493 trace!(
494 provider = ?self.name,
495 base_path = ?self.base_path,
496 ?path,
497 "Computed absolute path for object on provider.",
498 );
499
500 Ok(path)
501}
502
503#[implement(Provider)]
504fn prepend_base_path(&self, location: Path) -> Path {
505 match self.base_path.as_ref() {
506 | Some(base_path) if !location.prefix_matches(base_path) => base_path
507 .parts()
508 .chain(location.parts())
509 .collect(),
510
511 | _ => location,
512 }
513}
514
515#[implement(Provider)]
516fn strip_base_path(&self, location: Path) -> Path {
517 self.base_path
518 .as_ref()
519 .and_then(|base_path| location.prefix_match(base_path))
520 .map(Iterator::collect)
521 .unwrap_or(location)
522}
523
524#[implement(Provider)]
525fn multipart_threshold(&self) -> usize {
526 extract_variant!(&self.config, StorageProvider::s3)
527 .map(|config| config.multipart_threshold.as_u64())
528 .map(TryInto::try_into)
529 .flat_ok()
530 .unwrap_or(usize::MAX)
531}
532
533#[implement(Provider)]
534fn multipart_part_size(&self) -> usize {
535 extract_variant!(&self.config, StorageProvider::s3)
536 .map(|config| config.multipart_part_size.as_u64())
537 .map(TryInto::try_into)
538 .flat_ok()
539 .unwrap_or(usize::MAX)
540}
541
542fn chunked(payload: PutPayload, part_size: usize) -> impl Iterator<Item = PutPayload> {
543 let mut buf: Bytes = payload.into();
544 from_fn(move || {
545 buf.is_empty()
546 .is_false()
547 .then(|| buf.split_to(part_size.min(buf.len())).into())
548 })
549}