{-# LANGUAGE AllowAmbiguousTypes #-}

-- | TermCont-related adapters for Plutarch functions.
module Plutarch.Extra.TermCont (
  pletC,
  pmatchC,
  pletFieldsC,
  ptraceC,
  pguardC,
  pguardC',
  ptryFromC,
) where

import Plutarch.DataRepr (HRec, PDataFields, PFields)
import Plutarch.DataRepr.Internal.Field (
  BindFields,
  Bindings,
  BoundTerms,
 )
import Plutarch.Prelude
import Plutarch.Reducible (Reduce)
import Plutarch.TryFrom (PTryFrom (PTryFromExcess))

-- | Like `plet` but works in a `TermCont` monad
pletC :: Term s a -> TermCont s (Term s a)
pletC :: forall {r :: PType} (s :: S) (a :: PType).
Term s a -> TermCont @r s (Term s a)
pletC = forall a (s :: S) (r :: PType).
((a -> Term s r) -> Term s r) -> TermCont @r s a
tcont forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (s :: S) (a :: PType) (b :: PType).
Term s a -> (Term s a -> Term s b) -> Term s b
plet

-- | Like `pmatch` but works in a `TermCont` monad
pmatchC :: PlutusType a => Term s a -> TermCont s (a s)
pmatchC :: forall {r :: PType} (a :: PType) (s :: S).
PlutusType a =>
Term s a -> TermCont @r s (a s)
pmatchC = forall a (s :: S) (r :: PType).
((a -> Term s r) -> Term s r) -> TermCont @r s a
tcont forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (a :: PType) (s :: S) (b :: PType).
PlutusType a =>
Term s a -> (a s -> Term s b) -> Term s b
pmatch

-- | Like `pletFields` but works in a `TermCont` monad.
pletFieldsC ::
  forall fs a s b ps bs.
  ( PDataFields a
  , ps ~ PFields a
  , bs ~ Bindings ps fs
  , BindFields ps bs
  ) =>
  Term s a ->
  TermCont @b s (HRec (BoundTerms ps bs s))
pletFieldsC :: forall (fs :: [Symbol]) (a :: PType) (s :: S) (b :: PType)
       (ps :: [PLabeledType]) (bs :: [ToBind]).
(PDataFields a,
 (ps :: [PLabeledType]) ~ (PFields a :: [PLabeledType]),
 (bs :: [ToBind]) ~ (Bindings ps fs :: [ToBind]),
 BindFields ps bs) =>
Term s a -> TermCont @b s (HRec (BoundTerms ps bs s))
pletFieldsC Term s a
x = forall a (s :: S) (r :: PType).
((a -> Term s r) -> Term s r) -> TermCont @r s a
tcont forall a b. (a -> b) -> a -> b
$ forall (fs :: [Symbol]) (a :: PType) (s :: S) (b :: PType)
       (ps :: [PLabeledType]) (bs :: [ToBind]).
(PDataFields a,
 (ps :: [PLabeledType]) ~ (PFields a :: [PLabeledType]),
 (bs :: [ToBind]) ~ (Bindings ps fs :: [ToBind]),
 BindFields ps bs) =>
Term s a -> (HRecOf a fs s -> Term s b) -> Term s b
pletFields @fs Term s a
x

{- | Like `ptrace` but works in a `TermCont` monad.

=== Example ===

@
foo :: Term s PUnit
foo = unTermCont $ do
  ptraceC "returning unit!"
  pure $ pconstant ()
@
-}
ptraceC :: Term s PString -> TermCont s ()
ptraceC :: forall {r :: PType} (s :: S). Term s PString -> TermCont @r s ()
ptraceC Term s PString
s = forall a (s :: S) (r :: PType).
((a -> Term s r) -> Term s r) -> TermCont @r s a
tcont forall a b. (a -> b) -> a -> b
$ \() -> Term s r
f -> forall (s :: S) (a :: PType).
Term s PString -> Term s a -> Term s a
ptrace Term s PString
s (() -> Term s r
f ())

{- | Trace a message and raise error if 'cond' is false. Otherwise, continue.

=== Example ===

@
onlyAllow42 :: Term s (PInteger :--> PUnit)
onlyAllow42 = plam $ \i -> unTermCont $ do
  pguardC "expected 42" $ i #== 42
  pure $ pconstant ()
@
-}
pguardC :: Term s PString -> Term s PBool -> TermCont s ()
pguardC :: forall {r :: PType} (s :: S).
Term s PString -> Term s PBool -> TermCont @r s ()
pguardC Term s PString
s Term s PBool
cond = forall a (s :: S) (r :: PType).
((a -> Term s r) -> Term s r) -> TermCont @r s a
tcont forall a b. (a -> b) -> a -> b
$ \() -> Term s r
f -> forall (s :: S) (a :: PType).
Term s PBool -> Term s a -> Term s a -> Term s a
pif Term s PBool
cond (() -> Term s r
f ()) forall a b. (a -> b) -> a -> b
$ forall (s :: S) (a :: PType). Term s PString -> Term s a
ptraceError Term s PString
s

{- | Stop computation and return given term if 'cond' is false. Otherwise, continue.

=== Example ===

@
is42 :: Term s (PInteger :--> PBool)
is42 = plam $ \i -> unTermCont $ do
  pguardC "expected 42" (pconstant False) $ i #== 42
  pure $ pconstant True
@
-}
pguardC' :: Term s a -> Term s PBool -> TermCont @a s ()
pguardC' :: forall (s :: S) (a :: PType).
Term s a -> Term s PBool -> TermCont @a s ()
pguardC' Term s a
r Term s PBool
cond = forall a (s :: S) (r :: PType).
((a -> Term s r) -> Term s r) -> TermCont @r s a
tcont forall a b. (a -> b) -> a -> b
$ \() -> Term s a
f -> forall (s :: S) (a :: PType).
Term s PBool -> Term s a -> Term s a -> Term s a
pif Term s PBool
cond (() -> Term s a
f ()) Term s a
r

-- | 'TermCont' producing version of 'ptryFrom'.
ptryFromC :: forall b r a s. PTryFrom a b => Term s a -> TermCont @r s (Term s b, Reduce (PTryFromExcess a b s))
ptryFromC :: forall (b :: PType) (r :: PType) (a :: PType) (s :: S).
PTryFrom a b =>
Term s a -> TermCont @r s (Term s b, Reduce (PTryFromExcess a b s))
ptryFromC = forall a (s :: S) (r :: PType).
((a -> Term s r) -> Term s r) -> TermCont @r s a
tcont forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (b :: PType) (a :: PType) (s :: S) (r :: PType).
PTryFrom a b =>
Term s a
-> ((Term s b, Reduce (PTryFromExcess a b s)) -> Term s r)
-> Term s r
ptryFrom