use crate::workspace::ipc::error::Error;

#[cfg(unix)]
use anyhow::Context;

use nix::libc;
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, RawFd};

#[allow(unsafe_code)]
#[cfg(all(unix, target_os = "linux"))]
pub fn create_memfd(size: u64) -> Result<i32, Error> {
    use std::ffi::CString;

    let name = CString::new("ctb-blob")
        .context("memfd name contained an interior NUL")?;

    // Safety: libc call; returns an owned fd on success.
    let fd = unsafe { libc::memfd_create(name.as_ptr(), libc::MFD_CLOEXEC) };
    if fd < 0 {
        return Err(anyhow::anyhow!(
            "memfd_create failed: {}",
            std::io::Error::last_os_error()
        )
        .into());
    }

    // Safety: ftruncate sizes the anonymous file backing.
    let rc = unsafe {
        libc::ftruncate(
            fd,
            i64::try_from(size).context("size too large for off_t")?,
        )
    };
    if rc != 0 {
        let err = std::io::Error::last_os_error();
        let _ = unsafe { libc::close(fd) };
        return Err(anyhow::anyhow!("ftruncate failed: {err}").into());
    }

    Ok(fd)
}

#[cfg(all(unix, not(target_os = "linux")))]
pub fn create_memfd(_size: u64) -> Result<i32, Error> {
    Err(
        anyhow::anyhow!("memfd_create not supported on this Unix platform")
            .into(),
    )
}

#[allow(unsafe_code)]
#[cfg(unix)]
pub fn dup_fd(fd: RawFd) -> Result<RawFd, Error> {
    // Safety: dup returns a new owned fd on success.
    let new_fd = unsafe { libc::dup(fd) };
    if new_fd < 0 {
        return Err(anyhow::anyhow!(
            "dup failed: {}",
            std::io::Error::last_os_error()
        )
        .into());
    }
    Ok(new_fd)
}

#[allow(unsafe_code)]
#[cfg(unix)]
pub fn close_fd(fd: RawFd) -> Result<(), Error> {
    // Safety: close consumes the fd.
    let rc = unsafe { libc::close(fd) };
    if rc != 0 {
        return Err(anyhow::anyhow!(
            "close failed: {}",
            std::io::Error::last_os_error()
        )
        .into());
    }
    Ok(())
}

//
#[allow(unsafe_code, clippy::as_conversions)]
#[cfg(unix)]
pub fn send_fd(
    stream: &std::os::unix::net::UnixStream,
    fd: RawFd,
) -> Result<(), Error> {
    use std::mem;

    let byte: [u8; 1] = [0u8];
    let mut iov = libc::iovec {
        iov_base: byte.as_ptr() as *mut libc::c_void,
        iov_len: byte.len(),
    };

    let mut cmsg_buf = [0u8; 64];
    let mut msg: libc::msghdr = unsafe { mem::zeroed() };
    msg.msg_iov = std::ptr::addr_of_mut!(iov);
    msg.msg_iovlen = 1;
    msg.msg_control = cmsg_buf.as_mut_ptr().cast::<libc::c_void>();
    msg.msg_controllen = cmsg_buf.len();

    // Safety: build SCM_RIGHTS ancillary message.
    unsafe {
        let cmsg = libc::CMSG_FIRSTHDR(std::ptr::addr_of!(msg));
        if cmsg.is_null() {
            return Err(anyhow::anyhow!("CMSG_FIRSTHDR returned null").into());
        }
        (*cmsg).cmsg_level = libc::SOL_SOCKET;
        (*cmsg).cmsg_type = libc::SCM_RIGHTS;
        (*cmsg).cmsg_len =
            libc::CMSG_LEN(mem::size_of::<RawFd>() as u32) as usize;

        let data = libc::CMSG_DATA(cmsg).cast::<RawFd>();
        *data = fd;

        msg.msg_controllen = (*cmsg).cmsg_len;
    }

    let rc = unsafe {
        libc::sendmsg(stream.as_raw_fd(), std::ptr::addr_of!(msg), 0)
    };
    if rc < 0 {
        return Err(anyhow::anyhow!(
            "sendmsg failed: {}",
            std::io::Error::last_os_error()
        )
        .into());
    }
    Ok(())
}

#[allow(unsafe_code, clippy::as_conversions)]
#[cfg(unix)]
pub fn recv_fd(
    stream: &std::os::unix::net::UnixStream,
) -> Result<RawFd, Error> {
    use std::mem;

    let mut byte: [u8; 1] = [0u8];
    let mut iov = libc::iovec {
        iov_base: byte.as_mut_ptr().cast::<libc::c_void>(),
        iov_len: byte.len(),
    };

    let mut cmsg_buf = [0u8; 64];
    let mut msg: libc::msghdr = unsafe { mem::zeroed() };
    msg.msg_iov = std::ptr::addr_of_mut!(iov);
    msg.msg_iovlen = 1;
    msg.msg_control = cmsg_buf.as_mut_ptr().cast::<libc::c_void>();
    msg.msg_controllen = cmsg_buf.len();

    let rc = unsafe {
        libc::recvmsg(stream.as_raw_fd(), std::ptr::addr_of_mut!(msg), 0)
    };
    if rc < 0 {
        return Err(anyhow::anyhow!(
            "recvmsg failed: {}",
            std::io::Error::last_os_error()
        )
        .into());
    }

    // Safety: parse SCM_RIGHTS.
    unsafe {
        let cmsg = libc::CMSG_FIRSTHDR(std::ptr::addr_of!(msg));
        if cmsg.is_null() {
            return Err(anyhow::anyhow!("no cmsg received").into());
        }
        if (*cmsg).cmsg_level != libc::SOL_SOCKET
            || (*cmsg).cmsg_type != libc::SCM_RIGHTS
        {
            return Err(anyhow::anyhow!("unexpected cmsg type").into());
        }
        let data = libc::CMSG_DATA(cmsg).cast::<RawFd>();
        Ok(*data)
    }
}
