Skip to main content

tuwunel_service/resolver/
dns.rs

1use std::{net::SocketAddr, sync::Arc, time::Duration};
2
3use futures::FutureExt;
4use hickory_resolver::{
5	TokioResolver,
6	config::{ConnectionConfig, LookupIpStrategy, ResolverConfig, ResolverOpts},
7	lookup_ip::LookupIp,
8	net::runtime::TokioRuntimeProvider,
9};
10use reqwest::dns::{Addrs, Name, Resolve, Resolving};
11use tuwunel_core::{Result, Server, err, trace};
12
13use super::cache::{Cache, CachedOverride};
14
15pub struct Resolver {
16	pub(crate) resolver: Arc<TokioResolver>,
17	pub(crate) passthru: Arc<Passthru>,
18	pub(crate) hooked: Arc<Hooked>,
19	server: Arc<Server>,
20}
21
22pub(crate) struct Hooked {
23	resolver: Arc<TokioResolver>,
24	passthru: Arc<Passthru>,
25	cache: Arc<Cache>,
26	server: Arc<Server>,
27}
28
29pub(crate) struct Passthru {
30	resolver: Arc<TokioResolver>,
31	server: Arc<Server>,
32}
33
34type ResolvingResult = Result<Addrs, Box<dyn std::error::Error + Send + Sync>>;
35
36impl Resolver {
37	pub(super) fn build(server: &Arc<Server>, cache: Arc<Cache>) -> Result<Arc<Self>> {
38		let config = &server.config;
39
40		// Create the primary resolver.
41		let (conf, mut opts) = Self::configure(server)?;
42		opts.negative_min_ttl = Some(Duration::from_secs(config.dns_min_ttl_nxdomain));
43		opts.positive_min_ttl = Some(Duration::from_secs(config.dns_min_ttl));
44		opts.cache_size = config.dns_cache_entries.into();
45		let resolver = Self::create(server, conf.clone(), opts.clone())?;
46
47		// Create the passthru resolver with modified options.
48		let (conf, mut opts) = (conf, opts);
49		opts.negative_min_ttl = Some(Duration::ZERO);
50		opts.positive_min_ttl = Some(Duration::ZERO);
51		opts.cache_size = ResolverOpts::default().cache_size;
52		let passthru = Arc::new(Passthru {
53			resolver: Self::create(server, conf, opts)?,
54			server: server.clone(),
55		});
56
57		Ok(Arc::new(Self {
58			hooked: Arc::new(Hooked {
59				resolver: resolver.clone(),
60				passthru: passthru.clone(),
61				server: server.clone(),
62				cache,
63			}),
64			server: server.clone(),
65			passthru,
66			resolver,
67		}))
68	}
69
70	fn create(
71		server: &Arc<Server>,
72		conf: ResolverConfig,
73		opts: ResolverOpts,
74	) -> Result<Arc<TokioResolver>> {
75		let mut builder =
76			TokioResolver::builder_with_config(conf, TokioRuntimeProvider::default());
77		*builder.options_mut() = Self::configure_opts(server, opts);
78
79		builder
80			.build()
81			.map(Arc::new)
82			.map_err(|e| err!(error!("Failed to build DNS resolver: {e}")))
83	}
84
85	fn configure(server: &Arc<Server>) -> Result<(ResolverConfig, ResolverOpts)> {
86		let config = &server.config;
87		let (sys_conf, opts) =
88			hickory_resolver::system_conf::read_system_conf().map_err(|e| {
89				err!(error!("Failed to configure DNS resolver from `/etc/resolv.conf': {e}"))
90			})?;
91
92		let name_servers = sys_conf
93			.name_servers()
94			.iter()
95			.cloned()
96			.map(|mut ns| {
97				ns.trust_negative_responses = !config.query_all_nameservers;
98				if config.query_over_tcp_only {
99					ns.connections = vec![ConnectionConfig::tcp()];
100				}
101				ns
102			})
103			.collect();
104
105		let conf = ResolverConfig::from_parts(
106			sys_conf.domain().cloned(),
107			sys_conf.search().to_vec(),
108			name_servers,
109		);
110
111		Ok((conf, opts))
112	}
113
114	#[expect(clippy::as_conversions)]
115	fn configure_opts(server: &Arc<Server>, mut opts: ResolverOpts) -> ResolverOpts {
116		let config = &server.config;
117
118		opts.negative_max_ttl = Some(Duration::from_hours(720));
119		opts.positive_max_ttl = Some(Duration::from_hours(168));
120		opts.timeout = Duration::from_secs(config.dns_timeout);
121		opts.attempts = config.dns_attempts as usize;
122		opts.try_tcp_on_error = config.dns_tcp_fallback;
123		opts.num_concurrent_reqs = 1;
124		opts.edns0 = true;
125		opts.case_randomization = config.dns_case_randomization;
126		opts.preserve_intermediates = true;
127		opts.ip_strategy = match config.ip_lookup_strategy {
128			| 1 => LookupIpStrategy::Ipv4Only,
129			| 2 => LookupIpStrategy::Ipv6Only,
130			| 3 => LookupIpStrategy::Ipv4AndIpv6,
131			| 4 => LookupIpStrategy::Ipv6thenIpv4,
132			| _ => LookupIpStrategy::Ipv4thenIpv6,
133		};
134
135		opts
136	}
137
138	/// Clear the in-memory hickory-dns caches
139	#[inline]
140	pub fn clear_cache(&self) { self.resolver.clear_cache(); }
141}
142
143impl Resolve for Resolver {
144	fn resolve(&self, name: Name) -> Resolving {
145		let resolver = if self
146			.server
147			.config
148			.dns_passthru_domains
149			.is_match(name.as_str())
150		{
151			trace!(?name, "matched to passthru resolver");
152			&self.passthru.resolver
153		} else {
154			trace!(?name, "using primary resolver");
155			&self.resolver
156		};
157
158		resolve_to_reqwest(self.server.clone(), resolver.clone(), name).boxed()
159	}
160}
161
162impl Resolve for Hooked {
163	fn resolve(&self, name: Name) -> Resolving {
164		let resolver = if self
165			.server
166			.config
167			.dns_passthru_domains
168			.is_match(name.as_str())
169		{
170			trace!(?name, "matched to passthru resolver");
171			&self.passthru.resolver
172		} else {
173			trace!(?name, "using hooked resolver");
174			&self.resolver
175		};
176
177		hooked_resolve(self.cache.clone(), self.server.clone(), resolver.clone(), name).boxed()
178	}
179}
180
181impl Resolve for Passthru {
182	fn resolve(&self, name: Name) -> Resolving {
183		trace!(?name, "using passthru resolver");
184		resolve_to_reqwest(self.server.clone(), self.resolver.clone(), name).boxed()
185	}
186}
187
188#[tracing::instrument(
189	level = "debug",
190	skip_all,
191	fields(name = ?name.as_str())
192)]
193async fn hooked_resolve(
194	cache: Arc<Cache>,
195	server: Arc<Server>,
196	resolver: Arc<TokioResolver>,
197	name: Name,
198) -> Result<Addrs, Box<dyn std::error::Error + Send + Sync>> {
199	match cache.get_override(name.as_str()).await {
200		| Ok(cached) if cached.valid() => cached_to_reqwest(cached),
201		| Ok(CachedOverride { overriding, .. }) if overriding.is_some() =>
202			resolve_to_reqwest(
203				server,
204				resolver,
205				overriding
206					.as_deref()
207					.map(str::parse)
208					.expect("overriding is set for this record")
209					.expect("overriding is a valid internet name"),
210			)
211			.boxed()
212			.await,
213
214		| _ =>
215			resolve_to_reqwest(server, resolver, name)
216				.boxed()
217				.await,
218	}
219}
220
221async fn resolve_to_reqwest(
222	server: Arc<Server>,
223	resolver: Arc<TokioResolver>,
224	name: Name,
225) -> ResolvingResult {
226	use std::{io, io::ErrorKind::Interrupted};
227
228	let handle_shutdown = || Box::new(io::Error::new(Interrupted, "Server shutting down"));
229
230	let handle_results = |results: LookupIp| -> Addrs {
231		let addrs = results
232			.iter()
233			.map(|ip| SocketAddr::new(ip, 0))
234			.collect::<Vec<_>>()
235			.into_iter();
236
237		Box::new(addrs)
238	};
239
240	tokio::select! {
241		results = resolver.lookup_ip(name.as_str()) => Ok(handle_results(results?)),
242		() = server.until_shutdown() => Err(handle_shutdown()),
243	}
244}
245
246fn cached_to_reqwest(cached: CachedOverride) -> ResolvingResult {
247	let addrs = cached
248		.ips
249		.into_iter()
250		.map(move |ip| SocketAddr::new(ip, cached.port));
251
252	Ok(Box::new(addrs))
253}