ctoolbox/io/webui/
session_auth.rs

1use std::sync::Arc;
2use std::time::Instant;
3
4use anyhow::Result;
5use axum::extract::{FromRef, FromRequestParts};
6use axum::http::StatusCode;
7use axum_extra::extract::CookieJar;
8use base64::Engine;
9use base64::engine::general_purpose::URL_SAFE_NO_PAD;
10use http::HeaderMap;
11use rand::RngCore;
12use rand::rngs::ThreadRng;
13use tokio::sync::Mutex;
14use zeroize::{Zeroize, ZeroizeOnDrop};
15
16use crate::io::webui::AppState;
17use crate::storage::graph::Graph;
18use crate::storage::user::User;
19use crate::utilities::backtrace_string;
20use crate::{debug, debug_fmt};
21
22pub type SharedUser = Arc<Mutex<User>>;
23pub type SharedGraph = Arc<Mutex<Graph>>;
24
25pub struct AuthenticatedUser {
26    pub user: SharedUser,
27}
28
29impl<S> FromRequestParts<S> for AuthenticatedUser
30where
31    AppState: FromRef<S>,
32    S: Send + Sync,
33{
34    type Rejection = StatusCode;
35    async fn from_request_parts(
36        parts: &mut axum::http::request::Parts,
37        state: &S,
38    ) -> Result<Self, Self::Rejection> {
39        let mut state = AppState::from_ref(state);
40        let session_key_bytes = session_key_from_headers(&parts.headers);
41        let Some(session_key_bytes) = session_key_bytes else {
42            return Err(StatusCode::UNAUTHORIZED);
43        };
44        let user =
45            Session::get_user_by_key(&mut state, &session_key_bytes).await;
46        let Some(user) = user else {
47            debug!("No user found for session key");
48            return Err(StatusCode::UNAUTHORIZED);
49        };
50        Ok(AuthenticatedUser { user })
51    }
52}
53
54/// Extract the base64-encoded session key from the cookie
55pub fn session_key_string_from_headers(headers: &HeaderMap) -> Option<String> {
56    // Extract the session key from the Cookie header
57    // The Cookie header may contain multiple cookies, separated by "; "
58    // We need to find the one named "session"
59    let cookies = CookieJar::from_headers(headers);
60    let session_key =
61        cookies.get("session").map(std::string::ToString::to_string);
62    debug!("Session key: {:?}", &session_key);
63    if session_key.is_none() || session_key.as_ref()?.is_empty() {
64        debug_fmt!("No session key found in cookies {}", backtrace_string());
65        return None;
66    }
67    // Remove "session=" prefix
68    let session_key = session_key.as_ref()?.strip_prefix("session=")?;
69    Some(session_key.to_string())
70}
71
72/// Decode the base64-encoded session key from the cookie; return the key bytes
73pub fn session_key_from_headers(headers: &HeaderMap) -> Option<Vec<u8>> {
74    let session_key = session_key_string_from_headers(headers)?;
75    // Decode the base64-encoded session key to get the raw key bytes
76    let Ok(session_key_bytes) = URL_SAFE_NO_PAD.decode(session_key) else {
77        debug!("Invalid session key format");
78        return None;
79    };
80    Some(session_key_bytes)
81}
82
83/// Session state (key) for a User. This should not be saved to disk.
84#[derive(Clone, ZeroizeOnDrop)]
85pub struct Session {
86    key: Vec<u8>,
87    user_id: u64,
88    #[zeroize(skip)]
89    expiry: Instant,
90}
91
92impl Session {
93    pub async fn new(state: &mut AppState, user: User) -> Self {
94        let mut data = [0u8; 128];
95        ThreadRng::default().fill_bytes(&mut data[..]);
96        let key = data.to_vec();
97        data.zeroize();
98        // Check if session key already used in state
99        // Fulfil OWASP "ensure that each sessionID is unique"?
100        assert!(
101            (Self::get_by_key(state, &key).await).is_none(),
102            "Session key collision"
103        );
104        let session = Self {
105            key: key.clone(),
106            user_id: user.local_id(),
107            expiry: Instant::now() + std::time::Duration::from_secs(3600),
108        };
109        // Insert the session into the state's sessions map
110        {
111            let mut sessions = state.sessions.lock().await;
112            sessions.insert(key.clone(), session.clone());
113        }
114        // Also insert into sessions_by_user
115        {
116            let mut by_user = state.sessions_by_user.lock().await;
117            by_user
118                .entry(user.local_id())
119                .or_insert_with(Vec::new)
120                .push(key.clone());
121        }
122        {
123            let mut users = state.users.lock().await;
124            users.insert(user.local_id(), Arc::new(Mutex::new(user)));
125        }
126        assert!(
127            Self::get_user_by_key(state, &key).await.is_some(),
128            "Failed to retrieve user by session key after creating session"
129        );
130        session
131    }
132
133    pub fn is_expired(&self) -> bool {
134        Instant::now() > self.expiry
135    }
136
137    pub async fn get_by_key(
138        state: &mut AppState,
139        key: &[u8],
140    ) -> Option<Session> {
141        let sessions = state.sessions.lock().await;
142        debug!(format!(
143            "{:?}",
144            sessions
145                .keys()
146                .map(|k| URL_SAFE_NO_PAD.encode(k))
147                .collect::<Vec<_>>()
148        ));
149        let session = sessions.get(key).cloned();
150        debug!(
151            "Maybe got session for key {:?} user_id={:?}",
152            URL_SAFE_NO_PAD.encode(key),
153            session.as_ref().map(|s| s.user_id)
154        );
155        if let Some(sess) = &session {
156            if sess.is_expired() {
157                drop(sessions);
158                Self::invalidate(state, key).await;
159                return None;
160            }
161            debug!(
162                "Got session for key {:?}, user_id={:?}, bt={}",
163                URL_SAFE_NO_PAD.encode(key),
164                session.as_ref().map(|s| s.user_id),
165                backtrace_string()
166            );
167        }
168        session
169    }
170
171    pub async fn get_user_by_key(
172        state: &mut AppState,
173        key: &[u8],
174    ) -> Option<Arc<Mutex<User>>> {
175        let session = Self::get_by_key(state, key).await?;
176        return state.users.lock().await.get(&session.user_id).cloned();
177    }
178
179    pub async fn invalidate(state: &mut AppState, key: &[u8]) {
180        let mut sessions = state.sessions.lock().await;
181        if let Some(session) = sessions.remove(key) {
182            let user = &session.user_id;
183            let mut by_user = state.sessions_by_user.lock().await;
184            by_user
185                .entry(*user)
186                .and_modify(|v| v.retain(|k| k.as_slice() != key))
187                .or_default();
188            // If no sessions remain for user, remove state.users user
189            if by_user.get(user).is_some_and(std::vec::Vec::is_empty) {
190                state.users.lock().await.remove(user);
191            }
192        }
193    }
194
195    pub async fn invalidate_all_expired(state: &mut AppState) {
196        // collect keys of expired sessions
197        let expired_keys: Vec<Vec<u8>> = {
198            let sessions = state.sessions.lock().await;
199            sessions
200                .iter()
201                .filter_map(|(k, s)| {
202                    if s.is_expired() {
203                        Some(k.clone())
204                    } else {
205                        None
206                    }
207                })
208                .collect()
209        }; // sessions lock dropped here
210
211        for key in expired_keys {
212            Self::invalidate(state, &key).await;
213        }
214    }
215
216    pub fn id(&self) -> String {
217        URL_SAFE_NO_PAD.encode(&self.key)
218    }
219}
220
221#[cfg(test)]
222#[allow(clippy::unwrap_in_result, clippy::panic_in_result_fn)]
223mod tests {
224    use super::*;
225    use crate::io::webui::AppState;
226    use crate::storage::user::get_test_user;
227    use ctb_test_macro::ctb_test;
228
229    #[ctb_test(tokio::test)]
230    async fn test_session_new() {
231        let mut state = AppState::default();
232        let user = get_test_user(function_name!());
233        let user_local_id = user.local_id();
234        let session = Session::new(&mut state, user).await;
235        assert!(!session.is_expired());
236        assert_eq!(session.user_id, user_local_id);
237        assert!(
238            Session::get_by_key(&mut state, &session.key)
239                .await
240                .is_some()
241        );
242    }
243
244    #[ctb_test(tokio::test)]
245    async fn test_session_get_by_key() {
246        let mut state = AppState::default();
247        let user = get_test_user(function_name!());
248        let session = Session::new(&mut state, user).await;
249        let retrieved = Session::get_by_key(&mut state, &session.key).await;
250        assert!(retrieved.is_some());
251        assert_eq!(retrieved.unwrap().user_id, session.user_id);
252    }
253
254    #[ctb_test(tokio::test)]
255    async fn test_session_get_user_by_key() {
256        let mut state = AppState::default();
257        let user = get_test_user(function_name!());
258        let user_local_id = user.local_id();
259        let session = Session::new(&mut state, user).await;
260        let retrieved_user =
261            Session::get_user_by_key(&mut state, &session.key).await;
262        assert!(retrieved_user.is_some());
263        // Assume User has equality or check local_id
264        assert_eq!(
265            retrieved_user.unwrap().lock().await.local_id(),
266            user_local_id
267        );
268    }
269
270    #[ctb_test(tokio::test)]
271    async fn test_session_invalidate() {
272        let mut state = AppState::default();
273        let user = get_test_user(function_name!());
274        let session = Session::new(&mut state, user).await;
275        Session::invalidate(&mut state, &session.key).await;
276        assert!(
277            Session::get_by_key(&mut state, &session.key)
278                .await
279                .is_none()
280        );
281    }
282
283    #[ctb_test(tokio::test)]
284    async fn test_session_invalidate_all_expired() {
285        let mut state = AppState::default();
286        let user = get_test_user(function_name!());
287        let session = Session::new(&mut state, user).await;
288        // Manually expire the session by setting expiry to past
289        {
290            let mut sessions = state.sessions.lock().await;
291            if let Some(s) = sessions.get_mut(&session.key) {
292                s.expiry = Instant::now()
293                    .checked_sub(std::time::Duration::from_secs(1))
294                    .unwrap();
295            }
296        }
297        Session::invalidate_all_expired(&mut state).await;
298        assert!(
299            Session::get_by_key(&mut state, &session.key)
300                .await
301                .is_none()
302        );
303    }
304
305    #[ctb_test(tokio::test)]
306    async fn test_session_is_expired() {
307        let mut session = Session {
308            key: vec![1, 2, 3],
309            user_id: 1,
310            expiry: Instant::now() + std::time::Duration::from_secs(1),
311        };
312        assert!(!session.is_expired());
313        session.expiry = Instant::now()
314            .checked_sub(std::time::Duration::from_secs(1))
315            .unwrap();
316        assert!(session.is_expired());
317    }
318
319    #[ctb_test(tokio::test)]
320    async fn test_session_id() {
321        let session = Session {
322            key: vec![1, 2, 3],
323            user_id: 1,
324            expiry: Instant::now(),
325        };
326        let id = session.id();
327        assert_eq!(id, URL_SAFE_NO_PAD.encode(&session.key));
328    }
329
330    // Note: Testing AuthenticatedUser extractor requires full axum setup, which is complex for unit tests.
331    // Consider integration tests for extractor behavior.
332}