ctoolbox/workspace/ipc_old/dispatch/
db.rs1use anyhow::{Result, anyhow};
5
6use crate::storage::get_storage_dir;
7use crate::utilities::resource_lock::ResourceLock;
8use crate::{debug, warn};
9
10use redb::{Database, ReadableDatabase, ReadableTable, TableDefinition};
13
14use std::collections::HashMap;
15use std::path::PathBuf;
16use std::sync::atomic::{AtomicU64, Ordering};
17use std::sync::{Arc, OnceLock, RwLock};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
23enum DbKind {
24 Wrapped,
25 Unwrapped,
26}
27
28static DB_POOL: OnceLock<RwLock<HashMap<(DbKind, String), Arc<Database>>>> =
33 OnceLock::new();
34
35fn db_pool() -> &'static RwLock<HashMap<(DbKind, String), Arc<Database>>> {
36 DB_POOL.get_or_init(|| RwLock::new(HashMap::new()))
37}
38
39const DB_FILE_LOCK_FAMILY: &str = "db_file";
55
56fn db_file_lock(name: &str) -> Result<ResourceLock> {
57 ResourceLock::acquire(DB_FILE_LOCK_FAMILY, &name.replace('/', "_"))
60}
61
62static LOCK_SESSIONS: OnceLock<RwLock<HashMap<String, Vec<ResourceLock>>>> =
64 OnceLock::new();
65
66fn lock_sessions() -> &'static RwLock<HashMap<String, Vec<ResourceLock>>> {
67 LOCK_SESSIONS.get_or_init(|| RwLock::new(HashMap::new()))
68}
69
70static NEXT_LOCK_ID: AtomicU64 = AtomicU64::new(1);
71
72fn new_lock_id() -> String {
73 let id = NEXT_LOCK_ID.fetch_add(1, Ordering::Relaxed);
74 format!("db-lock-{id}")
75}
76
77pub fn lock_databases_session<I, S>(names: I) -> Result<String>
94where
95 I: IntoIterator<Item = S>,
96 S: AsRef<str>,
97{
98 let mut dbs: Vec<String> =
100 names.into_iter().map(|s| s.as_ref().to_string()).collect();
101 dbs.sort();
102
103 let mut acquired: Vec<ResourceLock> = Vec::with_capacity(dbs.len());
104 for db in &dbs {
105 match db_file_lock(db) {
106 Ok(lock) => acquired.push(lock),
107 Err(e) => {
108 drop(acquired);
110 return Err(e);
111 }
112 }
113 }
114
115 let session_id = new_lock_id();
116 let mut sessions = lock_sessions().write().expect("LOCK_SESSIONS poisoned");
117 sessions.insert(session_id.clone(), acquired);
118 Ok(session_id)
119}
120
121pub fn unlock_databases_session(session_id: &str) -> Result<()> {
124 let mut sessions = lock_sessions().write().expect("LOCK_SESSIONS poisoned");
125 if let Some(locks) = sessions.remove(session_id) {
126 drop(locks);
127 Ok(())
128 } else {
129 Err(anyhow!("Unknown lock session id: {session_id}"))
130 }
131}
132
133pub fn lock_database_session(name: &str) -> Result<String> {
136 lock_databases_session([name])
137}
138
139fn get_or_open_database(name: &str, kind: DbKind) -> Result<Arc<Database>> {
150 if let Some(db) = db_pool()
152 .read()
153 .expect("DB_POOL poisoned")
154 .get(&(kind, name.to_string()))
155 .cloned()
156 {
157 return Ok(db);
158 }
159
160 let _db_file_guard = db_file_lock(name)?;
163
164 if let Some(db) = db_pool()
166 .read()
167 .expect("DB_POOL poisoned")
168 .get(&(kind, name.to_string()))
169 .cloned()
170 {
171 return Ok(db);
172 }
173
174 let db = open_redb_file(name)?;
179 let arc = Arc::new(db);
180
181 let mut map = db_pool().write().expect("DB_POOL poisoned");
183 let entry = map
184 .entry((kind, name.to_string()))
185 .or_insert_with(|| arc.clone());
186 Ok(entry.clone())
187}
188
189fn open_redb_file(name: &str) -> Result<Database> {
191 let path_to_database = db_path(name)?;
192 if !path_to_database.exists() {
194 let _db_file_guard = db_file_lock(name)?;
196 if !path_to_database.exists() {
197 let _ = Database::create(&path_to_database);
198 }
199 }
200 let db = Database::open(path_to_database)?;
201 Ok(db)
202}
203
204fn db_path(name: &str) -> Result<PathBuf> {
205 Ok(get_storage_dir()?.join(format!("{name}.redb")))
206}
207
208pub fn open<K, V>(table_name: &str) -> Result<TableConnection<K, V>>
213where
214 K: redb::Key + Sized + 'static,
215 V: redb::Value + Sized + 'static,
216{
217 let conn = TableConnection::open(table_name)?;
218 Ok(conn)
219}
220
221pub fn open_u<K, V>(table_name: &str) -> Result<TableConnection<K, V>>
225where
226 K: redb::Key + Sized + 'static,
227 V: redb::Value + Sized + 'static,
228{
229 let conn = TableConnection::open_u(table_name)?;
230 Ok(conn)
231}
232
233pub struct TableConnection<K, V>
235where
236 K: redb::Key + Sized + 'static,
237 V: redb::Value + Sized + 'static,
238{
239 db: Arc<Database>,
241 table_def: TableDefinition<'static, K, V>,
242 table_name: String,
243}
244
245const UNWRAPPED_TABLES: [&str; 8] = [
246 "users/auth",
247 "users/ids",
248 "users/ids_rev",
249 "users/key_encryption_key_params",
250 "users/pictures",
251 "users/pubkeys",
252 "users/uuids",
253 "users/wrapped_dek",
254];
255
256impl<K, V> TableConnection<K, V>
257where
258 K: redb::Key + Sized + 'static,
259 V: redb::Value + Sized + 'static,
260{
261 fn open_table(db: Arc<Database>, table_name: &str) -> Result<Self> {
263 let leaked: &'static str =
264 Box::leak(table_name.to_string().into_boxed_str());
265 let table_def = TableDefinition::new(leaked);
266 Ok(Self {
267 db,
268 table_def,
269 table_name: table_name.to_string(),
270 })
271 }
272
273 fn open_u(table_name: &str) -> Result<Self> {
275 let db = get_or_open_database(table_name, DbKind::Unwrapped)?;
276 Self::open_table(db, table_name)
277 }
278
279 pub fn open(table_name: &str) -> Result<Self> {
281 if UNWRAPPED_TABLES.contains(&table_name) {
282 return Self::open_u(table_name);
283 }
284 let db = get_or_open_database(table_name, DbKind::Wrapped)?;
285 Self::open_table(db, table_name)
286 }
287
288 fn acquire_db_lock(&self) -> Result<ResourceLock> {
291 db_file_lock(&self.table_name)
292 }
293
294 fn get<'k, R, F>(
298 &self,
299 key: <K as redb::Value>::SelfType<'k>,
300 map: F,
301 ) -> Option<R>
302 where
303 F: for<'v> FnOnce(<V as redb::Value>::SelfType<'v>) -> R,
304 {
305 let _lock = match self.acquire_db_lock() {
307 Ok(guard) => guard,
308 Err(e) => {
309 warn!(format!(
310 "db: GET lock acquisition failed for {}: {e}: {e:?}",
311 self.table_name
312 ));
313 return None;
314 }
315 };
316
317 let tx = self.db.begin_read().ok()?;
318 let table = tx.open_table(self.table_def).ok()?;
319 table.get(key).ok()?.map(|acc| map(acc.value()))
320 }
321
322 pub fn get_vec<'k>(
324 &self,
325 key: <K as redb::Value>::SelfType<'k>,
326 ) -> Option<Vec<u8>>
327 where
328 V: redb::Value,
329 for<'v> <V as redb::Value>::SelfType<'v>: AsRef<[u8]>,
330 {
331 let key_temp = format!("{key:?}");
332 let res = self.get(key, |x| x.as_ref().to_vec());
333 debug!(format!(
334 "db: GET {}/{:?} -> {:?} (Vec<u8>)",
335 self.table_name, key_temp, res
336 ));
337 res
338 }
339
340 pub fn get_str<'k>(
342 &self,
343 key: <K as redb::Value>::SelfType<'k>,
344 ) -> Option<String>
345 where
346 V: redb::Value,
347 for<'v> <V as redb::Value>::SelfType<'v>: AsRef<str>,
348 {
349 let key_temp = format!("{key:?}");
350 let res = self.get(key, |x| x.as_ref().to_string());
351 debug!(format!(
352 "db: GET {}/{:?} -> {:?} (str)",
353 self.table_name, key_temp, res
354 ));
355 res
356 }
357
358 pub fn get_u64<'k>(
360 &self,
361 key: <K as redb::Value>::SelfType<'k>,
362 ) -> Option<u64>
363 where
364 V: redb::Value,
365 for<'a> <V as redb::Value>::SelfType<'a>: Into<u64>,
366 {
367 let key_temp = format!("{key:?}");
368 let res = self.get(key, |x: <V as redb::Value>::SelfType<'_>| x.into());
369 debug!(format!(
370 "db: GET {}/{:?} -> {:?} (u64)",
371 self.table_name, key_temp, res
372 ));
373 res
374 }
375
376 pub fn put<'k>(
380 &self,
381 key: <K as redb::Value>::SelfType<'_>,
382 value: <V as redb::Value>::SelfType<'_>,
383 ) -> Result<()>
384 where
385 K: redb::Key + Sized,
386 V: redb::Value + Sized,
387 {
388 let _lock = self.acquire_db_lock()?;
390
391 debug!(format!(
392 "db: PUT {}/{:?} -> {:?}",
393 self.table_name, key, value
394 ));
395 let tx = self.db.begin_write()?;
396 {
397 let mut table = tx.open_table(self.table_def)?;
398 table.insert(key, value)?;
399 }
400 tx.commit()?;
401 Ok(())
402 }
403
404 pub fn delete(&self, key: <K as redb::Value>::SelfType<'_>) -> Result<()> {
408 let _lock = self.acquire_db_lock()?;
410
411 debug!(format!("db: DELETE {}/{:?}", self.table_name, key));
412 let tx = self.db.begin_write()?;
413 {
414 let mut table = tx.open_table(self.table_def)?;
415 table.remove(key)?;
416 }
417 tx.commit()?;
418 Ok(())
419 }
420}
421
422pub fn get_str_u64(db: &str, key: &str) -> Option<u64> {
424 let conn: TableConnection<&str, u64> = TableConnection::open(db).ok()?;
425 conn.get_u64(key)
426}
427
428pub fn get_all_u64_keys(db: &str) -> Result<Vec<u64>> {
429 let conn: TableConnection<u64, &str> = TableConnection::open(db)?;
430 let _lock = db_file_lock(db)?;
432
433 let tx = conn.db.begin_read()?;
434 let table = tx.open_table(conn.table_def)?;
435 let mut keys = Vec::new();
436 if let Ok(entry_result) = table.iter() {
437 for entry in entry_result.flatten() {
438 keys.push(entry.0.value().to_owned());
439 }
440 } else {
441 warn!("Could not get iter of users/uuids table");
442 return Err(anyhow::anyhow!("Could not get iter of users/uuids table"));
443 }
444 Ok(keys)
445}
446
447pub fn put_str_u64(db: &str, key: &str, value: u64) -> Result<()> {
449 let conn: TableConnection<&str, u64> = TableConnection::open(db)?;
450 conn.put(key, value)
451}
452
453pub fn delete_str_u64(db: &str, key: &str) -> Result<()> {
455 let conn: TableConnection<&str, u64> = TableConnection::open(db)?;
456 conn.delete(key)
457}
458
459pub fn get_u64_str(db: &str, key: u64) -> Option<String> {
461 let conn: TableConnection<u64, &str> = TableConnection::open(db).ok()?;
462 conn.get_str(key)
463}
464
465pub fn put_u64_str(db: &str, key: u64, value: &str) -> Result<()> {
467 let conn: TableConnection<u64, &str> = TableConnection::open(db)?;
468 conn.put(key, value)
469}
470
471pub fn delete_u64_str(db: &str, key: u64) -> Result<()> {
473 let conn: TableConnection<u64, &str> = TableConnection::open(db)?;
474 conn.delete(key)
475}
476
477pub fn get_u64_bytes(db: &str, key: u64) -> Option<Vec<u8>> {
479 let conn: TableConnection<u64, &[u8]> = TableConnection::open(db).ok()?;
480 conn.get_vec(key)
481}
482
483pub fn put_u64_bytes(db: &str, key: u64, value: &[u8]) -> Result<()> {
485 let conn: TableConnection<u64, &[u8]> = TableConnection::open(db)?;
486 conn.put(key, value)
487}
488
489pub fn delete_u64_bytes(db: &str, key: u64) -> Result<()> {
491 let conn: TableConnection<u64, &[u8]> = TableConnection::open(db)?;
492 conn.delete(key)
493}
494
495#[cfg(test)]
496#[allow(clippy::unwrap_in_result, clippy::panic_in_result_fn)]
497mod tests {
498 use super::*;
501 use anyhow::Result;
502
503 fn test_db_name(suffix: &str) -> String {
504 format!("test_db_{suffix}")
505 }
506
507 #[crate::ctb_test]
508 fn test_str_u64_roundtrip() -> Result<()> {
509 let db = test_db_name("str_u64");
510 let key = "foo";
511 let value = 12345u64;
512
513 let session = lock_database_session(&db)?;
515
516 put_str_u64(&db, key, value)?;
517 assert_eq!(get_str_u64(&db, key), Some(value));
518
519 delete_str_u64(&db, key)?;
520 assert_eq!(get_str_u64(&db, key), None);
521
522 unlock_databases_session(&session)?;
524
525 std::fs::remove_file(db_path(&db)?.as_path()).ok();
526
527 Ok(())
528 }
529
530 #[crate::ctb_test]
531 fn test_u64_str_roundtrip() -> Result<()> {
532 let db = test_db_name("u64_str");
533 let key = 42u64;
534 let value = "bar";
535
536 put_u64_str(&db, key, value)?;
537 assert_eq!(get_u64_str(&db, key), Some(value.to_string()));
538
539 delete_u64_str(&db, key)?;
540 assert_eq!(get_u64_str(&db, key), None);
541
542 std::fs::remove_file(db_path(&db)?.as_path()).ok();
543
544 Ok(())
545 }
546
547 #[crate::ctb_test]
548 fn test_u64_bytes_roundtrip() -> Result<()> {
549 let db = test_db_name("u64_bytes");
550 let key = 99u64;
551 let value = vec![1u8, 2, 3, 4];
552
553 put_u64_bytes(&db, key, &value)?;
554 assert_eq!(get_u64_bytes(&db, key), Some(value.clone()));
555
556 delete_u64_bytes(&db, key)?;
557 assert_eq!(get_u64_bytes(&db, key), None);
558
559 std::fs::remove_file(db_path(&db)?.as_path()).ok();
560
561 Ok(())
562 }
563}