ctoolbox/workspace/ipc/
router.rs

1use crate::workspace::ipc::auth::capability::CapabilitySet;
2use crate::workspace::ipc::error::Error;
3use crate::workspace::ipc::process_manager::ProcessManager;
4use crate::workspace::ipc::protocol::{
5    Event, MethodId, Request, Response, RpcError,
6};
7use crate::workspace::ipc::services::network::api::{
8    METHOD_FETCH, METHOD_READ_FILE, NetworkService,
9    SERVICE_NAME as NETWORK_SERVICE_NAME,
10};
11use crate::workspace::ipc::services::process::api::{
12    METHOD_SHUTDOWN_TREE, ProcessService, SERVICE_NAME as PROCESS_SERVICE_NAME,
13    ShutdownTreeRequest,
14};
15use crate::workspace::ipc::types::{ConnectionId, ProcessId};
16use async_trait::async_trait;
17use std::collections::HashMap;
18use std::sync::{Arc, Mutex};
19use std::time::Duration;
20use tokio::time::{self, Instant};
21
22use tracing::Instrument as _;
23
24/// Key for rate limiting (connection, service, method).
25#[derive(Clone, Debug, Eq, PartialEq, Hash)]
26struct RateKey {
27    conn_id: ConnectionId,
28    service: String,
29    method: String,
30}
31
32impl RateKey {
33    fn bytes(conn_id: ConnectionId, service: String, method: String) -> Self {
34        Self {
35            conn_id,
36            service,
37            method,
38        }
39    }
40}
41
42/// Simple token bucket for rate limiting.
43#[derive(Debug)]
44struct TokenBucket {
45    capacity: u128,
46    tokens: u128,
47    last_refill: Instant,
48    rate_bytes_per_sec: u64,
49}
50
51impl TokenBucket {
52    fn new(rate: u64, capacity: u64, now: Instant) -> Self {
53        Self {
54            capacity: u128::from(capacity),
55            tokens: u128::from(capacity),
56            last_refill: now,
57            rate_bytes_per_sec: rate,
58        }
59    }
60
61    fn try_take(&mut self, amount: u64, now: Instant) -> bool {
62        let elapsed = now.duration_since(self.last_refill);
63        let elapsed_ns = elapsed.as_nanos();
64        let added_tokens =
65            (elapsed_ns * u128::from(self.rate_bytes_per_sec)) / 1_000_000_000;
66        self.tokens = (self.tokens + added_tokens).min(self.capacity);
67        self.last_refill = now;
68
69        let amount_u128 = u128::from(amount);
70        if self.tokens >= amount_u128 {
71            self.tokens -= amount_u128;
72            true
73        } else {
74            false
75        }
76    }
77}
78
79/// Request-scoped cancellation state for cooperative cancellation.
80///
81/// Services can query cancellation via [`is_cancelled`]. Cancellation is
82/// intentionally best-effort and requires cooperative checks.
83#[derive(Debug, Clone)]
84pub struct RequestCancellation {
85    cancelled: std::sync::Arc<std::sync::atomic::AtomicBool>,
86}
87
88impl RequestCancellation {
89    /// Create a new, non-cancelled token.
90    pub fn new() -> Self {
91        Self {
92            cancelled: std::sync::Arc::new(std::sync::atomic::AtomicBool::new(
93                false,
94            )),
95        }
96    }
97
98    /// Mark this token as cancelled.
99    pub fn cancel(&self) {
100        self.cancelled
101            .store(true, std::sync::atomic::Ordering::SeqCst);
102    }
103
104    /// Whether cancellation has been requested.
105    pub fn is_cancelled(&self) -> bool {
106        self.cancelled.load(std::sync::atomic::Ordering::SeqCst)
107    }
108}
109
110impl Default for RequestCancellation {
111    fn default() -> Self {
112        Self::new()
113    }
114}
115
116#[derive(Debug, Clone)]
117struct RequestContext {
118    cancellation: RequestCancellation,
119}
120
121tokio::task_local! {
122    static IPC_REQUEST_CONTEXT: RequestContext;
123}
124
125/// Run `fut` with a request cancellation context installed.
126pub async fn scope_request_cancellation<T>(
127    cancellation: RequestCancellation,
128    fut: impl std::future::Future<Output = T>,
129) -> T {
130    IPC_REQUEST_CONTEXT
131        .scope(RequestContext { cancellation }, fut)
132        .await
133}
134
135/// Returns true if the current task is executing within an IPC request and has
136/// been cancelled.
137pub fn is_cancelled() -> bool {
138    IPC_REQUEST_CONTEXT
139        .try_with(|ctx| ctx.cancellation.is_cancelled())
140        .unwrap_or(false)
141}
142
143/// Central router interface responsible for:
144/// - binding connections to capability sets
145/// - enforcing authorization
146/// - dispatching requests to services
147/// - emitting events
148#[async_trait]
149pub trait Router: Send + Sync {
150    /// Register a new connection after handshake.
151    async fn register_connection(
152        &self,
153        ctx: ConnectionContext,
154    ) -> Result<(), Error>;
155
156    /// Resolve and dispatch a request to a target service method.
157    async fn dispatch(
158        &self,
159        ctx: &ConnectionContext,
160        request: Request,
161    ) -> Result<Response, Error>;
162
163    /// Emit an event to a connection or broadcast to all with appropriate policies.
164    async fn emit_event(&self, event: Event) -> Result<(), Error>;
165
166    /// Check whether a given method is allowed by a connection’s capabilities.
167    fn is_authorized(
168        &self,
169        ctx: &ConnectionContext,
170        method: &MethodId,
171    ) -> Result<(), RpcError>;
172
173    /// Observe a cancellation request for auditing/metrics.
174    ///
175    /// Implementations may ignore this. Cancellation itself is enforced by the
176    /// IPC server loop via cooperative checks and/or task abort.
177    async fn observe_cancel(
178        &self,
179        _ctx: &ConnectionContext,
180        _id: u64,
181    ) -> Result<(), Error> {
182        Ok(())
183    }
184}
185
186/// Context bound to a connection for authorization and audit.
187#[derive(Debug, Clone)]
188pub struct ConnectionContext {
189    pub id: ConnectionId,
190    pub capabilities: CapabilitySet,
191    /// Optional additional metadata (process kind, user, document id, etc.)
192    pub metadata: Option<serde_json::Value>,
193}
194
195/// Tracks connection heartbeats and terminates associated processes when
196/// liveness timeouts are exceeded.
197#[derive(Debug)]
198pub struct HeartbeatTracker {
199    process_manager: Arc<dyn ProcessManager>,
200    check_interval: Duration,
201    /// Maximum silence before a connection is considered dead. `None` disables
202    /// timeouts.
203    max_silence: Option<Duration>,
204    state: Mutex<HashMap<ConnectionId, HeartbeatEntry>>,
205}
206
207#[derive(Debug, Clone)]
208struct HeartbeatEntry {
209    pid: ProcessId,
210    last_heartbeat: Instant,
211}
212
213impl HeartbeatTracker {
214    /// Create a new tracker and start a background checker task.
215    ///
216    /// `allowed_missed_intervals` controls after how many consecutive missed
217    /// intervals the connection is considered dead.
218    pub fn new(
219        process_manager: Arc<dyn ProcessManager>,
220        check_interval: Duration,
221        allowed_missed_intervals: u32,
222    ) -> Arc<Self> {
223        let max_silence = check_interval.checked_mul(allowed_missed_intervals);
224        let tracker = Arc::new(Self {
225            process_manager,
226            check_interval,
227            max_silence,
228            state: Mutex::new(HashMap::new()),
229        });
230        Self::spawn_checker(tracker.clone());
231        tracker
232    }
233
234    fn spawn_checker(tracker: Arc<Self>) {
235        let _ = tokio::spawn(async move {
236            tracker.run_checker().await;
237        });
238    }
239
240    async fn run_checker(self: Arc<Self>) {
241        let mut interval = time::interval(self.check_interval);
242
243        loop {
244            interval.tick().await;
245
246            let Some(max_silence) = self.max_silence else {
247                continue;
248            };
249
250            let now = Instant::now();
251            let to_terminate: Vec<(ConnectionId, ProcessId)> = {
252                let state_lock = self.state.lock();
253                let state = match state_lock {
254                    Ok(s) => s,
255                    Err(_) => break, // Poisoned
256                };
257                state
258                    .iter()
259                    .filter_map(|(conn_id, entry)| {
260                        let since = now.duration_since(entry.last_heartbeat);
261                        if since > max_silence {
262                            Some((*conn_id, entry.pid))
263                        } else {
264                            None
265                        }
266                    })
267                    .collect()
268            };
269
270            if to_terminate.is_empty() {
271                continue;
272            }
273
274            for (conn_id, pid) in to_terminate {
275                let span = tracing::info_span!(
276                    "ipc.heartbeat.timeout",
277                    conn_id = %conn_id,
278                    process_id = %pid
279                );
280
281                async {
282                    let _ =
283                        self.process_manager.terminate_tree(pid, true).await;
284                    if let Ok(mut state) = self.state.lock() {
285                        state.remove(&conn_id);
286                    }
287                }
288                .instrument(span)
289                .await;
290            }
291        }
292    }
293
294    /// Start tracking a connection and its owning process.
295    pub fn track_connection(&self, connection: ConnectionId, pid: ProcessId) {
296        let entry = HeartbeatEntry {
297            pid,
298            last_heartbeat: Instant::now(),
299        };
300
301        if let Ok(mut state) = self.state.lock() {
302            state.insert(connection, entry);
303        }
304    }
305
306    /// Record an observed heartbeat from the given connection.
307    pub fn record_heartbeat(&self, connection: &ConnectionId) {
308        if let Ok(mut state) = self.state.lock() {
309            if let Some(entry) = state.get_mut(connection) {
310                entry.last_heartbeat = Instant::now();
311            }
312        }
313    }
314
315    /// Stop tracking the given connection.
316    pub fn remove(&self, connection: &ConnectionId) {
317        if let Ok(mut state) = self.state.lock() {
318            state.remove(connection);
319        }
320    }
321
322    /// Check if a connection is currently tracked (test-only).
323    #[cfg(test)]
324    fn is_tracked(&self, connection: &ConnectionId) -> bool {
325        if let Ok(state) = self.state.lock() {
326            state.contains_key(connection)
327        } else {
328            false
329        }
330    }
331}
332
333/// A simple in-memory router that registers connections, authorizes requests,
334/// and returns a canned 'not implemented' response for authorized calls.
335#[derive(Debug)]
336pub struct IpcRouter {
337    // Keep a minimal registry. Avoid panics; ignore duplicates by replacing.
338    connections: std::sync::RwLock<Vec<ConnectionContext>>,
339    // New: thin network service adapter
340    network_service: Option<Arc<dyn NetworkService>>,
341    process_service: Option<Arc<dyn ProcessService>>,
342    rate_limiter: tokio::sync::Mutex<HashMap<RateKey, TokenBucket>>,
343}
344
345impl Default for IpcRouter {
346    fn default() -> Self {
347        Self::new()
348    }
349}
350
351impl IpcRouter {
352    /// Create a new `SimpleRouter`.
353    pub fn new() -> Self {
354        Self {
355            connections: std::sync::RwLock::new(Vec::new()),
356            network_service: None,
357            process_service: None,
358            rate_limiter: tokio::sync::Mutex::new(HashMap::new()),
359        }
360    }
361
362    /// Set the process service.
363    pub fn with_process_service(
364        mut self,
365        svc: Arc<dyn ProcessService>,
366    ) -> Self {
367        self.process_service = Some(svc);
368        self
369    }
370
371    /// Set the network service (thin adapter).
372    pub fn with_network_service(
373        mut self,
374        svc: Arc<dyn NetworkService>,
375    ) -> Self {
376        self.network_service = Some(svc);
377        self
378    }
379
380    fn should_rate_limit(service: &str) -> bool {
381        service == "io" || service == "network"
382    }
383
384    fn match_quotas(
385        ctx: &ConnectionContext,
386        method: &MethodId,
387    ) -> Result<
388        Option<crate::workspace::ipc::auth::capability::QuotaSet>,
389        RpcError,
390    > {
391        use crate::workspace::ipc::auth::capability::{
392            MethodRule, ServiceName,
393        };
394
395        let service_key = ServiceName(method.service.clone());
396        let Some(rules) = ctx.capabilities.allowed.get(&service_key) else {
397            return Err(RpcError::unauthorized(format!(
398                "service '{}' is not allowed",
399                method.service
400            )));
401        };
402
403        let matched: Option<&MethodRule> = rules
404            .iter()
405            .find(|rule| rule.method.matches(&method.service, &method.method));
406
407        let Some(rule) = matched else {
408            return Err(RpcError::unauthorized(format!(
409                "method '{}.{}' is not allowed",
410                method.service, method.method
411            )));
412        };
413
414        Ok(rule.quotas.clone())
415    }
416
417    async fn enforce_rate_limits(
418        &self,
419        ctx: &ConnectionContext,
420        method: &MethodId,
421        request_bytes: usize,
422    ) -> Result<(), RpcError> {
423        if !Self::should_rate_limit(method.service.as_str()) {
424            return Ok(());
425        }
426
427        let quotas = Self::match_quotas(ctx, method)?;
428        let Some(quotas) = quotas else {
429            return Ok(());
430        };
431
432        let now = Instant::now();
433
434        if let Some(rate) = quotas.bytes_per_sec {
435            let Some(burst) = quotas.effective_burst_bytes() else {
436                return Ok(());
437            };
438
439            let cost_u64 = u64::try_from(request_bytes).map_err(|e| {
440                RpcError::capability_denied(format!(
441                    "request size too large for limiter: {e}"
442                ))
443            })?;
444
445            let key = RateKey::bytes(
446                ctx.id,
447                method.service.clone(),
448                method.method.clone(),
449            );
450
451            let mut guard = self.rate_limiter.lock().await;
452            let bucket = guard
453                .entry(key)
454                .or_insert_with(|| TokenBucket::new(rate, burst, now));
455
456            if !bucket.try_take(cost_u64, now) {
457                return Err(RpcError::capability_denied(format!(
458                    "rate limit exceeded for {}.{} (bytes/sec)",
459                    method.service, method.method
460                )));
461            }
462        }
463
464        Ok(())
465    }
466
467    fn decode_shutdown_request(
468        args: Vec<u8>,
469    ) -> Result<ShutdownTreeRequest, RpcError> {
470        if args.is_empty() {
471            return Ok(ShutdownTreeRequest::default());
472        }
473        postcard::from_bytes::<ShutdownTreeRequest>(&args).map_err(|e| {
474            RpcError {
475                code: "invalid_args".into(),
476                message: format!("invalid workspace.shutdown args: {e}"),
477            }
478        })
479    }
480    /// Helper to create an error response.
481    fn error_response(id: u64, code: &str, message: &str) -> Response {
482        Response {
483            id,
484            ok: false,
485            result: None,
486            error: Some(RpcError {
487                code: code.into(),
488                message: message.into(),
489            }),
490        }
491    }
492
493    /// Helper to create a success response with serialized bytes.
494    fn success_response(id: u64, bytes: Vec<u8>) -> Response {
495        Response {
496            id,
497            ok: true,
498            result: Some(bytes),
499            error: None,
500        }
501    }
502
503    /// Helper to check if a service is configured.
504    fn check_service<T>(
505        &self,
506        service: &Option<T>,
507        service_name: &str,
508        id: u64,
509    ) -> Result<(), Response> {
510        if service.is_none() {
511            return Err(Self::error_response(
512                id,
513                "not_implemented",
514                &format!("{service_name} service not configured"),
515            ));
516        }
517        Ok(())
518    }
519
520    /// Helper to decode a request from bytes.
521    fn decode_request<T: serde::de::DeserializeOwned>(
522        &self,
523        args: &[u8],
524        method: &str,
525        id: u64,
526    ) -> Result<T, Response> {
527        postcard::from_bytes(args).map_err(|e| {
528            Self::error_response(
529                id,
530                "invalid_args",
531                &format!("invalid {method}.args: {e}"),
532            )
533        })
534    }
535
536    /// Helper to handle service call and serialize response.
537    async fn handle_service_call<T: serde::Serialize, E: std::fmt::Display>(
538        &self,
539        call: impl std::future::Future<Output = Result<T, E>>,
540        id: u64,
541    ) -> Result<Response, Error> {
542        match call.await {
543            Ok(resp) => {
544                let bytes = postcard::to_stdvec(&resp).map_err(|e| {
545                    anyhow::anyhow!("failed to serialize response: {e}")
546                })?;
547                Ok(Self::success_response(id, bytes))
548            }
549            Err(e) => Ok(Self::error_response(id, "internal", &e.to_string())),
550        }
551    }
552
553    async fn dispatch_process(
554        &self,
555        request: Request,
556    ) -> Result<Response, Error> {
557        if let Err(resp) =
558            self.check_service(&self.process_service, "process", request.id)
559        {
560            return Ok(resp);
561        }
562        let svc = self.process_service.as_ref().unwrap();
563
564        // FIXME: A process should only be able to shut down itself or its own
565        // subtrees.
566        match request.method.method.as_str() {
567            METHOD_SHUTDOWN_TREE => {
568                let shutdown_req = Self::decode_shutdown_request(request.args)
569                    .map_err(|e| {
570                        Self::error_response(request.id, &e.code, &e.message)
571                    })?;
572                self.handle_service_call(
573                    svc.shutdown_tree(shutdown_req),
574                    request.id,
575                )
576                .await
577            }
578            _ => Ok(Self::error_response(
579                request.id,
580                "not_implemented",
581                "method not implemented",
582            )),
583        }
584    }
585
586    async fn dispatch_network(
587        &self,
588        request: Request,
589    ) -> Result<Response, Error> {
590        if let Err(resp) =
591            self.check_service(&self.network_service, "network", request.id)
592        {
593            return Ok(resp);
594        }
595        let svc = self.network_service.as_ref().unwrap();
596
597        match request.method.method.as_str() {
598            METHOD_FETCH => {
599                let req = self.decode_request(
600                    &request.args,
601                    "network.fetch",
602                    request.id,
603                )?;
604                self.handle_service_call(svc.fetch(req), request.id).await
605            }
606            METHOD_READ_FILE => {
607                let req = self.decode_request(
608                    &request.args,
609                    "network.read_file",
610                    request.id,
611                )?;
612                self.handle_service_call(svc.read_file(req), request.id)
613                    .await
614            }
615            _ => Ok(Self::error_response(
616                request.id,
617                "not_implemented",
618                "method not implemented",
619            )),
620        }
621    }
622}
623
624#[async_trait]
625impl Router for IpcRouter {
626    /// Register a new connection after handshake.
627    async fn register_connection(
628        &self,
629        ctx: ConnectionContext,
630    ) -> Result<(), Error> {
631        if let Ok(mut guard) = self.connections.write() {
632            if let Some(pos) = guard.iter().position(|c| c.id == ctx.id) {
633                guard[pos] = ctx;
634            } else {
635                guard.push(ctx);
636            }
637        }
638        Ok(())
639    }
640
641    /// Resolve and dispatch a request to a target service method.
642    async fn dispatch(
643        &self,
644        ctx: &ConnectionContext,
645        request: Request,
646    ) -> Result<Response, Error> {
647        let span = tracing::info_span!(
648            "ipc.dispatch",
649            conn_id = %ctx.id,
650            request_id = request.id,
651            service = %request.method.service,
652            method = %request.method.method
653        );
654
655        async move {
656            match self.is_authorized(ctx, &request.method) {
657                Ok(()) => {
658                    if let Err(e) = self
659                        .enforce_rate_limits(
660                            ctx,
661                            &request.method,
662                            request.args.len(),
663                        )
664                        .await
665                    {
666                        return Ok(Response {
667                            id: request.id,
668                            ok: false,
669                            result: None,
670                            error: Some(e),
671                        });
672                    }
673
674                    if request.method.service == PROCESS_SERVICE_NAME {
675                        return self.dispatch_process(request).await;
676                    }
677
678                    if request.method.service == NETWORK_SERVICE_NAME {
679                        return self.dispatch_network(request).await;
680                    }
681
682                    Ok(Response {
683                        id: request.id,
684                        ok: false,
685                        result: None,
686                        error: Some(RpcError {
687                            code: "not_implemented".into(),
688                            message: "method not implemented".into(),
689                        }),
690                    })
691                }
692                Err(e) => Ok(Response {
693                    id: request.id,
694                    ok: false,
695                    result: None,
696                    error: Some(e),
697                }),
698            }
699        }
700        .instrument(span)
701        .await
702    }
703
704    async fn emit_event(&self, event: Event) -> Result<(), Error> {
705        let span = tracing::info_span!("ipc.emit_event", topic = ?event.topic);
706        async move {
707            // No-op for now.
708            Ok(())
709        }
710        .instrument(span)
711        .await
712    }
713
714    /// Check whether a given method is allowed by a connection’s capabilities.
715    fn is_authorized(
716        &self,
717        ctx: &ConnectionContext,
718        method: &MethodId,
719    ) -> Result<(), RpcError> {
720        use crate::workspace::ipc::auth::capability::{
721            MethodRule, ServiceName,
722        };
723
724        let service_key = ServiceName(method.service.clone());
725        let Some(rules) = ctx.capabilities.allowed.get(&service_key) else {
726            return Err(RpcError {
727                code: "unauthorized".into(),
728                message: format!("service '{}' is not allowed", method.service),
729            });
730        };
731
732        let allowed = rules.iter().any(|rule: &MethodRule| {
733            rule.method.matches(&method.service, &method.method)
734        });
735
736        if allowed {
737            Ok(())
738        } else {
739            Err(RpcError {
740                code: "unauthorized".into(),
741                message: format!(
742                    "method '{}.{}' is not allowed",
743                    method.service, method.method
744                ),
745            })
746        }
747    }
748
749    async fn observe_cancel(
750        &self,
751        _ctx: &ConnectionContext,
752        _id: u64,
753    ) -> Result<(), Error> {
754        Ok(())
755    }
756}
757
758impl IpcRouter {
759    /// Retrieve a registered connection by id.
760    pub fn get(&self, id: &ConnectionId) -> Option<ConnectionContext> {
761        if let Ok(guard) = self.connections.read() {
762            guard.iter().find(|c| c.id == *id).cloned()
763        } else {
764            None
765        }
766    }
767}
768
769#[cfg(test)]
770mod tests {
771    use super::*;
772
773    use crate::workspace::ipc::auth::capability::{
774        CapabilitySet, MethodRule, MethodSelector, QuotaSet, ServiceName,
775    };
776    use crate::workspace::ipc::process_manager::{
777        ChildHandle, ProcessManager, SpawnParams,
778    };
779    use crate::workspace::ipc::protocol::{MethodId, Request};
780    use crate::workspace::ipc::services::network::api::{
781        BytesResponse as NetBytesResponse, FetchRequest, METHOD_FETCH,
782        METHOD_READ_FILE, MockNetworkService, ReadFileRequest,
783        SERVICE_NAME as NETWORK_SERVICE_NAME,
784    };
785    use crate::workspace::ipc::services::process::MockProcessService;
786    use crate::workspace::ipc::{
787        assert_ipc_response_error, assert_ipc_response_ok,
788    };
789
790    use crate::workspace::ipc::types::{ChildKind, ProcessId};
791    use anyhow::Result;
792    use std::collections::HashMap;
793
794    fn ctx_with_rules(
795        service: &str,
796        rules: Vec<MethodRule>,
797    ) -> ConnectionContext {
798        let mut allowed: HashMap<ServiceName, Vec<MethodRule>> = HashMap::new();
799        allowed.insert(ServiceName(service.to_string()), rules);
800        ConnectionContext {
801            id: Default::default(),
802            capabilities: CapabilitySet {
803                allowed,
804                global_limits: None,
805            },
806            metadata: None,
807        }
808    }
809
810    fn req(service: &str, method: &str) -> Request {
811        Request {
812            id: Default::default(),
813            method: MethodId {
814                service: service.into(),
815                method: method.into(),
816            },
817            args: vec![],
818        }
819    }
820
821    /// Exact selector should allow only the exact method.
822    #[crate::ctb_test(tokio::test)]
823    async fn auth_exact_allows() -> Result<()> {
824        let router = IpcRouter::new();
825        let ctx = ctx_with_rules(
826            "svc",
827            vec![MethodRule {
828                method: MethodSelector::Exact("do_work".into()),
829                quotas: None,
830            }],
831        );
832
833        let resp = router.dispatch(&ctx, req("svc", "do_work")).await?;
834        assert!(!resp.ok);
835        assert_eq!(resp.error.unwrap().code, "not_implemented");
836
837        let resp2 = router.dispatch(&ctx, req("svc", "other")).await?;
838        assert!(!resp2.ok);
839        assert_eq!(resp2.error.unwrap().code, "unauthorized");
840        Ok(())
841    }
842
843    /// Prefix selector should allow matching prefix and deny others.
844    #[crate::ctb_test(tokio::test)]
845    async fn auth_prefix() -> Result<()> {
846        let router = IpcRouter::new();
847        let ctx = ctx_with_rules(
848            "svc",
849            vec![MethodRule {
850                method: MethodSelector::Prefix("do_".into()),
851                quotas: None,
852            }],
853        );
854
855        let ok_resp = router.dispatch(&ctx, req("svc", "do_stuff")).await?;
856        assert!(!ok_resp.ok);
857        assert_eq!(ok_resp.error.unwrap().code, "not_implemented");
858
859        let deny_resp = router.dispatch(&ctx, req("svc", "list")).await?;
860        assert!(!deny_resp.ok);
861        assert_eq!(deny_resp.error.unwrap().code, "unauthorized");
862        Ok(())
863    }
864
865    /// Any selector allows all methods within the service. Absence of service
866    /// entry denies all methods.
867    #[crate::ctb_test(tokio::test)]
868    async fn auth_any_and_missing_service() -> Result<()> {
869        let router = IpcRouter::new();
870
871        // Any allows all for 'svc'
872        let ctx_any = ctx_with_rules(
873            "svc",
874            vec![MethodRule {
875                method: MethodSelector::Any,
876                quotas: None,
877            }],
878        );
879
880        let ok_resp = router.dispatch(&ctx_any, req("svc", "x")).await?;
881        assert!(!ok_resp.ok);
882        assert_eq!(ok_resp.error.unwrap().code, "not_implemented");
883
884        // No entry for 'other' -> unauthorized
885        let deny_resp = router.dispatch(&ctx_any, req("other", "x")).await?;
886        assert!(!deny_resp.ok);
887        assert_eq!(deny_resp.error.unwrap().code, "unauthorized");
888        Ok(())
889    }
890
891    /// Selectors with fully-qualified strings should also match.
892    #[crate::ctb_test(tokio::test)]
893    async fn auth_fully_qualified_selectors() -> Result<()> {
894        let router = IpcRouter::new();
895        let ctx = ctx_with_rules(
896            "svc",
897            vec![
898                MethodRule {
899                    method: MethodSelector::Exact("svc.do_a".into()),
900                    quotas: None,
901                },
902                MethodRule {
903                    method: MethodSelector::Prefix("svc.do_".into()),
904                    quotas: None,
905                },
906            ],
907        );
908
909        // Exact match
910        let r1 = router.dispatch(&ctx, req("svc", "do_a")).await?;
911        assert_eq!(r1.error.unwrap().code, "not_implemented");
912
913        // Prefix match
914        let r2 = router.dispatch(&ctx, req("svc", "do_b")).await?;
915        assert_eq!(r2.error.unwrap().code, "not_implemented");
916
917        // Non-matching method
918        let r3 = router.dispatch(&ctx, req("svc", "list")).await?;
919        assert_eq!(r3.error.unwrap().code, "unauthorized");
920
921        Ok(())
922    }
923
924    /// Register_connection stores the ConnectionContext.
925    #[crate::ctb_test(tokio::test)]
926    async fn register_stores_context() -> Result<()> {
927        let router = IpcRouter::new();
928        let ctx = ConnectionContext {
929            id: ConnectionId::default(),
930            capabilities: CapabilitySet::default(),
931            metadata: Some(serde_json::json!({"k":"v"})),
932        };
933        router.register_connection(ctx.clone()).await?;
934        let got = router.get(&ctx.id).expect("stored");
935        assert_eq!(got.id, ctx.id);
936        assert_eq!(got.metadata, ctx.metadata);
937        Ok(())
938    }
939
940    #[derive(Debug, Default)]
941    struct MockProcessManager {
942        terminated: std::sync::Mutex<Vec<(ProcessId, bool)>>,
943    }
944
945    impl MockProcessManager {
946        fn terminations(&self) -> Vec<(ProcessId, bool)> {
947            if let Ok(guard) = self.terminated.lock() {
948                guard.clone()
949            } else {
950                Vec::new()
951            }
952        }
953    }
954
955    #[async_trait]
956    impl ProcessManager for MockProcessManager {
957        async fn spawn_child(
958            &self,
959            _params: SpawnParams,
960        ) -> Result<ChildHandle, Error> {
961            Ok(ChildHandle {
962                pid: ProcessId::default(),
963                kind: ChildKind::Renderer,
964                connection: None,
965            })
966        }
967
968        async fn attach_connection(
969            &self,
970            _pid: ProcessId,
971            _conn: ConnectionId,
972        ) -> Result<(), Error> {
973            Ok(())
974        }
975
976        async fn list_children(&self) -> Result<Vec<ChildHandle>, Error> {
977            Ok(Vec::new())
978        }
979
980        async fn terminate_tree(
981            &self,
982            pid: ProcessId,
983            force: bool,
984        ) -> Result<(), Error> {
985            if let Ok(mut guard) = self.terminated.lock() {
986                guard.push((pid, force));
987            }
988            Ok(())
989        }
990    }
991
992    /// Heartbeat miss should trigger termination and cleanup.
993    #[crate::ctb_test(tokio::test)]
994    async fn heartbeat_miss_triggers_termination_and_cleanup() -> Result<()> {
995        let manager = Arc::new(MockProcessManager::default());
996        let tracker = HeartbeatTracker::new(
997            manager.clone(),
998            Duration::from_millis(20),
999            1,
1000        );
1001
1002        let conn = ConnectionId::default();
1003        let pid = ProcessId::default();
1004        tracker.track_connection(conn, pid);
1005
1006        // Do not record any heartbeat; allow multiple intervals to pass.
1007        tokio::time::sleep(Duration::from_millis(100)).await;
1008
1009        let terms = manager.terminations();
1010        assert_eq!(terms.len(), 1);
1011        assert_eq!(terms[0].0, pid);
1012        assert!(terms[0].1);
1013        assert!(!tracker.is_tracked(&conn));
1014
1015        Ok(())
1016    }
1017
1018    #[crate::ctb_test(tokio::test)]
1019    async fn routes_workspace_shutdown() -> Result<()> {
1020        let router =
1021            IpcRouter::new().with_process_service(Arc::new(MockProcessService));
1022
1023        let mut allowed: HashMap<ServiceName, Vec<MethodRule>> = HashMap::new();
1024        allowed.insert(
1025            ServiceName("process".to_string()),
1026            vec![MethodRule {
1027                method: MethodSelector::Exact("shutdown_tree".into()),
1028                quotas: None,
1029            }],
1030        );
1031
1032        let ctx = ConnectionContext {
1033            id: Default::default(),
1034            capabilities: CapabilitySet {
1035                allowed,
1036                global_limits: None,
1037            },
1038            metadata: None,
1039        };
1040
1041        let req = Request {
1042            id: Default::default(),
1043            method: MethodId {
1044                service: "process".into(),
1045                method: "shutdown_tree".into(),
1046            },
1047            args: vec![],
1048        };
1049
1050        let resp = router.dispatch(&ctx, req).await?;
1051        assert_ipc_response_ok(&resp);
1052        assert!(resp.error.is_none());
1053        Ok(())
1054    }
1055
1056    /// Exceeding bytes/sec should yield a capability_denied error.
1057    #[crate::ctb_test(tokio::test)]
1058    async fn rate_limit_bytes_per_sec_is_enforced_for_io_and_network()
1059    -> Result<()> {
1060        let router = IpcRouter::new();
1061
1062        let mut allowed: HashMap<ServiceName, Vec<MethodRule>> = HashMap::new();
1063        allowed.insert(
1064            ServiceName("io".to_string()),
1065            vec![MethodRule {
1066                method: MethodSelector::Exact("read".into()),
1067                quotas: Some(QuotaSet {
1068                    bytes_per_sec: Some(10),
1069                    ops_per_sec: None,
1070                    burst: Some(10),
1071                }),
1072            }],
1073        );
1074
1075        let ctx = ConnectionContext {
1076            id: ConnectionId::default(),
1077            capabilities: CapabilitySet {
1078                allowed,
1079                global_limits: None,
1080            },
1081            metadata: None,
1082        };
1083
1084        let mk_req = |n: usize| Request {
1085            id: Default::default(),
1086            method: MethodId {
1087                service: "io".into(),
1088                method: "read".into(),
1089            },
1090            args: vec![0u8; n],
1091        };
1092
1093        // First request consumes full burst.
1094        let r1 = router.dispatch(&ctx, mk_req(10)).await?;
1095        assert_eq!(r1.error.unwrap().code, "not_implemented");
1096
1097        // Second request should be denied immediately.
1098        let r2 = router.dispatch(&ctx, mk_req(1)).await?;
1099        assert_ipc_response_error(&r2);
1100        assert_eq!(r2.error.unwrap().code, "capability_denied");
1101
1102        Ok(())
1103    }
1104
1105    /// Routes network.fetch and returns serialized bytes.
1106    #[crate::ctb_test(tokio::test)]
1107    async fn routes_network_fetch() -> anyhow::Result<()> {
1108        let router = IpcRouter::new()
1109            .with_network_service(Arc::new(MockNetworkService::new()));
1110
1111        let mut allowed: HashMap<ServiceName, Vec<MethodRule>> = HashMap::new();
1112        allowed.insert(
1113            ServiceName(NETWORK_SERVICE_NAME.to_string()),
1114            vec![MethodRule {
1115                method: MethodSelector::Exact(METHOD_FETCH.into()),
1116                quotas: None,
1117            }],
1118        );
1119
1120        let ctx = ConnectionContext {
1121            id: Default::default(),
1122            capabilities: CapabilitySet {
1123                allowed,
1124                global_limits: None,
1125            },
1126            metadata: None,
1127        };
1128
1129        let args = postcard::to_stdvec(&FetchRequest {
1130            url: "https://example.com".into(),
1131        })?;
1132        let req = Request {
1133            id: Default::default(),
1134            method: MethodId {
1135                service: NETWORK_SERVICE_NAME.into(),
1136                method: METHOD_FETCH.into(),
1137            },
1138            args,
1139        };
1140
1141        let resp = router.dispatch(&ctx, req).await?;
1142        assert_ipc_response_ok(&resp);
1143        let decoded: NetBytesResponse =
1144            postcard::from_bytes(&resp.result.unwrap())?;
1145        assert_eq!(
1146            decoded.bytes,
1147            vec![
1148                69, 120, 97, 109, 112, 108, 101, 32, 72, 84, 84, 80, 83, 32,
1149                70, 101, 116, 99, 104
1150            ]
1151        );
1152        Ok(())
1153    }
1154
1155    /// Routes network.read_file and returns serialized bytes.
1156    #[crate::ctb_test(tokio::test)]
1157    async fn routes_network_read_file() -> anyhow::Result<()> {
1158        let router = IpcRouter::new()
1159            .with_network_service(Arc::new(MockNetworkService::new()));
1160
1161        let mut allowed: HashMap<ServiceName, Vec<MethodRule>> = HashMap::new();
1162        allowed.insert(
1163            ServiceName(NETWORK_SERVICE_NAME.to_string()),
1164            vec![MethodRule {
1165                method: MethodSelector::Exact(METHOD_READ_FILE.into()),
1166                quotas: None,
1167            }],
1168        );
1169
1170        let ctx = ConnectionContext {
1171            id: Default::default(),
1172            capabilities: CapabilitySet {
1173                allowed,
1174                global_limits: None,
1175            },
1176            metadata: None,
1177        };
1178
1179        let args: Vec<u8> = postcard::to_stdvec(&ReadFileRequest {
1180            path: "/tmp/example/file.txt".into(),
1181        })?;
1182        let req = Request {
1183            id: Default::default(),
1184            method: MethodId {
1185                service: NETWORK_SERVICE_NAME.into(),
1186                method: METHOD_READ_FILE.into(),
1187            },
1188            args,
1189        };
1190
1191        let resp = router.dispatch(&ctx, req).await?;
1192        assert_ipc_response_ok(&resp);
1193        let decoded: NetBytesResponse =
1194            postcard::from_bytes(&resp.result.unwrap())?;
1195        assert_eq!(
1196            decoded.bytes,
1197            vec![
1198                69, 120, 97, 109, 112, 108, 101, 32, 70, 105, 108, 101, 32, 82,
1199                101, 97, 100
1200            ]
1201        );
1202        Ok(())
1203    }
1204}