1use std::collections::HashMap;
4
5use anyhow::{Context as _, Result};
6use percent_encoding::{NON_ALPHANUMERIC, utf8_percent_encode};
7use serde::Deserialize;
8
9use crate::context::Context;
10use crate::log::warn;
11use crate::net::http::post_form;
12use crate::net::read_url_blob;
13use crate::provider;
14use crate::provider::Oauth2Authorizer;
15use crate::tools::time;
16
17const OAUTH2_YANDEX: Oauth2 = Oauth2 {
18 client_id: "c4d0b6735fc8420a816d7e1303469341",
20 get_code: "https://oauth.yandex.com/authorize?client_id=$CLIENT_ID&response_type=code&scope=mail%3Aimap_full%20mail%3Asmtp&force_confirm=true",
21 init_token: "https://oauth.yandex.com/token?grant_type=authorization_code&code=$CODE&client_id=$CLIENT_ID&client_secret=58b8c6e94cf44fbe952da8511955dacf",
22 refresh_token: "https://oauth.yandex.com/token?grant_type=refresh_token&refresh_token=$REFRESH_TOKEN&client_id=$CLIENT_ID&client_secret=58b8c6e94cf44fbe952da8511955dacf",
23 get_userinfo: None,
24};
25
26#[derive(Debug, Clone, PartialEq, Eq)]
27struct Oauth2 {
28 client_id: &'static str,
29 get_code: &'static str,
30 init_token: &'static str,
31 refresh_token: &'static str,
32 get_userinfo: Option<&'static str>,
33}
34
35#[derive(Debug, Deserialize)]
37#[allow(dead_code)]
38struct Response {
39 access_token: Option<String>,
42 token_type: String,
43 expires_in: Option<u64>,
45 refresh_token: Option<String>,
46 scope: Option<String>,
47}
48
49pub async fn get_oauth2_url(
52 context: &Context,
53 addr: &str,
54 redirect_uri: &str,
55) -> Result<Option<String>> {
56 if let Some(oauth2) = Oauth2::from_address(addr) {
57 context
58 .sql
59 .set_raw_config("oauth2_pending_redirect_uri", Some(redirect_uri))
60 .await?;
61 let oauth2_url = replace_in_uri(oauth2.get_code, "$CLIENT_ID", oauth2.client_id);
62 let oauth2_url = replace_in_uri(&oauth2_url, "$REDIRECT_URI", redirect_uri);
63
64 Ok(Some(oauth2_url))
65 } else {
66 Ok(None)
67 }
68}
69
70#[expect(clippy::arithmetic_side_effects)]
71pub(crate) async fn get_oauth2_access_token(
72 context: &Context,
73 addr: &str,
74 code: &str,
75 regenerate: bool,
76) -> Result<Option<String>> {
77 if let Some(oauth2) = Oauth2::from_address(addr) {
78 let lock = context.oauth2_mutex.lock().await;
79
80 if !regenerate && !is_expired(context).await? {
82 let access_token = context.sql.get_raw_config("oauth2_access_token").await?;
83 if access_token.is_some() {
84 return Ok(access_token);
86 }
87 }
88
89 let refresh_token = context.sql.get_raw_config("oauth2_refresh_token").await?;
91 let refresh_token_for = context
92 .sql
93 .get_raw_config("oauth2_refresh_token_for")
94 .await?
95 .unwrap_or_else(|| "unset".into());
96
97 let (redirect_uri, token_url, update_redirect_uri_on_success) =
98 if refresh_token.is_none() || refresh_token_for != code {
99 info!(context, "Generate OAuth2 refresh_token and access_token...",);
100 (
101 context
102 .sql
103 .get_raw_config("oauth2_pending_redirect_uri")
104 .await?
105 .unwrap_or_else(|| "unset".into()),
106 oauth2.init_token,
107 true,
108 )
109 } else {
110 info!(
111 context,
112 "Regenerate OAuth2 access_token by refresh_token...",
113 );
114 (
115 context
116 .sql
117 .get_raw_config("oauth2_redirect_uri")
118 .await?
119 .unwrap_or_else(|| "unset".into()),
120 oauth2.refresh_token,
121 false,
122 )
123 };
124
125 let mut parts = token_url.splitn(2, '?');
129 let post_url = parts.next().unwrap_or_default();
130 let post_args = parts.next().unwrap_or_default();
131 let mut post_param = HashMap::new();
132 for key_value_pair in post_args.split('&') {
133 let mut parts = key_value_pair.splitn(2, '=');
134 let key = parts.next().unwrap_or_default();
135 let mut value = parts.next().unwrap_or_default();
136
137 if value == "$CLIENT_ID" {
138 value = oauth2.client_id;
139 } else if value == "$REDIRECT_URI" {
140 value = &redirect_uri;
141 } else if value == "$CODE" {
142 value = code;
143 } else if value == "$REFRESH_TOKEN"
144 && let Some(refresh_token) = refresh_token.as_ref()
145 {
146 value = refresh_token;
147 }
148
149 post_param.insert(key, value);
150 }
151
152 let response: Response = match post_form(context, post_url, &post_param).await {
155 Ok(resp) => match serde_json::from_slice(&resp) {
156 Ok(response) => response,
157 Err(err) => {
158 warn!(
159 context,
160 "Failed to parse OAuth2 JSON response from {token_url}: {err:#}."
161 );
162 return Ok(None);
163 }
164 },
165 Err(err) => {
166 warn!(context, "Error calling OAuth2 at {token_url}: {err:#}.");
167 return Ok(None);
168 }
169 };
170
171 if let Some(ref token) = response.refresh_token {
173 context
174 .sql
175 .set_raw_config("oauth2_refresh_token", Some(token))
176 .await?;
177 context
178 .sql
179 .set_raw_config("oauth2_refresh_token_for", Some(code))
180 .await?;
181 }
182
183 if let Some(ref token) = response.access_token {
186 context
187 .sql
188 .set_raw_config("oauth2_access_token", Some(token))
189 .await?;
190 let expires_in = response
191 .expires_in
192 .map(|t| time() + t as i64 - 5)
194 .unwrap_or_else(|| 0);
195 context
196 .sql
197 .set_raw_config_int64("oauth2_timestamp_expires", expires_in)
198 .await?;
199
200 if update_redirect_uri_on_success {
201 context
202 .sql
203 .set_raw_config("oauth2_redirect_uri", Some(redirect_uri.as_ref()))
204 .await?;
205 }
206 } else {
207 warn!(context, "Failed to find OAuth2 access token");
208 }
209
210 drop(lock);
211
212 Ok(response.access_token)
213 } else {
214 warn!(context, "Internal OAuth2 error: 2");
215
216 Ok(None)
217 }
218}
219
220pub(crate) async fn get_oauth2_addr(
221 context: &Context,
222 addr: &str,
223 code: &str,
224) -> Result<Option<String>> {
225 let oauth2 = match Oauth2::from_address(addr) {
226 Some(o) => o,
227 None => return Ok(None),
228 };
229 if oauth2.get_userinfo.is_none() {
230 return Ok(None);
231 }
232
233 if let Some(access_token) = get_oauth2_access_token(context, addr, code, false).await? {
234 let addr_out = match oauth2.get_addr(context, &access_token).await {
235 Ok(addr) => addr,
236 Err(err) => {
237 warn!(context, "Error getting addr: {err:#}.");
238 None
239 }
240 };
241 if addr_out.is_none() {
242 if let Some(access_token) = get_oauth2_access_token(context, addr, code, true).await? {
244 Ok(oauth2
245 .get_addr(context, &access_token)
246 .await
247 .unwrap_or_default())
248 } else {
249 Ok(None)
250 }
251 } else {
252 Ok(addr_out)
253 }
254 } else {
255 Ok(None)
256 }
257}
258
259impl Oauth2 {
260 #[expect(clippy::arithmetic_side_effects)]
261 fn from_address(addr: &str) -> Option<Self> {
262 let addr_normalized = normalize_addr(addr);
263 if let Some(domain) = addr_normalized
264 .find('@')
265 .map(|index| addr_normalized.split_at(index + 1).1)
266 && let Some(oauth2_authorizer) = provider::get_provider_info(domain)
267 .and_then(|provider| provider.oauth2_authorizer.as_ref())
268 {
269 return Some(match oauth2_authorizer {
270 Oauth2Authorizer::Yandex => OAUTH2_YANDEX,
271 });
272 }
273 None
274 }
275
276 async fn get_addr(&self, context: &Context, access_token: &str) -> Result<Option<String>> {
277 let userinfo_url = self.get_userinfo.unwrap_or("");
278 let userinfo_url = replace_in_uri(userinfo_url, "$ACCESS_TOKEN", access_token);
279
280 let response = read_url_blob(context, &userinfo_url).await?;
289 let parsed: HashMap<String, serde_json::Value> =
290 serde_json::from_slice(&response.blob).context("Error getting userinfo")?;
291 if let Some(addr) = parsed.get("email") {
294 if let Some(s) = addr.as_str() {
295 Ok(Some(s.to_string()))
296 } else {
297 warn!(context, "E-mail in userinfo is not a string: {}", addr);
298 Ok(None)
299 }
300 } else {
301 warn!(context, "E-mail missing in userinfo.");
302 Ok(None)
303 }
304 }
305}
306
307async fn is_expired(context: &Context) -> Result<bool> {
308 let expire_timestamp = context
309 .sql
310 .get_raw_config_int64("oauth2_timestamp_expires")
311 .await?
312 .unwrap_or_default();
313
314 if expire_timestamp <= 0 {
315 return Ok(false);
316 }
317 if expire_timestamp > time() {
318 return Ok(false);
319 }
320
321 Ok(true)
322}
323
324fn replace_in_uri(uri: &str, key: &str, value: &str) -> String {
325 let value_urlencoded = utf8_percent_encode(value, NON_ALPHANUMERIC).to_string();
326 uri.replace(key, &value_urlencoded)
327}
328
329fn normalize_addr(addr: &str) -> &str {
330 let normalized = addr.trim();
331 normalized.trim_start_matches("mailto:")
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337 use crate::test_utils::TestContext;
338
339 #[test]
340 fn test_normalize_addr() {
341 assert_eq!(normalize_addr(" hello@mail.de "), "hello@mail.de");
342 assert_eq!(normalize_addr("mailto:hello@mail.de "), "hello@mail.de");
343 }
344
345 #[test]
346 fn test_replace_in_uri() {
347 assert_eq!(
348 replace_in_uri("helloworld", "world", "a-b c"),
349 "helloa%2Db%20c"
350 );
351 }
352
353 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
354 async fn test_oauth_from_address() {
355 assert_eq!(Oauth2::from_address("hello@gmail.com"), None);
357 assert_eq!(Oauth2::from_address("hello@googlemail.com"), None);
358
359 assert_eq!(
360 Oauth2::from_address("hello@yandex.com"),
361 Some(OAUTH2_YANDEX)
362 );
363 assert_eq!(Oauth2::from_address("hello@yandex.ru"), Some(OAUTH2_YANDEX));
364 assert_eq!(Oauth2::from_address("hello@web.de"), None);
365 }
366
367 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
368 async fn test_get_oauth2_addr() {
369 let ctx = TestContext::new().await;
370 let addr = "dignifiedquire@gmail.com";
371 let code = "fail";
372 let res = get_oauth2_addr(&ctx.ctx, addr, code).await.unwrap();
373 assert_eq!(res, None);
375 }
376
377 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
378 async fn test_get_oauth2_url() {
379 let ctx = TestContext::new().await;
380 let addr = "example@yandex.com";
381 let redirect_uri = "chat.delta:/com.b44t.messenger";
382 let res = get_oauth2_url(&ctx.ctx, addr, redirect_uri).await.unwrap();
383
384 assert_eq!(res, Some("https://oauth.yandex.com/authorize?client_id=c4d0b6735fc8420a816d7e1303469341&response_type=code&scope=mail%3Aimap_full%20mail%3Asmtp&force_confirm=true".into()));
385 }
386
387 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
388 async fn test_get_oauth2_token() {
389 let ctx = TestContext::new().await;
390 let addr = "dignifiedquire@gmail.com";
391 let code = "fail";
392 let res = get_oauth2_access_token(&ctx.ctx, addr, code, false)
393 .await
394 .unwrap();
395 assert_eq!(res, None);
397 }
398}