ctoolbox/io/webui/
access_log_layer.rs

1//! `AccessLogLayer`: A standalone Tower Layer producing Apache/NCSA Common Log
2//! Format (CLF) lines with log level determined by HTTP status:
3//!
4//!   - error: 5xx
5//!   - warn : 4xx
6//!   - debug: 2xx / 3xx
7//!
8//! Also emits a separate warn log when Time To First Byte (TTFB) exceeds the
9//! configured threshold (default 150ms) for successful (2xx/3xx) responses.
10//!
11//! Does not currently support RFC 1413 identities, nor logging the user
12//! (whether via basic access authentication, digest access authentication, or
13//! otherwise).
14//!
15//! Integrate by adding (and enabling `ConnectInfo` if you want remote IP):
16//!
17//! ```rust,ignore
18//!   use crate::io::webui::access_log_layer::AccessLogLayer;
19//!   app.layer(AccessLogLayer::new(Duration::from_millis(150)))
20//! ```
21//!
22//! Ensure you serve with `into_make_service_with_connect_info::<SocketAddr>()`
23//! to populate `ConnectInfo<SocketAddr>` for remote address logging:
24//!
25//! ```rust,ignore
26//!   axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>())
27//!       .await?;
28//! ```
29//!
30//! Dependencies you need in Cargo.toml (if not already present):
31//! ```toml
32//!   chrono = { version = "0.4", features = ["clock"] }
33//!   pin-project = "1"
34//! ```
35//!
36//! This module does NOT depend on `tower_http::trace::TraceLayer`.
37//!
38//! Explanation of some of this:
39//! <https://raw.githubusercontent.com/tower-rs/tower/refs/heads/master/guides/building-a-middleware-from-scratch.md>
40
41use std::{
42    future::Future,
43    net::SocketAddr,
44    pin::Pin,
45    sync::{
46        Arc,
47        atomic::{AtomicBool, AtomicUsize, Ordering},
48    },
49    task::{Context, Poll},
50    time::{Duration, Instant},
51};
52
53use axum::extract::connect_info::ConnectInfo;
54use bytes::Buf;
55use chrono::{DateTime, Local};
56use http::{Request, Response, Version};
57use http_body::{Body, Frame};
58use pin_project::{pin_project, pinned_drop};
59use tower::{Layer, Service};
60
61// --------------------------------------
62// Public Layer
63// --------------------------------------
64
65#[derive(Clone, Debug)]
66pub struct AccessLogLayer {
67    slow_ttfb: Duration,
68}
69
70impl AccessLogLayer {
71    pub fn new(slow_ttfb: Duration) -> Self {
72        Self { slow_ttfb }
73    }
74}
75
76impl Default for AccessLogLayer {
77    fn default() -> Self {
78        Self::new(Duration::from_millis(150))
79    }
80}
81
82impl<S> Layer<S> for AccessLogLayer {
83    type Service = AccessLogService<S>;
84
85    fn layer(&self, inner: S) -> Self::Service {
86        AccessLogService {
87            inner,
88            slow_ttfb: self.slow_ttfb,
89        }
90    }
91}
92
93// --------------------------------------
94// Service
95// --------------------------------------
96
97#[derive(Clone, Debug)]
98pub struct AccessLogService<S> {
99    inner: S,
100    slow_ttfb: Duration,
101}
102
103impl<S, B, ResponseBody> Service<Request<B>> for AccessLogService<S>
104where
105    S: Service<Request<B>, Response = Response<ResponseBody>>,
106    ResponseBody: Body + Send + 'static,
107    ResponseBody::Data: Buf,
108    <ResponseBody as Body>::Error: std::fmt::Display,
109{
110    type Response = Response<CountingBody<ResponseBody>>;
111    type Error = S::Error;
112    type Future = AccessLogFuture<S, B>;
113
114    fn poll_ready(
115        &mut self,
116        cx: &mut Context<'_>,
117    ) -> Poll<Result<(), Self::Error>> {
118        self.inner.poll_ready(cx)
119    }
120
121    fn call(&mut self, req: Request<B>) -> Self::Future {
122        let start_instant = Instant::now();
123        let start_time_local: DateTime<Local> = Local::now();
124
125        let method = req.method().clone();
126        let uri_path = req
127            .uri()
128            .path_and_query()
129            .map_or_else(|| req.uri().path(), http::uri::PathAndQuery::as_str)
130            .to_string();
131        let version = req.version();
132
133        // Remote host if ConnectInfo configured, or CF-Connecting-IP header
134        let headers = req.headers().clone();
135        let cloudflare_ip = headers
136            .get("CF-Connecting-IP")
137            .and_then(|h| h.to_str().ok());
138        let remote_host = if let Some(cloudflare_ip) = cloudflare_ip {
139            cloudflare_ip
140        } else {
141            &req.extensions()
142                .get::<ConnectInfo<SocketAddr>>()
143                .map_or_else(|| "-".into(), |c| c.0.ip().to_string())
144        };
145
146        let fut = self.inner.call(req);
147
148        AccessLogFuture {
149            state: Some(AccessLogState {
150                start_instant,
151                start_time_local,
152                method,
153                path_and_query: uri_path,
154                version,
155                remote_host: remote_host.to_string(),
156                ident: "-".into(),
157                user: "-".into(),
158                ttfb_warn_threshold: self.slow_ttfb,
159            }),
160            inner: fut,
161        }
162    }
163}
164
165// --------------------------------------
166// Future wrapping inner service future
167// --------------------------------------
168
169struct AccessLogState {
170    remote_host: String,
171    /// RFC 1413 identity
172    ident: String,
173    /// HTTP auth user
174    user: String,
175    start_instant: Instant,
176    start_time_local: DateTime<Local>,
177    method: http::Method,
178    path_and_query: String,
179    version: Version,
180    ttfb_warn_threshold: Duration,
181}
182
183#[pin_project]
184pub struct AccessLogFuture<S, B>
185where
186    S: Service<Request<B>>,
187{
188    #[pin]
189    inner: S::Future,
190    state: Option<AccessLogState>,
191}
192
193impl<S, B, ResponseBody> Future for AccessLogFuture<S, B>
194where
195    S: Service<Request<B>, Response = Response<ResponseBody>>,
196    ResponseBody: Body + Send + 'static,
197    ResponseBody::Data: Buf,
198    <ResponseBody as Body>::Error: std::fmt::Display,
199{
200    type Output = Result<Response<CountingBody<ResponseBody>>, S::Error>;
201
202    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
203        let this = self.project();
204        let state = this.state.as_ref().expect("state present while polling");
205
206        match this.inner.poll(cx) {
207            Poll::Pending => Poll::Pending,
208            Poll::Ready(res) => {
209                let response = res?;
210                let ttfb = state.start_instant.elapsed();
211                let status = response.status();
212
213                // Extract content-length if present (used only if no chunks seen)
214                let content_length_guess = response
215                    .headers()
216                    .get(http::header::CONTENT_LENGTH)
217                    .and_then(|v| v.to_str().ok())
218                    .and_then(|s| s.parse::<usize>().ok());
219
220                // Emit slow TTFB warning
221                if ttfb > state.ttfb_warn_threshold {
222                    tracing::warn!(
223                        target:"access",
224                        ttfb_ms = ttfb.as_millis(),
225                        "Slow TTFB ({}ms) {} {}",
226                        ttfb.as_millis(),
227                        state.method,
228                        state.path_and_query
229                    );
230                }
231
232                // Prepare shared log record
233                let shared = Arc::new(SharedLog {
234                    remote_host: state.remote_host.clone(),
235                    ident: state.ident.clone(),
236                    user: state.user.clone(),
237                    start_time_local: state.start_time_local,
238                    method: state.method.clone(),
239                    path_and_query: state.path_and_query.clone(),
240                    protocol: http_version_str(state.version).to_string(),
241                    status: status.as_u16(),
242                    bytes: AtomicUsize::new(content_length_guess.unwrap_or(0)),
243                    logged: AtomicBool::new(false),
244                });
245
246                let headers = response.headers().clone();
247                let body = response.into_body();
248                let counting = CountingBody {
249                    inner: body,
250                    shared: shared.clone(),
251                };
252                // Copy headers from original response
253                let mut response_with_body =
254                    Response::builder().status(status).version(state.version);
255                for (key, value) in &headers {
256                    response_with_body = response_with_body.header(key, value);
257                }
258                let response_with_body = response_with_body
259                    .body(counting)
260                    .expect("response rebuild");
261
262                // Drop state
263                *this.state = None;
264
265                Poll::Ready(Ok(response_with_body))
266            }
267        }
268    }
269}
270
271// --------------------------------------
272// Counting Body
273// --------------------------------------
274
275#[pin_project(PinnedDrop)]
276pub struct CountingBody<B> {
277    #[pin]
278    inner: B,
279    shared: Arc<SharedLog>,
280}
281
282impl<B> Body for CountingBody<B>
283where
284    B: Body + Send + 'static,
285    B::Data: Buf,
286{
287    type Data = B::Data;
288    type Error = B::Error;
289
290    fn poll_frame(
291        self: Pin<&mut Self>,
292        cx: &mut Context<'_>,
293    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
294        let this = self.project();
295        match this.inner.poll_frame(cx) {
296            Poll::Ready(Some(Ok(frame))) => {
297                if let Some(data) = frame.data_ref() {
298                    // Count bytes in this data frame
299                    this.shared
300                        .bytes
301                        .fetch_add(data.remaining(), Ordering::Relaxed);
302                }
303                Poll::Ready(Some(Ok(frame)))
304            }
305            other => other,
306        }
307    }
308
309    fn is_end_stream(&self) -> bool {
310        self.inner.is_end_stream()
311    }
312
313    fn size_hint(&self) -> http_body::SizeHint {
314        self.inner.size_hint()
315    }
316}
317
318#[pinned_drop]
319impl<B> PinnedDrop for CountingBody<B> {
320    fn drop(self: Pin<&mut Self>) {
321        // Log once at end-of-stream / drop
322        if self
323            .shared
324            .logged
325            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
326            .is_ok()
327        {
328            let line = self.shared.clf_line();
329            let status = self.shared.status;
330            match status {
331                500..=599 => tracing::error!(target: "access", "{}", line),
332                400..=499 => tracing::warn!(target: "access", "{}", line),
333                _ => tracing::debug!(target: "access", "{}", line),
334            }
335        }
336    }
337}
338
339// --------------------------------------
340// Shared log record
341// --------------------------------------
342
343struct SharedLog {
344    remote_host: String,
345    ident: String,
346    user: String,
347    start_time_local: DateTime<Local>,
348    method: http::Method,
349    path_and_query: String,
350    protocol: String,
351    status: u16,
352    bytes: AtomicUsize,
353    logged: AtomicBool,
354}
355
356impl SharedLog {
357    fn clf_line(&self) -> String {
358        // [day/Mon/year:HH:MM:SS zone]
359        let ts = self.start_time_local.format("%d/%b/%Y:%H:%M:%S %z");
360        let bytes = self.bytes.load(Ordering::Relaxed);
361        let bytes_owned;
362        let bytes_str: &str = if bytes == 0 {
363            "-"
364        } else {
365            bytes_owned = bytes.to_string();
366            &bytes_owned
367        };
368        format!(
369            r#"{remote} {ident} {user} [{ts}] "{method} {path} {proto}" {status} {bytes}"#,
370            remote = self.remote_host,
371            ident = self.ident,
372            user = self.user,
373            ts = ts,
374            method = self.method,
375            path = self.path_and_query,
376            proto = self.protocol,
377            status = self.status,
378            bytes = bytes_str,
379        )
380    }
381}
382
383fn http_version_str(v: Version) -> &'static str {
384    match v {
385        Version::HTTP_09 => "HTTP/0.9",
386        Version::HTTP_10 => "HTTP/1.0",
387        Version::HTTP_11 => "HTTP/1.1",
388        Version::HTTP_2 => "HTTP/2.0",
389        Version::HTTP_3 => "HTTP/3.0",
390        _ => "HTTP/?",
391    }
392}
393
394// --------------------------------------
395// Optional convenience re-export
396// --------------------------------------
397
398pub mod prelude {
399    pub use super::AccessLogLayer;
400}
401
402// --------------------------------------
403// Tests (basic smoke)
404// --------------------------------------
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409    use bytes::Bytes;
410    use http::StatusCode;
411    use tower::{ServiceExt, service_fn};
412
413    #[tokio::test]
414    async fn test_basic_logging() {
415        let layer = AccessLogLayer::default();
416        let svc = layer.layer(service_fn(|_req: Request<()>| async {
417            Ok::<_, std::convert::Infallible>(
418                Response::builder()
419                    .status(StatusCode::OK)
420                    // Use a body whose Data implements Buf
421                    .body(http_body_util::Empty::<Bytes>::new())
422                    .unwrap(),
423            )
424        }));
425
426        let _resp = svc
427            .clone()
428            .oneshot(Request::builder().uri("/test").body(()).unwrap())
429            .await
430            .unwrap();
431        // On drop of body, log line would be emitted. We can't easily assert here
432        // without capturing logs; this just ensures no panic.
433    }
434}