1use std::{
2 collections::{BTreeSet, VecDeque},
3 convert::identity,
4 str::FromStr,
5};
6
7use axum::extract::State;
8use futures::{
9 StreamExt,
10 future::ready,
11 stream::{once, unfold},
12};
13use ruma::{
14 OwnedRoomId, OwnedServerName, RoomId, UInt, UserId, api::client::space::get_hierarchy,
15};
16use tuwunel_core::{
17 Err, Result, debug_error, error,
18 smallvec::SmallVec,
19 trace,
20 utils::{
21 BoolExt,
22 stream::{IterStream, ReadyExt, WidebandExt},
23 },
24};
25use tuwunel_service::{
26 Services,
27 rooms::{
28 short::ShortRoomId,
29 spaces::{
30 Accessibility, Identifier, PaginationToken, get_parent_children_via,
31 is_summary_serializable, summary_to_chunk,
32 },
33 },
34};
35
36use crate::Ruma;
37
38pub(crate) async fn get_hierarchy_route(
43 State(services): State<crate::State>,
44 body: Ruma<get_hierarchy::v1::Request>,
45) -> Result<get_hierarchy::v1::Response> {
46 let limit = body
47 .limit
48 .unwrap_or_else(|| UInt::from(10_u32))
49 .min(UInt::from(100_u32));
50
51 let max_depth = body
52 .max_depth
53 .unwrap_or_else(|| UInt::from(3_u32))
54 .min(UInt::from(10_u32));
55
56 let key = body
57 .from
58 .as_ref()
59 .and_then(|s| PaginationToken::from_str(s).ok());
60
61 if let Some(ref token) = key
63 && (token.suggested_only != body.suggested_only || token.max_depth != max_depth)
64 {
65 return Err!(Request(InvalidParam(
66 "suggested_only and max_depth cannot change on paginated requests"
67 )));
68 }
69
70 get_client_hierarchy(
71 &services,
72 body.sender_user(),
73 &body.room_id,
74 limit.try_into().unwrap_or(10),
75 max_depth.try_into().unwrap_or(usize::MAX),
76 body.suggested_only,
77 key.as_ref()
78 .map(|t| t.short_room_ids.as_slice())
79 .unwrap_or_default(),
80 )
81 .await
82}
83
84async fn get_client_hierarchy(
85 services: &Services,
86 sender_user: &UserId,
87 room_id: &RoomId,
88 limit: usize,
89 max_depth: usize,
90 suggested_only: bool,
91 skip_room_ids: &[ShortRoomId],
92) -> Result<get_hierarchy::v1::Response> {
93 type Via = SmallVec<[OwnedServerName; 1]>;
94 type QueueItem = (OwnedRoomId, Via, usize);
95
96 let root_via: Via = room_id
99 .server_name()
100 .map(ToOwned::to_owned)
101 .into_iter()
102 .collect();
103
104 let root_summary = match services
105 .spaces
106 .get_summary_and_children(room_id, &Identifier::UserId(sender_user), &root_via)
107 .await
108 {
109 | Err(e) => {
110 debug_error!(?room_id, "space hierarchy root: {e}");
111 return Err(e);
112 },
113 | Ok(Accessibility::Inaccessible) => {
114 return Err!(Request(Forbidden(debug_error!("The requested room is inaccessible."))));
115 },
116 | Ok(Accessibility::Accessible(s)) => s,
117 };
118
119 let initial_queue: VecDeque<QueueItem> = max_depth
122 .gt(&0)
123 .then(|| {
124 get_parent_children_via(&root_summary, suggested_only)
125 .filter(|(room_id_, _)| room_id.ne(room_id_))
126 .map(|(room_id, via)| (room_id, via.collect(), 1_usize))
127 })
128 .into_iter()
129 .flatten()
130 .collect();
131
132 let skip_ids: BTreeSet<ShortRoomId> = skip_room_ids.iter().copied().collect();
135
136 let initial_state = (initial_queue, BTreeSet::from([room_id.to_owned()]));
137
138 let rooms = once(ready(Some(root_summary)))
141 .chain(unfold(initial_state, async |(mut queue, mut visited)| {
142 let (current_room, via, depth) = queue.pop_front()?;
143
144 if visited.contains(¤t_room) {
147 return Some((None, (queue, visited)));
148 }
149
150 match services
151 .spaces
152 .get_summary_and_children(¤t_room, &Identifier::UserId(sender_user), &via)
153 .await
154 {
155 | Err(e) if !e.is_not_found() => {
156 error!(?current_room, ?depth, "space child error: {e}");
157
158 Some((None, (queue, visited)))
159 },
160 | Err(_) | Ok(Accessibility::Inaccessible) => {
161 trace!(?current_room, ?depth, "child inaccessible or not found");
162
163 Some((None, (queue, visited)))
164 },
165 | Ok(Accessibility::Accessible(s)) => {
166 visited.insert(current_room);
167
168 if depth < max_depth {
170 get_parent_children_via(&s, suggested_only)
171 .filter(|(child, _)| !visited.contains(child))
172 .for_each(|(child, via)| {
173 queue.push_back((child, via.collect(), depth.saturating_add(1)));
174 });
175 }
176
177 Some((Some(s), (queue, visited)))
178 },
179 }
180 }))
181 .ready_filter_map(identity)
182 .wide_filter_map(async |summary| {
183 skip_ids
184 .is_empty()
185 .is_false()
186 .then_async(async || {
187 services
188 .short
189 .get_shortroomid(&summary.summary.room_id)
190 .await
191 .ok()
192 .filter(|shortid| skip_ids.contains(shortid))
193 })
194 .await
195 .flatten()
196 .is_none()
197 .then_some(summary)
198 .filter(is_summary_serializable)
199 .map(summary_to_chunk)
200 })
201 .take(limit)
202 .collect::<Vec<_>>()
203 .await;
204
205 let next_batch = (limit > 0 && rooms.len() >= limit)
209 .then_async(async || {
210 let next_skip = skip_room_ids
211 .iter()
212 .copied()
213 .stream()
214 .chain(rooms.iter().stream().then(async |chunk| {
215 services
221 .short
222 .get_or_create_shortroomid(&chunk.summary.room_id)
223 .await
224 }))
225 .collect::<Vec<_>>()
226 .await;
227
228 (next_skip.len() > skip_room_ids.len()).then_some(PaginationToken {
232 suggested_only,
233 short_room_ids: next_skip,
234 limit: limit.try_into().unwrap_or_default(),
235 max_depth: max_depth.try_into().unwrap_or_default(),
236 })
237 })
238 .await
239 .flatten()
240 .as_ref()
241 .map(ToString::to_string);
242
243 Ok(get_hierarchy::v1::Response { rooms, next_batch })
244}