1use std::future::Future;
3use std::net::SocketAddr;
4use std::pin::Pin;
5use std::time::Duration;
6
7use anyhow::{Context as _, Result, format_err};
8use tokio::net::TcpStream;
9use tokio::task::JoinSet;
10use tokio::time::timeout;
11use tokio_io_timeout::TimeoutStream;
12
13use crate::context::Context;
14use crate::net::session::SessionStream;
15use crate::net::tls::{SpkiHashStore, TlsSessionStore};
16use crate::sql::Sql;
17use crate::tools::time;
18
19pub(crate) mod dns;
20pub(crate) mod http;
21pub(crate) mod proxy;
22pub(crate) mod session;
23pub(crate) mod tls;
24
25use dns::lookup_host_with_cache;
26pub(crate) use http::read_url_with_tls;
27pub use http::{Response as HttpResponse, read_url, read_url_blob};
28use tls::wrap_tls;
29
30pub(crate) const TIMEOUT: Duration = Duration::from_secs(60);
34
35pub(crate) const CACHE_TTL: u64 = 30 * 24 * 60 * 60;
37
38pub(crate) async fn prune_connection_history(context: &Context) -> Result<()> {
40 let now = time();
41 context
42 .sql
43 .execute(
44 "DELETE FROM connection_history
45 WHERE ? > timestamp + ?",
46 (now, CACHE_TTL),
47 )
48 .await?;
49 Ok(())
50}
51
52pub(crate) async fn update_connection_history(
61 context: &Context,
62 alpn: &str,
63 host: &str,
64 port: u16,
65 addr: &str,
66 now: i64,
67) -> Result<()> {
68 context
69 .sql
70 .execute(
71 "INSERT INTO connection_history (host, port, alpn, addr, timestamp)
72 VALUES (?, ?, ?, ?, ?)
73 ON CONFLICT (host, port, alpn, addr)
74 DO UPDATE SET timestamp=excluded.timestamp",
75 (host, port, alpn, addr, now),
76 )
77 .await?;
78 Ok(())
79}
80
81pub(crate) async fn load_connection_timestamp(
84 sql: &Sql,
85 alpn: &str,
86 host: &str,
87 port: u16,
88 addr: Option<&str>,
89) -> Result<Option<i64>> {
90 let timestamp = sql
91 .query_get_value(
92 "SELECT timestamp FROM connection_history
93 WHERE host = ?
94 AND port = ?
95 AND alpn = ?
96 AND addr = IFNULL(?, addr)",
97 (host, port, alpn, addr),
98 )
99 .await?;
100 Ok(timestamp)
101}
102
103pub(crate) async fn connect_tcp_inner(
109 addr: SocketAddr,
110) -> Result<Pin<Box<TimeoutStream<TcpStream>>>> {
111 let tcp_stream = timeout(TIMEOUT, TcpStream::connect(addr))
112 .await
113 .with_context(|| format!("Connection to {addr} timed out"))?
114 .with_context(|| format!("Connection to {addr} failed"))?;
115
116 tcp_stream.set_nodelay(true)?;
118
119 let mut timeout_stream = TimeoutStream::new(tcp_stream);
120 timeout_stream.set_write_timeout(Some(TIMEOUT));
121 timeout_stream.set_read_timeout(Some(TIMEOUT));
122
123 Ok(Box::pin(timeout_stream))
124}
125
126pub(crate) async fn connect_tls_inner(
129 addr: SocketAddr,
130 host: &str,
131 strict_tls: bool,
132 alpn: &str,
133 tls_session_store: &TlsSessionStore,
134 spki_hash_store: &SpkiHashStore,
135 sql: &Sql,
136) -> Result<impl SessionStream + 'static> {
137 let use_sni = true;
138 let tcp_stream = connect_tcp_inner(addr).await?;
139 let tls_stream = wrap_tls(
140 strict_tls,
141 host,
142 addr.port(),
143 use_sni,
144 alpn,
145 tcp_stream,
146 tls_session_store,
147 spki_hash_store,
148 sql,
149 )
150 .await?;
151 Ok(tls_stream)
152}
153
154pub(crate) async fn run_connection_attempts<O, I, F>(mut futures: I) -> Result<O>
165where
166 I: Iterator<Item = F>,
167 F: Future<Output = Result<O>> + Send + 'static,
168 O: Send + 'static,
169{
170 let mut connection_attempt_set = JoinSet::new();
171
172 let mut delay_set = JoinSet::new();
175 for delay in [
176 Duration::from_millis(300),
177 Duration::from_secs(1),
178 Duration::from_secs(5),
179 Duration::from_secs(10),
180 ] {
181 delay_set.spawn(tokio::time::sleep(delay));
182 }
183
184 let mut all_errors = Vec::new();
185
186 let res = loop {
187 if let Some(fut) = futures.next() {
188 connection_attempt_set.spawn(fut);
189 }
190
191 tokio::select! {
192 biased;
193
194 res = connection_attempt_set.join_next() => {
195 match res {
196 Some(res) => {
197 match res.context("Failed to join task") {
198 Ok(Ok(conn)) => {
199 break Ok(conn);
201 }
202 Ok(Err(err)) => {
203 all_errors.push(err);
205 }
206 Err(err) => {
207 break Err(err);
208 }
209 }
210 }
211 None => {
212 break if all_errors.is_empty() {
216 Err(format_err!("No connection attempts were made"))
217 } else {
218 Err(format_err!("All connection attempts failed: {}", all_errors.into_iter().map(|err| format!("{err:#}")).collect::<Vec<String>>().join("; ")))
219 };
220 }
221 }
222 },
223
224 _ = delay_set.join_next(), if !delay_set.is_empty() => {
225 }
230 }
231 };
232
233 connection_attempt_set.shutdown().await;
242
243 res
244}
245
246pub(crate) async fn connect_tcp(
253 context: &Context,
254 host: &str,
255 port: u16,
256 load_cache: bool,
257) -> Result<Pin<Box<TimeoutStream<TcpStream>>>> {
258 let connection_futures = lookup_host_with_cache(context, host, port, "", load_cache)
259 .await?
260 .into_iter()
261 .map(connect_tcp_inner);
262 run_connection_attempts(connection_futures).await
263}