-- |
-- Module      : Crypto.Number.Basic
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <[email protected]>
-- Stability   : experimental
-- Portability : Good

{-# LANGUAGE BangPatterns #-}
module Crypto.Number.Basic
    ( sqrti
    , gcde
    , areEven
    , log2
    , numBits
    , numBytes
    , asPowerOf2AndOdd
    ) where

import Data.Bits

import Crypto.Number.Compat

-- | @sqrti@ returns two integers @(l,b)@ so that @l <= sqrt i <= b@.
-- The implementation is quite naive, use an approximation for the first number
-- and use a dichotomy algorithm to compute the bound relatively efficiently.
sqrti :: Integer -> (Integer, Integer)
sqrti :: Integer -> (Integer, Integer)
sqrti Integer
i
    | Integer
i forall a. Ord a => a -> a -> Bool
< Integer
0     = forall a. HasCallStack => [Char] -> a
error [Char]
"cannot compute negative square root"
    | Integer
i forall a. Eq a => a -> a -> Bool
== Integer
0    = (Integer
0,Integer
0)
    | Integer
i forall a. Eq a => a -> a -> Bool
== Integer
1    = (Integer
1,Integer
1)
    | Integer
i forall a. Eq a => a -> a -> Bool
== Integer
2    = (Integer
1,Integer
2)
    | Bool
otherwise = Integer -> (Integer, Integer)
loop Integer
x0
        where
            nbdigits :: Int
nbdigits = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> [Char]
show Integer
i
            x0n :: Int
x0n = (if forall a. Integral a => a -> Bool
even Int
nbdigits then Int
nbdigits forall a. Num a => a -> a -> a
- Int
2 else Int
nbdigits forall a. Num a => a -> a -> a
- Int
1) forall a. Integral a => a -> a -> a
`div` Int
2
            x0 :: Integer
x0  = if forall a. Integral a => a -> Bool
even Int
nbdigits then Integer
2 forall a. Num a => a -> a -> a
* Integer
10 forall a b. (Num a, Integral b) => a -> b -> a
^ Int
x0n else Integer
6 forall a. Num a => a -> a -> a
* Integer
10 forall a b. (Num a, Integral b) => a -> b -> a
^ Int
x0n
            loop :: Integer -> (Integer, Integer)
loop Integer
x = case forall a. Ord a => a -> a -> Ordering
compare (forall {a}. Num a => a -> a
sq Integer
x) Integer
i of
                Ordering
LT -> Integer -> (Integer, Integer)
iterUp Integer
x
                Ordering
EQ -> (Integer
x, Integer
x)
                Ordering
GT -> Integer -> (Integer, Integer)
iterDown Integer
x
            iterUp :: Integer -> (Integer, Integer)
iterUp Integer
lb = if forall {a}. Num a => a -> a
sq Integer
ub forall a. Ord a => a -> a -> Bool
>= Integer
i then Integer -> Integer -> (Integer, Integer)
iter Integer
lb Integer
ub else Integer -> (Integer, Integer)
iterUp Integer
ub
                where ub :: Integer
ub = Integer
lb forall a. Num a => a -> a -> a
* Integer
2
            iterDown :: Integer -> (Integer, Integer)
iterDown Integer
ub = if forall {a}. Num a => a -> a
sq Integer
lb forall a. Ord a => a -> a -> Bool
>= Integer
i then Integer -> (Integer, Integer)
iterDown Integer
lb else Integer -> Integer -> (Integer, Integer)
iter Integer
lb Integer
ub
                where lb :: Integer
lb = Integer
ub forall a. Integral a => a -> a -> a
`div` Integer
2
            iter :: Integer -> Integer -> (Integer, Integer)
iter Integer
lb Integer
ub
                | Integer
lb forall a. Eq a => a -> a -> Bool
== Integer
ub   = (Integer
lb, Integer
ub)
                | Integer
lbforall a. Num a => a -> a -> a
+Integer
1 forall a. Eq a => a -> a -> Bool
== Integer
ub = (Integer
lb, Integer
ub)
                | Bool
otherwise  =
                    let d :: Integer
d = (Integer
ub forall a. Num a => a -> a -> a
- Integer
lb) forall a. Integral a => a -> a -> a
`div` Integer
2 in
                    if forall {a}. Num a => a -> a
sq (Integer
lb forall a. Num a => a -> a -> a
+ Integer
d) forall a. Ord a => a -> a -> Bool
>= Integer
i
                        then Integer -> Integer -> (Integer, Integer)
iter Integer
lb (Integer
ubforall a. Num a => a -> a -> a
-Integer
d)
                        else Integer -> Integer -> (Integer, Integer)
iter (Integer
lbforall a. Num a => a -> a -> a
+Integer
d) Integer
ub
            sq :: a -> a
sq a
a = a
a forall a. Num a => a -> a -> a
* a
a

-- | Get the extended GCD of two integer using integer divMod
--
-- gcde 'a' 'b' find (x,y,gcd(a,b)) where ax + by = d
--
gcde :: Integer -> Integer -> (Integer, Integer, Integer)
gcde :: Integer -> Integer -> (Integer, Integer, Integer)
gcde Integer
a Integer
b = forall a. GmpSupported a -> a -> a
onGmpUnsupported (Integer -> Integer -> GmpSupported (Integer, Integer, Integer)
gmpGcde Integer
a Integer
b) forall a b. (a -> b) -> a -> b
$
    if Integer
d forall a. Ord a => a -> a -> Bool
< Integer
0 then (-Integer
x,-Integer
y,-Integer
d) else (Integer
x,Integer
y,Integer
d)
  where
    (Integer
d, Integer
x, Integer
y)                     = forall {c}. Integral c => (c, c, c) -> (c, c, c) -> (c, c, c)
f (Integer
a,Integer
1,Integer
0) (Integer
b,Integer
0,Integer
1)
    f :: (c, c, c) -> (c, c, c) -> (c, c, c)
f (c, c, c)
t              (c
0, c
_, c
_)    = (c, c, c)
t
    f (c
a', c
sa, c
ta) t :: (c, c, c)
t@(c
b', c
sb, c
tb) =
        let (c
q, c
r) = c
a' forall a. Integral a => a -> a -> (a, a)
`divMod` c
b' in
        (c, c, c) -> (c, c, c) -> (c, c, c)
f (c, c, c)
t (c
r, c
sa forall a. Num a => a -> a -> a
- (c
q forall a. Num a => a -> a -> a
* c
sb), c
ta forall a. Num a => a -> a -> a
- (c
q forall a. Num a => a -> a -> a
* c
tb))

-- | Check if a list of integer are all even
areEven :: [Integer] -> Bool
areEven :: [Integer] -> Bool
areEven = forall (t :: * -> *). Foldable t => t Bool -> Bool
and forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map forall a. Integral a => a -> Bool
even

-- | Compute the binary logarithm of a integer
log2 :: Integer -> Int
log2 :: Integer -> Int
log2 Integer
n = forall a. GmpSupported a -> a -> a
onGmpUnsupported (Integer -> GmpSupported Int
gmpLog2 Integer
n) forall a b. (a -> b) -> a -> b
$ forall {a} {a}. (Integral a, Integral a) => a -> a -> a
imLog Integer
2 Integer
n
  where
    -- http://www.haskell.org/pipermail/haskell-cafe/2008-February/039465.html
    imLog :: a -> a -> a
imLog a
b a
x = if a
x forall a. Ord a => a -> a -> Bool
< a
b then a
0 else (a
x forall a. Integral a => a -> a -> a
`div` a
bforall a b. (Num a, Integral b) => a -> b -> a
^a
l) forall {t}. Num t => a -> t -> t
`doDiv` a
l
      where
        l :: a
l = a
2 forall a. Num a => a -> a -> a
* a -> a -> a
imLog (a
b forall a. Num a => a -> a -> a
* a
b) a
x
        doDiv :: a -> t -> t
doDiv a
x' t
l' = if a
x' forall a. Ord a => a -> a -> Bool
< a
b then t
l' else (a
x' forall a. Integral a => a -> a -> a
`div` a
b) a -> t -> t
`doDiv` (t
l' forall a. Num a => a -> a -> a
+ t
1)
{-# INLINE log2 #-}

-- | Compute the number of bits for an integer
numBits :: Integer -> Int
numBits :: Integer -> Int
numBits Integer
n = Integer -> GmpSupported Int
gmpSizeInBits Integer
n forall a. GmpSupported a -> a -> a
`onGmpUnsupported` (if Integer
n forall a. Eq a => a -> a -> Bool
== Integer
0 then Int
1 else forall {t} {t}. (Num t, Integral t) => t -> t -> t
computeBits Int
0 Integer
n)
  where computeBits :: t -> t -> t
computeBits !t
acc t
i
            | t
q forall a. Eq a => a -> a -> Bool
== t
0 =
                if t
r forall a. Ord a => a -> a -> Bool
>= t
0x80 then t
accforall a. Num a => a -> a -> a
+t
8
                else if t
r forall a. Ord a => a -> a -> Bool
>= t
0x40 then t
accforall a. Num a => a -> a -> a
+t
7
                else if t
r forall a. Ord a => a -> a -> Bool
>= t
0x20 then t
accforall a. Num a => a -> a -> a
+t
6
                else if t
r forall a. Ord a => a -> a -> Bool
>= t
0x10 then t
accforall a. Num a => a -> a -> a
+t
5
                else if t
r forall a. Ord a => a -> a -> Bool
>= t
0x08 then t
accforall a. Num a => a -> a -> a
+t
4
                else if t
r forall a. Ord a => a -> a -> Bool
>= t
0x04 then t
accforall a. Num a => a -> a -> a
+t
3
                else if t
r forall a. Ord a => a -> a -> Bool
>= t
0x02 then t
accforall a. Num a => a -> a -> a
+t
2
                else if t
r forall a. Ord a => a -> a -> Bool
>= t
0x01 then t
accforall a. Num a => a -> a -> a
+t
1
                else t
acc -- should be catch by previous loop
            | Bool
otherwise = t -> t -> t
computeBits (t
accforall a. Num a => a -> a -> a
+t
8) t
q
          where (t
q,t
r) = t
i forall a. Integral a => a -> a -> (a, a)
`divMod` t
256

-- | Compute the number of bytes for an integer
numBytes :: Integer -> Int
numBytes :: Integer -> Int
numBytes Integer
n = Integer -> GmpSupported Int
gmpSizeInBytes Integer
n forall a. GmpSupported a -> a -> a
`onGmpUnsupported` ((Integer -> Int
numBits Integer
n forall a. Num a => a -> a -> a
+ Int
7) forall a. Integral a => a -> a -> a
`div` Int
8)

-- | Express an integer as an odd number and a power of 2
asPowerOf2AndOdd :: Integer -> (Int, Integer)
asPowerOf2AndOdd :: Integer -> (Int, Integer)
asPowerOf2AndOdd Integer
a
    | Integer
a forall a. Eq a => a -> a -> Bool
== Integer
0       = (Int
0, Integer
0)
    | forall a. Integral a => a -> Bool
odd Integer
a        = (Int
0, Integer
a)
    | Integer
a forall a. Ord a => a -> a -> Bool
< Integer
0        = let (Int
e, Integer
a1) = Integer -> (Int, Integer)
asPowerOf2AndOdd forall a b. (a -> b) -> a -> b
$ forall {a}. Num a => a -> a
abs Integer
a in (Int
e, -Integer
a1)
    | forall {a}. (Num a, Bits a) => a -> Bool
isPowerOf2 Integer
a = (Integer -> Int
log2 Integer
a, Integer
1)
    | Bool
otherwise    = forall {b} {a}. (Integral b, Num a) => b -> a -> (a, b)
loop Integer
a Int
0
        where      
          isPowerOf2 :: a -> Bool
isPowerOf2 a
n = (a
n forall a. Eq a => a -> a -> Bool
/= a
0) Bool -> Bool -> Bool
&& ((a
n forall a. Bits a => a -> a -> a
.&. (a
n forall a. Num a => a -> a -> a
- a
1)) forall a. Eq a => a -> a -> Bool
== a
0)
          loop :: b -> a -> (a, b)
loop b
n a
pw = if b
n forall a. Integral a => a -> a -> a
`mod` b
2 forall a. Eq a => a -> a -> Bool
== b
0 then b -> a -> (a, b)
loop (b
n forall a. Integral a => a -> a -> a
`div` b
2) (a
pw forall a. Num a => a -> a -> a
+ a
1)
                      else (a
pw, b
n)