1use parking_lot::Mutex;
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use anyhow::Result;
7
8use crate::net::session::SessionStream;
9use crate::sql::Sql;
10use crate::tools::time;
11
12use tokio_rustls::rustls;
13use tokio_rustls::rustls::client::ClientSessionStore;
14use tokio_rustls::rustls::server::ParsedCertificate;
15
16mod danger;
17use danger::CustomCertificateVerifier;
18
19mod spki;
20pub use spki::SpkiHashStore;
21
22#[expect(clippy::too_many_arguments)]
23pub async fn wrap_tls<'a>(
24 strict_tls: bool,
25 hostname: &str,
26 port: u16,
27 use_sni: bool,
28 alpn: &str,
29 stream: impl SessionStream + 'static,
30 tls_session_store: &TlsSessionStore,
31 spki_hash_store: &SpkiHashStore,
32 sql: &Sql,
33) -> Result<impl SessionStream + 'a> {
34 if strict_tls {
35 let tls_stream = wrap_rustls(
36 hostname,
37 port,
38 use_sni,
39 alpn,
40 stream,
41 tls_session_store,
42 spki_hash_store,
43 sql,
44 )
45 .await?;
46 let boxed_stream: Box<dyn SessionStream> = Box::new(tls_stream);
47 Ok(boxed_stream)
48 } else {
49 let alpns = if alpn.is_empty() {
53 Box::from([])
54 } else {
55 Box::from([alpn])
56 };
57 let tls = async_native_tls::TlsConnector::new()
58 .min_protocol_version(Some(async_native_tls::Protocol::Tlsv12))
59 .use_sni(use_sni)
60 .request_alpns(&alpns)
61 .danger_accept_invalid_hostnames(true)
62 .danger_accept_invalid_certs(true);
63 let tls_stream = tls.connect(hostname, stream).await?;
64 let boxed_stream: Box<dyn SessionStream> = Box::new(tls_stream);
65 Ok(boxed_stream)
66 }
67}
68
69#[derive(Debug)]
77pub(crate) struct TlsSessionStore {
78 sessions: Mutex<HashMap<(u16, String), Arc<dyn ClientSessionStore>>>,
79}
80
81const TLS_CACHE_SIZE: usize = 256;
86
87impl TlsSessionStore {
88 pub fn new() -> Self {
93 Self {
94 sessions: Default::default(),
95 }
96 }
97
98 pub fn get(&self, port: u16, alpn: &str) -> Arc<dyn ClientSessionStore> {
102 Arc::clone(
103 self.sessions
104 .lock()
105 .entry((port, alpn.to_string()))
106 .or_insert_with(|| {
107 Arc::new(rustls::client::ClientSessionMemoryCache::new(
108 TLS_CACHE_SIZE,
109 ))
110 }),
111 )
112 }
113}
114
115#[expect(clippy::too_many_arguments)]
116pub async fn wrap_rustls<'a>(
117 hostname: &str,
118 port: u16,
119 use_sni: bool,
120 alpn: &str,
121 stream: impl SessionStream + 'a,
122 tls_session_store: &TlsSessionStore,
123 spki_hash_store: &SpkiHashStore,
124 sql: &Sql,
125) -> Result<impl SessionStream + 'a> {
126 let root_cert_store =
127 rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
128
129 let mut config = rustls::ClientConfig::builder()
130 .with_root_certificates(root_cert_store)
131 .with_no_client_auth();
132 config.alpn_protocols = if alpn.is_empty() {
133 vec![]
134 } else {
135 vec![alpn.as_bytes().to_vec()]
136 };
137
138 let resumption_store = tls_session_store.get(port, alpn);
146 let resumption = rustls::client::Resumption::store(resumption_store)
147 .tls12_resumption(rustls::client::Tls12Resumption::Disabled);
148 config.resumption = resumption;
149 config.enable_sni = use_sni;
150
151 config
152 .dangerous()
153 .set_certificate_verifier(Arc::new(CustomCertificateVerifier::new(
154 spki_hash_store.get_spki_hash(hostname, sql).await?,
155 )));
156
157 let tls = tokio_rustls::TlsConnector::from(Arc::new(config));
158 let name = tokio_rustls::rustls::pki_types::ServerName::try_from(hostname)?.to_owned();
159 let tls_stream = tls.connect(name, stream).await?;
160
161 let (_io, client_connection) = tls_stream.get_ref();
164 if let Some(end_entity) = client_connection
165 .peer_certificates()
166 .and_then(|certs| certs.first())
167 {
168 let now = time();
169 let parsed_certificate = ParsedCertificate::try_from(end_entity)?;
170 let spki = parsed_certificate.subject_public_key_info();
171 spki_hash_store.save_spki(hostname, &spki, sql, now).await?;
172 }
173
174 Ok(tls_stream)
175}