{-# LANGUAGE RankNTypes #-}

{- | Pre-compiling Plutarch functions and applying them.

 Speeds up benchmarking and testing.
-}
module Plutarch.Extra.Precompile (
  -- Exporting the data constructor on purpose, since users might want to
  -- deserialize compiled terms.  If someone wants to subvert type safety using
  -- Scripts, they can do that regardless of this export.
  CompiledTerm (..),
  debuggableScript,
  compile',
  toDebuggableScript,
  applyCompiledTerm,
  applyCompiledTerm',
  applyCompiledTerm2,
  applyCompiledTerm2',
  (##),
  (##~),
  (###),
  (###~),
  pliftCompiled',
  pliftCompiled,
) where

import Data.Text (Text)
import Data.Text qualified as Text
import GHC.Stack (HasCallStack)
import Optics.Getter (view)
import Plutarch.Evaluate (evalScript)
import Plutarch.Extra.DebuggableScript (
  DebuggableScript,
  applyDebuggableArg,
  finalEvalDebuggableScript,
  mustCompileD,
  mustEvalD,
 )
import Plutarch.Internal (
  Config (Config),
  RawTerm (RCompiled),
  Term (Term),
  TermResult (TermResult),
  TracingMode (DetTracing),
  tracingMode,
 )
import Plutarch.Lift (
  LiftError (
    LiftError_CompilationError,
    LiftError_EvalError,
    LiftError_FromRepr,
    LiftError_KnownTypeError
  ),
  PUnsafeLiftDecl (PLifted),
  plift',
 )
import Plutarch.Script (Script (Script))
import PlutusCore.Builtin (KnownTypeError (KnownTypeEvaluationFailure, KnownTypeUnliftingError))
import UntypedPlutusCore qualified as UPLC

{- | Type-safe wrapper for compiled Plutarch functions.

 @since 3.8.0
-}
newtype CompiledTerm (a :: S -> Type) = CompiledTerm DebuggableScript

-- | @since 3.8.0
debuggableScript ::
  forall (a :: S -> Type).
  CompiledTerm a ->
  DebuggableScript
debuggableScript :: forall (a :: S -> Type). CompiledTerm a -> DebuggableScript
debuggableScript (CompiledTerm DebuggableScript
x) = DebuggableScript
x

{- | Compile a closed Plutarch 'Term' to a 'CompiledTerm'.

 Beware, the Script inside contains everything it needs. You can end up with
 multiple copies of the same helper function through compiled terms (including
 RHS terms compiled by '##' and '##~').

 @since 3.0.2
-}
compile' ::
  forall (a :: S -> Type).
  (forall (s :: S). Term s a) ->
  CompiledTerm a
compile' :: forall (a :: S -> Type).
(forall (s :: S). Term s a) -> CompiledTerm a
compile' forall (s :: S). Term s a
t = forall (a :: S -> Type). DebuggableScript -> CompiledTerm a
CompiledTerm forall a b. (a -> b) -> a -> b
$ forall (a :: S -> Type).
(forall (s :: S). Term s a) -> DebuggableScript
mustCompileD forall (s :: S). Term s a
t

{- | Convert a 'CompiledTerm' to a 'Script'.

 @since 3.0.2
-}
toDebuggableScript ::
  forall (a :: S -> Type).
  CompiledTerm a ->
  DebuggableScript
toDebuggableScript :: forall (a :: S -> Type). CompiledTerm a -> DebuggableScript
toDebuggableScript (CompiledTerm DebuggableScript
dscript) = DebuggableScript
dscript

{- | Apply a 'CompiledTerm' to a closed Plutarch 'Term'.

 Evaluates the argument before applying. You want this for benchmarking the
 compiled function. Helps to avoid tainting the measurement by input
 conversions.

 @since 3.0.2
-}
applyCompiledTerm ::
  forall (a :: S -> Type) (b :: S -> Type).
  CompiledTerm (a :--> b) ->
  (forall (s :: S). Term s a) ->
  CompiledTerm b
applyCompiledTerm :: forall (a :: S -> Type) (b :: S -> Type).
CompiledTerm (a :--> b)
-> (forall (s :: S). Term s a) -> CompiledTerm b
applyCompiledTerm (CompiledTerm DebuggableScript
sf) forall (s :: S). Term s a
a =
  forall (a :: S -> Type). DebuggableScript -> CompiledTerm a
CompiledTerm forall a b. (a -> b) -> a -> b
$ DebuggableScript -> DebuggableScript -> DebuggableScript
applyDebuggableArg DebuggableScript
sf (DebuggableScript -> DebuggableScript
mustEvalD forall a b. (a -> b) -> a -> b
$ forall (a :: S -> Type).
(forall (s :: S). Term s a) -> DebuggableScript
mustCompileD forall (s :: S). Term s a
a)

{- | Apply a 'CompiledTerm' to a closed Plutarch 'Term'.

 Does NOT evaluate the argument before applying. Using this seems to save very
 little overhead, not worth it for efficiency. Only use it to make argument
 evaluation count for benchmarking.

 @since 3.0.2
-}
applyCompiledTerm' ::
  forall (a :: S -> Type) (b :: S -> Type).
  CompiledTerm (a :--> b) ->
  (forall (s :: S). Term s a) ->
  CompiledTerm b
applyCompiledTerm' :: forall (a :: S -> Type) (b :: S -> Type).
CompiledTerm (a :--> b)
-> (forall (s :: S). Term s a) -> CompiledTerm b
applyCompiledTerm' (CompiledTerm DebuggableScript
sf) forall (s :: S). Term s a
a =
  forall (a :: S -> Type). DebuggableScript -> CompiledTerm a
CompiledTerm forall a b. (a -> b) -> a -> b
$ DebuggableScript -> DebuggableScript -> DebuggableScript
applyDebuggableArg DebuggableScript
sf (forall (a :: S -> Type).
(forall (s :: S). Term s a) -> DebuggableScript
mustCompileD forall (s :: S). Term s a
a)

{- | Apply a 'CompiledTerm' to a 'CompiledTerm'.

 Evaluates the argument before applying. You want this for benchmarking the
 compiled function. Helps to avoid tainting the measurement by input
 conversions.

 @since 3.0.2
-}
applyCompiledTerm2 ::
  forall (a :: S -> Type) (b :: S -> Type).
  CompiledTerm (a :--> b) ->
  CompiledTerm a ->
  CompiledTerm b
applyCompiledTerm2 :: forall (a :: S -> Type) (b :: S -> Type).
CompiledTerm (a :--> b) -> CompiledTerm a -> CompiledTerm b
applyCompiledTerm2 (CompiledTerm DebuggableScript
sf) (CompiledTerm DebuggableScript
sa) =
  forall (a :: S -> Type). DebuggableScript -> CompiledTerm a
CompiledTerm forall a b. (a -> b) -> a -> b
$ DebuggableScript -> DebuggableScript -> DebuggableScript
applyDebuggableArg DebuggableScript
sf (DebuggableScript -> DebuggableScript
mustEvalD DebuggableScript
sa)

{- | Apply a 'CompiledTerm' to a 'CompiledTerm'.

 Does NOT evaluate the argument before applying. Using this seems to save very
 little overhead, not worth it for efficiency. Only use it to make argument
 evaluation count for benchmarking.

 @since 3.0.2
-}
applyCompiledTerm2' ::
  forall (a :: S -> Type) (b :: S -> Type).
  CompiledTerm (a :--> b) ->
  CompiledTerm a ->
  CompiledTerm b
applyCompiledTerm2' :: forall (a :: S -> Type) (b :: S -> Type).
CompiledTerm (a :--> b) -> CompiledTerm a -> CompiledTerm b
applyCompiledTerm2' (CompiledTerm DebuggableScript
sf) (CompiledTerm DebuggableScript
sa) =
  forall (a :: S -> Type). DebuggableScript -> CompiledTerm a
CompiledTerm forall a b. (a -> b) -> a -> b
$ DebuggableScript -> DebuggableScript -> DebuggableScript
applyDebuggableArg DebuggableScript
sf DebuggableScript
sa

{- | Alias for 'applyCompiledTerm'.

 @since 3.0.2
-}
(##) ::
  forall (a :: S -> Type) (b :: S -> Type).
  CompiledTerm (a :--> b) ->
  (forall (s :: S). Term s a) ->
  CompiledTerm b
## :: forall (a :: S -> Type) (b :: S -> Type).
CompiledTerm (a :--> b)
-> (forall (s :: S). Term s a) -> CompiledTerm b
(##) = forall (a :: S -> Type) (b :: S -> Type).
CompiledTerm (a :--> b)
-> (forall (s :: S). Term s a) -> CompiledTerm b
applyCompiledTerm

infixl 8 ##

{- | Alias for 'applyCompiledTerm\''.

 @since 3.0.2
-}
(##~) ::
  forall (a :: S -> Type) (b :: S -> Type).
  CompiledTerm (a :--> b) ->
  (forall (s :: S). Term s a) ->
  CompiledTerm b
##~ :: forall (a :: S -> Type) (b :: S -> Type).
CompiledTerm (a :--> b)
-> (forall (s :: S). Term s a) -> CompiledTerm b
(##~) = forall (a :: S -> Type) (b :: S -> Type).
CompiledTerm (a :--> b)
-> (forall (s :: S). Term s a) -> CompiledTerm b
applyCompiledTerm'

infixl 8 ##~

{- | Alias for 'applyCompiledTerm2'.

 @since 3.0.2
-}
(###) ::
  forall (a :: S -> Type) (b :: S -> Type).
  CompiledTerm (a :--> b) ->
  CompiledTerm a ->
  CompiledTerm b
### :: forall (a :: S -> Type) (b :: S -> Type).
CompiledTerm (a :--> b) -> CompiledTerm a -> CompiledTerm b
(###) = forall (a :: S -> Type) (b :: S -> Type).
CompiledTerm (a :--> b) -> CompiledTerm a -> CompiledTerm b
applyCompiledTerm2

infixl 7 ###

{- | Alias for 'applyCompiledTerm2\''.

 @since 3.0.2
-}
(###~) ::
  forall (a :: S -> Type) (b :: S -> Type).
  CompiledTerm (a :--> b) ->
  CompiledTerm a ->
  CompiledTerm b
###~ :: forall (a :: S -> Type) (b :: S -> Type).
CompiledTerm (a :--> b) -> CompiledTerm a -> CompiledTerm b
(###~) = forall (a :: S -> Type) (b :: S -> Type).
CompiledTerm (a :--> b) -> CompiledTerm a -> CompiledTerm b
applyCompiledTerm2'

infixl 7 ###~

scriptToTerm :: forall (a :: S -> Type) (s :: S). Script -> Term s a
scriptToTerm :: forall (a :: S -> Type) (s :: S). Script -> Term s a
scriptToTerm (Script Program DeBruijn DefaultUni DefaultFun ()
prog) =
  forall (s :: S) (a :: S -> Type).
(Word64 -> TermMonad TermResult) -> Term s a
Term forall a b. (a -> b) -> a -> b
$ forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ RawTerm -> [HoistedTerm] -> TermResult
TermResult (UTerm -> RawTerm
RCompiled forall a b. (a -> b) -> a -> b
$ forall name (uni :: Type -> Type) fun ann.
Program name uni fun ann -> Term name uni fun ann
UPLC._progTerm Program DeBruijn DefaultUni DefaultFun ()
prog) []

-- | Make a human-readable message from a 'LiftError'.
liftErrorMsg :: LiftError -> String
-- There is no Show instance for LiftError:
-- We would need to get Show for 'KnownTypeError' into 'PlutusCore.Builtin',
-- then Show for 'LiftError' into Plutarch.
-- Though seeing the data constructors only would not be very informative
-- anyway.
liftErrorMsg :: LiftError -> String
liftErrorMsg = \case
  LiftError
LiftError_FromRepr -> String
"pconstantFromRepr returned 'Nothing'"
  LiftError_KnownTypeError KnownTypeError
e ->
    case KnownTypeError
e of
      KnownTypeUnliftingError UnliftingError
unliftErr ->
        String
"incorrect type: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show UnliftingError
unliftErr
      KnownTypeError
KnownTypeEvaluationFailure ->
        String
"absurd evaluation failure"
  LiftError_EvalError EvalError
e -> String
"erring term: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show EvalError
e
  LiftError_CompilationError Text
msg -> String
"compilation failed: " forall a. Semigroup a => a -> a -> a
<> Text -> String
Text.unpack Text
msg

{- | Convert a 'CompiledTerm' to the associated Haskell value. Fail otherwise.

 This will fully evaluate the compiled term, and convert the resulting value.

 @since 3.0.2
-}
pliftCompiled' ::
  forall (p :: S -> Type).
  PUnsafeLiftDecl p =>
  CompiledTerm p ->
  Either (LiftError, [Text]) (PLifted p)
pliftCompiled' :: forall (p :: S -> Type).
PUnsafeLiftDecl p =>
CompiledTerm p -> Either (LiftError, [Text]) (PLifted p)
pliftCompiled' CompiledTerm p
ct =
  case Either EvalError Script
res of
    Left EvalError
evalError -> forall a b. a -> Either a b
Left (EvalError -> LiftError
LiftError_EvalError EvalError
evalError, [Text]
traces)
    Right Script
evaluatedScript ->
      case forall (p :: S -> Type).
PUnsafeLiftDecl p =>
Config -> ClosedTerm p -> Either LiftError (PLifted p)
plift'
        (Config {$sel:tracingMode:Config :: TracingMode
tracingMode = TracingMode
DetTracing})
        (forall (a :: S -> Type) (s :: S). Script -> Term s a
scriptToTerm @p Script
evaluatedScript) of
        Right PLifted p
lifted -> forall a b. b -> Either a b
Right PLifted p
lifted
        Left (LiftError_EvalError EvalError
evalError) ->
          forall a. HasCallStack => String -> a
error forall b c a. (b -> c) -> (a -> b) -> a -> c
. [String] -> String
unlines forall a b. (a -> b) -> a -> b
$
            [ String
"Lifting EVALUATED compiled term resulted in "
                forall a. Semigroup a => a -> a -> a
<> String
"LiftError_EvalError!"
            , forall a. Show a => a -> String
show EvalError
evalError
            ]
        Left (LiftError_CompilationError Text
compilationMsg) ->
          forall a. HasCallStack => String -> a
error forall b c a. (b -> c) -> (a -> b) -> a -> c
. [String] -> String
unlines forall a b. (a -> b) -> a -> b
$
            [ String
"Lifting evaluated COMPILED term resulted in "
                forall a. Semigroup a => a -> a -> a
<> String
"LiftError_CompilationError!"
            , Text -> String
Text.unpack Text
compilationMsg
            ]
        Left LiftError
liftError -> LiftError -> Either (LiftError, [Text]) (PLifted p)
handleOtherLiftError LiftError
liftError
  where
    (Either EvalError Script
res, ExBudget
_, [Text]
traces) = DebuggableScript -> (Either EvalError Script, ExBudget, [Text])
finalEvalDebuggableScript forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (a :: S -> Type). CompiledTerm a -> DebuggableScript
debuggableScript forall a b. (a -> b) -> a -> b
$ CompiledTerm p
ct
    (Either EvalError Script
res', ExBudget
_, [Text]
traces') = Script -> (Either EvalError Script, ExBudget, [Text])
evalScript forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (is :: IxList) s a.
Is k A_Getter =>
Optic' k is s a -> s -> a
view forall a. IsLabel "debugScript" a => a
#debugScript forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (a :: S -> Type). CompiledTerm a -> DebuggableScript
debuggableScript forall a b. (a -> b) -> a -> b
$ CompiledTerm p
ct
    handleOtherLiftError :: LiftError -> Either (LiftError, [Text]) (PLifted p)
handleOtherLiftError LiftError
liftError =
      case Either EvalError Script
res' of
        Left EvalError
evalError ->
          forall a. HasCallStack => String -> a
error forall b c a. (b -> c) -> (a -> b) -> a -> c
. [String] -> String
unlines forall a b. (a -> b) -> a -> b
$
            [ String
"Script succeeded, but corresponding debug Script failed!"
            , forall a. Show a => a -> String
show EvalError
evalError
            , String
"Debug Script traces:"
            , Text -> String
Text.unpack ([Text] -> Text
Text.unlines [Text]
traces')
            , String
"The debug Script was tried because of a LiftError."
            , String
"The original LiftError of the succeeded Script:"
            , LiftError -> String
liftErrorMsg LiftError
liftError
            ]
        Right Script
evaluatedDebugScript ->
          case forall (p :: S -> Type).
PUnsafeLiftDecl p =>
Config -> ClosedTerm p -> Either LiftError (PLifted p)
plift'
            (Config {$sel:tracingMode:Config :: TracingMode
tracingMode = TracingMode
DetTracing})
            (forall (a :: S -> Type) (s :: S). Script -> Term s a
scriptToTerm @p Script
evaluatedDebugScript) of
            Right PLifted p
_ ->
              forall a. HasCallStack => String -> a
error forall b c a. (b -> c) -> (a -> b) -> a -> c
. [String] -> String
unlines forall a b. (a -> b) -> a -> b
$
                [ String
"Lifting evaluated compiled term resulted in a "
                    forall a. Semigroup a => a -> a -> a
<> String
"LiftError, but lifting the debug version "
                    forall a. Semigroup a => a -> a -> a
<> String
"succeeded!"
                , String
"The LiftError:"
                , LiftError -> String
liftErrorMsg LiftError
liftError
                ]
            Left LiftError
liftError' ->
              if LiftError
liftError' forall a. Eq a => a -> a -> Bool
== LiftError
liftError
                then forall a b. a -> Either a b
Left (LiftError
liftError, [Text]
traces')
                else
                  forall a. HasCallStack => String -> a
error forall b c a. (b -> c) -> (a -> b) -> a -> c
. [String] -> String
unlines forall a b. (a -> b) -> a -> b
$
                    [ String
"Lifting Script and corresponding debug "
                        forall a. Semigroup a => a -> a -> a
<> String
"Script resulted in different "
                        forall a. Semigroup a => a -> a -> a
<> String
"LiftErrors!"
                    , String
"Original LiftError:"
                    , LiftError -> String
liftErrorMsg LiftError
liftError
                    , String
"Debug Script LiftError:"
                    , LiftError -> String
liftErrorMsg LiftError
liftError'
                    ]

{- | Like `pliftCompiled'` but throws on failure.

 @since 3.0.2
-}
pliftCompiled ::
  forall (p :: S -> Type).
  (HasCallStack, PLift p) =>
  CompiledTerm p ->
  PLifted p
pliftCompiled :: forall (p :: S -> Type).
(HasCallStack, PLift p) =>
CompiledTerm p -> PLifted p
pliftCompiled CompiledTerm p
ct =
  case forall (p :: S -> Type).
PUnsafeLiftDecl p =>
CompiledTerm p -> Either (LiftError, [Text]) (PLifted p)
pliftCompiled' CompiledTerm p
ct of
    Left (LiftError
liftError, [Text]
traces) ->
      forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
        [String] -> String
unlines
          [ String
"Lifting compiled term failed:"
          , LiftError -> String
liftErrorMsg LiftError
liftError
          , String
"Traces:"
          , Text -> String
Text.unpack ([Text] -> Text
Text.unlines [Text]
traces)
          ]
    Right PLifted p
x -> PLifted p
x