1use std::collections::HashMap;
4
5use anyhow::{Context as _, Result};
6use percent_encoding::{NON_ALPHANUMERIC, utf8_percent_encode};
7use serde::Deserialize;
8
9use crate::context::Context;
10use crate::log::warn;
11use crate::net::http::post_form;
12use crate::net::read_url_blob;
13use crate::provider;
14use crate::provider::Oauth2Authorizer;
15use crate::tools::time;
16
17const OAUTH2_YANDEX: Oauth2 = Oauth2 {
18 client_id: "c4d0b6735fc8420a816d7e1303469341",
20 get_code: "https://oauth.yandex.com/authorize?client_id=$CLIENT_ID&response_type=code&scope=mail%3Aimap_full%20mail%3Asmtp&force_confirm=true",
21 init_token: "https://oauth.yandex.com/token?grant_type=authorization_code&code=$CODE&client_id=$CLIENT_ID&client_secret=58b8c6e94cf44fbe952da8511955dacf",
22 refresh_token: "https://oauth.yandex.com/token?grant_type=refresh_token&refresh_token=$REFRESH_TOKEN&client_id=$CLIENT_ID&client_secret=58b8c6e94cf44fbe952da8511955dacf",
23 get_userinfo: None,
24};
25
26#[derive(Debug, Clone, PartialEq, Eq)]
27struct Oauth2 {
28 client_id: &'static str,
29 get_code: &'static str,
30 init_token: &'static str,
31 refresh_token: &'static str,
32 get_userinfo: Option<&'static str>,
33}
34
35#[derive(Debug, Deserialize)]
37#[allow(dead_code)]
38struct Response {
39 access_token: Option<String>,
42 token_type: String,
43 expires_in: Option<u64>,
45 refresh_token: Option<String>,
46 scope: Option<String>,
47}
48
49pub async fn get_oauth2_url(
52 context: &Context,
53 addr: &str,
54 redirect_uri: &str,
55) -> Result<Option<String>> {
56 if let Some(oauth2) = Oauth2::from_address(addr) {
57 context
58 .sql
59 .set_raw_config("oauth2_pending_redirect_uri", Some(redirect_uri))
60 .await?;
61 let oauth2_url = replace_in_uri(oauth2.get_code, "$CLIENT_ID", oauth2.client_id);
62 let oauth2_url = replace_in_uri(&oauth2_url, "$REDIRECT_URI", redirect_uri);
63
64 Ok(Some(oauth2_url))
65 } else {
66 Ok(None)
67 }
68}
69
70pub(crate) async fn get_oauth2_access_token(
71 context: &Context,
72 addr: &str,
73 code: &str,
74 regenerate: bool,
75) -> Result<Option<String>> {
76 if let Some(oauth2) = Oauth2::from_address(addr) {
77 let lock = context.oauth2_mutex.lock().await;
78
79 if !regenerate && !is_expired(context).await? {
81 let access_token = context.sql.get_raw_config("oauth2_access_token").await?;
82 if access_token.is_some() {
83 return Ok(access_token);
85 }
86 }
87
88 let refresh_token = context.sql.get_raw_config("oauth2_refresh_token").await?;
90 let refresh_token_for = context
91 .sql
92 .get_raw_config("oauth2_refresh_token_for")
93 .await?
94 .unwrap_or_else(|| "unset".into());
95
96 let (redirect_uri, token_url, update_redirect_uri_on_success) =
97 if refresh_token.is_none() || refresh_token_for != code {
98 info!(context, "Generate OAuth2 refresh_token and access_token...",);
99 (
100 context
101 .sql
102 .get_raw_config("oauth2_pending_redirect_uri")
103 .await?
104 .unwrap_or_else(|| "unset".into()),
105 oauth2.init_token,
106 true,
107 )
108 } else {
109 info!(
110 context,
111 "Regenerate OAuth2 access_token by refresh_token...",
112 );
113 (
114 context
115 .sql
116 .get_raw_config("oauth2_redirect_uri")
117 .await?
118 .unwrap_or_else(|| "unset".into()),
119 oauth2.refresh_token,
120 false,
121 )
122 };
123
124 let mut parts = token_url.splitn(2, '?');
128 let post_url = parts.next().unwrap_or_default();
129 let post_args = parts.next().unwrap_or_default();
130 let mut post_param = HashMap::new();
131 for key_value_pair in post_args.split('&') {
132 let mut parts = key_value_pair.splitn(2, '=');
133 let key = parts.next().unwrap_or_default();
134 let mut value = parts.next().unwrap_or_default();
135
136 if value == "$CLIENT_ID" {
137 value = oauth2.client_id;
138 } else if value == "$REDIRECT_URI" {
139 value = &redirect_uri;
140 } else if value == "$CODE" {
141 value = code;
142 } else if value == "$REFRESH_TOKEN"
143 && let Some(refresh_token) = refresh_token.as_ref()
144 {
145 value = refresh_token;
146 }
147
148 post_param.insert(key, value);
149 }
150
151 let response: Response = match post_form(context, post_url, &post_param).await {
154 Ok(resp) => match serde_json::from_slice(&resp) {
155 Ok(response) => response,
156 Err(err) => {
157 warn!(
158 context,
159 "Failed to parse OAuth2 JSON response from {token_url}: {err:#}."
160 );
161 return Ok(None);
162 }
163 },
164 Err(err) => {
165 warn!(context, "Error calling OAuth2 at {token_url}: {err:#}.");
166 return Ok(None);
167 }
168 };
169
170 if let Some(ref token) = response.refresh_token {
172 context
173 .sql
174 .set_raw_config("oauth2_refresh_token", Some(token))
175 .await?;
176 context
177 .sql
178 .set_raw_config("oauth2_refresh_token_for", Some(code))
179 .await?;
180 }
181
182 if let Some(ref token) = response.access_token {
185 context
186 .sql
187 .set_raw_config("oauth2_access_token", Some(token))
188 .await?;
189 let expires_in = response
190 .expires_in
191 .map(|t| time() + t as i64 - 5)
193 .unwrap_or_else(|| 0);
194 context
195 .sql
196 .set_raw_config_int64("oauth2_timestamp_expires", expires_in)
197 .await?;
198
199 if update_redirect_uri_on_success {
200 context
201 .sql
202 .set_raw_config("oauth2_redirect_uri", Some(redirect_uri.as_ref()))
203 .await?;
204 }
205 } else {
206 warn!(context, "Failed to find OAuth2 access token");
207 }
208
209 drop(lock);
210
211 Ok(response.access_token)
212 } else {
213 warn!(context, "Internal OAuth2 error: 2");
214
215 Ok(None)
216 }
217}
218
219pub(crate) async fn get_oauth2_addr(
220 context: &Context,
221 addr: &str,
222 code: &str,
223) -> Result<Option<String>> {
224 let oauth2 = match Oauth2::from_address(addr) {
225 Some(o) => o,
226 None => return Ok(None),
227 };
228 if oauth2.get_userinfo.is_none() {
229 return Ok(None);
230 }
231
232 if let Some(access_token) = get_oauth2_access_token(context, addr, code, false).await? {
233 let addr_out = match oauth2.get_addr(context, &access_token).await {
234 Ok(addr) => addr,
235 Err(err) => {
236 warn!(context, "Error getting addr: {err:#}.");
237 None
238 }
239 };
240 if addr_out.is_none() {
241 if let Some(access_token) = get_oauth2_access_token(context, addr, code, true).await? {
243 Ok(oauth2
244 .get_addr(context, &access_token)
245 .await
246 .unwrap_or_default())
247 } else {
248 Ok(None)
249 }
250 } else {
251 Ok(addr_out)
252 }
253 } else {
254 Ok(None)
255 }
256}
257
258impl Oauth2 {
259 fn from_address(addr: &str) -> Option<Self> {
260 let addr_normalized = normalize_addr(addr);
261 if let Some(domain) = addr_normalized
262 .find('@')
263 .map(|index| addr_normalized.split_at(index + 1).1)
264 && let Some(oauth2_authorizer) = provider::get_provider_info(domain)
265 .and_then(|provider| provider.oauth2_authorizer.as_ref())
266 {
267 return Some(match oauth2_authorizer {
268 Oauth2Authorizer::Yandex => OAUTH2_YANDEX,
269 });
270 }
271 None
272 }
273
274 async fn get_addr(&self, context: &Context, access_token: &str) -> Result<Option<String>> {
275 let userinfo_url = self.get_userinfo.unwrap_or("");
276 let userinfo_url = replace_in_uri(userinfo_url, "$ACCESS_TOKEN", access_token);
277
278 let response = read_url_blob(context, &userinfo_url).await?;
287 let parsed: HashMap<String, serde_json::Value> =
288 serde_json::from_slice(&response.blob).context("Error getting userinfo")?;
289 if let Some(addr) = parsed.get("email") {
292 if let Some(s) = addr.as_str() {
293 Ok(Some(s.to_string()))
294 } else {
295 warn!(context, "E-mail in userinfo is not a string: {}", addr);
296 Ok(None)
297 }
298 } else {
299 warn!(context, "E-mail missing in userinfo.");
300 Ok(None)
301 }
302 }
303}
304
305async fn is_expired(context: &Context) -> Result<bool> {
306 let expire_timestamp = context
307 .sql
308 .get_raw_config_int64("oauth2_timestamp_expires")
309 .await?
310 .unwrap_or_default();
311
312 if expire_timestamp <= 0 {
313 return Ok(false);
314 }
315 if expire_timestamp > time() {
316 return Ok(false);
317 }
318
319 Ok(true)
320}
321
322fn replace_in_uri(uri: &str, key: &str, value: &str) -> String {
323 let value_urlencoded = utf8_percent_encode(value, NON_ALPHANUMERIC).to_string();
324 uri.replace(key, &value_urlencoded)
325}
326
327fn normalize_addr(addr: &str) -> &str {
328 let normalized = addr.trim();
329 normalized.trim_start_matches("mailto:")
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335 use crate::test_utils::TestContext;
336
337 #[test]
338 fn test_normalize_addr() {
339 assert_eq!(normalize_addr(" hello@mail.de "), "hello@mail.de");
340 assert_eq!(normalize_addr("mailto:hello@mail.de "), "hello@mail.de");
341 }
342
343 #[test]
344 fn test_replace_in_uri() {
345 assert_eq!(
346 replace_in_uri("helloworld", "world", "a-b c"),
347 "helloa%2Db%20c"
348 );
349 }
350
351 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
352 async fn test_oauth_from_address() {
353 assert_eq!(Oauth2::from_address("hello@gmail.com"), None);
355 assert_eq!(Oauth2::from_address("hello@googlemail.com"), None);
356
357 assert_eq!(
358 Oauth2::from_address("hello@yandex.com"),
359 Some(OAUTH2_YANDEX)
360 );
361 assert_eq!(Oauth2::from_address("hello@yandex.ru"), Some(OAUTH2_YANDEX));
362 assert_eq!(Oauth2::from_address("hello@web.de"), None);
363 }
364
365 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
366 async fn test_get_oauth2_addr() {
367 let ctx = TestContext::new().await;
368 let addr = "dignifiedquire@gmail.com";
369 let code = "fail";
370 let res = get_oauth2_addr(&ctx.ctx, addr, code).await.unwrap();
371 assert_eq!(res, None);
373 }
374
375 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
376 async fn test_get_oauth2_url() {
377 let ctx = TestContext::new().await;
378 let addr = "example@yandex.com";
379 let redirect_uri = "chat.delta:/com.b44t.messenger";
380 let res = get_oauth2_url(&ctx.ctx, addr, redirect_uri).await.unwrap();
381
382 assert_eq!(res, Some("https://oauth.yandex.com/authorize?client_id=c4d0b6735fc8420a816d7e1303469341&response_type=code&scope=mail%3Aimap_full%20mail%3Asmtp&force_confirm=true".into()));
383 }
384
385 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
386 async fn test_get_oauth2_token() {
387 let ctx = TestContext::new().await;
388 let addr = "dignifiedquire@gmail.com";
389 let code = "fail";
390 let res = get_oauth2_access_token(&ctx.ctx, addr, code, false)
391 .await
392 .unwrap();
393 assert_eq!(res, None);
395 }
396}