//! Length-delimited framed local-socket transport (Tokio).
//!
//! This module provides a Tokio-based implementation of framed connections over
//! local sockets using a length-delimited codec. It includes both the client
//! and server sides, allowing for bidirectional communication between processes
//! on the same machine.
//!
//! The implementation is based on Tokio's asynchronous runtime and interprocess
//! crate for local socket communication. It provides a high-level API for
//! sending and receiving framed messages, as well as managing connections and
//! listeners.

use crate::workspace::ipc::error::Error;
use async_trait::async_trait;
use bytes::{Bytes, BytesMut};
use futures::{SinkExt, StreamExt};
use interprocess::local_socket::{
    GenericFilePath, ListenerOptions,
    tokio::{Listener as TokioListener, Stream as TokioStream, prelude::*},
};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::sync::Mutex;
// Add AsyncWriteExt so we can shutdown the underlying stream and signal EOF.
use tokio_util::codec::{Framed, LengthDelimitedCodec};

/// A framed, length-delimited, duplex connection. Multiplexing is layered
/// above this.
#[async_trait]
pub trait FramedConnection: Send + Sync {
    /// Send one frame.
    async fn send_frame(&self, data: Bytes) -> Result<(), Error>;
    /// Receive one frame. Returns None on EOF.
    async fn recv_frame(&self) -> Result<Option<Bytes>, Error>;
    /// Close half or full connection gracefully.
    async fn close(&self) -> Result<(), Error>;
}

/// Factory to connect or accept connections using local sockets.
#[async_trait]
pub trait TransportFactory: Send + Sync {
    type Conn: FramedConnection;

    /// Client side: connect to an endpoint (e.g., path or named pipe).
    async fn connect(&self, endpoint: &str) -> Result<Self::Conn, Error>;

    /// Server side: bind and accept incoming connections.
    async fn bind(
        &self,
        endpoint: &str,
    ) -> Result<Box<dyn TransportListener<Conn = Self::Conn>>, Error>;
}

#[async_trait]
pub trait TransportListener: Send + Sync {
    type Conn: FramedConnection;

    /// Accept the next connection. Returns None when listener is closed.
    async fn accept(&self) -> Result<Option<Self::Conn>, Error>;

    /// Close the listener.
    async fn close(&self) -> Result<(), Error>;
}

/// Concrete framed connection over a Tokio local socket using a
/// length-delimited codec.
pub struct LocalSocketFramedConnection {
    inner: Mutex<Framed<TokioStream, LengthDelimitedCodec>>,
}

#[async_trait]
impl FramedConnection for LocalSocketFramedConnection {
    /// Send one frame.
    async fn send_frame(&self, data: Bytes) -> Result<(), Error> {
        let mut guard = self.inner.lock().await;
        let mut bm = BytesMut::with_capacity(data.len());
        bm.extend_from_slice(&data);
        SinkExt::send(&mut *guard, bm.into())
            .await
            .map_err(Error::from)
    }

    /// Receive one frame. Returns None on EOF.
    async fn recv_frame(&self) -> Result<Option<Bytes>, Error> {
        let mut guard = self.inner.lock().await;
        match StreamExt::next(&mut *guard).await {
            Some(Ok(bm)) => Ok(Some(bm.freeze())),
            Some(Err(e)) => Err(Error::from(e)),
            None => Ok(None),
        }
    }

    /// Close half or full connection gracefully.
    async fn close(&self) -> Result<(), Error> {
        let mut guard = self.inner.lock().await;
        // Flush and close the codec sink.
        SinkExt::close(&mut *guard).await.map_err(Error::from)
    }
}

/// Concrete transport factory based on Tokio local sockets.
pub struct LocalSocketTransportFactory;

#[async_trait]
impl TransportFactory for LocalSocketTransportFactory {
    type Conn = LocalSocketFramedConnection;

    /// Client side: connect to an endpoint (e.g., path or named pipe).
    async fn connect(&self, endpoint: &str) -> Result<Self::Conn, Error> {
        // Build a platform-appropriate name from the provided endpoint string.
        #[cfg(windows)]
        let name = endpoint
            .to_ns_name::<GenericNamespaced>()
            .map_err(Error::from)?;
        #[cfg(unix)]
        let name = endpoint
            .to_fs_name::<GenericFilePath>()
            .map_err(Error::from)?;
        let stream = TokioStream::connect(name).await.map_err(Error::from)?;
        let framed = Framed::new(stream, LengthDelimitedCodec::new());
        Ok(LocalSocketFramedConnection {
            inner: Mutex::new(framed),
        })
    }

    /// Server side: bind and accept incoming connections.
    async fn bind(
        &self,
        endpoint: &str,
    ) -> Result<Box<dyn TransportListener<Conn = Self::Conn>>, Error> {
        // Build a platform-appropriate name from the provided endpoint string.
        #[cfg(windows)]
        let name = endpoint
            .to_ns_name::<GenericNamespaced>()
            .map_err(Error::from)?;
        #[cfg(unix)]
        let name = endpoint
            .to_fs_name::<GenericFilePath>()
            .map_err(Error::from)?;
        let listener = ListenerOptions::new()
            .name(name)
            .create_tokio()
            .map_err(Error::from)?;
        Ok(Box::new(LocalSocketTransportListener::new(listener)))
    }
}

/// Transport listener for Tokio local sockets.
pub struct LocalSocketTransportListener {
    inner: Arc<TokioListener>,
    closed: AtomicBool,
}

impl LocalSocketTransportListener {
    pub fn new(listener: TokioListener) -> Self {
        Self {
            inner: Arc::new(listener),
            closed: AtomicBool::new(false),
        }
    }
}

#[async_trait]
impl TransportListener for LocalSocketTransportListener {
    type Conn = LocalSocketFramedConnection;

    /// Accept the next connection. Returns None when listener is closed.
    async fn accept(&self) -> Result<Option<Self::Conn>, Error> {
        if self.closed.load(Ordering::Relaxed) {
            return Ok(None);
        }
        let stream = match self.inner.accept().await {
            Ok(s) => s,
            Err(e) => return Err(Error::from(e)),
        };
        let framed = Framed::new(stream, LengthDelimitedCodec::new());
        Ok(Some(LocalSocketFramedConnection {
            inner: Mutex::new(framed),
        }))
    }

    /// Close the listener.
    async fn close(&self) -> Result<(), Error> {
        self.closed.store(true, Ordering::Release);
        Ok(())
        // Dropping the listener will close the underlying handle once all Arcs are gone.
    }
}

#[cfg(test)]
mod tests {
    use crate::debug_fmt;

    use super::*;
    use anyhow::{Result, ensure};
    use bytes::Bytes;

    fn unique_endpoint() -> String {
        use std::time::{SystemTime, UNIX_EPOCH};
        let pid = std::process::id();
        let nanos = SystemTime::now()
            .duration_since(UNIX_EPOCH)
            .map(|d| d.as_nanos())
            .unwrap_or(0);
        #[cfg(unix)]
        {
            let mut p = std::env::temp_dir();
            p.push(format!("ctb-echo-{}-{}.sock", pid, nanos));
            return p.to_string_lossy().into_owned();
        }
        #[cfg(windows)]
        {
            // Namespaced name for Windows named pipes via interprocess local sockets.
            return format!("ctb-echo-{}-{}", pid, nanos);
        }
    }

    #[crate::ctb_test(tokio::test)]
    async fn echo_frames_over_local_socket() -> Result<()> {
        let endpoint = unique_endpoint();

        // Best-effort cleanup on Unix in case a stale socket path exists.
        #[cfg(unix)]
        let _ = std::fs::remove_file(&endpoint);

        let factory = LocalSocketTransportFactory;

        let payloads = vec![
            Bytes::from_static(b"hello"),
            Bytes::from_static(b"world"),
            Bytes::from_static(b"!"),
        ];

        // Bind server listener.
        let listener = factory.bind(&endpoint).await?;
        let listener = Arc::new(listener);

        // Server task: accept a single connection and echo frames until EOF.
        {
            let listener = Arc::clone(&listener);
            tokio::spawn(async move {
                let conn = match listener.accept().await {
                    Ok(Some(c)) => c,
                    Ok(None) => return anyhow::Ok(()),
                    Err(e) => return Err(anyhow::Error::from(e)),
                };

                debug_fmt!("Server begins listening");
                loop {
                    let maybe = conn.recv_frame().await;
                    let Some(frame) = maybe? else {
                        debug_fmt!("Server ending loop");
                        break;
                    };
                    debug_fmt!("Server echoing frame: {:?}", &frame);
                    conn.send_frame(frame).await?;
                }
                debug_fmt!("Server ended loop");
                conn.close().await.map_err(anyhow::Error::from)
            })
        };

        // Client: connect, send frames, receive echoes.
        let client = factory.connect(&endpoint).await?;

        // Send payloads.
        for p in &payloads {
            client.send_frame(p.clone()).await?;
        }

        debug_fmt!("Done sending on {}; now listening for echoes", &endpoint);

        // Read exactly the expected number of echoes, then close.
        // FIXME: I can't figure out how to do this with just listening for the
        // end of stream, rather than counting. Looping makes it hang.
        /*let mut echoes = Vec::new();
        loop {
            match client.recv_frame().await? {
                Some(b) => echoes.push(b),
                None => break,
                }
            }*/
        let mut echoes = Vec::with_capacity(payloads.len());
        for _ in 0..payloads.len() {
            let Some(b) = client.recv_frame().await? else {
                break;
            };
            debug_fmt!("Client received echo: {:?}", &b);
            echoes.push(b);
        }
        debug_fmt!("Received echoes on {}", &endpoint);

        ensure!(echoes.len() == payloads.len(), "echo count mismatch");
        for (a, b) in echoes.iter().zip(payloads.iter()) {
            ensure!(a == b, "echo payload mismatch");
        }

        // Close listener.
        listener.close().await?;
        debug_fmt!("Closing listener!");

        // Cleanup Unix socket path.
        #[cfg(unix)]
        let _ = std::fs::remove_file(&endpoint);

        Ok(())
    }
}
