Skip to main content

tuwunel_api/client/session/
jwt.rs

1use std::str::FromStr;
2
3use jwt::{Algorithm, DecodingKey, Validation, decode};
4use ruma::{
5	OwnedUserId, UserId,
6	api::client::session::login::v3::{Request, Token},
7};
8use serde::Deserialize;
9use tuwunel_core::{Err, Result, at, config::JwtConfig, debug, err, jwt, warn};
10use tuwunel_service::Services;
11
12use crate::Ruma;
13
14#[derive(Debug, Deserialize)]
15struct Claim {
16	/// Subject is the localpart of the User MXID
17	sub: String,
18}
19
20pub(super) async fn handle_login(
21	services: &Services,
22	_body: &Ruma<Request>,
23	info: &Token,
24) -> Result<OwnedUserId> {
25	let user_id = validate_user(services, &info.token)?;
26	if !services.users.exists(&user_id).await {
27		let config = &services.config.jwt;
28		if !config.register_user {
29			return Err!(Request(NotFound("User {user_id} is not registered on this server.")));
30		}
31
32		services
33			.users
34			.create(&user_id, Some("*"), Some("jwt"))
35			.await?;
36	}
37
38	Ok(user_id)
39}
40
41pub(crate) fn validate_user(services: &Services, token: &str) -> Result<OwnedUserId> {
42	let config = &services.config.jwt;
43	if !config.enable {
44		return Err!(Request(Unauthorized("JWT login is not enabled.")));
45	}
46
47	let claim = validate(config, token)?;
48	let local = claim.sub.to_lowercase();
49	let server = &services.server.name;
50	let user_id = UserId::parse_with_server_name(local, server).map_err(|e| {
51		err!(Request(InvalidUsername("JWT subject is not a valid user MXID: {e}")))
52	})?;
53
54	Ok(user_id)
55}
56
57fn validate(config: &JwtConfig, token: &str) -> Result<Claim> {
58	let verifier = init_verifier(config)?;
59	let validator = init_validator(config)?;
60	decode::<Claim>(token, &verifier, &validator)
61		.map(|decoded| (decoded.header, decoded.claims))
62		.inspect(|(head, claim)| debug!(?head, ?claim, "JWT token decoded"))
63		.map_err(|e| err!(Request(Forbidden("Invalid JWT token: {e}"))))
64		.map(at!(1))
65}
66
67fn init_verifier(config: &JwtConfig) -> Result<DecodingKey> {
68	let key = &config.key;
69	let format = config.format.to_uppercase();
70
71	Ok(match format.as_str() {
72		| "HMAC" => DecodingKey::from_secret(key.as_bytes()),
73
74		| "HMACB64" => DecodingKey::from_base64_secret(key.as_str())
75			.map_err(|e| err!(Config("jwt.key", "JWT key is not valid base64: {e}")))?,
76
77		| "ECDSA" => DecodingKey::from_ec_pem(key.as_bytes())
78			.map_err(|e| err!(Config("jwt.key", "JWT key is not valid ECDSA PEM: {e}")))?,
79
80		| "EDDSA" => DecodingKey::from_ed_pem(key.as_bytes())
81			.map_err(|e| err!(Config("jwt.key", "JWT key is not valid EDDSA PEM: {e}")))?,
82
83		| _ => return Err!(Config("jwt.format", "Key format {format:?} is not supported.")),
84	})
85}
86
87fn init_validator(config: &JwtConfig) -> Result<Validation> {
88	let alg = config.algorithm.as_str();
89	let alg = Algorithm::from_str(alg).map_err(|e| {
90		err!(Config("jwt.algorithm", "JWT algorithm is not recognized or configured {e}"))
91	})?;
92
93	let mut validator = Validation::new(alg);
94	let mut required_spec_claims: Vec<_> = ["sub"].into();
95
96	validator.validate_exp = config.validate_exp;
97	if config.require_exp {
98		required_spec_claims.push("exp");
99	}
100
101	validator.validate_nbf = config.validate_nbf;
102	if config.require_nbf {
103		required_spec_claims.push("nbf");
104	}
105
106	if !config.audience.is_empty() {
107		required_spec_claims.push("aud");
108		validator.set_audience(&config.audience);
109	}
110
111	if !config.issuer.is_empty() {
112		required_spec_claims.push("iss");
113		validator.set_issuer(&config.issuer);
114	}
115
116	#[expect(deprecated)]
117	if cfg!(debug_assertions) && !config.validate_signature {
118		warn!("JWT signature validation is disabled!");
119		validator.insecure_disable_signature_validation();
120	}
121
122	validator.set_required_spec_claims(&required_spec_claims);
123	debug!(?validator, "JWT configured");
124
125	Ok(validator)
126}