tuwunel_api/client/session/
jwt.rs1use std::str::FromStr;
2
3use jwt::{Algorithm, DecodingKey, Validation, decode};
4use ruma::{
5 OwnedUserId, UserId,
6 api::client::session::login::v3::{Request, Token},
7};
8use serde::Deserialize;
9use tuwunel_core::{Err, Result, at, config::JwtConfig, debug, err, jwt, warn};
10use tuwunel_service::Services;
11
12use crate::Ruma;
13
14#[derive(Debug, Deserialize)]
15struct Claim {
16 sub: String,
18}
19
20pub(super) async fn handle_login(
21 services: &Services,
22 _body: &Ruma<Request>,
23 info: &Token,
24) -> Result<OwnedUserId> {
25 let user_id = validate_user(services, &info.token)?;
26 if !services.users.exists(&user_id).await {
27 let config = &services.config.jwt;
28 if !config.register_user {
29 return Err!(Request(NotFound("User {user_id} is not registered on this server.")));
30 }
31
32 services
33 .users
34 .create(&user_id, Some("*"), Some("jwt"))
35 .await?;
36 }
37
38 Ok(user_id)
39}
40
41pub(crate) fn validate_user(services: &Services, token: &str) -> Result<OwnedUserId> {
42 let config = &services.config.jwt;
43 if !config.enable {
44 return Err!(Request(Unauthorized("JWT login is not enabled.")));
45 }
46
47 let claim = validate(config, token)?;
48 let local = claim.sub.to_lowercase();
49 let server = &services.server.name;
50 let user_id = UserId::parse_with_server_name(local, server).map_err(|e| {
51 err!(Request(InvalidUsername("JWT subject is not a valid user MXID: {e}")))
52 })?;
53
54 Ok(user_id)
55}
56
57fn validate(config: &JwtConfig, token: &str) -> Result<Claim> {
58 let verifier = init_verifier(config)?;
59 let validator = init_validator(config)?;
60 decode::<Claim>(token, &verifier, &validator)
61 .map(|decoded| (decoded.header, decoded.claims))
62 .inspect(|(head, claim)| debug!(?head, ?claim, "JWT token decoded"))
63 .map_err(|e| err!(Request(Forbidden("Invalid JWT token: {e}"))))
64 .map(at!(1))
65}
66
67fn init_verifier(config: &JwtConfig) -> Result<DecodingKey> {
68 let key = &config.key;
69 let format = config.format.to_uppercase();
70
71 Ok(match format.as_str() {
72 | "HMAC" => DecodingKey::from_secret(key.as_bytes()),
73
74 | "HMACB64" => DecodingKey::from_base64_secret(key.as_str())
75 .map_err(|e| err!(Config("jwt.key", "JWT key is not valid base64: {e}")))?,
76
77 | "ECDSA" => DecodingKey::from_ec_pem(key.as_bytes())
78 .map_err(|e| err!(Config("jwt.key", "JWT key is not valid ECDSA PEM: {e}")))?,
79
80 | "EDDSA" => DecodingKey::from_ed_pem(key.as_bytes())
81 .map_err(|e| err!(Config("jwt.key", "JWT key is not valid EDDSA PEM: {e}")))?,
82
83 | _ => return Err!(Config("jwt.format", "Key format {format:?} is not supported.")),
84 })
85}
86
87fn init_validator(config: &JwtConfig) -> Result<Validation> {
88 let alg = config.algorithm.as_str();
89 let alg = Algorithm::from_str(alg).map_err(|e| {
90 err!(Config("jwt.algorithm", "JWT algorithm is not recognized or configured {e}"))
91 })?;
92
93 let mut validator = Validation::new(alg);
94 let mut required_spec_claims: Vec<_> = ["sub"].into();
95
96 validator.validate_exp = config.validate_exp;
97 if config.require_exp {
98 required_spec_claims.push("exp");
99 }
100
101 validator.validate_nbf = config.validate_nbf;
102 if config.require_nbf {
103 required_spec_claims.push("nbf");
104 }
105
106 if !config.audience.is_empty() {
107 required_spec_claims.push("aud");
108 validator.set_audience(&config.audience);
109 }
110
111 if !config.issuer.is_empty() {
112 required_spec_claims.push("iss");
113 validator.set_issuer(&config.issuer);
114 }
115
116 #[expect(deprecated)]
117 if cfg!(debug_assertions) && !config.validate_signature {
118 warn!("JWT signature validation is disabled!");
119 validator.insecure_disable_signature_validation();
120 }
121
122 validator.set_required_spec_claims(&required_spec_claims);
123 debug!(?validator, "JWT configured");
124
125 Ok(validator)
126}