1use std::collections::HashMap;
4
5use anyhow::{Context as _, Result};
6use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC};
7use serde::Deserialize;
8
9use crate::context::Context;
10use crate::net::http::post_form;
11use crate::net::read_url_blob;
12use crate::provider;
13use crate::provider::Oauth2Authorizer;
14use crate::tools::time;
15
16const OAUTH2_GMAIL: Oauth2 = Oauth2 {
17 client_id: "959970109878-4mvtgf6feshskf7695nfln6002mom908.apps.googleusercontent.com",
19 get_code: "https://accounts.google.com/o/oauth2/auth?client_id=$CLIENT_ID&redirect_uri=$REDIRECT_URI&response_type=code&scope=https%3A%2F%2Fmail.google.com%2F%20email&access_type=offline",
20 init_token: "https://accounts.google.com/o/oauth2/token?client_id=$CLIENT_ID&redirect_uri=$REDIRECT_URI&code=$CODE&grant_type=authorization_code",
21 refresh_token: "https://accounts.google.com/o/oauth2/token?client_id=$CLIENT_ID&redirect_uri=$REDIRECT_URI&refresh_token=$REFRESH_TOKEN&grant_type=refresh_token",
22 get_userinfo: Some("https://www.googleapis.com/oauth2/v1/userinfo?alt=json&access_token=$ACCESS_TOKEN"),
23};
24
25const OAUTH2_YANDEX: Oauth2 = Oauth2 {
26 client_id: "c4d0b6735fc8420a816d7e1303469341",
28 get_code: "https://oauth.yandex.com/authorize?client_id=$CLIENT_ID&response_type=code&scope=mail%3Aimap_full%20mail%3Asmtp&force_confirm=true",
29 init_token: "https://oauth.yandex.com/token?grant_type=authorization_code&code=$CODE&client_id=$CLIENT_ID&client_secret=58b8c6e94cf44fbe952da8511955dacf",
30 refresh_token: "https://oauth.yandex.com/token?grant_type=refresh_token&refresh_token=$REFRESH_TOKEN&client_id=$CLIENT_ID&client_secret=58b8c6e94cf44fbe952da8511955dacf",
31 get_userinfo: None,
32};
33
34#[derive(Debug, Clone, PartialEq, Eq)]
35struct Oauth2 {
36 client_id: &'static str,
37 get_code: &'static str,
38 init_token: &'static str,
39 refresh_token: &'static str,
40 get_userinfo: Option<&'static str>,
41}
42
43#[derive(Debug, Deserialize)]
45#[allow(dead_code)]
46struct Response {
47 access_token: Option<String>,
50 token_type: String,
51 expires_in: Option<u64>,
53 refresh_token: Option<String>,
54 scope: Option<String>,
55}
56
57pub async fn get_oauth2_url(
60 context: &Context,
61 addr: &str,
62 redirect_uri: &str,
63) -> Result<Option<String>> {
64 if let Some(oauth2) = Oauth2::from_address(context, addr).await {
65 context
66 .sql
67 .set_raw_config("oauth2_pending_redirect_uri", Some(redirect_uri))
68 .await?;
69 let oauth2_url = replace_in_uri(oauth2.get_code, "$CLIENT_ID", oauth2.client_id);
70 let oauth2_url = replace_in_uri(&oauth2_url, "$REDIRECT_URI", redirect_uri);
71
72 Ok(Some(oauth2_url))
73 } else {
74 Ok(None)
75 }
76}
77
78pub(crate) async fn get_oauth2_access_token(
79 context: &Context,
80 addr: &str,
81 code: &str,
82 regenerate: bool,
83) -> Result<Option<String>> {
84 if let Some(oauth2) = Oauth2::from_address(context, addr).await {
85 let lock = context.oauth2_mutex.lock().await;
86
87 if !regenerate && !is_expired(context).await? {
89 let access_token = context.sql.get_raw_config("oauth2_access_token").await?;
90 if access_token.is_some() {
91 return Ok(access_token);
93 }
94 }
95
96 let refresh_token = context.sql.get_raw_config("oauth2_refresh_token").await?;
98 let refresh_token_for = context
99 .sql
100 .get_raw_config("oauth2_refresh_token_for")
101 .await?
102 .unwrap_or_else(|| "unset".into());
103
104 let (redirect_uri, token_url, update_redirect_uri_on_success) =
105 if refresh_token.is_none() || refresh_token_for != code {
106 info!(context, "Generate OAuth2 refresh_token and access_token...",);
107 (
108 context
109 .sql
110 .get_raw_config("oauth2_pending_redirect_uri")
111 .await?
112 .unwrap_or_else(|| "unset".into()),
113 oauth2.init_token,
114 true,
115 )
116 } else {
117 info!(
118 context,
119 "Regenerate OAuth2 access_token by refresh_token...",
120 );
121 (
122 context
123 .sql
124 .get_raw_config("oauth2_redirect_uri")
125 .await?
126 .unwrap_or_else(|| "unset".into()),
127 oauth2.refresh_token,
128 false,
129 )
130 };
131
132 let mut parts = token_url.splitn(2, '?');
136 let post_url = parts.next().unwrap_or_default();
137 let post_args = parts.next().unwrap_or_default();
138 let mut post_param = HashMap::new();
139 for key_value_pair in post_args.split('&') {
140 let mut parts = key_value_pair.splitn(2, '=');
141 let key = parts.next().unwrap_or_default();
142 let mut value = parts.next().unwrap_or_default();
143
144 if value == "$CLIENT_ID" {
145 value = oauth2.client_id;
146 } else if value == "$REDIRECT_URI" {
147 value = &redirect_uri;
148 } else if value == "$CODE" {
149 value = code;
150 } else if value == "$REFRESH_TOKEN" {
151 if let Some(refresh_token) = refresh_token.as_ref() {
152 value = refresh_token;
153 }
154 }
155
156 post_param.insert(key, value);
157 }
158
159 let response: Response = match post_form(context, post_url, &post_param).await {
162 Ok(resp) => match serde_json::from_slice(&resp) {
163 Ok(response) => response,
164 Err(err) => {
165 warn!(
166 context,
167 "Failed to parse OAuth2 JSON response from {token_url}: {err:#}."
168 );
169 return Ok(None);
170 }
171 },
172 Err(err) => {
173 warn!(context, "Error calling OAuth2 at {token_url}: {err:#}.");
174 return Ok(None);
175 }
176 };
177
178 if let Some(ref token) = response.refresh_token {
180 context
181 .sql
182 .set_raw_config("oauth2_refresh_token", Some(token))
183 .await?;
184 context
185 .sql
186 .set_raw_config("oauth2_refresh_token_for", Some(code))
187 .await?;
188 }
189
190 if let Some(ref token) = response.access_token {
193 context
194 .sql
195 .set_raw_config("oauth2_access_token", Some(token))
196 .await?;
197 let expires_in = response
198 .expires_in
199 .map(|t| time() + t as i64 - 5)
201 .unwrap_or_else(|| 0);
202 context
203 .sql
204 .set_raw_config_int64("oauth2_timestamp_expires", expires_in)
205 .await?;
206
207 if update_redirect_uri_on_success {
208 context
209 .sql
210 .set_raw_config("oauth2_redirect_uri", Some(redirect_uri.as_ref()))
211 .await?;
212 }
213 } else {
214 warn!(context, "Failed to find OAuth2 access token");
215 }
216
217 drop(lock);
218
219 Ok(response.access_token)
220 } else {
221 warn!(context, "Internal OAuth2 error: 2");
222
223 Ok(None)
224 }
225}
226
227pub(crate) async fn get_oauth2_addr(
228 context: &Context,
229 addr: &str,
230 code: &str,
231) -> Result<Option<String>> {
232 let oauth2 = match Oauth2::from_address(context, addr).await {
233 Some(o) => o,
234 None => return Ok(None),
235 };
236 if oauth2.get_userinfo.is_none() {
237 return Ok(None);
238 }
239
240 if let Some(access_token) = get_oauth2_access_token(context, addr, code, false).await? {
241 let addr_out = match oauth2.get_addr(context, &access_token).await {
242 Ok(addr) => addr,
243 Err(err) => {
244 warn!(context, "Error getting addr: {err:#}.");
245 None
246 }
247 };
248 if addr_out.is_none() {
249 if let Some(access_token) = get_oauth2_access_token(context, addr, code, true).await? {
251 Ok(oauth2
252 .get_addr(context, &access_token)
253 .await
254 .unwrap_or_default())
255 } else {
256 Ok(None)
257 }
258 } else {
259 Ok(addr_out)
260 }
261 } else {
262 Ok(None)
263 }
264}
265
266impl Oauth2 {
267 async fn from_address(context: &Context, addr: &str) -> Option<Self> {
268 let addr_normalized = normalize_addr(addr);
269 let skip_mx = true;
270 if let Some(domain) = addr_normalized
271 .find('@')
272 .map(|index| addr_normalized.split_at(index + 1).1)
273 {
274 if let Some(oauth2_authorizer) = provider::get_provider_info(context, domain, skip_mx)
275 .await
276 .and_then(|provider| provider.oauth2_authorizer.as_ref())
277 {
278 return Some(match oauth2_authorizer {
279 Oauth2Authorizer::Gmail => OAUTH2_GMAIL,
280 Oauth2Authorizer::Yandex => OAUTH2_YANDEX,
281 });
282 }
283 }
284 None
285 }
286
287 async fn get_addr(&self, context: &Context, access_token: &str) -> Result<Option<String>> {
288 let userinfo_url = self.get_userinfo.unwrap_or("");
289 let userinfo_url = replace_in_uri(userinfo_url, "$ACCESS_TOKEN", access_token);
290
291 let response = read_url_blob(context, &userinfo_url).await?;
300 let parsed: HashMap<String, serde_json::Value> =
301 serde_json::from_slice(&response.blob).context("Error getting userinfo")?;
302 if let Some(addr) = parsed.get("email") {
305 if let Some(s) = addr.as_str() {
306 Ok(Some(s.to_string()))
307 } else {
308 warn!(context, "E-mail in userinfo is not a string: {}", addr);
309 Ok(None)
310 }
311 } else {
312 warn!(context, "E-mail missing in userinfo.");
313 Ok(None)
314 }
315 }
316}
317
318async fn is_expired(context: &Context) -> Result<bool> {
319 let expire_timestamp = context
320 .sql
321 .get_raw_config_int64("oauth2_timestamp_expires")
322 .await?
323 .unwrap_or_default();
324
325 if expire_timestamp <= 0 {
326 return Ok(false);
327 }
328 if expire_timestamp > time() {
329 return Ok(false);
330 }
331
332 Ok(true)
333}
334
335fn replace_in_uri(uri: &str, key: &str, value: &str) -> String {
336 let value_urlencoded = utf8_percent_encode(value, NON_ALPHANUMERIC).to_string();
337 uri.replace(key, &value_urlencoded)
338}
339
340fn normalize_addr(addr: &str) -> &str {
341 let normalized = addr.trim();
342 normalized.trim_start_matches("mailto:")
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348 use crate::test_utils::TestContext;
349
350 #[test]
351 fn test_normalize_addr() {
352 assert_eq!(normalize_addr(" hello@mail.de "), "hello@mail.de");
353 assert_eq!(normalize_addr("mailto:hello@mail.de "), "hello@mail.de");
354 }
355
356 #[test]
357 fn test_replace_in_uri() {
358 assert_eq!(
359 replace_in_uri("helloworld", "world", "a-b c"),
360 "helloa%2Db%20c"
361 );
362 }
363
364 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
365 async fn test_oauth_from_address() {
366 let t = TestContext::new().await;
367
368 assert_eq!(Oauth2::from_address(&t, "hello@gmail.com").await, None);
370 assert_eq!(Oauth2::from_address(&t, "hello@googlemail.com").await, None);
371
372 assert_eq!(
373 Oauth2::from_address(&t, "hello@yandex.com").await,
374 Some(OAUTH2_YANDEX)
375 );
376 assert_eq!(
377 Oauth2::from_address(&t, "hello@yandex.ru").await,
378 Some(OAUTH2_YANDEX)
379 );
380 assert_eq!(Oauth2::from_address(&t, "hello@web.de").await, None);
381 }
382
383 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
384 async fn test_get_oauth2_addr() {
385 let ctx = TestContext::new().await;
386 let addr = "dignifiedquire@gmail.com";
387 let code = "fail";
388 let res = get_oauth2_addr(&ctx.ctx, addr, code).await.unwrap();
389 assert_eq!(res, None);
391 }
392
393 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
394 async fn test_get_oauth2_url() {
395 let ctx = TestContext::new().await;
396 let addr = "example@yandex.com";
397 let redirect_uri = "chat.delta:/com.b44t.messenger";
398 let res = get_oauth2_url(&ctx.ctx, addr, redirect_uri).await.unwrap();
399
400 assert_eq!(res, Some("https://oauth.yandex.com/authorize?client_id=c4d0b6735fc8420a816d7e1303469341&response_type=code&scope=mail%3Aimap_full%20mail%3Asmtp&force_confirm=true".into()));
401 }
402
403 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
404 async fn test_get_oauth2_token() {
405 let ctx = TestContext::new().await;
406 let addr = "dignifiedquire@gmail.com";
407 let code = "fail";
408 let res = get_oauth2_access_token(&ctx.ctx, addr, code, false)
409 .await
410 .unwrap();
411 assert_eq!(res, None);
413 }
414}