1use std::collections::HashMap;
2use std::net::{Ipv4Addr, SocketAddr};
3use std::sync::Arc;
4
5use anyhow::{Context, Result};
6use axum::Router;
7use axum::extract::FromRequestParts;
8use axum::response::{Html, IntoResponse, Response};
9use axum_server::tls_rustls::RustlsConfig;
10use handlebars::Handlebars;
11use http::StatusCode;
12use maplit::btreemap;
13use portpicker::pick_unused_port;
14use serde::{Deserialize, Serialize};
15use serde_json::{Value, json, to_value};
16use std::time::Duration;
17use tokio::sync::Mutex;
18use tower_http::compression::CompressionLayer;
19use tower_http::cors::CorsLayer;
20
21use crate::formats::markdown::markdown2html;
22use crate::io::webui::access_log_layer::AccessLogLayer;
23use crate::io::webui::routes::build_routes;
24use crate::io::webui::session_auth::{AuthenticatedUser, Session, SharedUser};
25use crate::json_value;
26use crate::storage::{get_asset, register_views};
27use crate::utilities::serde_value::insert_key;
28use crate::utilities::*;
29
30pub mod access_log_layer;
31pub mod error;
32pub mod flexible_form;
33pub mod routes;
34pub mod session_auth;
35pub mod test_helpers;
36pub mod webview;
37pub mod controllers {
38 pub mod app;
39 pub mod auth;
40 pub mod base;
41 pub mod graph;
42 pub mod search;
43 pub mod web;
44}
45
46#[derive(Clone)]
48pub struct AppState {
49 hbs: Arc<Handlebars<'static>>,
50 sessions: Arc<Mutex<HashMap<Vec<u8>, Session>>>,
51 sessions_by_user: Arc<Mutex<HashMap<u64, Vec<Vec<u8>>>>>,
52 users: Arc<Mutex<HashMap<u64, SharedUser>>>,
53}
54
55impl Default for AppState {
56 fn default() -> Self {
57 let hbs = register_views();
58 Self {
59 hbs: Arc::new(hbs),
60 sessions: Arc::new(Mutex::new(HashMap::new())),
61 users: Arc::new(Mutex::new(HashMap::new())),
62 sessions_by_user: Arc::new(Mutex::new(HashMap::new())),
63 }
64 }
65}
66
67pub trait ViewContext: Serialize {
69 fn with_content(self, content: String) -> serde_json::Value
71 where
72 Self: Sized,
73 {
74 let mut map = serde_json::to_value(self)
76 .expect("context to be serializable")
77 .as_object()
78 .cloned()
79 .unwrap_or_default();
80 map.insert("content".to_string(), Value::String(content));
81 Value::Object(map)
82 }
83}
84
85impl<T: Serialize> ViewContext for T {}
87
88#[derive(Serialize)]
91struct ErrorContext {
92 message: String,
93 message_details: String,
94}
95
96pub fn start_webui_server() -> u16 {
97 log!("Starting local web UI server");
98 let current_settings =
99 crate::storage::pc_settings::PcSettings::load().unwrap_or_default();
100 let protocol: String =
101 if let Some(ref _cert) = current_settings.tls_certificate {
102 log!("Using HTTPS");
103 "https".to_string()
104 } else {
105 log!("Using HTTP");
106 "http".to_string()
107 };
108 let relevant_port = if protocol == "http" {
109 current_settings.fixed_http_port
110 } else {
111 current_settings.fixed_https_port
112 };
113 let port: u16 = if let Some(port) = relevant_port {
114 log!("Using fixed port from settings: {}", port);
115 port
116 } else {
117 pick_unused_port().expect("No ports free")
118 };
119 let bind_to_ip: String = current_settings.bind_to_ip;
120 log!("Using server address: {}", bind_to_ip.clone());
121 let domain: String = if let Some(domain) = current_settings.domain_name {
122 log!("Using configured domain name: {}", domain.clone());
123 domain
124 } else {
125 bind_to_ip.to_string()
126 };
127 let protocol_clone = protocol.clone();
128 let bind_to_ip_clone = bind_to_ip.clone();
129 let domain_clone = domain.clone();
130 let tls_certificate = current_settings.tls_certificate.clone();
131 let tls_private_key = current_settings.tls_private_key.clone();
132 std::thread::spawn(move || {
133 if let Err(e) = start_webui_server_inner(
134 port,
135 protocol_clone,
136 bind_to_ip_clone,
137 Some(domain_clone),
138 tls_certificate,
139 tls_private_key,
140 ) {
141 log!(format!("Web UI server failed to start: {e:?}"));
142 }
143 });
144
145 if protocol == "https" && port != 80 && current_settings.http_redirect {
147 let redirect_from_port = 80;
148 let bind_to_ip_clone = bind_to_ip.clone();
149 std::thread::spawn(move || {
150 let can_bind = bind_to_ip_clone
152 .parse::<Ipv4Addr>()
153 .ok()
154 .and_then(|ip| {
155 std::net::TcpListener::bind((ip, redirect_from_port)).ok()
156 })
157 .is_some();
158 if !can_bind {
159 log!(
160 "Cannot bind to port 80 for HTTP->HTTPS redirector, skipping"
161 );
162 return;
163 }
164
165 let rt = match tokio::runtime::Builder::new_current_thread()
166 .enable_all()
167 .thread_name("localwebui-redirect")
168 .build()
169 {
170 Ok(rt) => rt,
171 Err(e) => {
172 log!(format!(
173 "Failed building redirector tokio runtime: {e:?}"
174 ));
175 return;
176 }
177 };
178
179 let result = rt.block_on(http_to_https(
180 bind_to_ip_clone,
181 redirect_from_port,
182 Some(port),
183 ));
184 if let Err(e) = result {
185 log!(format!("HTTP->HTTPS redirector failed: {e:?}"));
186 }
187 });
188 }
189
190 let url = format!("{protocol}://{domain}:{port}");
191 let result = webbrowser::open(url.as_str());
192 if let Err(e) = result {
193 log!(format!("Failed to open web browser automatically: {e:?}"));
194 log!(format!(
195 "Please open your web browser and navigate to {url}"
196 ));
197 } else {
198 log!(format!("Web browser opened to {url}"));
199 }
200
201 port
202}
203
204async fn http_to_https(
205 bind_to_ip: String,
206 redirect_from_port: u16,
207 relevant_port: Option<u16>,
208) -> Result<()> {
209 let ip = bind_to_ip.parse::<Ipv4Addr>().with_context(|| {
210 format!("Could not parse bind IP address: {bind_to_ip}")
211 })?;
212 let addr = SocketAddr::from((ip, redirect_from_port));
213
214 axum_server::bind(addr)
215 .serve(
216 Router::new()
217 .fallback(axum::routing::any(
218 move |req: axum::http::Request<_>| async move {
219 let host = req
220 .headers()
221 .get("host")
222 .and_then(|h| h.to_str().ok())
223 .unwrap_or("");
224 let uri = req.uri().to_string();
225 let redirect_to_port = if let Some(relevant_port) =
226 relevant_port
227 && relevant_port != 443
228 {
229 format!(":{}", relevant_port)
230 } else {
231 "".to_string()
232 };
233 let redirect_url = format!(
234 "https://{}{}{}",
235 host, redirect_to_port, uri
236 );
237 axum::response::Redirect::permanent(&redirect_url)
238 },
239 ))
240 .into_make_service(),
241 )
242 .await
243 .context("Error in HTTP to HTTPS redirector")?;
244
245 Ok(())
246}
247
248const SLOW_TTFB_THRESHOLD: Duration = Duration::from_millis(150);
249
250pub fn build_app_router(state: AppState) -> Router {
251 build_routes(state)
252 .layer(AccessLogLayer::new(SLOW_TTFB_THRESHOLD))
253 .layer(CompressionLayer::new())
254 .layer(CorsLayer::permissive())
255}
256
257fn start_webui_server_inner(
258 port: u16,
259 protocol: String,
260 bind_to_ip: String,
261 domain_name: Option<String>,
262 tls_certificate: Option<String>,
263 tls_private_key: Option<String>,
264) -> Result<()> {
265 let hbs = register_views();
267 let state = AppState {
268 hbs: Arc::new(hbs),
269 sessions: Arc::new(Mutex::new(HashMap::new())),
270 users: Arc::new(Mutex::new(HashMap::new())),
271 sessions_by_user: Arc::new(Mutex::new(HashMap::new())),
272 };
273
274 let app = build_app_router(state);
275
276 let rt = tokio::runtime::Builder::new_multi_thread()
278 .enable_all()
279 .thread_name("localwebui-axum")
280 .build()
281 .context("failed building tokio runtime")?;
282
283 rt.block_on(async move {
284 let ip = bind_to_ip.parse::<Ipv4Addr>().with_context(|| {
285 format!("Could not parse bind IP address: {bind_to_ip}")
286 })?;
287 let addr = SocketAddr::from((ip, port));
288
289 let make_service =
291 app.into_make_service_with_connect_info::<SocketAddr>();
292
293 if protocol == "http" {
294 axum_server::bind(addr)
295 .serve(make_service)
296 .await
297 .context("HTTP server exited with error")?;
298 return Ok(());
299 }
300
301 let cert_vec = tls_certificate
302 .context("TLS certificate not provided, cannot start HTTPS server")?
303 .into_bytes();
304 let key_vec = tls_private_key
305 .context("TLS private key not provided, cannot start HTTPS server")?
306 .into_bytes();
307
308 let config = RustlsConfig::from_pem(cert_vec, key_vec)
309 .await
310 .context("Failed to build RustlsConfig from PEM")?;
311
312 axum_server::bind_rustls(addr, config)
313 .serve(make_service)
314 .await
315 .context("HTTPS server exited with error")?;
316
317 Ok(())
318 })
319}
320
321#[derive(Serialize, Clone)]
322pub struct RequestState {
323 route: String,
324 method: String,
325 accept: Option<String>,
326 is_js_request: bool,
327}
328
329impl<S> FromRequestParts<S> for RequestState
330where
331 S: Send + Sync,
332{
333 type Rejection = StatusCode;
334 async fn from_request_parts(
335 parts: &mut axum::http::request::Parts,
336 _state: &S,
337 ) -> Result<Self, Self::Rejection> {
338 Ok(RequestState {
339 route: parts.uri.path().to_string(),
340 method: parts.method.to_string(),
341 accept: parts
342 .headers
343 .get(axum::http::header::ACCEPT)
344 .map(|v| v.to_str().unwrap().to_string()),
345 is_js_request: parts
346 .headers
347 .get("X-CollectiveToolbox-IsJsRequest")
348 .and_then(|v| v.to_str().ok())
349 .map(|s| s.eq_ignore_ascii_case("true"))
350 .unwrap_or(false)
351 })
352 }
353}
354
355#[derive(Deserialize)]
356pub struct PageQuery {
358 page: Option<String>,
359}
360
361fn respond_general<T: serde::Serialize>(
364 state: &AppState,
365 req: RequestState,
366 view: &str,
367 data: &T,
368) -> Response {
369 match render_page(&state.hbs, None, view.to_string(), &req, data) {
370 Ok(html) => Html(html).into_response(),
371 Err(e) => error_400(state, &req, e),
372 }
373}
374
375fn respond_page<T: serde::Serialize>(
376 state: &AppState,
377 req: RequestState,
378 view: &str,
379 data: &T,
380) -> Response {
381 match render_page(&state.hbs, Some("page"), view.to_string(), &req, data) {
382 Ok(html) => Html(html).into_response(),
383 Err(e) => error_400(state, &req, e),
384 }
385}
386
387fn respond_markdown_page(
388 state: &AppState,
389 req: RequestState,
390 view: &str,
391) -> Response {
392 let md = get_asset(format!("views/pages/{view}.md").as_str());
393
394 if let Some(md) = md {
395 let page = markdown2html(md);
396
397 return match render_page(
398 &state.hbs,
399 Some("page"),
400 "pages.markdown".to_string(),
401 &req,
402 &json_value!({ "page" => String::from_utf8_lossy(&page).to_string() }),
403 ) {
404 Ok(html) => Html(html).into_response(),
405 Err(e) => error_400(state, &req, e),
406 };
407 } else {
408 return error_404(
409 state,
410 &req,
411 format!("Markdown page not found: {}", view),
412 );
413 }
414}
415
416fn respond_dialog<T: serde::Serialize>(
417 state: &AppState,
418 req: RequestState,
419 view: &str,
420 data: &T,
421) -> Response {
422 match render_page(
423 &state.hbs,
424 Some("dialog"),
425 format!("dialogs.{view}"),
426 &req,
427 data,
428 ) {
429 Ok(html) => Html(html).into_response(),
430 Err(e) => error_400(state, &req, e),
431 }
432}
433
434fn error_500<E: std::fmt::Debug + std::fmt::Display>(
437 state: &AppState,
438 req: &RequestState,
439 e: E,
440) -> Response {
441 error_response(state, req, e, StatusCode::INTERNAL_SERVER_ERROR)
442}
443
444fn error_400<E: std::fmt::Debug + std::fmt::Display>(
445 state: &AppState,
446 req: &RequestState,
447 e: E,
448) -> Response {
449 error_response(state, req, e, StatusCode::BAD_REQUEST)
450}
451
452fn error_401<E: std::fmt::Debug + std::fmt::Display>(
453 state: &AppState,
454 req: &RequestState,
455 e: E,
456) -> Response {
457 error_response(state, req, e, StatusCode::UNAUTHORIZED)
458}
459
460fn error_403<E: std::fmt::Debug + std::fmt::Display>(
461 state: &AppState,
462 req: &RequestState,
463 e: E,
464) -> Response {
465 error_response(state, req, e, StatusCode::FORBIDDEN)
466}
467
468fn error_404<E: std::fmt::Debug + std::fmt::Display>(
469 state: &AppState,
470 req: &RequestState,
471 e: E,
472) -> Response {
473 error_response(state, req, e, StatusCode::NOT_FOUND)
474}
475
476fn recoverable_error<E: std::fmt::Display>(
477 state: &AppState,
478 req: RequestState,
479 e: E,
480) -> Response {
481 let mut response = respond_page(
483 state,
484 req,
485 "layouts._recoverable-error",
486 &btreemap! { "recoverable_error_message".to_string() => e.to_string()},
487 );
488 let status = response.status_mut();
489 *status = StatusCode::BAD_REQUEST;
490 response
491}
492
493fn error_response<E: std::fmt::Debug + std::fmt::Display>(
494 state: &AppState,
495 req: &RequestState,
496 e: E,
497 status_code: StatusCode,
498) -> Response {
499 let accept = req.accept.clone();
500
501 let (message, details) = {
502 let message = e.to_string();
503 let details = format!("{e:?}");
504 (message, details)
505 };
506
507 if let Some(accept) = accept
508 && accept.contains("application/json")
509 {
510 return error_response_json_with_details(
512 message.clone(),
513 details.clone(),
514 status_code,
515 );
516 }
517
518 match render_page(
520 &state.hbs,
521 Some("page"),
522 "error".to_string(),
523 req,
524 &ErrorContext {
525 message: message.clone(),
526 message_details: format!(
527 "{message}\nHTTP Status: {status_code}\ndetails:\n{details}"
528 ),
529 },
530 ) {
531 Ok(html) => {
532 let mut resp = Html(html).into_response();
533 *resp.status_mut() = status_code;
534 resp
535 }
536 Err(e) => error_response_json_with_details(
537 format!("Error rendering error response {e:?}"),
538 details,
539 status_code,
540 ),
541 }
542}
543
544fn error_response_json_with_details<E: std::fmt::Display>(
546 message: E,
547 details: String,
548 status_code: StatusCode,
549) -> Response {
550 let body = json!({
551 "type": "error",
552 "message": message.to_string(),
553 "message_details": details,
554 });
555 (status_code, axum::Json(body)).into_response()
556}
557
558fn render_view<T: serde::Serialize>(
561 hbs: &Handlebars<'_>,
562 view: String,
563 req: &RequestState,
564 data: &T,
565) -> Result<String> {
566 hbs_render(hbs, &view, req, data).context("Could not render view")
567}
568
569fn render_page<T: serde::Serialize>(
570 hbs: &Handlebars<'_>,
571 layout: Option<&str>,
572 view: String,
573 req: &RequestState,
574 data: &T,
575) -> Result<String> {
576 let view_rendered = hbs_render(hbs, view.as_str(), req, data)?;
577
578 let layout_rendered = if let Some(layout) = layout {
579 hbs_render(
580 hbs,
581 format!("layouts.{layout}").as_str(),
582 req,
583 &data.with_content(view_rendered),
584 )
585 } else {
586 Ok(view_rendered)
587 }?;
588
589 hbs_render(hbs, "layouts.app", req, &data.with_content(layout_rendered))
590}
591
592fn hbs_render<T: serde::Serialize>(
593 hbs: &Handlebars<'_>,
594 view: &str,
595 req: &RequestState,
596 data: &T,
597) -> Result<String> {
598 let req_value = to_value(req).context("Could not serialize request")?;
600 insert_key(data, "_request", req_value);
601
602 hbs.render(view, &data).map_err(|e| {
603 anyhow::anyhow!(
604 "Could not render template: {}\n\
605 Template: {:?}\n\
606 Line: {:?}\n\
607 Column: {:?}\n\
608 Reason: {}",
609 view,
610 e.template_name,
611 e.line_no,
612 e.column_no,
613 e.reason(),
614 )
615 })
616}
617
618#[macro_export]
621macro_rules! get_user {
622 ($shared_user:expr, $req:expr, $user:ident) => {
623 let $user = $shared_user.blocking_lock(); };
625}
626
627#[macro_export]
628macro_rules! get_user_and_graph {
629 ($state:expr, $req:expr, $shared_user:expr, $graph_id:expr, $user:ident, $graph:ident) => {
630 let $user = $shared_user.user.blocking_lock(); let graph = $user.get_graph_by_id($graph_id);
632 if (graph.is_none()) {
633 return error_400($state, $req, "Graph not found");
634 }
635 let $graph = graph.unwrap();
636 if !$graph.is_writable_by(&*$user) {
637 return error_403($state, $req, "User can't write to graph");
638 }
639 };
640}
641
642#[cfg(test)]
643#[allow(clippy::unwrap_in_result, clippy::panic_in_result_fn)]
644mod tests {
645 use super::*;
646
647 #[crate::ctb_test]
648 fn test_is_send_and_sync() {
649 fn is_send_and_sync<T: Send + Sync>() {}
650 is_send_and_sync::<AppState>();
651 }
652}