1use parking_lot::Mutex;
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use anyhow::Result;
7
8use crate::net::session::SessionStream;
9
10use tokio_rustls::rustls::client::ClientSessionStore;
11
12pub async fn wrap_tls<'a>(
13 strict_tls: bool,
14 hostname: &str,
15 port: u16,
16 alpn: &str,
17 stream: impl SessionStream + 'static,
18 tls_session_store: &TlsSessionStore,
19) -> Result<impl SessionStream + 'a> {
20 if strict_tls {
21 let tls_stream = wrap_rustls(hostname, port, alpn, stream, tls_session_store).await?;
22 let boxed_stream: Box<dyn SessionStream> = Box::new(tls_stream);
23 Ok(boxed_stream)
24 } else {
25 let alpns = if alpn.is_empty() {
29 Box::from([])
30 } else {
31 Box::from([alpn])
32 };
33 let tls = async_native_tls::TlsConnector::new()
34 .min_protocol_version(Some(async_native_tls::Protocol::Tlsv12))
35 .request_alpns(&alpns)
36 .danger_accept_invalid_hostnames(true)
37 .danger_accept_invalid_certs(true);
38 let tls_stream = tls.connect(hostname, stream).await?;
39 let boxed_stream: Box<dyn SessionStream> = Box::new(tls_stream);
40 Ok(boxed_stream)
41 }
42}
43
44#[derive(Debug)]
52pub(crate) struct TlsSessionStore {
53 sessions: Mutex<HashMap<(u16, String), Arc<dyn ClientSessionStore>>>,
54}
55
56const TLS_CACHE_SIZE: usize = 256;
61
62impl TlsSessionStore {
63 pub fn new() -> Self {
68 Self {
69 sessions: Default::default(),
70 }
71 }
72
73 pub fn get(&self, port: u16, alpn: &str) -> Arc<dyn ClientSessionStore> {
77 Arc::clone(
78 self.sessions
79 .lock()
80 .entry((port, alpn.to_string()))
81 .or_insert_with(|| {
82 Arc::new(tokio_rustls::rustls::client::ClientSessionMemoryCache::new(
83 TLS_CACHE_SIZE,
84 ))
85 }),
86 )
87 }
88}
89
90pub async fn wrap_rustls<'a>(
91 hostname: &str,
92 port: u16,
93 alpn: &str,
94 stream: impl SessionStream + 'a,
95 tls_session_store: &TlsSessionStore,
96) -> Result<impl SessionStream + 'a> {
97 let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
98 root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
99
100 let mut config = tokio_rustls::rustls::ClientConfig::builder()
101 .with_root_certificates(root_cert_store)
102 .with_no_client_auth();
103 config.alpn_protocols = if alpn.is_empty() {
104 vec![]
105 } else {
106 vec![alpn.as_bytes().to_vec()]
107 };
108
109 let resumption_store = tls_session_store.get(port, alpn);
117 let resumption = tokio_rustls::rustls::client::Resumption::store(resumption_store)
118 .tls12_resumption(tokio_rustls::rustls::client::Tls12Resumption::Disabled);
119 config.resumption = resumption;
120
121 let tls = tokio_rustls::TlsConnector::from(Arc::new(config));
122 let name = rustls_pki_types::ServerName::try_from(hostname)?.to_owned();
123 let tls_stream = tls.connect(name, stream).await?;
124 Ok(tls_stream)
125}