deltachat/log/
stream.rs

1//! Stream that logs errors as events.
2//!
3//! This stream can be used to wrap IMAP,
4//! SMTP and HTTP streams so errors
5//! that occur are logged before
6//! they are processed.
7
8use std::net::SocketAddr;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use std::time::Duration;
12
13use anyhow::{Context as _, Result};
14use pin_project::pin_project;
15
16use crate::events::{Event, EventType, Events};
17use crate::net::session::SessionStream;
18
19use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
20
21#[derive(Debug)]
22struct Metrics {
23    /// Total number of bytes read.
24    pub total_read: usize,
25
26    /// Total number of bytes written.
27    pub total_written: usize,
28}
29
30impl Metrics {
31    fn new() -> Self {
32        Self {
33            total_read: 0,
34            total_written: 0,
35        }
36    }
37}
38
39/// Stream that logs errors to the event channel.
40#[derive(Debug)]
41#[pin_project]
42pub(crate) struct LoggingStream<S: SessionStream> {
43    #[pin]
44    inner: S,
45
46    /// Account ID for logging.
47    account_id: u32,
48
49    /// Event channel.
50    events: Events,
51
52    /// Metrics for this stream.
53    metrics: Metrics,
54
55    /// Peer address at the time of creation.
56    ///
57    /// Socket may become disconnected later,
58    /// so we save it when `LoggingStream` is created.
59    peer_addr: SocketAddr,
60}
61
62impl<S: SessionStream> LoggingStream<S> {
63    pub fn new(inner: S, account_id: u32, events: Events) -> Result<Self> {
64        let peer_addr: SocketAddr = inner
65            .peer_addr()
66            .context("Attempt to create LoggingStream over an unconnected stream")?;
67        Ok(Self {
68            inner,
69            account_id,
70            events,
71            metrics: Metrics::new(),
72            peer_addr,
73        })
74    }
75}
76
77impl<S: SessionStream> AsyncRead for LoggingStream<S> {
78    fn poll_read(
79        self: Pin<&mut Self>,
80        cx: &mut Context<'_>,
81        buf: &mut ReadBuf<'_>,
82    ) -> Poll<std::io::Result<()>> {
83        let this = self.project();
84        let old_remaining = buf.remaining();
85
86        let res = this.inner.poll_read(cx, buf);
87
88        if let Poll::Ready(Err(ref err)) = res {
89            let peer_addr = this.peer_addr;
90            let log_message = format!(
91                "Read error on stream {peer_addr:?} after reading {} and writing {} bytes: {err}.",
92                this.metrics.total_read, this.metrics.total_written
93            );
94            this.events.emit(Event {
95                id: *this.account_id,
96                typ: EventType::Warning(log_message),
97            });
98        }
99
100        let n = old_remaining - buf.remaining();
101        this.metrics.total_read = this.metrics.total_read.saturating_add(n);
102
103        res
104    }
105}
106
107impl<S: SessionStream> AsyncWrite for LoggingStream<S> {
108    fn poll_write(
109        self: Pin<&mut Self>,
110        cx: &mut std::task::Context<'_>,
111        buf: &[u8],
112    ) -> Poll<std::io::Result<usize>> {
113        let this = self.project();
114        let res = this.inner.poll_write(cx, buf);
115        if let Poll::Ready(Ok(n)) = res {
116            this.metrics.total_written = this.metrics.total_written.saturating_add(n);
117        }
118        res
119    }
120
121    fn poll_flush(
122        self: Pin<&mut Self>,
123        cx: &mut std::task::Context<'_>,
124    ) -> Poll<std::io::Result<()>> {
125        self.project().inner.poll_flush(cx)
126    }
127
128    fn poll_shutdown(
129        self: Pin<&mut Self>,
130        cx: &mut std::task::Context<'_>,
131    ) -> Poll<std::io::Result<()>> {
132        self.project().inner.poll_shutdown(cx)
133    }
134
135    fn poll_write_vectored(
136        self: Pin<&mut Self>,
137        cx: &mut Context<'_>,
138        bufs: &[std::io::IoSlice<'_>],
139    ) -> Poll<std::io::Result<usize>> {
140        let this = self.project();
141        let res = this.inner.poll_write_vectored(cx, bufs);
142        if let Poll::Ready(Ok(n)) = res {
143            this.metrics.total_written = this.metrics.total_written.saturating_add(n);
144        }
145        res
146    }
147
148    fn is_write_vectored(&self) -> bool {
149        self.inner.is_write_vectored()
150    }
151}
152
153impl<S: SessionStream> SessionStream for LoggingStream<S> {
154    fn set_read_timeout(&mut self, timeout: Option<Duration>) {
155        self.inner.set_read_timeout(timeout)
156    }
157
158    fn peer_addr(&self) -> Result<SocketAddr> {
159        self.inner.peer_addr()
160    }
161}