1use crate::workspace::ipc::auth::capability::CapabilitySet;
2use crate::workspace::ipc::error::Error;
3use crate::workspace::ipc::process_manager::ProcessManager;
4use crate::workspace::ipc::protocol::{
5 Event, MethodId, Request, Response, RpcError,
6};
7use crate::workspace::ipc::services::network::api::{
8 METHOD_FETCH, METHOD_READ_FILE, NetworkService,
9 SERVICE_NAME as NETWORK_SERVICE_NAME,
10};
11use crate::workspace::ipc::services::process::api::{
12 METHOD_SHUTDOWN_TREE, ProcessService, SERVICE_NAME as PROCESS_SERVICE_NAME,
13 ShutdownTreeRequest,
14};
15use crate::workspace::ipc::types::{ConnectionId, ProcessId};
16use async_trait::async_trait;
17use std::collections::HashMap;
18use std::sync::{Arc, Mutex};
19use std::time::Duration;
20use tokio::time::{self, Instant};
21
22use tracing::Instrument as _;
23
24#[derive(Clone, Debug, Eq, PartialEq, Hash)]
26struct RateKey {
27 conn_id: ConnectionId,
28 service: String,
29 method: String,
30}
31
32impl RateKey {
33 fn bytes(conn_id: ConnectionId, service: String, method: String) -> Self {
34 Self {
35 conn_id,
36 service,
37 method,
38 }
39 }
40}
41
42#[derive(Debug)]
44struct TokenBucket {
45 capacity: u128,
46 tokens: u128,
47 last_refill: Instant,
48 rate_bytes_per_sec: u64,
49}
50
51impl TokenBucket {
52 fn new(rate: u64, capacity: u64, now: Instant) -> Self {
53 Self {
54 capacity: u128::from(capacity),
55 tokens: u128::from(capacity),
56 last_refill: now,
57 rate_bytes_per_sec: rate,
58 }
59 }
60
61 fn try_take(&mut self, amount: u64, now: Instant) -> bool {
62 let elapsed = now.duration_since(self.last_refill);
63 let elapsed_ns = elapsed.as_nanos();
64 let added_tokens =
65 (elapsed_ns * u128::from(self.rate_bytes_per_sec)) / 1_000_000_000;
66 self.tokens = (self.tokens + added_tokens).min(self.capacity);
67 self.last_refill = now;
68
69 let amount_u128 = u128::from(amount);
70 if self.tokens >= amount_u128 {
71 self.tokens -= amount_u128;
72 true
73 } else {
74 false
75 }
76 }
77}
78
79#[derive(Debug, Clone)]
84pub struct RequestCancellation {
85 cancelled: std::sync::Arc<std::sync::atomic::AtomicBool>,
86}
87
88impl RequestCancellation {
89 pub fn new() -> Self {
91 Self {
92 cancelled: std::sync::Arc::new(std::sync::atomic::AtomicBool::new(
93 false,
94 )),
95 }
96 }
97
98 pub fn cancel(&self) {
100 self.cancelled
101 .store(true, std::sync::atomic::Ordering::SeqCst);
102 }
103
104 pub fn is_cancelled(&self) -> bool {
106 self.cancelled.load(std::sync::atomic::Ordering::SeqCst)
107 }
108}
109
110impl Default for RequestCancellation {
111 fn default() -> Self {
112 Self::new()
113 }
114}
115
116#[derive(Debug, Clone)]
117struct RequestContext {
118 cancellation: RequestCancellation,
119}
120
121tokio::task_local! {
122 static IPC_REQUEST_CONTEXT: RequestContext;
123}
124
125pub async fn scope_request_cancellation<T>(
127 cancellation: RequestCancellation,
128 fut: impl std::future::Future<Output = T>,
129) -> T {
130 IPC_REQUEST_CONTEXT
131 .scope(RequestContext { cancellation }, fut)
132 .await
133}
134
135pub fn is_cancelled() -> bool {
138 IPC_REQUEST_CONTEXT
139 .try_with(|ctx| ctx.cancellation.is_cancelled())
140 .unwrap_or(false)
141}
142
143#[async_trait]
149pub trait Router: Send + Sync {
150 async fn register_connection(
152 &self,
153 ctx: ConnectionContext,
154 ) -> Result<(), Error>;
155
156 async fn dispatch(
158 &self,
159 ctx: &ConnectionContext,
160 request: Request,
161 ) -> Result<Response, Error>;
162
163 async fn emit_event(&self, event: Event) -> Result<(), Error>;
165
166 fn is_authorized(
168 &self,
169 ctx: &ConnectionContext,
170 method: &MethodId,
171 ) -> Result<(), RpcError>;
172
173 async fn observe_cancel(
178 &self,
179 _ctx: &ConnectionContext,
180 _id: u64,
181 ) -> Result<(), Error> {
182 Ok(())
183 }
184}
185
186#[derive(Debug, Clone)]
188pub struct ConnectionContext {
189 pub id: ConnectionId,
190 pub capabilities: CapabilitySet,
191 pub metadata: Option<serde_json::Value>,
193}
194
195#[derive(Debug)]
198pub struct HeartbeatTracker {
199 process_manager: Arc<dyn ProcessManager>,
200 check_interval: Duration,
201 max_silence: Option<Duration>,
204 state: Mutex<HashMap<ConnectionId, HeartbeatEntry>>,
205}
206
207#[derive(Debug, Clone)]
208struct HeartbeatEntry {
209 pid: ProcessId,
210 last_heartbeat: Instant,
211}
212
213impl HeartbeatTracker {
214 pub fn new(
219 process_manager: Arc<dyn ProcessManager>,
220 check_interval: Duration,
221 allowed_missed_intervals: u32,
222 ) -> Arc<Self> {
223 let max_silence = check_interval.checked_mul(allowed_missed_intervals);
224 let tracker = Arc::new(Self {
225 process_manager,
226 check_interval,
227 max_silence,
228 state: Mutex::new(HashMap::new()),
229 });
230 Self::spawn_checker(tracker.clone());
231 tracker
232 }
233
234 fn spawn_checker(tracker: Arc<Self>) {
235 let _ = tokio::spawn(async move {
236 tracker.run_checker().await;
237 });
238 }
239
240 async fn run_checker(self: Arc<Self>) {
241 let mut interval = time::interval(self.check_interval);
242
243 loop {
244 interval.tick().await;
245
246 let Some(max_silence) = self.max_silence else {
247 continue;
248 };
249
250 let now = Instant::now();
251 let to_terminate: Vec<(ConnectionId, ProcessId)> = {
252 let state_lock = self.state.lock();
253 let state = match state_lock {
254 Ok(s) => s,
255 Err(_) => break, };
257 state
258 .iter()
259 .filter_map(|(conn_id, entry)| {
260 let since = now.duration_since(entry.last_heartbeat);
261 if since > max_silence {
262 Some((*conn_id, entry.pid))
263 } else {
264 None
265 }
266 })
267 .collect()
268 };
269
270 if to_terminate.is_empty() {
271 continue;
272 }
273
274 for (conn_id, pid) in to_terminate {
275 let span = tracing::info_span!(
276 "ipc.heartbeat.timeout",
277 conn_id = %conn_id,
278 process_id = %pid
279 );
280
281 async {
282 let _ =
283 self.process_manager.terminate_tree(pid, true).await;
284 if let Ok(mut state) = self.state.lock() {
285 state.remove(&conn_id);
286 }
287 }
288 .instrument(span)
289 .await;
290 }
291 }
292 }
293
294 pub fn track_connection(&self, connection: ConnectionId, pid: ProcessId) {
296 let entry = HeartbeatEntry {
297 pid,
298 last_heartbeat: Instant::now(),
299 };
300
301 if let Ok(mut state) = self.state.lock() {
302 state.insert(connection, entry);
303 }
304 }
305
306 pub fn record_heartbeat(&self, connection: &ConnectionId) {
308 if let Ok(mut state) = self.state.lock() {
309 if let Some(entry) = state.get_mut(connection) {
310 entry.last_heartbeat = Instant::now();
311 }
312 }
313 }
314
315 pub fn remove(&self, connection: &ConnectionId) {
317 if let Ok(mut state) = self.state.lock() {
318 state.remove(connection);
319 }
320 }
321
322 #[cfg(test)]
324 fn is_tracked(&self, connection: &ConnectionId) -> bool {
325 if let Ok(state) = self.state.lock() {
326 state.contains_key(connection)
327 } else {
328 false
329 }
330 }
331}
332
333#[derive(Debug)]
336pub struct IpcRouter {
337 connections: std::sync::RwLock<Vec<ConnectionContext>>,
339 network_service: Option<Arc<dyn NetworkService>>,
341 process_service: Option<Arc<dyn ProcessService>>,
342 rate_limiter: tokio::sync::Mutex<HashMap<RateKey, TokenBucket>>,
343}
344
345impl Default for IpcRouter {
346 fn default() -> Self {
347 Self::new()
348 }
349}
350
351impl IpcRouter {
352 pub fn new() -> Self {
354 Self {
355 connections: std::sync::RwLock::new(Vec::new()),
356 network_service: None,
357 process_service: None,
358 rate_limiter: tokio::sync::Mutex::new(HashMap::new()),
359 }
360 }
361
362 pub fn with_process_service(
364 mut self,
365 svc: Arc<dyn ProcessService>,
366 ) -> Self {
367 self.process_service = Some(svc);
368 self
369 }
370
371 pub fn with_network_service(
373 mut self,
374 svc: Arc<dyn NetworkService>,
375 ) -> Self {
376 self.network_service = Some(svc);
377 self
378 }
379
380 fn should_rate_limit(service: &str) -> bool {
381 service == "io" || service == "network"
382 }
383
384 fn match_quotas(
385 ctx: &ConnectionContext,
386 method: &MethodId,
387 ) -> Result<
388 Option<crate::workspace::ipc::auth::capability::QuotaSet>,
389 RpcError,
390 > {
391 use crate::workspace::ipc::auth::capability::{
392 MethodRule, ServiceName,
393 };
394
395 let service_key = ServiceName(method.service.clone());
396 let Some(rules) = ctx.capabilities.allowed.get(&service_key) else {
397 return Err(RpcError::unauthorized(format!(
398 "service '{}' is not allowed",
399 method.service
400 )));
401 };
402
403 let matched: Option<&MethodRule> = rules
404 .iter()
405 .find(|rule| rule.method.matches(&method.service, &method.method));
406
407 let Some(rule) = matched else {
408 return Err(RpcError::unauthorized(format!(
409 "method '{}.{}' is not allowed",
410 method.service, method.method
411 )));
412 };
413
414 Ok(rule.quotas.clone())
415 }
416
417 async fn enforce_rate_limits(
418 &self,
419 ctx: &ConnectionContext,
420 method: &MethodId,
421 request_bytes: usize,
422 ) -> Result<(), RpcError> {
423 if !Self::should_rate_limit(method.service.as_str()) {
424 return Ok(());
425 }
426
427 let quotas = Self::match_quotas(ctx, method)?;
428 let Some(quotas) = quotas else {
429 return Ok(());
430 };
431
432 let now = Instant::now();
433
434 if let Some(rate) = quotas.bytes_per_sec {
435 let Some(burst) = quotas.effective_burst_bytes() else {
436 return Ok(());
437 };
438
439 let cost_u64 = u64::try_from(request_bytes).map_err(|e| {
440 RpcError::capability_denied(format!(
441 "request size too large for limiter: {e}"
442 ))
443 })?;
444
445 let key = RateKey::bytes(
446 ctx.id,
447 method.service.clone(),
448 method.method.clone(),
449 );
450
451 let mut guard = self.rate_limiter.lock().await;
452 let bucket = guard
453 .entry(key)
454 .or_insert_with(|| TokenBucket::new(rate, burst, now));
455
456 if !bucket.try_take(cost_u64, now) {
457 return Err(RpcError::capability_denied(format!(
458 "rate limit exceeded for {}.{} (bytes/sec)",
459 method.service, method.method
460 )));
461 }
462 }
463
464 Ok(())
465 }
466
467 fn decode_shutdown_request(
468 args: Vec<u8>,
469 ) -> Result<ShutdownTreeRequest, RpcError> {
470 if args.is_empty() {
471 return Ok(ShutdownTreeRequest::default());
472 }
473 postcard::from_bytes::<ShutdownTreeRequest>(&args).map_err(|e| {
474 RpcError {
475 code: "invalid_args".into(),
476 message: format!("invalid workspace.shutdown args: {e}"),
477 }
478 })
479 }
480 fn error_response(id: u64, code: &str, message: &str) -> Response {
482 Response {
483 id,
484 ok: false,
485 result: None,
486 error: Some(RpcError {
487 code: code.into(),
488 message: message.into(),
489 }),
490 }
491 }
492
493 fn success_response(id: u64, bytes: Vec<u8>) -> Response {
495 Response {
496 id,
497 ok: true,
498 result: Some(bytes),
499 error: None,
500 }
501 }
502
503 fn check_service<T>(
505 &self,
506 service: &Option<T>,
507 service_name: &str,
508 id: u64,
509 ) -> Result<(), Response> {
510 if service.is_none() {
511 return Err(Self::error_response(
512 id,
513 "not_implemented",
514 &format!("{service_name} service not configured"),
515 ));
516 }
517 Ok(())
518 }
519
520 fn decode_request<T: serde::de::DeserializeOwned>(
522 &self,
523 args: &[u8],
524 method: &str,
525 id: u64,
526 ) -> Result<T, Response> {
527 postcard::from_bytes(args).map_err(|e| {
528 Self::error_response(
529 id,
530 "invalid_args",
531 &format!("invalid {method}.args: {e}"),
532 )
533 })
534 }
535
536 async fn handle_service_call<T: serde::Serialize, E: std::fmt::Display>(
538 &self,
539 call: impl std::future::Future<Output = Result<T, E>>,
540 id: u64,
541 ) -> Result<Response, Error> {
542 match call.await {
543 Ok(resp) => {
544 let bytes = postcard::to_stdvec(&resp).map_err(|e| {
545 anyhow::anyhow!("failed to serialize response: {e}")
546 })?;
547 Ok(Self::success_response(id, bytes))
548 }
549 Err(e) => Ok(Self::error_response(id, "internal", &e.to_string())),
550 }
551 }
552
553 async fn dispatch_process(
554 &self,
555 request: Request,
556 ) -> Result<Response, Error> {
557 if let Err(resp) =
558 self.check_service(&self.process_service, "process", request.id)
559 {
560 return Ok(resp);
561 }
562 let svc = self.process_service.as_ref().unwrap();
563
564 match request.method.method.as_str() {
567 METHOD_SHUTDOWN_TREE => {
568 let shutdown_req = Self::decode_shutdown_request(request.args)
569 .map_err(|e| {
570 Self::error_response(request.id, &e.code, &e.message)
571 })?;
572 self.handle_service_call(
573 svc.shutdown_tree(shutdown_req),
574 request.id,
575 )
576 .await
577 }
578 _ => Ok(Self::error_response(
579 request.id,
580 "not_implemented",
581 "method not implemented",
582 )),
583 }
584 }
585
586 async fn dispatch_network(
587 &self,
588 request: Request,
589 ) -> Result<Response, Error> {
590 if let Err(resp) =
591 self.check_service(&self.network_service, "network", request.id)
592 {
593 return Ok(resp);
594 }
595 let svc = self.network_service.as_ref().unwrap();
596
597 match request.method.method.as_str() {
598 METHOD_FETCH => {
599 let req = self.decode_request(
600 &request.args,
601 "network.fetch",
602 request.id,
603 )?;
604 self.handle_service_call(svc.fetch(req), request.id).await
605 }
606 METHOD_READ_FILE => {
607 let req = self.decode_request(
608 &request.args,
609 "network.read_file",
610 request.id,
611 )?;
612 self.handle_service_call(svc.read_file(req), request.id)
613 .await
614 }
615 _ => Ok(Self::error_response(
616 request.id,
617 "not_implemented",
618 "method not implemented",
619 )),
620 }
621 }
622}
623
624#[async_trait]
625impl Router for IpcRouter {
626 async fn register_connection(
628 &self,
629 ctx: ConnectionContext,
630 ) -> Result<(), Error> {
631 if let Ok(mut guard) = self.connections.write() {
632 if let Some(pos) = guard.iter().position(|c| c.id == ctx.id) {
633 guard[pos] = ctx;
634 } else {
635 guard.push(ctx);
636 }
637 }
638 Ok(())
639 }
640
641 async fn dispatch(
643 &self,
644 ctx: &ConnectionContext,
645 request: Request,
646 ) -> Result<Response, Error> {
647 let span = tracing::info_span!(
648 "ipc.dispatch",
649 conn_id = %ctx.id,
650 request_id = request.id,
651 service = %request.method.service,
652 method = %request.method.method
653 );
654
655 async move {
656 match self.is_authorized(ctx, &request.method) {
657 Ok(()) => {
658 if let Err(e) = self
659 .enforce_rate_limits(
660 ctx,
661 &request.method,
662 request.args.len(),
663 )
664 .await
665 {
666 return Ok(Response {
667 id: request.id,
668 ok: false,
669 result: None,
670 error: Some(e),
671 });
672 }
673
674 if request.method.service == PROCESS_SERVICE_NAME {
675 return self.dispatch_process(request).await;
676 }
677
678 if request.method.service == NETWORK_SERVICE_NAME {
679 return self.dispatch_network(request).await;
680 }
681
682 Ok(Response {
683 id: request.id,
684 ok: false,
685 result: None,
686 error: Some(RpcError {
687 code: "not_implemented".into(),
688 message: "method not implemented".into(),
689 }),
690 })
691 }
692 Err(e) => Ok(Response {
693 id: request.id,
694 ok: false,
695 result: None,
696 error: Some(e),
697 }),
698 }
699 }
700 .instrument(span)
701 .await
702 }
703
704 async fn emit_event(&self, event: Event) -> Result<(), Error> {
705 let span = tracing::info_span!("ipc.emit_event", topic = ?event.topic);
706 async move {
707 Ok(())
709 }
710 .instrument(span)
711 .await
712 }
713
714 fn is_authorized(
716 &self,
717 ctx: &ConnectionContext,
718 method: &MethodId,
719 ) -> Result<(), RpcError> {
720 use crate::workspace::ipc::auth::capability::{
721 MethodRule, ServiceName,
722 };
723
724 let service_key = ServiceName(method.service.clone());
725 let Some(rules) = ctx.capabilities.allowed.get(&service_key) else {
726 return Err(RpcError {
727 code: "unauthorized".into(),
728 message: format!("service '{}' is not allowed", method.service),
729 });
730 };
731
732 let allowed = rules.iter().any(|rule: &MethodRule| {
733 rule.method.matches(&method.service, &method.method)
734 });
735
736 if allowed {
737 Ok(())
738 } else {
739 Err(RpcError {
740 code: "unauthorized".into(),
741 message: format!(
742 "method '{}.{}' is not allowed",
743 method.service, method.method
744 ),
745 })
746 }
747 }
748
749 async fn observe_cancel(
750 &self,
751 _ctx: &ConnectionContext,
752 _id: u64,
753 ) -> Result<(), Error> {
754 Ok(())
755 }
756}
757
758impl IpcRouter {
759 pub fn get(&self, id: &ConnectionId) -> Option<ConnectionContext> {
761 if let Ok(guard) = self.connections.read() {
762 guard.iter().find(|c| c.id == *id).cloned()
763 } else {
764 None
765 }
766 }
767}
768
769#[cfg(test)]
770mod tests {
771 use super::*;
772
773 use crate::workspace::ipc::auth::capability::{
774 CapabilitySet, MethodRule, MethodSelector, QuotaSet, ServiceName,
775 };
776 use crate::workspace::ipc::process_manager::{
777 ChildHandle, ProcessManager, SpawnParams,
778 };
779 use crate::workspace::ipc::protocol::{MethodId, Request};
780 use crate::workspace::ipc::services::network::api::{
781 BytesResponse as NetBytesResponse, FetchRequest, METHOD_FETCH,
782 METHOD_READ_FILE, MockNetworkService, ReadFileRequest,
783 SERVICE_NAME as NETWORK_SERVICE_NAME,
784 };
785 use crate::workspace::ipc::services::process::MockProcessService;
786 use crate::workspace::ipc::{
787 assert_ipc_response_error, assert_ipc_response_ok,
788 };
789
790 use crate::workspace::ipc::types::{ChildKind, ProcessId};
791 use anyhow::Result;
792 use std::collections::HashMap;
793
794 fn ctx_with_rules(
795 service: &str,
796 rules: Vec<MethodRule>,
797 ) -> ConnectionContext {
798 let mut allowed: HashMap<ServiceName, Vec<MethodRule>> = HashMap::new();
799 allowed.insert(ServiceName(service.to_string()), rules);
800 ConnectionContext {
801 id: Default::default(),
802 capabilities: CapabilitySet {
803 allowed,
804 global_limits: None,
805 },
806 metadata: None,
807 }
808 }
809
810 fn req(service: &str, method: &str) -> Request {
811 Request {
812 id: Default::default(),
813 method: MethodId {
814 service: service.into(),
815 method: method.into(),
816 },
817 args: vec![],
818 }
819 }
820
821 #[crate::ctb_test(tokio::test)]
823 async fn auth_exact_allows() -> Result<()> {
824 let router = IpcRouter::new();
825 let ctx = ctx_with_rules(
826 "svc",
827 vec![MethodRule {
828 method: MethodSelector::Exact("do_work".into()),
829 quotas: None,
830 }],
831 );
832
833 let resp = router.dispatch(&ctx, req("svc", "do_work")).await?;
834 assert!(!resp.ok);
835 assert_eq!(resp.error.unwrap().code, "not_implemented");
836
837 let resp2 = router.dispatch(&ctx, req("svc", "other")).await?;
838 assert!(!resp2.ok);
839 assert_eq!(resp2.error.unwrap().code, "unauthorized");
840 Ok(())
841 }
842
843 #[crate::ctb_test(tokio::test)]
845 async fn auth_prefix() -> Result<()> {
846 let router = IpcRouter::new();
847 let ctx = ctx_with_rules(
848 "svc",
849 vec![MethodRule {
850 method: MethodSelector::Prefix("do_".into()),
851 quotas: None,
852 }],
853 );
854
855 let ok_resp = router.dispatch(&ctx, req("svc", "do_stuff")).await?;
856 assert!(!ok_resp.ok);
857 assert_eq!(ok_resp.error.unwrap().code, "not_implemented");
858
859 let deny_resp = router.dispatch(&ctx, req("svc", "list")).await?;
860 assert!(!deny_resp.ok);
861 assert_eq!(deny_resp.error.unwrap().code, "unauthorized");
862 Ok(())
863 }
864
865 #[crate::ctb_test(tokio::test)]
868 async fn auth_any_and_missing_service() -> Result<()> {
869 let router = IpcRouter::new();
870
871 let ctx_any = ctx_with_rules(
873 "svc",
874 vec![MethodRule {
875 method: MethodSelector::Any,
876 quotas: None,
877 }],
878 );
879
880 let ok_resp = router.dispatch(&ctx_any, req("svc", "x")).await?;
881 assert!(!ok_resp.ok);
882 assert_eq!(ok_resp.error.unwrap().code, "not_implemented");
883
884 let deny_resp = router.dispatch(&ctx_any, req("other", "x")).await?;
886 assert!(!deny_resp.ok);
887 assert_eq!(deny_resp.error.unwrap().code, "unauthorized");
888 Ok(())
889 }
890
891 #[crate::ctb_test(tokio::test)]
893 async fn auth_fully_qualified_selectors() -> Result<()> {
894 let router = IpcRouter::new();
895 let ctx = ctx_with_rules(
896 "svc",
897 vec![
898 MethodRule {
899 method: MethodSelector::Exact("svc.do_a".into()),
900 quotas: None,
901 },
902 MethodRule {
903 method: MethodSelector::Prefix("svc.do_".into()),
904 quotas: None,
905 },
906 ],
907 );
908
909 let r1 = router.dispatch(&ctx, req("svc", "do_a")).await?;
911 assert_eq!(r1.error.unwrap().code, "not_implemented");
912
913 let r2 = router.dispatch(&ctx, req("svc", "do_b")).await?;
915 assert_eq!(r2.error.unwrap().code, "not_implemented");
916
917 let r3 = router.dispatch(&ctx, req("svc", "list")).await?;
919 assert_eq!(r3.error.unwrap().code, "unauthorized");
920
921 Ok(())
922 }
923
924 #[crate::ctb_test(tokio::test)]
926 async fn register_stores_context() -> Result<()> {
927 let router = IpcRouter::new();
928 let ctx = ConnectionContext {
929 id: ConnectionId::default(),
930 capabilities: CapabilitySet::default(),
931 metadata: Some(serde_json::json!({"k":"v"})),
932 };
933 router.register_connection(ctx.clone()).await?;
934 let got = router.get(&ctx.id).expect("stored");
935 assert_eq!(got.id, ctx.id);
936 assert_eq!(got.metadata, ctx.metadata);
937 Ok(())
938 }
939
940 #[derive(Debug, Default)]
941 struct MockProcessManager {
942 terminated: std::sync::Mutex<Vec<(ProcessId, bool)>>,
943 }
944
945 impl MockProcessManager {
946 fn terminations(&self) -> Vec<(ProcessId, bool)> {
947 if let Ok(guard) = self.terminated.lock() {
948 guard.clone()
949 } else {
950 Vec::new()
951 }
952 }
953 }
954
955 #[async_trait]
956 impl ProcessManager for MockProcessManager {
957 async fn spawn_child(
958 &self,
959 _params: SpawnParams,
960 ) -> Result<ChildHandle, Error> {
961 Ok(ChildHandle {
962 pid: ProcessId::default(),
963 kind: ChildKind::Renderer,
964 connection: None,
965 })
966 }
967
968 async fn attach_connection(
969 &self,
970 _pid: ProcessId,
971 _conn: ConnectionId,
972 ) -> Result<(), Error> {
973 Ok(())
974 }
975
976 async fn list_children(&self) -> Result<Vec<ChildHandle>, Error> {
977 Ok(Vec::new())
978 }
979
980 async fn terminate_tree(
981 &self,
982 pid: ProcessId,
983 force: bool,
984 ) -> Result<(), Error> {
985 if let Ok(mut guard) = self.terminated.lock() {
986 guard.push((pid, force));
987 }
988 Ok(())
989 }
990 }
991
992 #[crate::ctb_test(tokio::test)]
994 async fn heartbeat_miss_triggers_termination_and_cleanup() -> Result<()> {
995 let manager = Arc::new(MockProcessManager::default());
996 let tracker = HeartbeatTracker::new(
997 manager.clone(),
998 Duration::from_millis(20),
999 1,
1000 );
1001
1002 let conn = ConnectionId::default();
1003 let pid = ProcessId::default();
1004 tracker.track_connection(conn, pid);
1005
1006 tokio::time::sleep(Duration::from_millis(100)).await;
1008
1009 let terms = manager.terminations();
1010 assert_eq!(terms.len(), 1);
1011 assert_eq!(terms[0].0, pid);
1012 assert!(terms[0].1);
1013 assert!(!tracker.is_tracked(&conn));
1014
1015 Ok(())
1016 }
1017
1018 #[crate::ctb_test(tokio::test)]
1019 async fn routes_workspace_shutdown() -> Result<()> {
1020 let router =
1021 IpcRouter::new().with_process_service(Arc::new(MockProcessService));
1022
1023 let mut allowed: HashMap<ServiceName, Vec<MethodRule>> = HashMap::new();
1024 allowed.insert(
1025 ServiceName("process".to_string()),
1026 vec![MethodRule {
1027 method: MethodSelector::Exact("shutdown_tree".into()),
1028 quotas: None,
1029 }],
1030 );
1031
1032 let ctx = ConnectionContext {
1033 id: Default::default(),
1034 capabilities: CapabilitySet {
1035 allowed,
1036 global_limits: None,
1037 },
1038 metadata: None,
1039 };
1040
1041 let req = Request {
1042 id: Default::default(),
1043 method: MethodId {
1044 service: "process".into(),
1045 method: "shutdown_tree".into(),
1046 },
1047 args: vec![],
1048 };
1049
1050 let resp = router.dispatch(&ctx, req).await?;
1051 assert_ipc_response_ok(&resp);
1052 assert!(resp.error.is_none());
1053 Ok(())
1054 }
1055
1056 #[crate::ctb_test(tokio::test)]
1058 async fn rate_limit_bytes_per_sec_is_enforced_for_io_and_network()
1059 -> Result<()> {
1060 let router = IpcRouter::new();
1061
1062 let mut allowed: HashMap<ServiceName, Vec<MethodRule>> = HashMap::new();
1063 allowed.insert(
1064 ServiceName("io".to_string()),
1065 vec![MethodRule {
1066 method: MethodSelector::Exact("read".into()),
1067 quotas: Some(QuotaSet {
1068 bytes_per_sec: Some(10),
1069 ops_per_sec: None,
1070 burst: Some(10),
1071 }),
1072 }],
1073 );
1074
1075 let ctx = ConnectionContext {
1076 id: ConnectionId::default(),
1077 capabilities: CapabilitySet {
1078 allowed,
1079 global_limits: None,
1080 },
1081 metadata: None,
1082 };
1083
1084 let mk_req = |n: usize| Request {
1085 id: Default::default(),
1086 method: MethodId {
1087 service: "io".into(),
1088 method: "read".into(),
1089 },
1090 args: vec![0u8; n],
1091 };
1092
1093 let r1 = router.dispatch(&ctx, mk_req(10)).await?;
1095 assert_eq!(r1.error.unwrap().code, "not_implemented");
1096
1097 let r2 = router.dispatch(&ctx, mk_req(1)).await?;
1099 assert_ipc_response_error(&r2);
1100 assert_eq!(r2.error.unwrap().code, "capability_denied");
1101
1102 Ok(())
1103 }
1104
1105 #[crate::ctb_test(tokio::test)]
1107 async fn routes_network_fetch() -> anyhow::Result<()> {
1108 let router = IpcRouter::new()
1109 .with_network_service(Arc::new(MockNetworkService::new()));
1110
1111 let mut allowed: HashMap<ServiceName, Vec<MethodRule>> = HashMap::new();
1112 allowed.insert(
1113 ServiceName(NETWORK_SERVICE_NAME.to_string()),
1114 vec![MethodRule {
1115 method: MethodSelector::Exact(METHOD_FETCH.into()),
1116 quotas: None,
1117 }],
1118 );
1119
1120 let ctx = ConnectionContext {
1121 id: Default::default(),
1122 capabilities: CapabilitySet {
1123 allowed,
1124 global_limits: None,
1125 },
1126 metadata: None,
1127 };
1128
1129 let args = postcard::to_stdvec(&FetchRequest {
1130 url: "https://example.com".into(),
1131 })?;
1132 let req = Request {
1133 id: Default::default(),
1134 method: MethodId {
1135 service: NETWORK_SERVICE_NAME.into(),
1136 method: METHOD_FETCH.into(),
1137 },
1138 args,
1139 };
1140
1141 let resp = router.dispatch(&ctx, req).await?;
1142 assert_ipc_response_ok(&resp);
1143 let decoded: NetBytesResponse =
1144 postcard::from_bytes(&resp.result.unwrap())?;
1145 assert_eq!(
1146 decoded.bytes,
1147 vec![
1148 69, 120, 97, 109, 112, 108, 101, 32, 72, 84, 84, 80, 83, 32,
1149 70, 101, 116, 99, 104
1150 ]
1151 );
1152 Ok(())
1153 }
1154
1155 #[crate::ctb_test(tokio::test)]
1157 async fn routes_network_read_file() -> anyhow::Result<()> {
1158 let router = IpcRouter::new()
1159 .with_network_service(Arc::new(MockNetworkService::new()));
1160
1161 let mut allowed: HashMap<ServiceName, Vec<MethodRule>> = HashMap::new();
1162 allowed.insert(
1163 ServiceName(NETWORK_SERVICE_NAME.to_string()),
1164 vec![MethodRule {
1165 method: MethodSelector::Exact(METHOD_READ_FILE.into()),
1166 quotas: None,
1167 }],
1168 );
1169
1170 let ctx = ConnectionContext {
1171 id: Default::default(),
1172 capabilities: CapabilitySet {
1173 allowed,
1174 global_limits: None,
1175 },
1176 metadata: None,
1177 };
1178
1179 let args: Vec<u8> = postcard::to_stdvec(&ReadFileRequest {
1180 path: "/tmp/example/file.txt".into(),
1181 })?;
1182 let req = Request {
1183 id: Default::default(),
1184 method: MethodId {
1185 service: NETWORK_SERVICE_NAME.into(),
1186 method: METHOD_READ_FILE.into(),
1187 },
1188 args,
1189 };
1190
1191 let resp = router.dispatch(&ctx, req).await?;
1192 assert_ipc_response_ok(&resp);
1193 let decoded: NetBytesResponse =
1194 postcard::from_bytes(&resp.result.unwrap())?;
1195 assert_eq!(
1196 decoded.bytes,
1197 vec![
1198 69, 120, 97, 109, 112, 108, 101, 32, 70, 105, 108, 101, 32, 82,
1199 101, 97, 100
1200 ]
1201 );
1202 Ok(())
1203 }
1204}