1use parking_lot::Mutex;
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use anyhow::Result;
7
8use crate::net::session::SessionStream;
9
10use tokio_rustls::rustls::client::ClientSessionStore;
11
12pub async fn wrap_tls<'a>(
13 strict_tls: bool,
14 hostname: &str,
15 port: u16,
16 use_sni: bool,
17 alpn: &str,
18 stream: impl SessionStream + 'static,
19 tls_session_store: &TlsSessionStore,
20) -> Result<impl SessionStream + 'a> {
21 if strict_tls {
22 let tls_stream =
23 wrap_rustls(hostname, port, use_sni, alpn, stream, tls_session_store).await?;
24 let boxed_stream: Box<dyn SessionStream> = Box::new(tls_stream);
25 Ok(boxed_stream)
26 } else {
27 let alpns = if alpn.is_empty() {
31 Box::from([])
32 } else {
33 Box::from([alpn])
34 };
35 let tls = async_native_tls::TlsConnector::new()
36 .min_protocol_version(Some(async_native_tls::Protocol::Tlsv12))
37 .use_sni(use_sni)
38 .request_alpns(&alpns)
39 .danger_accept_invalid_hostnames(true)
40 .danger_accept_invalid_certs(true);
41 let tls_stream = tls.connect(hostname, stream).await?;
42 let boxed_stream: Box<dyn SessionStream> = Box::new(tls_stream);
43 Ok(boxed_stream)
44 }
45}
46
47#[derive(Debug)]
55pub(crate) struct TlsSessionStore {
56 sessions: Mutex<HashMap<(u16, String), Arc<dyn ClientSessionStore>>>,
57}
58
59const TLS_CACHE_SIZE: usize = 256;
64
65impl TlsSessionStore {
66 pub fn new() -> Self {
71 Self {
72 sessions: Default::default(),
73 }
74 }
75
76 pub fn get(&self, port: u16, alpn: &str) -> Arc<dyn ClientSessionStore> {
80 Arc::clone(
81 self.sessions
82 .lock()
83 .entry((port, alpn.to_string()))
84 .or_insert_with(|| {
85 Arc::new(tokio_rustls::rustls::client::ClientSessionMemoryCache::new(
86 TLS_CACHE_SIZE,
87 ))
88 }),
89 )
90 }
91}
92
93pub async fn wrap_rustls<'a>(
94 hostname: &str,
95 port: u16,
96 use_sni: bool,
97 alpn: &str,
98 stream: impl SessionStream + 'a,
99 tls_session_store: &TlsSessionStore,
100) -> Result<impl SessionStream + 'a> {
101 let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
102 root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
103
104 let mut config = tokio_rustls::rustls::ClientConfig::builder()
105 .with_root_certificates(root_cert_store)
106 .with_no_client_auth();
107 config.alpn_protocols = if alpn.is_empty() {
108 vec![]
109 } else {
110 vec![alpn.as_bytes().to_vec()]
111 };
112
113 let resumption_store = tls_session_store.get(port, alpn);
121 let resumption = tokio_rustls::rustls::client::Resumption::store(resumption_store)
122 .tls12_resumption(tokio_rustls::rustls::client::Tls12Resumption::Disabled);
123 config.resumption = resumption;
124 config.enable_sni = use_sni;
125
126 let tls = tokio_rustls::TlsConnector::from(Arc::new(config));
127 let name = rustls_pki_types::ServerName::try_from(hostname)?.to_owned();
128 let tls_stream = tls.connect(name, stream).await?;
129 Ok(tls_stream)
130}