use crate::connection::Connection;
use crate::util::refined_tcp_stream::Stream as RefinedStream;
use rustls::ServerConfig;
use rustls_pemfile::{certs, pkcs8_private_keys, rsa_private_keys};
use rustls_pki_types::{CertificateDer, PrivatePkcs1KeyDer, PrivatePkcs8KeyDer};
use std::error::Error;
use std::io::{Read, Write};
use std::net::{Shutdown, SocketAddr};
use std::sync::{Arc, Mutex};
use zeroize::Zeroizing;

/// A wrapper around an owned Rustls connection and corresponding stream.
///
/// Uses an internal Mutex to permit disparate reader & writer threads to access the stream independently.
pub(crate) struct RustlsStream(
    Arc<Mutex<rustls::StreamOwned<rustls::ServerConnection, Connection>>>,
);

impl RustlsStream {
    pub(crate) fn peer_addr(&mut self) -> std::io::Result<Option<SocketAddr>> {
        self.0
            .lock()
            .expect("Failed to lock SSL stream mutex")
            .sock
            .peer_addr()
    }

    pub(crate) fn shutdown(&mut self, how: Shutdown) -> std::io::Result<()> {
        self.0
            .lock()
            .expect("Failed to lock SSL stream mutex")
            .sock
            .shutdown(how)
    }
}

impl Clone for RustlsStream {
    fn clone(&self) -> Self {
        Self(self.0.clone())
    }
}

impl Read for RustlsStream {
    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
        self.0
            .lock()
            .expect("Failed to lock SSL stream mutex")
            .read(buf)
    }
}

impl Write for RustlsStream {
    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
        self.0
            .lock()
            .expect("Failed to lock SSL stream mutex")
            .write(buf)
    }

    fn flush(&mut self) -> std::io::Result<()> {
        self.0
            .lock()
            .expect("Failed to lock SSL stream mutex")
            .flush()
    }
}

pub(crate) struct RustlsContext(Arc<rustls::ServerConfig>);

// Updated with Copilot assistance from rustls 0.20 to 0.23, so don't trust this
impl RustlsContext {
    pub(crate) fn from_pem(
        certificates: Vec<u8>,
        private_key: Zeroizing<Vec<u8>>,
    ) -> Result<Self, Box<dyn Error + Send + Sync>> {
        // Parse certificates
        let mut cert_reader = certificates.as_slice();
        let certificate_chain: Vec<CertificateDer<'static>> = certs(&mut cert_reader)
            .filter_map(|res| match res {
                Ok(cert) => Some(cert.into_owned()),
                Err(_) => None, // Or handle error as needed
            })
            .collect();

        if certificate_chain.is_empty() {
            return Err("Couldn't extract certificate chain from config.".into());
        }

        // Parse private key (prefer PKCS8, fallback to RSA)
        let binding = private_key.clone();
        let mut key_reader = binding.as_slice();
        let pkcs8_keys: Vec<PrivatePkcs8KeyDer<'static>> = pkcs8_private_keys(&mut key_reader)
            .filter_map(|res| match res {
                Ok(key) => Some(key),
                Err(_) => None, // Or handle error as needed
            })
            .collect();

        let private_key = if let Some(pkcs8_key) = pkcs8_keys.first() {
            rustls_pki_types::PrivateKeyDer::Pkcs8(pkcs8_key.clone_key())
        } else {
            let mut rsa_reader = private_key.as_slice();
            let rsa_keys: Vec<PrivatePkcs1KeyDer<'static>> = rsa_private_keys(&mut rsa_reader)
                .filter_map(|res| match res {
                    Ok(key) => Some(key),
                    Err(_) => None,
                })
                .collect();

            if let Some(rsa_key) = rsa_keys.first() {
                rustls_pki_types::PrivateKeyDer::Pkcs1(rsa_key.clone_key())
            } else {
                return Err("No valid private key found".into());
            }
        };

        // Build server config. 0.22+ uses .builder().with_no_client_auth().with_single_cert()
        let tls_conf = ServerConfig::builder()
            .with_no_client_auth()
            .with_single_cert(certificate_chain, private_key)?;

        Ok(Self(Arc::new(tls_conf)))
    }

    pub(crate) fn accept(
        &self,
        stream: Connection,
    ) -> Result<RustlsStream, Box<dyn Error + Send + Sync + 'static>> {
        let connection = rustls::ServerConnection::new(self.0.clone())?;
        Ok(RustlsStream(Arc::new(Mutex::new(
            rustls::StreamOwned::new(connection, stream),
        ))))
    }
}

impl From<RustlsStream> for RefinedStream {
    fn from(stream: RustlsStream) -> Self {
        Self::Https(stream)
    }
}
