ctoolbox/io/
webui.rs

1use std::collections::HashMap;
2use std::net::{Ipv4Addr, SocketAddr};
3use std::sync::Arc;
4
5use anyhow::{Context, Result};
6use axum::Router;
7use axum::extract::FromRequestParts;
8use axum::response::{Html, IntoResponse, Response};
9use axum_server::tls_rustls::RustlsConfig;
10use handlebars::Handlebars;
11use http::StatusCode;
12use maplit::btreemap;
13use portpicker::pick_unused_port;
14use serde::{Deserialize, Serialize};
15use serde_json::{Value, json, to_value};
16use std::time::Duration;
17use tokio::sync::Mutex;
18use tower_http::compression::CompressionLayer;
19use tower_http::cors::CorsLayer;
20
21use crate::formats::markdown::markdown2html;
22use crate::io::webui::access_log_layer::AccessLogLayer;
23use crate::io::webui::routes::build_routes;
24use crate::io::webui::session_auth::{AuthenticatedUser, Session, SharedUser};
25use crate::json_value;
26use crate::storage::{get_asset, register_views};
27use crate::utilities::serde_value::insert_key;
28use crate::utilities::*;
29
30pub mod access_log_layer;
31pub mod error;
32pub mod flexible_form;
33pub mod routes;
34pub mod session_auth;
35pub mod test_helpers;
36pub mod webview;
37pub mod controllers {
38    pub mod app;
39    pub mod auth;
40    pub mod base;
41    pub mod graph;
42    pub mod search;
43    pub mod web;
44}
45
46// Shared application state
47#[derive(Clone)]
48pub struct AppState {
49    hbs: Arc<Handlebars<'static>>,
50    sessions: Arc<Mutex<HashMap<Vec<u8>, Session>>>,
51    sessions_by_user: Arc<Mutex<HashMap<u64, Vec<Vec<u8>>>>>,
52    users: Arc<Mutex<HashMap<u64, SharedUser>>>,
53}
54
55impl Default for AppState {
56    fn default() -> Self {
57        let hbs = register_views();
58        Self {
59            hbs: Arc::new(hbs),
60            sessions: Arc::new(Mutex::new(HashMap::new())),
61            users: Arc::new(Mutex::new(HashMap::new())),
62            sessions_by_user: Arc::new(Mutex::new(HashMap::new())),
63        }
64    }
65}
66
67/// Trait for types that can serve as a context for a Handlebars view.
68pub trait ViewContext: Serialize {
69    /// Used to add or override keys for layouts.
70    fn with_content(self, content: String) -> serde_json::Value
71    where
72        Self: Sized,
73    {
74        // Compose self and content into a merged JSON object.
75        let mut map = serde_json::to_value(self)
76            .expect("context to be serializable")
77            .as_object()
78            .cloned()
79            .unwrap_or_default();
80        map.insert("content".to_string(), Value::String(content));
81        Value::Object(map)
82    }
83}
84
85// Blanket impl for all Serialize types
86impl<T: Serialize> ViewContext for T {}
87
88// --- 2. Example context structs ---
89
90#[derive(Serialize)]
91struct ErrorContext {
92    message: String,
93    message_details: String,
94}
95
96pub fn start_webui_server() -> u16 {
97    log!("Starting local web UI server");
98    let current_settings =
99        crate::storage::pc_settings::PcSettings::load().unwrap_or_default();
100    let protocol: String =
101        if let Some(ref _cert) = current_settings.tls_certificate {
102            log!("Using HTTPS");
103            "https".to_string()
104        } else {
105            log!("Using HTTP");
106            "http".to_string()
107        };
108    let relevant_port = if protocol == "http" {
109        current_settings.fixed_http_port
110    } else {
111        current_settings.fixed_https_port
112    };
113    let port: u16 = if let Some(port) = relevant_port {
114        log!("Using fixed port from settings: {}", port);
115        port
116    } else {
117        pick_unused_port().expect("No ports free")
118    };
119    let bind_to_ip: String = current_settings.bind_to_ip;
120    log!("Using server address: {}", bind_to_ip.clone());
121    let domain: String = if let Some(domain) = current_settings.domain_name {
122        log!("Using configured domain name: {}", domain.clone());
123        domain
124    } else {
125        bind_to_ip.to_string()
126    };
127    let protocol_clone = protocol.clone();
128    let bind_to_ip_clone = bind_to_ip.clone();
129    let domain_clone = domain.clone();
130    let tls_certificate = current_settings.tls_certificate.clone();
131    let tls_private_key = current_settings.tls_private_key.clone();
132    std::thread::spawn(move || {
133        if let Err(e) = start_webui_server_inner(
134            port,
135            protocol_clone,
136            bind_to_ip_clone,
137            Some(domain_clone),
138            tls_certificate,
139            tls_private_key,
140        ) {
141            log!(format!("Web UI server failed to start: {e:?}"));
142        }
143    });
144
145    // If using HTTPS and not on port 80, also start up a HTTP->HTTPS redirector
146    if protocol == "https" && port != 80 && current_settings.http_redirect {
147        let redirect_from_port = 80;
148        let bind_to_ip_clone = bind_to_ip.clone();
149        std::thread::spawn(move || {
150            // Check if we can bind to port 80 on the given IP
151            let can_bind = bind_to_ip_clone
152                .parse::<Ipv4Addr>()
153                .ok()
154                .and_then(|ip| {
155                    std::net::TcpListener::bind((ip, redirect_from_port)).ok()
156                })
157                .is_some();
158            if !can_bind {
159                log!(
160                    "Cannot bind to port 80 for HTTP->HTTPS redirector, skipping"
161                );
162                return;
163            }
164
165            let rt = match tokio::runtime::Builder::new_current_thread()
166                .enable_all()
167                .thread_name("localwebui-redirect")
168                .build()
169            {
170                Ok(rt) => rt,
171                Err(e) => {
172                    log!(format!(
173                        "Failed building redirector tokio runtime: {e:?}"
174                    ));
175                    return;
176                }
177            };
178
179            let result = rt.block_on(http_to_https(
180                bind_to_ip_clone,
181                redirect_from_port,
182                Some(port),
183            ));
184            if let Err(e) = result {
185                log!(format!("HTTP->HTTPS redirector failed: {e:?}"));
186            }
187        });
188    }
189
190    let url = format!("{protocol}://{domain}:{port}");
191    let result = webbrowser::open(url.as_str());
192    if let Err(e) = result {
193        log!(format!("Failed to open web browser automatically: {e:?}"));
194        log!(format!(
195            "Please open your web browser and navigate to {url}"
196        ));
197    } else {
198        log!(format!("Web browser opened to {url}"));
199    }
200
201    port
202}
203
204async fn http_to_https(
205    bind_to_ip: String,
206    redirect_from_port: u16,
207    relevant_port: Option<u16>,
208) -> Result<()> {
209    let ip = bind_to_ip.parse::<Ipv4Addr>().with_context(|| {
210        format!("Could not parse bind IP address: {bind_to_ip}")
211    })?;
212    let addr = SocketAddr::from((ip, redirect_from_port));
213
214    axum_server::bind(addr)
215        .serve(
216            Router::new()
217                .fallback(axum::routing::any(
218                    move |req: axum::http::Request<_>| async move {
219                        let host = req
220                            .headers()
221                            .get("host")
222                            .and_then(|h| h.to_str().ok())
223                            .unwrap_or("");
224                        let uri = req.uri().to_string();
225                        let redirect_to_port = if let Some(relevant_port) =
226                            relevant_port
227                            && relevant_port != 443
228                        {
229                            format!(":{}", relevant_port)
230                        } else {
231                            "".to_string()
232                        };
233                        let redirect_url = format!(
234                            "https://{}{}{}",
235                            host, redirect_to_port, uri
236                        );
237                        axum::response::Redirect::permanent(&redirect_url)
238                    },
239                ))
240                .into_make_service(),
241        )
242        .await
243        .context("Error in HTTP to HTTPS redirector")?;
244
245    Ok(())
246}
247
248const SLOW_TTFB_THRESHOLD: Duration = Duration::from_millis(150);
249
250pub fn build_app_router(state: AppState) -> Router {
251    build_routes(state)
252        .layer(AccessLogLayer::new(SLOW_TTFB_THRESHOLD))
253        .layer(CompressionLayer::new())
254        .layer(CorsLayer::permissive())
255}
256
257fn start_webui_server_inner(
258    port: u16,
259    protocol: String,
260    bind_to_ip: String,
261    domain_name: Option<String>,
262    tls_certificate: Option<String>,
263    tls_private_key: Option<String>,
264) -> Result<()> {
265    // Build templates once and share via state
266    let hbs = register_views();
267    let state = AppState {
268        hbs: Arc::new(hbs),
269        sessions: Arc::new(Mutex::new(HashMap::new())),
270        users: Arc::new(Mutex::new(HashMap::new())),
271        sessions_by_user: Arc::new(Mutex::new(HashMap::new())),
272    };
273
274    let app = build_app_router(state);
275
276    // Run on a dedicated runtime in this thread
277    let rt = tokio::runtime::Builder::new_multi_thread()
278        .enable_all()
279        .thread_name("localwebui-axum")
280        .build()
281        .context("failed building tokio runtime")?;
282
283    rt.block_on(async move {
284        let ip = bind_to_ip.parse::<Ipv4Addr>().with_context(|| {
285            format!("Could not parse bind IP address: {bind_to_ip}")
286        })?;
287        let addr = SocketAddr::from((ip, port));
288
289        // NOTE: into_make_service_with_connect_info enables the ConnectInfo<SocketAddr> extraction
290        let make_service =
291            app.into_make_service_with_connect_info::<SocketAddr>();
292
293        if protocol == "http" {
294            axum_server::bind(addr)
295                .serve(make_service)
296                .await
297                .context("HTTP server exited with error")?;
298            return Ok(());
299        }
300
301        let cert_vec = tls_certificate
302            .context("TLS certificate not provided, cannot start HTTPS server")?
303            .into_bytes();
304        let key_vec = tls_private_key
305            .context("TLS private key not provided, cannot start HTTPS server")?
306            .into_bytes();
307
308        let config = RustlsConfig::from_pem(cert_vec, key_vec)
309            .await
310            .context("Failed to build RustlsConfig from PEM")?;
311
312        axum_server::bind_rustls(addr, config)
313            .serve(make_service)
314            .await
315            .context("HTTPS server exited with error")?;
316
317        Ok(())
318    })
319}
320
321#[derive(Serialize, Clone)]
322pub struct RequestState {
323    route: String,
324    method: String,
325    accept: Option<String>,
326    is_js_request: bool,
327}
328
329impl<S> FromRequestParts<S> for RequestState
330where
331    S: Send + Sync,
332{
333    type Rejection = StatusCode;
334    async fn from_request_parts(
335        parts: &mut axum::http::request::Parts,
336        _state: &S,
337    ) -> Result<Self, Self::Rejection> {
338        Ok(RequestState {
339            route: parts.uri.path().to_string(),
340            method: parts.method.to_string(),
341            accept: parts
342                .headers
343                .get(axum::http::header::ACCEPT)
344                .map(|v| v.to_str().unwrap().to_string()),
345            is_js_request: parts
346            .headers
347            .get("X-CollectiveToolbox-IsJsRequest")
348            .and_then(|v| v.to_str().ok())
349            .map(|s| s.eq_ignore_ascii_case("true"))
350            .unwrap_or(false)
351        })
352    }
353}
354
355#[derive(Deserialize)]
356/// The `PageQuery` struct is meant for extracting query parameters from the request URL, specifically a ?page=N parameter. For example, /nodes?page=2.
357pub struct PageQuery {
358    page: Option<String>,
359}
360
361// ================ Render helpers ================
362
363fn respond_general<T: serde::Serialize>(
364    state: &AppState,
365    req: RequestState,
366    view: &str,
367    data: &T,
368) -> Response {
369    match render_page(&state.hbs, None, view.to_string(), &req, data) {
370        Ok(html) => Html(html).into_response(),
371        Err(e) => error_400(state, &req, e),
372    }
373}
374
375fn respond_page<T: serde::Serialize>(
376    state: &AppState,
377    req: RequestState,
378    view: &str,
379    data: &T,
380) -> Response {
381    match render_page(&state.hbs, Some("page"), view.to_string(), &req, data) {
382        Ok(html) => Html(html).into_response(),
383        Err(e) => error_400(state, &req, e),
384    }
385}
386
387fn respond_markdown_page(
388    state: &AppState,
389    req: RequestState,
390    view: &str,
391) -> Response {
392    let md = get_asset(format!("views/pages/{view}.md").as_str());
393
394    if let Some(md) = md {
395        let page = markdown2html(md);
396
397        return match render_page(
398            &state.hbs,
399            Some("page"),
400            "pages.markdown".to_string(),
401            &req,
402            &json_value!({ "page" => String::from_utf8_lossy(&page).to_string() }),
403        ) {
404            Ok(html) => Html(html).into_response(),
405            Err(e) => error_400(state, &req, e),
406        };
407    } else {
408        return error_404(
409            state,
410            &req,
411            format!("Markdown page not found: {}", view),
412        );
413    }
414}
415
416fn respond_dialog<T: serde::Serialize>(
417    state: &AppState,
418    req: RequestState,
419    view: &str,
420    data: &T,
421) -> Response {
422    match render_page(
423        &state.hbs,
424        Some("dialog"),
425        format!("dialogs.{view}"),
426        &req,
427        data,
428    ) {
429        Ok(html) => Html(html).into_response(),
430        Err(e) => error_400(state, &req, e),
431    }
432}
433
434// ================ Error helpers ================
435
436fn error_500<E: std::fmt::Debug + std::fmt::Display>(
437    state: &AppState,
438    req: &RequestState,
439    e: E,
440) -> Response {
441    error_response(state, req, e, StatusCode::INTERNAL_SERVER_ERROR)
442}
443
444fn error_400<E: std::fmt::Debug + std::fmt::Display>(
445    state: &AppState,
446    req: &RequestState,
447    e: E,
448) -> Response {
449    error_response(state, req, e, StatusCode::BAD_REQUEST)
450}
451
452fn error_401<E: std::fmt::Debug + std::fmt::Display>(
453    state: &AppState,
454    req: &RequestState,
455    e: E,
456) -> Response {
457    error_response(state, req, e, StatusCode::UNAUTHORIZED)
458}
459
460fn error_403<E: std::fmt::Debug + std::fmt::Display>(
461    state: &AppState,
462    req: &RequestState,
463    e: E,
464) -> Response {
465    error_response(state, req, e, StatusCode::FORBIDDEN)
466}
467
468fn error_404<E: std::fmt::Debug + std::fmt::Display>(
469    state: &AppState,
470    req: &RequestState,
471    e: E,
472) -> Response {
473    error_response(state, req, e, StatusCode::NOT_FOUND)
474}
475
476fn recoverable_error<E: std::fmt::Display>(
477    state: &AppState,
478    req: RequestState,
479    e: E,
480) -> Response {
481    // FIXME: Use JS to intercept this (if JS is running) and show a modal dialog instead of a full page
482    let mut response = respond_page(
483        state,
484        req,
485        "layouts._recoverable-error",
486        &btreemap! { "recoverable_error_message".to_string() =>  e.to_string()},
487    );
488    let status = response.status_mut();
489    *status = StatusCode::BAD_REQUEST;
490    response
491}
492
493fn error_response<E: std::fmt::Debug + std::fmt::Display>(
494    state: &AppState,
495    req: &RequestState,
496    e: E,
497    status_code: StatusCode,
498) -> Response {
499    let accept = req.accept.clone();
500
501    let (message, details) = {
502        let message = e.to_string();
503        let details = format!("{e:?}");
504        (message, details)
505    };
506
507    if let Some(accept) = accept
508        && accept.contains("application/json")
509    {
510        // Return JSON error
511        return error_response_json_with_details(
512            message.clone(),
513            details.clone(),
514            status_code,
515        );
516    }
517
518    // Default or for "text/html"
519    match render_page(
520        &state.hbs,
521        Some("page"),
522        "error".to_string(),
523        req,
524        &ErrorContext {
525            message: message.clone(),
526            message_details: format!(
527                "{message}\nHTTP Status: {status_code}\ndetails:\n{details}"
528            ),
529        },
530    ) {
531        Ok(html) => {
532            let mut resp = Html(html).into_response();
533            *resp.status_mut() = status_code;
534            resp
535        }
536        Err(e) => error_response_json_with_details(
537            format!("Error rendering error response {e:?}"),
538            details,
539            status_code,
540        ),
541    }
542}
543
544/// Returns a JSON error response including message and details.
545fn error_response_json_with_details<E: std::fmt::Display>(
546    message: E,
547    details: String,
548    status_code: StatusCode,
549) -> Response {
550    let body = json!({
551        "type": "error",
552        "message": message.to_string(),
553        "message_details": details,
554    });
555    (status_code, axum::Json(body)).into_response()
556}
557
558// ================ Template rendering ================
559
560fn render_view<T: serde::Serialize>(
561    hbs: &Handlebars<'_>,
562    view: String,
563    req: &RequestState,
564    data: &T,
565) -> Result<String> {
566    hbs_render(hbs, &view, req, data).context("Could not render view")
567}
568
569fn render_page<T: serde::Serialize>(
570    hbs: &Handlebars<'_>,
571    layout: Option<&str>,
572    view: String,
573    req: &RequestState,
574    data: &T,
575) -> Result<String> {
576    let view_rendered = hbs_render(hbs, view.as_str(), req, data)?;
577
578    let layout_rendered = if let Some(layout) = layout {
579        hbs_render(
580            hbs,
581            format!("layouts.{layout}").as_str(),
582            req,
583            &data.with_content(view_rendered),
584        )
585    } else {
586        Ok(view_rendered)
587    }?;
588
589    hbs_render(hbs, "layouts.app", req, &data.with_content(layout_rendered))
590}
591
592fn hbs_render<T: serde::Serialize>(
593    hbs: &Handlebars<'_>,
594    view: &str,
595    req: &RequestState,
596    data: &T,
597) -> Result<String> {
598    // Convert data to a HashMap
599    let req_value = to_value(req).context("Could not serialize request")?;
600    insert_key(data, "_request", req_value);
601
602    hbs.render(view, &data).map_err(|e| {
603        anyhow::anyhow!(
604            "Could not render template: {}\n\
605             Template: {:?}\n\
606             Line: {:?}\n\
607             Column: {:?}\n\
608             Reason: {}",
609            view,
610            e.template_name,
611            e.line_no,
612            e.column_no,
613            e.reason(),
614        )
615    })
616}
617
618// Macros for user and graph access
619
620#[macro_export]
621macro_rules! get_user {
622    ($shared_user:expr, $req:expr, $user:ident) => {
623        let $user = $shared_user.blocking_lock(); // keep guard in scope
624    };
625}
626
627#[macro_export]
628macro_rules! get_user_and_graph {
629    ($state:expr, $req:expr, $shared_user:expr, $graph_id:expr, $user:ident, $graph:ident) => {
630        let $user = $shared_user.user.blocking_lock(); // keep guard in scope
631        let graph = $user.get_graph_by_id($graph_id);
632        if (graph.is_none()) {
633            return error_400($state, $req, "Graph not found");
634        }
635        let $graph = graph.unwrap();
636        if !$graph.is_writable_by(&*$user) {
637            return error_403($state, $req, "User can't write to graph");
638        }
639    };
640}
641
642#[cfg(test)]
643#[allow(clippy::unwrap_in_result, clippy::panic_in_result_fn)]
644mod tests {
645    use super::*;
646
647    #[crate::ctb_test]
648    fn test_is_send_and_sync() {
649        fn is_send_and_sync<T: Send + Sync>() {}
650        is_send_and_sync::<AppState>();
651    }
652}