-- |
-- Module      : Data.ByteArray.ScrubbedBytes
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <[email protected]>
-- Stability   : Stable
-- Portability : GHC
--
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
module Data.ByteArray.ScrubbedBytes
    ( ScrubbedBytes
    ) where

import           GHC.Types
import           GHC.Prim
import           GHC.Ptr
import           GHC.Word
#if MIN_VERSION_base(4,15,0)
import           GHC.Exts (unsafeCoerce#)
#endif
#if MIN_VERSION_base(4,9,0)
import           Data.Semigroup
import           Data.Foldable (toList)
#else
import           Data.Monoid
#endif
import           Data.String (IsString(..))
import           Data.Typeable
import           Data.Memory.PtrMethods
import           Data.Memory.Internal.CompatPrim
import           Data.Memory.Internal.Compat     (unsafeDoIO)
import           Data.Memory.Internal.Imports
import           Data.ByteArray.Types
import           Foreign.Storable
#ifdef MIN_VERSION_basement
import           Basement.NormalForm
#endif

-- | ScrubbedBytes is a memory chunk which have the properties of:
--
-- * Being scrubbed after its goes out of scope.
--
-- * A Show instance that doesn't actually show any content
--
-- * A Eq instance that is constant time
--
data ScrubbedBytes = ScrubbedBytes (MutableByteArray# RealWorld)
  deriving (Typeable)

instance Show ScrubbedBytes where
    show :: ScrubbedBytes -> String
show ScrubbedBytes
_ = String
"<scrubbed-bytes>"

instance Eq ScrubbedBytes where
    == :: ScrubbedBytes -> ScrubbedBytes -> Bool
(==) = ScrubbedBytes -> ScrubbedBytes -> Bool
scrubbedBytesEq
instance Ord ScrubbedBytes where
    compare :: ScrubbedBytes -> ScrubbedBytes -> Ordering
compare = ScrubbedBytes -> ScrubbedBytes -> Ordering
scrubbedBytesCompare
#if MIN_VERSION_base(4,9,0)
instance Semigroup ScrubbedBytes where
    ScrubbedBytes
b1 <> :: ScrubbedBytes -> ScrubbedBytes -> ScrubbedBytes
<> ScrubbedBytes
b2      = forall a. IO a -> a
unsafeDoIO forall a b. (a -> b) -> a -> b
$ ScrubbedBytes -> ScrubbedBytes -> IO ScrubbedBytes
scrubbedBytesAppend ScrubbedBytes
b1 ScrubbedBytes
b2
    sconcat :: NonEmpty ScrubbedBytes -> ScrubbedBytes
sconcat       = forall a. IO a -> a
unsafeDoIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ScrubbedBytes] -> IO ScrubbedBytes
scrubbedBytesConcat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t a -> [a]
toList
#endif
instance Monoid ScrubbedBytes where
    mempty :: ScrubbedBytes
mempty        = forall a. IO a -> a
unsafeDoIO (Int -> IO ScrubbedBytes
newScrubbedBytes Int
0)
#if !(MIN_VERSION_base(4,11,0))
    mappend b1 b2 = unsafeDoIO $ scrubbedBytesAppend b1 b2
    mconcat       = unsafeDoIO . scrubbedBytesConcat
#endif
instance NFData ScrubbedBytes where
    rnf :: ScrubbedBytes -> ()
rnf ScrubbedBytes
b = ScrubbedBytes
b seq :: forall a b. a -> b -> b
`seq` ()
#ifdef MIN_VERSION_basement
instance NormalForm ScrubbedBytes where
    toNormalForm :: ScrubbedBytes -> ()
toNormalForm ScrubbedBytes
b = ScrubbedBytes
b seq :: forall a b. a -> b -> b
`seq` ()
#endif
instance IsString ScrubbedBytes where
    fromString :: String -> ScrubbedBytes
fromString = String -> ScrubbedBytes
scrubbedFromChar8

instance ByteArrayAccess ScrubbedBytes where
    length :: ScrubbedBytes -> Int
length        = ScrubbedBytes -> Int
sizeofScrubbedBytes
    withByteArray :: forall p a. ScrubbedBytes -> (Ptr p -> IO a) -> IO a
withByteArray = forall p a. ScrubbedBytes -> (Ptr p -> IO a) -> IO a
withPtr

instance ByteArray ScrubbedBytes where
    allocRet :: forall p a. Int -> (Ptr p -> IO a) -> IO (a, ScrubbedBytes)
allocRet = forall p a. Int -> (Ptr p -> IO a) -> IO (a, ScrubbedBytes)
scrubbedBytesAllocRet

newScrubbedBytes :: Int -> IO ScrubbedBytes
newScrubbedBytes :: Int -> IO ScrubbedBytes
newScrubbedBytes (I# Int#
sz)
    | Int# -> Bool
booleanPrim (Int#
sz Int# -> Int# -> Int#
<# Int#
0#)  = forall a. HasCallStack => String -> a
error String
"ScrubbedBytes: size must be >= 0"
    | Int# -> Bool
booleanPrim (Int#
sz Int# -> Int# -> Int#
==# Int#
0#) = forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s ->
        case forall d.
Int# -> Int# -> State# d -> (# State# d, MutableByteArray# d #)
newAlignedPinnedByteArray# Int#
0# Int#
8# State# RealWorld
s of
            (# State# RealWorld
s2, MutableByteArray# RealWorld
mba #) -> (# State# RealWorld
s2, MutableByteArray# RealWorld -> ScrubbedBytes
ScrubbedBytes MutableByteArray# RealWorld
mba #)
    | Bool
otherwise               = forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s ->
        case forall d.
Int# -> Int# -> State# d -> (# State# d, MutableByteArray# d #)
newAlignedPinnedByteArray# Int#
sz Int#
8# State# RealWorld
s of
            (# State# RealWorld
s1, MutableByteArray# RealWorld
mbarr #) ->
                let !scrubber :: State# RealWorld -> State# RealWorld
scrubber = Addr# -> State# RealWorld -> State# RealWorld
getScrubber (ByteArray# -> Addr#
byteArrayContents# (unsafeCoerce# :: forall a b. a -> b
unsafeCoerce# MutableByteArray# RealWorld
mbarr))
                    !mba :: ScrubbedBytes
mba      = MutableByteArray# RealWorld -> ScrubbedBytes
ScrubbedBytes MutableByteArray# RealWorld
mbarr
                 in case mkWeak# :: forall a b c.
a
-> b
-> (State# RealWorld -> (# State# RealWorld, c #))
-> State# RealWorld
-> (# State# RealWorld, Weak# b #)
mkWeak# MutableByteArray# RealWorld
mbarr () ((State# RealWorld -> State# RealWorld)
-> ScrubbedBytes -> State# RealWorld -> (# State# RealWorld, () #)
finalize State# RealWorld -> State# RealWorld
scrubber ScrubbedBytes
mba) State# RealWorld
s1 of
                    (# State# RealWorld
s2, Weak# ()
_ #) -> (# State# RealWorld
s2, ScrubbedBytes
mba #)
  where
    getScrubber :: Addr# -> State# RealWorld -> State# RealWorld
    getScrubber :: Addr# -> State# RealWorld -> State# RealWorld
getScrubber Addr#
addr State# RealWorld
s =
        let IO State# RealWorld -> (# State# RealWorld, () #)
scrubBytes = Ptr Word8 -> Word8 -> Int -> IO ()
memSet (forall a. Addr# -> Ptr a
Ptr Addr#
addr) Word8
0 (Int# -> Int
I# Int#
sz)
         in case State# RealWorld -> (# State# RealWorld, () #)
scrubBytes State# RealWorld
s of
                (# State# RealWorld
s', ()
_ #) -> State# RealWorld
s'

#if __GLASGOW_HASKELL__ >= 800
    finalize :: (State# RealWorld -> State# RealWorld) -> ScrubbedBytes -> State# RealWorld -> (# State# RealWorld, () #)
    finalize :: (State# RealWorld -> State# RealWorld)
-> ScrubbedBytes -> State# RealWorld -> (# State# RealWorld, () #)
finalize State# RealWorld -> State# RealWorld
scrubber mba :: ScrubbedBytes
mba@(ScrubbedBytes MutableByteArray# RealWorld
_) = \State# RealWorld
s1 ->
        case State# RealWorld -> State# RealWorld
scrubber State# RealWorld
s1 of
            State# RealWorld
s2 -> case touch# :: forall a. a -> State# RealWorld -> State# RealWorld
touch# ScrubbedBytes
mba State# RealWorld
s2 of
                    State# RealWorld
s3 -> (# State# RealWorld
s3, () #)
#else
    finalize :: (State# RealWorld -> State# RealWorld) -> ScrubbedBytes -> IO ()
    finalize scrubber mba@(ScrubbedBytes _) = IO $ \s1 -> do
        case scrubber s1 of
            s2 -> case touch# mba s2 of
                    s3 -> (# s3, () #)
#endif

scrubbedBytesAllocRet :: Int -> (Ptr p -> IO a) -> IO (a, ScrubbedBytes)
scrubbedBytesAllocRet :: forall p a. Int -> (Ptr p -> IO a) -> IO (a, ScrubbedBytes)
scrubbedBytesAllocRet Int
sz Ptr p -> IO a
f = do
    ScrubbedBytes
ba <- Int -> IO ScrubbedBytes
newScrubbedBytes Int
sz
    a
r  <- forall p a. ScrubbedBytes -> (Ptr p -> IO a) -> IO a
withPtr ScrubbedBytes
ba Ptr p -> IO a
f
    forall (m :: * -> *) a. Monad m => a -> m a
return (a
r, ScrubbedBytes
ba)

scrubbedBytesAlloc :: Int -> (Ptr p -> IO ()) -> IO ScrubbedBytes
scrubbedBytesAlloc :: forall p. Int -> (Ptr p -> IO ()) -> IO ScrubbedBytes
scrubbedBytesAlloc Int
sz Ptr p -> IO ()
f = do
    ScrubbedBytes
ba <- Int -> IO ScrubbedBytes
newScrubbedBytes Int
sz
    forall p a. ScrubbedBytes -> (Ptr p -> IO a) -> IO a
withPtr ScrubbedBytes
ba Ptr p -> IO ()
f
    forall (m :: * -> *) a. Monad m => a -> m a
return ScrubbedBytes
ba

scrubbedBytesConcat :: [ScrubbedBytes] -> IO ScrubbedBytes
scrubbedBytesConcat :: [ScrubbedBytes] -> IO ScrubbedBytes
scrubbedBytesConcat [ScrubbedBytes]
l = forall p. Int -> (Ptr p -> IO ()) -> IO ScrubbedBytes
scrubbedBytesAlloc Int
retLen ([ScrubbedBytes] -> Ptr Word8 -> IO ()
copy [ScrubbedBytes]
l)
  where
    retLen :: Int
retLen = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map ScrubbedBytes -> Int
sizeofScrubbedBytes [ScrubbedBytes]
l

    copy :: [ScrubbedBytes] -> Ptr Word8 -> IO ()
copy []     Ptr Word8
_   = forall (m :: * -> *) a. Monad m => a -> m a
return ()
    copy (ScrubbedBytes
x:[ScrubbedBytes]
xs) Ptr Word8
dst = do
        forall p a. ScrubbedBytes -> (Ptr p -> IO a) -> IO a
withPtr ScrubbedBytes
x forall a b. (a -> b) -> a -> b
$ \Ptr Word8
src -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memCopy Ptr Word8
dst Ptr Word8
src Int
chunkLen
        [ScrubbedBytes] -> Ptr Word8 -> IO ()
copy [ScrubbedBytes]
xs (Ptr Word8
dst forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
chunkLen)
      where
        chunkLen :: Int
chunkLen = ScrubbedBytes -> Int
sizeofScrubbedBytes ScrubbedBytes
x

scrubbedBytesAppend :: ScrubbedBytes -> ScrubbedBytes -> IO ScrubbedBytes
scrubbedBytesAppend :: ScrubbedBytes -> ScrubbedBytes -> IO ScrubbedBytes
scrubbedBytesAppend ScrubbedBytes
b1 ScrubbedBytes
b2 = forall p. Int -> (Ptr p -> IO ()) -> IO ScrubbedBytes
scrubbedBytesAlloc Int
retLen forall a b. (a -> b) -> a -> b
$ \Ptr Word8
dst -> do
    forall p a. ScrubbedBytes -> (Ptr p -> IO a) -> IO a
withPtr ScrubbedBytes
b1 forall a b. (a -> b) -> a -> b
$ \Ptr Word8
s1 -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memCopy Ptr Word8
dst                  Ptr Word8
s1 Int
len1
    forall p a. ScrubbedBytes -> (Ptr p -> IO a) -> IO a
withPtr ScrubbedBytes
b2 forall a b. (a -> b) -> a -> b
$ \Ptr Word8
s2 -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memCopy (Ptr Word8
dst forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
len1) Ptr Word8
s2 Int
len2
  where
    len1 :: Int
len1   = ScrubbedBytes -> Int
sizeofScrubbedBytes ScrubbedBytes
b1
    len2 :: Int
len2   = ScrubbedBytes -> Int
sizeofScrubbedBytes ScrubbedBytes
b2
    retLen :: Int
retLen = Int
len1 forall a. Num a => a -> a -> a
+ Int
len2


sizeofScrubbedBytes :: ScrubbedBytes -> Int
sizeofScrubbedBytes :: ScrubbedBytes -> Int
sizeofScrubbedBytes (ScrubbedBytes MutableByteArray# RealWorld
mba) = Int# -> Int
I# (forall d. MutableByteArray# d -> Int#
sizeofMutableByteArray# MutableByteArray# RealWorld
mba)

withPtr :: ScrubbedBytes -> (Ptr p -> IO a) -> IO a
withPtr :: forall p a. ScrubbedBytes -> (Ptr p -> IO a) -> IO a
withPtr b :: ScrubbedBytes
b@(ScrubbedBytes MutableByteArray# RealWorld
mba) Ptr p -> IO a
f = do
    a
a <- Ptr p -> IO a
f (forall a. Addr# -> Ptr a
Ptr (ByteArray# -> Addr#
byteArrayContents# (unsafeCoerce# :: forall a b. a -> b
unsafeCoerce# MutableByteArray# RealWorld
mba)))
    ScrubbedBytes -> IO ()
touchScrubbedBytes ScrubbedBytes
b
    forall (m :: * -> *) a. Monad m => a -> m a
return a
a

touchScrubbedBytes :: ScrubbedBytes -> IO ()
touchScrubbedBytes :: ScrubbedBytes -> IO ()
touchScrubbedBytes (ScrubbedBytes MutableByteArray# RealWorld
mba) = forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s -> case touch# :: forall a. a -> State# RealWorld -> State# RealWorld
touch# MutableByteArray# RealWorld
mba State# RealWorld
s of State# RealWorld
s' -> (# State# RealWorld
s', () #)

scrubbedBytesEq :: ScrubbedBytes -> ScrubbedBytes -> Bool
scrubbedBytesEq :: ScrubbedBytes -> ScrubbedBytes -> Bool
scrubbedBytesEq ScrubbedBytes
a ScrubbedBytes
b
    | Int
l1 forall a. Eq a => a -> a -> Bool
/= Int
l2  = Bool
False
    | Bool
otherwise = forall a. IO a -> a
unsafeDoIO forall a b. (a -> b) -> a -> b
$ forall p a. ScrubbedBytes -> (Ptr p -> IO a) -> IO a
withPtr ScrubbedBytes
a forall a b. (a -> b) -> a -> b
$ \Ptr Word8
p1 -> forall p a. ScrubbedBytes -> (Ptr p -> IO a) -> IO a
withPtr ScrubbedBytes
b forall a b. (a -> b) -> a -> b
$ \Ptr Word8
p2 -> Ptr Word8 -> Ptr Word8 -> Int -> IO Bool
memConstEqual Ptr Word8
p1 Ptr Word8
p2 Int
l1
  where
        l1 :: Int
l1 = ScrubbedBytes -> Int
sizeofScrubbedBytes ScrubbedBytes
a
        l2 :: Int
l2 = ScrubbedBytes -> Int
sizeofScrubbedBytes ScrubbedBytes
b

scrubbedBytesCompare :: ScrubbedBytes -> ScrubbedBytes -> Ordering
scrubbedBytesCompare :: ScrubbedBytes -> ScrubbedBytes -> Ordering
scrubbedBytesCompare b1 :: ScrubbedBytes
b1@(ScrubbedBytes MutableByteArray# RealWorld
m1) b2 :: ScrubbedBytes
b2@(ScrubbedBytes MutableByteArray# RealWorld
m2) = forall a. IO a -> a
unsafeDoIO forall a b. (a -> b) -> a -> b
$ Int -> IO Ordering
loop Int
0
  where
    !l1 :: Int
l1  = ScrubbedBytes -> Int
sizeofScrubbedBytes ScrubbedBytes
b1
    !l2 :: Int
l2  = ScrubbedBytes -> Int
sizeofScrubbedBytes ScrubbedBytes
b2
    !len :: Int
len = forall a. Ord a => a -> a -> a
min Int
l1 Int
l2

    loop :: Int -> IO Ordering
loop !Int
i
        | Int
i forall a. Eq a => a -> a -> Bool
== Int
len =
            if Int
l1 forall a. Eq a => a -> a -> Bool
== Int
l2
                then forall (f :: * -> *) a. Applicative f => a -> f a
pure Ordering
EQ
                else if Int
l1 forall a. Ord a => a -> a -> Bool
> Int
l2 then forall (f :: * -> *) a. Applicative f => a -> f a
pure Ordering
GT
                                else forall (f :: * -> *) a. Applicative f => a -> f a
pure Ordering
LT
        | Bool
otherwise = do
            Word8
e1 <- MutableByteArray# RealWorld -> Int -> IO Word8
read8 MutableByteArray# RealWorld
m1 Int
i
            Word8
e2 <- MutableByteArray# RealWorld -> Int -> IO Word8
read8 MutableByteArray# RealWorld
m2 Int
i
            if Word8
e1 forall a. Eq a => a -> a -> Bool
== Word8
e2
                then Int -> IO Ordering
loop (Int
iforall a. Num a => a -> a -> a
+Int
1)
                else if Word8
e1 forall a. Ord a => a -> a -> Bool
< Word8
e2 then forall (f :: * -> *) a. Applicative f => a -> f a
pure Ordering
LT
                                else forall (f :: * -> *) a. Applicative f => a -> f a
pure Ordering
GT

    read8 :: MutableByteArray# RealWorld -> Int -> IO Word8
read8 MutableByteArray# RealWorld
m (I# Int#
i) = forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s -> case forall d.
MutableByteArray# d -> Int# -> State# d -> (# State# d, Word8# #)
readWord8Array# MutableByteArray# RealWorld
m Int#
i State# RealWorld
s of
                                    (# State# RealWorld
s2, Word8#
e #) -> (# State# RealWorld
s2, Word8# -> Word8
W8# Word8#
e #)

scrubbedFromChar8 :: [Char] -> ScrubbedBytes
scrubbedFromChar8 :: String -> ScrubbedBytes
scrubbedFromChar8 String
l = forall a. IO a -> a
unsafeDoIO forall a b. (a -> b) -> a -> b
$ forall p. Int -> (Ptr p -> IO ()) -> IO ScrubbedBytes
scrubbedBytesAlloc Int
len (String -> Ptr Word8 -> IO ()
fill String
l)
  where
    len :: Int
len = forall (t :: * -> *) a. Foldable t => t a -> Int
Prelude.length String
l
    fill :: [Char] -> Ptr Word8 -> IO ()
    fill :: String -> Ptr Word8 -> IO ()
fill []     Ptr Word8
_  = forall (m :: * -> *) a. Monad m => a -> m a
return ()
    fill (Char
x:String
xs) !Ptr Word8
p = forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
p (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall a. Enum a => a -> Int
fromEnum Char
x) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> String -> Ptr Word8 -> IO ()
fill String
xs (Ptr Word8
p forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1)