{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module Cardano.Crypto.Libsodium.Memory.Internal (
  -- * High-level memory management
  MLockedForeignPtr (..),
  withMLockedForeignPtr,
  allocMLockedForeignPtr,
  finalizeMLockedForeignPtr,
  traceMLockedForeignPtr,
  -- * Low-level memory function
  mlockedAlloca,
  mlockedAllocaSized,
  sodiumMalloc,
  sodiumFree,
) where

import Control.Exception (bracket)
import Control.Monad (when)
import Data.Coerce (coerce)
import Data.Proxy (Proxy (..))
import Foreign.C.Error (errnoToIOError, getErrno)
import Foreign.C.Types (CSize (..))
import Foreign.ForeignPtr (ForeignPtr, newForeignPtr, withForeignPtr, finalizeForeignPtr)
import Foreign.Ptr (Ptr, nullPtr)
import Foreign.Storable (Storable (alignment, sizeOf, peek))
import GHC.TypeLits (KnownNat, natVal)
import GHC.IO.Exception (ioException)
import NoThunks.Class (NoThunks, OnlyCheckWhnfNamed (..))

import Cardano.Foreign
import Cardano.Crypto.Libsodium.C

-- | Foreign pointer to securely allocated memory.
newtype MLockedForeignPtr a = SFP { forall a. MLockedForeignPtr a -> ForeignPtr a
_unwrapMLockedForeignPtr :: ForeignPtr a }
  deriving Context -> MLockedForeignPtr a -> IO (Maybe ThunkInfo)
Proxy (MLockedForeignPtr a) -> String
forall a. Context -> MLockedForeignPtr a -> IO (Maybe ThunkInfo)
forall a. Proxy (MLockedForeignPtr a) -> String
forall a.
(Context -> a -> IO (Maybe ThunkInfo))
-> (Context -> a -> IO (Maybe ThunkInfo))
-> (Proxy a -> String)
-> NoThunks a
showTypeOf :: Proxy (MLockedForeignPtr a) -> String
$cshowTypeOf :: forall a. Proxy (MLockedForeignPtr a) -> String
wNoThunks :: Context -> MLockedForeignPtr a -> IO (Maybe ThunkInfo)
$cwNoThunks :: forall a. Context -> MLockedForeignPtr a -> IO (Maybe ThunkInfo)
noThunks :: Context -> MLockedForeignPtr a -> IO (Maybe ThunkInfo)
$cnoThunks :: forall a. Context -> MLockedForeignPtr a -> IO (Maybe ThunkInfo)
NoThunks via OnlyCheckWhnfNamed "MLockedForeignPtr" (MLockedForeignPtr a)

withMLockedForeignPtr :: forall a b. MLockedForeignPtr a -> (Ptr a -> IO b) -> IO b
withMLockedForeignPtr :: forall a b. MLockedForeignPtr a -> (Ptr a -> IO b) -> IO b
withMLockedForeignPtr = coerce :: forall a b. Coercible a b => a -> b
coerce (forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr @a @b)

finalizeMLockedForeignPtr :: forall a. MLockedForeignPtr a -> IO ()
finalizeMLockedForeignPtr :: forall a. MLockedForeignPtr a -> IO ()
finalizeMLockedForeignPtr = coerce :: forall a b. Coercible a b => a -> b
coerce (forall a. ForeignPtr a -> IO ()
finalizeForeignPtr @a)

traceMLockedForeignPtr :: (Storable a, Show a) => MLockedForeignPtr a -> IO ()
traceMLockedForeignPtr :: forall a. (Storable a, Show a) => MLockedForeignPtr a -> IO ()
traceMLockedForeignPtr MLockedForeignPtr a
fptr = forall a b. MLockedForeignPtr a -> (Ptr a -> IO b) -> IO b
withMLockedForeignPtr MLockedForeignPtr a
fptr forall a b. (a -> b) -> a -> b
$ \Ptr a
ptr -> do
    a
a <- forall a. Storable a => Ptr a -> IO a
peek Ptr a
ptr
    forall a. Show a => a -> IO ()
print a
a

{-# DEPRECATED traceMLockedForeignPtr "Don't leave traceMLockedForeignPtr in production" #-}

-- | Allocate secure memory using 'c_sodium_malloc'.
--
-- <https://libsodium.gitbook.io/doc/memory_management>
--
allocMLockedForeignPtr :: Storable a => IO (MLockedForeignPtr a)
allocMLockedForeignPtr :: forall a. Storable a => IO (MLockedForeignPtr a)
allocMLockedForeignPtr = forall b. Storable b => b -> IO (MLockedForeignPtr b)
impl forall a. HasCallStack => a
undefined where
    impl :: forall b. Storable b => b -> IO (MLockedForeignPtr b)
    impl :: forall b. Storable b => b -> IO (MLockedForeignPtr b)
impl b
b = do
        Ptr b
ptr <- forall a. CSize -> IO (Ptr a)
sodiumMalloc CSize
size
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. ForeignPtr a -> MLockedForeignPtr a
SFP (forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr forall a. FunPtr (Ptr a -> IO ())
c_sodium_free_funptr Ptr b
ptr)

      where
        size :: CSize
        size :: CSize
size = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
size''

        size' :: Int
        size' :: Int
size' = forall a. Storable a => a -> Int
sizeOf b
b

        align :: Int
        align :: Int
align = forall a. Storable a => a -> Int
alignment b
b

        size'' :: Int
        size'' :: Int
size''
            | Int
m forall a. Eq a => a -> a -> Bool
== Int
0    = Int
size'
            | Bool
otherwise = (Int
q forall a. Num a => a -> a -> a
+ Int
1) forall a. Num a => a -> a -> a
* Int
align
          where
            (Int
q,Int
m) = Int
size' forall a. Integral a => a -> a -> (a, a)
`divMod` Int
align

sodiumMalloc :: CSize -> IO (Ptr a)
sodiumMalloc :: forall a. CSize -> IO (Ptr a)
sodiumMalloc CSize
size = do
    Ptr a
ptr <- forall a. CSize -> IO (Ptr a)
c_sodium_malloc CSize
size
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Ptr a
ptr forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr) forall a b. (a -> b) -> a -> b
$ do
        Errno
errno <- IO Errno
getErrno
        forall a. IOException -> IO a
ioException forall a b. (a -> b) -> a -> b
$ String -> Errno -> Maybe Handle -> Maybe String -> IOException
errnoToIOError String
"c_sodium_malloc" Errno
errno forall a. Maybe a
Nothing forall a. Maybe a
Nothing
    forall (m :: * -> *) a. Monad m => a -> m a
return Ptr a
ptr

sodiumFree :: Ptr a -> IO ()
sodiumFree :: forall a. Ptr a -> IO ()
sodiumFree = forall a. Ptr a -> IO ()
c_sodium_free

mlockedAlloca :: forall a b. CSize -> (Ptr a -> IO b) -> IO b
mlockedAlloca :: forall a b. CSize -> (Ptr a -> IO b) -> IO b
mlockedAlloca CSize
size = forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (forall a. CSize -> IO (Ptr a)
sodiumMalloc CSize
size) forall a. Ptr a -> IO ()
sodiumFree

mlockedAllocaSized :: forall n b. KnownNat n => (SizedPtr n -> IO b) -> IO b
mlockedAllocaSized :: forall (n :: Nat) b. KnownNat n => (SizedPtr n -> IO b) -> IO b
mlockedAllocaSized SizedPtr n -> IO b
k = forall a b. CSize -> (Ptr a -> IO b) -> IO b
mlockedAlloca CSize
size (SizedPtr n -> IO b
k forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Nat). Ptr Void -> SizedPtr n
SizedPtr) where
    size :: CSize
    size :: CSize
size = forall a. Num a => Integer -> a
fromInteger (forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy @n))