1use std::collections::HashMap;
4
5use anyhow::{Context as _, Result};
6use percent_encoding::{NON_ALPHANUMERIC, utf8_percent_encode};
7use serde::Deserialize;
8
9use crate::context::Context;
10use crate::log::{info, warn};
11use crate::net::http::post_form;
12use crate::net::read_url_blob;
13use crate::provider;
14use crate::provider::Oauth2Authorizer;
15use crate::tools::time;
16
17const OAUTH2_GMAIL: Oauth2 = Oauth2 {
18 client_id: "959970109878-4mvtgf6feshskf7695nfln6002mom908.apps.googleusercontent.com",
20 get_code: "https://accounts.google.com/o/oauth2/auth?client_id=$CLIENT_ID&redirect_uri=$REDIRECT_URI&response_type=code&scope=https%3A%2F%2Fmail.google.com%2F%20email&access_type=offline",
21 init_token: "https://accounts.google.com/o/oauth2/token?client_id=$CLIENT_ID&redirect_uri=$REDIRECT_URI&code=$CODE&grant_type=authorization_code",
22 refresh_token: "https://accounts.google.com/o/oauth2/token?client_id=$CLIENT_ID&redirect_uri=$REDIRECT_URI&refresh_token=$REFRESH_TOKEN&grant_type=refresh_token",
23 get_userinfo: Some(
24 "https://www.googleapis.com/oauth2/v1/userinfo?alt=json&access_token=$ACCESS_TOKEN",
25 ),
26};
27
28const OAUTH2_YANDEX: Oauth2 = Oauth2 {
29 client_id: "c4d0b6735fc8420a816d7e1303469341",
31 get_code: "https://oauth.yandex.com/authorize?client_id=$CLIENT_ID&response_type=code&scope=mail%3Aimap_full%20mail%3Asmtp&force_confirm=true",
32 init_token: "https://oauth.yandex.com/token?grant_type=authorization_code&code=$CODE&client_id=$CLIENT_ID&client_secret=58b8c6e94cf44fbe952da8511955dacf",
33 refresh_token: "https://oauth.yandex.com/token?grant_type=refresh_token&refresh_token=$REFRESH_TOKEN&client_id=$CLIENT_ID&client_secret=58b8c6e94cf44fbe952da8511955dacf",
34 get_userinfo: None,
35};
36
37#[derive(Debug, Clone, PartialEq, Eq)]
38struct Oauth2 {
39 client_id: &'static str,
40 get_code: &'static str,
41 init_token: &'static str,
42 refresh_token: &'static str,
43 get_userinfo: Option<&'static str>,
44}
45
46#[derive(Debug, Deserialize)]
48#[allow(dead_code)]
49struct Response {
50 access_token: Option<String>,
53 token_type: String,
54 expires_in: Option<u64>,
56 refresh_token: Option<String>,
57 scope: Option<String>,
58}
59
60pub async fn get_oauth2_url(
63 context: &Context,
64 addr: &str,
65 redirect_uri: &str,
66) -> Result<Option<String>> {
67 if let Some(oauth2) = Oauth2::from_address(context, addr).await {
68 context
69 .sql
70 .set_raw_config("oauth2_pending_redirect_uri", Some(redirect_uri))
71 .await?;
72 let oauth2_url = replace_in_uri(oauth2.get_code, "$CLIENT_ID", oauth2.client_id);
73 let oauth2_url = replace_in_uri(&oauth2_url, "$REDIRECT_URI", redirect_uri);
74
75 Ok(Some(oauth2_url))
76 } else {
77 Ok(None)
78 }
79}
80
81pub(crate) async fn get_oauth2_access_token(
82 context: &Context,
83 addr: &str,
84 code: &str,
85 regenerate: bool,
86) -> Result<Option<String>> {
87 if let Some(oauth2) = Oauth2::from_address(context, addr).await {
88 let lock = context.oauth2_mutex.lock().await;
89
90 if !regenerate && !is_expired(context).await? {
92 let access_token = context.sql.get_raw_config("oauth2_access_token").await?;
93 if access_token.is_some() {
94 return Ok(access_token);
96 }
97 }
98
99 let refresh_token = context.sql.get_raw_config("oauth2_refresh_token").await?;
101 let refresh_token_for = context
102 .sql
103 .get_raw_config("oauth2_refresh_token_for")
104 .await?
105 .unwrap_or_else(|| "unset".into());
106
107 let (redirect_uri, token_url, update_redirect_uri_on_success) =
108 if refresh_token.is_none() || refresh_token_for != code {
109 info!(context, "Generate OAuth2 refresh_token and access_token...",);
110 (
111 context
112 .sql
113 .get_raw_config("oauth2_pending_redirect_uri")
114 .await?
115 .unwrap_or_else(|| "unset".into()),
116 oauth2.init_token,
117 true,
118 )
119 } else {
120 info!(
121 context,
122 "Regenerate OAuth2 access_token by refresh_token...",
123 );
124 (
125 context
126 .sql
127 .get_raw_config("oauth2_redirect_uri")
128 .await?
129 .unwrap_or_else(|| "unset".into()),
130 oauth2.refresh_token,
131 false,
132 )
133 };
134
135 let mut parts = token_url.splitn(2, '?');
139 let post_url = parts.next().unwrap_or_default();
140 let post_args = parts.next().unwrap_or_default();
141 let mut post_param = HashMap::new();
142 for key_value_pair in post_args.split('&') {
143 let mut parts = key_value_pair.splitn(2, '=');
144 let key = parts.next().unwrap_or_default();
145 let mut value = parts.next().unwrap_or_default();
146
147 if value == "$CLIENT_ID" {
148 value = oauth2.client_id;
149 } else if value == "$REDIRECT_URI" {
150 value = &redirect_uri;
151 } else if value == "$CODE" {
152 value = code;
153 } else if value == "$REFRESH_TOKEN" {
154 if let Some(refresh_token) = refresh_token.as_ref() {
155 value = refresh_token;
156 }
157 }
158
159 post_param.insert(key, value);
160 }
161
162 let response: Response = match post_form(context, post_url, &post_param).await {
165 Ok(resp) => match serde_json::from_slice(&resp) {
166 Ok(response) => response,
167 Err(err) => {
168 warn!(
169 context,
170 "Failed to parse OAuth2 JSON response from {token_url}: {err:#}."
171 );
172 return Ok(None);
173 }
174 },
175 Err(err) => {
176 warn!(context, "Error calling OAuth2 at {token_url}: {err:#}.");
177 return Ok(None);
178 }
179 };
180
181 if let Some(ref token) = response.refresh_token {
183 context
184 .sql
185 .set_raw_config("oauth2_refresh_token", Some(token))
186 .await?;
187 context
188 .sql
189 .set_raw_config("oauth2_refresh_token_for", Some(code))
190 .await?;
191 }
192
193 if let Some(ref token) = response.access_token {
196 context
197 .sql
198 .set_raw_config("oauth2_access_token", Some(token))
199 .await?;
200 let expires_in = response
201 .expires_in
202 .map(|t| time() + t as i64 - 5)
204 .unwrap_or_else(|| 0);
205 context
206 .sql
207 .set_raw_config_int64("oauth2_timestamp_expires", expires_in)
208 .await?;
209
210 if update_redirect_uri_on_success {
211 context
212 .sql
213 .set_raw_config("oauth2_redirect_uri", Some(redirect_uri.as_ref()))
214 .await?;
215 }
216 } else {
217 warn!(context, "Failed to find OAuth2 access token");
218 }
219
220 drop(lock);
221
222 Ok(response.access_token)
223 } else {
224 warn!(context, "Internal OAuth2 error: 2");
225
226 Ok(None)
227 }
228}
229
230pub(crate) async fn get_oauth2_addr(
231 context: &Context,
232 addr: &str,
233 code: &str,
234) -> Result<Option<String>> {
235 let oauth2 = match Oauth2::from_address(context, addr).await {
236 Some(o) => o,
237 None => return Ok(None),
238 };
239 if oauth2.get_userinfo.is_none() {
240 return Ok(None);
241 }
242
243 if let Some(access_token) = get_oauth2_access_token(context, addr, code, false).await? {
244 let addr_out = match oauth2.get_addr(context, &access_token).await {
245 Ok(addr) => addr,
246 Err(err) => {
247 warn!(context, "Error getting addr: {err:#}.");
248 None
249 }
250 };
251 if addr_out.is_none() {
252 if let Some(access_token) = get_oauth2_access_token(context, addr, code, true).await? {
254 Ok(oauth2
255 .get_addr(context, &access_token)
256 .await
257 .unwrap_or_default())
258 } else {
259 Ok(None)
260 }
261 } else {
262 Ok(addr_out)
263 }
264 } else {
265 Ok(None)
266 }
267}
268
269impl Oauth2 {
270 async fn from_address(context: &Context, addr: &str) -> Option<Self> {
271 let addr_normalized = normalize_addr(addr);
272 let skip_mx = true;
273 if let Some(domain) = addr_normalized
274 .find('@')
275 .map(|index| addr_normalized.split_at(index + 1).1)
276 {
277 if let Some(oauth2_authorizer) = provider::get_provider_info(context, domain, skip_mx)
278 .await
279 .and_then(|provider| provider.oauth2_authorizer.as_ref())
280 {
281 return Some(match oauth2_authorizer {
282 Oauth2Authorizer::Gmail => OAUTH2_GMAIL,
283 Oauth2Authorizer::Yandex => OAUTH2_YANDEX,
284 });
285 }
286 }
287 None
288 }
289
290 async fn get_addr(&self, context: &Context, access_token: &str) -> Result<Option<String>> {
291 let userinfo_url = self.get_userinfo.unwrap_or("");
292 let userinfo_url = replace_in_uri(userinfo_url, "$ACCESS_TOKEN", access_token);
293
294 let response = read_url_blob(context, &userinfo_url).await?;
303 let parsed: HashMap<String, serde_json::Value> =
304 serde_json::from_slice(&response.blob).context("Error getting userinfo")?;
305 if let Some(addr) = parsed.get("email") {
308 if let Some(s) = addr.as_str() {
309 Ok(Some(s.to_string()))
310 } else {
311 warn!(context, "E-mail in userinfo is not a string: {}", addr);
312 Ok(None)
313 }
314 } else {
315 warn!(context, "E-mail missing in userinfo.");
316 Ok(None)
317 }
318 }
319}
320
321async fn is_expired(context: &Context) -> Result<bool> {
322 let expire_timestamp = context
323 .sql
324 .get_raw_config_int64("oauth2_timestamp_expires")
325 .await?
326 .unwrap_or_default();
327
328 if expire_timestamp <= 0 {
329 return Ok(false);
330 }
331 if expire_timestamp > time() {
332 return Ok(false);
333 }
334
335 Ok(true)
336}
337
338fn replace_in_uri(uri: &str, key: &str, value: &str) -> String {
339 let value_urlencoded = utf8_percent_encode(value, NON_ALPHANUMERIC).to_string();
340 uri.replace(key, &value_urlencoded)
341}
342
343fn normalize_addr(addr: &str) -> &str {
344 let normalized = addr.trim();
345 normalized.trim_start_matches("mailto:")
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351 use crate::test_utils::TestContext;
352
353 #[test]
354 fn test_normalize_addr() {
355 assert_eq!(normalize_addr(" hello@mail.de "), "hello@mail.de");
356 assert_eq!(normalize_addr("mailto:hello@mail.de "), "hello@mail.de");
357 }
358
359 #[test]
360 fn test_replace_in_uri() {
361 assert_eq!(
362 replace_in_uri("helloworld", "world", "a-b c"),
363 "helloa%2Db%20c"
364 );
365 }
366
367 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
368 async fn test_oauth_from_address() {
369 let t = TestContext::new().await;
370
371 assert_eq!(Oauth2::from_address(&t, "hello@gmail.com").await, None);
373 assert_eq!(Oauth2::from_address(&t, "hello@googlemail.com").await, None);
374
375 assert_eq!(
376 Oauth2::from_address(&t, "hello@yandex.com").await,
377 Some(OAUTH2_YANDEX)
378 );
379 assert_eq!(
380 Oauth2::from_address(&t, "hello@yandex.ru").await,
381 Some(OAUTH2_YANDEX)
382 );
383 assert_eq!(Oauth2::from_address(&t, "hello@web.de").await, None);
384 }
385
386 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
387 async fn test_get_oauth2_addr() {
388 let ctx = TestContext::new().await;
389 let addr = "dignifiedquire@gmail.com";
390 let code = "fail";
391 let res = get_oauth2_addr(&ctx.ctx, addr, code).await.unwrap();
392 assert_eq!(res, None);
394 }
395
396 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
397 async fn test_get_oauth2_url() {
398 let ctx = TestContext::new().await;
399 let addr = "example@yandex.com";
400 let redirect_uri = "chat.delta:/com.b44t.messenger";
401 let res = get_oauth2_url(&ctx.ctx, addr, redirect_uri).await.unwrap();
402
403 assert_eq!(res, Some("https://oauth.yandex.com/authorize?client_id=c4d0b6735fc8420a816d7e1303469341&response_type=code&scope=mail%3Aimap_full%20mail%3Asmtp&force_confirm=true".into()));
404 }
405
406 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
407 async fn test_get_oauth2_token() {
408 let ctx = TestContext::new().await;
409 let addr = "dignifiedquire@gmail.com";
410 let code = "fail";
411 let res = get_oauth2_access_token(&ctx.ctx, addr, code, false)
412 .await
413 .unwrap();
414 assert_eq!(res, None);
416 }
417}