{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Simplex.Messaging.Agent.Store.SQLite.Common
  ( DBStore (..),
    DBOpts (..),
    SQLiteFuncDef (..),
    SQLiteFuncPtrs (..),
    withConnection,
    withConnection',
    withTransaction,
    withTransaction',
    withTransactionPriority,
    withSavepoint,
    dbBusyLoop,
    storeKey,
  )
where

import Control.Concurrent (threadDelay)
import Control.Concurrent.STM (retry)
import Data.ByteArray (ScrubbedBytes)
import qualified Data.ByteArray as BA
import Data.ByteString (ByteString)
import Database.SQLite.Simple (SQLError)
import qualified Database.SQLite.Simple as SQL
import Database.SQLite3.Bindings
import Foreign.Ptr
import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB
import Simplex.Messaging.Agent.Store.SQLite.Util
import Simplex.Messaging.Util (ifM, unlessM)
import qualified UnliftIO.Exception as E
import UnliftIO.MVar
import UnliftIO.STM

storeKey :: ScrubbedBytes -> Bool -> Maybe ScrubbedBytes
storeKey :: ScrubbedBytes -> Bool -> Maybe ScrubbedBytes
storeKey ScrubbedBytes
key Bool
keepKey = if Bool
keepKey Bool -> Bool -> Bool
|| ScrubbedBytes -> Bool
forall a. ByteArrayAccess a => a -> Bool
BA.null ScrubbedBytes
key then ScrubbedBytes -> Maybe ScrubbedBytes
forall a. a -> Maybe a
Just ScrubbedBytes
key else Maybe ScrubbedBytes
forall a. Maybe a
Nothing

data DBStore = DBStore
  { DBStore -> FilePath
dbFilePath :: FilePath,
    DBStore -> [SQLiteFuncDef]
dbFunctions :: [SQLiteFuncDef],
    DBStore -> TVar (Maybe ScrubbedBytes)
dbKey :: TVar (Maybe ScrubbedBytes),
    DBStore -> TVar Int
dbSem :: TVar Int,
    DBStore -> MVar Connection
dbConnection :: MVar DB.Connection,
    DBStore -> TVar Bool
dbClosed :: TVar Bool,
    DBStore -> Bool
dbNew :: Bool
  }

data DBOpts = DBOpts
  { DBOpts -> FilePath
dbFilePath :: FilePath,
    DBOpts -> [SQLiteFuncDef]
dbFunctions :: [SQLiteFuncDef],
    DBOpts -> ScrubbedBytes
dbKey :: ScrubbedBytes,
    DBOpts -> Bool
keepKey :: Bool,
    DBOpts -> Bool
vacuum :: Bool,
    DBOpts -> TrackQueries
track :: DB.TrackQueries
  }

-- e.g. `SQLiteFuncDef "func_name" 2 (SQLiteFuncPtr True func)`
-- or   `SQLiteFuncDef "aggr_name" 3 (SQLiteAggrPtrs step final)`
data SQLiteFuncDef = SQLiteFuncDef
  { SQLiteFuncDef -> ByteString
funcName :: ByteString,
    SQLiteFuncDef -> CArgCount
argCount :: CArgCount,
    SQLiteFuncDef -> SQLiteFuncPtrs
funcPtrs :: SQLiteFuncPtrs
  }

data SQLiteFuncPtrs
  = SQLiteFuncPtr {SQLiteFuncPtrs -> Bool
deterministic :: Bool, SQLiteFuncPtrs -> FunPtr SQLiteFunc
funcPtr :: FunPtr SQLiteFunc}
  | SQLiteAggrPtrs {SQLiteFuncPtrs -> FunPtr SQLiteFunc
stepPtr :: FunPtr SQLiteFunc, SQLiteFuncPtrs -> FunPtr SQLiteFuncFinal
finalPtr :: FunPtr SQLiteFuncFinal}

withConnectionPriority :: DBStore -> Bool -> (DB.Connection -> IO a) -> IO a
withConnectionPriority :: forall a. DBStore -> Bool -> (Connection -> IO a) -> IO a
withConnectionPriority DBStore {TVar Int
$sel:dbSem:DBStore :: DBStore -> TVar Int
dbSem :: TVar Int
dbSem, MVar Connection
$sel:dbConnection:DBStore :: DBStore -> MVar Connection
dbConnection :: MVar Connection
dbConnection} Bool
priority Connection -> IO a
action
  | Bool
priority = IO () -> IO () -> IO a -> IO a
forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> m b -> m c -> m c
E.bracket_ IO ()
signal IO ()
release (IO a -> IO a) -> IO a -> IO a
forall a b. (a -> b) -> a -> b
$ MVar Connection -> (Connection -> IO a) -> IO a
forall (m :: * -> *) a b.
MonadUnliftIO m =>
MVar a -> (a -> m b) -> m b
withMVar MVar Connection
dbConnection Connection -> IO a
action
  | Bool
otherwise = IO a
lowPriority
  where
    lowPriority :: IO a
lowPriority = IO ()
wait IO () -> IO (Maybe a) -> IO (Maybe a)
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> MVar Connection -> (Connection -> IO (Maybe a)) -> IO (Maybe a)
forall (m :: * -> *) a b.
MonadUnliftIO m =>
MVar a -> (a -> m b) -> m b
withMVar MVar Connection
dbConnection (\Connection
db -> IO Bool -> IO (Maybe a) -> IO (Maybe a) -> IO (Maybe a)
forall (m :: * -> *) a. Monad m => m Bool -> m a -> m a -> m a
ifM IO Bool
free (a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> IO a -> IO (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection -> IO a
action Connection
db) (Maybe a -> IO (Maybe a)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe a
forall a. Maybe a
Nothing)) IO (Maybe a) -> (Maybe a -> IO a) -> IO a
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IO a -> (a -> IO a) -> Maybe a -> IO a
forall b a. b -> (a -> b) -> Maybe a -> b
maybe IO a
lowPriority a -> IO a
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    signal :: IO ()
signal = STM () -> IO ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar Int -> (Int -> Int) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar Int
dbSem (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
    release :: IO ()
release = STM () -> IO ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar Int -> (Int -> Int) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar Int
dbSem ((Int -> Int) -> STM ()) -> (Int -> Int) -> STM ()
forall a b. (a -> b) -> a -> b
$ \Int
sem -> if Int
sem Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 then Int
sem Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 else Int
0
    wait :: IO ()
wait = IO Bool -> IO () -> IO ()
forall (m :: * -> *). Monad m => m Bool -> m () -> m ()
unlessM IO Bool
free (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ STM () -> IO ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ STM Bool -> STM () -> STM ()
forall (m :: * -> *). Monad m => m Bool -> m () -> m ()
unlessM ((Int
0 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==) (Int -> Bool) -> STM Int -> STM Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TVar Int -> STM Int
forall a. TVar a -> STM a
readTVar TVar Int
dbSem) STM ()
forall a. STM a
retry
    free :: IO Bool
free = (Int
0 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==) (Int -> Bool) -> IO Int -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TVar Int -> IO Int
forall (m :: * -> *) a. MonadIO m => TVar a -> m a
readTVarIO TVar Int
dbSem

withConnection :: DBStore -> (DB.Connection -> IO a) -> IO a
withConnection :: forall a. DBStore -> (Connection -> IO a) -> IO a
withConnection DBStore
st = DBStore -> Bool -> (Connection -> IO a) -> IO a
forall a. DBStore -> Bool -> (Connection -> IO a) -> IO a
withConnectionPriority DBStore
st Bool
False

withConnection' :: DBStore -> (SQL.Connection -> IO a) -> IO a
withConnection' :: forall a. DBStore -> (Connection -> IO a) -> IO a
withConnection' DBStore
st Connection -> IO a
action = DBStore -> (Connection -> IO a) -> IO a
forall a. DBStore -> (Connection -> IO a) -> IO a
withConnection DBStore
st ((Connection -> IO a) -> IO a) -> (Connection -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ Connection -> IO a
action (Connection -> IO a)
-> (Connection -> Connection) -> Connection -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Connection -> Connection
DB.conn

withTransaction' :: DBStore -> (SQL.Connection -> IO a) -> IO a
withTransaction' :: forall a. DBStore -> (Connection -> IO a) -> IO a
withTransaction' DBStore
st Connection -> IO a
action = DBStore -> (Connection -> IO a) -> IO a
forall a. DBStore -> (Connection -> IO a) -> IO a
withTransaction DBStore
st ((Connection -> IO a) -> IO a) -> (Connection -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ Connection -> IO a
action (Connection -> IO a)
-> (Connection -> Connection) -> Connection -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Connection -> Connection
DB.conn

withTransaction :: DBStore -> (DB.Connection -> IO a) -> IO a
withTransaction :: forall a. DBStore -> (Connection -> IO a) -> IO a
withTransaction DBStore
st = DBStore -> Bool -> (Connection -> IO a) -> IO a
forall a. DBStore -> Bool -> (Connection -> IO a) -> IO a
withTransactionPriority DBStore
st Bool
False
{-# INLINE withTransaction #-}

withTransactionPriority :: DBStore -> Bool -> (DB.Connection -> IO a) -> IO a
withTransactionPriority :: forall a. DBStore -> Bool -> (Connection -> IO a) -> IO a
withTransactionPriority DBStore
st Bool
priority Connection -> IO a
action = DBStore -> Bool -> (Connection -> IO a) -> IO a
forall a. DBStore -> Bool -> (Connection -> IO a) -> IO a
withConnectionPriority DBStore
st Bool
priority ((Connection -> IO a) -> IO a) -> (Connection -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ IO a -> IO a
forall a. IO a -> IO a
dbBusyLoop (IO a -> IO a) -> (Connection -> IO a) -> Connection -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Connection -> IO a
transaction
  where
    transaction :: Connection -> IO a
transaction db :: Connection
db@DB.Connection {Connection
conn :: Connection -> Connection
conn :: Connection
conn} = Connection -> IO a -> IO a
forall a. Connection -> IO a -> IO a
SQL.withImmediateTransaction Connection
conn (IO a -> IO a) -> IO a -> IO a
forall a b. (a -> b) -> a -> b
$ Connection -> IO a
action Connection
db

-- No-op for SQLite, just tries the action.
-- This provides a consistent interface with the PostgreSQL version.
withSavepoint :: DB.Connection -> SQL.Query -> IO a -> IO (Either SQLError a)
withSavepoint :: forall a. Connection -> Query -> IO a -> IO (Either SQLError a)
withSavepoint Connection
_ Query
_ = IO a -> IO (Either SQLError a)
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> m (Either e a)
E.try
{-# INLINE withSavepoint #-}

dbBusyLoop :: forall a. IO a -> IO a
dbBusyLoop :: forall a. IO a -> IO a
dbBusyLoop IO a
action = Int -> Int -> IO a
loop Int
500 Int
3000000
  where
    loop :: Int -> Int -> IO a
    loop :: Int -> Int -> IO a
loop Int
t Int
tLim =
      IO a
action IO a -> (SQLError -> IO a) -> IO a
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> (e -> m a) -> m a
`E.catch` \(SQLError
e :: SQLError) ->
        let se :: Error
se = SQLError -> Error
SQL.sqlError SQLError
e
         in if Int
tLim Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
t Bool -> Bool -> Bool
&& (Error
se Error -> Error -> Bool
forall a. Eq a => a -> a -> Bool
== Error
SQL.ErrorBusy Bool -> Bool -> Bool
|| Error
se Error -> Error -> Bool
forall a. Eq a => a -> a -> Bool
== Error
SQL.ErrorLocked)
              then do
                Int -> IO ()
threadDelay Int
t
                Int -> Int -> IO a
loop (Int
t Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
9 Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
8) (Int
tLim Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
t)
              else SQLError -> IO a
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
E.throwIO SQLError
e