{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE UnboxedTuples #-}

module Cardano.Crypto.Util
  ( Empty
  , SignableRepresentation(..)
  , getRandomWord64

    -- * Simple serialisation used in mock instances
  , readBinaryWord64
  , writeBinaryWord64
  , readBinaryNatural
  , writeBinaryNatural
  , splitsAt

  -- * Low level conversions
  , bytesToNatural
  , naturalToBytes

  -- * ByteString manipulation
  , slice

  -- * Base16 conversion
  , decodeHexByteString
  , decodeHexString
  , decodeHexStringQ
  )
where

import           Control.Monad (unless)
import           Data.Bifunctor (first)
import           Data.Char (isAscii)
import           Data.Word
import           Numeric.Natural
import           Data.Bits
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as BSC8
import qualified Data.ByteString.Internal as BS
import           Data.ByteString (ByteString)
import           Data.ByteString.Base16 as BS16
import           Language.Haskell.TH

import           GHC.Exts (Addr#, Int#, Word#)
import qualified GHC.Exts    as GHC
import qualified GHC.Natural as GHC
import           Foreign.ForeignPtr (withForeignPtr)

import           Crypto.Random (MonadRandom (..))

#if __GLASGOW_HASKELL__ >= 900
-- Use the GHC version here because this is compiler dependent, and only indirectly lib dependent.
import           GHC.Num.Integer (integerFromAddr#)
import           GHC.IO (IO (..), unsafeDupablePerformIO)
#else
import qualified GHC.Integer.GMP.Internals as GMP
import           GHC.IO (unsafeDupablePerformIO)
#endif

class Empty a
instance Empty a



--
-- Signable
--

-- | A class of types that have a representation in bytes that can be used
-- for signing and verifying.
--
class SignableRepresentation a where
    getSignableRepresentation :: a -> ByteString

instance SignableRepresentation ByteString where
    getSignableRepresentation :: ByteString -> ByteString
getSignableRepresentation = forall a. a -> a
id


--
-- Random source used in some mock instances
--

getRandomWord64 :: MonadRandom m => m Word64
getRandomWord64 :: forall (m :: * -> *). MonadRandom m => m Word64
getRandomWord64 = ByteString -> Word64
readBinaryWord64 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
8


--
-- Really simple serialisation used in some mock instances
--

readBinaryWord64 :: ByteString -> Word64
readBinaryWord64 :: ByteString -> Word64
readBinaryWord64 =
  forall a. (a -> Word8 -> a) -> a -> ByteString -> a
BS.foldl' (\Word64
acc Word8
w8 -> forall a. Bits a => a -> Int -> a
unsafeShiftL Word64
acc Int
8 forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
w8) Word64
0


readBinaryNatural :: ByteString -> Natural
readBinaryNatural :: ByteString -> Natural
readBinaryNatural =
  forall a. (a -> Word8 -> a) -> a -> ByteString -> a
BS.foldl' (\Natural
acc Word8
w8 -> forall a. Bits a => a -> Int -> a
unsafeShiftL Natural
acc Int
8 forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
w8) Natural
0


writeBinaryWord64 :: Word64 -> ByteString
writeBinaryWord64 :: Word64 -> ByteString
writeBinaryWord64 =
    ByteString -> ByteString
BS.reverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst
  forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a.
Int -> (a -> Maybe (Word8, a)) -> a -> (ByteString, Maybe a)
BS.unfoldrN Int
8 (\Word64
w -> forall a. a -> Maybe a
Just (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
w, forall a. Bits a => a -> Int -> a
unsafeShiftR Word64
w Int
8))

writeBinaryNatural :: Int -> Natural -> ByteString
writeBinaryNatural :: Int -> Natural -> ByteString
writeBinaryNatural Int
bytes =
    ByteString -> ByteString
BS.reverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst
  forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a.
Int -> (a -> Maybe (Word8, a)) -> a -> (ByteString, Maybe a)
BS.unfoldrN Int
bytes (\Natural
w -> forall a. a -> Maybe a
Just (forall a b. (Integral a, Num b) => a -> b
fromIntegral Natural
w, forall a. Bits a => a -> Int -> a
unsafeShiftR Natural
w Int
8))

splitsAt :: [Int] -> ByteString -> [ByteString]
splitsAt :: [Int] -> ByteString -> [ByteString]
splitsAt = Int -> [Int] -> ByteString -> [ByteString]
go Int
0
  where
    go :: Int -> [Int] -> ByteString -> [ByteString]
go !Int
_   [] ByteString
bs
      | ByteString -> Bool
BS.null ByteString
bs         = []
      | Bool
otherwise          = [ByteString
bs]

    go !Int
off (Int
sz:[Int]
szs) ByteString
bs
      | ByteString -> Int
BS.length ByteString
bs forall a. Ord a => a -> a -> Bool
>= Int
sz = Int -> ByteString -> ByteString
BS.take Int
sz ByteString
bs forall a. a -> [a] -> [a]
: Int -> [Int] -> ByteString -> [ByteString]
go (Int
offforall a. Num a => a -> a -> a
+Int
sz) [Int]
szs (Int -> ByteString -> ByteString
BS.drop Int
sz ByteString
bs)
      | Bool
otherwise          = []

-- | Create a 'Natural' out of a 'ByteString', in big endian.
--
-- This is fast enough to use in production.
--
bytesToNatural :: ByteString -> Natural
bytesToNatural :: ByteString -> Natural
bytesToNatural = Integer -> Natural
GHC.naturalFromInteger forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Integer
bytesToInteger

-- | The inverse of 'bytesToNatural'. Note that this is a naive implementation
-- and only suitable for tests.
--
naturalToBytes :: Int -> Natural -> ByteString
naturalToBytes :: Int -> Natural -> ByteString
naturalToBytes = Int -> Natural -> ByteString
writeBinaryNatural

bytesToInteger :: ByteString -> Integer
bytesToInteger :: ByteString -> Integer
bytesToInteger (BS.PS ForeignPtr Word8
fp (GHC.I# Int#
off#) (GHC.I# Int#
len#)) =
    -- This should be safe since we're simply reading from ByteString (which is
    -- immutable) and GMP allocates a new memory for the Integer, i.e., there is
    -- no mutation involved.
    forall a. IO a -> a
unsafeDupablePerformIO forall a b. (a -> b) -> a -> b
$
      forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fp forall a b. (a -> b) -> a -> b
$ \(GHC.Ptr Addr#
addr#) ->
        let addrOff# :: Addr#
addrOff# = Addr#
addr# Addr# -> Int# -> Addr#
`GHC.plusAddr#` Int#
off#
        -- The last parmaeter (`1#`) tells the import function to use big
        -- endian encoding.
        in Addr# -> Word# -> Int# -> IO Integer
importIntegerFromAddr Addr#
addrOff# (Int# -> Word#
GHC.int2Word# Int#
len#) Int#
1#
  where
    importIntegerFromAddr :: Addr# -> Word# -> Int# -> IO Integer
#if __GLASGOW_HASKELL__ >= 900
-- Use the GHC version here because this is compiler dependent, and only indirectly lib dependent.
    importIntegerFromAddr :: Addr# -> Word# -> Int# -> IO Integer
importIntegerFromAddr Addr#
addr Word#
sz Int#
endian = forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s ->
       case forall s.
Word# -> Addr# -> Int# -> State# s -> (# State# s, Integer #)
integerFromAddr# Word#
sz Addr#
addr Int#
endian State# RealWorld
s of
          (# State# RealWorld
s', Integer
i #) -> (# State# RealWorld
s', Integer
i #)
#else
    importIntegerFromAddr = GMP.importIntegerFromAddr
#endif

slice :: Word -> Word -> ByteString -> ByteString
slice :: Word -> Word -> ByteString -> ByteString
slice Word
offset Word
size = Int -> ByteString -> ByteString
BS.take (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
size)
                  forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ByteString -> ByteString
BS.drop (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
offset)

-- | Decode base16 ByteString, while ensuring expected length.
decodeHexByteString :: ByteString -> Int -> Either String ByteString
decodeHexByteString :: ByteString -> Int -> Either String ByteString
decodeHexByteString ByteString
bsHex Int
lenExpected = do
  ByteString
bs <- forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (String
"Malformed hex: " forall a. [a] -> [a] -> [a]
++) forall a b. (a -> b) -> a -> b
$ ByteString -> Either String ByteString
BS16.decode ByteString
bsHex
  let lenActual :: Int
lenActual = ByteString -> Int
BS.length ByteString
bs
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int
lenExpected forall a. Eq a => a -> a -> Bool
== Int
lenActual) forall a b. (a -> b) -> a -> b
$
    forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ String
"Expected in decoded form to be: " forall a. [a] -> [a] -> [a]
++
           forall a. Show a => a -> String
show Int
lenExpected forall a. [a] -> [a] -> [a]
++ String
" bytes, but got: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
lenActual
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
bs


-- | Decode base16 String, while ensuring expected length. Unlike
-- `decodeHexByteString` this function expects a '0x' prefix.
decodeHexString :: String -> Int -> Either String ByteString
decodeHexString :: String -> Int -> Either String ByteString
decodeHexString String
hexStr' Int
lenExpected = do
  let hexStr :: String
hexStr =
        case String
hexStr' of
          Char
'0':Char
'x':String
str -> String
str
          String
str -> String
str
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Char -> Bool
isAscii String
hexStr) forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ String
"Input string contains invalid characters: " forall a. [a] -> [a] -> [a]
++ String
hexStr
  ByteString -> Int -> Either String ByteString
decodeHexByteString (String -> ByteString
BSC8.pack String
hexStr) Int
lenExpected

-- | Decode a `String` with Hex characters, while ensuring expected length.
decodeHexStringQ :: String -> Int -> Q Exp
decodeHexStringQ :: String -> Int -> Q Exp
decodeHexStringQ String
hexStr Int
n = do
  case String -> Int -> Either String ByteString
decodeHexString String
hexStr Int
n of
    Left String
err -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"<decodeHexByteString>: " forall a. [a] -> [a] -> [a]
++ String
err
    Right ByteString
_  -> [| either error id (decodeHexString hexStr n) |]