ctoolbox/workspace/ipc/
transport.rs1use crate::workspace::ipc::error::Error;
14use async_trait::async_trait;
15use bytes::{Bytes, BytesMut};
16use futures::{SinkExt, StreamExt};
17use interprocess::local_socket::{
18 GenericFilePath, ListenerOptions,
19 tokio::{Listener as TokioListener, Stream as TokioStream, prelude::*},
20};
21use std::sync::Arc;
22use std::sync::atomic::{AtomicBool, Ordering};
23use tokio::sync::Mutex;
24use tokio_util::codec::{Framed, LengthDelimitedCodec};
26
27#[async_trait]
30pub trait FramedConnection: Send + Sync {
31 async fn send_frame(&self, data: Bytes) -> Result<(), Error>;
33 async fn recv_frame(&self) -> Result<Option<Bytes>, Error>;
35 async fn close(&self) -> Result<(), Error>;
37}
38
39#[async_trait]
41pub trait TransportFactory: Send + Sync {
42 type Conn: FramedConnection;
43
44 async fn connect(&self, endpoint: &str) -> Result<Self::Conn, Error>;
46
47 async fn bind(
49 &self,
50 endpoint: &str,
51 ) -> Result<Box<dyn TransportListener<Conn = Self::Conn>>, Error>;
52}
53
54#[async_trait]
55pub trait TransportListener: Send + Sync {
56 type Conn: FramedConnection;
57
58 async fn accept(&self) -> Result<Option<Self::Conn>, Error>;
60
61 async fn close(&self) -> Result<(), Error>;
63}
64
65pub struct LocalSocketFramedConnection {
68 inner: Mutex<Framed<TokioStream, LengthDelimitedCodec>>,
69}
70
71#[async_trait]
72impl FramedConnection for LocalSocketFramedConnection {
73 async fn send_frame(&self, data: Bytes) -> Result<(), Error> {
75 let mut guard = self.inner.lock().await;
76 let mut bm = BytesMut::with_capacity(data.len());
77 bm.extend_from_slice(&data);
78 SinkExt::send(&mut *guard, bm.into())
79 .await
80 .map_err(Error::from)
81 }
82
83 async fn recv_frame(&self) -> Result<Option<Bytes>, Error> {
85 let mut guard = self.inner.lock().await;
86 match StreamExt::next(&mut *guard).await {
87 Some(Ok(bm)) => Ok(Some(bm.freeze())),
88 Some(Err(e)) => Err(Error::from(e)),
89 None => Ok(None),
90 }
91 }
92
93 async fn close(&self) -> Result<(), Error> {
95 let mut guard = self.inner.lock().await;
96 SinkExt::close(&mut *guard).await.map_err(Error::from)
98 }
99}
100
101pub struct LocalSocketTransportFactory;
103
104#[async_trait]
105impl TransportFactory for LocalSocketTransportFactory {
106 type Conn = LocalSocketFramedConnection;
107
108 async fn connect(&self, endpoint: &str) -> Result<Self::Conn, Error> {
110 #[cfg(windows)]
112 let name = endpoint
113 .to_ns_name::<GenericNamespaced>()
114 .map_err(Error::from)?;
115 #[cfg(unix)]
116 let name = endpoint
117 .to_fs_name::<GenericFilePath>()
118 .map_err(Error::from)?;
119 let stream = TokioStream::connect(name).await.map_err(Error::from)?;
120 let framed = Framed::new(stream, LengthDelimitedCodec::new());
121 Ok(LocalSocketFramedConnection {
122 inner: Mutex::new(framed),
123 })
124 }
125
126 async fn bind(
128 &self,
129 endpoint: &str,
130 ) -> Result<Box<dyn TransportListener<Conn = Self::Conn>>, Error> {
131 #[cfg(windows)]
133 let name = endpoint
134 .to_ns_name::<GenericNamespaced>()
135 .map_err(Error::from)?;
136 #[cfg(unix)]
137 let name = endpoint
138 .to_fs_name::<GenericFilePath>()
139 .map_err(Error::from)?;
140 let listener = ListenerOptions::new()
141 .name(name)
142 .create_tokio()
143 .map_err(Error::from)?;
144 Ok(Box::new(LocalSocketTransportListener::new(listener)))
145 }
146}
147
148pub struct LocalSocketTransportListener {
150 inner: Arc<TokioListener>,
151 closed: AtomicBool,
152}
153
154impl LocalSocketTransportListener {
155 pub fn new(listener: TokioListener) -> Self {
156 Self {
157 inner: Arc::new(listener),
158 closed: AtomicBool::new(false),
159 }
160 }
161}
162
163#[async_trait]
164impl TransportListener for LocalSocketTransportListener {
165 type Conn = LocalSocketFramedConnection;
166
167 async fn accept(&self) -> Result<Option<Self::Conn>, Error> {
169 if self.closed.load(Ordering::Relaxed) {
170 return Ok(None);
171 }
172 let stream = match self.inner.accept().await {
173 Ok(s) => s,
174 Err(e) => return Err(Error::from(e)),
175 };
176 let framed = Framed::new(stream, LengthDelimitedCodec::new());
177 Ok(Some(LocalSocketFramedConnection {
178 inner: Mutex::new(framed),
179 }))
180 }
181
182 async fn close(&self) -> Result<(), Error> {
184 self.closed.store(true, Ordering::Release);
185 Ok(())
186 }
188}
189
190#[cfg(test)]
191mod tests {
192 use crate::debug_fmt;
193
194 use super::*;
195 use anyhow::{Result, ensure};
196 use bytes::Bytes;
197
198 fn unique_endpoint() -> String {
199 use std::time::{SystemTime, UNIX_EPOCH};
200 let pid = std::process::id();
201 let nanos = SystemTime::now()
202 .duration_since(UNIX_EPOCH)
203 .map(|d| d.as_nanos())
204 .unwrap_or(0);
205 #[cfg(unix)]
206 {
207 let mut p = std::env::temp_dir();
208 p.push(format!("ctb-echo-{}-{}.sock", pid, nanos));
209 return p.to_string_lossy().into_owned();
210 }
211 #[cfg(windows)]
212 {
213 return format!("ctb-echo-{}-{}", pid, nanos);
215 }
216 }
217
218 #[crate::ctb_test(tokio::test)]
219 async fn echo_frames_over_local_socket() -> Result<()> {
220 let endpoint = unique_endpoint();
221
222 #[cfg(unix)]
224 let _ = std::fs::remove_file(&endpoint);
225
226 let factory = LocalSocketTransportFactory;
227
228 let payloads = vec![
229 Bytes::from_static(b"hello"),
230 Bytes::from_static(b"world"),
231 Bytes::from_static(b"!"),
232 ];
233
234 let listener = factory.bind(&endpoint).await?;
236 let listener = Arc::new(listener);
237
238 {
240 let listener = Arc::clone(&listener);
241 tokio::spawn(async move {
242 let conn = match listener.accept().await {
243 Ok(Some(c)) => c,
244 Ok(None) => return anyhow::Ok(()),
245 Err(e) => return Err(anyhow::Error::from(e)),
246 };
247
248 debug_fmt!("Server begins listening");
249 loop {
250 let maybe = conn.recv_frame().await;
251 let Some(frame) = maybe? else {
252 debug_fmt!("Server ending loop");
253 break;
254 };
255 debug_fmt!("Server echoing frame: {:?}", &frame);
256 conn.send_frame(frame).await?;
257 }
258 debug_fmt!("Server ended loop");
259 conn.close().await.map_err(anyhow::Error::from)
260 })
261 };
262
263 let client = factory.connect(&endpoint).await?;
265
266 for p in &payloads {
268 client.send_frame(p.clone()).await?;
269 }
270
271 debug_fmt!("Done sending on {}; now listening for echoes", &endpoint);
272
273 let mut echoes = Vec::with_capacity(payloads.len());
284 for _ in 0..payloads.len() {
285 let Some(b) = client.recv_frame().await? else {
286 break;
287 };
288 debug_fmt!("Client received echo: {:?}", &b);
289 echoes.push(b);
290 }
291 debug_fmt!("Received echoes on {}", &endpoint);
292
293 ensure!(echoes.len() == payloads.len(), "echo count mismatch");
294 for (a, b) in echoes.iter().zip(payloads.iter()) {
295 ensure!(a == b, "echo payload mismatch");
296 }
297
298 listener.close().await?;
300 debug_fmt!("Closing listener!");
301
302 #[cfg(unix)]
304 let _ = std::fs::remove_file(&endpoint);
305
306 Ok(())
307 }
308}