{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module Cardano.Crypto.Libsodium.MLockedBytes.Internal (
    MLockedSizedBytes (..),
    mlsbZero,
    mlsbFromByteString,
    mlsbFromByteStringCheck,
    mlsbToByteString,
    mlsbUseAsCPtr,
    mlsbUseAsSizedPtr,
    mlsbFinalize,
) where

import Control.DeepSeq (NFData (..))
import Data.Proxy (Proxy (..))
import Foreign.C.Types (CSize (..))
import Foreign.ForeignPtr (castForeignPtr)
import Foreign.Ptr (Ptr, castPtr)
import GHC.TypeLits (KnownNat, natVal)
import NoThunks.Class (NoThunks, OnlyCheckWhnfNamed (..))
import System.IO.Unsafe (unsafeDupablePerformIO)
import Data.Word (Word8)

import Cardano.Foreign
import Cardano.Crypto.Libsodium.Memory.Internal
import Cardano.Crypto.Libsodium.C
import Cardano.Crypto.PinnedSizedBytes

import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BSI

{- HLINT ignore "Reduce duplication" -}

newtype MLockedSizedBytes n = MLSB (MLockedForeignPtr (PinnedSizedBytes n))
  deriving Context -> MLockedSizedBytes n -> IO (Maybe ThunkInfo)
Proxy (MLockedSizedBytes n) -> String
forall (n :: Nat).
Context -> MLockedSizedBytes n -> IO (Maybe ThunkInfo)
forall (n :: Nat). Proxy (MLockedSizedBytes n) -> String
forall a.
(Context -> a -> IO (Maybe ThunkInfo))
-> (Context -> a -> IO (Maybe ThunkInfo))
-> (Proxy a -> String)
-> NoThunks a
showTypeOf :: Proxy (MLockedSizedBytes n) -> String
$cshowTypeOf :: forall (n :: Nat). Proxy (MLockedSizedBytes n) -> String
wNoThunks :: Context -> MLockedSizedBytes n -> IO (Maybe ThunkInfo)
$cwNoThunks :: forall (n :: Nat).
Context -> MLockedSizedBytes n -> IO (Maybe ThunkInfo)
noThunks :: Context -> MLockedSizedBytes n -> IO (Maybe ThunkInfo)
$cnoThunks :: forall (n :: Nat).
Context -> MLockedSizedBytes n -> IO (Maybe ThunkInfo)
NoThunks via OnlyCheckWhnfNamed "MLockedSizedBytes" (MLockedSizedBytes n)

instance KnownNat n => Eq (MLockedSizedBytes n) where
    MLockedSizedBytes n
x == :: MLockedSizedBytes n -> MLockedSizedBytes n -> Bool
== MLockedSizedBytes n
y = forall a. Ord a => a -> a -> Ordering
compare MLockedSizedBytes n
x MLockedSizedBytes n
y forall a. Eq a => a -> a -> Bool
== Ordering
EQ

instance KnownNat n => Ord (MLockedSizedBytes n) where
    compare :: MLockedSizedBytes n -> MLockedSizedBytes n -> Ordering
compare (MLSB MLockedForeignPtr (PinnedSizedBytes n)
x) (MLSB MLockedForeignPtr (PinnedSizedBytes n)
y) = forall a. IO a -> a
unsafeDupablePerformIO forall a b. (a -> b) -> a -> b
$
        forall a b. MLockedForeignPtr a -> (Ptr a -> IO b) -> IO b
withMLockedForeignPtr MLockedForeignPtr (PinnedSizedBytes n)
x forall a b. (a -> b) -> a -> b
$ \Ptr (PinnedSizedBytes n)
x' ->
        forall a b. MLockedForeignPtr a -> (Ptr a -> IO b) -> IO b
withMLockedForeignPtr MLockedForeignPtr (PinnedSizedBytes n)
y forall a b. (a -> b) -> a -> b
$ \Ptr (PinnedSizedBytes n)
y' -> do
            Int
res <- forall a. Ptr a -> Ptr a -> CSize -> IO Int
c_sodium_compare Ptr (PinnedSizedBytes n)
x' Ptr (PinnedSizedBytes n)
y' (Word64 -> CSize
CSize (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
size))
            forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. Ord a => a -> a -> Ordering
compare Int
res Int
0)
      where
        size :: Integer
size = forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy @n)

instance KnownNat n => Show (MLockedSizedBytes n) where
    showsPrec :: Int -> MLockedSizedBytes n -> ShowS
showsPrec Int
d MLockedSizedBytes n
_ = Bool -> ShowS -> ShowS
showParen (Int
d forall a. Ord a => a -> a -> Bool
> Int
10)
        forall a b. (a -> b) -> a -> b
$ String -> ShowS
showString String
"_ :: MLockedSizedBytes "
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => Int -> a -> ShowS
showsPrec Int
11 (forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy @n))

instance NFData (MLockedSizedBytes n) where
    rnf :: MLockedSizedBytes n -> ()
rnf (MLSB MLockedForeignPtr (PinnedSizedBytes n)
p) = seq :: forall a b. a -> b -> b
seq MLockedForeignPtr (PinnedSizedBytes n)
p ()

-- | Note: this doesn't need to allocate mlocked memory,
-- but we do that for consistency
mlsbZero :: forall n. KnownNat n => MLockedSizedBytes n
mlsbZero :: forall (n :: Nat). KnownNat n => MLockedSizedBytes n
mlsbZero = forall a. IO a -> a
unsafeDupablePerformIO forall a b. (a -> b) -> a -> b
$ do
    MLockedForeignPtr (PinnedSizedBytes n)
fptr <- forall a. Storable a => IO (MLockedForeignPtr a)
allocMLockedForeignPtr
    forall a b. MLockedForeignPtr a -> (Ptr a -> IO b) -> IO b
withMLockedForeignPtr MLockedForeignPtr (PinnedSizedBytes n)
fptr forall a b. (a -> b) -> a -> b
$ \Ptr (PinnedSizedBytes n)
ptr -> do
        Ptr ()
_ <- forall a. Ptr a -> Int -> CSize -> IO (Ptr ())
c_memset (forall a b. Ptr a -> Ptr b
castPtr Ptr (PinnedSizedBytes n)
ptr) Int
0 CSize
size
        forall (m :: * -> *) a. Monad m => a -> m a
return ()
    forall (m :: * -> *) a. Monad m => a -> m a
return (forall (n :: Nat).
MLockedForeignPtr (PinnedSizedBytes n) -> MLockedSizedBytes n
MLSB MLockedForeignPtr (PinnedSizedBytes n)
fptr)
  where
    size  :: CSize
    size :: CSize
size = forall a. Num a => Integer -> a
fromInteger (forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy @n))

mlsbFromByteString :: forall n. KnownNat n => BS.ByteString -> MLockedSizedBytes n
mlsbFromByteString :: forall (n :: Nat). KnownNat n => ByteString -> MLockedSizedBytes n
mlsbFromByteString ByteString
bs = forall a. IO a -> a
unsafeDupablePerformIO forall a b. (a -> b) -> a -> b
$ forall a. ByteString -> (CStringLen -> IO a) -> IO a
BS.useAsCStringLen ByteString
bs forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
ptrBS, Int
len) -> do
    MLockedForeignPtr (PinnedSizedBytes n)
fptr <- forall a. Storable a => IO (MLockedForeignPtr a)
allocMLockedForeignPtr
    forall a b. MLockedForeignPtr a -> (Ptr a -> IO b) -> IO b
withMLockedForeignPtr MLockedForeignPtr (PinnedSizedBytes n)
fptr forall a b. (a -> b) -> a -> b
$ \Ptr (PinnedSizedBytes n)
ptr -> do
        Ptr ()
_ <- forall a. Ptr a -> Ptr a -> CSize -> IO (Ptr ())
c_memcpy (forall a b. Ptr a -> Ptr b
castPtr Ptr (PinnedSizedBytes n)
ptr) Ptr CChar
ptrBS (forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Ord a => a -> a -> a
min Int
len Int
size))
        forall (m :: * -> *) a. Monad m => a -> m a
return ()
    forall (m :: * -> *) a. Monad m => a -> m a
return (forall (n :: Nat).
MLockedForeignPtr (PinnedSizedBytes n) -> MLockedSizedBytes n
MLSB MLockedForeignPtr (PinnedSizedBytes n)
fptr)
  where
    size  :: Int
    size :: Int
size = forall a. Num a => Integer -> a
fromInteger (forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy @n))

mlsbFromByteStringCheck :: forall n. KnownNat n => BS.ByteString -> Maybe (MLockedSizedBytes n)
mlsbFromByteStringCheck :: forall (n :: Nat).
KnownNat n =>
ByteString -> Maybe (MLockedSizedBytes n)
mlsbFromByteStringCheck ByteString
bs
    | ByteString -> Int
BS.length ByteString
bs forall a. Eq a => a -> a -> Bool
/= Int
size = forall a. Maybe a
Nothing
    | Bool
otherwise = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a. IO a -> a
unsafeDupablePerformIO forall a b. (a -> b) -> a -> b
$ forall a. ByteString -> (CStringLen -> IO a) -> IO a
BS.useAsCStringLen ByteString
bs forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
ptrBS, Int
len) -> do
    MLockedForeignPtr (PinnedSizedBytes n)
fptr <- forall a. Storable a => IO (MLockedForeignPtr a)
allocMLockedForeignPtr
    forall a b. MLockedForeignPtr a -> (Ptr a -> IO b) -> IO b
withMLockedForeignPtr MLockedForeignPtr (PinnedSizedBytes n)
fptr forall a b. (a -> b) -> a -> b
$ \Ptr (PinnedSizedBytes n)
ptr -> do
        Ptr ()
_ <- forall a. Ptr a -> Ptr a -> CSize -> IO (Ptr ())
c_memcpy (forall a b. Ptr a -> Ptr b
castPtr Ptr (PinnedSizedBytes n)
ptr) Ptr CChar
ptrBS (forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Ord a => a -> a -> a
min Int
len Int
size))
        forall (m :: * -> *) a. Monad m => a -> m a
return ()
    forall (m :: * -> *) a. Monad m => a -> m a
return (forall (n :: Nat).
MLockedForeignPtr (PinnedSizedBytes n) -> MLockedSizedBytes n
MLSB MLockedForeignPtr (PinnedSizedBytes n)
fptr)
  where
    size  :: Int
    size :: Int
size = forall a. Num a => Integer -> a
fromInteger (forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy @n))

-- | /Note:/ the resulting 'BS.ByteString' will still refer to secure memory,
-- but the types don't prevent it from be exposed.
--
mlsbToByteString :: forall n. KnownNat n => MLockedSizedBytes n -> BS.ByteString
mlsbToByteString :: forall (n :: Nat). KnownNat n => MLockedSizedBytes n -> ByteString
mlsbToByteString (MLSB (SFP ForeignPtr (PinnedSizedBytes n)
fptr)) = ForeignPtr Word8 -> Int -> Int -> ByteString
BSI.PS (forall a b. ForeignPtr a -> ForeignPtr b
castForeignPtr ForeignPtr (PinnedSizedBytes n)
fptr) Int
0 Int
size where
    size  :: Int
    size :: Int
size = forall a. Num a => Integer -> a
fromInteger (forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy @n))

mlsbUseAsCPtr :: MLockedSizedBytes n -> (Ptr Word8 -> IO r) -> IO r
mlsbUseAsCPtr :: forall (n :: Nat) r.
MLockedSizedBytes n -> (Ptr Word8 -> IO r) -> IO r
mlsbUseAsCPtr (MLSB MLockedForeignPtr (PinnedSizedBytes n)
x) Ptr Word8 -> IO r
k = forall a b. MLockedForeignPtr a -> (Ptr a -> IO b) -> IO b
withMLockedForeignPtr MLockedForeignPtr (PinnedSizedBytes n)
x (Ptr Word8 -> IO r
k forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. Ptr a -> Ptr b
castPtr)

mlsbUseAsSizedPtr :: MLockedSizedBytes n -> (SizedPtr n -> IO r) -> IO r
mlsbUseAsSizedPtr :: forall (n :: Nat) r.
MLockedSizedBytes n -> (SizedPtr n -> IO r) -> IO r
mlsbUseAsSizedPtr (MLSB MLockedForeignPtr (PinnedSizedBytes n)
x) SizedPtr n -> IO r
k = forall a b. MLockedForeignPtr a -> (Ptr a -> IO b) -> IO b
withMLockedForeignPtr MLockedForeignPtr (PinnedSizedBytes n)
x (SizedPtr n -> IO r
k forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Nat). Ptr (PinnedSizedBytes n) -> SizedPtr n
ptrPsbToSizedPtr)

-- | Calls 'finalizeMLockedForeignPtr' on underlying pointer.
-- This function invalidates argument.
--
mlsbFinalize :: MLockedSizedBytes n -> IO ()
mlsbFinalize :: forall (n :: Nat). MLockedSizedBytes n -> IO ()
mlsbFinalize (MLSB MLockedForeignPtr (PinnedSizedBytes n)
ptr) = forall a. MLockedForeignPtr a -> IO ()
finalizeMLockedForeignPtr MLockedForeignPtr (PinnedSizedBytes n)
ptr