Skip to main content

tuwunel_service/rooms/short/
mod.rs

1use std::{borrow::Borrow, mem::size_of_val, sync::Arc};
2
3use futures::{FutureExt, Stream, StreamExt, pin_mut};
4use ruma::{EventId, OwnedRoomId, RoomId, events::StateEventType};
5use serde::Deserialize;
6pub use tuwunel_core::matrix::{ShortEventId, ShortId, ShortRoomId, ShortStateKey};
7use tuwunel_core::{
8	Err, Result, err, implement,
9	matrix::StateKey,
10	utils,
11	utils::{IterStream, stream::ReadyExt},
12};
13use tuwunel_database::{Deserialized, Get, Map, Qry};
14
15pub struct Service {
16	db: Data,
17	services: Arc<crate::services::OnceServices>,
18}
19
20struct Data {
21	eventid_shorteventid: Arc<Map>,
22	shorteventid_eventid: Arc<Map>,
23	statekey_shortstatekey: Arc<Map>,
24	shortstatekey_statekey: Arc<Map>,
25	roomid_shortroomid: Arc<Map>,
26	statehash_shortstatehash: Arc<Map>,
27}
28
29pub type ShortStateHash = ShortId;
30
31impl crate::Service for Service {
32	fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
33		Ok(Arc::new(Self {
34			db: Data {
35				eventid_shorteventid: args.db["eventid_shorteventid"].clone(),
36				shorteventid_eventid: args.db["shorteventid_eventid"].clone(),
37				statekey_shortstatekey: args.db["statekey_shortstatekey"].clone(),
38				shortstatekey_statekey: args.db["shortstatekey_statekey"].clone(),
39				roomid_shortroomid: args.db["roomid_shortroomid"].clone(),
40				statehash_shortstatehash: args.db["statehash_shortstatehash"].clone(),
41			},
42			services: args.services.clone(),
43		}))
44	}
45
46	fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
47}
48
49#[implement(Service)]
50pub async fn get_or_create_shorteventid(&self, event_id: &EventId) -> ShortEventId {
51	if let Ok(shorteventid) = self.get_shorteventid(event_id).await {
52		return shorteventid;
53	}
54
55	self.create_shorteventid(event_id)
56}
57
58#[implement(Service)]
59pub fn multi_get_or_create_shorteventid<'a, I>(
60	&'a self,
61	event_ids: I,
62) -> impl Stream<Item = ShortEventId> + Send + '_
63where
64	I: Iterator<Item = &'a EventId> + Clone + Send + 'a,
65{
66	event_ids
67		.clone()
68		.stream()
69		.get(&self.db.eventid_shorteventid)
70		.zip(event_ids.into_iter().stream())
71		.map(|(result, event_id)| match result {
72			| Ok(ref short) => utils::u64_from_u8(short),
73			| Err(_) => self.create_shorteventid(event_id),
74		})
75}
76
77#[implement(Service)]
78fn create_shorteventid(&self, event_id: &EventId) -> ShortEventId {
79	const BUFSIZE: usize = size_of::<ShortEventId>();
80
81	let short = self.services.globals.next_count();
82	debug_assert!(size_of_val(&*short) == BUFSIZE, "buffer requirement changed");
83
84	self.db
85		.eventid_shorteventid
86		.raw_aput::<BUFSIZE, _, _>(event_id, *short);
87
88	self.db
89		.shorteventid_eventid
90		.aput_raw::<BUFSIZE, _, _>(*short, event_id);
91
92	*short
93}
94
95#[implement(Service)]
96pub async fn get_shorteventid(&self, event_id: &EventId) -> Result<ShortEventId> {
97	self.db
98		.eventid_shorteventid
99		.get(event_id)
100		.await
101		.deserialized()
102}
103
104#[implement(Service)]
105pub async fn get_or_create_shortstatekey(
106	&self,
107	event_type: &StateEventType,
108	state_key: &str,
109) -> ShortStateKey {
110	const BUFSIZE: usize = size_of::<ShortStateKey>();
111
112	if let Ok(shortstatekey) = self
113		.get_shortstatekey(event_type, state_key)
114		.await
115	{
116		return shortstatekey;
117	}
118
119	let key = (event_type, state_key);
120	let shortstatekey = self.services.globals.next_count();
121
122	debug_assert!(size_of_val(&*shortstatekey) == BUFSIZE, "buffer requirement changed");
123
124	self.db
125		.statekey_shortstatekey
126		.put_aput::<BUFSIZE, _, _>(key, *shortstatekey);
127
128	self.db
129		.shortstatekey_statekey
130		.aput_put::<BUFSIZE, _, _>(*shortstatekey, key);
131
132	*shortstatekey
133}
134
135#[implement(Service)]
136pub async fn get_shortstatekey(
137	&self,
138	event_type: &StateEventType,
139	state_key: &str,
140) -> Result<ShortStateKey> {
141	let key = (event_type, state_key);
142	self.db
143		.statekey_shortstatekey
144		.qry(&key)
145		.await
146		.deserialized()
147}
148
149#[implement(Service)]
150pub async fn get_eventid_from_short<Id>(&self, shorteventid: ShortEventId) -> Result<Id>
151where
152	Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned,
153	<Id as ToOwned>::Owned: Borrow<EventId>,
154{
155	const BUFSIZE: usize = size_of::<ShortEventId>();
156
157	self.db
158		.shorteventid_eventid
159		.aqry::<BUFSIZE, _>(&shorteventid)
160		.await
161		.deserialized()
162		.map_err(|e| err!(Database("Failed to find EventId from short {shorteventid:?}: {e:?}")))
163}
164
165#[implement(Service)]
166pub fn multi_get_eventid_from_short<'a, Id, S>(
167	&'a self,
168	shorteventid: S,
169) -> impl Stream<Item = Result<Id>> + Send + 'a
170where
171	S: Stream<Item = ShortEventId> + Send + 'a,
172	Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned + 'a,
173	<Id as ToOwned>::Owned: Borrow<EventId>,
174{
175	shorteventid
176		.qry(&self.db.shorteventid_eventid)
177		.map(Deserialized::deserialized)
178}
179
180#[implement(Service)]
181pub async fn get_statekey_from_short(
182	&self,
183	shortstatekey: ShortStateKey,
184) -> Result<(StateEventType, StateKey)> {
185	const BUFSIZE: usize = size_of::<ShortStateKey>();
186
187	self.db
188		.shortstatekey_statekey
189		.aqry::<BUFSIZE, _>(&shortstatekey)
190		.await
191		.deserialized()
192		.map_err(|e| {
193			err!(Database(
194				"Failed to find (StateEventType, state_key) from short {shortstatekey:?}: {e:?}"
195			))
196		})
197}
198
199#[implement(Service)]
200pub fn multi_get_statekey_from_short<'a, S>(
201	&'a self,
202	shortstatekey: S,
203) -> impl Stream<Item = Result<(StateEventType, StateKey)>> + Send + 'a
204where
205	S: Stream<Item = ShortStateKey> + Send + 'a,
206{
207	shortstatekey
208		.qry(&self.db.shortstatekey_statekey)
209		.map(Deserialized::deserialized)
210}
211
212/// Returns (shortstatehash, already_existed)
213#[implement(Service)]
214pub async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> (ShortStateHash, bool) {
215	const BUFSIZE: usize = size_of::<ShortStateHash>();
216
217	if let Ok(shortstatehash) = self
218		.db
219		.statehash_shortstatehash
220		.get(state_hash)
221		.await
222		.deserialized()
223	{
224		return (shortstatehash, true);
225	}
226
227	let shortstatehash = self.services.globals.next_count();
228	debug_assert!(size_of_val(&*shortstatehash) == BUFSIZE, "buffer requirement changed");
229
230	self.db
231		.statehash_shortstatehash
232		.raw_aput::<BUFSIZE, _, _>(state_hash, *shortstatehash);
233
234	(*shortstatehash, false)
235}
236
237#[implement(Service)]
238pub async fn get_shortroomid(&self, room_id: &RoomId) -> Result<ShortRoomId> {
239	self.db
240		.roomid_shortroomid
241		.get(room_id)
242		.await
243		.deserialized()
244}
245
246#[implement(Service)]
247pub async fn get_roomid_from_short(&self, shortroomid_: ShortRoomId) -> Result<OwnedRoomId> {
248	let stream = self
249		.db
250		.roomid_shortroomid
251		.stream()
252		.ready_filter_map(Result::ok);
253
254	pin_mut!(stream);
255	stream
256		.ready_find(|&(_, shortroomid)| shortroomid == shortroomid_)
257		.map(|found| found.map(|(room_id, _): (&RoomId, ShortRoomId)| room_id.to_owned()))
258		.await
259		.ok_or_else(|| err!(Database("Failed to find RoomId from {shortroomid_:?}")))
260}
261
262#[implement(Service)]
263pub async fn get_or_create_shortroomid(&self, room_id: &RoomId) -> ShortRoomId {
264	self.db
265		.roomid_shortroomid
266		.get(room_id)
267		.await
268		.deserialized()
269		.unwrap_or_else(|_| {
270			const BUFSIZE: usize = size_of::<ShortRoomId>();
271
272			let short = self.services.globals.next_count();
273			debug_assert!(size_of_val(&*short) == BUFSIZE, "buffer requirement changed");
274
275			self.db
276				.roomid_shortroomid
277				.raw_aput::<BUFSIZE, _, _>(room_id, *short);
278
279			*short
280		})
281}
282
283#[implement(Service)]
284pub async fn delete_shortroomid(&self, room_id: &RoomId) -> Result {
285	if self
286		.db
287		.roomid_shortroomid
288		.exists(room_id)
289		.await
290		.is_ok()
291	{
292		self.db.roomid_shortroomid.remove(room_id);
293		Ok(())
294	} else {
295		Err!(Database("not found"))
296	}
297}