{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE UndecidableInstances #-}

module Plutarch.Internal.PLam (
  plam,
  pinl,
) where

import Data.Kind (Type)
import Data.Text qualified as Text
import GHC.Stack (HasCallStack, callStack, withFrozenCallStack)
import Plutarch.Internal (
  PType,
  S,
  Term,
  pgetConfig,
  plam',
  punsafeConstantInternal,
  tracingMode,
  (:-->),
  pattern DoTracingAndBinds,
 )
import Plutarch.Internal.PrettyStack (prettyStack)
import Plutarch.Internal.Trace (ptrace)
import PlutusCore qualified as PLC

{- $plam
 Lambda abstraction.

 The 'PLamN' constraint allows
 currying to work as expected for any number of arguments.

 > id :: Term s (a :--> a)
 > id = plam (\x -> x)

 > const :: Term s (a :--> b :-> a)
 > const = plam (\x y -> x)
-}

mkstring :: Text.Text -> Term s a
mkstring :: forall (s :: S) (a :: PType). Text -> Term s a
mkstring Text
x = forall (s :: S) (a :: PType).
Some @Type (ValueOf DefaultUni) -> Term s a
punsafeConstantInternal forall a b. (a -> b) -> a -> b
$ forall a (uni :: Type -> Type).
Includes @Type uni a =>
a -> Some @Type (ValueOf uni)
PLC.someValue @Text.Text @PLC.DefaultUni Text
x

class PLamN (a :: Type) (b :: PType) (s :: S) | a -> b, s b -> a where
  plam :: forall c. HasCallStack => (Term s c -> a) -> Term s (c :--> b)

instance {-# OVERLAPPABLE #-} (a' ~ Term s a) => PLamN a' a s where
  plam :: forall (c :: PType).
HasCallStack =>
(Term s c -> a') -> Term s (c :--> a)
plam Term s c -> a'
f =
    let cs :: CallStack
cs = HasCallStack => CallStack
callStack
     in forall (s :: S) (a :: PType) (b :: PType).
(Term s a -> Term s b) -> Term s (a :--> b)
plam' \Term s c
x -> forall (s :: S) (a :: PType). (Config -> Term s a) -> Term s a
pgetConfig \Config
c -> case Config -> TracingMode
tracingMode Config
c of
          TracingMode
DoTracingAndBinds -> forall (s :: S) (a :: PType).
Term s PString -> Term s a -> Term s a
ptrace (forall (s :: S) (a :: PType). Text -> Term s a
mkstring forall a b. (a -> b) -> a -> b
$ Text -> CallStack -> Text
prettyStack Text
"L" CallStack
cs) forall a b. (a -> b) -> a -> b
$ Term s c -> a'
f Term s c
x
          TracingMode
_ -> Term s c -> a'
f Term s c
x

instance (a' ~ Term s a, PLamN b' b s) => PLamN (a' -> b') (a :--> b) s where
  plam :: forall (c :: PType).
HasCallStack =>
(Term s c -> a' -> b') -> Term s (c :--> (a :--> b))
plam Term s c -> a' -> b'
f = forall a. HasCallStack => (HasCallStack => a) -> a
withFrozenCallStack forall a b. (a -> b) -> a -> b
$ forall (s :: S) (a :: PType) (b :: PType).
(Term s a -> Term s b) -> Term s (a :--> b)
plam' forall a b. (a -> b) -> a -> b
$ \Term s c
x -> forall a (b :: PType) (s :: S) (c :: PType).
(PLamN a b s, HasCallStack) =>
(Term s c -> a) -> Term s (c :--> b)
plam (Term s c -> a' -> b'
f Term s c
x)

pinl :: Term s a -> (Term s a -> Term s b) -> Term s b
pinl :: forall (s :: S) (a :: PType) (b :: PType).
Term s a -> (Term s a -> Term s b) -> Term s b
pinl Term s a
v Term s a -> Term s b
f = Term s a -> Term s b
f Term s a
v