-- |
-- Module      : Crypto.KDF.PBKDF2
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <[email protected]>
-- Stability   : experimental
-- Portability : unknown
--
-- Password Based Key Derivation Function 2
--
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ForeignFunctionInterface #-}

module Crypto.KDF.PBKDF2
    ( PRF
    , prfHMAC
    , Parameters(..)
    , generate
    , fastPBKDF2_SHA1
    , fastPBKDF2_SHA256
    , fastPBKDF2_SHA512
    ) where

import           Data.Word
import           Data.Bits
import           Foreign.Marshal.Alloc
import           Foreign.Ptr (plusPtr, Ptr)
import           Foreign.C.Types (CUInt(..), CSize(..))

import           Crypto.Hash (HashAlgorithm)
import qualified Crypto.MAC.HMAC as HMAC

import           Crypto.Internal.ByteArray (ByteArray, ByteArrayAccess, Bytes)
import qualified Crypto.Internal.ByteArray as B
import           Data.Memory.PtrMethods

-- | The PRF used for PBKDF2
type PRF password =
       password -- ^ the password parameters
    -> Bytes    -- ^ the content
    -> Bytes    -- ^ prf(password,content)

-- | PRF for PBKDF2 using HMAC with the hash algorithm as parameter
prfHMAC :: (HashAlgorithm a, ByteArrayAccess password)
        => a
        -> PRF password
prfHMAC :: forall a password.
(HashAlgorithm a, ByteArrayAccess password) =>
a -> PRF password
prfHMAC a
alg password
k = forall a. HashAlgorithm a => a -> Context a -> Bytes -> Bytes
hmacIncr a
alg (forall key a.
(ByteArrayAccess key, HashAlgorithm a) =>
key -> Context a
HMAC.initialize password
k)
  where hmacIncr :: HashAlgorithm a => a -> HMAC.Context a -> (Bytes -> Bytes)
        hmacIncr :: forall a. HashAlgorithm a => a -> Context a -> Bytes -> Bytes
hmacIncr a
_ !Context a
ctx = \Bytes
b -> forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert forall a b. (a -> b) -> a -> b
$ forall a. HashAlgorithm a => Context a -> HMAC a
HMAC.finalize forall a b. (a -> b) -> a -> b
$ forall message a.
(ByteArrayAccess message, HashAlgorithm a) =>
Context a -> message -> Context a
HMAC.update Context a
ctx Bytes
b

-- | Parameters for PBKDF2
data Parameters = Parameters
    { Parameters -> Int
iterCounts   :: Int -- ^ the number of user-defined iterations for the algorithms. e.g. WPA2 uses 4000.
    , Parameters -> Int
outputLength :: Int -- ^ the number of bytes to generate out of PBKDF2
    }

-- | generate the pbkdf2 key derivation function from the output
generate :: (ByteArrayAccess password, ByteArrayAccess salt, ByteArray ba)
         => PRF password
         -> Parameters
         -> password
         -> salt
         -> ba
generate :: forall password salt ba.
(ByteArrayAccess password, ByteArrayAccess salt, ByteArray ba) =>
PRF password -> Parameters -> password -> salt -> ba
generate PRF password
prf Parameters
params password
password salt
salt =
    forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze (Parameters -> Int
outputLength Parameters
params) forall a b. (a -> b) -> a -> b
$ \Ptr Word8
p -> do
        Ptr Word8 -> Word8 -> Int -> IO ()
memSet Ptr Word8
p Word8
0 (Parameters -> Int
outputLength Parameters
params)
        Word32 -> Int -> Ptr Word8 -> IO ()
loop Word32
1 (Parameters -> Int
outputLength Parameters
params) Ptr Word8
p
  where
    !runPRF :: Bytes -> Bytes
runPRF = PRF password
prf password
password
    !hLen :: Int
hLen   = forall ba. ByteArrayAccess ba => ba -> Int
B.length forall a b. (a -> b) -> a -> b
$ Bytes -> Bytes
runPRF forall a. ByteArray a => a
B.empty

    -- run the following f function on each complete chunk.
    -- when having an incomplete chunk, we call partial.
    -- partial need to be the last call.
    --
    -- f(pass,salt,c,i) = U1 xor U2 xor .. xor Uc
    -- U1 = PRF(pass,salt || BE32(i))
    -- Uc = PRF(pass,Uc-1)
    loop :: Word32 -> Int -> Ptr Word8 -> IO ()
loop Word32
iterNb Int
len Ptr Word8
p
        | Int
len forall a. Eq a => a -> a -> Bool
== Int
0   = forall (m :: * -> *) a. Monad m => a -> m a
return ()
        | Int
len forall a. Ord a => a -> a -> Bool
< Int
hLen = Word32 -> Int -> Ptr Word8 -> IO ()
partial Word32
iterNb Int
len Ptr Word8
p
        | Bool
otherwise  = do
            let applyMany :: t -> Bytes -> IO ()
applyMany t
0 Bytes
_     = forall (m :: * -> *) a. Monad m => a -> m a
return ()
                applyMany t
i Bytes
uprev = do
                    let uData :: Bytes
uData = Bytes -> Bytes
runPRF Bytes
uprev
                    forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray Bytes
uData forall a b. (a -> b) -> a -> b
$ \Ptr Word8
u -> Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memXor Ptr Word8
p Ptr Word8
p Ptr Word8
u Int
hLen
                    t -> Bytes -> IO ()
applyMany (t
iforall a. Num a => a -> a -> a
-t
1) Bytes
uData
            forall {t}. (Eq t, Num t) => t -> Bytes -> IO ()
applyMany (Parameters -> Int
iterCounts Parameters
params) (forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert salt
salt forall bs. ByteArray bs => bs -> bs -> bs
`B.append` forall ba. ByteArray ba => Word32 -> ba
toBS Word32
iterNb)
            Word32 -> Int -> Ptr Word8 -> IO ()
loop (Word32
iterNbforall a. Num a => a -> a -> a
+Word32
1) (Int
len forall a. Num a => a -> a -> a
- Int
hLen) (Ptr Word8
p forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
hLen)

    partial :: Word32 -> Int -> Ptr Word8 -> IO ()
partial Word32
iterNb Int
len Ptr Word8
p = forall a b. Int -> Int -> (Ptr a -> IO b) -> IO b
allocaBytesAligned Int
hLen Int
8 forall a b. (a -> b) -> a -> b
$ \Ptr Word8
tmp -> do
        let applyMany :: Int -> Bytes -> IO ()
            applyMany :: Int -> Bytes -> IO ()
applyMany Int
0 Bytes
_     = forall (m :: * -> *) a. Monad m => a -> m a
return ()
            applyMany Int
i Bytes
uprev = do
                let uData :: Bytes
uData = Bytes -> Bytes
runPRF Bytes
uprev
                forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray Bytes
uData forall a b. (a -> b) -> a -> b
$ \Ptr Word8
u -> Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memXor Ptr Word8
tmp Ptr Word8
tmp Ptr Word8
u Int
hLen
                Int -> Bytes -> IO ()
applyMany (Int
iforall a. Num a => a -> a -> a
-Int
1) Bytes
uData
        Ptr Word8 -> Word8 -> Int -> IO ()
memSet Ptr Word8
tmp Word8
0 Int
hLen
        Int -> Bytes -> IO ()
applyMany (Parameters -> Int
iterCounts Parameters
params) (forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert salt
salt forall bs. ByteArray bs => bs -> bs -> bs
`B.append` forall ba. ByteArray ba => Word32 -> ba
toBS Word32
iterNb)
        Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memCopy Ptr Word8
p Ptr Word8
tmp Int
len

    -- big endian encoding of Word32
    toBS :: ByteArray ba => Word32 -> ba
    toBS :: forall ba. ByteArray ba => Word32 -> ba
toBS Word32
w = forall a. ByteArray a => [Word8] -> a
B.pack [Word8
a,Word8
b,Word8
c,Word8
d]
      where a :: Word8
a = forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32
w forall a. Bits a => a -> Int -> a
`shiftR` Int
24)
            b :: Word8
b = forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word32
w forall a. Bits a => a -> Int -> a
`shiftR` Int
16) forall a. Bits a => a -> a -> a
.&. Word32
0xff)
            c :: Word8
c = forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word32
w forall a. Bits a => a -> Int -> a
`shiftR` Int
8) forall a. Bits a => a -> a -> a
.&. Word32
0xff)
            d :: Word8
d = forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32
w forall a. Bits a => a -> a -> a
.&. Word32
0xff)
{-# NOINLINE generate #-}

fastPBKDF2_SHA1 :: (ByteArrayAccess password, ByteArrayAccess salt, ByteArray out)
                => Parameters
                -> password
                -> salt
                -> out
fastPBKDF2_SHA1 :: forall password salt out.
(ByteArrayAccess password, ByteArrayAccess salt, ByteArray out) =>
Parameters -> password -> salt -> out
fastPBKDF2_SHA1 Parameters
params password
password salt
salt =
    forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze (Parameters -> Int
outputLength Parameters
params) forall a b. (a -> b) -> a -> b
$ \Ptr Word8
outPtr ->
    forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray password
password forall a b. (a -> b) -> a -> b
$ \Ptr Word8
passPtr ->
    forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray salt
salt forall a b. (a -> b) -> a -> b
$ \Ptr Word8
saltPtr ->
        Ptr Word8
-> CSize
-> Ptr Word8
-> CSize
-> CUInt
-> Ptr Word8
-> CSize
-> IO ()
c_cryptonite_fastpbkdf2_hmac_sha1
            Ptr Word8
passPtr (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall ba. ByteArrayAccess ba => ba -> Int
B.length password
password)
            Ptr Word8
saltPtr (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall ba. ByteArrayAccess ba => ba -> Int
B.length salt
salt)
            (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Parameters -> Int
iterCounts Parameters
params)
            Ptr Word8
outPtr (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Parameters -> Int
outputLength Parameters
params)

fastPBKDF2_SHA256 :: (ByteArrayAccess password, ByteArrayAccess salt, ByteArray out)
                  => Parameters
                  -> password
                  -> salt
                  -> out
fastPBKDF2_SHA256 :: forall password salt out.
(ByteArrayAccess password, ByteArrayAccess salt, ByteArray out) =>
Parameters -> password -> salt -> out
fastPBKDF2_SHA256 Parameters
params password
password salt
salt =
    forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze (Parameters -> Int
outputLength Parameters
params) forall a b. (a -> b) -> a -> b
$ \Ptr Word8
outPtr ->
    forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray password
password forall a b. (a -> b) -> a -> b
$ \Ptr Word8
passPtr ->
    forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray salt
salt forall a b. (a -> b) -> a -> b
$ \Ptr Word8
saltPtr ->
        Ptr Word8
-> CSize
-> Ptr Word8
-> CSize
-> CUInt
-> Ptr Word8
-> CSize
-> IO ()
c_cryptonite_fastpbkdf2_hmac_sha256
            Ptr Word8
passPtr (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall ba. ByteArrayAccess ba => ba -> Int
B.length password
password)
            Ptr Word8
saltPtr (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall ba. ByteArrayAccess ba => ba -> Int
B.length salt
salt)
            (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Parameters -> Int
iterCounts Parameters
params)
            Ptr Word8
outPtr (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Parameters -> Int
outputLength Parameters
params)

fastPBKDF2_SHA512 :: (ByteArrayAccess password, ByteArrayAccess salt, ByteArray out)
                  => Parameters
                  -> password
                  -> salt
                  -> out
fastPBKDF2_SHA512 :: forall password salt out.
(ByteArrayAccess password, ByteArrayAccess salt, ByteArray out) =>
Parameters -> password -> salt -> out
fastPBKDF2_SHA512 Parameters
params password
password salt
salt =
    forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze (Parameters -> Int
outputLength Parameters
params) forall a b. (a -> b) -> a -> b
$ \Ptr Word8
outPtr ->
    forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray password
password forall a b. (a -> b) -> a -> b
$ \Ptr Word8
passPtr ->
    forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray salt
salt forall a b. (a -> b) -> a -> b
$ \Ptr Word8
saltPtr ->
        Ptr Word8
-> CSize
-> Ptr Word8
-> CSize
-> CUInt
-> Ptr Word8
-> CSize
-> IO ()
c_cryptonite_fastpbkdf2_hmac_sha512
            Ptr Word8
passPtr (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall ba. ByteArrayAccess ba => ba -> Int
B.length password
password)
            Ptr Word8
saltPtr (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall ba. ByteArrayAccess ba => ba -> Int
B.length salt
salt)
            (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Parameters -> Int
iterCounts Parameters
params)
            Ptr Word8
outPtr (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Parameters -> Int
outputLength Parameters
params)


foreign import ccall unsafe "cryptonite_pbkdf2.h cryptonite_fastpbkdf2_hmac_sha1"
    c_cryptonite_fastpbkdf2_hmac_sha1 :: Ptr Word8 -> CSize
                                      -> Ptr Word8 -> CSize
                                      -> CUInt
                                      -> Ptr Word8 -> CSize
                                      -> IO ()

foreign import ccall unsafe "cryptonite_pbkdf2.h cryptonite_fastpbkdf2_hmac_sha256"
    c_cryptonite_fastpbkdf2_hmac_sha256 :: Ptr Word8 -> CSize
                                        -> Ptr Word8 -> CSize
                                        -> CUInt
                                        -> Ptr Word8 -> CSize
                                        -> IO ()

foreign import ccall unsafe "cryptonite_pbkdf2.h cryptonite_fastpbkdf2_hmac_sha512"
    c_cryptonite_fastpbkdf2_hmac_sha512 :: Ptr Word8 -> CSize
                                        -> Ptr Word8 -> CSize
                                        -> CUInt
                                        -> Ptr Word8 -> CSize
                                        -> IO ()