{-# LANGUAGE QuantifiedConstraints #-}

module Plutarch.Extra.Bind (
  -- * Type class
  PBind (..),

  -- * Functions
  pjoin,
) where

import Plutarch.Api.V1.Maybe (PMaybeData (PDJust, PDNothing))
import Plutarch.Extra.Applicative (PApply)
import Plutarch.Extra.Function (pidentity)
import Plutarch.Extra.Functor (PSubcategory)
import Plutarch.Lift (PUnsafeLiftDecl)

{- | Gives the capability to bind a Kleisli arrow over @f@ to a value:
 essentially, the equivalent of Haskell's '>>='. Unlike Haskell, we don't
 require the availability of 'pure': to recover the equivalent of Haskell's
 'Monad', you want both 'PApplicative' and 'PBind'.

 = Laws

 * @(m '#>>=' f) '#>>=' g@ @=@ @m '#>>=' ('plam' '$' \x -> (f '#' x) '#>>=' g)@
 * @f '#<*>' x@ @=@ @f '#>>=' ('#<$>' x)@

 @since 3.0.1
-}
class (PApply f) => PBind (f :: (S -> Type) -> S -> Type) where
  -- | '>>=', but as a function on 'Term's.
  (#>>=) ::
    forall (a :: S -> Type) (b :: S -> Type) (s :: S).
    (PSubcategory f a, PSubcategory f b) =>
    Term s (f a) ->
    Term s (a :--> f b) ->
    Term s (f b)

infixl 1 #>>=

-- | @since 3.0.1
instance PBind PMaybe where
  {-# INLINEABLE (#>>=) #-}
  Term s (PMaybe a)
xs #>>= :: forall (a :: S -> Type) (b :: S -> Type) (s :: S).
(PSubcategory PMaybe a, PSubcategory PMaybe b) =>
Term s (PMaybe a) -> Term s (a :--> PMaybe b) -> Term s (PMaybe b)
#>>= Term s (a :--> PMaybe b)
f = forall (a :: S -> Type) (s :: S) (b :: S -> Type).
PlutusType a =>
Term s a -> (a s -> Term s b) -> Term s b
pmatch Term s (PMaybe a)
xs forall a b. (a -> b) -> a -> b
$ \case
    PMaybe a s
PNothing -> forall (a :: S -> Type) (s :: S). PlutusType a => a s -> Term s a
pcon forall (a :: S -> Type) (s :: S). PMaybe a s
PNothing
    PJust Term s a
t -> Term s (a :--> PMaybe b)
f forall (s :: S) (a :: S -> Type) (b :: S -> Type).
HasCallStack =>
Term s (a :--> b) -> Term s a -> Term s b
# Term s a
t

-- | @since 3.0.1
instance PBind PMaybeData where
  {-# INLINEABLE (#>>=) #-}
  Term s (PMaybeData a)
xs #>>= :: forall (a :: S -> Type) (b :: S -> Type) (s :: S).
(PSubcategory PMaybeData a, PSubcategory PMaybeData b) =>
Term s (PMaybeData a)
-> Term s (a :--> PMaybeData b) -> Term s (PMaybeData b)
#>>= Term s (a :--> PMaybeData b)
f = forall (a :: S -> Type) (s :: S) (b :: S -> Type).
PlutusType a =>
Term s a -> (a s -> Term s b) -> Term s b
pmatch Term s (PMaybeData a)
xs forall a b. (a -> b) -> a -> b
$ \case
    PDNothing Term s (PDataRecord '[])
t -> forall (a :: S -> Type) (s :: S). PlutusType a => a s -> Term s a
pcon forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (a :: S -> Type) (s :: S).
Term s (PDataRecord '[]) -> PMaybeData a s
PDNothing forall a b. (a -> b) -> a -> b
$ Term s (PDataRecord '[])
t
    PDJust Term s (PDataRecord '[ "_0" ':= a])
t -> Term s (a :--> PMaybeData b)
f forall (s :: S) (a :: S -> Type) (b :: S -> Type).
HasCallStack =>
Term s (a :--> b) -> Term s a -> Term s b
# forall (a :: S -> Type) (s :: S).
PIsData a =>
Term s (PAsData a) -> Term s a
pfromData (forall (name :: Symbol) (b :: S -> Type) (p :: S -> Type) (s :: S)
       (a :: S -> Type) (as :: [PLabeledType]) (n :: Nat).
(PDataFields p, as ~ PFields p, n ~ PLabelIndex name as,
 KnownNat n, a ~ PUnLabel (IndexList n as), PFromDataable a b) =>
Term s (p :--> b)
pfield @"_0" forall (s :: S) (a :: S -> Type) (b :: S -> Type).
HasCallStack =>
Term s (a :--> b) -> Term s a -> Term s b
# Term s (PDataRecord '[ "_0" ':= a])
t)

-- | @since 3.0.1
instance PBind PList where
  (#>>=) ::
    forall (a :: S -> Type) (b :: S -> Type) (s :: S).
    Term s (PList a) ->
    Term s (a :--> PList b) ->
    Term s (PList b)
  {-# INLINEABLE (#>>=) #-}
  Term s (PList a)
xs #>>= :: forall (a :: S -> Type) (b :: S -> Type) (s :: S).
Term s (PList a) -> Term s (a :--> PList b) -> Term s (PList b)
#>>= Term s (a :--> PList b)
f = forall (list :: (S -> Type) -> S -> Type) (a :: S -> Type) (s :: S)
       (r :: S -> Type).
(PListLike list, PElemConstraint list a) =>
(Term s a -> Term s (list a) -> Term s r)
-> Term s r -> Term s (list a) -> Term s r
pelimList Term s a -> Term s (PList a) -> Term s (PList b)
go forall (list :: (S -> Type) -> S -> Type) (a :: S -> Type)
       (s :: S).
(PListLike list, PElemConstraint list a) =>
Term s (list a)
pnil Term s (PList a)
xs
    where
      go :: Term s a -> Term s (PList a) -> Term s (PList b)
      go :: Term s a -> Term s (PList a) -> Term s (PList b)
go Term s a
h Term s (PList a)
t = forall (list :: (S -> Type) -> S -> Type) (a :: S -> Type)
       (s :: S).
PIsListLike list a =>
Term s (list a :--> (list a :--> list a))
pconcat forall (s :: S) (a :: S -> Type) (b :: S -> Type).
HasCallStack =>
Term s (a :--> b) -> Term s a -> Term s b
# (Term s (a :--> PList b)
f forall (s :: S) (a :: S -> Type) (b :: S -> Type).
HasCallStack =>
Term s (a :--> b) -> Term s a -> Term s b
# Term s a
h) forall (s :: S) (a :: S -> Type) (b :: S -> Type).
HasCallStack =>
Term s (a :--> b) -> Term s a -> Term s b
# (Term s (PList a)
t forall (f :: (S -> Type) -> S -> Type) (a :: S -> Type)
       (b :: S -> Type) (s :: S).
(PBind f, PSubcategory f a, PSubcategory f b) =>
Term s (f a) -> Term s (a :--> f b) -> Term s (f b)
#>>= Term s (a :--> PList b)
f)

-- | @since 3.0.1
instance PBind PBuiltinList where
  (#>>=) ::
    forall (a :: S -> Type) (b :: S -> Type) (s :: S).
    (PUnsafeLiftDecl a, PUnsafeLiftDecl b) =>
    Term s (PBuiltinList a) ->
    Term s (a :--> PBuiltinList b) ->
    Term s (PBuiltinList b)
  {-# INLINEABLE (#>>=) #-}
  Term s (PBuiltinList a)
xs #>>= :: forall (a :: S -> Type) (b :: S -> Type) (s :: S).
(PUnsafeLiftDecl a, PUnsafeLiftDecl b) =>
Term s (PBuiltinList a)
-> Term s (a :--> PBuiltinList b) -> Term s (PBuiltinList b)
#>>= Term s (a :--> PBuiltinList b)
f = forall (list :: (S -> Type) -> S -> Type) (a :: S -> Type) (s :: S)
       (r :: S -> Type).
(PListLike list, PElemConstraint list a) =>
(Term s a -> Term s (list a) -> Term s r)
-> Term s r -> Term s (list a) -> Term s r
pelimList Term s a -> Term s (PBuiltinList a) -> Term s (PBuiltinList b)
go forall (list :: (S -> Type) -> S -> Type) (a :: S -> Type)
       (s :: S).
(PListLike list, PElemConstraint list a) =>
Term s (list a)
pnil Term s (PBuiltinList a)
xs
    where
      go :: Term s a -> Term s (PBuiltinList a) -> Term s (PBuiltinList b)
      go :: Term s a -> Term s (PBuiltinList a) -> Term s (PBuiltinList b)
go Term s a
h Term s (PBuiltinList a)
t = forall (list :: (S -> Type) -> S -> Type) (a :: S -> Type)
       (s :: S).
PIsListLike list a =>
Term s (list a :--> (list a :--> list a))
pconcat forall (s :: S) (a :: S -> Type) (b :: S -> Type).
HasCallStack =>
Term s (a :--> b) -> Term s a -> Term s b
# (Term s (a :--> PBuiltinList b)
f forall (s :: S) (a :: S -> Type) (b :: S -> Type).
HasCallStack =>
Term s (a :--> b) -> Term s a -> Term s b
# Term s a
h) forall (s :: S) (a :: S -> Type) (b :: S -> Type).
HasCallStack =>
Term s (a :--> b) -> Term s a -> Term s b
# (Term s (PBuiltinList a)
t forall (f :: (S -> Type) -> S -> Type) (a :: S -> Type)
       (b :: S -> Type) (s :: S).
(PBind f, PSubcategory f a, PSubcategory f b) =>
Term s (f a) -> Term s (a :--> f b) -> Term s (f b)
#>>= Term s (a :--> PBuiltinList b)
f)

-- | @since 3.0.1
instance (forall (s :: S). Semigroup (Term s a)) => PBind (PPair a) where
  {-# INLINEABLE (#>>=) #-}
  Term s (PPair a a)
xs #>>= :: forall (a :: S -> Type) (b :: S -> Type) (s :: S).
(PSubcategory (PPair a) a, PSubcategory (PPair a) b) =>
Term s (PPair a a)
-> Term s (a :--> PPair a b) -> Term s (PPair a b)
#>>= Term s (a :--> PPair a b)
f = forall (a :: S -> Type) (s :: S) (b :: S -> Type).
PlutusType a =>
Term s a -> (a s -> Term s b) -> Term s b
pmatch Term s (PPair a a)
xs forall a b. (a -> b) -> a -> b
$ \case
    PPair Term s a
acc Term s a
t -> forall (a :: S -> Type) (s :: S) (b :: S -> Type).
PlutusType a =>
Term s a -> (a s -> Term s b) -> Term s b
pmatch (Term s (a :--> PPair a b)
f forall (s :: S) (a :: S -> Type) (b :: S -> Type).
HasCallStack =>
Term s (a :--> b) -> Term s a -> Term s b
# Term s a
t) forall a b. (a -> b) -> a -> b
$ \case
      PPair Term s a
acc' Term s b
res -> forall (a :: S -> Type) (s :: S). PlutusType a => a s -> Term s a
pcon forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (a :: S -> Type) (b :: S -> Type) (s :: S).
Term s a -> Term s b -> PPair a b s
PPair (Term s a
acc forall a. Semigroup a => a -> a -> a
<> Term s a
acc') forall a b. (a -> b) -> a -> b
$ Term s b
res

{- | Forwards the /first/ 'PLeft'.

 @since 3.0.1
-}
instance PBind (PEither e) where
  {-# INLINEABLE (#>>=) #-}
  Term s (PEither e a)
xs #>>= :: forall (a :: S -> Type) (b :: S -> Type) (s :: S).
(PSubcategory (PEither e) a, PSubcategory (PEither e) b) =>
Term s (PEither e a)
-> Term s (a :--> PEither e b) -> Term s (PEither e b)
#>>= Term s (a :--> PEither e b)
f = forall (a :: S -> Type) (s :: S) (b :: S -> Type).
PlutusType a =>
Term s a -> (a s -> Term s b) -> Term s b
pmatch Term s (PEither e a)
xs forall a b. (a -> b) -> a -> b
$ \case
    PLeft Term s e
t -> forall (a :: S -> Type) (s :: S). PlutusType a => a s -> Term s a
pcon forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (a :: S -> Type) (b :: S -> Type) (s :: S).
Term s a -> PEither a b s
PLeft forall a b. (a -> b) -> a -> b
$ Term s e
t
    PRight Term s a
t -> Term s (a :--> PEither e b)
f forall (s :: S) (a :: S -> Type) (b :: S -> Type).
HasCallStack =>
Term s (a :--> b) -> Term s a -> Term s b
# Term s a
t

{- | \'Flattens\' two identical 'PBind' layers into one.

 @since 3.0.1
-}
pjoin ::
  forall (a :: S -> Type) (f :: (S -> Type) -> S -> Type) (s :: S).
  (PBind f, PSubcategory f a, PSubcategory f (f a)) =>
  Term s (f (f a) :--> f a)
pjoin :: forall (a :: S -> Type) (f :: (S -> Type) -> S -> Type) (s :: S).
(PBind f, PSubcategory f a, PSubcategory f (f a)) =>
Term s (f (f a) :--> f a)
pjoin = forall (a :: S -> Type) (s :: S).
HasCallStack =>
ClosedTerm a -> Term s a
phoistAcyclic forall a b. (a -> b) -> a -> b
$ forall a (b :: S -> Type) (s :: S) (c :: S -> Type).
(PLamN a b s, HasCallStack) =>
(Term s c -> a) -> Term s (c :--> b)
plam forall a b. (a -> b) -> a -> b
$ \Term s (f (f a))
xs -> Term s (f (f a))
xs forall (f :: (S -> Type) -> S -> Type) (a :: S -> Type)
       (b :: S -> Type) (s :: S).
(PBind f, PSubcategory f a, PSubcategory f b) =>
Term s (f a) -> Term s (a :--> f b) -> Term s (f b)
#>>= forall (a :: S -> Type) (s :: S). Term s (a :--> a)
pidentity