Skip to main content

tuwunel_database/
de.rs

1use serde::{
2	Deserialize, de,
3	de::{DeserializeSeed, Visitor},
4};
5use tuwunel_core::{
6	Error, Result, arrayvec::ArrayVec, checked, debug::DebugInspect, err, unhandled,
7	utils::string,
8};
9
10/// Deserialize into T from buffer.
11#[cfg_attr(
12	unabridged,
13	tracing::instrument(
14		name = "deserialize",
15		level = "trace",
16		skip_all,
17		fields(len = %buf.len()),
18	)
19)]
20pub(crate) fn from_slice<'a, T>(buf: &'a [u8]) -> Result<T>
21where
22	T: Deserialize<'a>,
23{
24	let mut deserializer = Deserializer { buf, pos: 0, rec: 0, seq: 0 };
25
26	T::deserialize(&mut deserializer).debug_inspect(|_| {
27		deserializer
28			.finished()
29			.expect("deserialization failed to consume trailing bytes");
30	})
31}
32
33/// Deserialization state.
34pub(crate) struct Deserializer<'de> {
35	buf: &'de [u8],
36	pos: usize,
37	rec: usize,
38	seq: usize,
39}
40
41/// Directive to ignore a record. This type can be used to skip deserialization
42/// until the next separator is found.
43#[derive(Debug, Deserialize)]
44pub struct Ignore;
45
46/// Directive to ignore all remaining records. This can be used in a sequence to
47/// ignore the rest of the sequence.
48#[derive(Debug, Deserialize)]
49pub struct IgnoreAll;
50
51impl<'de> Deserializer<'de> {
52	const SEP: u8 = crate::ser::SEP;
53
54	/// Determine if the input was fully consumed and error if bytes remaining.
55	/// This is intended for debug assertions; not optimized for parsing logic.
56	fn finished(&self) -> Result {
57		let pos = self.pos;
58		let len = self.buf.len();
59		let parsed = &self.buf[0..pos];
60		let unparsed = &self.buf[pos..];
61		let remain = self.remaining()?;
62		let trailing_sep = remain == 1 && unparsed[0] == Self::SEP;
63		(remain == 0 || trailing_sep)
64			.then_some(())
65			.ok_or(err!(SerdeDe(
66				"{remain} trailing of {len} bytes not deserialized.\n{parsed:?}\n{unparsed:?}",
67			)))
68	}
69
70	/// Called at the start of arrays and tuples
71	#[inline]
72	fn sequence_start(&mut self, len: usize) {
73		debug_assert!(self.seq == 0, "Nested sequences are not handled at this time");
74		self.seq = len;
75	}
76
77	/// Consume the current record to ignore it. Inside a sequence the next
78	/// record is skipped but at the top-level all records are skipped such that
79	/// deserialization completes with self.finished() == Ok.
80	#[inline]
81	fn record_ignore(&mut self) {
82		if self.seq > 0 {
83			self.record_next();
84		} else {
85			self.record_ignore_all();
86		}
87	}
88
89	/// Consume the current and all remaining records to ignore them. Similar to
90	/// Ignore at the top-level, but it can be provided in a sequence to Ignore
91	/// all remaining elements.
92	#[inline]
93	fn record_ignore_all(&mut self) { self.record_trail(); }
94
95	/// Consume the current record. The position pointer is moved to the start
96	/// of the next record. Slice of the current record is returned.
97	#[inline]
98	fn record_next(&mut self) -> &'de [u8] {
99		self.buf[self.pos..]
100			.split(|b| *b == Deserializer::SEP)
101			.inspect(|record| self.inc_pos(record.len()))
102			.next()
103			.expect("remainder of buf even if SEP was not found")
104	}
105
106	/// Peek at the first byte of the current record. If all records were
107	/// consumed None is returned instead.
108	#[inline]
109	fn record_peek_byte(&self) -> Option<u8> {
110		let started = self.pos != 0 || self.rec > 0;
111		let buf = &self.buf[self.pos..];
112		debug_assert!(
113			!started || buf[0] == Self::SEP,
114			"Missing expected record separator at current position"
115		);
116
117		buf.get::<usize>(started.into()).copied()
118	}
119
120	/// Consume the record separator such that the position cleanly points to
121	/// the start of the next record. When input is exhausted but the
122	/// sequence is not, advance no bytes; the caller deserializes from an
123	/// empty slice. See `next_element_seed` for the additive-tail mechanic
124	/// this enables.
125	#[inline]
126	fn record_start(&mut self) {
127		let started = self.pos != 0 || self.rec > 0;
128		let input_done = self.pos >= self.buf.len();
129		let output_done = self.rec >= self.seq;
130		let incomplete = input_done && !output_done;
131		debug_assert!(
132			!started || incomplete || self.buf.get(self.pos) == Some(&Self::SEP),
133			"Missing expected record separator at current position"
134		);
135
136		let inc = started && !incomplete;
137		self.inc_pos(inc.into());
138		self.inc_rec(1);
139	}
140
141	/// Consume all remaining bytes, which may include record separators,
142	/// returning a raw slice.
143	#[inline]
144	fn record_trail(&mut self) -> &'de [u8] {
145		let record = &self.buf[self.pos..];
146		self.inc_pos(record.len());
147		record
148	}
149
150	/// Increment the position pointer.
151	#[inline]
152	#[cfg_attr(
153		unabridged,
154		tracing::instrument(
155			level = "trace",
156			skip(self),
157			fields(
158				len = self.buf.len(),
159				rem = self.remaining().unwrap_or_default().saturating_sub(n),
160			),
161		)
162	)]
163	fn inc_pos(&mut self, n: usize) {
164		self.pos = self.pos.saturating_add(n);
165		debug_assert!(self.pos <= self.buf.len(), "pos out of range");
166	}
167
168	#[inline]
169	fn inc_rec(&mut self, n: usize) { self.rec = self.rec.saturating_add(n); }
170
171	/// Unconsumed input bytes.
172	#[inline]
173	fn remaining(&self) -> Result<usize> {
174		let pos = self.pos;
175		let len = self.buf.len();
176		checked!(len - pos)
177	}
178}
179
180impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
181	type Error = Error;
182
183	#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
184	fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
185	where
186		V: Visitor<'de>,
187	{
188		self.sequence_start(1);
189		visitor.visit_seq(self)
190	}
191
192	#[cfg_attr(
193		unabridged,
194		tracing::instrument(level = "trace", skip(self, visitor))
195	)]
196	fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value>
197	where
198		V: Visitor<'de>,
199	{
200		self.sequence_start(len);
201		visitor.visit_seq(self)
202	}
203
204	#[cfg_attr(
205		unabridged,
206		tracing::instrument(level = "trace", skip(self, visitor))
207	)]
208	fn deserialize_tuple_struct<V>(
209		self,
210		_name: &'static str,
211		len: usize,
212		visitor: V,
213	) -> Result<V::Value>
214	where
215		V: Visitor<'de>,
216	{
217		self.sequence_start(len);
218		visitor.visit_seq(self)
219	}
220
221	#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
222	fn deserialize_map<V>(self, visitor: V) -> Result<V::Value>
223	where
224		V: Visitor<'de>,
225	{
226		let input = self.record_next();
227		let mut d = serde_json::Deserializer::from_slice(input);
228		d.deserialize_map(visitor).map_err(Into::into)
229	}
230
231	#[cfg_attr(
232		unabridged,
233		tracing::instrument(level = "trace", skip(self, visitor))
234	)]
235	fn deserialize_struct<V>(
236		self,
237		name: &'static str,
238		fields: &'static [&'static str],
239		visitor: V,
240	) -> Result<V::Value>
241	where
242		V: Visitor<'de>,
243	{
244		let input = self.record_next();
245		let mut d = serde_json::Deserializer::from_slice(input);
246		d.deserialize_struct(name, fields, visitor)
247			.map_err(Into::into)
248	}
249
250	#[cfg_attr(
251		unabridged,
252		tracing::instrument(level = "trace", skip(self, visitor))
253	)]
254	fn deserialize_unit_struct<V>(self, name: &'static str, visitor: V) -> Result<V::Value>
255	where
256		V: Visitor<'de>,
257	{
258		match name {
259			| "Ignore" => self.record_ignore(),
260			| "IgnoreAll" => self.record_ignore_all(),
261			| _ => unhandled!("Unrecognized deserialization Directive {name:?}"),
262		}
263
264		visitor.visit_unit()
265	}
266
267	#[cfg_attr(
268		unabridged,
269		tracing::instrument(level = "trace", skip(self, visitor))
270	)]
271	fn deserialize_newtype_struct<V>(self, name: &'static str, visitor: V) -> Result<V::Value>
272	where
273		V: Visitor<'de>,
274	{
275		match name {
276			| "$serde_json::private::RawValue" => visitor.visit_map(self),
277			| "Json" => visitor
278				.visit_newtype_struct(&mut serde_json::Deserializer::from_slice(
279					self.record_trail(),
280				))
281				.map_err(|e| Self::Error::SerdeDe(format!("{name}: {e}").into())),
282
283			| "Cbor" => visitor
284				.visit_newtype_struct(&mut minicbor_serde::Deserializer::new(self.record_trail()))
285				.map_err(|e| Self::Error::SerdeDe(format!("{name}: {e}").into())),
286
287			| _ => visitor.visit_newtype_struct(self),
288		}
289	}
290
291	#[cfg_attr(
292		unabridged,
293		tracing::instrument(level = "trace", skip(self, _visitor))
294	)]
295	fn deserialize_enum<V>(
296		self,
297		_name: &'static str,
298		_variants: &'static [&'static str],
299		_visitor: V,
300	) -> Result<V::Value>
301	where
302		V: Visitor<'de>,
303	{
304		unhandled!("deserialize Enum not implemented")
305	}
306
307	#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
308	fn deserialize_option<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
309		if self
310			.buf
311			.get(self.pos)
312			.is_none_or(|b| *b == Deserializer::SEP)
313		{
314			visitor.visit_none()
315		} else {
316			visitor.visit_some(self)
317		}
318	}
319
320	#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
321	fn deserialize_bool<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
322		unhandled!("deserialize bool not implemented")
323	}
324
325	#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
326	fn deserialize_i8<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
327		unhandled!("deserialize i8 not implemented")
328	}
329
330	#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
331	fn deserialize_i16<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
332		unhandled!("deserialize i16 not implemented")
333	}
334
335	#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
336	fn deserialize_i32<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
337		unhandled!("deserialize i32 not implemented")
338	}
339
340	#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
341	fn deserialize_i64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
342		const BYTES: usize = size_of::<i64>();
343
344		let end = self.pos.saturating_add(BYTES).min(self.buf.len());
345		let bytes: ArrayVec<u8, BYTES> = self.buf[self.pos..end].try_into()?;
346		let bytes = bytes
347			.into_inner()
348			.map_err(|_| Self::Error::SerdeDe("i64 buffer underflow".into()))?;
349
350		self.inc_pos(BYTES);
351		visitor.visit_i64(i64::from_be_bytes(bytes))
352	}
353
354	#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
355	fn deserialize_u8<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
356		unhandled!(
357			"deserialize u8 not implemented; try dereferencing the Handle for [u8] access \
358			 instead"
359		)
360	}
361
362	#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
363	fn deserialize_u16<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
364		unhandled!("deserialize u16 not implemented")
365	}
366
367	#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
368	fn deserialize_u32<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
369		unhandled!("deserialize u32 not implemented")
370	}
371
372	#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
373	fn deserialize_u64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
374		const BYTES: usize = size_of::<u64>();
375
376		let end = self.pos.saturating_add(BYTES).min(self.buf.len());
377		let bytes: ArrayVec<u8, BYTES> = self.buf[self.pos..end].try_into()?;
378		let bytes = bytes
379			.into_inner()
380			.map_err(|_| Self::Error::SerdeDe("u64 buffer underflow".into()))?;
381
382		self.inc_pos(BYTES);
383		visitor.visit_u64(u64::from_be_bytes(bytes))
384	}
385
386	#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
387	fn deserialize_f32<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
388		unhandled!("deserialize f32 not implemented")
389	}
390
391	#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
392	fn deserialize_f64<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
393		unhandled!("deserialize f64 not implemented")
394	}
395
396	#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
397	fn deserialize_char<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
398		unhandled!("deserialize char not implemented")
399	}
400
401	#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
402	fn deserialize_str<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
403		let input = self.record_next();
404		let out = deserialize_str(input)?;
405		visitor.visit_borrowed_str(out)
406	}
407
408	#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
409	fn deserialize_string<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
410		let input = self.record_next();
411		let out = string::string_from_bytes(input)?;
412		visitor.visit_string(out)
413	}
414
415	#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
416	fn deserialize_bytes<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
417		let input = self.record_trail();
418		visitor.visit_borrowed_bytes(input)
419	}
420
421	#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
422	fn deserialize_byte_buf<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
423		unhandled!("deserialize Byte Buf not implemented")
424	}
425
426	#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
427	fn deserialize_unit<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
428		unhandled!("deserialize Unit not implemented")
429	}
430
431	// this only used for $serde_json::private::RawValue at this time; see MapAccess
432	#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
433	fn deserialize_identifier<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
434		let input = "$serde_json::private::RawValue";
435		visitor.visit_borrowed_str(input)
436	}
437
438	#[cfg_attr(unabridged, tracing::instrument(level = "trace", skip_all))]
439	fn deserialize_ignored_any<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
440		unhandled!("deserialize Ignored Any not implemented")
441	}
442
443	#[cfg_attr(
444		unabridged,
445		tracing::instrument(level = "trace", skip_all, fields(?self.buf))
446	)]
447	fn deserialize_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
448		const TYPE_PRE_1_91: &str = "serde_json::value::de::<impl serde_core::de::Deserialize \
449		                             for serde_json::value::Value>::deserialize::ValueVisitor";
450		const TYPE: &str = "serde_json::value::de::<impl serde_core::de::Deserialize<'_> for \
451		                    serde_json::value::Value>::deserialize::ValueVisitor";
452		debug_assert!(
453			matches!(tuwunel_core::debug::type_name::<V>(), TYPE | TYPE_PRE_1_91),
454			"deserialize_any: type not expected {0}",
455			tuwunel_core::debug::type_name::<V>()
456		);
457
458		match self.record_peek_byte() {
459			| Some(b'{') => self.deserialize_map(visitor),
460			| Some(b'[') => serde_json::Deserializer::from_slice(self.record_next())
461				.deserialize_seq(visitor)
462				.map_err(Into::into),
463
464			| _ => self.deserialize_str(visitor),
465		}
466	}
467}
468
469impl<'a, 'de: 'a> de::SeqAccess<'de> for &'a mut Deserializer<'de> {
470	type Error = Error;
471
472	#[cfg_attr(
473		unabridged,
474		tracing::instrument(level = "trace", skip(self, seed))
475	)]
476	fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
477	where
478		T: DeserializeSeed<'de>,
479	{
480		// Finished parsing the input.
481		let finished = self.pos >= self.buf.len();
482
483		// Completely satisfied the output.
484		let complete = self.rec >= self.seq;
485
486		// Early-return only when both input and output are exhausted. If
487		// input is exhausted but the tuple is not, fall through:
488		// `record_start` does not advance pos and the inner deserializer
489		// runs against an empty slice. Tail types that visit an empty slice
490		// successfully (`&str` -> "", `&[u8]` -> &[], `Option<_>` -> None)
491		// round-trip; non-tolerant tails (numerics, typed Matrix IDs) error.
492		// This enables additive evolution of record-key tuples without a
493		// migration.
494		//
495		// Returning early before the input is exhausted trips the
496		// `finished()` check; before the tuple is exhausted, serde's length
497		// check.
498		if finished && complete {
499			return Ok(None);
500		}
501
502		self.record_start();
503		seed.deserialize(&mut **self).map(Some)
504	}
505}
506
507// this only used for $serde_json::private::RawValue at this time. our db
508// schema doesn't have its own map format; we use json for that anyway
509impl<'a, 'de: 'a> de::MapAccess<'de> for &'a mut Deserializer<'de> {
510	type Error = Error;
511
512	#[cfg_attr(
513		unabridged,
514		tracing::instrument(level = "trace", skip(self, seed))
515	)]
516	fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
517	where
518		K: DeserializeSeed<'de>,
519	{
520		seed.deserialize(&mut **self).map(Some)
521	}
522
523	#[cfg_attr(
524		unabridged,
525		tracing::instrument(level = "trace", skip(self, seed))
526	)]
527	fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
528	where
529		V: DeserializeSeed<'de>,
530	{
531		seed.deserialize(&mut **self)
532	}
533}
534
535// activate when stable; too soon now
536//#[cfg(debug_assertions)]
537#[inline]
538fn deserialize_str(input: &[u8]) -> Result<&str> { string::str_from_bytes(input) }
539
540//#[cfg(not(debug_assertions))]
541#[cfg(disable)]
542#[inline]
543fn deserialize_str(input: &[u8]) -> Result<&str> {
544	// SAFETY: Strings were written by the serializer to the database. Assuming no
545	// database corruption, the string will be valid. Database corruption is
546	// detected via rocksdb checksums.
547	unsafe { std::str::from_utf8_unchecked(input) }
548}