{-# LANGUAGE CPP #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UndecidableInstances #-}

#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 702
{-# LANGUAGE Trustworthy #-}
#endif

#ifndef MIN_VERSION_transformers
#define MIN_VERSION_transformers(x,y,z) 1
#endif

#ifndef MIN_VERSION_mtl
#define MIN_VERSION_mtl(x,y,z) 1
#endif

--------------------------------------------------------------------
-- |
-- Copyright   :  (C) Edward Kmett 2013-2015, (c) Google Inc. 2012
-- License     :  BSD-style (see the file LICENSE)
-- Maintainer  :  Edward Kmett <[email protected]>
-- Stability   :  experimental
-- Portability :  non-portable
--
-- This module supplies a \'pure\' monad transformer that can be used for
-- mock-testing code that throws exceptions, so long as those exceptions
-- are always thrown with 'throwM'.
--
-- Do not mix 'CatchT' with 'IO'. Choose one or the other for the
-- bottom of your transformer stack!
--------------------------------------------------------------------

module Control.Monad.Catch.Pure (
    -- * Transformer
    -- $transformer
    CatchT(..), Catch
  , runCatch
  , mapCatchT

  -- * Typeclass
  -- $mtl
  , module Control.Monad.Catch
  ) where

#if defined(__GLASGOW_HASKELL__) && (__GLASGOW_HASKELL__ >= 706)
import Prelude hiding (foldr)
#else
import Prelude hiding (catch, foldr)
#endif

import Control.Applicative
import Control.Monad.Catch
import qualified Control.Monad.Fail as Fail
import Control.Monad.Reader as Reader
import Control.Monad.RWS
#if __GLASGOW_HASKELL__ < 710
import Data.Foldable
#endif
import Data.Functor.Identity
import Data.Traversable as Traversable

------------------------------------------------------------------------------
-- $mtl
-- The mtl style typeclass
------------------------------------------------------------------------------

------------------------------------------------------------------------------
-- $transformer
-- The @transformers@-style monad transfomer
------------------------------------------------------------------------------

-- | Add 'Exception' handling abilities to a 'Monad'.
--
-- This should /never/ be used in combination with 'IO'. Think of 'CatchT'
-- as an alternative base monad for use with mocking code that solely throws
-- exceptions via 'throwM'.
--
-- Note: that 'IO' monad has these abilities already, so stacking 'CatchT' on top
-- of it does not add any value and can possibly be confusing:
--
-- >>> (error "Hello!" :: IO ()) `catch` (\(e :: ErrorCall) -> liftIO $ print e)
-- Hello!
--
-- >>> runCatchT $ (error "Hello!" :: CatchT IO ()) `catch` (\(e :: ErrorCall) -> liftIO $ print e)
-- *** Exception: Hello!
--
-- >>> runCatchT $ (throwM (ErrorCall "Hello!") :: CatchT IO ()) `catch` (\(e :: ErrorCall) -> liftIO $ print e)
-- Hello!

newtype CatchT m a = CatchT { forall (m :: * -> *) a. CatchT m a -> m (Either SomeException a)
runCatchT :: m (Either SomeException a) }

type Catch = CatchT Identity

runCatch :: Catch a -> Either SomeException a
runCatch :: forall a. Catch a -> Either SomeException a
runCatch = forall a. Identity a -> a
runIdentity forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. CatchT m a -> m (Either SomeException a)
runCatchT

instance Monad m => Functor (CatchT m) where
  fmap :: forall a b. (a -> b) -> CatchT m a -> CatchT m b
fmap a -> b
f (CatchT m (Either SomeException a)
m) = forall (m :: * -> *) a. m (Either SomeException a) -> CatchT m a
CatchT (forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f) m (Either SomeException a)
m)

instance Monad m => Applicative (CatchT m) where
  pure :: forall a. a -> CatchT m a
pure a
a = forall (m :: * -> *) a. m (Either SomeException a) -> CatchT m a
CatchT (forall (m :: * -> *) a. Monad m => a -> m a
return (forall a b. b -> Either a b
Right a
a))
  <*> :: forall a b. CatchT m (a -> b) -> CatchT m a -> CatchT m b
(<*>) = forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap

instance Monad m => Monad (CatchT m) where
  return :: forall a. a -> CatchT m a
return = forall (f :: * -> *) a. Applicative f => a -> f a
pure
  CatchT m (Either SomeException a)
m >>= :: forall a b. CatchT m a -> (a -> CatchT m b) -> CatchT m b
>>= a -> CatchT m b
k = forall (m :: * -> *) a. m (Either SomeException a) -> CatchT m a
CatchT forall a b. (a -> b) -> a -> b
$ m (Either SomeException a)
m forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Either SomeException a
ea -> case Either SomeException a
ea of
    Left SomeException
e -> forall (m :: * -> *) a. Monad m => a -> m a
return (forall a b. a -> Either a b
Left SomeException
e)
    Right a
a -> forall (m :: * -> *) a. CatchT m a -> m (Either SomeException a)
runCatchT (a -> CatchT m b
k a
a)
#if !(MIN_VERSION_base(4,13,0))
  fail = Fail.fail
#endif

instance Monad m => Fail.MonadFail (CatchT m) where
  fail :: forall a. String -> CatchT m a
fail = forall (m :: * -> *) a. m (Either SomeException a) -> CatchT m a
CatchT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> Either a b
Left forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e. Exception e => e -> SomeException
toException forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> IOError
userError

instance MonadFix m => MonadFix (CatchT m) where
  mfix :: forall a. (a -> CatchT m a) -> CatchT m a
mfix a -> CatchT m a
f = forall (m :: * -> *) a. m (Either SomeException a) -> CatchT m a
CatchT forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadFix m => (a -> m a) -> m a
mfix forall a b. (a -> b) -> a -> b
$ \Either SomeException a
a -> forall (m :: * -> *) a. CatchT m a -> m (Either SomeException a)
runCatchT forall a b. (a -> b) -> a -> b
$ a -> CatchT m a
f forall a b. (a -> b) -> a -> b
$ case Either SomeException a
a of
    Right a
r -> a
r
    Either SomeException a
_       -> forall a. HasCallStack => String -> a
error String
"empty mfix argument"

instance Foldable m => Foldable (CatchT m) where
  foldMap :: forall m a. Monoid m => (a -> m) -> CatchT m a -> m
foldMap a -> m
f (CatchT m (Either SomeException a)
m) = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (forall {t} {t} {a}. Monoid t => (t -> t) -> Either a t -> t
foldMapEither a -> m
f) m (Either SomeException a)
m where
    foldMapEither :: (t -> t) -> Either a t -> t
foldMapEither t -> t
g (Right t
a) = t -> t
g t
a
    foldMapEither t -> t
_ (Left a
_) = forall a. Monoid a => a
mempty

instance (Monad m, Traversable m) => Traversable (CatchT m) where
  traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> CatchT m a -> f (CatchT m b)
traverse a -> f b
f (CatchT m (Either SomeException a)
m) = forall (m :: * -> *) a. m (Either SomeException a) -> CatchT m a
CatchT forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
Traversable.traverse (forall {f :: * -> *} {t} {a} {a}.
Applicative f =>
(t -> f a) -> Either a t -> f (Either a a)
traverseEither a -> f b
f) m (Either SomeException a)
m where
    traverseEither :: (t -> f a) -> Either a t -> f (Either a a)
traverseEither t -> f a
g (Right t
a) = forall a b. b -> Either a b
Right forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> t -> f a
g t
a
    traverseEither t -> f a
_ (Left a
e) = forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> Either a b
Left a
e)

instance Monad m => Alternative (CatchT m) where
  empty :: forall a. CatchT m a
empty = forall (m :: * -> *) a. MonadPlus m => m a
mzero
  <|> :: forall a. CatchT m a -> CatchT m a -> CatchT m a
(<|>) = forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
mplus

instance Monad m => MonadPlus (CatchT m) where
  mzero :: forall a. CatchT m a
mzero = forall (m :: * -> *) a. m (Either SomeException a) -> CatchT m a
CatchT forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ forall e. Exception e => e -> SomeException
toException forall a b. (a -> b) -> a -> b
$ String -> IOError
userError String
""
  mplus :: forall a. CatchT m a -> CatchT m a -> CatchT m a
mplus (CatchT m (Either SomeException a)
m) (CatchT m (Either SomeException a)
n) = forall (m :: * -> *) a. m (Either SomeException a) -> CatchT m a
CatchT forall a b. (a -> b) -> a -> b
$ m (Either SomeException a)
m forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Either SomeException a
ea -> case Either SomeException a
ea of
    Left SomeException
_ -> m (Either SomeException a)
n
    Right a
a -> forall (m :: * -> *) a. Monad m => a -> m a
return (forall a b. b -> Either a b
Right a
a)

instance MonadTrans CatchT where
  lift :: forall (m :: * -> *) a. Monad m => m a -> CatchT m a
lift m a
m = forall (m :: * -> *) a. m (Either SomeException a) -> CatchT m a
CatchT forall a b. (a -> b) -> a -> b
$ do
    a
a <- m a
m
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right a
a

instance MonadIO m => MonadIO (CatchT m) where
  liftIO :: forall a. IO a -> CatchT m a
liftIO IO a
m = forall (m :: * -> *) a. m (Either SomeException a) -> CatchT m a
CatchT forall a b. (a -> b) -> a -> b
$ do
    a
a <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO a
m
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right a
a

instance Monad m => MonadThrow (CatchT m) where
  throwM :: forall e a. Exception e => e -> CatchT m a
throwM = forall (m :: * -> *) a. m (Either SomeException a) -> CatchT m a
CatchT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> Either a b
Left forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e. Exception e => e -> SomeException
toException
instance Monad m => MonadCatch (CatchT m) where
  catch :: forall e a.
Exception e =>
CatchT m a -> (e -> CatchT m a) -> CatchT m a
catch (CatchT m (Either SomeException a)
m) e -> CatchT m a
c = forall (m :: * -> *) a. m (Either SomeException a) -> CatchT m a
CatchT forall a b. (a -> b) -> a -> b
$ m (Either SomeException a)
m forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Either SomeException a
ea -> case Either SomeException a
ea of
    Left SomeException
e -> case forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e of
      Just e
e' -> forall (m :: * -> *) a. CatchT m a -> m (Either SomeException a)
runCatchT (e -> CatchT m a
c e
e')
      Maybe e
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return (forall a b. a -> Either a b
Left SomeException
e)
    Right a
a -> forall (m :: * -> *) a. Monad m => a -> m a
return (forall a b. b -> Either a b
Right a
a)
-- | Note: This instance is only valid if the underlying monad has a single
-- exit point!
--
-- For example, @IO@ or @Either@ would be invalid base monads, but
-- @Reader@ or @State@ would be acceptable.
instance Monad m => MonadMask (CatchT m) where
  mask :: forall b.
((forall a. CatchT m a -> CatchT m a) -> CatchT m b) -> CatchT m b
mask (forall a. CatchT m a -> CatchT m a) -> CatchT m b
a = (forall a. CatchT m a -> CatchT m a) -> CatchT m b
a forall a. a -> a
id
  uninterruptibleMask :: forall b.
((forall a. CatchT m a -> CatchT m a) -> CatchT m b) -> CatchT m b
uninterruptibleMask (forall a. CatchT m a -> CatchT m a) -> CatchT m b
a = (forall a. CatchT m a -> CatchT m a) -> CatchT m b
a forall a. a -> a
id
  generalBracket :: forall a b c.
CatchT m a
-> (a -> ExitCase b -> CatchT m c)
-> (a -> CatchT m b)
-> CatchT m (b, c)
generalBracket CatchT m a
acquire a -> ExitCase b -> CatchT m c
release a -> CatchT m b
use = forall (m :: * -> *) a. m (Either SomeException a) -> CatchT m a
CatchT forall a b. (a -> b) -> a -> b
$ do
    Either SomeException a
eresource <- forall (m :: * -> *) a. CatchT m a -> m (Either SomeException a)
runCatchT CatchT m a
acquire
    case Either SomeException a
eresource of
      Left SomeException
e -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left SomeException
e
      Right a
resource -> do
        Either SomeException b
eb <- forall (m :: * -> *) a. CatchT m a -> m (Either SomeException a)
runCatchT (a -> CatchT m b
use a
resource)
        case Either SomeException b
eb of
          Left SomeException
e -> forall (m :: * -> *) a. CatchT m a -> m (Either SomeException a)
runCatchT forall a b. (a -> b) -> a -> b
$ do
            c
_ <- a -> ExitCase b -> CatchT m c
release a
resource (forall a. SomeException -> ExitCase a
ExitCaseException SomeException
e)
            forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM SomeException
e
          Right b
b -> forall (m :: * -> *) a. CatchT m a -> m (Either SomeException a)
runCatchT forall a b. (a -> b) -> a -> b
$ do
            c
c <- a -> ExitCase b -> CatchT m c
release a
resource (forall a. a -> ExitCase a
ExitCaseSuccess b
b)
            forall (m :: * -> *) a. Monad m => a -> m a
return (b
b, c
c)

instance MonadState s m => MonadState s (CatchT m) where
  get :: CatchT m s
get = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall s (m :: * -> *). MonadState s m => m s
get
  put :: s -> CatchT m ()
put = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s (m :: * -> *). MonadState s m => s -> m ()
put
#if MIN_VERSION_mtl(2,1,0)
  state :: forall a. (s -> (a, s)) -> CatchT m a
state = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state
#endif

instance MonadReader e m => MonadReader e (CatchT m) where
  ask :: CatchT m e
ask = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall r (m :: * -> *). MonadReader r m => m r
ask
  local :: forall a. (e -> e) -> CatchT m a -> CatchT m a
local e -> e
f (CatchT m (Either SomeException a)
m) = forall (m :: * -> *) a. m (Either SomeException a) -> CatchT m a
CatchT (forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local e -> e
f m (Either SomeException a)
m)

instance MonadWriter w m => MonadWriter w (CatchT m) where
  tell :: w -> CatchT m ()
tell = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
  listen :: forall a. CatchT m a -> CatchT m (a, w)
listen = forall (m :: * -> *) a (n :: * -> *) b.
(m (Either SomeException a) -> n (Either SomeException b))
-> CatchT m a -> CatchT n b
mapCatchT forall a b. (a -> b) -> a -> b
$ \ m (Either SomeException a)
m -> do
    (Either SomeException a
a, w
w) <- forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
listen m (Either SomeException a)
m
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\ a
r -> (a
r, w
w)) Either SomeException a
a
  pass :: forall a. CatchT m (a, w -> w) -> CatchT m a
pass = forall (m :: * -> *) a (n :: * -> *) b.
(m (Either SomeException a) -> n (Either SomeException b))
-> CatchT m a -> CatchT n b
mapCatchT forall a b. (a -> b) -> a -> b
$ \ m (Either SomeException (a, w -> w))
m -> forall w (m :: * -> *) a. MonadWriter w m => m (a, w -> w) -> m a
pass forall a b. (a -> b) -> a -> b
$ do
    Either SomeException (a, w -> w)
a <- m (Either SomeException (a, w -> w))
m
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! case Either SomeException (a, w -> w)
a of
        Left  SomeException
l      -> (forall a b. a -> Either a b
Left  SomeException
l, forall a. a -> a
id)
        Right (a
r, w -> w
f) -> (forall a b. b -> Either a b
Right a
r, w -> w
f)
#if MIN_VERSION_mtl(2,1,0)
  writer :: forall a. (a, w) -> CatchT m a
writer (a, w)
aw = forall (m :: * -> *) a. m (Either SomeException a) -> CatchT m a
CatchT (forall a b. b -> Either a b
Right forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
`liftM` forall w (m :: * -> *) a. MonadWriter w m => (a, w) -> m a
writer (a, w)
aw)
#endif

instance MonadRWS r w s m => MonadRWS r w s (CatchT m)

-- | Map the unwrapped computation using the given function.
--
-- @'runCatchT' ('mapCatchT' f m) = f ('runCatchT' m)@
mapCatchT :: (m (Either SomeException a) -> n (Either SomeException b))
          -> CatchT m a
          -> CatchT n b
mapCatchT :: forall (m :: * -> *) a (n :: * -> *) b.
(m (Either SomeException a) -> n (Either SomeException b))
-> CatchT m a -> CatchT n b
mapCatchT m (Either SomeException a) -> n (Either SomeException b)
f CatchT m a
m = forall (m :: * -> *) a. m (Either SomeException a) -> CatchT m a
CatchT forall a b. (a -> b) -> a -> b
$ m (Either SomeException a) -> n (Either SomeException b)
f (forall (m :: * -> *) a. CatchT m a -> m (Either SomeException a)
runCatchT CatchT m a
m)