Skip to main content

tuwunel_service/resolver/
dns.rs

1use std::{io, io::ErrorKind::PermissionDenied, net::SocketAddr, sync::Arc, time::Duration};
2
3use futures::FutureExt;
4use hickory_resolver::{
5	TokioResolver,
6	config::{ConnectionConfig, LookupIpStrategy, ResolverConfig, ResolverOpts},
7	lookup_ip::LookupIp,
8	net::runtime::TokioRuntimeProvider,
9};
10use ipaddress::IPAddress;
11use reqwest::dns::{Addrs, Name, Resolve, Resolving};
12use tuwunel_core::{Result, Server, err, trace};
13
14use super::cache::{Cache, CachedOverride};
15use crate::client::ipaddress_from_std;
16
17pub struct Resolver {
18	pub(crate) resolver: Arc<TokioResolver>,
19	pub(crate) passthru: Arc<Passthru>,
20	pub(crate) hooked: Arc<Hooked>,
21	server: Arc<Server>,
22}
23
24/// Inner resolver wrapper that drops addresses matching the configured CIDR
25/// denylist before reqwest opens a connection.
26pub struct Validating<R> {
27	inner: Arc<R>,
28	denylist: Arc<[IPAddress]>,
29}
30
31pub(crate) struct Hooked {
32	resolver: Arc<TokioResolver>,
33	passthru: Arc<Passthru>,
34	cache: Arc<Cache>,
35	server: Arc<Server>,
36}
37
38pub(crate) struct Passthru {
39	resolver: Arc<TokioResolver>,
40	server: Arc<Server>,
41}
42
43type ResolvingResult = Result<Addrs, Box<dyn std::error::Error + Send + Sync>>;
44
45impl Resolver {
46	pub(super) fn build(server: &Arc<Server>, cache: Arc<Cache>) -> Result<Arc<Self>> {
47		let config = &server.config;
48
49		// Create the primary resolver.
50		let (conf, mut opts) = Self::configure(server)?;
51		opts.negative_min_ttl = Some(Duration::from_secs(config.dns_min_ttl_nxdomain));
52		opts.positive_min_ttl = Some(Duration::from_secs(config.dns_min_ttl));
53		opts.cache_size = config.dns_cache_entries.into();
54		let resolver = Self::create(server, conf.clone(), opts.clone())?;
55
56		// Create the passthru resolver with modified options.
57		let (conf, mut opts) = (conf, opts);
58		opts.negative_min_ttl = Some(Duration::ZERO);
59		opts.positive_min_ttl = Some(Duration::ZERO);
60		opts.cache_size = ResolverOpts::default().cache_size;
61		let passthru = Arc::new(Passthru {
62			resolver: Self::create(server, conf, opts)?,
63			server: server.clone(),
64		});
65
66		Ok(Arc::new(Self {
67			hooked: Arc::new(Hooked {
68				resolver: resolver.clone(),
69				passthru: passthru.clone(),
70				server: server.clone(),
71				cache,
72			}),
73			server: server.clone(),
74			passthru,
75			resolver,
76		}))
77	}
78
79	fn create(
80		server: &Arc<Server>,
81		conf: ResolverConfig,
82		opts: ResolverOpts,
83	) -> Result<Arc<TokioResolver>> {
84		let mut builder =
85			TokioResolver::builder_with_config(conf, TokioRuntimeProvider::default());
86		*builder.options_mut() = Self::configure_opts(server, opts);
87
88		builder
89			.build()
90			.map(Arc::new)
91			.map_err(|e| err!(error!("Failed to build DNS resolver: {e}")))
92	}
93
94	fn configure(server: &Arc<Server>) -> Result<(ResolverConfig, ResolverOpts)> {
95		let config = &server.config;
96		let (sys_conf, opts) =
97			hickory_resolver::system_conf::read_system_conf().map_err(|e| {
98				err!(error!("Failed to configure DNS resolver from `/etc/resolv.conf': {e}"))
99			})?;
100
101		let name_servers = sys_conf
102			.name_servers()
103			.iter()
104			.cloned()
105			.map(|mut ns| {
106				ns.trust_negative_responses = !config.query_all_nameservers;
107				if config.query_over_tcp_only {
108					ns.connections = vec![ConnectionConfig::tcp()];
109				}
110				ns
111			})
112			.collect();
113
114		let conf = ResolverConfig::from_parts(
115			sys_conf.domain().cloned(),
116			sys_conf.search().to_vec(),
117			name_servers,
118		);
119
120		Ok((conf, opts))
121	}
122
123	#[expect(clippy::as_conversions)]
124	fn configure_opts(server: &Arc<Server>, mut opts: ResolverOpts) -> ResolverOpts {
125		let config = &server.config;
126
127		opts.negative_max_ttl = Some(Duration::from_hours(720));
128		opts.positive_max_ttl = Some(Duration::from_hours(168));
129		opts.timeout = Duration::from_secs(config.dns_timeout);
130		opts.attempts = config.dns_attempts as usize;
131		opts.try_tcp_on_error = config.dns_tcp_fallback;
132		opts.num_concurrent_reqs = 1;
133		opts.edns0 = true;
134		opts.case_randomization = config.dns_case_randomization;
135		opts.preserve_intermediates = true;
136		opts.ip_strategy = match config.ip_lookup_strategy {
137			| 1 => LookupIpStrategy::Ipv4Only,
138			| 2 => LookupIpStrategy::Ipv6Only,
139			| 3 => LookupIpStrategy::Ipv4AndIpv6,
140			| 4 => LookupIpStrategy::Ipv6thenIpv4,
141			| _ => LookupIpStrategy::Ipv4thenIpv6,
142		};
143
144		opts
145	}
146
147	/// Clear the in-memory hickory-dns caches
148	#[inline]
149	pub fn clear_cache(&self) { self.resolver.clear_cache(); }
150}
151
152impl<R: Resolve + 'static> Validating<R> {
153	pub fn new(inner: Arc<R>, denylist: Arc<[IPAddress]>) -> Arc<Self> {
154		Arc::new(Self { inner, denylist })
155	}
156}
157
158impl<R: Resolve + 'static> Resolve for Validating<R> {
159	fn resolve(&self, name: Name) -> Resolving {
160		validate_addrs(self.inner.clone(), self.denylist.clone(), name).boxed()
161	}
162}
163
164async fn validate_addrs<R: Resolve + 'static>(
165	inner: Arc<R>,
166	denylist: Arc<[IPAddress]>,
167	name: Name,
168) -> ResolvingResult {
169	let mut filtered = inner
170		.resolve(name)
171		.await?
172		.filter(move |sa| {
173			let ip = ipaddress_from_std(sa.ip());
174			!denylist.iter().any(|cidr| cidr.includes(&ip))
175		})
176		.peekable();
177
178	if filtered.peek().is_none() {
179		return Err(Box::new(io::Error::new(
180			PermissionDenied,
181			"All resolved addresses are denied by ip_range_denylist",
182		)));
183	}
184
185	Ok(Box::new(filtered))
186}
187
188impl Resolve for Resolver {
189	fn resolve(&self, name: Name) -> Resolving {
190		let resolver = if self
191			.server
192			.config
193			.dns_passthru_domains
194			.is_match(name.as_str())
195		{
196			trace!(?name, "matched to passthru resolver");
197			&self.passthru.resolver
198		} else {
199			trace!(?name, "using primary resolver");
200			&self.resolver
201		};
202
203		resolve_to_reqwest(self.server.clone(), resolver.clone(), name).boxed()
204	}
205}
206
207impl Resolve for Hooked {
208	fn resolve(&self, name: Name) -> Resolving {
209		let resolver = if self
210			.server
211			.config
212			.dns_passthru_domains
213			.is_match(name.as_str())
214		{
215			trace!(?name, "matched to passthru resolver");
216			&self.passthru.resolver
217		} else {
218			trace!(?name, "using hooked resolver");
219			&self.resolver
220		};
221
222		hooked_resolve(self.cache.clone(), self.server.clone(), resolver.clone(), name).boxed()
223	}
224}
225
226impl Resolve for Passthru {
227	fn resolve(&self, name: Name) -> Resolving {
228		trace!(?name, "using passthru resolver");
229		resolve_to_reqwest(self.server.clone(), self.resolver.clone(), name).boxed()
230	}
231}
232
233#[tracing::instrument(
234	level = "debug",
235	skip_all,
236	fields(name = ?name.as_str())
237)]
238async fn hooked_resolve(
239	cache: Arc<Cache>,
240	server: Arc<Server>,
241	resolver: Arc<TokioResolver>,
242	name: Name,
243) -> Result<Addrs, Box<dyn std::error::Error + Send + Sync>> {
244	match cache.get_override(name.as_str()).await {
245		| Ok(cached) if cached.valid() => cached_to_reqwest(cached),
246		| Ok(CachedOverride { overriding, .. }) if overriding.is_some() =>
247			resolve_to_reqwest(
248				server,
249				resolver,
250				overriding
251					.as_deref()
252					.map(str::parse)
253					.expect("overriding is set for this record")
254					.expect("overriding is a valid internet name"),
255			)
256			.boxed()
257			.await,
258
259		| _ =>
260			resolve_to_reqwest(server, resolver, name)
261				.boxed()
262				.await,
263	}
264}
265
266async fn resolve_to_reqwest(
267	server: Arc<Server>,
268	resolver: Arc<TokioResolver>,
269	name: Name,
270) -> ResolvingResult {
271	use std::{io, io::ErrorKind::Interrupted};
272
273	let handle_shutdown = || Box::new(io::Error::new(Interrupted, "Server shutting down"));
274
275	let handle_results = |results: LookupIp| -> Addrs {
276		let addrs = results
277			.iter()
278			.map(|ip| SocketAddr::new(ip, 0))
279			.collect::<Vec<_>>()
280			.into_iter();
281
282		Box::new(addrs)
283	};
284
285	tokio::select! {
286		results = resolver.lookup_ip(name.as_str()) => Ok(handle_results(results?)),
287		() = server.until_shutdown() => Err(handle_shutdown()),
288	}
289}
290
291fn cached_to_reqwest(cached: CachedOverride) -> ResolvingResult {
292	let addrs = cached
293		.ips
294		.into_iter()
295		.map(move |ip| SocketAddr::new(ip, cached.port));
296
297	Ok(Box::new(addrs))
298}