{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
module Cardano.Crypto.Libsodium.Hash (
    SodiumHashAlgorithm (..),
    digestMLockedStorable,
    digestMLockedBS,
    expandHash,
) where

import Control.Monad (unless)
import Data.Proxy (Proxy (..))
import Foreign.C.Error (errnoToIOError, getErrno)
import Foreign.C.Types (CSize)
import Foreign.Ptr (Ptr, castPtr, nullPtr, plusPtr)
import Foreign.Storable (Storable (sizeOf, poke))
import Data.Word (Word8)
import Data.Type.Equality ((:~:)(..))
import GHC.IO.Exception (ioException)
import GHC.TypeLits
import System.IO.Unsafe (unsafeDupablePerformIO)

import qualified Data.ByteString as BS

import Cardano.Foreign
import Cardano.Crypto.Hash (HashAlgorithm(SizeHash), SHA256, Blake2b_256)
import Cardano.Crypto.PinnedSizedBytes (ptrPsbToSizedPtr)
import Cardano.Crypto.Libsodium.C
import Cardano.Crypto.Libsodium.Memory.Internal
import Cardano.Crypto.Libsodium.MLockedBytes.Internal

-------------------------------------------------------------------------------
-- Type-Class
-------------------------------------------------------------------------------

class HashAlgorithm h => SodiumHashAlgorithm h where
    -- This function is in IO, it is "morally pure"
    -- and can be 'unsafePerformDupableIO'd.
    naclDigestPtr
        :: proxy h
        -> Ptr a  -- ^ input
        -> Int    -- ^ input length
        -> IO (MLockedSizedBytes (SizeHash h))

    -- TODO: provide interface for multi-part?
    -- That will be useful to hashing ('1' <> oldseed).

digestMLockedStorable
    :: forall h a proxy. (SodiumHashAlgorithm h, Storable a)
    => proxy h -> Ptr a -> MLockedSizedBytes (SizeHash h)
digestMLockedStorable :: forall h a (proxy :: * -> *).
(SodiumHashAlgorithm h, Storable a) =>
proxy h -> Ptr a -> MLockedSizedBytes (SizeHash h)
digestMLockedStorable proxy h
p Ptr a
ptr = forall a. IO a -> a
unsafeDupablePerformIO forall a b. (a -> b) -> a -> b
$
    forall h (proxy :: * -> *) a.
SodiumHashAlgorithm h =>
proxy h -> Ptr a -> Int -> IO (MLockedSizedBytes (SizeHash h))
naclDigestPtr proxy h
p Ptr a
ptr (forall a. Storable a => a -> Int
sizeOf (forall a. HasCallStack => a
undefined :: a))

digestMLockedBS
    :: forall h proxy. (SodiumHashAlgorithm h)
    => proxy h -> BS.ByteString -> MLockedSizedBytes (SizeHash h)
digestMLockedBS :: forall h (proxy :: * -> *).
SodiumHashAlgorithm h =>
proxy h -> ByteString -> MLockedSizedBytes (SizeHash h)
digestMLockedBS proxy h
p ByteString
bs = forall a. IO a -> a
unsafeDupablePerformIO forall a b. (a -> b) -> a -> b
$
    forall a. ByteString -> (CStringLen -> IO a) -> IO a
BS.useAsCStringLen ByteString
bs forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
ptr, Int
len) ->
    forall h (proxy :: * -> *) a.
SodiumHashAlgorithm h =>
proxy h -> Ptr a -> Int -> IO (MLockedSizedBytes (SizeHash h))
naclDigestPtr proxy h
p (forall a b. Ptr a -> Ptr b
castPtr Ptr CChar
ptr) Int
len

-------------------------------------------------------------------------------
-- Hash expansion
-------------------------------------------------------------------------------

expandHash
    :: forall h proxy. SodiumHashAlgorithm h
    => proxy h
    -> MLockedSizedBytes (SizeHash h)
    -> (MLockedSizedBytes (SizeHash h), MLockedSizedBytes (SizeHash h))
expandHash :: forall h (proxy :: * -> *).
SodiumHashAlgorithm h =>
proxy h
-> MLockedSizedBytes (SizeHash h)
-> (MLockedSizedBytes (SizeHash h), MLockedSizedBytes (SizeHash h))
expandHash proxy h
h (MLSB MLockedForeignPtr (PinnedSizedBytes (SizeHash h))
sfptr) = forall a. IO a -> a
unsafeDupablePerformIO forall a b. (a -> b) -> a -> b
$ do
    forall a b. MLockedForeignPtr a -> (Ptr a -> IO b) -> IO b
withMLockedForeignPtr MLockedForeignPtr (PinnedSizedBytes (SizeHash h))
sfptr forall a b. (a -> b) -> a -> b
$ \Ptr (PinnedSizedBytes (SizeHash h))
ptr -> do
        MLockedSizedBytes (SizeHash h)
l <- forall a b. CSize -> (Ptr a -> IO b) -> IO b
mlockedAlloca CSize
size1 forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr' -> do
            forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
ptr' (Word8
1 :: Word8)
            Ptr ()
_ <- forall a. Ptr a -> Ptr a -> CSize -> IO (Ptr ())
c_memcpy (forall a b. Ptr a -> Ptr b
castPtr (forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr Word8
ptr' Int
1)) Ptr (PinnedSizedBytes (SizeHash h))
ptr CSize
size
            forall h (proxy :: * -> *) a.
SodiumHashAlgorithm h =>
proxy h -> Ptr a -> Int -> IO (MLockedSizedBytes (SizeHash h))
naclDigestPtr proxy h
h Ptr Word8
ptr' (forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
size1)

        MLockedSizedBytes (SizeHash h)
r <- forall a b. CSize -> (Ptr a -> IO b) -> IO b
mlockedAlloca CSize
size1 forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr' -> do
            forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
ptr' (Word8
2 :: Word8)
            Ptr ()
_ <- forall a. Ptr a -> Ptr a -> CSize -> IO (Ptr ())
c_memcpy (forall a b. Ptr a -> Ptr b
castPtr (forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr Word8
ptr' Int
1)) Ptr (PinnedSizedBytes (SizeHash h))
ptr CSize
size
            forall h (proxy :: * -> *) a.
SodiumHashAlgorithm h =>
proxy h -> Ptr a -> Int -> IO (MLockedSizedBytes (SizeHash h))
naclDigestPtr proxy h
h Ptr Word8
ptr' (forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
size1)

        forall (m :: * -> *) a. Monad m => a -> m a
return (MLockedSizedBytes (SizeHash h)
l, MLockedSizedBytes (SizeHash h)
r)
  where
    size1 :: CSize
    size1 :: CSize
size1 = CSize
size forall a. Num a => a -> a -> a
+ CSize
1

    size :: CSize
    size :: CSize
size = forall a. Num a => Integer -> a
fromInteger forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy @(SizeHash h))

-------------------------------------------------------------------------------
-- Instances
-------------------------------------------------------------------------------

instance SodiumHashAlgorithm SHA256 where
    naclDigestPtr :: forall proxy a. proxy SHA256 -> Ptr a -> Int -> IO (MLockedSizedBytes (SizeHash SHA256))
    naclDigestPtr :: forall (proxy :: * -> *) a.
proxy SHA256
-> Ptr a -> Int -> IO (MLockedSizedBytes (SizeHash SHA256))
naclDigestPtr proxy SHA256
_ Ptr a
input Int
inputlen = do
        MLockedForeignPtr (PinnedSizedBytes CRYPTO_SHA256_BYTES)
output <- forall a. Storable a => IO (MLockedForeignPtr a)
allocMLockedForeignPtr
        forall a b. MLockedForeignPtr a -> (Ptr a -> IO b) -> IO b
withMLockedForeignPtr MLockedForeignPtr (PinnedSizedBytes CRYPTO_SHA256_BYTES)
output forall a b. (a -> b) -> a -> b
$ \Ptr (PinnedSizedBytes CRYPTO_SHA256_BYTES)
output' -> do
            Int
res <- SizedPtr CRYPTO_SHA256_BYTES -> Ptr CUChar -> CULLong -> IO Int
c_crypto_hash_sha256 (forall (n :: Nat). Ptr (PinnedSizedBytes n) -> SizedPtr n
ptrPsbToSizedPtr Ptr (PinnedSizedBytes CRYPTO_SHA256_BYTES)
output') (forall a b. Ptr a -> Ptr b
castPtr Ptr a
input) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
inputlen)
            forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int
res forall a. Eq a => a -> a -> Bool
== Int
0) forall a b. (a -> b) -> a -> b
$ do
                Errno
errno <- IO Errno
getErrno
                forall a. IOException -> IO a
ioException forall a b. (a -> b) -> a -> b
$ String -> Errno -> Maybe Handle -> Maybe String -> IOException
errnoToIOError String
"digestMLocked @SHA256: c_crypto_hash_sha256" Errno
errno forall a. Maybe a
Nothing forall a. Maybe a
Nothing

        forall (m :: * -> *) a. Monad m => a -> m a
return (forall (n :: Nat).
MLockedForeignPtr (PinnedSizedBytes n) -> MLockedSizedBytes n
MLSB MLockedForeignPtr (PinnedSizedBytes CRYPTO_SHA256_BYTES)
output)

-- Test that manually written numbers are the same as in libsodium
_testSHA256 :: SizeHash SHA256 :~: CRYPTO_SHA256_BYTES
_testSHA256 :: SizeHash SHA256 :~: CRYPTO_SHA256_BYTES
_testSHA256 = forall {k} (a :: k). a :~: a
Refl

instance SodiumHashAlgorithm Blake2b_256 where
    naclDigestPtr :: forall proxy a. proxy Blake2b_256 -> Ptr a -> Int -> IO (MLockedSizedBytes (SizeHash Blake2b_256))
    naclDigestPtr :: forall (proxy :: * -> *) a.
proxy Blake2b_256
-> Ptr a -> Int -> IO (MLockedSizedBytes (SizeHash Blake2b_256))
naclDigestPtr proxy Blake2b_256
_ Ptr a
input Int
inputlen = do
        MLockedForeignPtr (PinnedSizedBytes CRYPTO_SHA256_BYTES)
output <- forall a. Storable a => IO (MLockedForeignPtr a)
allocMLockedForeignPtr
        forall a b. MLockedForeignPtr a -> (Ptr a -> IO b) -> IO b
withMLockedForeignPtr MLockedForeignPtr (PinnedSizedBytes CRYPTO_SHA256_BYTES)
output forall a b. (a -> b) -> a -> b
$ \Ptr (PinnedSizedBytes CRYPTO_SHA256_BYTES)
output' -> do
            Int
res <- forall out key.
Ptr out
-> CSize -> Ptr CUChar -> CULLong -> Ptr key -> CSize -> IO Int
c_crypto_generichash_blake2b
                Ptr (PinnedSizedBytes CRYPTO_SHA256_BYTES)
output' (forall a. Num a => Integer -> a
fromInteger forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy @CRYPTO_BLAKE2B_256_BYTES))  -- output
                (forall a b. Ptr a -> Ptr b
castPtr Ptr a
input) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
inputlen)  -- input
                forall a. Ptr a
nullPtr CSize
0                                -- key, unused
            forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int
res forall a. Eq a => a -> a -> Bool
== Int
0) forall a b. (a -> b) -> a -> b
$ do
                Errno
errno <- IO Errno
getErrno
                forall a. IOException -> IO a
ioException forall a b. (a -> b) -> a -> b
$ String -> Errno -> Maybe Handle -> Maybe String -> IOException
errnoToIOError String
"digestMLocked @Blake2b_256: c_crypto_hash_sha256" Errno
errno forall a. Maybe a
Nothing forall a. Maybe a
Nothing

        forall (m :: * -> *) a. Monad m => a -> m a
return (forall (n :: Nat).
MLockedForeignPtr (PinnedSizedBytes n) -> MLockedSizedBytes n
MLSB MLockedForeignPtr (PinnedSizedBytes CRYPTO_SHA256_BYTES)
output)

_testBlake2b256 :: SizeHash Blake2b_256 :~: CRYPTO_BLAKE2B_256_BYTES
_testBlake2b256 :: SizeHash Blake2b_256 :~: CRYPTO_SHA256_BYTES
_testBlake2b256 = forall {k} (a :: k). a :~: a
Refl