use base64ct::{Base64UrlUnpadded, Encoding};
use chrono::{DateTime, Duration, Utc};
use mas_storage::Clock;
use rand::{Rng, RngCore, distributions::Standard, prelude::Distribution as _};
use serde::{Deserialize, Serialize};
use serde_with::{TimestampSeconds, serde_as};
use thiserror::Error;
use crate::cookies::{CookieDecodeError, CookieJar};
#[derive(Debug, Error)]
pub enum CsrfError {
    #[error("CSRF token mismatch")]
    Mismatch,
    #[error("Missing CSRF cookie")]
    Missing,
    #[error("could not decode CSRF cookie")]
    DecodeCookie(#[from] CookieDecodeError),
    #[error("CSRF token expired")]
    Expired,
    #[error("could not decode CSRF token")]
    Decode(#[from] base64ct::Error),
}
#[serde_as]
#[derive(Serialize, Deserialize, Debug)]
pub struct CsrfToken {
    #[serde_as(as = "TimestampSeconds<i64>")]
    expiration: DateTime<Utc>,
    token: [u8; 32],
}
impl CsrfToken {
    fn new(token: [u8; 32], now: DateTime<Utc>, ttl: Duration) -> Self {
        let expiration = now + ttl;
        Self { expiration, token }
    }
    fn generate(now: DateTime<Utc>, mut rng: impl Rng, ttl: Duration) -> Self {
        let token = Standard.sample(&mut rng);
        Self::new(token, now, ttl)
    }
    fn refresh(self, now: DateTime<Utc>, ttl: Duration) -> Self {
        Self::new(self.token, now, ttl)
    }
    #[must_use]
    pub fn form_value(&self) -> String {
        Base64UrlUnpadded::encode_string(&self.token[..])
    }
    pub fn verify_form_value(&self, form_value: &str) -> Result<(), CsrfError> {
        let form_value = Base64UrlUnpadded::decode_vec(form_value)?;
        if self.token[..] == form_value {
            Ok(())
        } else {
            Err(CsrfError::Mismatch)
        }
    }
    fn verify_expiration(self, now: DateTime<Utc>) -> Result<Self, CsrfError> {
        if now < self.expiration {
            Ok(self)
        } else {
            Err(CsrfError::Expired)
        }
    }
}
#[derive(Deserialize)]
pub struct ProtectedForm<T> {
    csrf: String,
    #[serde(flatten)]
    inner: T,
}
pub trait CsrfExt {
    fn csrf_token<C, R>(self, clock: &C, rng: R) -> (CsrfToken, Self)
    where
        R: RngCore,
        C: Clock;
    fn verify_form<C, T>(&self, clock: &C, form: ProtectedForm<T>) -> Result<T, CsrfError>
    where
        C: Clock;
}
impl CsrfExt for CookieJar {
    fn csrf_token<C, R>(self, clock: &C, rng: R) -> (CsrfToken, Self)
    where
        R: RngCore,
        C: Clock,
    {
        let now = clock.now();
        let maybe_token = match self.load::<CsrfToken>("csrf") {
            Ok(Some(token)) => {
                let token = token.verify_expiration(now);
                token.ok()
            }
            Ok(None) => None,
            Err(e) => {
                tracing::warn!("Failed to decode CSRF cookie: {}", e);
                None
            }
        };
        let token = maybe_token.map_or_else(
            || CsrfToken::generate(now, rng, Duration::try_hours(1).unwrap()),
            |token| token.refresh(now, Duration::try_hours(1).unwrap()),
        );
        let jar = self.save("csrf", &token, false);
        (token, jar)
    }
    fn verify_form<C, T>(&self, clock: &C, form: ProtectedForm<T>) -> Result<T, CsrfError>
    where
        C: Clock,
    {
        let token: CsrfToken = self.load("csrf")?.ok_or(CsrfError::Missing)?;
        let token = token.verify_expiration(clock.now())?;
        token.verify_form_value(&form.csrf)?;
        Ok(form.inner)
    }
}