module Simplex.Messaging.Agent.Store.SQLite.Util
  ( SQLiteFunc,
    SQLiteFuncFinal,
    mkSQLiteFunc,
    mkSQLiteAggStep,
    mkSQLiteAggFinal,
    createStaticFunction,
    createStaticAggregate,
  ) where

import Control.Exception (SomeException, catch, mask_)
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Data.IORef
import Database.SQLite3.Direct (Database (..), FuncArgs (..), FuncContext (..))
import Database.SQLite3.Bindings
import Foreign.C.String
import Foreign.Ptr
import Foreign.StablePtr
import Foreign.Storable

data CFuncPtrs = CFuncPtrs (FunPtr CFunc) (FunPtr CFunc) (FunPtr CFuncFinal)

type SQLiteFunc = Ptr CContext -> CArgCount -> Ptr (Ptr CValue) -> IO ()

type SQLiteFuncFinal = Ptr CContext -> IO ()

mkSQLiteFunc :: (FuncContext -> FuncArgs -> IO ()) -> SQLiteFunc
mkSQLiteFunc :: (FuncContext -> FuncArgs -> IO ()) -> SQLiteFunc
mkSQLiteFunc FuncContext -> FuncArgs -> IO ()
f Ptr CContext
cxt CArgCount
nArgs Ptr (Ptr CValue)
cvals = Ptr CContext -> IO () -> IO ()
catchAsResultError Ptr CContext
cxt (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ FuncContext -> FuncArgs -> IO ()
f (Ptr CContext -> FuncContext
FuncContext Ptr CContext
cxt) (CArgCount -> Ptr (Ptr CValue) -> FuncArgs
FuncArgs CArgCount
nArgs Ptr (Ptr CValue)
cvals)
{-# INLINE mkSQLiteFunc #-}

-- Based on createFunction from Database.SQLite3.Direct, but uses static function pointer to avoid dynamic wrapper that triggers DCL.
createStaticFunction :: Database -> ByteString -> CArgCount -> Bool -> FunPtr SQLiteFunc -> IO (Either Error ())
createStaticFunction :: Database
-> ByteString
-> CArgCount
-> Bool
-> FunPtr SQLiteFunc
-> IO (Either Error ())
createStaticFunction (Database Ptr CDatabase
db) ByteString
name CArgCount
nArgs Bool
isDet FunPtr SQLiteFunc
funPtr = IO (Either Error ()) -> IO (Either Error ())
forall a. IO a -> IO a
mask_ (IO (Either Error ()) -> IO (Either Error ()))
-> IO (Either Error ()) -> IO (Either Error ())
forall a b. (a -> b) -> a -> b
$ do
  StablePtr CFuncPtrs
u <- CFuncPtrs -> IO (StablePtr CFuncPtrs)
forall a. a -> IO (StablePtr a)
newStablePtr (CFuncPtrs -> IO (StablePtr CFuncPtrs))
-> CFuncPtrs -> IO (StablePtr CFuncPtrs)
forall a b. (a -> b) -> a -> b
$ FunPtr SQLiteFunc
-> FunPtr SQLiteFunc -> FunPtr CFuncFinal -> CFuncPtrs
CFuncPtrs FunPtr SQLiteFunc
funPtr FunPtr SQLiteFunc
forall a. FunPtr a
nullFunPtr FunPtr CFuncFinal
forall a. FunPtr a
nullFunPtr
  let flags :: CInt
flags = if Bool
isDet then CInt
c_SQLITE_DETERMINISTIC else CInt
0
  ByteString
-> (CString -> IO (Either Error ())) -> IO (Either Error ())
forall a. ByteString -> (CString -> IO a) -> IO a
B.useAsCString ByteString
name ((CString -> IO (Either Error ())) -> IO (Either Error ()))
-> (CString -> IO (Either Error ())) -> IO (Either Error ())
forall a b. (a -> b) -> a -> b
$ \CString
namePtr ->
    () -> CError -> Either Error ()
forall a. a -> CError -> Either Error a
toResult () (CError -> Either Error ()) -> IO CError -> IO (Either Error ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr CDatabase
-> CString
-> CArgCount
-> CInt
-> Ptr ()
-> FunPtr SQLiteFunc
-> FunPtr SQLiteFunc
-> FunPtr CFuncFinal
-> FunPtr (CFuncDestroy ())
-> IO CError
forall a.
Ptr CDatabase
-> CString
-> CArgCount
-> CInt
-> Ptr a
-> FunPtr SQLiteFunc
-> FunPtr SQLiteFunc
-> FunPtr CFuncFinal
-> FunPtr (CFuncDestroy a)
-> IO CError
c_sqlite3_create_function_v2 Ptr CDatabase
db CString
namePtr CArgCount
nArgs CInt
flags (StablePtr CFuncPtrs -> Ptr ()
forall a. StablePtr a -> Ptr ()
castStablePtrToPtr StablePtr CFuncPtrs
u) FunPtr SQLiteFunc
funPtr FunPtr SQLiteFunc
forall a. FunPtr a
nullFunPtr FunPtr CFuncFinal
forall a. FunPtr a
nullFunPtr FunPtr (CFuncDestroy ())
forall a. FunPtr a
nullFunPtr

mkSQLiteAggStep :: a -> (FuncContext -> FuncArgs -> a -> IO a) -> SQLiteFunc
mkSQLiteAggStep :: forall a. a -> (FuncContext -> FuncArgs -> a -> IO a) -> SQLiteFunc
mkSQLiteAggStep a
initSt FuncContext -> FuncArgs -> a -> IO a
xStep Ptr CContext
cxt CArgCount
nArgs Ptr (Ptr CValue)
cvals = Ptr CContext -> IO () -> IO ()
catchAsResultError Ptr CContext
cxt (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
  -- we store the aggregate state in the buffer returned by
  -- c_sqlite3_aggregate_context as a StablePtr pointing to an IORef that
  -- contains the actual aggregate state
  Ptr (StablePtr (IORef a))
aggCtx <- Ptr CContext -> IO (Ptr (StablePtr (IORef a)))
forall a. Ptr CContext -> IO (Ptr a)
getAggregateContext Ptr CContext
cxt
  StablePtr (IORef a)
aggStPtr <- Ptr (StablePtr (IORef a)) -> IO (StablePtr (IORef a))
forall a. Storable a => Ptr a -> IO a
peek Ptr (StablePtr (IORef a))
aggCtx
  IORef a
aggStRef <-
    if StablePtr (IORef a) -> Ptr ()
forall a. StablePtr a -> Ptr ()
castStablePtrToPtr StablePtr (IORef a)
aggStPtr Ptr () -> Ptr () -> Bool
forall a. Eq a => a -> a -> Bool
/= Ptr ()
forall a. Ptr a
nullPtr
      then StablePtr (IORef a) -> IO (IORef a)
forall a. StablePtr a -> IO a
deRefStablePtr StablePtr (IORef a)
aggStPtr
      else do
        IORef a
aggStRef <- a -> IO (IORef a)
forall a. a -> IO (IORef a)
newIORef a
initSt
        StablePtr (IORef a)
aggStPtr' <- IORef a -> IO (StablePtr (IORef a))
forall a. a -> IO (StablePtr a)
newStablePtr IORef a
aggStRef
        Ptr (StablePtr (IORef a)) -> StablePtr (IORef a) -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr (StablePtr (IORef a))
aggCtx StablePtr (IORef a)
aggStPtr'
        IORef a -> IO (IORef a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return IORef a
aggStRef
  a
aggSt <- IORef a -> IO a
forall a. IORef a -> IO a
readIORef IORef a
aggStRef
  a
aggSt' <- FuncContext -> FuncArgs -> a -> IO a
xStep (Ptr CContext -> FuncContext
FuncContext Ptr CContext
cxt) (CArgCount -> Ptr (Ptr CValue) -> FuncArgs
FuncArgs CArgCount
nArgs Ptr (Ptr CValue)
cvals) a
aggSt
  IORef a -> a -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef a
aggStRef a
aggSt'

mkSQLiteAggFinal :: a -> (FuncContext -> a -> IO ()) -> SQLiteFuncFinal
mkSQLiteAggFinal :: forall a. a -> (FuncContext -> a -> IO ()) -> CFuncFinal
mkSQLiteAggFinal a
initSt FuncContext -> a -> IO ()
xFinal Ptr CContext
cxt = do
  Ptr (StablePtr (IORef a))
aggCtx <- Ptr CContext -> IO (Ptr (StablePtr (IORef a)))
forall a. Ptr CContext -> IO (Ptr a)
getAggregateContext Ptr CContext
cxt
  StablePtr (IORef a)
aggStPtr <- Ptr (StablePtr (IORef a)) -> IO (StablePtr (IORef a))
forall a. Storable a => Ptr a -> IO a
peek Ptr (StablePtr (IORef a))
aggCtx
  if StablePtr (IORef a) -> Ptr ()
forall a. StablePtr a -> Ptr ()
castStablePtrToPtr StablePtr (IORef a)
aggStPtr Ptr () -> Ptr () -> Bool
forall a. Eq a => a -> a -> Bool
== Ptr ()
forall a. Ptr a
nullPtr
    then Ptr CContext -> IO () -> IO ()
catchAsResultError Ptr CContext
cxt (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ FuncContext -> a -> IO ()
xFinal (Ptr CContext -> FuncContext
FuncContext Ptr CContext
cxt) a
initSt
    else do
      Ptr CContext -> IO () -> IO ()
catchAsResultError Ptr CContext
cxt (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        IORef a
aggStRef <- StablePtr (IORef a) -> IO (IORef a)
forall a. StablePtr a -> IO a
deRefStablePtr StablePtr (IORef a)
aggStPtr
        a
aggSt <- IORef a -> IO a
forall a. IORef a -> IO a
readIORef IORef a
aggStRef
        FuncContext -> a -> IO ()
xFinal (Ptr CContext -> FuncContext
FuncContext Ptr CContext
cxt) a
aggSt
      StablePtr (IORef a) -> IO ()
forall a. StablePtr a -> IO ()
freeStablePtr StablePtr (IORef a)
aggStPtr

getAggregateContext :: Ptr CContext -> IO (Ptr a)
getAggregateContext :: forall a. Ptr CContext -> IO (Ptr a)
getAggregateContext Ptr CContext
cxt = Ptr CContext -> CNumBytes -> IO (Ptr a)
forall a. Ptr CContext -> CNumBytes -> IO (Ptr a)
c_sqlite3_aggregate_context Ptr CContext
cxt CNumBytes
stPtrSize
  where
    stPtrSize :: CNumBytes
stPtrSize = Int -> CNumBytes
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CNumBytes) -> Int -> CNumBytes
forall a b. (a -> b) -> a -> b
$ StablePtr () -> Int
forall a. Storable a => a -> Int
sizeOf (StablePtr ()
forall a. HasCallStack => a
undefined :: StablePtr ())

-- Based on createAggregate from Database.SQLite3.Direct, but uses static function pointers to avoid dynamic wrappers that trigger DCL.
createStaticAggregate :: Database -> ByteString -> CArgCount -> FunPtr SQLiteFunc -> FunPtr SQLiteFuncFinal -> IO (Either Error ())
createStaticAggregate :: Database
-> ByteString
-> CArgCount
-> FunPtr SQLiteFunc
-> FunPtr CFuncFinal
-> IO (Either Error ())
createStaticAggregate (Database Ptr CDatabase
db) ByteString
name CArgCount
nArgs FunPtr SQLiteFunc
stepPtr FunPtr CFuncFinal
finalPtr = IO (Either Error ()) -> IO (Either Error ())
forall a. IO a -> IO a
mask_ (IO (Either Error ()) -> IO (Either Error ()))
-> IO (Either Error ()) -> IO (Either Error ())
forall a b. (a -> b) -> a -> b
$ do
  StablePtr CFuncPtrs
u <- CFuncPtrs -> IO (StablePtr CFuncPtrs)
forall a. a -> IO (StablePtr a)
newStablePtr (CFuncPtrs -> IO (StablePtr CFuncPtrs))
-> CFuncPtrs -> IO (StablePtr CFuncPtrs)
forall a b. (a -> b) -> a -> b
$ FunPtr SQLiteFunc
-> FunPtr SQLiteFunc -> FunPtr CFuncFinal -> CFuncPtrs
CFuncPtrs FunPtr SQLiteFunc
forall a. FunPtr a
nullFunPtr FunPtr SQLiteFunc
stepPtr FunPtr CFuncFinal
finalPtr
  ByteString
-> (CString -> IO (Either Error ())) -> IO (Either Error ())
forall a. ByteString -> (CString -> IO a) -> IO a
B.useAsCString ByteString
name ((CString -> IO (Either Error ())) -> IO (Either Error ()))
-> (CString -> IO (Either Error ())) -> IO (Either Error ())
forall a b. (a -> b) -> a -> b
$ \CString
namePtr ->
    () -> CError -> Either Error ()
forall a. a -> CError -> Either Error a
toResult () (CError -> Either Error ()) -> IO CError -> IO (Either Error ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr CDatabase
-> CString
-> CArgCount
-> CInt
-> Ptr ()
-> FunPtr SQLiteFunc
-> FunPtr SQLiteFunc
-> FunPtr CFuncFinal
-> FunPtr (CFuncDestroy ())
-> IO CError
forall a.
Ptr CDatabase
-> CString
-> CArgCount
-> CInt
-> Ptr a
-> FunPtr SQLiteFunc
-> FunPtr SQLiteFunc
-> FunPtr CFuncFinal
-> FunPtr (CFuncDestroy a)
-> IO CError
c_sqlite3_create_function_v2 Ptr CDatabase
db CString
namePtr CArgCount
nArgs CInt
0 (StablePtr CFuncPtrs -> Ptr ()
forall a. StablePtr a -> Ptr ()
castStablePtrToPtr StablePtr CFuncPtrs
u) FunPtr SQLiteFunc
forall a. FunPtr a
nullFunPtr FunPtr SQLiteFunc
stepPtr FunPtr CFuncFinal
finalPtr FunPtr (CFuncDestroy ())
forall a. FunPtr a
nullFunPtr

-- Convert a 'CError' to a 'Either Error', in the common case where
-- SQLITE_OK signals success and anything else signals an error.
--
-- Note that SQLITE_OK == 0.
toResult :: a -> CError -> Either Error a
toResult :: forall a. a -> CError -> Either Error a
toResult a
a (CError CInt
0) = a -> Either Error a
forall a b. b -> Either a b
Right a
a
toResult a
_ CError
code = Error -> Either Error a
forall a b. a -> Either a b
Left (Error -> Either Error a) -> Error -> Either Error a
forall a b. (a -> b) -> a -> b
$ CError -> Error
decodeError CError
code

-- call c_sqlite3_result_error in the event of an error
catchAsResultError :: Ptr CContext -> IO () -> IO ()
catchAsResultError :: Ptr CContext -> IO () -> IO ()
catchAsResultError Ptr CContext
ctx IO ()
action = IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch IO ()
action ((SomeException -> IO ()) -> IO ())
-> (SomeException -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \SomeException
exn -> do
  let msg :: String
msg = SomeException -> String
forall a. Show a => a -> String
show (SomeException
exn :: SomeException)
  String -> (CStringLen -> IO ()) -> IO ()
forall a. String -> (CStringLen -> IO a) -> IO a
withCAStringLen String
msg ((CStringLen -> IO ()) -> IO ()) -> (CStringLen -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(CString
ptr, Int
len) ->
    Ptr CContext -> CString -> CNumBytes -> IO ()
c_sqlite3_result_error Ptr CContext
ctx CString
ptr (Int -> CNumBytes
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)