tuwunel_service/resolver/
dns.rs1use std::{net::SocketAddr, sync::Arc, time::Duration};
2
3use futures::FutureExt;
4use hickory_resolver::{
5 TokioResolver,
6 config::{ConnectionConfig, LookupIpStrategy, ResolverConfig, ResolverOpts},
7 lookup_ip::LookupIp,
8 net::runtime::TokioRuntimeProvider,
9};
10use reqwest::dns::{Addrs, Name, Resolve, Resolving};
11use tuwunel_core::{Result, Server, err, trace};
12
13use super::cache::{Cache, CachedOverride};
14
15pub struct Resolver {
16 pub(crate) resolver: Arc<TokioResolver>,
17 pub(crate) passthru: Arc<Passthru>,
18 pub(crate) hooked: Arc<Hooked>,
19 server: Arc<Server>,
20}
21
22pub(crate) struct Hooked {
23 resolver: Arc<TokioResolver>,
24 passthru: Arc<Passthru>,
25 cache: Arc<Cache>,
26 server: Arc<Server>,
27}
28
29pub(crate) struct Passthru {
30 resolver: Arc<TokioResolver>,
31 server: Arc<Server>,
32}
33
34type ResolvingResult = Result<Addrs, Box<dyn std::error::Error + Send + Sync>>;
35
36impl Resolver {
37 pub(super) fn build(server: &Arc<Server>, cache: Arc<Cache>) -> Result<Arc<Self>> {
38 let config = &server.config;
39
40 let (conf, mut opts) = Self::configure(server)?;
42 opts.negative_min_ttl = Some(Duration::from_secs(config.dns_min_ttl_nxdomain));
43 opts.positive_min_ttl = Some(Duration::from_secs(config.dns_min_ttl));
44 opts.cache_size = config.dns_cache_entries.into();
45 let resolver = Self::create(server, conf.clone(), opts.clone())?;
46
47 let (conf, mut opts) = (conf, opts);
49 opts.negative_min_ttl = Some(Duration::ZERO);
50 opts.positive_min_ttl = Some(Duration::ZERO);
51 opts.cache_size = ResolverOpts::default().cache_size;
52 let passthru = Arc::new(Passthru {
53 resolver: Self::create(server, conf, opts)?,
54 server: server.clone(),
55 });
56
57 Ok(Arc::new(Self {
58 hooked: Arc::new(Hooked {
59 resolver: resolver.clone(),
60 passthru: passthru.clone(),
61 server: server.clone(),
62 cache,
63 }),
64 server: server.clone(),
65 passthru,
66 resolver,
67 }))
68 }
69
70 fn create(
71 server: &Arc<Server>,
72 conf: ResolverConfig,
73 opts: ResolverOpts,
74 ) -> Result<Arc<TokioResolver>> {
75 let mut builder =
76 TokioResolver::builder_with_config(conf, TokioRuntimeProvider::default());
77 *builder.options_mut() = Self::configure_opts(server, opts);
78
79 builder
80 .build()
81 .map(Arc::new)
82 .map_err(|e| err!(error!("Failed to build DNS resolver: {e}")))
83 }
84
85 fn configure(server: &Arc<Server>) -> Result<(ResolverConfig, ResolverOpts)> {
86 let config = &server.config;
87 let (sys_conf, opts) =
88 hickory_resolver::system_conf::read_system_conf().map_err(|e| {
89 err!(error!("Failed to configure DNS resolver from `/etc/resolv.conf': {e}"))
90 })?;
91
92 let name_servers = sys_conf
93 .name_servers()
94 .iter()
95 .cloned()
96 .map(|mut ns| {
97 ns.trust_negative_responses = !config.query_all_nameservers;
98 if config.query_over_tcp_only {
99 ns.connections = vec![ConnectionConfig::tcp()];
100 }
101 ns
102 })
103 .collect();
104
105 let conf = ResolverConfig::from_parts(
106 sys_conf.domain().cloned(),
107 sys_conf.search().to_vec(),
108 name_servers,
109 );
110
111 Ok((conf, opts))
112 }
113
114 #[expect(clippy::as_conversions)]
115 fn configure_opts(server: &Arc<Server>, mut opts: ResolverOpts) -> ResolverOpts {
116 let config = &server.config;
117
118 opts.negative_max_ttl = Some(Duration::from_hours(720));
119 opts.positive_max_ttl = Some(Duration::from_hours(168));
120 opts.timeout = Duration::from_secs(config.dns_timeout);
121 opts.attempts = config.dns_attempts as usize;
122 opts.try_tcp_on_error = config.dns_tcp_fallback;
123 opts.num_concurrent_reqs = 1;
124 opts.edns0 = true;
125 opts.case_randomization = config.dns_case_randomization;
126 opts.preserve_intermediates = true;
127 opts.ip_strategy = match config.ip_lookup_strategy {
128 | 1 => LookupIpStrategy::Ipv4Only,
129 | 2 => LookupIpStrategy::Ipv6Only,
130 | 3 => LookupIpStrategy::Ipv4AndIpv6,
131 | 4 => LookupIpStrategy::Ipv6thenIpv4,
132 | _ => LookupIpStrategy::Ipv4thenIpv6,
133 };
134
135 opts
136 }
137
138 #[inline]
140 pub fn clear_cache(&self) { self.resolver.clear_cache(); }
141}
142
143impl Resolve for Resolver {
144 fn resolve(&self, name: Name) -> Resolving {
145 let resolver = if self
146 .server
147 .config
148 .dns_passthru_domains
149 .is_match(name.as_str())
150 {
151 trace!(?name, "matched to passthru resolver");
152 &self.passthru.resolver
153 } else {
154 trace!(?name, "using primary resolver");
155 &self.resolver
156 };
157
158 resolve_to_reqwest(self.server.clone(), resolver.clone(), name).boxed()
159 }
160}
161
162impl Resolve for Hooked {
163 fn resolve(&self, name: Name) -> Resolving {
164 let resolver = if self
165 .server
166 .config
167 .dns_passthru_domains
168 .is_match(name.as_str())
169 {
170 trace!(?name, "matched to passthru resolver");
171 &self.passthru.resolver
172 } else {
173 trace!(?name, "using hooked resolver");
174 &self.resolver
175 };
176
177 hooked_resolve(self.cache.clone(), self.server.clone(), resolver.clone(), name).boxed()
178 }
179}
180
181impl Resolve for Passthru {
182 fn resolve(&self, name: Name) -> Resolving {
183 trace!(?name, "using passthru resolver");
184 resolve_to_reqwest(self.server.clone(), self.resolver.clone(), name).boxed()
185 }
186}
187
188#[tracing::instrument(
189 level = "debug",
190 skip_all,
191 fields(name = ?name.as_str())
192)]
193async fn hooked_resolve(
194 cache: Arc<Cache>,
195 server: Arc<Server>,
196 resolver: Arc<TokioResolver>,
197 name: Name,
198) -> Result<Addrs, Box<dyn std::error::Error + Send + Sync>> {
199 match cache.get_override(name.as_str()).await {
200 | Ok(cached) if cached.valid() => cached_to_reqwest(cached),
201 | Ok(CachedOverride { overriding, .. }) if overriding.is_some() =>
202 resolve_to_reqwest(
203 server,
204 resolver,
205 overriding
206 .as_deref()
207 .map(str::parse)
208 .expect("overriding is set for this record")
209 .expect("overriding is a valid internet name"),
210 )
211 .boxed()
212 .await,
213
214 | _ =>
215 resolve_to_reqwest(server, resolver, name)
216 .boxed()
217 .await,
218 }
219}
220
221async fn resolve_to_reqwest(
222 server: Arc<Server>,
223 resolver: Arc<TokioResolver>,
224 name: Name,
225) -> ResolvingResult {
226 use std::{io, io::ErrorKind::Interrupted};
227
228 let handle_shutdown = || Box::new(io::Error::new(Interrupted, "Server shutting down"));
229
230 let handle_results = |results: LookupIp| -> Addrs {
231 let addrs = results
232 .iter()
233 .map(|ip| SocketAddr::new(ip, 0))
234 .collect::<Vec<_>>()
235 .into_iter();
236
237 Box::new(addrs)
238 };
239
240 tokio::select! {
241 results = resolver.lookup_ip(name.as_str()) => Ok(handle_results(results?)),
242 () = server.until_shutdown() => Err(handle_shutdown()),
243 }
244}
245
246fn cached_to_reqwest(cached: CachedOverride) -> ResolvingResult {
247 let addrs = cached
248 .ips
249 .into_iter()
250 .map(move |ip| SocketAddr::new(ip, cached.port));
251
252 Ok(Box::new(addrs))
253}