{-# LANGUAGE DeriveAnyClass        #-}
{-# LANGUAGE DeriveGeneric         #-}
{-# LANGUAGE DerivingStrategies    #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TemplateHaskell       #-}
{-# OPTIONS_GHC -fno-omit-interface-pragmas #-}
{-# OPTIONS_GHC -fplugin-opt PlutusTx.Plugin:debug-context #-}
{-# OPTIONS_GHC -fno-ignore-interface-pragmas #-}
module PlutusTx.Sqrt(
    Sqrt (..)
    , rsqrt
    , isqrt
    ) where

import PlutusTx.IsData (makeIsDataIndexed)
import PlutusTx.Lift (makeLift)
import PlutusTx.Prelude (Integer, divide, negate, otherwise, ($), (*), (+), (<), (<=), (==))
import PlutusTx.Ratio (Rational, denominator, numerator, unsafeRatio)
import Prelude qualified as Haskell

-- | Integer square-root representation, discarding imaginary integers.
data Sqrt
  -- | The number was negative, so we don't even attempt to compute it;
  -- just note that the result would be imaginary.
  = Imaginary
  -- | An exact integer result. The 'rsqrt' of 4 is 'Exactly 2'.
  | Exactly Integer
  -- | The Integer component (i.e. the floor) of a non-integral result. The
  -- 'rsqrt 2' is 'Approximately 1'.
  | Approximately Integer
  deriving stock (Int -> Sqrt -> ShowS
[Sqrt] -> ShowS
Sqrt -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Sqrt] -> ShowS
$cshowList :: [Sqrt] -> ShowS
show :: Sqrt -> String
$cshow :: Sqrt -> String
showsPrec :: Int -> Sqrt -> ShowS
$cshowsPrec :: Int -> Sqrt -> ShowS
Haskell.Show, Sqrt -> Sqrt -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Sqrt -> Sqrt -> Bool
$c/= :: Sqrt -> Sqrt -> Bool
== :: Sqrt -> Sqrt -> Bool
$c== :: Sqrt -> Sqrt -> Bool
Haskell.Eq)

{-# INLINABLE rsqrt #-}
-- | Calculates the sqrt of a ratio of integers. As x / 0 is undefined,
-- calling this function with `d=0` results in an error.
rsqrt :: Rational -> Sqrt
rsqrt :: Rational -> Sqrt
rsqrt Rational
r
    | Integer
n forall a. MultiplicativeSemigroup a => a -> a -> a
* Integer
d forall a. Ord a => a -> a -> Bool
< Integer
0 = Sqrt
Imaginary
    | Integer
n forall a. Eq a => a -> a -> Bool
== Integer
0    = Integer -> Sqrt
Exactly Integer
0
    | Integer
n forall a. Eq a => a -> a -> Bool
== Integer
d    = Integer -> Sqrt
Exactly Integer
1
    | Integer
n forall a. Ord a => a -> a -> Bool
< Integer
d     = Integer -> Sqrt
Approximately Integer
0
    | Integer
n forall a. Ord a => a -> a -> Bool
< Integer
0     = Rational -> Sqrt
rsqrt forall a b. (a -> b) -> a -> b
$ Integer -> Integer -> Rational
unsafeRatio (forall a. AdditiveGroup a => a -> a
negate Integer
n) (forall a. AdditiveGroup a => a -> a
negate Integer
d)
    | Bool
otherwise = Integer -> Integer -> Sqrt
go Integer
1 forall a b. (a -> b) -> a -> b
$ Integer
1 forall a. AdditiveSemigroup a => a -> a -> a
+ Integer -> Integer -> Integer
divide Integer
n Integer
d
  where
    n :: Integer
n = Rational -> Integer
numerator Rational
r
    d :: Integer
d = Rational -> Integer
denominator Rational
r
    go :: Integer -> Integer -> Sqrt
    go :: Integer -> Integer -> Sqrt
go Integer
l Integer
u
        | Integer
l forall a. MultiplicativeSemigroup a => a -> a -> a
* Integer
l forall a. MultiplicativeSemigroup a => a -> a -> a
* Integer
d forall a. Eq a => a -> a -> Bool
== Integer
n = Integer -> Sqrt
Exactly Integer
l
        | Integer
u forall a. Eq a => a -> a -> Bool
== (Integer
l forall a. AdditiveSemigroup a => a -> a -> a
+ Integer
1)   = Integer -> Sqrt
Approximately Integer
l
        | Bool
otherwise      =
              let
                m :: Integer
m = Integer -> Integer -> Integer
divide (Integer
l forall a. AdditiveSemigroup a => a -> a -> a
+ Integer
u) Integer
2
              in
                if Integer
m forall a. MultiplicativeSemigroup a => a -> a -> a
* Integer
m forall a. MultiplicativeSemigroup a => a -> a -> a
* Integer
d forall a. Ord a => a -> a -> Bool
<= Integer
n then Integer -> Integer -> Sqrt
go Integer
m Integer
u
                                  else Integer -> Integer -> Sqrt
go Integer
l Integer
m

{-# INLINABLE isqrt #-}
-- | Calculates the integer-component of the sqrt of 'n'.
isqrt :: Integer -> Sqrt
isqrt :: Integer -> Sqrt
isqrt Integer
n = Rational -> Sqrt
rsqrt (Integer -> Integer -> Rational
unsafeRatio Integer
n Integer
1)

makeLift ''Sqrt
makeIsDataIndexed ''Sqrt [ ('Imaginary,     0)
                         , ('Exactly,       1)
                         , ('Approximately, 2)
                         ]