1use parking_lot::Mutex;
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use anyhow::Result;
7
8use crate::net::session::SessionStream;
9
10use tokio_rustls::rustls;
11use tokio_rustls::rustls::client::ClientSessionStore;
12
13mod danger;
14use danger::NoCertificateVerification;
15
16pub async fn wrap_tls<'a>(
17 strict_tls: bool,
18 hostname: &str,
19 port: u16,
20 use_sni: bool,
21 alpn: &str,
22 stream: impl SessionStream + 'static,
23 tls_session_store: &TlsSessionStore,
24) -> Result<impl SessionStream + 'a> {
25 if strict_tls {
26 let tls_stream =
27 wrap_rustls(hostname, port, use_sni, alpn, stream, tls_session_store).await?;
28 let boxed_stream: Box<dyn SessionStream> = Box::new(tls_stream);
29 Ok(boxed_stream)
30 } else {
31 let alpns = if alpn.is_empty() {
35 Box::from([])
36 } else {
37 Box::from([alpn])
38 };
39 let tls = async_native_tls::TlsConnector::new()
40 .min_protocol_version(Some(async_native_tls::Protocol::Tlsv12))
41 .use_sni(use_sni)
42 .request_alpns(&alpns)
43 .danger_accept_invalid_hostnames(true)
44 .danger_accept_invalid_certs(true);
45 let tls_stream = tls.connect(hostname, stream).await?;
46 let boxed_stream: Box<dyn SessionStream> = Box::new(tls_stream);
47 Ok(boxed_stream)
48 }
49}
50
51#[derive(Debug)]
59pub(crate) struct TlsSessionStore {
60 sessions: Mutex<HashMap<(u16, String), Arc<dyn ClientSessionStore>>>,
61}
62
63const TLS_CACHE_SIZE: usize = 256;
68
69impl TlsSessionStore {
70 pub fn new() -> Self {
75 Self {
76 sessions: Default::default(),
77 }
78 }
79
80 pub fn get(&self, port: u16, alpn: &str) -> Arc<dyn ClientSessionStore> {
84 Arc::clone(
85 self.sessions
86 .lock()
87 .entry((port, alpn.to_string()))
88 .or_insert_with(|| {
89 Arc::new(rustls::client::ClientSessionMemoryCache::new(
90 TLS_CACHE_SIZE,
91 ))
92 }),
93 )
94 }
95}
96
97pub async fn wrap_rustls<'a>(
98 hostname: &str,
99 port: u16,
100 use_sni: bool,
101 alpn: &str,
102 stream: impl SessionStream + 'a,
103 tls_session_store: &TlsSessionStore,
104) -> Result<impl SessionStream + 'a> {
105 let mut root_cert_store = rustls::RootCertStore::empty();
106 root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
107
108 let mut config = rustls::ClientConfig::builder()
109 .with_root_certificates(root_cert_store)
110 .with_no_client_auth();
111 config.alpn_protocols = if alpn.is_empty() {
112 vec![]
113 } else {
114 vec![alpn.as_bytes().to_vec()]
115 };
116
117 let resumption_store = tls_session_store.get(port, alpn);
125 let resumption = rustls::client::Resumption::store(resumption_store)
126 .tls12_resumption(rustls::client::Tls12Resumption::Disabled);
127 config.resumption = resumption;
128 config.enable_sni = use_sni;
129
130 if hostname.starts_with("_") {
137 config
138 .dangerous()
139 .set_certificate_verifier(Arc::new(NoCertificateVerification::new()));
140 }
141
142 let tls = tokio_rustls::TlsConnector::from(Arc::new(config));
143 let name = tokio_rustls::rustls::pki_types::ServerName::try_from(hostname)?.to_owned();
144 let tls_stream = tls.connect(name, stream).await?;
145 Ok(tls_stream)
146}