Skip to main content

tuwunel_api/client/
search.rs

1use std::collections::BTreeMap;
2
3use axum::extract::State;
4use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, future::join};
5use ruma::{
6	OwnedRoomId, RoomId, UInt, UserId,
7	api::client::search::search_events::{
8		self,
9		v3::{
10			Criteria, EventContext, EventContextResult, ResultCategories, ResultRoomEvents,
11			SearchResult,
12		},
13	},
14	events::AnyStateEvent,
15	serde::Raw,
16};
17use search_events::v3::{Request, Response};
18use tuwunel_core::{
19	Err, Result, at, is_true,
20	matrix::Event,
21	result::FlatOk,
22	utils::{
23		IterStream,
24		option::OptionExt,
25		stream::{ReadyExt, TryIgnore, WidebandExt},
26	},
27};
28use tuwunel_service::{
29	Services,
30	rooms::{search::RoomQuery, timeline::PdusIterItem},
31};
32
33use crate::{Ruma, client::message::visibility_filter};
34
35type RoomStates = BTreeMap<OwnedRoomId, RoomState>;
36type RoomState = Vec<Raw<AnyStateEvent>>;
37
38const LIMIT_DEFAULT: usize = 10;
39const LIMIT_MAX: usize = 100;
40const BATCH_MAX: usize = 20;
41
42/// # `POST /_matrix/client/r0/search`
43///
44/// Searches rooms for messages.
45///
46/// - Only works if the user is currently joined to the room (TODO: Respect
47///   history visibility)
48pub(crate) async fn search_events_route(
49	State(services): State<crate::State>,
50	body: Ruma<Request>,
51) -> Result<Response> {
52	let sender_user = body.sender_user();
53	let next_batch = body.next_batch.as_deref();
54	let room_events = body
55		.search_categories
56		.room_events
57		.as_ref()
58		.map_async(|criteria| category_room_events(&services, sender_user, next_batch, criteria))
59		.await
60		.transpose()?;
61
62	Ok(Response {
63		search_categories: ResultCategories {
64			room_events: room_events.unwrap_or_default(),
65		},
66	})
67}
68
69#[expect(clippy::map_unwrap_or)]
70async fn category_room_events(
71	services: &Services,
72	sender_user: &UserId,
73	next_batch: Option<&str>,
74	criteria: &Criteria,
75) -> Result<ResultRoomEvents> {
76	let filter = &criteria.filter;
77
78	let limit: usize = filter
79		.limit
80		.map(TryInto::try_into)
81		.flat_ok()
82		.unwrap_or(LIMIT_DEFAULT)
83		.min(LIMIT_MAX);
84
85	let next_batch: usize = next_batch
86		.map(str::parse)
87		.transpose()?
88		.unwrap_or(0)
89		.min(limit.saturating_mul(BATCH_MAX));
90
91	let rooms = filter
92		.rooms
93		.clone()
94		.map(IntoIterator::into_iter)
95		.map(IterStream::stream)
96		.map(StreamExt::boxed)
97		.unwrap_or_else(|| {
98			services
99				.state_cache
100				.rooms_joined(sender_user)
101				.map(ToOwned::to_owned)
102				.boxed()
103		});
104
105	let results: Vec<_> = rooms
106		.filter_map(async |room_id| {
107			check_room_visible(services, sender_user, &room_id, criteria)
108				.await
109				.is_ok()
110				.then_some(room_id)
111		})
112		.filter_map(async |room_id| {
113			let query = RoomQuery {
114				room_id: &room_id,
115				user_id: Some(sender_user),
116				criteria,
117				skip: next_batch,
118				limit,
119			};
120
121			let (count, results) = services.search.search_pdus(&query).await.ok()?;
122
123			results
124				.collect::<Vec<_>>()
125				.map(|results| (room_id.clone(), count, results))
126				.map(Some)
127				.await
128		})
129		.collect()
130		.await;
131
132	let total: UInt = results
133		.iter()
134		.fold(0, |a: usize, (_, count, _)| a.saturating_add(*count))
135		.try_into()?;
136
137	let state: RoomStates = results
138		.iter()
139		.stream()
140		.ready_filter(|_| criteria.include_state.is_some_and(is_true!()))
141		.filter_map(async |(room_id, ..)| {
142			procure_room_state(services, room_id)
143				.map_ok(|state| (room_id.clone(), state))
144				.await
145				.ok()
146		})
147		.collect()
148		.await;
149
150	let results: Vec<SearchResult> = results
151		.into_iter()
152		.map(at!(2))
153		.flatten()
154		.stream()
155		.map(Event::into_pdu)
156		.wide_then(async |pdu| {
157			let context =
158				event_context(services, sender_user, &pdu, &criteria.event_context).await;
159
160			let pdu = services
161				.pdu_metadata
162				.bundle_aggregations(sender_user, pdu)
163				.await;
164
165			SearchResult {
166				rank: None,
167				result: Some(pdu.into_format()),
168				context,
169			}
170		})
171		.collect()
172		.await;
173
174	let highlights = criteria
175		.search_term
176		.split_terminator(|c: char| !c.is_alphanumeric())
177		.map(str::to_lowercase)
178		.collect();
179
180	let next_batch = (results.len() >= limit)
181		.then_some(next_batch.saturating_add(results.len()))
182		.as_ref()
183		.map(ToString::to_string);
184
185	Ok(ResultRoomEvents {
186		count: Some(total),
187		next_batch,
188		results,
189		state,
190		highlights,
191		groups: Default::default(), // TODO
192	})
193}
194
195async fn event_context<E>(
196	services: &Services,
197	sender_user: &UserId,
198	pdu: &E,
199	event_context: &EventContext,
200) -> EventContextResult
201where
202	E: Event,
203{
204	// An absent event_context deserializes to the default 5/5; treat that as no
205	// request.
206	if event_context.is_default() {
207		return EventContextResult::default();
208	}
209
210	let Ok(base_count) = services
211		.timeline
212		.get_pdu_count(pdu.event_id())
213		.await
214	else {
215		return EventContextResult::default();
216	};
217
218	let room_id = pdu.room_id();
219	let before_limit: usize = event_context.before_limit.try_into().unwrap_or(0);
220	let after_limit: usize = event_context.after_limit.try_into().unwrap_or(0);
221
222	let events_before = collect_context_half(
223		services,
224		services
225			.timeline
226			.pdus_rev(Some(sender_user), room_id, Some(base_count)),
227		sender_user,
228		before_limit,
229	);
230
231	let events_after = collect_context_half(
232		services,
233		services
234			.timeline
235			.pdus(Some(sender_user), room_id, Some(base_count)),
236		sender_user,
237		after_limit,
238	);
239
240	let (events_before, events_after) = join(events_before, events_after).await;
241
242	let start = events_before
243		.last()
244		.map(at!(0))
245		.or(Some(base_count))
246		.as_ref()
247		.map(ToString::to_string);
248
249	let end = events_after
250		.last()
251		.map(at!(0))
252		.or_else(|| Some(base_count.saturating_add(1)))
253		.as_ref()
254		.map(ToString::to_string);
255
256	let events_before = events_before
257		.into_iter()
258		.map(at!(1))
259		.map(Event::into_format)
260		.collect();
261
262	let events_after = events_after
263		.into_iter()
264		.map(at!(1))
265		.map(Event::into_format)
266		.collect();
267
268	EventContextResult {
269		start,
270		end,
271		events_before,
272		events_after,
273		profile_info: BTreeMap::new(),
274	}
275}
276
277async fn collect_context_half<'a, S>(
278	services: &'a Services,
279	pdus: S,
280	sender_user: &'a UserId,
281	take: usize,
282) -> Vec<PdusIterItem>
283where
284	S: Stream<Item = Result<PdusIterItem>> + Send + 'a,
285{
286	pdus.ignore_err()
287		.wide_filter_map(|item| visibility_filter(services, item, sender_user))
288		.take(take)
289		.wide_then(async |(count, pdu)| {
290			let pdu = services
291				.pdu_metadata
292				.bundle_aggregations(sender_user, pdu)
293				.await;
294
295			(count, pdu)
296		})
297		.collect()
298		.await
299}
300
301async fn procure_room_state(services: &Services, room_id: &RoomId) -> Result<RoomState> {
302	let state = services
303		.state_accessor
304		.room_state_full_pdus(room_id)
305		.map_ok(Event::into_format)
306		.try_collect()
307		.await?;
308
309	Ok(state)
310}
311
312async fn check_room_visible(
313	services: &Services,
314	user_id: &UserId,
315	room_id: &RoomId,
316	search: &Criteria,
317) -> Result {
318	let check_visible = search.filter.rooms.is_some();
319	let check_state = check_visible && search.include_state.is_some_and(is_true!());
320
321	let is_joined = !check_visible
322		|| services
323			.state_cache
324			.is_joined(user_id, room_id)
325			.await;
326
327	let state_visible = !check_state
328		|| services
329			.state_accessor
330			.user_can_see_state_events(user_id, room_id)
331			.await;
332
333	if !is_joined || !state_visible {
334		return Err!(Request(Forbidden("You don't have permission to view {room_id:?}")));
335	}
336
337	Ok(())
338}