ctoolbox/io/webui/
session_auth.rs1use std::sync::Arc;
2use std::time::Instant;
3
4use anyhow::Result;
5use axum::extract::{FromRef, FromRequestParts};
6use axum::http::StatusCode;
7use axum_extra::extract::CookieJar;
8use base64::Engine;
9use base64::engine::general_purpose::URL_SAFE_NO_PAD;
10use http::HeaderMap;
11use rand::RngCore;
12use rand::rngs::ThreadRng;
13use tokio::sync::Mutex;
14use zeroize::{Zeroize, ZeroizeOnDrop};
15
16use crate::io::webui::AppState;
17use crate::storage::graph::Graph;
18use crate::storage::user::User;
19use crate::utilities::backtrace_string;
20use crate::{debug, debug_fmt};
21
22pub type SharedUser = Arc<Mutex<User>>;
23pub type SharedGraph = Arc<Mutex<Graph>>;
24
25pub struct AuthenticatedUser {
26 pub user: SharedUser,
27}
28
29impl<S> FromRequestParts<S> for AuthenticatedUser
30where
31 AppState: FromRef<S>,
32 S: Send + Sync,
33{
34 type Rejection = StatusCode;
35 async fn from_request_parts(
36 parts: &mut axum::http::request::Parts,
37 state: &S,
38 ) -> Result<Self, Self::Rejection> {
39 let mut state = AppState::from_ref(state);
40 let session_key_bytes = session_key_from_headers(&parts.headers);
41 let Some(session_key_bytes) = session_key_bytes else {
42 return Err(StatusCode::UNAUTHORIZED);
43 };
44 let user =
45 Session::get_user_by_key(&mut state, &session_key_bytes).await;
46 let Some(user) = user else {
47 debug!("No user found for session key");
48 return Err(StatusCode::UNAUTHORIZED);
49 };
50 Ok(AuthenticatedUser { user })
51 }
52}
53
54pub fn session_key_string_from_headers(headers: &HeaderMap) -> Option<String> {
56 let cookies = CookieJar::from_headers(headers);
60 let session_key =
61 cookies.get("session").map(std::string::ToString::to_string);
62 debug!("Session key: {:?}", &session_key);
63 if session_key.is_none() || session_key.as_ref()?.is_empty() {
64 debug_fmt!("No session key found in cookies {}", backtrace_string());
65 return None;
66 }
67 let session_key = session_key.as_ref()?.strip_prefix("session=")?;
69 Some(session_key.to_string())
70}
71
72pub fn session_key_from_headers(headers: &HeaderMap) -> Option<Vec<u8>> {
74 let session_key = session_key_string_from_headers(headers)?;
75 let Ok(session_key_bytes) = URL_SAFE_NO_PAD.decode(session_key) else {
77 debug!("Invalid session key format");
78 return None;
79 };
80 Some(session_key_bytes)
81}
82
83#[derive(Clone, ZeroizeOnDrop)]
85pub struct Session {
86 key: Vec<u8>,
87 user_id: u64,
88 #[zeroize(skip)]
89 expiry: Instant,
90}
91
92impl Session {
93 pub async fn new(state: &mut AppState, user: User) -> Self {
94 let mut data = [0u8; 128];
95 ThreadRng::default().fill_bytes(&mut data[..]);
96 let key = data.to_vec();
97 data.zeroize();
98 assert!(
101 (Self::get_by_key(state, &key).await).is_none(),
102 "Session key collision"
103 );
104 let session = Self {
105 key: key.clone(),
106 user_id: user.local_id(),
107 expiry: Instant::now() + std::time::Duration::from_secs(3600),
108 };
109 {
111 let mut sessions = state.sessions.lock().await;
112 sessions.insert(key.clone(), session.clone());
113 }
114 {
116 let mut by_user = state.sessions_by_user.lock().await;
117 by_user
118 .entry(user.local_id())
119 .or_insert_with(Vec::new)
120 .push(key.clone());
121 }
122 {
123 let mut users = state.users.lock().await;
124 users.insert(user.local_id(), Arc::new(Mutex::new(user)));
125 }
126 assert!(
127 Self::get_user_by_key(state, &key).await.is_some(),
128 "Failed to retrieve user by session key after creating session"
129 );
130 session
131 }
132
133 pub fn is_expired(&self) -> bool {
134 Instant::now() > self.expiry
135 }
136
137 pub async fn get_by_key(
138 state: &mut AppState,
139 key: &[u8],
140 ) -> Option<Session> {
141 let sessions = state.sessions.lock().await;
142 debug!(format!(
143 "{:?}",
144 sessions
145 .keys()
146 .map(|k| URL_SAFE_NO_PAD.encode(k))
147 .collect::<Vec<_>>()
148 ));
149 let session = sessions.get(key).cloned();
150 debug!(
151 "Maybe got session for key {:?} user_id={:?}",
152 URL_SAFE_NO_PAD.encode(key),
153 session.as_ref().map(|s| s.user_id)
154 );
155 if let Some(sess) = &session {
156 if sess.is_expired() {
157 drop(sessions);
158 Self::invalidate(state, key).await;
159 return None;
160 }
161 debug!(
162 "Got session for key {:?}, user_id={:?}, bt={}",
163 URL_SAFE_NO_PAD.encode(key),
164 session.as_ref().map(|s| s.user_id),
165 backtrace_string()
166 );
167 }
168 session
169 }
170
171 pub async fn get_user_by_key(
172 state: &mut AppState,
173 key: &[u8],
174 ) -> Option<Arc<Mutex<User>>> {
175 let session = Self::get_by_key(state, key).await?;
176 return state.users.lock().await.get(&session.user_id).cloned();
177 }
178
179 pub async fn invalidate(state: &mut AppState, key: &[u8]) {
180 let mut sessions = state.sessions.lock().await;
181 if let Some(session) = sessions.remove(key) {
182 let user = &session.user_id;
183 let mut by_user = state.sessions_by_user.lock().await;
184 by_user
185 .entry(*user)
186 .and_modify(|v| v.retain(|k| k.as_slice() != key))
187 .or_default();
188 if by_user.get(user).is_some_and(std::vec::Vec::is_empty) {
190 state.users.lock().await.remove(user);
191 }
192 }
193 }
194
195 pub async fn invalidate_all_expired(state: &mut AppState) {
196 let expired_keys: Vec<Vec<u8>> = {
198 let sessions = state.sessions.lock().await;
199 sessions
200 .iter()
201 .filter_map(|(k, s)| {
202 if s.is_expired() {
203 Some(k.clone())
204 } else {
205 None
206 }
207 })
208 .collect()
209 }; for key in expired_keys {
212 Self::invalidate(state, &key).await;
213 }
214 }
215
216 pub fn id(&self) -> String {
217 URL_SAFE_NO_PAD.encode(&self.key)
218 }
219}
220
221#[cfg(test)]
222#[allow(clippy::unwrap_in_result, clippy::panic_in_result_fn)]
223mod tests {
224 use super::*;
225 use crate::io::webui::AppState;
226 use crate::storage::user::get_test_user;
227 use ctb_test_macro::ctb_test;
228
229 #[ctb_test(tokio::test)]
230 async fn test_session_new() {
231 let mut state = AppState::default();
232 let user = get_test_user(function_name!());
233 let user_local_id = user.local_id();
234 let session = Session::new(&mut state, user).await;
235 assert!(!session.is_expired());
236 assert_eq!(session.user_id, user_local_id);
237 assert!(
238 Session::get_by_key(&mut state, &session.key)
239 .await
240 .is_some()
241 );
242 }
243
244 #[ctb_test(tokio::test)]
245 async fn test_session_get_by_key() {
246 let mut state = AppState::default();
247 let user = get_test_user(function_name!());
248 let session = Session::new(&mut state, user).await;
249 let retrieved = Session::get_by_key(&mut state, &session.key).await;
250 assert!(retrieved.is_some());
251 assert_eq!(retrieved.unwrap().user_id, session.user_id);
252 }
253
254 #[ctb_test(tokio::test)]
255 async fn test_session_get_user_by_key() {
256 let mut state = AppState::default();
257 let user = get_test_user(function_name!());
258 let user_local_id = user.local_id();
259 let session = Session::new(&mut state, user).await;
260 let retrieved_user =
261 Session::get_user_by_key(&mut state, &session.key).await;
262 assert!(retrieved_user.is_some());
263 assert_eq!(
265 retrieved_user.unwrap().lock().await.local_id(),
266 user_local_id
267 );
268 }
269
270 #[ctb_test(tokio::test)]
271 async fn test_session_invalidate() {
272 let mut state = AppState::default();
273 let user = get_test_user(function_name!());
274 let session = Session::new(&mut state, user).await;
275 Session::invalidate(&mut state, &session.key).await;
276 assert!(
277 Session::get_by_key(&mut state, &session.key)
278 .await
279 .is_none()
280 );
281 }
282
283 #[ctb_test(tokio::test)]
284 async fn test_session_invalidate_all_expired() {
285 let mut state = AppState::default();
286 let user = get_test_user(function_name!());
287 let session = Session::new(&mut state, user).await;
288 {
290 let mut sessions = state.sessions.lock().await;
291 if let Some(s) = sessions.get_mut(&session.key) {
292 s.expiry = Instant::now()
293 .checked_sub(std::time::Duration::from_secs(1))
294 .unwrap();
295 }
296 }
297 Session::invalidate_all_expired(&mut state).await;
298 assert!(
299 Session::get_by_key(&mut state, &session.key)
300 .await
301 .is_none()
302 );
303 }
304
305 #[ctb_test(tokio::test)]
306 async fn test_session_is_expired() {
307 let mut session = Session {
308 key: vec![1, 2, 3],
309 user_id: 1,
310 expiry: Instant::now() + std::time::Duration::from_secs(1),
311 };
312 assert!(!session.is_expired());
313 session.expiry = Instant::now()
314 .checked_sub(std::time::Duration::from_secs(1))
315 .unwrap();
316 assert!(session.is_expired());
317 }
318
319 #[ctb_test(tokio::test)]
320 async fn test_session_id() {
321 let session = Session {
322 key: vec![1, 2, 3],
323 user_id: 1,
324 expiry: Instant::now(),
325 };
326 let id = session.id();
327 assert_eq!(id, URL_SAFE_NO_PAD.encode(&session.key));
328 }
329
330 }