ctoolbox/workspace/ipc/platform/
unix.rs

1use crate::workspace::ipc::error::Error;
2
3#[cfg(unix)]
4use anyhow::Context;
5
6use nix::libc;
7#[cfg(unix)]
8use std::os::unix::io::{AsRawFd, RawFd};
9
10#[allow(unsafe_code)]
11#[cfg(all(unix, target_os = "linux"))]
12pub fn create_memfd(size: u64) -> Result<i32, Error> {
13    use std::ffi::CString;
14
15    let name = CString::new("ctb-blob")
16        .context("memfd name contained an interior NUL")?;
17
18    // Safety: libc call; returns an owned fd on success.
19    let fd = unsafe { libc::memfd_create(name.as_ptr(), libc::MFD_CLOEXEC) };
20    if fd < 0 {
21        return Err(anyhow::anyhow!(
22            "memfd_create failed: {}",
23            std::io::Error::last_os_error()
24        )
25        .into());
26    }
27
28    // Safety: ftruncate sizes the anonymous file backing.
29    let rc = unsafe {
30        libc::ftruncate(
31            fd,
32            i64::try_from(size).context("size too large for off_t")?,
33        )
34    };
35    if rc != 0 {
36        let err = std::io::Error::last_os_error();
37        let _ = unsafe { libc::close(fd) };
38        return Err(anyhow::anyhow!("ftruncate failed: {err}").into());
39    }
40
41    Ok(fd)
42}
43
44#[cfg(all(unix, not(target_os = "linux")))]
45pub fn create_memfd(_size: u64) -> Result<i32, Error> {
46    Err(
47        anyhow::anyhow!("memfd_create not supported on this Unix platform")
48            .into(),
49    )
50}
51
52#[allow(unsafe_code)]
53#[cfg(unix)]
54pub fn dup_fd(fd: RawFd) -> Result<RawFd, Error> {
55    // Safety: dup returns a new owned fd on success.
56    let new_fd = unsafe { libc::dup(fd) };
57    if new_fd < 0 {
58        return Err(anyhow::anyhow!(
59            "dup failed: {}",
60            std::io::Error::last_os_error()
61        )
62        .into());
63    }
64    Ok(new_fd)
65}
66
67#[allow(unsafe_code)]
68#[cfg(unix)]
69pub fn close_fd(fd: RawFd) -> Result<(), Error> {
70    // Safety: close consumes the fd.
71    let rc = unsafe { libc::close(fd) };
72    if rc != 0 {
73        return Err(anyhow::anyhow!(
74            "close failed: {}",
75            std::io::Error::last_os_error()
76        )
77        .into());
78    }
79    Ok(())
80}
81
82//
83#[allow(unsafe_code, clippy::as_conversions)]
84#[cfg(unix)]
85pub fn send_fd(
86    stream: &std::os::unix::net::UnixStream,
87    fd: RawFd,
88) -> Result<(), Error> {
89    use std::mem;
90
91    let byte: [u8; 1] = [0u8];
92    let mut iov = libc::iovec {
93        iov_base: byte.as_ptr() as *mut libc::c_void,
94        iov_len: byte.len(),
95    };
96
97    let mut cmsg_buf = [0u8; 64];
98    let mut msg: libc::msghdr = unsafe { mem::zeroed() };
99    msg.msg_iov = std::ptr::addr_of_mut!(iov);
100    msg.msg_iovlen = 1;
101    msg.msg_control = cmsg_buf.as_mut_ptr().cast::<libc::c_void>();
102    msg.msg_controllen = cmsg_buf.len();
103
104    // Safety: build SCM_RIGHTS ancillary message.
105    unsafe {
106        let cmsg = libc::CMSG_FIRSTHDR(std::ptr::addr_of!(msg));
107        if cmsg.is_null() {
108            return Err(anyhow::anyhow!("CMSG_FIRSTHDR returned null").into());
109        }
110        (*cmsg).cmsg_level = libc::SOL_SOCKET;
111        (*cmsg).cmsg_type = libc::SCM_RIGHTS;
112        (*cmsg).cmsg_len =
113            libc::CMSG_LEN(mem::size_of::<RawFd>() as u32) as usize;
114
115        let data = libc::CMSG_DATA(cmsg).cast::<RawFd>();
116        *data = fd;
117
118        msg.msg_controllen = (*cmsg).cmsg_len;
119    }
120
121    let rc = unsafe {
122        libc::sendmsg(stream.as_raw_fd(), std::ptr::addr_of!(msg), 0)
123    };
124    if rc < 0 {
125        return Err(anyhow::anyhow!(
126            "sendmsg failed: {}",
127            std::io::Error::last_os_error()
128        )
129        .into());
130    }
131    Ok(())
132}
133
134#[allow(unsafe_code, clippy::as_conversions)]
135#[cfg(unix)]
136pub fn recv_fd(
137    stream: &std::os::unix::net::UnixStream,
138) -> Result<RawFd, Error> {
139    use std::mem;
140
141    let mut byte: [u8; 1] = [0u8];
142    let mut iov = libc::iovec {
143        iov_base: byte.as_mut_ptr().cast::<libc::c_void>(),
144        iov_len: byte.len(),
145    };
146
147    let mut cmsg_buf = [0u8; 64];
148    let mut msg: libc::msghdr = unsafe { mem::zeroed() };
149    msg.msg_iov = std::ptr::addr_of_mut!(iov);
150    msg.msg_iovlen = 1;
151    msg.msg_control = cmsg_buf.as_mut_ptr().cast::<libc::c_void>();
152    msg.msg_controllen = cmsg_buf.len();
153
154    let rc = unsafe {
155        libc::recvmsg(stream.as_raw_fd(), std::ptr::addr_of_mut!(msg), 0)
156    };
157    if rc < 0 {
158        return Err(anyhow::anyhow!(
159            "recvmsg failed: {}",
160            std::io::Error::last_os_error()
161        )
162        .into());
163    }
164
165    // Safety: parse SCM_RIGHTS.
166    unsafe {
167        let cmsg = libc::CMSG_FIRSTHDR(std::ptr::addr_of!(msg));
168        if cmsg.is_null() {
169            return Err(anyhow::anyhow!("no cmsg received").into());
170        }
171        if (*cmsg).cmsg_level != libc::SOL_SOCKET
172            || (*cmsg).cmsg_type != libc::SCM_RIGHTS
173        {
174            return Err(anyhow::anyhow!("unexpected cmsg type").into());
175        }
176        let data = libc::CMSG_DATA(cmsg).cast::<RawFd>();
177        Ok(*data)
178    }
179}