ctoolbox/workspace/ipc_old/
server.rs

1use crate::utilities::ipc::{Channel, channel_from_json_string};
2use crate::utilities::json::jq;
3use crate::utilities::process::ProcessManager;
4use crate::utilities::{Context, Result, json};
5use crate::workspace::ipc_old::dispatch::ipc_call_method;
6use clap::Parser;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::sync::Mutex;
11use tiny_http::{Response, Server};
12
13#[derive(PartialEq, Clone, Serialize, Deserialize)]
14enum CallState {
15    Created,
16    Claimed,
17    Success,
18    Error,
19}
20
21#[derive(Clone, Serialize, Deserialize)]
22struct Call {
23    method: String,
24    args: String,
25    response: Option<String>,
26    state: CallState,
27}
28
29#[derive(Clone, Serialize, Deserialize)]
30struct CallWithId {
31    id: u64,
32    call: Call,
33}
34
35struct Calls {
36    calls: HashMap<u64, Call>,
37    last_id: u64,
38}
39
40impl Calls {
41    fn new() -> Calls {
42        Calls {
43            calls: HashMap::new(),
44            last_id: 0,
45        }
46    }
47    fn get(&mut self, id: &u64) -> &mut Call {
48        self.calls.get_mut(id).unwrap()
49    }
50}
51
52pub fn start_ipc_server(
53    process_manager: Arc<Mutex<ProcessManager>>,
54    port: u16,
55    authentication_key: String,
56) {
57    // log!("Starting IPC server");
58
59    let mut queued_calls: HashMap<String, Calls> = HashMap::new();
60
61    let mut known_channels: Vec<Channel> = vec![Channel {
62        name: "ipc_control".to_string(),
63        port,
64        authentication_key: authentication_key.clone(),
65    }];
66
67    let server = Server::http(format!("127.0.0.1:{port}")).unwrap();
68
69    for request in server.incoming_requests() {
70        handle_request(
71            process_manager.clone(),
72            request,
73            &mut known_channels,
74            &mut queued_calls,
75        );
76    }
77}
78
79fn process_request(
80    process_manager: Arc<Mutex<ProcessManager>>,
81    target_channel: Channel,
82    content: String,
83    known_channels: &mut Vec<Channel>,
84    queued_calls: &mut HashMap<String, Calls>,
85) -> Result<String> {
86    let provided_key = target_channel.authentication_key.clone();
87
88    let expected_key: String = known_channels
89        .iter()
90        .find(|c| c.name == target_channel.name)
91        .map_or(String::new(), |c| c.authentication_key.clone());
92
93    if expected_key != provided_key {
94        return Err(anyhow::anyhow!("Invalid authentication key"));
95    }
96
97    let channel_calls: &mut Calls = queued_calls
98        .entry(target_channel.name.clone())
99        .or_insert(Calls::new());
100
101    let message_type = jq(".type", content.as_str()).unwrap_or_default();
102
103    let mut response = None;
104
105    // log!(format!("Received message: {}", content).as_str());
106
107    if message_type == "workspace_call" {
108        let new_task_method = jq(".method", content.as_str()).unwrap();
109        // get Serde value from jq .args
110        let new_task_args = jq(".args", content.as_str()).unwrap();
111        let task_args_val =
112            serde_json::from_str::<serde_json::Value>(new_task_args.as_str())
113                .with_context(|| "Could not parse args as JSON")?;
114        ipc_call_method(
115            &new_task_method,
116            &task_args_val,
117            Some(process_manager),
118        );
119        response = Some(json!({"type": "success"}).to_string());
120    } else if message_type == "add_channel" {
121        let channel: Channel = serde_json::from_str::<Channel>(
122            jq(".channel", content.as_str())
123                .with_context(|| "Could not get channel from body json")?
124                .as_str(),
125        )
126        .with_context(|| "Could not parse channel from body json")?;
127        known_channels.push(channel);
128        response = Some(json!({"type": "success"}).to_string());
129    } else if message_type == "remove_channel" {
130        let name = jq(".channel.name", content.as_str())
131            .with_context(|| "Could not get channel from json")?;
132        let name = name.as_str();
133        known_channels.retain(|c| c.name != name);
134        response = Some(json!({"type": "success"}).to_string());
135    } else if message_type == "ipc_ping" {
136        response = Some(json!({"type": "ipc_ready"}).to_string());
137    } else if message_type == "poll_for_result" {
138        let msgid = jq(".msgid", content.as_str())
139            .with_context(|| "Could not jq msgid")?;
140        if msgid == "null" {
141            return Err(anyhow::anyhow!("No msgid provided"));
142        }
143        let msgid = msgid
144            .parse::<u64>()
145            .with_context(|| "Could not parse msgid")?;
146        let message = &channel_calls.get(&msgid);
147        if message.state == CallState::Success {
148            response = Some(
149                json!({"type": "result", "content": message.response})
150                    .to_string(),
151            );
152            channel_calls.calls.remove(&msgid);
153        } else if message.state == CallState::Error {
154            response = Some(json!({"type": "error"}).to_string());
155        } else {
156            response = Some(json!({"type": "pending"}).to_string());
157        }
158    } else if message_type == "poll_for_task" {
159        response = Some(json!({"type": "no_new_tasks"}));
160        for (msgid, call) in &mut channel_calls.calls {
161            if call.state == CallState::Created {
162                call.state = CallState::Claimed;
163                response = Some(
164                    json!({
165                        "type": "new_task",
166                        "method": call.method,
167                        "args": call.args,
168                        "msgid": msgid,
169                    })
170                    .to_string(),
171                );
172                break;
173            }
174        }
175    } else if message_type == "call" {
176        let method = jq(".method", content.as_str())
177            .with_context(|| "Could not query method")?;
178        let args = jq(".args", content.as_str())
179            .with_context(|| "Could not query args")?;
180        let msgid = channel_calls.last_id + 1;
181        channel_calls.calls.insert(
182            msgid,
183            Call {
184                method: method.clone(),
185                args: args.clone(),
186                response: None,
187                state: CallState::Created,
188            },
189        );
190        channel_calls.last_id = msgid;
191        response =
192            Some(json!({"type": "call_pending", "msgid": msgid}).to_string());
193    } else if message_type == "response" {
194        let msgid: u64 = jq(".msgid", content.as_str())
195            .with_context(|| "Could not jq msgid")?
196            .parse::<u64>()
197            .with_context(|| "Could not parse msgid")?;
198
199        let content = jq(".content", content.as_str())
200            .with_context(|| "Could not jq content")?;
201        channel_calls.get(&msgid).response = Some(content.clone());
202        channel_calls.get(&msgid).state = CallState::Success;
203        response = Some(json!({"type": "success"}).to_string());
204    }
205
206    response.with_context(|| format!("Invalid message type {message_type}"))
207}
208
209fn error_response(request: tiny_http::Request, message: &str) {
210    match request.respond(Response::from_string(
211        json!({"type": "error", "message": message}).to_string(),
212    )) {
213        Ok(()) => {}
214        Err(e) => {
215            eprintln!("Failed to send error response: {e}");
216        }
217    }
218}
219
220fn handle_request(
221    process_manager: Arc<Mutex<ProcessManager>>,
222    mut request: tiny_http::Request,
223    known_channels: &mut Vec<Channel>,
224    queued_calls: &mut HashMap<String, Calls>,
225) {
226    let target_channel = channel_from_json_string(
227        request
228            .headers()
229            .iter()
230            .find(|h| h.field.equiv("X-CollectiveToolbox-IPCAuth"))
231            .map_or_else(
232                || "No channel provided".to_string(),
233                |h| h.value.clone().to_string(),
234            )
235            .as_str(),
236    )
237    .context("Invalid channel header");
238
239    if target_channel.is_err() {
240        error_response(request, "Invalid channel");
241        return;
242    }
243
244    let mut content = String::new();
245    if request
246        .as_reader()
247        .read_to_string(&mut content)
248        .context("Could not read request body")
249        .is_err()
250    {
251        error_response(request, "Could not read request body");
252        return;
253    }
254
255    let response = process_request(
256        process_manager,
257        target_channel.unwrap(),
258        content,
259        known_channels,
260        queued_calls,
261    );
262
263    let response = match response {
264        Ok(r) => r,
265        Err(e) => {
266            json!({"type": "error", "message": e.to_string()}).to_string()
267        }
268    };
269
270    request.respond(Response::from_string(response)).unwrap();
271}