1use std::net::SocketAddr;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use std::time::Duration;
12
13use anyhow::{Context as _, Result};
14use pin_project::pin_project;
15
16use crate::events::{Event, EventType, Events};
17use crate::net::session::SessionStream;
18use crate::tools::usize_to_u64;
19
20use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
21
22#[derive(Debug)]
23struct Metrics {
24 pub total_read: u64,
26
27 pub total_written: u64,
29}
30
31impl Metrics {
32 fn new() -> Self {
33 Self {
34 total_read: 0,
35 total_written: 0,
36 }
37 }
38}
39
40#[derive(Debug)]
42#[pin_project]
43pub(crate) struct LoggingStream<S: SessionStream> {
44 #[pin]
45 inner: S,
46
47 account_id: u32,
49
50 events: Events,
52
53 metrics: Metrics,
55
56 peer_addr: SocketAddr,
61}
62
63impl<S: SessionStream> LoggingStream<S> {
64 pub fn new(inner: S, account_id: u32, events: Events) -> Result<Self> {
65 let peer_addr: SocketAddr = inner
66 .peer_addr()
67 .context("Attempt to create LoggingStream over an unconnected stream")?;
68 Ok(Self {
69 inner,
70 account_id,
71 events,
72 metrics: Metrics::new(),
73 peer_addr,
74 })
75 }
76}
77
78impl<S: SessionStream> AsyncRead for LoggingStream<S> {
79 #[expect(clippy::arithmetic_side_effects)]
80 fn poll_read(
81 self: Pin<&mut Self>,
82 cx: &mut Context<'_>,
83 buf: &mut ReadBuf<'_>,
84 ) -> Poll<std::io::Result<()>> {
85 let this = self.project();
86 let old_remaining = buf.remaining();
87
88 let res = this.inner.poll_read(cx, buf);
89
90 if let Poll::Ready(Err(ref err)) = res {
91 let peer_addr = this.peer_addr;
92 let log_message = format!(
93 "Read error on stream {peer_addr:?} after reading {} and writing {} bytes: {err}.",
94 this.metrics.total_read, this.metrics.total_written
95 );
96 tracing::event!(
97 ::tracing::Level::WARN,
98 account_id = *this.account_id,
99 log_message
100 );
101 this.events.emit(Event {
102 id: *this.account_id,
103 typ: EventType::Warning(log_message),
104 });
105 }
106
107 let n = old_remaining - buf.remaining();
108 this.metrics.total_read = this.metrics.total_read.saturating_add(usize_to_u64(n));
109
110 res
111 }
112}
113
114impl<S: SessionStream> AsyncWrite for LoggingStream<S> {
115 fn poll_write(
116 self: Pin<&mut Self>,
117 cx: &mut std::task::Context<'_>,
118 buf: &[u8],
119 ) -> Poll<std::io::Result<usize>> {
120 let this = self.project();
121 let res = this.inner.poll_write(cx, buf);
122 if let Poll::Ready(Ok(n)) = res {
123 this.metrics.total_written = this.metrics.total_written.saturating_add(usize_to_u64(n));
124 }
125 res
126 }
127
128 fn poll_flush(
129 self: Pin<&mut Self>,
130 cx: &mut std::task::Context<'_>,
131 ) -> Poll<std::io::Result<()>> {
132 self.project().inner.poll_flush(cx)
133 }
134
135 fn poll_shutdown(
136 self: Pin<&mut Self>,
137 cx: &mut std::task::Context<'_>,
138 ) -> Poll<std::io::Result<()>> {
139 self.project().inner.poll_shutdown(cx)
140 }
141
142 fn poll_write_vectored(
143 self: Pin<&mut Self>,
144 cx: &mut Context<'_>,
145 bufs: &[std::io::IoSlice<'_>],
146 ) -> Poll<std::io::Result<usize>> {
147 let this = self.project();
148 let res = this.inner.poll_write_vectored(cx, bufs);
149 if let Poll::Ready(Ok(n)) = res {
150 this.metrics.total_written = this.metrics.total_written.saturating_add(usize_to_u64(n));
151 }
152 res
153 }
154
155 fn is_write_vectored(&self) -> bool {
156 self.inner.is_write_vectored()
157 }
158}
159
160impl<S: SessionStream> SessionStream for LoggingStream<S> {
161 fn set_read_timeout(&mut self, timeout: Option<Duration>) {
162 self.inner.set_read_timeout(timeout)
163 }
164
165 fn peer_addr(&self) -> Result<SocketAddr> {
166 self.inner.peer_addr()
167 }
168}