1use std::net::SocketAddr;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use std::time::Duration;
12
13use anyhow::{Context as _, Result};
14use pin_project::pin_project;
15
16use crate::events::{Event, EventType, Events};
17use crate::net::session::SessionStream;
18use crate::tools::usize_to_u64;
19
20use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
21
22#[derive(Debug)]
23struct Metrics {
24 pub total_read: u64,
26
27 pub total_written: u64,
29}
30
31impl Metrics {
32 fn new() -> Self {
33 Self {
34 total_read: 0,
35 total_written: 0,
36 }
37 }
38}
39
40#[derive(Debug)]
42#[pin_project]
43pub(crate) struct LoggingStream<S: SessionStream> {
44 #[pin]
45 inner: S,
46
47 account_id: u32,
49
50 events: Events,
52
53 metrics: Metrics,
55
56 peer_addr: SocketAddr,
61}
62
63impl<S: SessionStream> LoggingStream<S> {
64 pub fn new(inner: S, account_id: u32, events: Events) -> Result<Self> {
65 let peer_addr: SocketAddr = inner
66 .peer_addr()
67 .context("Attempt to create LoggingStream over an unconnected stream")?;
68 Ok(Self {
69 inner,
70 account_id,
71 events,
72 metrics: Metrics::new(),
73 peer_addr,
74 })
75 }
76}
77
78impl<S: SessionStream> AsyncRead for LoggingStream<S> {
79 fn poll_read(
80 self: Pin<&mut Self>,
81 cx: &mut Context<'_>,
82 buf: &mut ReadBuf<'_>,
83 ) -> Poll<std::io::Result<()>> {
84 let this = self.project();
85 let old_remaining = buf.remaining();
86
87 let res = this.inner.poll_read(cx, buf);
88
89 if let Poll::Ready(Err(ref err)) = res {
90 let peer_addr = this.peer_addr;
91 let log_message = format!(
92 "Read error on stream {peer_addr:?} after reading {} and writing {} bytes: {err}.",
93 this.metrics.total_read, this.metrics.total_written
94 );
95 tracing::event!(
96 ::tracing::Level::WARN,
97 account_id = *this.account_id,
98 log_message
99 );
100 this.events.emit(Event {
101 id: *this.account_id,
102 typ: EventType::Warning(log_message),
103 });
104 }
105
106 let n = old_remaining - buf.remaining();
107 this.metrics.total_read = this.metrics.total_read.saturating_add(usize_to_u64(n));
108
109 res
110 }
111}
112
113impl<S: SessionStream> AsyncWrite for LoggingStream<S> {
114 fn poll_write(
115 self: Pin<&mut Self>,
116 cx: &mut std::task::Context<'_>,
117 buf: &[u8],
118 ) -> Poll<std::io::Result<usize>> {
119 let this = self.project();
120 let res = this.inner.poll_write(cx, buf);
121 if let Poll::Ready(Ok(n)) = res {
122 this.metrics.total_written = this.metrics.total_written.saturating_add(usize_to_u64(n));
123 }
124 res
125 }
126
127 fn poll_flush(
128 self: Pin<&mut Self>,
129 cx: &mut std::task::Context<'_>,
130 ) -> Poll<std::io::Result<()>> {
131 self.project().inner.poll_flush(cx)
132 }
133
134 fn poll_shutdown(
135 self: Pin<&mut Self>,
136 cx: &mut std::task::Context<'_>,
137 ) -> Poll<std::io::Result<()>> {
138 self.project().inner.poll_shutdown(cx)
139 }
140
141 fn poll_write_vectored(
142 self: Pin<&mut Self>,
143 cx: &mut Context<'_>,
144 bufs: &[std::io::IoSlice<'_>],
145 ) -> Poll<std::io::Result<usize>> {
146 let this = self.project();
147 let res = this.inner.poll_write_vectored(cx, bufs);
148 if let Poll::Ready(Ok(n)) = res {
149 this.metrics.total_written = this.metrics.total_written.saturating_add(usize_to_u64(n));
150 }
151 res
152 }
153
154 fn is_write_vectored(&self) -> bool {
155 self.inner.is_write_vectored()
156 }
157}
158
159impl<S: SessionStream> SessionStream for LoggingStream<S> {
160 fn set_read_timeout(&mut self, timeout: Option<Duration>) {
161 self.inner.set_read_timeout(timeout)
162 }
163
164 fn peer_addr(&self) -> Result<SocketAddr> {
165 self.inner.peer_addr()
166 }
167}