1use std::collections::HashMap;
4
5use anyhow::{Context as _, Result};
6use percent_encoding::{NON_ALPHANUMERIC, utf8_percent_encode};
7use serde::Deserialize;
8
9use crate::context::Context;
10use crate::log::{info, warn};
11use crate::net::http::post_form;
12use crate::net::read_url_blob;
13use crate::provider;
14use crate::provider::Oauth2Authorizer;
15use crate::tools::time;
16
17const OAUTH2_YANDEX: Oauth2 = Oauth2 {
18 client_id: "c4d0b6735fc8420a816d7e1303469341",
20 get_code: "https://oauth.yandex.com/authorize?client_id=$CLIENT_ID&response_type=code&scope=mail%3Aimap_full%20mail%3Asmtp&force_confirm=true",
21 init_token: "https://oauth.yandex.com/token?grant_type=authorization_code&code=$CODE&client_id=$CLIENT_ID&client_secret=58b8c6e94cf44fbe952da8511955dacf",
22 refresh_token: "https://oauth.yandex.com/token?grant_type=refresh_token&refresh_token=$REFRESH_TOKEN&client_id=$CLIENT_ID&client_secret=58b8c6e94cf44fbe952da8511955dacf",
23 get_userinfo: None,
24};
25
26#[derive(Debug, Clone, PartialEq, Eq)]
27struct Oauth2 {
28 client_id: &'static str,
29 get_code: &'static str,
30 init_token: &'static str,
31 refresh_token: &'static str,
32 get_userinfo: Option<&'static str>,
33}
34
35#[derive(Debug, Deserialize)]
37#[allow(dead_code)]
38struct Response {
39 access_token: Option<String>,
42 token_type: String,
43 expires_in: Option<u64>,
45 refresh_token: Option<String>,
46 scope: Option<String>,
47}
48
49pub async fn get_oauth2_url(
52 context: &Context,
53 addr: &str,
54 redirect_uri: &str,
55) -> Result<Option<String>> {
56 if let Some(oauth2) = Oauth2::from_address(addr) {
57 context
58 .sql
59 .set_raw_config("oauth2_pending_redirect_uri", Some(redirect_uri))
60 .await?;
61 let oauth2_url = replace_in_uri(oauth2.get_code, "$CLIENT_ID", oauth2.client_id);
62 let oauth2_url = replace_in_uri(&oauth2_url, "$REDIRECT_URI", redirect_uri);
63
64 Ok(Some(oauth2_url))
65 } else {
66 Ok(None)
67 }
68}
69
70pub(crate) async fn get_oauth2_access_token(
71 context: &Context,
72 addr: &str,
73 code: &str,
74 regenerate: bool,
75) -> Result<Option<String>> {
76 if let Some(oauth2) = Oauth2::from_address(addr) {
77 let lock = context.oauth2_mutex.lock().await;
78
79 if !regenerate && !is_expired(context).await? {
81 let access_token = context.sql.get_raw_config("oauth2_access_token").await?;
82 if access_token.is_some() {
83 return Ok(access_token);
85 }
86 }
87
88 let refresh_token = context.sql.get_raw_config("oauth2_refresh_token").await?;
90 let refresh_token_for = context
91 .sql
92 .get_raw_config("oauth2_refresh_token_for")
93 .await?
94 .unwrap_or_else(|| "unset".into());
95
96 let (redirect_uri, token_url, update_redirect_uri_on_success) =
97 if refresh_token.is_none() || refresh_token_for != code {
98 info!(context, "Generate OAuth2 refresh_token and access_token...",);
99 (
100 context
101 .sql
102 .get_raw_config("oauth2_pending_redirect_uri")
103 .await?
104 .unwrap_or_else(|| "unset".into()),
105 oauth2.init_token,
106 true,
107 )
108 } else {
109 info!(
110 context,
111 "Regenerate OAuth2 access_token by refresh_token...",
112 );
113 (
114 context
115 .sql
116 .get_raw_config("oauth2_redirect_uri")
117 .await?
118 .unwrap_or_else(|| "unset".into()),
119 oauth2.refresh_token,
120 false,
121 )
122 };
123
124 let mut parts = token_url.splitn(2, '?');
128 let post_url = parts.next().unwrap_or_default();
129 let post_args = parts.next().unwrap_or_default();
130 let mut post_param = HashMap::new();
131 for key_value_pair in post_args.split('&') {
132 let mut parts = key_value_pair.splitn(2, '=');
133 let key = parts.next().unwrap_or_default();
134 let mut value = parts.next().unwrap_or_default();
135
136 if value == "$CLIENT_ID" {
137 value = oauth2.client_id;
138 } else if value == "$REDIRECT_URI" {
139 value = &redirect_uri;
140 } else if value == "$CODE" {
141 value = code;
142 } else if value == "$REFRESH_TOKEN" {
143 if let Some(refresh_token) = refresh_token.as_ref() {
144 value = refresh_token;
145 }
146 }
147
148 post_param.insert(key, value);
149 }
150
151 let response: Response = match post_form(context, post_url, &post_param).await {
154 Ok(resp) => match serde_json::from_slice(&resp) {
155 Ok(response) => response,
156 Err(err) => {
157 warn!(
158 context,
159 "Failed to parse OAuth2 JSON response from {token_url}: {err:#}."
160 );
161 return Ok(None);
162 }
163 },
164 Err(err) => {
165 warn!(context, "Error calling OAuth2 at {token_url}: {err:#}.");
166 return Ok(None);
167 }
168 };
169
170 if let Some(ref token) = response.refresh_token {
172 context
173 .sql
174 .set_raw_config("oauth2_refresh_token", Some(token))
175 .await?;
176 context
177 .sql
178 .set_raw_config("oauth2_refresh_token_for", Some(code))
179 .await?;
180 }
181
182 if let Some(ref token) = response.access_token {
185 context
186 .sql
187 .set_raw_config("oauth2_access_token", Some(token))
188 .await?;
189 let expires_in = response
190 .expires_in
191 .map(|t| time() + t as i64 - 5)
193 .unwrap_or_else(|| 0);
194 context
195 .sql
196 .set_raw_config_int64("oauth2_timestamp_expires", expires_in)
197 .await?;
198
199 if update_redirect_uri_on_success {
200 context
201 .sql
202 .set_raw_config("oauth2_redirect_uri", Some(redirect_uri.as_ref()))
203 .await?;
204 }
205 } else {
206 warn!(context, "Failed to find OAuth2 access token");
207 }
208
209 drop(lock);
210
211 Ok(response.access_token)
212 } else {
213 warn!(context, "Internal OAuth2 error: 2");
214
215 Ok(None)
216 }
217}
218
219pub(crate) async fn get_oauth2_addr(
220 context: &Context,
221 addr: &str,
222 code: &str,
223) -> Result<Option<String>> {
224 let oauth2 = match Oauth2::from_address(addr) {
225 Some(o) => o,
226 None => return Ok(None),
227 };
228 if oauth2.get_userinfo.is_none() {
229 return Ok(None);
230 }
231
232 if let Some(access_token) = get_oauth2_access_token(context, addr, code, false).await? {
233 let addr_out = match oauth2.get_addr(context, &access_token).await {
234 Ok(addr) => addr,
235 Err(err) => {
236 warn!(context, "Error getting addr: {err:#}.");
237 None
238 }
239 };
240 if addr_out.is_none() {
241 if let Some(access_token) = get_oauth2_access_token(context, addr, code, true).await? {
243 Ok(oauth2
244 .get_addr(context, &access_token)
245 .await
246 .unwrap_or_default())
247 } else {
248 Ok(None)
249 }
250 } else {
251 Ok(addr_out)
252 }
253 } else {
254 Ok(None)
255 }
256}
257
258impl Oauth2 {
259 fn from_address(addr: &str) -> Option<Self> {
260 let addr_normalized = normalize_addr(addr);
261 if let Some(domain) = addr_normalized
262 .find('@')
263 .map(|index| addr_normalized.split_at(index + 1).1)
264 {
265 if let Some(oauth2_authorizer) = provider::get_provider_info(domain)
266 .and_then(|provider| provider.oauth2_authorizer.as_ref())
267 {
268 return Some(match oauth2_authorizer {
269 Oauth2Authorizer::Yandex => OAUTH2_YANDEX,
270 });
271 }
272 }
273 None
274 }
275
276 async fn get_addr(&self, context: &Context, access_token: &str) -> Result<Option<String>> {
277 let userinfo_url = self.get_userinfo.unwrap_or("");
278 let userinfo_url = replace_in_uri(userinfo_url, "$ACCESS_TOKEN", access_token);
279
280 let response = read_url_blob(context, &userinfo_url).await?;
289 let parsed: HashMap<String, serde_json::Value> =
290 serde_json::from_slice(&response.blob).context("Error getting userinfo")?;
291 if let Some(addr) = parsed.get("email") {
294 if let Some(s) = addr.as_str() {
295 Ok(Some(s.to_string()))
296 } else {
297 warn!(context, "E-mail in userinfo is not a string: {}", addr);
298 Ok(None)
299 }
300 } else {
301 warn!(context, "E-mail missing in userinfo.");
302 Ok(None)
303 }
304 }
305}
306
307async fn is_expired(context: &Context) -> Result<bool> {
308 let expire_timestamp = context
309 .sql
310 .get_raw_config_int64("oauth2_timestamp_expires")
311 .await?
312 .unwrap_or_default();
313
314 if expire_timestamp <= 0 {
315 return Ok(false);
316 }
317 if expire_timestamp > time() {
318 return Ok(false);
319 }
320
321 Ok(true)
322}
323
324fn replace_in_uri(uri: &str, key: &str, value: &str) -> String {
325 let value_urlencoded = utf8_percent_encode(value, NON_ALPHANUMERIC).to_string();
326 uri.replace(key, &value_urlencoded)
327}
328
329fn normalize_addr(addr: &str) -> &str {
330 let normalized = addr.trim();
331 normalized.trim_start_matches("mailto:")
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337 use crate::test_utils::TestContext;
338
339 #[test]
340 fn test_normalize_addr() {
341 assert_eq!(normalize_addr(" hello@mail.de "), "hello@mail.de");
342 assert_eq!(normalize_addr("mailto:hello@mail.de "), "hello@mail.de");
343 }
344
345 #[test]
346 fn test_replace_in_uri() {
347 assert_eq!(
348 replace_in_uri("helloworld", "world", "a-b c"),
349 "helloa%2Db%20c"
350 );
351 }
352
353 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
354 async fn test_oauth_from_address() {
355 assert_eq!(Oauth2::from_address("hello@gmail.com"), None);
357 assert_eq!(Oauth2::from_address("hello@googlemail.com"), None);
358
359 assert_eq!(
360 Oauth2::from_address("hello@yandex.com"),
361 Some(OAUTH2_YANDEX)
362 );
363 assert_eq!(Oauth2::from_address("hello@yandex.ru"), Some(OAUTH2_YANDEX));
364 assert_eq!(Oauth2::from_address("hello@web.de"), None);
365 }
366
367 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
368 async fn test_get_oauth2_addr() {
369 let ctx = TestContext::new().await;
370 let addr = "dignifiedquire@gmail.com";
371 let code = "fail";
372 let res = get_oauth2_addr(&ctx.ctx, addr, code).await.unwrap();
373 assert_eq!(res, None);
375 }
376
377 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
378 async fn test_get_oauth2_url() {
379 let ctx = TestContext::new().await;
380 let addr = "example@yandex.com";
381 let redirect_uri = "chat.delta:/com.b44t.messenger";
382 let res = get_oauth2_url(&ctx.ctx, addr, redirect_uri).await.unwrap();
383
384 assert_eq!(res, Some("https://oauth.yandex.com/authorize?client_id=c4d0b6735fc8420a816d7e1303469341&response_type=code&scope=mail%3Aimap_full%20mail%3Asmtp&force_confirm=true".into()));
385 }
386
387 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
388 async fn test_get_oauth2_token() {
389 let ctx = TestContext::new().await;
390 let addr = "dignifiedquire@gmail.com";
391 let code = "fail";
392 let res = get_oauth2_access_token(&ctx.ctx, addr, code, false)
393 .await
394 .unwrap();
395 assert_eq!(res, None);
397 }
398}