1use crate::workspace::ipc::auth::capability::TokenValidator;
13use crate::workspace::ipc::error::Error;
14use crate::workspace::ipc::protocol::RpcError;
15use crate::workspace::ipc::protocol::{
16 Cancel, Event, Message, Request, Response, StreamControl,
17};
18use crate::workspace::ipc::protocol::{Hello, HelloErr, HelloOk};
19use crate::workspace::ipc::router::{ConnectionContext, Router};
20use crate::workspace::ipc::transport::FramedConnection;
21use crate::workspace::ipc::types::ConnectionId;
22use crate::workspace::ipc::types::{RequestId, StreamId};
23use anyhow::{Context as _, anyhow};
24use async_trait::async_trait;
25use bytes::Bytes;
26use futures_core::Stream;
27use std::collections::HashMap;
28use std::pin::Pin;
29use std::sync::Arc;
30use tokio::sync::{mpsc, oneshot};
31use tokio_stream::wrappers::ReceiverStream;
32use tracing::Instrument as _;
33
34#[async_trait]
36pub trait RpcClient: Send + Sync {
37 async fn send_request(&self, req: Request) -> Result<(), Error>;
39
40 async fn recv_response(&self, id: RequestId) -> Result<Response, Error>;
42
43 async fn cancel(&self, cancel: Cancel) -> Result<(), Error>;
45
46 fn events(&self) -> Pin<Box<dyn Stream<Item = Event> + Send>>;
48}
49
50#[async_trait]
52pub trait StreamManager: Send + Sync {
53 async fn start(&self, ctl: StreamControl) -> Result<(), Error>;
54 async fn next(&self, ctl: StreamControl) -> Result<(), Error>;
55 async fn end(&self, ctl: StreamControl) -> Result<(), Error>;
56
57 fn stream_incoming(
59 &self,
60 id: StreamId,
61 ) -> Pin<Box<dyn Stream<Item = Bytes> + Send>>;
62}
63
64#[async_trait]
66pub trait Session: Send + Sync {
67 async fn send(&self, msg: &Message) -> Result<(), Error>;
69
70 async fn recv(&self) -> Result<Option<Message>, Error>;
72
73 async fn client_handshake(
78 &self,
79 hello: Hello,
80 ) -> Result<crate::workspace::ipc::auth::capability::CapabilitySet, Error>
81 where
82 Self: Sized,
83 {
84 self.send(&Message::Hello(hello)).await?;
85 loop {
86 let Some(msg) = self.recv().await? else {
87 return Err(to_ipc_error(anyhow!("eof before HelloOk")));
88 };
89 match msg {
90 Message::HelloOk(ok) => {
91 return Ok(ok.bound_capabilities);
92 }
93 Message::HelloErr(err) => {
94 return Err(to_ipc_error(anyhow!(
95 "handshake failed: {}",
96 err.message
97 )));
98 }
99 _ => {}
101 }
102 }
103 }
104
105 async fn server_handshake<V: TokenValidator, R: Router>(
110 &self,
111 validator: &V,
112 router: &R,
113 connection_id: ConnectionId,
114 ) -> Result<crate::workspace::ipc::auth::capability::CapabilitySet, Error>
115 where
116 Self: Sized,
117 {
118 let Some(msg) = self.recv().await? else {
119 return Err(to_ipc_error(anyhow!("eof awaiting Hello")));
120 };
121 let Message::Hello(hello) = msg else {
122 return Err(to_ipc_error(anyhow!("expected Hello")));
123 };
124
125 let bound = match validator.validate(&hello.token) {
126 Ok(set) => set,
127 Err(e) => {
128 let _ = self
130 .send(&Message::HelloErr(HelloErr::new(format!("{e:#}"))))
131 .await;
132 return Err(to_ipc_error(anyhow!("token invalid: {e:#}")));
133 }
134 };
135
136 self.send(&Message::HelloOk(HelloOk::new(bound.clone())))
138 .await?;
139
140 let ctx = ConnectionContext {
142 id: connection_id,
143 capabilities: bound.clone(),
144 metadata: hello.client_info.as_ref().map(|ci| {
145 serde_json::json!({
146 "name": ci.name,
147 "version": ci.version,
148 "process_kind": ci.process_kind
149 })
150 }),
151 };
152 router.register_connection(ctx).await?;
153
154 Ok(bound)
155 }
156}
157
158const DEFAULT_EVENTS_CAPACITY: usize = 64;
159
160pub struct FramedSession<T> {
163 inner: T,
164}
165
166impl<T: Clone> FramedSession<T> {
167 pub fn new(inner: T) -> Self {
169 Self { inner }
170 }
171}
172
173#[async_trait]
174impl<T> Session for FramedSession<T>
175where
176 T: FramedConnection + Clone + 'static,
177{
178 async fn send(&self, msg: &Message) -> Result<(), Error> {
180 let bytes = postcard::to_stdvec(msg)
181 .context("postcard serialize Message")
182 .map_err(to_ipc_error)?;
183 self.inner.send_frame(Bytes::from(bytes)).await
184 }
185
186 async fn recv(&self) -> Result<Option<Message>, Error> {
188 let maybe = self.inner.recv_frame().await?;
189 let Some(frame) = maybe else {
190 return Ok(None);
191 };
192 let msg = postcard::from_bytes::<Message>(&frame)
193 .context("postcard deserialize Message")
194 .map_err(to_ipc_error)?;
195 Ok(Some(msg))
196 }
197}
198
199fn to_ipc_error(e: anyhow::Error) -> Error {
201 e.into()
202}
203
204#[async_trait]
209pub trait BlobStore: Send + Sync {
210 async fn put(
211 &self,
212 token: crate::workspace::ipc::types::BlobToken,
213 bytes: Bytes,
214 ) -> Result<(), Error>;
215
216 async fn take(
217 &self,
218 token: &crate::workspace::ipc::types::BlobToken,
219 ) -> Result<Bytes, Error>;
220
221 async fn len(&self) -> Result<usize, Error>;
222}
223
224#[derive(Debug, Default)]
226pub struct InMemoryBlobStore {
227 inner: tokio::sync::Mutex<
228 HashMap<crate::workspace::ipc::types::BlobToken, Bytes>,
229 >,
230}
231
232#[async_trait]
233impl BlobStore for InMemoryBlobStore {
234 async fn put(
235 &self,
236 token: crate::workspace::ipc::types::BlobToken,
237 bytes: Bytes,
238 ) -> Result<(), Error> {
239 let mut guard = self.inner.lock().await;
240 guard.insert(token, bytes);
241 Ok(())
242 }
243
244 async fn take(
245 &self,
246 token: &crate::workspace::ipc::types::BlobToken,
247 ) -> Result<Bytes, Error> {
248 let mut guard = self.inner.lock().await;
249 let Some(bytes) = guard.remove(token) else {
250 return Err(to_ipc_error(anyhow!("unknown blob token")));
251 };
252 Ok(bytes)
253 }
254
255 async fn len(&self) -> Result<usize, Error> {
256 let guard = self.inner.lock().await;
257 Ok(guard.len())
258 }
259}
260
261pub struct DefaultRpcClient<S: Session> {
270 session: Arc<S>,
271 connection_id: Option<ConnectionId>,
272 inflight: tokio::sync::Mutex<HashMap<RequestId, oneshot::Sender<Response>>>,
273 pending: tokio::sync::Mutex<HashMap<RequestId, Response>>,
274 events_tx: mpsc::Sender<Event>,
275 events_rx: std::sync::Mutex<Option<mpsc::Receiver<Event>>>,
276 streams: Arc<tokio::sync::Mutex<HashMap<StreamId, IncomingStreamState>>>,
277}
278
279impl<S: Session + 'static> DefaultRpcClient<S> {
280 pub fn new(
283 session: std::sync::Arc<S>,
284 events_capacity: usize,
285 ) -> std::sync::Arc<Self> {
286 Self::new_with_connection_id(session, None, events_capacity)
287 }
288
289 pub fn new_with_connection_id(
291 session: std::sync::Arc<S>,
292 connection_id: Option<ConnectionId>,
293 events_capacity: usize,
294 ) -> std::sync::Arc<Self> {
295 let (events_tx, events_rx) = mpsc::channel(events_capacity);
296 let client = std::sync::Arc::new(Self {
297 session: session.clone(),
298 connection_id,
299 inflight: tokio::sync::Mutex::new(HashMap::new()),
300 pending: tokio::sync::Mutex::new(HashMap::new()),
301 events_tx,
302 events_rx: std::sync::Mutex::new(Some(events_rx)),
303 streams: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
304 });
305
306 Self::spawn_dispatcher(client.clone());
307 client
308 }
309
310 async fn handle_stream_control(
311 &self,
312 ctl: StreamControl,
313 ) -> Result<(), Error> {
314 let id = match &ctl {
315 StreamControl::Start { id, .. } => *id,
316 StreamControl::Next { id, .. } => *id,
317 StreamControl::End { id, .. } => *id,
318 };
319
320 let span = if let Some(conn_id) = self.connection_id {
321 tracing::info_span!("ipc.rpc.stream_control", conn_id = %conn_id, stream_id = id)
322 } else {
323 tracing::info_span!("ipc.rpc.stream_control", stream_id = id)
324 };
325
326 async move {
327 match ctl {
328 StreamControl::Start { .. } => {
329 let mut guard = self.streams.lock().await;
330 guard
331 .entry(id)
332 .or_insert_with(|| IncomingStreamState::new(64));
333 Ok(())
334 }
335 StreamControl::Next { chunk, .. } => {
336 let Some(chunk) = chunk else {
337 return Ok(());
339 };
340
341 let mut guard = self.streams.lock().await;
342 let st = guard
343 .entry(id)
344 .or_insert_with(|| IncomingStreamState::new(64));
345 st.tx.send(Bytes::from(chunk)).await.map_err(|e| {
346 to_ipc_error(anyhow!("stream send failed: {e}"))
347 })?;
348 Ok(())
349 }
350 StreamControl::End { .. } => {
351 let mut guard = self.streams.lock().await;
352 guard.remove(&id);
353 Ok(())
354 }
355 }
356 }
357 .instrument(span)
358 .await
359 }
360
361 #[cfg(test)]
362 async fn stream_count(&self) -> usize {
363 let guard = self.streams.lock().await;
364 guard.len()
365 }
366}
367
368struct IncomingStreamState {
369 tx: mpsc::Sender<Bytes>,
370 rx: Option<mpsc::Receiver<Bytes>>,
371}
372
373impl IncomingStreamState {
374 fn new(capacity: usize) -> Self {
375 let (tx, rx) = mpsc::channel(capacity);
376 Self { tx, rx: Some(rx) }
377 }
378}
379
380impl<S: Session + 'static> DefaultRpcClient<S> {
381 fn spawn_dispatcher(client: std::sync::Arc<Self>) {
382 tokio::spawn(async move {
383 loop {
384 let msg = match client.session.recv().await {
385 Ok(Some(m)) => m,
386 Ok(None) => {
387 let mut inflight = client.inflight.lock().await;
389 for (id, sender) in inflight.drain() {
390 let _ = sender.send(Response {
391 id,
392 ok: false,
393 result: None,
394 error: Some(RpcError {
395 code: "eof".to_string(),
396 message:
397 "session closed while awaiting response"
398 .to_string(),
399 }),
400 });
401 }
402 break;
403 }
404 Err(e) => {
405 let mut inflight = client.inflight.lock().await;
407 for (id, sender) in inflight.drain() {
408 let _ = sender.send(Response {
409 id,
410 ok: false,
411 result: None,
412 error: Some(RpcError {
413 code: "session_recv_error".to_string(),
414 message: format!("{e:#}"),
415 }),
416 });
417 }
418 break;
419 }
420 };
421
422 match msg {
423 Message::Response(resp) => {
424 let id = response_id(&resp);
425 let span = if let Some(conn_id) = client.connection_id {
426 tracing::info_span!("ipc.rpc.recv_response_msg", conn_id = %conn_id, request_id = id)
427 } else {
428 tracing::info_span!(
429 "ipc.rpc.recv_response_msg",
430 request_id = id
431 )
432 };
433 async {
434 if let Some(sender) =
435 client.inflight.lock().await.remove(&id)
436 {
437 let _ = sender.send(resp);
438 } else {
439 client.pending.lock().await.insert(id, resp);
440 }
441 }
442 .instrument(span)
443 .await;
444 }
445 Message::Event(ev) => {
446 if let Err(_e) = client.events_tx.send(ev).await {
448 }
450 }
451 Message::Stream(ctl) => {
452 let _ = client.handle_stream_control(ctl).await;
453 }
454 Message::Hello(_)
456 | Message::HelloOk(_)
457 | Message::HelloErr(_)
458 | Message::Request(_)
459 | Message::Cancel(_) => {}
460 }
461 }
462 });
463 }
464
465 async fn handle_pending_response(
466 &self,
467 id: RequestId,
468 ) -> Result<Response, Error> {
469 if let Some(resp) = self.pending.lock().await.remove(&id) {
470 return Ok(resp);
471 }
472
473 let (tx, rx) = oneshot::channel();
474 {
475 let mut inflight = self.inflight.lock().await;
476 if let Some(resp) = self.pending.lock().await.remove(&id) {
477 return Ok(resp);
478 }
479 inflight.insert(id, tx);
480 }
481
482 match rx.await {
483 Ok(resp) => Ok(resp),
484 Err(e) => Err(to_ipc_error(anyhow!(
485 "response waiter dropped for id={id:?}: {e}"
486 ))),
487 }
488 }
489}
490
491#[async_trait]
492impl<S: Session + 'static> RpcClient for DefaultRpcClient<S> {
493 async fn send_request(&self, req: Request) -> Result<(), Error> {
495 let span = if let Some(conn_id) = self.connection_id {
496 tracing::info_span!("ipc.rpc.send_request", conn_id = %conn_id, request_id = req.id, service = %req.method.service, method = %req.method.method)
497 } else {
498 tracing::info_span!("ipc.rpc.send_request", request_id = req.id, service = %req.method.service, method = %req.method.method)
499 };
500 async move { self.session.send(&Message::Request(req)).await }
501 .instrument(span)
502 .await
503 }
504
505 async fn recv_response(&self, id: RequestId) -> Result<Response, Error> {
508 let span = if let Some(conn_id) = self.connection_id {
509 tracing::info_span!("ipc.rpc.recv_response", conn_id = %conn_id, request_id = id)
510 } else {
511 tracing::info_span!("ipc.rpc.recv_response", request_id = id)
512 };
513 async move { self.handle_pending_response(id).await }
514 .instrument(span)
515 .await
516 }
517
518 async fn cancel(&self, cancel: Cancel) -> Result<(), Error> {
522 let span = if let Some(conn_id) = self.connection_id {
523 tracing::info_span!("ipc.rpc.cancel", conn_id = %conn_id, request_id = cancel.id)
524 } else {
525 tracing::info_span!("ipc.rpc.cancel", request_id = cancel.id)
526 };
527 async move {
528 if let Some(sender) = self.inflight.lock().await.remove(&cancel.id)
530 {
531 let _ = sender.send(Response {
532 id: cancel.id,
533 ok: false,
534 result: None,
535 error: Some(RpcError::cancelled(
536 "request cancelled by client",
537 )),
538 });
539 }
540 let _ = self.pending.lock().await.remove(&cancel.id);
541
542 self.session.send(&Message::Cancel(cancel)).await
543 }
544 .instrument(span)
545 .await
546 }
547
548 fn events(&self) -> Pin<Box<dyn Stream<Item = Event> + Send>> {
549 let mut guard = match self.events_rx.lock() {
550 Ok(g) => g,
551 Err(_) => return Box::pin(futures_util::stream::empty()),
552 };
553
554 if let Some(rx) = guard.take() {
555 Box::pin(ReceiverStream::new(rx))
556 } else {
557 Box::pin(futures_util::stream::empty())
558 }
559 }
560}
561
562#[async_trait]
563impl<S: Session + 'static> StreamManager for DefaultRpcClient<S> {
564 async fn start(&self, ctl: StreamControl) -> Result<(), Error> {
565 self.session.send(&Message::Stream(ctl)).await
566 }
567
568 async fn next(&self, ctl: StreamControl) -> Result<(), Error> {
569 self.session.send(&Message::Stream(ctl)).await
570 }
571
572 async fn end(&self, ctl: StreamControl) -> Result<(), Error> {
573 self.session.send(&Message::Stream(ctl)).await
574 }
575
576 fn stream_incoming(
577 &self,
578 id: StreamId,
579 ) -> Pin<Box<dyn Stream<Item = Bytes> + Send>> {
580 let streams = self.streams.clone();
581 let (tx, rx) = mpsc::channel::<Bytes>(1);
582
583 tokio::spawn(async move {
584 let mut guard = streams.lock().await;
585 if let Some(state) = guard.get_mut(&id) {
586 if let Some(mut original_rx) = state.rx.take() {
587 drop(guard);
588 while let Some(bytes) = original_rx.recv().await {
589 if tx.send(bytes).await.is_err() {
590 break;
591 }
592 }
593 }
594 }
595 });
596
597 Box::pin(ReceiverStream::new(rx))
598 }
599}
600
601fn response_id(resp: &Response) -> RequestId {
605 resp.id
606}
607
608#[cfg(test)]
612mod tests {
613 use crate::workspace::ipc::protocol::MethodId;
614
615 use super::*;
616 use anyhow::{Context as _, Result};
617
618 struct MockSession {
620 to_server: mpsc::Sender<Message>,
621 from_server: tokio::sync::Mutex<mpsc::Receiver<Message>>,
622 }
623
624 #[async_trait]
625 impl Session for MockSession {
626 async fn send(&self, msg: &Message) -> Result<(), Error> {
627 self.to_server
628 .send(msg.clone())
629 .await
630 .map_err(|e| to_ipc_error(anyhow!("send failed: {}", e)))
631 }
632
633 async fn recv(&self) -> Result<Option<Message>, Error> {
634 let mut rx = self.from_server.lock().await;
635 Ok(rx.recv().await)
636 }
637 }
638
639 fn make_mock_session() -> (
640 std::sync::Arc<MockSession>,
641 mpsc::Receiver<Message>,
642 mpsc::Sender<Message>,
643 ) {
644 let (to_server_tx, to_server_rx) = mpsc::channel(16);
645 let (from_server_tx, from_server_rx) = mpsc::channel(16);
646 let session = std::sync::Arc::new(MockSession {
647 to_server: to_server_tx,
648 from_server: tokio::sync::Mutex::new(from_server_rx),
649 });
650 (session, to_server_rx, from_server_tx)
651 }
652
653 #[crate::ctb_test(tokio::test)]
656 async fn rpc_correlation_out_of_order() -> Result<()> {
657 let (session, mut to_server_rx, from_server_tx) = make_mock_session();
658 let client = DefaultRpcClient::new(session, DEFAULT_EVENTS_CAPACITY);
659
660 let req1 = Request {
662 id: 1,
663 method: MethodId {
664 service: "test".to_string(),
665 method: "method1".to_string(),
666 },
667 args: vec![],
668 };
669 let req2 = Request {
670 id: 2,
671 method: MethodId {
672 service: "test".to_string(),
673 method: "method2".to_string(),
674 },
675 args: vec![],
676 };
677
678 let send1 = client.send_request(req1.clone());
680 let send2 = client.send_request(req2.clone());
681 tokio::try_join!(send1, send2)?;
682
683 let sent_req1 = to_server_rx.recv().await.context("missing req1")?;
684 let sent_req2 = to_server_rx.recv().await.context("missing req2")?;
685 assert!(matches!(sent_req1, Message::Request(r) if r.id == 1));
686 assert!(matches!(sent_req2, Message::Request(r) if r.id == 2));
687
688 let client2 = client.clone();
690 let recv1_handle =
691 tokio::spawn(async move { client.recv_response(1).await });
692 let recv2_handle =
693 tokio::spawn(async move { client2.recv_response(2).await });
694
695 let resp2 = Response {
697 id: 2,
698 ok: true,
699 result: Some(vec![0xBB]),
700 error: None,
701 };
702 let resp1 = Response {
703 id: 1,
704 ok: true,
705 result: Some(vec![0xAA]),
706 error: None,
707 };
708 from_server_tx.send(Message::Response(resp2)).await?;
709 from_server_tx.send(Message::Response(resp1)).await?;
710
711 let (resp1_recv, resp2_recv) =
712 tokio::try_join!(recv1_handle, recv2_handle)?;
713 let resp1_recv = resp1_recv?;
714 let resp2_recv = resp2_recv?;
715 assert_eq!(resp1_recv.id, 1);
716 assert_eq!(resp1_recv.result, Some(vec![0xAA]));
717 assert_eq!(resp2_recv.id, 2);
718 assert_eq!(resp2_recv.result, Some(vec![0xBB]));
719 Ok(())
720 }
721
722 #[crate::ctb_test(tokio::test)]
725 async fn rpc_cancellation() -> Result<()> {
726 let (session, mut to_server_rx, from_server_tx) = make_mock_session();
727 let client = DefaultRpcClient::new(session, DEFAULT_EVENTS_CAPACITY);
728
729 let req = Request {
731 id: 1,
732 method: MethodId {
733 service: "test".to_string(),
734 method: "method".to_string(),
735 },
736 args: vec![],
737 };
738 client.send_request(req).await?;
739
740 let sent_req = to_server_rx.recv().await.context("missing request")?;
742 assert!(matches!(sent_req, Message::Request(r) if r.id == 1));
743
744 let cancel = Cancel { id: 1 };
746 client.cancel(cancel).await?;
747
748 let sent_cancel =
750 to_server_rx.recv().await.context("missing cancel")?;
751 assert!(matches!(sent_cancel, Message::Cancel(c) if c.id == 1));
752
753 let resp = Response {
755 id: 1,
756 ok: false,
757 result: None,
758 error: Some(RpcError::cancelled("request cancelled")),
759 };
760 from_server_tx.send(Message::Response(resp)).await?;
761
762 let recv_resp = client.recv_response(1).await?;
763 assert_eq!(recv_resp.id, 1);
764 assert!(!recv_resp.ok);
765 assert_eq!(recv_resp.error.unwrap().code, "cancelled");
766 Ok(())
767 }
768
769 #[ignore = "test is incomplete"]
771 #[crate::ctb_test(tokio::test)]
772 async fn server_cancel_stops_inflight_fetch() -> Result<()> {
773 Ok(())
774 }
825}