1use std::sync::Arc;
3
4use anyhow::Result;
5
6use crate::net::session::SessionStream;
7
8pub async fn wrap_tls(
9 strict_tls: bool,
10 hostname: &str,
11 alpn: &[&str],
12 stream: impl SessionStream + 'static,
13) -> Result<impl SessionStream> {
14 if strict_tls {
15 let tls_stream = wrap_rustls(hostname, alpn, stream).await?;
16 let boxed_stream: Box<dyn SessionStream> = Box::new(tls_stream);
17 Ok(boxed_stream)
18 } else {
19 let tls = async_native_tls::TlsConnector::new()
23 .min_protocol_version(Some(async_native_tls::Protocol::Tlsv12))
24 .request_alpns(alpn)
25 .danger_accept_invalid_hostnames(true)
26 .danger_accept_invalid_certs(true);
27 let tls_stream = tls.connect(hostname, stream).await?;
28 let boxed_stream: Box<dyn SessionStream> = Box::new(tls_stream);
29 Ok(boxed_stream)
30 }
31}
32
33pub async fn wrap_rustls(
34 hostname: &str,
35 alpn: &[&str],
36 stream: impl SessionStream,
37) -> Result<impl SessionStream> {
38 let mut root_cert_store = rustls::RootCertStore::empty();
39 root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
40
41 let mut config = rustls::ClientConfig::builder()
42 .with_root_certificates(root_cert_store)
43 .with_no_client_auth();
44 config.alpn_protocols = alpn.iter().map(|s| s.as_bytes().to_vec()).collect();
45
46 let tls = tokio_rustls::TlsConnector::from(Arc::new(config));
47 let name = rustls_pki_types::ServerName::try_from(hostname)?.to_owned();
48 let tls_stream = tls.connect(name, stream).await?;
49 Ok(tls_stream)
50}