1use std::net::SocketAddr;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use std::time::Duration;
12
13use anyhow::{Context as _, Result};
14use pin_project::pin_project;
15
16use crate::events::{Event, EventType, Events};
17use crate::net::session::SessionStream;
18
19use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
20
21#[derive(Debug)]
22struct Metrics {
23 pub total_read: usize,
25
26 pub total_written: usize,
28}
29
30impl Metrics {
31 fn new() -> Self {
32 Self {
33 total_read: 0,
34 total_written: 0,
35 }
36 }
37}
38
39#[derive(Debug)]
41#[pin_project]
42pub(crate) struct LoggingStream<S: SessionStream> {
43 #[pin]
44 inner: S,
45
46 account_id: u32,
48
49 events: Events,
51
52 metrics: Metrics,
54
55 peer_addr: SocketAddr,
60}
61
62impl<S: SessionStream> LoggingStream<S> {
63 pub fn new(inner: S, account_id: u32, events: Events) -> Result<Self> {
64 let peer_addr: SocketAddr = inner
65 .peer_addr()
66 .context("Attempt to create LoggingStream over an unconnected stream")?;
67 Ok(Self {
68 inner,
69 account_id,
70 events,
71 metrics: Metrics::new(),
72 peer_addr,
73 })
74 }
75}
76
77impl<S: SessionStream> AsyncRead for LoggingStream<S> {
78 fn poll_read(
79 self: Pin<&mut Self>,
80 cx: &mut Context<'_>,
81 buf: &mut ReadBuf<'_>,
82 ) -> Poll<std::io::Result<()>> {
83 let this = self.project();
84 let old_remaining = buf.remaining();
85
86 let res = this.inner.poll_read(cx, buf);
87
88 if let Poll::Ready(Err(ref err)) = res {
89 let peer_addr = this.peer_addr;
90 let log_message = format!(
91 "Read error on stream {peer_addr:?} after reading {} and writing {} bytes: {err}.",
92 this.metrics.total_read, this.metrics.total_written
93 );
94 this.events.emit(Event {
95 id: *this.account_id,
96 typ: EventType::Warning(log_message),
97 });
98 }
99
100 let n = old_remaining - buf.remaining();
101 this.metrics.total_read = this.metrics.total_read.saturating_add(n);
102
103 res
104 }
105}
106
107impl<S: SessionStream> AsyncWrite for LoggingStream<S> {
108 fn poll_write(
109 self: Pin<&mut Self>,
110 cx: &mut std::task::Context<'_>,
111 buf: &[u8],
112 ) -> Poll<std::io::Result<usize>> {
113 let this = self.project();
114 let res = this.inner.poll_write(cx, buf);
115 if let Poll::Ready(Ok(n)) = res {
116 this.metrics.total_written = this.metrics.total_written.saturating_add(n);
117 }
118 res
119 }
120
121 fn poll_flush(
122 self: Pin<&mut Self>,
123 cx: &mut std::task::Context<'_>,
124 ) -> Poll<std::io::Result<()>> {
125 self.project().inner.poll_flush(cx)
126 }
127
128 fn poll_shutdown(
129 self: Pin<&mut Self>,
130 cx: &mut std::task::Context<'_>,
131 ) -> Poll<std::io::Result<()>> {
132 self.project().inner.poll_shutdown(cx)
133 }
134
135 fn poll_write_vectored(
136 self: Pin<&mut Self>,
137 cx: &mut Context<'_>,
138 bufs: &[std::io::IoSlice<'_>],
139 ) -> Poll<std::io::Result<usize>> {
140 let this = self.project();
141 let res = this.inner.poll_write_vectored(cx, bufs);
142 if let Poll::Ready(Ok(n)) = res {
143 this.metrics.total_written = this.metrics.total_written.saturating_add(n);
144 }
145 res
146 }
147
148 fn is_write_vectored(&self) -> bool {
149 self.inner.is_write_vectored()
150 }
151}
152
153impl<S: SessionStream> SessionStream for LoggingStream<S> {
154 fn set_read_timeout(&mut self, timeout: Option<Duration>) {
155 self.inner.set_read_timeout(timeout)
156 }
157
158 fn peer_addr(&self) -> Result<SocketAddr> {
159 self.inner.peer_addr()
160 }
161}