deltachat/log/
stream.rs

1//! Stream that logs errors as events.
2//!
3//! This stream can be used to wrap IMAP,
4//! SMTP and HTTP streams so errors
5//! that occur are logged before
6//! they are processed.
7
8use std::net::SocketAddr;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use std::time::Duration;
12
13use anyhow::{Context as _, Result};
14use pin_project::pin_project;
15
16use crate::events::{Event, EventType, Events};
17use crate::net::session::SessionStream;
18use crate::tools::usize_to_u64;
19
20use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
21
22#[derive(Debug)]
23struct Metrics {
24    /// Total number of bytes read.
25    pub total_read: u64,
26
27    /// Total number of bytes written.
28    pub total_written: u64,
29}
30
31impl Metrics {
32    fn new() -> Self {
33        Self {
34            total_read: 0,
35            total_written: 0,
36        }
37    }
38}
39
40/// Stream that logs errors to the event channel.
41#[derive(Debug)]
42#[pin_project]
43pub(crate) struct LoggingStream<S: SessionStream> {
44    #[pin]
45    inner: S,
46
47    /// Account ID for logging.
48    account_id: u32,
49
50    /// Event channel.
51    events: Events,
52
53    /// Metrics for this stream.
54    metrics: Metrics,
55
56    /// Peer address at the time of creation.
57    ///
58    /// Socket may become disconnected later,
59    /// so we save it when `LoggingStream` is created.
60    peer_addr: SocketAddr,
61}
62
63impl<S: SessionStream> LoggingStream<S> {
64    pub fn new(inner: S, account_id: u32, events: Events) -> Result<Self> {
65        let peer_addr: SocketAddr = inner
66            .peer_addr()
67            .context("Attempt to create LoggingStream over an unconnected stream")?;
68        Ok(Self {
69            inner,
70            account_id,
71            events,
72            metrics: Metrics::new(),
73            peer_addr,
74        })
75    }
76}
77
78impl<S: SessionStream> AsyncRead for LoggingStream<S> {
79    #[expect(clippy::arithmetic_side_effects)]
80    fn poll_read(
81        self: Pin<&mut Self>,
82        cx: &mut Context<'_>,
83        buf: &mut ReadBuf<'_>,
84    ) -> Poll<std::io::Result<()>> {
85        let this = self.project();
86        let old_remaining = buf.remaining();
87
88        let res = this.inner.poll_read(cx, buf);
89
90        if let Poll::Ready(Err(ref err)) = res {
91            let peer_addr = this.peer_addr;
92            let log_message = format!(
93                "Read error on stream {peer_addr:?} after reading {} and writing {} bytes: {err}.",
94                this.metrics.total_read, this.metrics.total_written
95            );
96            tracing::event!(
97                ::tracing::Level::WARN,
98                account_id = *this.account_id,
99                log_message
100            );
101            this.events.emit(Event {
102                id: *this.account_id,
103                typ: EventType::Warning(log_message),
104            });
105        }
106
107        let n = old_remaining - buf.remaining();
108        this.metrics.total_read = this.metrics.total_read.saturating_add(usize_to_u64(n));
109
110        res
111    }
112}
113
114impl<S: SessionStream> AsyncWrite for LoggingStream<S> {
115    fn poll_write(
116        self: Pin<&mut Self>,
117        cx: &mut std::task::Context<'_>,
118        buf: &[u8],
119    ) -> Poll<std::io::Result<usize>> {
120        let this = self.project();
121        let res = this.inner.poll_write(cx, buf);
122        if let Poll::Ready(Ok(n)) = res {
123            this.metrics.total_written = this.metrics.total_written.saturating_add(usize_to_u64(n));
124        }
125        res
126    }
127
128    fn poll_flush(
129        self: Pin<&mut Self>,
130        cx: &mut std::task::Context<'_>,
131    ) -> Poll<std::io::Result<()>> {
132        self.project().inner.poll_flush(cx)
133    }
134
135    fn poll_shutdown(
136        self: Pin<&mut Self>,
137        cx: &mut std::task::Context<'_>,
138    ) -> Poll<std::io::Result<()>> {
139        self.project().inner.poll_shutdown(cx)
140    }
141
142    fn poll_write_vectored(
143        self: Pin<&mut Self>,
144        cx: &mut Context<'_>,
145        bufs: &[std::io::IoSlice<'_>],
146    ) -> Poll<std::io::Result<usize>> {
147        let this = self.project();
148        let res = this.inner.poll_write_vectored(cx, bufs);
149        if let Poll::Ready(Ok(n)) = res {
150            this.metrics.total_written = this.metrics.total_written.saturating_add(usize_to_u64(n));
151        }
152        res
153    }
154
155    fn is_write_vectored(&self) -> bool {
156        self.inner.is_write_vectored()
157    }
158}
159
160impl<S: SessionStream> SessionStream for LoggingStream<S> {
161    fn set_read_timeout(&mut self, timeout: Option<Duration>) {
162        self.inner.set_read_timeout(timeout)
163    }
164
165    fn peer_addr(&self) -> Result<SocketAddr> {
166        self.inner.peer_addr()
167    }
168}