//! `AccessLogLayer`: A standalone Tower Layer producing Apache/NCSA Common Log
//! Format (CLF) lines with log level determined by HTTP status:
//!
//!   - error: 5xx
//!   - warn : 4xx
//!   - debug: 2xx / 3xx
//!
//! Also emits a separate warn log when Time To First Byte (TTFB) exceeds the
//! configured threshold (default 150ms) for successful (2xx/3xx) responses.
//!
//! Does not currently support RFC 1413 identities, nor logging the user
//! (whether via basic access authentication, digest access authentication, or
//! otherwise).
//!
//! Integrate by adding (and enabling `ConnectInfo` if you want remote IP):
//!
//! ```rust,ignore
//!   use crate::io::webui::access_log_layer::AccessLogLayer;
//!   app.layer(AccessLogLayer::new(Duration::from_millis(150)))
//! ```
//!
//! Ensure you serve with `into_make_service_with_connect_info::<SocketAddr>()`
//! to populate `ConnectInfo<SocketAddr>` for remote address logging:
//!
//! ```rust,ignore
//!   axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>())
//!       .await?;
//! ```
//!
//! Dependencies you need in Cargo.toml (if not already present):
//! ```toml
//!   chrono = { version = "0.4", features = ["clock"] }
//!   pin-project = "1"
//! ```
//!
//! This module does NOT depend on `tower_http::trace::TraceLayer`.
//!
//! Explanation of some of this:
//! <https://raw.githubusercontent.com/tower-rs/tower/refs/heads/master/guides/building-a-middleware-from-scratch.md>

use std::{
    future::Future,
    net::SocketAddr,
    pin::Pin,
    sync::{
        Arc,
        atomic::{AtomicBool, AtomicUsize, Ordering},
    },
    task::{Context, Poll},
    time::{Duration, Instant},
};

use axum::extract::connect_info::ConnectInfo;
use bytes::Buf;
use chrono::{DateTime, Local};
use http::{Request, Response, Version};
use http_body::{Body, Frame};
use pin_project::{pin_project, pinned_drop};
use tower::{Layer, Service};

// --------------------------------------
// Public Layer
// --------------------------------------

#[derive(Clone, Debug)]
pub struct AccessLogLayer {
    slow_ttfb: Duration,
}

impl AccessLogLayer {
    pub fn new(slow_ttfb: Duration) -> Self {
        Self { slow_ttfb }
    }
}

impl Default for AccessLogLayer {
    fn default() -> Self {
        Self::new(Duration::from_millis(150))
    }
}

impl<S> Layer<S> for AccessLogLayer {
    type Service = AccessLogService<S>;

    fn layer(&self, inner: S) -> Self::Service {
        AccessLogService {
            inner,
            slow_ttfb: self.slow_ttfb,
        }
    }
}

// --------------------------------------
// Service
// --------------------------------------

#[derive(Clone, Debug)]
pub struct AccessLogService<S> {
    inner: S,
    slow_ttfb: Duration,
}

impl<S, B, ResponseBody> Service<Request<B>> for AccessLogService<S>
where
    S: Service<Request<B>, Response = Response<ResponseBody>>,
    ResponseBody: Body + Send + 'static,
    ResponseBody::Data: Buf,
    <ResponseBody as Body>::Error: std::fmt::Display,
{
    type Response = Response<CountingBody<ResponseBody>>;
    type Error = S::Error;
    type Future = AccessLogFuture<S, B>;

    fn poll_ready(
        &mut self,
        cx: &mut Context<'_>,
    ) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, req: Request<B>) -> Self::Future {
        let start_instant = Instant::now();
        let start_time_local: DateTime<Local> = Local::now();

        let method = req.method().clone();
        let uri_path = req
            .uri()
            .path_and_query()
            .map_or_else(|| req.uri().path(), http::uri::PathAndQuery::as_str)
            .to_string();
        let version = req.version();

        // Remote host if ConnectInfo configured, or CF-Connecting-IP header
        let headers = req.headers().clone();
        let cloudflare_ip = headers
            .get("CF-Connecting-IP")
            .and_then(|h| h.to_str().ok());
        let remote_host = if let Some(cloudflare_ip) = cloudflare_ip {
            cloudflare_ip
        } else {
            &req.extensions()
                .get::<ConnectInfo<SocketAddr>>()
                .map_or_else(|| "-".into(), |c| c.0.ip().to_string())
        };

        let fut = self.inner.call(req);

        AccessLogFuture {
            state: Some(AccessLogState {
                start_instant,
                start_time_local,
                method,
                path_and_query: uri_path,
                version,
                remote_host: remote_host.to_string(),
                ident: "-".into(),
                user: "-".into(),
                ttfb_warn_threshold: self.slow_ttfb,
            }),
            inner: fut,
        }
    }
}

// --------------------------------------
// Future wrapping inner service future
// --------------------------------------

struct AccessLogState {
    remote_host: String,
    /// RFC 1413 identity
    ident: String,
    /// HTTP auth user
    user: String,
    start_instant: Instant,
    start_time_local: DateTime<Local>,
    method: http::Method,
    path_and_query: String,
    version: Version,
    ttfb_warn_threshold: Duration,
}

#[pin_project]
pub struct AccessLogFuture<S, B>
where
    S: Service<Request<B>>,
{
    #[pin]
    inner: S::Future,
    state: Option<AccessLogState>,
}

impl<S, B, ResponseBody> Future for AccessLogFuture<S, B>
where
    S: Service<Request<B>, Response = Response<ResponseBody>>,
    ResponseBody: Body + Send + 'static,
    ResponseBody::Data: Buf,
    <ResponseBody as Body>::Error: std::fmt::Display,
{
    type Output = Result<Response<CountingBody<ResponseBody>>, S::Error>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let this = self.project();
        let state = this.state.as_ref().expect("state present while polling");

        match this.inner.poll(cx) {
            Poll::Pending => Poll::Pending,
            Poll::Ready(res) => {
                let response = res?;
                let ttfb = state.start_instant.elapsed();
                let status = response.status();

                // Extract content-length if present (used only if no chunks seen)
                let content_length_guess = response
                    .headers()
                    .get(http::header::CONTENT_LENGTH)
                    .and_then(|v| v.to_str().ok())
                    .and_then(|s| s.parse::<usize>().ok());

                // Emit slow TTFB warning
                if ttfb > state.ttfb_warn_threshold {
                    tracing::warn!(
                        target:"access",
                        ttfb_ms = ttfb.as_millis(),
                        "Slow TTFB ({}ms) {} {}",
                        ttfb.as_millis(),
                        state.method,
                        state.path_and_query
                    );
                }

                // Prepare shared log record
                let shared = Arc::new(SharedLog {
                    remote_host: state.remote_host.clone(),
                    ident: state.ident.clone(),
                    user: state.user.clone(),
                    start_time_local: state.start_time_local,
                    method: state.method.clone(),
                    path_and_query: state.path_and_query.clone(),
                    protocol: http_version_str(state.version).to_string(),
                    status: status.as_u16(),
                    bytes: AtomicUsize::new(content_length_guess.unwrap_or(0)),
                    logged: AtomicBool::new(false),
                });

                let headers = response.headers().clone();
                let body = response.into_body();
                let counting = CountingBody {
                    inner: body,
                    shared: shared.clone(),
                };
                // Copy headers from original response
                let mut response_with_body =
                    Response::builder().status(status).version(state.version);
                for (key, value) in &headers {
                    response_with_body = response_with_body.header(key, value);
                }
                let response_with_body = response_with_body
                    .body(counting)
                    .expect("response rebuild");

                // Drop state
                *this.state = None;

                Poll::Ready(Ok(response_with_body))
            }
        }
    }
}

// --------------------------------------
// Counting Body
// --------------------------------------

#[pin_project(PinnedDrop)]
pub struct CountingBody<B> {
    #[pin]
    inner: B,
    shared: Arc<SharedLog>,
}

impl<B> Body for CountingBody<B>
where
    B: Body + Send + 'static,
    B::Data: Buf,
{
    type Data = B::Data;
    type Error = B::Error;

    fn poll_frame(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
        let this = self.project();
        match this.inner.poll_frame(cx) {
            Poll::Ready(Some(Ok(frame))) => {
                if let Some(data) = frame.data_ref() {
                    // Count bytes in this data frame
                    this.shared
                        .bytes
                        .fetch_add(data.remaining(), Ordering::Relaxed);
                }
                Poll::Ready(Some(Ok(frame)))
            }
            other => other,
        }
    }

    fn is_end_stream(&self) -> bool {
        self.inner.is_end_stream()
    }

    fn size_hint(&self) -> http_body::SizeHint {
        self.inner.size_hint()
    }
}

#[pinned_drop]
impl<B> PinnedDrop for CountingBody<B> {
    fn drop(self: Pin<&mut Self>) {
        // Log once at end-of-stream / drop
        if self
            .shared
            .logged
            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
            .is_ok()
        {
            let line = self.shared.clf_line();
            let status = self.shared.status;
            match status {
                500..=599 => tracing::error!(target: "access", "{}", line),
                400..=499 => tracing::warn!(target: "access", "{}", line),
                _ => tracing::debug!(target: "access", "{}", line),
            }
        }
    }
}

// --------------------------------------
// Shared log record
// --------------------------------------

struct SharedLog {
    remote_host: String,
    ident: String,
    user: String,
    start_time_local: DateTime<Local>,
    method: http::Method,
    path_and_query: String,
    protocol: String,
    status: u16,
    bytes: AtomicUsize,
    logged: AtomicBool,
}

impl SharedLog {
    fn clf_line(&self) -> String {
        // [day/Mon/year:HH:MM:SS zone]
        let ts = self.start_time_local.format("%d/%b/%Y:%H:%M:%S %z");
        let bytes = self.bytes.load(Ordering::Relaxed);
        let bytes_owned;
        let bytes_str: &str = if bytes == 0 {
            "-"
        } else {
            bytes_owned = bytes.to_string();
            &bytes_owned
        };
        format!(
            r#"{remote} {ident} {user} [{ts}] "{method} {path} {proto}" {status} {bytes}"#,
            remote = self.remote_host,
            ident = self.ident,
            user = self.user,
            ts = ts,
            method = self.method,
            path = self.path_and_query,
            proto = self.protocol,
            status = self.status,
            bytes = bytes_str,
        )
    }
}

fn http_version_str(v: Version) -> &'static str {
    match v {
        Version::HTTP_09 => "HTTP/0.9",
        Version::HTTP_10 => "HTTP/1.0",
        Version::HTTP_11 => "HTTP/1.1",
        Version::HTTP_2 => "HTTP/2.0",
        Version::HTTP_3 => "HTTP/3.0",
        _ => "HTTP/?",
    }
}

// --------------------------------------
// Optional convenience re-export
// --------------------------------------

pub mod prelude {
    pub use super::AccessLogLayer;
}

// --------------------------------------
// Tests (basic smoke)
// --------------------------------------

#[cfg(test)]
mod tests {
    use super::*;
    use bytes::Bytes;
    use http::StatusCode;
    use tower::{ServiceExt, service_fn};

    #[tokio::test]
    async fn test_basic_logging() {
        let layer = AccessLogLayer::default();
        let svc = layer.layer(service_fn(|_req: Request<()>| async {
            Ok::<_, std::convert::Infallible>(
                Response::builder()
                    .status(StatusCode::OK)
                    // Use a body whose Data implements Buf
                    .body(http_body_util::Empty::<Bytes>::new())
                    .unwrap(),
            )
        }));

        let _resp = svc
            .clone()
            .oneshot(Request::builder().uri("/test").body(()).unwrap())
            .await
            .unwrap();
        // On drop of body, log line would be emitted. We can't easily assert here
        // without capturing logs; this just ensures no panic.
    }
}
