ctoolbox/workspace/ipc_old/
server.rs1use crate::utilities::ipc::{Channel, channel_from_json_string};
2use crate::utilities::json::jq;
3use crate::utilities::process::ProcessManager;
4use crate::utilities::{Context, Result, json};
5use crate::workspace::ipc_old::dispatch::ipc_call_method;
6use clap::Parser;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::sync::Mutex;
11use tiny_http::{Response, Server};
12
13#[derive(PartialEq, Clone, Serialize, Deserialize)]
14enum CallState {
15 Created,
16 Claimed,
17 Success,
18 Error,
19}
20
21#[derive(Clone, Serialize, Deserialize)]
22struct Call {
23 method: String,
24 args: String,
25 response: Option<String>,
26 state: CallState,
27}
28
29#[derive(Clone, Serialize, Deserialize)]
30struct CallWithId {
31 id: u64,
32 call: Call,
33}
34
35struct Calls {
36 calls: HashMap<u64, Call>,
37 last_id: u64,
38}
39
40impl Calls {
41 fn new() -> Calls {
42 Calls {
43 calls: HashMap::new(),
44 last_id: 0,
45 }
46 }
47 fn get(&mut self, id: &u64) -> &mut Call {
48 self.calls.get_mut(id).unwrap()
49 }
50}
51
52pub fn start_ipc_server(
53 process_manager: Arc<Mutex<ProcessManager>>,
54 port: u16,
55 authentication_key: String,
56) {
57 let mut queued_calls: HashMap<String, Calls> = HashMap::new();
60
61 let mut known_channels: Vec<Channel> = vec![Channel {
62 name: "ipc_control".to_string(),
63 port,
64 authentication_key: authentication_key.clone(),
65 }];
66
67 let server = Server::http(format!("127.0.0.1:{port}")).unwrap();
68
69 for request in server.incoming_requests() {
70 handle_request(
71 process_manager.clone(),
72 request,
73 &mut known_channels,
74 &mut queued_calls,
75 );
76 }
77}
78
79fn process_request(
80 process_manager: Arc<Mutex<ProcessManager>>,
81 target_channel: Channel,
82 content: String,
83 known_channels: &mut Vec<Channel>,
84 queued_calls: &mut HashMap<String, Calls>,
85) -> Result<String> {
86 let provided_key = target_channel.authentication_key.clone();
87
88 let expected_key: String = known_channels
89 .iter()
90 .find(|c| c.name == target_channel.name)
91 .map_or(String::new(), |c| c.authentication_key.clone());
92
93 if expected_key != provided_key {
94 return Err(anyhow::anyhow!("Invalid authentication key"));
95 }
96
97 let channel_calls: &mut Calls = queued_calls
98 .entry(target_channel.name.clone())
99 .or_insert(Calls::new());
100
101 let message_type = jq(".type", content.as_str()).unwrap_or_default();
102
103 let mut response = None;
104
105 if message_type == "workspace_call" {
108 let new_task_method = jq(".method", content.as_str()).unwrap();
109 let new_task_args = jq(".args", content.as_str()).unwrap();
111 let task_args_val =
112 serde_json::from_str::<serde_json::Value>(new_task_args.as_str())
113 .with_context(|| "Could not parse args as JSON")?;
114 ipc_call_method(
115 &new_task_method,
116 &task_args_val,
117 Some(process_manager),
118 );
119 response = Some(json!({"type": "success"}).to_string());
120 } else if message_type == "add_channel" {
121 let channel: Channel = serde_json::from_str::<Channel>(
122 jq(".channel", content.as_str())
123 .with_context(|| "Could not get channel from body json")?
124 .as_str(),
125 )
126 .with_context(|| "Could not parse channel from body json")?;
127 known_channels.push(channel);
128 response = Some(json!({"type": "success"}).to_string());
129 } else if message_type == "remove_channel" {
130 let name = jq(".channel.name", content.as_str())
131 .with_context(|| "Could not get channel from json")?;
132 let name = name.as_str();
133 known_channels.retain(|c| c.name != name);
134 response = Some(json!({"type": "success"}).to_string());
135 } else if message_type == "ipc_ping" {
136 response = Some(json!({"type": "ipc_ready"}).to_string());
137 } else if message_type == "poll_for_result" {
138 let msgid = jq(".msgid", content.as_str())
139 .with_context(|| "Could not jq msgid")?;
140 if msgid == "null" {
141 return Err(anyhow::anyhow!("No msgid provided"));
142 }
143 let msgid = msgid
144 .parse::<u64>()
145 .with_context(|| "Could not parse msgid")?;
146 let message = &channel_calls.get(&msgid);
147 if message.state == CallState::Success {
148 response = Some(
149 json!({"type": "result", "content": message.response})
150 .to_string(),
151 );
152 channel_calls.calls.remove(&msgid);
153 } else if message.state == CallState::Error {
154 response = Some(json!({"type": "error"}).to_string());
155 } else {
156 response = Some(json!({"type": "pending"}).to_string());
157 }
158 } else if message_type == "poll_for_task" {
159 response = Some(json!({"type": "no_new_tasks"}));
160 for (msgid, call) in &mut channel_calls.calls {
161 if call.state == CallState::Created {
162 call.state = CallState::Claimed;
163 response = Some(
164 json!({
165 "type": "new_task",
166 "method": call.method,
167 "args": call.args,
168 "msgid": msgid,
169 })
170 .to_string(),
171 );
172 break;
173 }
174 }
175 } else if message_type == "call" {
176 let method = jq(".method", content.as_str())
177 .with_context(|| "Could not query method")?;
178 let args = jq(".args", content.as_str())
179 .with_context(|| "Could not query args")?;
180 let msgid = channel_calls.last_id + 1;
181 channel_calls.calls.insert(
182 msgid,
183 Call {
184 method: method.clone(),
185 args: args.clone(),
186 response: None,
187 state: CallState::Created,
188 },
189 );
190 channel_calls.last_id = msgid;
191 response =
192 Some(json!({"type": "call_pending", "msgid": msgid}).to_string());
193 } else if message_type == "response" {
194 let msgid: u64 = jq(".msgid", content.as_str())
195 .with_context(|| "Could not jq msgid")?
196 .parse::<u64>()
197 .with_context(|| "Could not parse msgid")?;
198
199 let content = jq(".content", content.as_str())
200 .with_context(|| "Could not jq content")?;
201 channel_calls.get(&msgid).response = Some(content.clone());
202 channel_calls.get(&msgid).state = CallState::Success;
203 response = Some(json!({"type": "success"}).to_string());
204 }
205
206 response.with_context(|| format!("Invalid message type {message_type}"))
207}
208
209fn error_response(request: tiny_http::Request, message: &str) {
210 match request.respond(Response::from_string(
211 json!({"type": "error", "message": message}).to_string(),
212 )) {
213 Ok(()) => {}
214 Err(e) => {
215 eprintln!("Failed to send error response: {e}");
216 }
217 }
218}
219
220fn handle_request(
221 process_manager: Arc<Mutex<ProcessManager>>,
222 mut request: tiny_http::Request,
223 known_channels: &mut Vec<Channel>,
224 queued_calls: &mut HashMap<String, Calls>,
225) {
226 let target_channel = channel_from_json_string(
227 request
228 .headers()
229 .iter()
230 .find(|h| h.field.equiv("X-CollectiveToolbox-IPCAuth"))
231 .map_or_else(
232 || "No channel provided".to_string(),
233 |h| h.value.clone().to_string(),
234 )
235 .as_str(),
236 )
237 .context("Invalid channel header");
238
239 if target_channel.is_err() {
240 error_response(request, "Invalid channel");
241 return;
242 }
243
244 let mut content = String::new();
245 if request
246 .as_reader()
247 .read_to_string(&mut content)
248 .context("Could not read request body")
249 .is_err()
250 {
251 error_response(request, "Could not read request body");
252 return;
253 }
254
255 let response = process_request(
256 process_manager,
257 target_channel.unwrap(),
258 content,
259 known_channels,
260 queued_calls,
261 );
262
263 let response = match response {
264 Ok(r) => r,
265 Err(e) => {
266 json!({"type": "error", "message": e.to_string()}).to_string()
267 }
268 };
269
270 request.respond(Response::from_string(response)).unwrap();
271}