//! [`BoxedUint`] square root operations.

use crate::{BitOps, BoxedUint, CtEq, CtGt, CtOption, CtSelect, Limb, SquareRoot};

impl BoxedUint {
    /// Computes √(`self`) in constant time.
    ///
    /// Callers can check if `self` is a square by squaring the result
    pub fn sqrt(&self) -> Self {
        // Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13.
        //
        // See Hast, "Note on computation of integer square roots"
        // for the proof of the sufficiency of the bound on iterations.
        // https://github.com/RustCrypto/crypto-bigint/files/12600669/ct_sqrt.pdf

        // The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`.
        // Will not overflow since `b <= BITS`.
        let mut x = Self::one_with_precision(self.bits_precision());
        x.overflowing_shl_assign((self.bits() + 1) >> 1); // ≥ √(`self`)

        let mut nz_x = x.clone();
        let mut quo = Self::zero_with_precision(self.bits_precision());
        let mut rem = Self::zero_with_precision(self.bits_precision());
        let mut i = 0;

        // Repeat enough times to guarantee result has stabilized.
        // TODO (#378): the tests indicate that just `Self::LOG2_BITS` may be enough.
        while i < self.log2_bits() + 2 {
            let x_nonzero = x.is_nonzero();
            nz_x.ct_assign(&x, x_nonzero);

            // Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)`
            quo.limbs.copy_from_slice(&self.limbs);
            rem.limbs.copy_from_slice(&nz_x.limbs);
            quo.as_mut_uint_ref().div_rem(rem.as_mut_uint_ref());
            x.conditional_carrying_add_assign(&quo, x_nonzero);
            x.shr1_assign();

            i += 1;
        }

        // At this point `x_prev == x_{n}` and `x == x_{n+1}`
        // where `n == i - 1 == LOG2_BITS + 1 == floor(log2(BITS)) + 1`.
        // Thus, according to Hast, `sqrt(self) = min(x_n, x_{n+1})`.
        x.ct_assign(&nz_x, x.ct_gt(&nz_x));
        x
    }

    /// Computes √(`self`)
    ///
    /// Callers can check if `self` is a square by squaring the result
    pub fn sqrt_vartime(&self) -> Self {
        // Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13

        if self.is_zero_vartime() {
            return Self::zero_with_precision(self.bits_precision());
        }

        // The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`.
        // Will not overflow since `b <= BITS`.
        // The initial value of `x` is always greater than zero.
        let mut x = Self::one_with_precision(self.bits_precision());
        x.overflowing_shl_assign_vartime((self.bits() + 1) >> 1); // ≥ √(`self`)

        let mut quo = Self::zero_with_precision(self.bits_precision());
        let mut rem = Self::zero_with_precision(self.bits_precision());

        loop {
            // Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)`
            quo.limbs.copy_from_slice(&self.limbs);
            rem.limbs.copy_from_slice(&x.limbs);
            quo.as_mut_uint_ref().div_rem_vartime(rem.as_mut_uint_ref());
            quo.carrying_add_assign(&x, Limb::ZERO);
            quo.shr1_assign();

            // If `quo` is the same as `x` or greater, we reached convergence
            // (`x` is guaranteed to either go down or oscillate between
            // `sqrt(self)` and `sqrt(self) + 1`)
            if !x.cmp_vartime(&quo).is_gt() {
                break;
            }
            x.limbs.copy_from_slice(&quo.limbs);
            if x.is_zero_vartime() {
                break;
            }
        }

        x
    }

    /// Wrapped sqrt is just normal √(`self`)
    /// There’s no way wrapping could ever happen.
    /// This function exists so that all operations are accounted for in the wrapping operations.
    pub fn wrapping_sqrt(&self) -> Self {
        self.sqrt()
    }

    /// Wrapped sqrt is just normal √(`self`)
    /// There’s no way wrapping could ever happen.
    /// This function exists so that all operations are accounted for in the wrapping operations.
    pub fn wrapping_sqrt_vartime(&self) -> Self {
        self.sqrt_vartime()
    }

    /// Perform checked sqrt, returning a [`CtOption`] which `is_some`
    /// only if the √(`self`)² == self
    pub fn checked_sqrt(&self) -> CtOption<Self> {
        let r = self.sqrt();
        let s = r.wrapping_mul(&r);
        CtOption::new(r, self.ct_eq(&s))
    }

    /// Perform checked sqrt, returning a [`CtOption`] which `is_some`
    /// only if the √(`self`)² == self
    pub fn checked_sqrt_vartime(&self) -> CtOption<Self> {
        let r = self.sqrt_vartime();
        let s = r.wrapping_mul(&r);
        CtOption::new(r, self.ct_eq(&s))
    }
}

impl SquareRoot for BoxedUint {
    fn sqrt(&self) -> Self {
        self.sqrt()
    }

    fn sqrt_vartime(&self) -> Self {
        self.sqrt_vartime()
    }
}

#[cfg(test)]
mod tests {
    use crate::{BoxedUint, Limb};

    #[cfg(feature = "rand_core")]
    use {
        crate::RandomBits,
        chacha20::ChaCha8Rng,
        rand_core::{RngCore, SeedableRng},
    };

    #[test]
    fn edge() {
        assert_eq!(
            BoxedUint::zero_with_precision(256).sqrt(),
            BoxedUint::zero_with_precision(256)
        );
        assert_eq!(
            BoxedUint::one_with_precision(256).sqrt(),
            BoxedUint::one_with_precision(256)
        );
        let mut half = BoxedUint::zero_with_precision(256);
        for i in 0..half.limbs.len() / 2 {
            half.limbs[i] = Limb::MAX;
        }
        let u256_max = !BoxedUint::zero_with_precision(256);
        assert_eq!(u256_max.sqrt(), half);

        // Test edge cases that use up the maximum number of iterations.

        // `x = (r + 1)^2 - 583`, where `r` is the expected square root.
        assert_eq!(
            BoxedUint::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d", 192)
                .unwrap()
                .sqrt(),
            BoxedUint::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21", 192)
                .unwrap(),
        );
        assert_eq!(
            BoxedUint::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d", 192)
                .unwrap()
                .sqrt_vartime(),
            BoxedUint::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21", 192)
                .unwrap()
        );

        // `x = (r + 1)^2 - 205`, where `r` is the expected square root.
        assert_eq!(
            BoxedUint::from_be_hex(
                "4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597",
                256
            )
            .unwrap()
            .sqrt(),
            BoxedUint::from_be_hex(
                "000000000000000000000000000000008b3956339e8315cff66eb6107b610075",
                256
            )
            .unwrap()
        );
        assert_eq!(
            BoxedUint::from_be_hex(
                "4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597",
                256
            )
            .unwrap()
            .sqrt_vartime(),
            BoxedUint::from_be_hex(
                "000000000000000000000000000000008b3956339e8315cff66eb6107b610075",
                256
            )
            .unwrap()
        );
    }

    #[test]
    fn edge_vartime() {
        assert_eq!(
            BoxedUint::zero_with_precision(256).sqrt_vartime(),
            BoxedUint::zero_with_precision(256)
        );
        assert_eq!(
            BoxedUint::one_with_precision(256).sqrt_vartime(),
            BoxedUint::one_with_precision(256)
        );
        let mut half = BoxedUint::zero_with_precision(256);
        for i in 0..half.limbs.len() / 2 {
            half.limbs[i] = Limb::MAX;
        }
        let u256_max = !BoxedUint::zero_with_precision(256);
        assert_eq!(u256_max.sqrt_vartime(), half);
    }

    #[test]
    fn simple() {
        let tests = [
            (4u8, 2u8),
            (9, 3),
            (16, 4),
            (25, 5),
            (36, 6),
            (49, 7),
            (64, 8),
            (81, 9),
            (100, 10),
            (121, 11),
            (144, 12),
            (169, 13),
        ];
        for (a, e) in &tests {
            let l = BoxedUint::from(*a);
            let r = BoxedUint::from(*e);
            assert_eq!(l.sqrt(), r);
            assert_eq!(l.sqrt_vartime(), r);
            assert!(l.checked_sqrt().is_some().to_bool());
            assert!(l.checked_sqrt_vartime().is_some().to_bool());
        }
    }

    #[test]
    fn nonsquares() {
        assert_eq!(BoxedUint::from(2u8).sqrt(), BoxedUint::from(1u8));
        assert!(!BoxedUint::from(2u8).checked_sqrt().is_some().to_bool());
        assert_eq!(BoxedUint::from(3u8).sqrt(), BoxedUint::from(1u8));
        assert!(!BoxedUint::from(3u8).checked_sqrt().is_some().to_bool());
        assert_eq!(BoxedUint::from(5u8).sqrt(), BoxedUint::from(2u8));
        assert_eq!(BoxedUint::from(6u8).sqrt(), BoxedUint::from(2u8));
        assert_eq!(BoxedUint::from(7u8).sqrt(), BoxedUint::from(2u8));
        assert_eq!(BoxedUint::from(8u8).sqrt(), BoxedUint::from(2u8));
        assert_eq!(BoxedUint::from(10u8).sqrt(), BoxedUint::from(3u8));
    }

    #[test]
    fn nonsquares_vartime() {
        assert_eq!(BoxedUint::from(2u8).sqrt_vartime(), BoxedUint::from(1u8));
        assert!(
            !BoxedUint::from(2u8)
                .checked_sqrt_vartime()
                .is_some()
                .to_bool()
        );
        assert_eq!(BoxedUint::from(3u8).sqrt_vartime(), BoxedUint::from(1u8));
        assert!(
            !BoxedUint::from(3u8)
                .checked_sqrt_vartime()
                .is_some()
                .to_bool()
        );
        assert_eq!(BoxedUint::from(5u8).sqrt_vartime(), BoxedUint::from(2u8));
        assert_eq!(BoxedUint::from(6u8).sqrt_vartime(), BoxedUint::from(2u8));
        assert_eq!(BoxedUint::from(7u8).sqrt_vartime(), BoxedUint::from(2u8));
        assert_eq!(BoxedUint::from(8u8).sqrt_vartime(), BoxedUint::from(2u8));
        assert_eq!(BoxedUint::from(10u8).sqrt_vartime(), BoxedUint::from(3u8));
    }

    #[cfg(feature = "rand_core")]
    #[test]
    fn fuzz() {
        let mut rng = ChaCha8Rng::from_seed([7u8; 32]);
        for _ in 0..50 {
            let t = rng.next_u32() as u64;
            let s = BoxedUint::from(t);
            let s2 = s.checked_mul(&s).unwrap();
            assert_eq!(s2.sqrt(), s);
            assert_eq!(s2.sqrt_vartime(), s);
            assert!(s2.checked_sqrt().is_some().to_bool());
            assert!(s2.checked_sqrt_vartime().is_some().to_bool());
        }

        for _ in 0..50 {
            let s = BoxedUint::random_bits(&mut rng, 512);
            let mut s2 = BoxedUint::zero_with_precision(512);
            s2.limbs[..s.limbs.len()].copy_from_slice(&s.limbs);
            assert_eq!(s.square().sqrt(), s2);
            assert_eq!(s.square().sqrt_vartime(), s2);
        }
    }
}
