ctoolbox/workspace/ipc/
transport.rs

1//! Length-delimited framed local-socket transport (Tokio).
2//!
3//! This module provides a Tokio-based implementation of framed connections over
4//! local sockets using a length-delimited codec. It includes both the client
5//! and server sides, allowing for bidirectional communication between processes
6//! on the same machine.
7//!
8//! The implementation is based on Tokio's asynchronous runtime and interprocess
9//! crate for local socket communication. It provides a high-level API for
10//! sending and receiving framed messages, as well as managing connections and
11//! listeners.
12
13use crate::workspace::ipc::error::Error;
14use async_trait::async_trait;
15use bytes::{Bytes, BytesMut};
16use futures::{SinkExt, StreamExt};
17use interprocess::local_socket::{
18    GenericFilePath, ListenerOptions,
19    tokio::{Listener as TokioListener, Stream as TokioStream, prelude::*},
20};
21use std::sync::Arc;
22use std::sync::atomic::{AtomicBool, Ordering};
23use tokio::sync::Mutex;
24// Add AsyncWriteExt so we can shutdown the underlying stream and signal EOF.
25use tokio_util::codec::{Framed, LengthDelimitedCodec};
26
27/// A framed, length-delimited, duplex connection. Multiplexing is layered
28/// above this.
29#[async_trait]
30pub trait FramedConnection: Send + Sync {
31    /// Send one frame.
32    async fn send_frame(&self, data: Bytes) -> Result<(), Error>;
33    /// Receive one frame. Returns None on EOF.
34    async fn recv_frame(&self) -> Result<Option<Bytes>, Error>;
35    /// Close half or full connection gracefully.
36    async fn close(&self) -> Result<(), Error>;
37}
38
39/// Factory to connect or accept connections using local sockets.
40#[async_trait]
41pub trait TransportFactory: Send + Sync {
42    type Conn: FramedConnection;
43
44    /// Client side: connect to an endpoint (e.g., path or named pipe).
45    async fn connect(&self, endpoint: &str) -> Result<Self::Conn, Error>;
46
47    /// Server side: bind and accept incoming connections.
48    async fn bind(
49        &self,
50        endpoint: &str,
51    ) -> Result<Box<dyn TransportListener<Conn = Self::Conn>>, Error>;
52}
53
54#[async_trait]
55pub trait TransportListener: Send + Sync {
56    type Conn: FramedConnection;
57
58    /// Accept the next connection. Returns None when listener is closed.
59    async fn accept(&self) -> Result<Option<Self::Conn>, Error>;
60
61    /// Close the listener.
62    async fn close(&self) -> Result<(), Error>;
63}
64
65/// Concrete framed connection over a Tokio local socket using a
66/// length-delimited codec.
67pub struct LocalSocketFramedConnection {
68    inner: Mutex<Framed<TokioStream, LengthDelimitedCodec>>,
69}
70
71#[async_trait]
72impl FramedConnection for LocalSocketFramedConnection {
73    /// Send one frame.
74    async fn send_frame(&self, data: Bytes) -> Result<(), Error> {
75        let mut guard = self.inner.lock().await;
76        let mut bm = BytesMut::with_capacity(data.len());
77        bm.extend_from_slice(&data);
78        SinkExt::send(&mut *guard, bm.into())
79            .await
80            .map_err(Error::from)
81    }
82
83    /// Receive one frame. Returns None on EOF.
84    async fn recv_frame(&self) -> Result<Option<Bytes>, Error> {
85        let mut guard = self.inner.lock().await;
86        match StreamExt::next(&mut *guard).await {
87            Some(Ok(bm)) => Ok(Some(bm.freeze())),
88            Some(Err(e)) => Err(Error::from(e)),
89            None => Ok(None),
90        }
91    }
92
93    /// Close half or full connection gracefully.
94    async fn close(&self) -> Result<(), Error> {
95        let mut guard = self.inner.lock().await;
96        // Flush and close the codec sink.
97        SinkExt::close(&mut *guard).await.map_err(Error::from)
98    }
99}
100
101/// Concrete transport factory based on Tokio local sockets.
102pub struct LocalSocketTransportFactory;
103
104#[async_trait]
105impl TransportFactory for LocalSocketTransportFactory {
106    type Conn = LocalSocketFramedConnection;
107
108    /// Client side: connect to an endpoint (e.g., path or named pipe).
109    async fn connect(&self, endpoint: &str) -> Result<Self::Conn, Error> {
110        // Build a platform-appropriate name from the provided endpoint string.
111        #[cfg(windows)]
112        let name = endpoint
113            .to_ns_name::<GenericNamespaced>()
114            .map_err(Error::from)?;
115        #[cfg(unix)]
116        let name = endpoint
117            .to_fs_name::<GenericFilePath>()
118            .map_err(Error::from)?;
119        let stream = TokioStream::connect(name).await.map_err(Error::from)?;
120        let framed = Framed::new(stream, LengthDelimitedCodec::new());
121        Ok(LocalSocketFramedConnection {
122            inner: Mutex::new(framed),
123        })
124    }
125
126    /// Server side: bind and accept incoming connections.
127    async fn bind(
128        &self,
129        endpoint: &str,
130    ) -> Result<Box<dyn TransportListener<Conn = Self::Conn>>, Error> {
131        // Build a platform-appropriate name from the provided endpoint string.
132        #[cfg(windows)]
133        let name = endpoint
134            .to_ns_name::<GenericNamespaced>()
135            .map_err(Error::from)?;
136        #[cfg(unix)]
137        let name = endpoint
138            .to_fs_name::<GenericFilePath>()
139            .map_err(Error::from)?;
140        let listener = ListenerOptions::new()
141            .name(name)
142            .create_tokio()
143            .map_err(Error::from)?;
144        Ok(Box::new(LocalSocketTransportListener::new(listener)))
145    }
146}
147
148/// Transport listener for Tokio local sockets.
149pub struct LocalSocketTransportListener {
150    inner: Arc<TokioListener>,
151    closed: AtomicBool,
152}
153
154impl LocalSocketTransportListener {
155    pub fn new(listener: TokioListener) -> Self {
156        Self {
157            inner: Arc::new(listener),
158            closed: AtomicBool::new(false),
159        }
160    }
161}
162
163#[async_trait]
164impl TransportListener for LocalSocketTransportListener {
165    type Conn = LocalSocketFramedConnection;
166
167    /// Accept the next connection. Returns None when listener is closed.
168    async fn accept(&self) -> Result<Option<Self::Conn>, Error> {
169        if self.closed.load(Ordering::Relaxed) {
170            return Ok(None);
171        }
172        let stream = match self.inner.accept().await {
173            Ok(s) => s,
174            Err(e) => return Err(Error::from(e)),
175        };
176        let framed = Framed::new(stream, LengthDelimitedCodec::new());
177        Ok(Some(LocalSocketFramedConnection {
178            inner: Mutex::new(framed),
179        }))
180    }
181
182    /// Close the listener.
183    async fn close(&self) -> Result<(), Error> {
184        self.closed.store(true, Ordering::Release);
185        Ok(())
186        // Dropping the listener will close the underlying handle once all Arcs are gone.
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use crate::debug_fmt;
193
194    use super::*;
195    use anyhow::{Result, ensure};
196    use bytes::Bytes;
197
198    fn unique_endpoint() -> String {
199        use std::time::{SystemTime, UNIX_EPOCH};
200        let pid = std::process::id();
201        let nanos = SystemTime::now()
202            .duration_since(UNIX_EPOCH)
203            .map(|d| d.as_nanos())
204            .unwrap_or(0);
205        #[cfg(unix)]
206        {
207            let mut p = std::env::temp_dir();
208            p.push(format!("ctb-echo-{}-{}.sock", pid, nanos));
209            return p.to_string_lossy().into_owned();
210        }
211        #[cfg(windows)]
212        {
213            // Namespaced name for Windows named pipes via interprocess local sockets.
214            return format!("ctb-echo-{}-{}", pid, nanos);
215        }
216    }
217
218    #[crate::ctb_test(tokio::test)]
219    async fn echo_frames_over_local_socket() -> Result<()> {
220        let endpoint = unique_endpoint();
221
222        // Best-effort cleanup on Unix in case a stale socket path exists.
223        #[cfg(unix)]
224        let _ = std::fs::remove_file(&endpoint);
225
226        let factory = LocalSocketTransportFactory;
227
228        let payloads = vec![
229            Bytes::from_static(b"hello"),
230            Bytes::from_static(b"world"),
231            Bytes::from_static(b"!"),
232        ];
233
234        // Bind server listener.
235        let listener = factory.bind(&endpoint).await?;
236        let listener = Arc::new(listener);
237
238        // Server task: accept a single connection and echo frames until EOF.
239        {
240            let listener = Arc::clone(&listener);
241            tokio::spawn(async move {
242                let conn = match listener.accept().await {
243                    Ok(Some(c)) => c,
244                    Ok(None) => return anyhow::Ok(()),
245                    Err(e) => return Err(anyhow::Error::from(e)),
246                };
247
248                debug_fmt!("Server begins listening");
249                loop {
250                    let maybe = conn.recv_frame().await;
251                    let Some(frame) = maybe? else {
252                        debug_fmt!("Server ending loop");
253                        break;
254                    };
255                    debug_fmt!("Server echoing frame: {:?}", &frame);
256                    conn.send_frame(frame).await?;
257                }
258                debug_fmt!("Server ended loop");
259                conn.close().await.map_err(anyhow::Error::from)
260            })
261        };
262
263        // Client: connect, send frames, receive echoes.
264        let client = factory.connect(&endpoint).await?;
265
266        // Send payloads.
267        for p in &payloads {
268            client.send_frame(p.clone()).await?;
269        }
270
271        debug_fmt!("Done sending on {}; now listening for echoes", &endpoint);
272
273        // Read exactly the expected number of echoes, then close.
274        // FIXME: I can't figure out how to do this with just listening for the
275        // end of stream, rather than counting. Looping makes it hang.
276        /*let mut echoes = Vec::new();
277        loop {
278            match client.recv_frame().await? {
279                Some(b) => echoes.push(b),
280                None => break,
281                }
282            }*/
283        let mut echoes = Vec::with_capacity(payloads.len());
284        for _ in 0..payloads.len() {
285            let Some(b) = client.recv_frame().await? else {
286                break;
287            };
288            debug_fmt!("Client received echo: {:?}", &b);
289            echoes.push(b);
290        }
291        debug_fmt!("Received echoes on {}", &endpoint);
292
293        ensure!(echoes.len() == payloads.len(), "echo count mismatch");
294        for (a, b) in echoes.iter().zip(payloads.iter()) {
295            ensure!(a == b, "echo payload mismatch");
296        }
297
298        // Close listener.
299        listener.close().await?;
300        debug_fmt!("Closing listener!");
301
302        // Cleanup Unix socket path.
303        #[cfg(unix)]
304        let _ = std::fs::remove_file(&endpoint);
305
306        Ok(())
307    }
308}