{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE FlexibleInstances #-}
#if USE_DEFAULT_SIGNATURES
{-# LANGUAGE DefaultSignatures #-}
#endif
{-# LANGUAGE TypeFamilies #-}
-- Foreign.ForeignPtr is unsafe before GHC-7.10
#if __GLASGOW_HASKELL__ >= 704 && MIN_VERSION_base(4,8,0)
{-# LANGUAGE Safe #-}
#elif __GLASGOW_HASKELL__ >= 702
{-# LANGUAGE Trustworthy #-}
#endif
--------------------------------------------------------------------------------
-- |
-- Module      :  Data.StateVar
-- Copyright   :  (c) Edward Kmett 2014-2019, Sven Panne 2009-2021
-- License     :  BSD3
-- 
-- Maintainer  :  Sven Panne <[email protected]>
-- Stability   :  stable
-- Portability :  portable
--
-- State variables are references in the IO monad, like 'IORef's or parts of
-- the OpenGL state. Note that state variables are not neccessarily writable or
-- readable, they may come in read-only or write-only flavours, too. As a very
-- simple example for a state variable, consider an explicitly allocated memory
-- buffer. This buffer could easily be converted into a 'StateVar':
--
-- @
-- makeStateVarFromPtr :: Storable a => Ptr a -> StateVar a
-- makeStateVarFromPtr p = makeStateVar (peek p) (poke p)
-- @
--
-- The example below puts 11 into a state variable (i.e. into the buffer),
-- increments the contents of the state variable by 22, and finally prints the
-- resulting content:
--
-- @
--   do p <- malloc :: IO (Ptr Int)
--      let v = makeStateVarFromPtr p
--      v $= 11
--      v $~ (+ 22)
--      x <- get v
--      print x
-- @
--
-- However, 'Ptr' can be used directly through the same API:
--
-- @
--   do p <- malloc :: IO (Ptr Int)
--      p $= 11
--      p $~ (+ 22)
--      x <- get p
--      print x
-- @
--
-- 'IORef's are state variables, too, so an example with them looks extremely
-- similiar:
--
-- @
--   do v <- newIORef (0 :: Int)
--      v $= 11
--      v $~ (+ 22)
--      x <- get v
--      print x
-- @
--------------------------------------------------------------------------------

module Data.StateVar
  (
  -- * Readable State Variables
    HasGetter(get)
  , GettableStateVar, makeGettableStateVar
  -- * Writable State Variables
  , HasSetter(($=)), ($=!)
  , SettableStateVar(SettableStateVar), makeSettableStateVar
  -- * Updatable State Variables
  , HasUpdate(($~), ($~!))
  , StateVar(StateVar), makeStateVar
  , mapStateVar
  ) where

import Control.Concurrent.STM
import Control.Monad.IO.Class
import Data.IORef
import Data.Typeable
import Foreign.ForeignPtr
import Foreign.Ptr
import Foreign.Storable
#if MIN_VERSION_base(4,12,0)
import Data.Functor.Contravariant
#endif

--------------------------------------------------------------------
-- * StateVar
--------------------------------------------------------------------

-- | A concrete implementation of a readable and writable state variable,
-- carrying one IO action to read the value and another IO action to write the
-- new value. This data type represents a piece of mutable, imperative state
-- with possible side-effects. These tend to encapsulate all sorts tricky
-- behavior in external libraries, and may well throw exceptions. Inhabitants
-- __should__ satsify the following properties:
--
-- * In the absence of concurrent mutation from other threads or a thrown
-- exception:
--
-- @
-- do x <- 'get' v; v '$=' y; v '$=' x
-- @
--
-- should restore the previous state.
--
-- * Ideally, in the absence of thrown exceptions:
--
-- @
-- v '$=' a >> 'get' v
-- @
--
-- should return @a@, regardless of @a@. In practice some 'StateVar's only
-- permit a very limited range of value assignments, and do not report failure.
data StateVar a = StateVar (IO a) (a -> IO ()) deriving Typeable

#if MIN_VERSION_base(4,12,0)
instance Contravariant SettableStateVar where
  contramap :: forall a' a. (a' -> a) -> SettableStateVar a -> SettableStateVar a'
contramap a' -> a
f (SettableStateVar a -> IO ()
k) = forall a. (a -> IO ()) -> SettableStateVar a
SettableStateVar (a -> IO ()
k forall b c a. (b -> c) -> (a -> b) -> a -> c
. a' -> a
f)
  {-# INLINE contramap #-}
#endif

-- | Construct a 'StateVar' from two IO actions, one for reading and one for
--- writing.
makeStateVar
  :: IO a         -- ^ getter
  -> (a -> IO ()) -- ^ setter
  -> StateVar a
makeStateVar :: forall a. IO a -> (a -> IO ()) -> StateVar a
makeStateVar = forall a. IO a -> (a -> IO ()) -> StateVar a
StateVar

-- | Change the type of a 'StateVar'
mapStateVar :: (b -> a) -> (a -> b) -> StateVar a -> StateVar b
mapStateVar :: forall b a. (b -> a) -> (a -> b) -> StateVar a -> StateVar b
mapStateVar b -> a
ba a -> b
ab (StateVar IO a
ga a -> IO ()
sa) = forall a. IO a -> (a -> IO ()) -> StateVar a
StateVar (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
ab IO a
ga) (a -> IO ()
sa forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> a
ba)
{-# INLINE mapStateVar #-}

-- | A concrete implementation of a write-only state variable, carrying an IO
-- action to write the new value.
newtype SettableStateVar a = SettableStateVar (a -> IO ())
  deriving Typeable

-- | Construct a 'SettableStateVar' from an IO action for writing.
makeSettableStateVar
  :: (a -> IO ()) -- ^ setter
  -> SettableStateVar a
makeSettableStateVar :: forall a. (a -> IO ()) -> SettableStateVar a
makeSettableStateVar = forall a. (a -> IO ()) -> SettableStateVar a
SettableStateVar
{-# INLINE makeSettableStateVar #-}

-- | A concrete implementation of a read-only state variable is simply an IO
-- action to read the value.
type GettableStateVar = IO

-- | Construct a 'GettableStateVar' from an IO action.
makeGettableStateVar
  :: IO a -- ^ getter
  -> GettableStateVar a
makeGettableStateVar :: forall a. IO a -> IO a
makeGettableStateVar = forall a. a -> a
id
{-# INLINE makeGettableStateVar #-}

--------------------------------------------------------------------
-- * HasSetter
--------------------------------------------------------------------

infixr 2 $=, $=!

-- | This is the class of all writable state variables.
class HasSetter t a | t -> a where
  -- | Write a new value into a state variable.
  ($=) :: MonadIO m => t -> a -> m ()

-- | This is a variant of '$=' which is strict in the value to be set.
($=!) :: (HasSetter t a, MonadIO m) => t -> a -> m ()
t
p $=! :: forall t a (m :: * -> *).
(HasSetter t a, MonadIO m) =>
t -> a -> m ()
$=! a
a = (t
p forall t a (m :: * -> *).
(HasSetter t a, MonadIO m) =>
t -> a -> m ()
$=) forall a b. (a -> b) -> a -> b
$! a
a
{-# INLINE ($=!) #-}

instance HasSetter (SettableStateVar a) a where
  SettableStateVar a -> IO ()
f $= :: forall (m :: * -> *). MonadIO m => SettableStateVar a -> a -> m ()
$= a
a = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (a -> IO ()
f a
a)
  {-# INLINE ($=) #-}

instance HasSetter (StateVar a) a where
  StateVar IO a
_ a -> IO ()
s $= :: forall (m :: * -> *). MonadIO m => StateVar a -> a -> m ()
$= a
a = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ a -> IO ()
s a
a
  {-# INLINE ($=) #-}

instance Storable a => HasSetter (Ptr a) a where
  Ptr a
p $= :: forall (m :: * -> *). MonadIO m => Ptr a -> a -> m ()
$= a
a = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr a
p a
a
  {-# INLINE ($=) #-}

instance HasSetter (IORef a) a where
  IORef a
p $= :: forall (m :: * -> *). MonadIO m => IORef a -> a -> m ()
$= a
a = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. IORef a -> a -> IO ()
writeIORef IORef a
p a
a
  {-# INLINE ($=) #-}

instance HasSetter (TVar a) a where
  TVar a
p $= :: forall (m :: * -> *). MonadIO m => TVar a -> a -> m ()
$= a
a = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TVar a -> a -> STM ()
writeTVar TVar a
p a
a
  {-# INLINE ($=) #-}

instance Storable a => HasSetter (ForeignPtr a) a where
  ForeignPtr a
p $= :: forall (m :: * -> *). MonadIO m => ForeignPtr a -> a -> m ()
$= a
a = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
p (forall t a (m :: * -> *).
(HasSetter t a, MonadIO m) =>
t -> a -> m ()
$= a
a)
  {-# INLINE ($=) #-}

--------------------------------------------------------------------
-- * HasUpdate
--------------------------------------------------------------------

infixr 2 $~, $~!

-- | This is the class of all updatable state variables.
class HasSetter t b => HasUpdate t a b | t -> a b where
  -- | Transform the contents of a state variable with a given funtion.
  ($~) :: MonadIO m => t -> (a -> b) -> m ()
#if USE_DEFAULT_SIGNATURES
  default ($~) :: (MonadIO m, a ~ b, HasGetter t a) => t -> (a -> b) -> m ()
  ($~) = forall (m :: * -> *) a b t.
(MonadIO m, a ~ b, HasGetter t a, HasSetter t a) =>
t -> (a -> b) -> m ()
defaultUpdate
#endif
  -- | This is a variant of '$~' which is strict in the transformed value.
  ($~!) :: MonadIO m => t -> (a -> b) -> m ()
#if USE_DEFAULT_SIGNATURES
  default ($~!) :: (MonadIO m, a ~ b, HasGetter t a) => t -> (a -> b) -> m ()
  ($~!) = forall (m :: * -> *) a b t.
(MonadIO m, a ~ b, HasGetter t a, HasSetter t a) =>
t -> (a -> b) -> m ()
defaultUpdateStrict
#endif

defaultUpdate :: (MonadIO m, a ~ b, HasGetter t a, HasSetter t a) => t -> (a -> b) -> m ()
defaultUpdate :: forall (m :: * -> *) a b t.
(MonadIO m, a ~ b, HasGetter t a, HasSetter t a) =>
t -> (a -> b) -> m ()
defaultUpdate t
r a -> b
f = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ do
  a
a <- forall t a (m :: * -> *). (HasGetter t a, MonadIO m) => t -> m a
get t
r
  t
r forall t a (m :: * -> *).
(HasSetter t a, MonadIO m) =>
t -> a -> m ()
$= a -> b
f a
a

defaultUpdateStrict :: (MonadIO m, a ~ b, HasGetter t a, HasSetter t a) => t -> (a -> b) -> m ()
defaultUpdateStrict :: forall (m :: * -> *) a b t.
(MonadIO m, a ~ b, HasGetter t a, HasSetter t a) =>
t -> (a -> b) -> m ()
defaultUpdateStrict t
r a -> b
f = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ do
  a
a <- forall t a (m :: * -> *). (HasGetter t a, MonadIO m) => t -> m a
get t
r
  t
r forall t a (m :: * -> *).
(HasSetter t a, MonadIO m) =>
t -> a -> m ()
$=! a -> b
f a
a

instance HasUpdate (StateVar a) a a where
  $~ :: forall (m :: * -> *). MonadIO m => StateVar a -> (a -> a) -> m ()
($~) = forall (m :: * -> *) a b t.
(MonadIO m, a ~ b, HasGetter t a, HasSetter t a) =>
t -> (a -> b) -> m ()
defaultUpdate
  $~! :: forall (m :: * -> *). MonadIO m => StateVar a -> (a -> a) -> m ()
($~!) = forall (m :: * -> *) a b t.
(MonadIO m, a ~ b, HasGetter t a, HasSetter t a) =>
t -> (a -> b) -> m ()
defaultUpdateStrict

instance Storable a => HasUpdate (Ptr a) a a where
  $~ :: forall (m :: * -> *). MonadIO m => Ptr a -> (a -> a) -> m ()
($~) = forall (m :: * -> *) a b t.
(MonadIO m, a ~ b, HasGetter t a, HasSetter t a) =>
t -> (a -> b) -> m ()
defaultUpdate
  $~! :: forall (m :: * -> *). MonadIO m => Ptr a -> (a -> a) -> m ()
($~!) = forall (m :: * -> *) a b t.
(MonadIO m, a ~ b, HasGetter t a, HasSetter t a) =>
t -> (a -> b) -> m ()
defaultUpdateStrict

instance HasUpdate (IORef a) a a where
  IORef a
r $~ :: forall (m :: * -> *). MonadIO m => IORef a -> (a -> a) -> m ()
$~ a -> a
f  = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef IORef a
r forall a b. (a -> b) -> a -> b
$ \a
a -> (a -> a
f a
a,())
#if MIN_VERSION_base(4,6,0)
  IORef a
r $~! :: forall (m :: * -> *). MonadIO m => IORef a -> (a -> a) -> m ()
$~! a -> a
f = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef a
r forall a b. (a -> b) -> a -> b
$ \a
a -> (a -> a
f a
a,())
#else
  r $~! f = liftIO $ do
    s <- atomicModifyIORef r $ \a -> let s = f a in (s, s)
    s `seq` return ()
#endif

instance HasUpdate (TVar a) a a where
  TVar a
r $~ :: forall (m :: * -> *). MonadIO m => TVar a -> (a -> a) -> m ()
$~ a -> a
f = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ do
    a
a <- forall a. TVar a -> STM a
readTVar TVar a
r
    forall a. TVar a -> a -> STM ()
writeTVar TVar a
r (a -> a
f a
a)
  TVar a
r $~! :: forall (m :: * -> *). MonadIO m => TVar a -> (a -> a) -> m ()
$~! a -> a
f = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ do
    a
a <- forall a. TVar a -> STM a
readTVar TVar a
r
    forall a. TVar a -> a -> STM ()
writeTVar TVar a
r forall a b. (a -> b) -> a -> b
$! a -> a
f a
a

instance Storable a => HasUpdate (ForeignPtr a) a a where
  ForeignPtr a
p $~ :: forall (m :: * -> *). MonadIO m => ForeignPtr a -> (a -> a) -> m ()
$~ a -> a
f = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
p (forall t a b (m :: * -> *).
(HasUpdate t a b, MonadIO m) =>
t -> (a -> b) -> m ()
$~ a -> a
f)
  ForeignPtr a
p $~! :: forall (m :: * -> *). MonadIO m => ForeignPtr a -> (a -> a) -> m ()
$~! a -> a
f = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
p (forall t a b (m :: * -> *).
(HasUpdate t a b, MonadIO m) =>
t -> (a -> b) -> m ()
$~! a -> a
f)

--------------------------------------------------------------------
-- * HasGetter
--------------------------------------------------------------------

-- | This is the class of all readable state variables.
class HasGetter t a | t -> a where
  get :: MonadIO m => t -> m a

instance HasGetter (StateVar a) a where
  get :: forall (m :: * -> *). MonadIO m => StateVar a -> m a
get (StateVar IO a
g a -> IO ()
_) = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO a
g
  {-# INLINE get #-}

instance HasGetter (TVar a) a where
  get :: forall (m :: * -> *). MonadIO m => TVar a -> m a
get = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. STM a -> IO a
atomically forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. TVar a -> STM a
readTVar
  {-# INLINE get #-}

instance HasGetter (IO a) a where
  get :: forall (m :: * -> *). MonadIO m => IO a -> m a
get = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO
  {-# INLINE get #-}

instance HasGetter (STM a) a where
  get :: forall (m :: * -> *). MonadIO m => STM a -> m a
get = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. STM a -> IO a
atomically
  {-# INLINE get #-}

instance Storable a => HasGetter (Ptr a) a where
  get :: forall (m :: * -> *). MonadIO m => Ptr a -> m a
get = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Storable a => Ptr a -> IO a
peek
  {-# INLINE get #-}

instance HasGetter (IORef a) a where
  get :: forall (m :: * -> *). MonadIO m => IORef a -> m a
get = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IORef a -> IO a
readIORef
  {-# INLINE get #-}

instance Storable a => HasGetter (ForeignPtr a) a where
  get :: forall (m :: * -> *). MonadIO m => ForeignPtr a -> m a
get ForeignPtr a
p = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
p forall t a (m :: * -> *). (HasGetter t a, MonadIO m) => t -> m a
get
  {-# INLINE get #-}