1use std::{
42 future::Future,
43 net::SocketAddr,
44 pin::Pin,
45 sync::{
46 Arc,
47 atomic::{AtomicBool, AtomicUsize, Ordering},
48 },
49 task::{Context, Poll},
50 time::{Duration, Instant},
51};
52
53use axum::extract::connect_info::ConnectInfo;
54use bytes::Buf;
55use chrono::{DateTime, Local};
56use http::{Request, Response, Version};
57use http_body::{Body, Frame};
58use pin_project::{pin_project, pinned_drop};
59use tower::{Layer, Service};
60
61#[derive(Clone, Debug)]
66pub struct AccessLogLayer {
67 slow_ttfb: Duration,
68}
69
70impl AccessLogLayer {
71 pub fn new(slow_ttfb: Duration) -> Self {
72 Self { slow_ttfb }
73 }
74}
75
76impl Default for AccessLogLayer {
77 fn default() -> Self {
78 Self::new(Duration::from_millis(150))
79 }
80}
81
82impl<S> Layer<S> for AccessLogLayer {
83 type Service = AccessLogService<S>;
84
85 fn layer(&self, inner: S) -> Self::Service {
86 AccessLogService {
87 inner,
88 slow_ttfb: self.slow_ttfb,
89 }
90 }
91}
92
93#[derive(Clone, Debug)]
98pub struct AccessLogService<S> {
99 inner: S,
100 slow_ttfb: Duration,
101}
102
103impl<S, B, ResponseBody> Service<Request<B>> for AccessLogService<S>
104where
105 S: Service<Request<B>, Response = Response<ResponseBody>>,
106 ResponseBody: Body + Send + 'static,
107 ResponseBody::Data: Buf,
108 <ResponseBody as Body>::Error: std::fmt::Display,
109{
110 type Response = Response<CountingBody<ResponseBody>>;
111 type Error = S::Error;
112 type Future = AccessLogFuture<S, B>;
113
114 fn poll_ready(
115 &mut self,
116 cx: &mut Context<'_>,
117 ) -> Poll<Result<(), Self::Error>> {
118 self.inner.poll_ready(cx)
119 }
120
121 fn call(&mut self, req: Request<B>) -> Self::Future {
122 let start_instant = Instant::now();
123 let start_time_local: DateTime<Local> = Local::now();
124
125 let method = req.method().clone();
126 let uri_path = req
127 .uri()
128 .path_and_query()
129 .map_or_else(|| req.uri().path(), http::uri::PathAndQuery::as_str)
130 .to_string();
131 let version = req.version();
132
133 let headers = req.headers().clone();
135 let cloudflare_ip = headers
136 .get("CF-Connecting-IP")
137 .and_then(|h| h.to_str().ok());
138 let remote_host = if let Some(cloudflare_ip) = cloudflare_ip {
139 cloudflare_ip
140 } else {
141 &req.extensions()
142 .get::<ConnectInfo<SocketAddr>>()
143 .map_or_else(|| "-".into(), |c| c.0.ip().to_string())
144 };
145
146 let fut = self.inner.call(req);
147
148 AccessLogFuture {
149 state: Some(AccessLogState {
150 start_instant,
151 start_time_local,
152 method,
153 path_and_query: uri_path,
154 version,
155 remote_host: remote_host.to_string(),
156 ident: "-".into(),
157 user: "-".into(),
158 ttfb_warn_threshold: self.slow_ttfb,
159 }),
160 inner: fut,
161 }
162 }
163}
164
165struct AccessLogState {
170 remote_host: String,
171 ident: String,
173 user: String,
175 start_instant: Instant,
176 start_time_local: DateTime<Local>,
177 method: http::Method,
178 path_and_query: String,
179 version: Version,
180 ttfb_warn_threshold: Duration,
181}
182
183#[pin_project]
184pub struct AccessLogFuture<S, B>
185where
186 S: Service<Request<B>>,
187{
188 #[pin]
189 inner: S::Future,
190 state: Option<AccessLogState>,
191}
192
193impl<S, B, ResponseBody> Future for AccessLogFuture<S, B>
194where
195 S: Service<Request<B>, Response = Response<ResponseBody>>,
196 ResponseBody: Body + Send + 'static,
197 ResponseBody::Data: Buf,
198 <ResponseBody as Body>::Error: std::fmt::Display,
199{
200 type Output = Result<Response<CountingBody<ResponseBody>>, S::Error>;
201
202 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
203 let this = self.project();
204 let state = this.state.as_ref().expect("state present while polling");
205
206 match this.inner.poll(cx) {
207 Poll::Pending => Poll::Pending,
208 Poll::Ready(res) => {
209 let response = res?;
210 let ttfb = state.start_instant.elapsed();
211 let status = response.status();
212
213 let content_length_guess = response
215 .headers()
216 .get(http::header::CONTENT_LENGTH)
217 .and_then(|v| v.to_str().ok())
218 .and_then(|s| s.parse::<usize>().ok());
219
220 if ttfb > state.ttfb_warn_threshold {
222 tracing::warn!(
223 target:"access",
224 ttfb_ms = ttfb.as_millis(),
225 "Slow TTFB ({}ms) {} {}",
226 ttfb.as_millis(),
227 state.method,
228 state.path_and_query
229 );
230 }
231
232 let shared = Arc::new(SharedLog {
234 remote_host: state.remote_host.clone(),
235 ident: state.ident.clone(),
236 user: state.user.clone(),
237 start_time_local: state.start_time_local,
238 method: state.method.clone(),
239 path_and_query: state.path_and_query.clone(),
240 protocol: http_version_str(state.version).to_string(),
241 status: status.as_u16(),
242 bytes: AtomicUsize::new(content_length_guess.unwrap_or(0)),
243 logged: AtomicBool::new(false),
244 });
245
246 let headers = response.headers().clone();
247 let body = response.into_body();
248 let counting = CountingBody {
249 inner: body,
250 shared: shared.clone(),
251 };
252 let mut response_with_body =
254 Response::builder().status(status).version(state.version);
255 for (key, value) in &headers {
256 response_with_body = response_with_body.header(key, value);
257 }
258 let response_with_body = response_with_body
259 .body(counting)
260 .expect("response rebuild");
261
262 *this.state = None;
264
265 Poll::Ready(Ok(response_with_body))
266 }
267 }
268 }
269}
270
271#[pin_project(PinnedDrop)]
276pub struct CountingBody<B> {
277 #[pin]
278 inner: B,
279 shared: Arc<SharedLog>,
280}
281
282impl<B> Body for CountingBody<B>
283where
284 B: Body + Send + 'static,
285 B::Data: Buf,
286{
287 type Data = B::Data;
288 type Error = B::Error;
289
290 fn poll_frame(
291 self: Pin<&mut Self>,
292 cx: &mut Context<'_>,
293 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
294 let this = self.project();
295 match this.inner.poll_frame(cx) {
296 Poll::Ready(Some(Ok(frame))) => {
297 if let Some(data) = frame.data_ref() {
298 this.shared
300 .bytes
301 .fetch_add(data.remaining(), Ordering::Relaxed);
302 }
303 Poll::Ready(Some(Ok(frame)))
304 }
305 other => other,
306 }
307 }
308
309 fn is_end_stream(&self) -> bool {
310 self.inner.is_end_stream()
311 }
312
313 fn size_hint(&self) -> http_body::SizeHint {
314 self.inner.size_hint()
315 }
316}
317
318#[pinned_drop]
319impl<B> PinnedDrop for CountingBody<B> {
320 fn drop(self: Pin<&mut Self>) {
321 if self
323 .shared
324 .logged
325 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
326 .is_ok()
327 {
328 let line = self.shared.clf_line();
329 let status = self.shared.status;
330 match status {
331 500..=599 => tracing::error!(target: "access", "{}", line),
332 400..=499 => tracing::warn!(target: "access", "{}", line),
333 _ => tracing::debug!(target: "access", "{}", line),
334 }
335 }
336 }
337}
338
339struct SharedLog {
344 remote_host: String,
345 ident: String,
346 user: String,
347 start_time_local: DateTime<Local>,
348 method: http::Method,
349 path_and_query: String,
350 protocol: String,
351 status: u16,
352 bytes: AtomicUsize,
353 logged: AtomicBool,
354}
355
356impl SharedLog {
357 fn clf_line(&self) -> String {
358 let ts = self.start_time_local.format("%d/%b/%Y:%H:%M:%S %z");
360 let bytes = self.bytes.load(Ordering::Relaxed);
361 let bytes_owned;
362 let bytes_str: &str = if bytes == 0 {
363 "-"
364 } else {
365 bytes_owned = bytes.to_string();
366 &bytes_owned
367 };
368 format!(
369 r#"{remote} {ident} {user} [{ts}] "{method} {path} {proto}" {status} {bytes}"#,
370 remote = self.remote_host,
371 ident = self.ident,
372 user = self.user,
373 ts = ts,
374 method = self.method,
375 path = self.path_and_query,
376 proto = self.protocol,
377 status = self.status,
378 bytes = bytes_str,
379 )
380 }
381}
382
383fn http_version_str(v: Version) -> &'static str {
384 match v {
385 Version::HTTP_09 => "HTTP/0.9",
386 Version::HTTP_10 => "HTTP/1.0",
387 Version::HTTP_11 => "HTTP/1.1",
388 Version::HTTP_2 => "HTTP/2.0",
389 Version::HTTP_3 => "HTTP/3.0",
390 _ => "HTTP/?",
391 }
392}
393
394pub mod prelude {
399 pub use super::AccessLogLayer;
400}
401
402#[cfg(test)]
407mod tests {
408 use super::*;
409 use bytes::Bytes;
410 use http::StatusCode;
411 use tower::{ServiceExt, service_fn};
412
413 #[tokio::test]
414 async fn test_basic_logging() {
415 let layer = AccessLogLayer::default();
416 let svc = layer.layer(service_fn(|_req: Request<()>| async {
417 Ok::<_, std::convert::Infallible>(
418 Response::builder()
419 .status(StatusCode::OK)
420 .body(http_body_util::Empty::<Bytes>::new())
422 .unwrap(),
423 )
424 }));
425
426 let _resp = svc
427 .clone()
428 .oneshot(Request::builder().uri("/test").body(()).unwrap())
429 .await
430 .unwrap();
431 }
434}