1use std::collections::HashMap;
4
5use anyhow::{Context as _, Result};
6use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC};
7use serde::Deserialize;
8
9use crate::context::Context;
10use crate::log::{info, warn};
11use crate::net::http::post_form;
12use crate::net::read_url_blob;
13use crate::provider;
14use crate::provider::Oauth2Authorizer;
15use crate::tools::time;
16
17const OAUTH2_GMAIL: Oauth2 = Oauth2 {
18 client_id: "959970109878-4mvtgf6feshskf7695nfln6002mom908.apps.googleusercontent.com",
20 get_code: "https://accounts.google.com/o/oauth2/auth?client_id=$CLIENT_ID&redirect_uri=$REDIRECT_URI&response_type=code&scope=https%3A%2F%2Fmail.google.com%2F%20email&access_type=offline",
21 init_token: "https://accounts.google.com/o/oauth2/token?client_id=$CLIENT_ID&redirect_uri=$REDIRECT_URI&code=$CODE&grant_type=authorization_code",
22 refresh_token: "https://accounts.google.com/o/oauth2/token?client_id=$CLIENT_ID&redirect_uri=$REDIRECT_URI&refresh_token=$REFRESH_TOKEN&grant_type=refresh_token",
23 get_userinfo: Some("https://www.googleapis.com/oauth2/v1/userinfo?alt=json&access_token=$ACCESS_TOKEN"),
24};
25
26const OAUTH2_YANDEX: Oauth2 = Oauth2 {
27 client_id: "c4d0b6735fc8420a816d7e1303469341",
29 get_code: "https://oauth.yandex.com/authorize?client_id=$CLIENT_ID&response_type=code&scope=mail%3Aimap_full%20mail%3Asmtp&force_confirm=true",
30 init_token: "https://oauth.yandex.com/token?grant_type=authorization_code&code=$CODE&client_id=$CLIENT_ID&client_secret=58b8c6e94cf44fbe952da8511955dacf",
31 refresh_token: "https://oauth.yandex.com/token?grant_type=refresh_token&refresh_token=$REFRESH_TOKEN&client_id=$CLIENT_ID&client_secret=58b8c6e94cf44fbe952da8511955dacf",
32 get_userinfo: None,
33};
34
35#[derive(Debug, Clone, PartialEq, Eq)]
36struct Oauth2 {
37 client_id: &'static str,
38 get_code: &'static str,
39 init_token: &'static str,
40 refresh_token: &'static str,
41 get_userinfo: Option<&'static str>,
42}
43
44#[derive(Debug, Deserialize)]
46#[allow(dead_code)]
47struct Response {
48 access_token: Option<String>,
51 token_type: String,
52 expires_in: Option<u64>,
54 refresh_token: Option<String>,
55 scope: Option<String>,
56}
57
58pub async fn get_oauth2_url(
61 context: &Context,
62 addr: &str,
63 redirect_uri: &str,
64) -> Result<Option<String>> {
65 if let Some(oauth2) = Oauth2::from_address(context, addr).await {
66 context
67 .sql
68 .set_raw_config("oauth2_pending_redirect_uri", Some(redirect_uri))
69 .await?;
70 let oauth2_url = replace_in_uri(oauth2.get_code, "$CLIENT_ID", oauth2.client_id);
71 let oauth2_url = replace_in_uri(&oauth2_url, "$REDIRECT_URI", redirect_uri);
72
73 Ok(Some(oauth2_url))
74 } else {
75 Ok(None)
76 }
77}
78
79pub(crate) async fn get_oauth2_access_token(
80 context: &Context,
81 addr: &str,
82 code: &str,
83 regenerate: bool,
84) -> Result<Option<String>> {
85 if let Some(oauth2) = Oauth2::from_address(context, addr).await {
86 let lock = context.oauth2_mutex.lock().await;
87
88 if !regenerate && !is_expired(context).await? {
90 let access_token = context.sql.get_raw_config("oauth2_access_token").await?;
91 if access_token.is_some() {
92 return Ok(access_token);
94 }
95 }
96
97 let refresh_token = context.sql.get_raw_config("oauth2_refresh_token").await?;
99 let refresh_token_for = context
100 .sql
101 .get_raw_config("oauth2_refresh_token_for")
102 .await?
103 .unwrap_or_else(|| "unset".into());
104
105 let (redirect_uri, token_url, update_redirect_uri_on_success) =
106 if refresh_token.is_none() || refresh_token_for != code {
107 info!(context, "Generate OAuth2 refresh_token and access_token...",);
108 (
109 context
110 .sql
111 .get_raw_config("oauth2_pending_redirect_uri")
112 .await?
113 .unwrap_or_else(|| "unset".into()),
114 oauth2.init_token,
115 true,
116 )
117 } else {
118 info!(
119 context,
120 "Regenerate OAuth2 access_token by refresh_token...",
121 );
122 (
123 context
124 .sql
125 .get_raw_config("oauth2_redirect_uri")
126 .await?
127 .unwrap_or_else(|| "unset".into()),
128 oauth2.refresh_token,
129 false,
130 )
131 };
132
133 let mut parts = token_url.splitn(2, '?');
137 let post_url = parts.next().unwrap_or_default();
138 let post_args = parts.next().unwrap_or_default();
139 let mut post_param = HashMap::new();
140 for key_value_pair in post_args.split('&') {
141 let mut parts = key_value_pair.splitn(2, '=');
142 let key = parts.next().unwrap_or_default();
143 let mut value = parts.next().unwrap_or_default();
144
145 if value == "$CLIENT_ID" {
146 value = oauth2.client_id;
147 } else if value == "$REDIRECT_URI" {
148 value = &redirect_uri;
149 } else if value == "$CODE" {
150 value = code;
151 } else if value == "$REFRESH_TOKEN" {
152 if let Some(refresh_token) = refresh_token.as_ref() {
153 value = refresh_token;
154 }
155 }
156
157 post_param.insert(key, value);
158 }
159
160 let response: Response = match post_form(context, post_url, &post_param).await {
163 Ok(resp) => match serde_json::from_slice(&resp) {
164 Ok(response) => response,
165 Err(err) => {
166 warn!(
167 context,
168 "Failed to parse OAuth2 JSON response from {token_url}: {err:#}."
169 );
170 return Ok(None);
171 }
172 },
173 Err(err) => {
174 warn!(context, "Error calling OAuth2 at {token_url}: {err:#}.");
175 return Ok(None);
176 }
177 };
178
179 if let Some(ref token) = response.refresh_token {
181 context
182 .sql
183 .set_raw_config("oauth2_refresh_token", Some(token))
184 .await?;
185 context
186 .sql
187 .set_raw_config("oauth2_refresh_token_for", Some(code))
188 .await?;
189 }
190
191 if let Some(ref token) = response.access_token {
194 context
195 .sql
196 .set_raw_config("oauth2_access_token", Some(token))
197 .await?;
198 let expires_in = response
199 .expires_in
200 .map(|t| time() + t as i64 - 5)
202 .unwrap_or_else(|| 0);
203 context
204 .sql
205 .set_raw_config_int64("oauth2_timestamp_expires", expires_in)
206 .await?;
207
208 if update_redirect_uri_on_success {
209 context
210 .sql
211 .set_raw_config("oauth2_redirect_uri", Some(redirect_uri.as_ref()))
212 .await?;
213 }
214 } else {
215 warn!(context, "Failed to find OAuth2 access token");
216 }
217
218 drop(lock);
219
220 Ok(response.access_token)
221 } else {
222 warn!(context, "Internal OAuth2 error: 2");
223
224 Ok(None)
225 }
226}
227
228pub(crate) async fn get_oauth2_addr(
229 context: &Context,
230 addr: &str,
231 code: &str,
232) -> Result<Option<String>> {
233 let oauth2 = match Oauth2::from_address(context, addr).await {
234 Some(o) => o,
235 None => return Ok(None),
236 };
237 if oauth2.get_userinfo.is_none() {
238 return Ok(None);
239 }
240
241 if let Some(access_token) = get_oauth2_access_token(context, addr, code, false).await? {
242 let addr_out = match oauth2.get_addr(context, &access_token).await {
243 Ok(addr) => addr,
244 Err(err) => {
245 warn!(context, "Error getting addr: {err:#}.");
246 None
247 }
248 };
249 if addr_out.is_none() {
250 if let Some(access_token) = get_oauth2_access_token(context, addr, code, true).await? {
252 Ok(oauth2
253 .get_addr(context, &access_token)
254 .await
255 .unwrap_or_default())
256 } else {
257 Ok(None)
258 }
259 } else {
260 Ok(addr_out)
261 }
262 } else {
263 Ok(None)
264 }
265}
266
267impl Oauth2 {
268 async fn from_address(context: &Context, addr: &str) -> Option<Self> {
269 let addr_normalized = normalize_addr(addr);
270 let skip_mx = true;
271 if let Some(domain) = addr_normalized
272 .find('@')
273 .map(|index| addr_normalized.split_at(index + 1).1)
274 {
275 if let Some(oauth2_authorizer) = provider::get_provider_info(context, domain, skip_mx)
276 .await
277 .and_then(|provider| provider.oauth2_authorizer.as_ref())
278 {
279 return Some(match oauth2_authorizer {
280 Oauth2Authorizer::Gmail => OAUTH2_GMAIL,
281 Oauth2Authorizer::Yandex => OAUTH2_YANDEX,
282 });
283 }
284 }
285 None
286 }
287
288 async fn get_addr(&self, context: &Context, access_token: &str) -> Result<Option<String>> {
289 let userinfo_url = self.get_userinfo.unwrap_or("");
290 let userinfo_url = replace_in_uri(userinfo_url, "$ACCESS_TOKEN", access_token);
291
292 let response = read_url_blob(context, &userinfo_url).await?;
301 let parsed: HashMap<String, serde_json::Value> =
302 serde_json::from_slice(&response.blob).context("Error getting userinfo")?;
303 if let Some(addr) = parsed.get("email") {
306 if let Some(s) = addr.as_str() {
307 Ok(Some(s.to_string()))
308 } else {
309 warn!(context, "E-mail in userinfo is not a string: {}", addr);
310 Ok(None)
311 }
312 } else {
313 warn!(context, "E-mail missing in userinfo.");
314 Ok(None)
315 }
316 }
317}
318
319async fn is_expired(context: &Context) -> Result<bool> {
320 let expire_timestamp = context
321 .sql
322 .get_raw_config_int64("oauth2_timestamp_expires")
323 .await?
324 .unwrap_or_default();
325
326 if expire_timestamp <= 0 {
327 return Ok(false);
328 }
329 if expire_timestamp > time() {
330 return Ok(false);
331 }
332
333 Ok(true)
334}
335
336fn replace_in_uri(uri: &str, key: &str, value: &str) -> String {
337 let value_urlencoded = utf8_percent_encode(value, NON_ALPHANUMERIC).to_string();
338 uri.replace(key, &value_urlencoded)
339}
340
341fn normalize_addr(addr: &str) -> &str {
342 let normalized = addr.trim();
343 normalized.trim_start_matches("mailto:")
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349 use crate::test_utils::TestContext;
350
351 #[test]
352 fn test_normalize_addr() {
353 assert_eq!(normalize_addr(" hello@mail.de "), "hello@mail.de");
354 assert_eq!(normalize_addr("mailto:hello@mail.de "), "hello@mail.de");
355 }
356
357 #[test]
358 fn test_replace_in_uri() {
359 assert_eq!(
360 replace_in_uri("helloworld", "world", "a-b c"),
361 "helloa%2Db%20c"
362 );
363 }
364
365 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
366 async fn test_oauth_from_address() {
367 let t = TestContext::new().await;
368
369 assert_eq!(Oauth2::from_address(&t, "hello@gmail.com").await, None);
371 assert_eq!(Oauth2::from_address(&t, "hello@googlemail.com").await, None);
372
373 assert_eq!(
374 Oauth2::from_address(&t, "hello@yandex.com").await,
375 Some(OAUTH2_YANDEX)
376 );
377 assert_eq!(
378 Oauth2::from_address(&t, "hello@yandex.ru").await,
379 Some(OAUTH2_YANDEX)
380 );
381 assert_eq!(Oauth2::from_address(&t, "hello@web.de").await, None);
382 }
383
384 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
385 async fn test_get_oauth2_addr() {
386 let ctx = TestContext::new().await;
387 let addr = "dignifiedquire@gmail.com";
388 let code = "fail";
389 let res = get_oauth2_addr(&ctx.ctx, addr, code).await.unwrap();
390 assert_eq!(res, None);
392 }
393
394 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
395 async fn test_get_oauth2_url() {
396 let ctx = TestContext::new().await;
397 let addr = "example@yandex.com";
398 let redirect_uri = "chat.delta:/com.b44t.messenger";
399 let res = get_oauth2_url(&ctx.ctx, addr, redirect_uri).await.unwrap();
400
401 assert_eq!(res, Some("https://oauth.yandex.com/authorize?client_id=c4d0b6735fc8420a816d7e1303469341&response_type=code&scope=mail%3Aimap_full%20mail%3Asmtp&force_confirm=true".into()));
402 }
403
404 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
405 async fn test_get_oauth2_token() {
406 let ctx = TestContext::new().await;
407 let addr = "dignifiedquire@gmail.com";
408 let code = "fail";
409 let res = get_oauth2_access_token(&ctx.ctx, addr, code, false)
410 .await
411 .unwrap();
412 assert_eq!(res, None);
414 }
415}