1use crate::workspace::ipc::auth::capability::CapabilityBundle;
10use crate::workspace::ipc::error::Error;
11use crate::workspace::ipc::services::process::api::ShutdownTreeResponse;
12use crate::workspace::ipc::types::ChildKind;
13use crate::workspace::ipc::types::{ConnectionId, ProcessId};
14use async_trait::async_trait;
15use serde::{Deserialize, Serialize};
16use std::future::Future;
17use std::time::Duration;
18use tokio::time;
19
20pub mod unix;
21pub mod windows;
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct SpawnParams {
26 pub kind: ChildKind,
27 pub program: Option<String>,
29 pub args: Vec<String>,
31 pub env: Vec<(String, String)>,
33 pub cwd: Option<String>,
35 pub capabilities: CapabilityBundle,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct ChildHandle {
42 pub pid: ProcessId,
43 pub kind: ChildKind,
44 pub connection: Option<ConnectionId>,
45}
46
47#[async_trait]
50pub trait ProcessManager: Send + Sync + std::fmt::Debug {
51 async fn spawn_child(
52 &self,
53 params: SpawnParams,
54 ) -> Result<ChildHandle, Error>;
55
56 async fn attach_connection(
57 &self,
58 pid: ProcessId,
59 conn: ConnectionId,
60 ) -> Result<(), Error>;
61
62 async fn list_children(&self) -> Result<Vec<ChildHandle>, Error>;
63
64 async fn terminate_tree(
67 &self,
68 pid: ProcessId,
69 force: bool,
70 ) -> Result<(), Error>;
71
72 async fn wait_for_exit(
81 &self,
82 _pid: ProcessId,
83 timeout: Duration,
84 ) -> Result<bool, Error> {
85 time::sleep(timeout).await;
86 Ok(false)
87 }
88
89 async fn kill_tree(&self, pid: ProcessId) -> Result<(), Error> {
93 self.terminate_tree(pid, true).await
94 }
95}
96
97#[derive(Debug, Clone, Copy, PartialEq, Eq)]
99pub struct GracefulShutdownOutcome {
100 pub acknowledged: bool,
102 pub exited: bool,
104 pub forced: bool,
106}
107
108pub async fn graceful_shutdown_tree<F, Fut>(
122 process_manager: &dyn ProcessManager,
123 pid: ProcessId,
124 send_shutdown: F,
125 ack_timeout: Duration,
126 exit_timeout: Duration,
127) -> Result<GracefulShutdownOutcome, Error>
128where
129 F: FnOnce() -> Fut + Send,
130 Fut: Future<Output = Result<ShutdownTreeResponse, Error>> + Send,
131{
132 let acknowledged = match time::timeout(ack_timeout, send_shutdown()).await {
133 Ok(Ok(resp)) => resp.acknowledged,
134 Ok(Err(_)) => false,
135 Err(_elapsed) => false,
136 };
137
138 if acknowledged {
139 let exited = process_manager.wait_for_exit(pid, exit_timeout).await?;
140 if exited {
141 return Ok(GracefulShutdownOutcome {
142 acknowledged,
143 exited,
144 forced: false,
145 });
146 }
147 }
148
149 process_manager.kill_tree(pid).await?;
150 let exited = process_manager.wait_for_exit(pid, exit_timeout).await?;
151
152 Ok(GracefulShutdownOutcome {
153 acknowledged,
154 exited,
155 forced: true,
156 })
157}
158
159#[cfg(unix)]
160pub use unix::TokioProcessManager;
161#[cfg(windows)]
162pub use windows::TokioProcessManager;
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167 use crate::workspace::ipc::types::ConnectionId;
168 use anyhow::Result;
169 use std::sync::Arc;
170 use std::sync::atomic::{AtomicUsize, Ordering};
171 use tokio::sync::Notify;
172
173 #[derive(Debug)]
174 struct MockProcessManager {
175 killed: AtomicUsize,
176 waited: AtomicUsize,
177 exited: Arc<Notify>,
178 }
179
180 impl MockProcessManager {
181 fn new(exited: Arc<Notify>) -> Self {
182 Self {
183 killed: AtomicUsize::new(0),
184 waited: AtomicUsize::new(0),
185 exited,
186 }
187 }
188 }
189
190 #[async_trait]
191 impl ProcessManager for MockProcessManager {
192 async fn spawn_child(
193 &self,
194 _params: SpawnParams,
195 ) -> Result<ChildHandle, Error> {
196 Ok(ChildHandle {
197 pid: ProcessId::default(),
198 kind: ChildKind::Renderer,
199 connection: None,
200 })
201 }
202
203 async fn attach_connection(
204 &self,
205 _pid: ProcessId,
206 _conn: ConnectionId,
207 ) -> Result<(), Error> {
208 Ok(())
209 }
210
211 async fn list_children(&self) -> Result<Vec<ChildHandle>, Error> {
212 Ok(Vec::new())
213 }
214
215 async fn terminate_tree(
216 &self,
217 _pid: ProcessId,
218 _force: bool,
219 ) -> Result<(), Error> {
220 self.killed.fetch_add(1, Ordering::SeqCst);
221 let exited = self.exited.clone();
222 tokio::spawn(async move {
223 time::sleep(Duration::from_millis(1)).await;
224 exited.notify_waiters();
225 });
226 Ok(())
227 }
228
229 async fn wait_for_exit(
230 &self,
231 _pid: ProcessId,
232 timeout: Duration,
233 ) -> Result<bool, Error> {
234 self.waited.fetch_add(1, Ordering::SeqCst);
235 match time::timeout(timeout, self.exited.notified()).await {
236 Ok(()) => Ok(true),
237 Err(_elapsed) => Ok(false),
238 }
239 }
240 }
241
242 #[crate::ctb_test(tokio::test)]
243 async fn graceful_shutdown_cooperative_child() -> Result<()> {
244 let exited = Arc::new(Notify::new());
245 let pm = MockProcessManager::new(exited.clone());
246 let pid = ProcessId::default();
247
248 let outcome = graceful_shutdown_tree(
249 &pm,
250 pid,
251 || async {
252 let exited = exited.clone();
253 let _ = tokio::spawn(async move {
254 time::sleep(Duration::from_millis(10)).await;
255 exited.notify_waiters();
256 });
257 Ok(ShutdownTreeResponse { acknowledged: true })
258 },
259 Duration::from_millis(50),
260 Duration::from_millis(200),
261 )
262 .await?;
263
264 assert_eq!(
265 outcome,
266 GracefulShutdownOutcome {
267 acknowledged: true,
268 exited: true,
269 forced: false,
270 }
271 );
272 assert_eq!(pm.killed.load(Ordering::SeqCst), 0);
273 assert!(pm.waited.load(Ordering::SeqCst) >= 1);
274 Ok(())
275 }
276
277 #[crate::ctb_test(tokio::test)]
278 async fn graceful_shutdown_ignoring_child_forces_kill() -> Result<()> {
279 let exited = Arc::new(Notify::new());
280 let pm = MockProcessManager::new(exited.clone());
281 let pid = ProcessId::default();
282
283 let outcome = graceful_shutdown_tree(
284 &pm,
285 pid,
286 || async {
287 std::future::pending::<Result<ShutdownTreeResponse, Error>>()
288 .await
289 },
290 Duration::from_millis(20),
291 Duration::from_millis(200),
292 )
293 .await?;
294
295 assert!(!outcome.acknowledged);
296 assert!(outcome.forced);
297 assert!(outcome.exited);
298 assert_eq!(pm.killed.load(Ordering::SeqCst), 1);
299 assert!(pm.waited.load(Ordering::SeqCst) >= 1);
300 Ok(())
301 }
302}