use std::sync::Arc;
use std::time::Instant;

use anyhow::Result;
use axum::extract::{FromRef, FromRequestParts};
use axum::http::StatusCode;
use axum_extra::extract::CookieJar;
use base64::Engine;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use http::HeaderMap;
use rand::RngCore;
use rand::rngs::ThreadRng;
use tokio::sync::Mutex;
use zeroize::{Zeroize, ZeroizeOnDrop};

use crate::io::webui::AppState;
use crate::storage::graph::Graph;
use crate::storage::user::User;
use crate::utilities::backtrace_string;
use crate::{debug, debug_fmt};

pub type SharedUser = Arc<Mutex<User>>;
pub type SharedGraph = Arc<Mutex<Graph>>;

pub struct AuthenticatedUser {
    pub user: SharedUser,
}

impl<S> FromRequestParts<S> for AuthenticatedUser
where
    AppState: FromRef<S>,
    S: Send + Sync,
{
    type Rejection = StatusCode;
    async fn from_request_parts(
        parts: &mut axum::http::request::Parts,
        state: &S,
    ) -> Result<Self, Self::Rejection> {
        let mut state = AppState::from_ref(state);
        let session_key_bytes = session_key_from_headers(&parts.headers);
        let Some(session_key_bytes) = session_key_bytes else {
            return Err(StatusCode::UNAUTHORIZED);
        };
        let user =
            Session::get_user_by_key(&mut state, &session_key_bytes).await;
        let Some(user) = user else {
            debug!("No user found for session key");
            return Err(StatusCode::UNAUTHORIZED);
        };
        Ok(AuthenticatedUser { user })
    }
}

/// Extract the base64-encoded session key from the cookie
pub fn session_key_string_from_headers(headers: &HeaderMap) -> Option<String> {
    // Extract the session key from the Cookie header
    // The Cookie header may contain multiple cookies, separated by "; "
    // We need to find the one named "session"
    let cookies = CookieJar::from_headers(headers);
    let session_key =
        cookies.get("session").map(std::string::ToString::to_string);
    debug!("Session key: {:?}", &session_key);
    if session_key.is_none() || session_key.as_ref()?.is_empty() {
        debug_fmt!("No session key found in cookies {}", backtrace_string());
        return None;
    }
    // Remove "session=" prefix
    let session_key = session_key.as_ref()?.strip_prefix("session=")?;
    Some(session_key.to_string())
}

/// Decode the base64-encoded session key from the cookie; return the key bytes
pub fn session_key_from_headers(headers: &HeaderMap) -> Option<Vec<u8>> {
    let session_key = session_key_string_from_headers(headers)?;
    // Decode the base64-encoded session key to get the raw key bytes
    let Ok(session_key_bytes) = URL_SAFE_NO_PAD.decode(session_key) else {
        debug!("Invalid session key format");
        return None;
    };
    Some(session_key_bytes)
}

/// Session state (key) for a User. This should not be saved to disk.
#[derive(Clone, ZeroizeOnDrop)]
pub struct Session {
    key: Vec<u8>,
    user_id: u64,
    #[zeroize(skip)]
    expiry: Instant,
}

impl Session {
    pub async fn new(state: &mut AppState, user: User) -> Self {
        let mut data = [0u8; 128];
        ThreadRng::default().fill_bytes(&mut data[..]);
        let key = data.to_vec();
        data.zeroize();
        // Check if session key already used in state
        // Fulfil OWASP "ensure that each sessionID is unique"?
        assert!(
            (Self::get_by_key(state, &key).await).is_none(),
            "Session key collision"
        );
        let session = Self {
            key: key.clone(),
            user_id: user.local_id(),
            expiry: Instant::now() + std::time::Duration::from_secs(3600),
        };
        // Insert the session into the state's sessions map
        {
            let mut sessions = state.sessions.lock().await;
            sessions.insert(key.clone(), session.clone());
        }
        // Also insert into sessions_by_user
        {
            let mut by_user = state.sessions_by_user.lock().await;
            by_user
                .entry(user.local_id())
                .or_insert_with(Vec::new)
                .push(key.clone());
        }
        {
            let mut users = state.users.lock().await;
            users.insert(user.local_id(), Arc::new(Mutex::new(user)));
        }
        assert!(
            Self::get_user_by_key(state, &key).await.is_some(),
            "Failed to retrieve user by session key after creating session"
        );
        session
    }

    pub fn is_expired(&self) -> bool {
        Instant::now() > self.expiry
    }

    pub async fn get_by_key(
        state: &mut AppState,
        key: &[u8],
    ) -> Option<Session> {
        let sessions = state.sessions.lock().await;
        debug!(format!(
            "{:?}",
            sessions
                .keys()
                .map(|k| URL_SAFE_NO_PAD.encode(k))
                .collect::<Vec<_>>()
        ));
        let session = sessions.get(key).cloned();
        debug!(
            "Maybe got session for key {:?} user_id={:?}",
            URL_SAFE_NO_PAD.encode(key),
            session.as_ref().map(|s| s.user_id)
        );
        if let Some(sess) = &session {
            if sess.is_expired() {
                drop(sessions);
                Self::invalidate(state, key).await;
                return None;
            }
            debug!(
                "Got session for key {:?}, user_id={:?}, bt={}",
                URL_SAFE_NO_PAD.encode(key),
                session.as_ref().map(|s| s.user_id),
                backtrace_string()
            );
        }
        session
    }

    pub async fn get_user_by_key(
        state: &mut AppState,
        key: &[u8],
    ) -> Option<Arc<Mutex<User>>> {
        let session = Self::get_by_key(state, key).await?;
        return state.users.lock().await.get(&session.user_id).cloned();
    }

    pub async fn invalidate(state: &mut AppState, key: &[u8]) {
        let mut sessions = state.sessions.lock().await;
        if let Some(session) = sessions.remove(key) {
            let user = &session.user_id;
            let mut by_user = state.sessions_by_user.lock().await;
            by_user
                .entry(*user)
                .and_modify(|v| v.retain(|k| k.as_slice() != key))
                .or_default();
            // If no sessions remain for user, remove state.users user
            if by_user.get(user).is_some_and(std::vec::Vec::is_empty) {
                state.users.lock().await.remove(user);
            }
        }
    }

    pub async fn invalidate_all_expired(state: &mut AppState) {
        // collect keys of expired sessions
        let expired_keys: Vec<Vec<u8>> = {
            let sessions = state.sessions.lock().await;
            sessions
                .iter()
                .filter_map(|(k, s)| {
                    if s.is_expired() {
                        Some(k.clone())
                    } else {
                        None
                    }
                })
                .collect()
        }; // sessions lock dropped here

        for key in expired_keys {
            Self::invalidate(state, &key).await;
        }
    }

    pub fn id(&self) -> String {
        URL_SAFE_NO_PAD.encode(&self.key)
    }
}

#[cfg(test)]
#[allow(clippy::unwrap_in_result, clippy::panic_in_result_fn)]
mod tests {
    use super::*;
    use crate::io::webui::AppState;
    use crate::storage::user::get_test_user;
    use ctb_test_macro::ctb_test;

    #[ctb_test(tokio::test)]
    async fn test_session_new() {
        let mut state = AppState::default();
        let user = get_test_user(function_name!());
        let user_local_id = user.local_id();
        let session = Session::new(&mut state, user).await;
        assert!(!session.is_expired());
        assert_eq!(session.user_id, user_local_id);
        assert!(
            Session::get_by_key(&mut state, &session.key)
                .await
                .is_some()
        );
    }

    #[ctb_test(tokio::test)]
    async fn test_session_get_by_key() {
        let mut state = AppState::default();
        let user = get_test_user(function_name!());
        let session = Session::new(&mut state, user).await;
        let retrieved = Session::get_by_key(&mut state, &session.key).await;
        assert!(retrieved.is_some());
        assert_eq!(retrieved.unwrap().user_id, session.user_id);
    }

    #[ctb_test(tokio::test)]
    async fn test_session_get_user_by_key() {
        let mut state = AppState::default();
        let user = get_test_user(function_name!());
        let user_local_id = user.local_id();
        let session = Session::new(&mut state, user).await;
        let retrieved_user =
            Session::get_user_by_key(&mut state, &session.key).await;
        assert!(retrieved_user.is_some());
        // Assume User has equality or check local_id
        assert_eq!(
            retrieved_user.unwrap().lock().await.local_id(),
            user_local_id
        );
    }

    #[ctb_test(tokio::test)]
    async fn test_session_invalidate() {
        let mut state = AppState::default();
        let user = get_test_user(function_name!());
        let session = Session::new(&mut state, user).await;
        Session::invalidate(&mut state, &session.key).await;
        assert!(
            Session::get_by_key(&mut state, &session.key)
                .await
                .is_none()
        );
    }

    #[ctb_test(tokio::test)]
    async fn test_session_invalidate_all_expired() {
        let mut state = AppState::default();
        let user = get_test_user(function_name!());
        let session = Session::new(&mut state, user).await;
        // Manually expire the session by setting expiry to past
        {
            let mut sessions = state.sessions.lock().await;
            if let Some(s) = sessions.get_mut(&session.key) {
                s.expiry = Instant::now()
                    .checked_sub(std::time::Duration::from_secs(1))
                    .unwrap();
            }
        }
        Session::invalidate_all_expired(&mut state).await;
        assert!(
            Session::get_by_key(&mut state, &session.key)
                .await
                .is_none()
        );
    }

    #[ctb_test(tokio::test)]
    async fn test_session_is_expired() {
        let mut session = Session {
            key: vec![1, 2, 3],
            user_id: 1,
            expiry: Instant::now() + std::time::Duration::from_secs(1),
        };
        assert!(!session.is_expired());
        session.expiry = Instant::now()
            .checked_sub(std::time::Duration::from_secs(1))
            .unwrap();
        assert!(session.is_expired());
    }

    #[ctb_test(tokio::test)]
    async fn test_session_id() {
        let session = Session {
            key: vec![1, 2, 3],
            user_id: 1,
            expiry: Instant::now(),
        };
        let id = session.id();
        assert_eq!(id, URL_SAFE_NO_PAD.encode(&session.key));
    }

    // Note: Testing AuthenticatedUser extractor requires full axum setup, which is complex for unit tests.
    // Consider integration tests for extractor behavior.
}
