{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TemplateHaskell #-}

module Simplex.Messaging.Agent.Store.SQLite.DB
  ( BoolInt (..),
    Binary (..),
    Connection (..),
    SlowQueryStats (..),
    TrackQueries (..),
    FromField (..),
    ToField (..),
    SQLError,
    open,
    close,
    execute,
    execute_,
    executeMany,
    query,
    query_,
    blobFieldDecoder,
    fromTextField_,
  )
where

import Control.Concurrent.STM
import Control.Exception
import Control.Monad (when)
import qualified Data.Aeson.TH as J
import Data.ByteString (ByteString)
import Data.Int (Int64)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Text (Text)
import qualified Data.Text as T
import Data.Time (diffUTCTime, getCurrentTime)
import Data.Typeable (Typeable)
import Database.SQLite.Simple (FromRow, ResultError (..), Query, SQLData (..), SQLError, ToRow)
import qualified Database.SQLite.Simple as SQL
import Database.SQLite.Simple.FromField (FieldParser, FromField (..), returnError)
import Database.SQLite.Simple.Internal (Field (..))
import Database.SQLite.Simple.Ok (Ok (Ok))
import Database.SQLite.Simple.ToField (ToField (..))
import Simplex.Messaging.Parsers (defaultJSON)
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Util (diffToMicroseconds, tshow)

newtype BoolInt = BI {BoolInt -> Bool
unBI :: Bool}
  deriving newtype (FieldParser BoolInt
FieldParser BoolInt -> FromField BoolInt
forall a. FieldParser a -> FromField a
$cfromField :: FieldParser BoolInt
fromField :: FieldParser BoolInt
FromField, BoolInt -> SQLData
(BoolInt -> SQLData) -> ToField BoolInt
forall a. (a -> SQLData) -> ToField a
$ctoField :: BoolInt -> SQLData
toField :: BoolInt -> SQLData
ToField)

newtype Binary a = Binary {forall a. Binary a -> a
fromBinary :: a}
  deriving newtype (FieldParser (Binary a)
FieldParser (Binary a) -> FromField (Binary a)
forall a. FromField a => FieldParser (Binary a)
forall a. FieldParser a -> FromField a
$cfromField :: forall a. FromField a => FieldParser (Binary a)
fromField :: FieldParser (Binary a)
FromField, Binary a -> SQLData
(Binary a -> SQLData) -> ToField (Binary a)
forall a. ToField a => Binary a -> SQLData
forall a. (a -> SQLData) -> ToField a
$ctoField :: forall a. ToField a => Binary a -> SQLData
toField :: Binary a -> SQLData
ToField)

data Connection = Connection
  { Connection -> Connection
conn :: SQL.Connection,
    Connection -> TrackQueries
track :: TrackQueries,
    Connection -> TMap Query SlowQueryStats
slow :: TMap Query SlowQueryStats
  }

data TrackQueries = TQAll | TQSlow Int64 | TQOff
  deriving (TrackQueries -> TrackQueries -> Bool
(TrackQueries -> TrackQueries -> Bool)
-> (TrackQueries -> TrackQueries -> Bool) -> Eq TrackQueries
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: TrackQueries -> TrackQueries -> Bool
== :: TrackQueries -> TrackQueries -> Bool
$c/= :: TrackQueries -> TrackQueries -> Bool
/= :: TrackQueries -> TrackQueries -> Bool
Eq)

data SlowQueryStats = SlowQueryStats
  { SlowQueryStats -> Int64
count :: Int64,
    SlowQueryStats -> Int64
timeMax :: Int64,
    SlowQueryStats -> Int64
timeAvg :: Int64,
    SlowQueryStats -> Map Text Int
errs :: Map Text Int
  }
  deriving (Int -> SlowQueryStats -> ShowS
[SlowQueryStats] -> ShowS
SlowQueryStats -> String
(Int -> SlowQueryStats -> ShowS)
-> (SlowQueryStats -> String)
-> ([SlowQueryStats] -> ShowS)
-> Show SlowQueryStats
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SlowQueryStats -> ShowS
showsPrec :: Int -> SlowQueryStats -> ShowS
$cshow :: SlowQueryStats -> String
show :: SlowQueryStats -> String
$cshowList :: [SlowQueryStats] -> ShowS
showList :: [SlowQueryStats] -> ShowS
Show)

timeIt :: Connection -> Query -> IO a -> IO a
timeIt :: forall a. Connection -> Query -> IO a -> IO a
timeIt Connection {TMap Query SlowQueryStats
slow :: Connection -> TMap Query SlowQueryStats
slow :: TMap Query SlowQueryStats
slow, TrackQueries
track :: Connection -> TrackQueries
track :: TrackQueries
track} Query
sql IO a
a
  | TrackQueries
track TrackQueries -> TrackQueries -> Bool
forall a. Eq a => a -> a -> Bool
== TrackQueries
TQOff = IO a
makeQuery
  | Bool
otherwise = do
      UTCTime
t <- IO UTCTime
getCurrentTime
      a
r <- IO a
makeQuery
      UTCTime
t' <- IO UTCTime
getCurrentTime
      let diff :: Int64
diff = NominalDiffTime -> Int64
diffToMicroseconds (NominalDiffTime -> Int64) -> NominalDiffTime -> Int64
forall a b. (a -> b) -> a -> b
$ UTCTime -> UTCTime -> NominalDiffTime
diffUTCTime UTCTime
t' UTCTime
t
      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int64 -> Bool
trackQuery Int64
diff) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ (Maybe SlowQueryStats -> Maybe SlowQueryStats)
-> Query -> TMap Query SlowQueryStats -> STM ()
forall k a.
Ord k =>
(Maybe a -> Maybe a) -> k -> TMap k a -> STM ()
TM.alter (Int64 -> Maybe SlowQueryStats -> Maybe SlowQueryStats
updateQueryStats Int64
diff) Query
sql TMap Query SlowQueryStats
slow
      a -> IO a
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
r
  where
    makeQuery :: IO a
makeQuery =
      IO a
a IO a -> (SomeException -> IO a) -> IO a
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` \SomeException
e -> do
        STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ (Maybe SlowQueryStats -> Maybe SlowQueryStats)
-> Query -> TMap Query SlowQueryStats -> STM ()
forall k a.
Ord k =>
(Maybe a -> Maybe a) -> k -> TMap k a -> STM ()
TM.alter (SlowQueryStats -> Maybe SlowQueryStats
forall a. a -> Maybe a
Just (SlowQueryStats -> Maybe SlowQueryStats)
-> (Maybe SlowQueryStats -> SlowQueryStats)
-> Maybe SlowQueryStats
-> Maybe SlowQueryStats
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SomeException -> Maybe SlowQueryStats -> SlowQueryStats
updateQueryErrors SomeException
e) Query
sql TMap Query SlowQueryStats
slow
        SomeException -> IO a
forall e a. Exception e => e -> IO a
throwIO SomeException
e
    trackQuery :: Int64 -> Bool
trackQuery Int64
diff = case TrackQueries
track of
      TrackQueries
TQOff -> Bool
False
      TQSlow Int64
t -> Int64
diff Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
> Int64
t
      TrackQueries
TQAll -> Bool
True
    updateQueryErrors :: SomeException -> Maybe SlowQueryStats -> SlowQueryStats
    updateQueryErrors :: SomeException -> Maybe SlowQueryStats -> SlowQueryStats
updateQueryErrors SomeException
e Maybe SlowQueryStats
Nothing = Int64 -> Int64 -> Int64 -> Map Text Int -> SlowQueryStats
SlowQueryStats Int64
0 Int64
0 Int64
0 (Map Text Int -> SlowQueryStats) -> Map Text Int -> SlowQueryStats
forall a b. (a -> b) -> a -> b
$ Text -> Int -> Map Text Int
forall k a. k -> a -> Map k a
M.singleton (SomeException -> Text
forall a. Show a => a -> Text
tshow SomeException
e) Int
1
    updateQueryErrors SomeException
e (Just st :: SlowQueryStats
st@SlowQueryStats {Map Text Int
errs :: SlowQueryStats -> Map Text Int
errs :: Map Text Int
errs}) =
      SlowQueryStats
st {errs = M.alter (Just . maybe 1 (+ 1)) (tshow e) errs}
    updateQueryStats :: Int64 -> Maybe SlowQueryStats -> Maybe SlowQueryStats
    updateQueryStats :: Int64 -> Maybe SlowQueryStats -> Maybe SlowQueryStats
updateQueryStats Int64
diff Maybe SlowQueryStats
Nothing = SlowQueryStats -> Maybe SlowQueryStats
forall a. a -> Maybe a
Just (SlowQueryStats -> Maybe SlowQueryStats)
-> SlowQueryStats -> Maybe SlowQueryStats
forall a b. (a -> b) -> a -> b
$ Int64 -> Int64 -> Int64 -> Map Text Int -> SlowQueryStats
SlowQueryStats Int64
1 Int64
diff Int64
diff Map Text Int
forall k a. Map k a
M.empty
    updateQueryStats Int64
diff (Just SlowQueryStats {Int64
count :: SlowQueryStats -> Int64
count :: Int64
count, Int64
timeMax :: SlowQueryStats -> Int64
timeMax :: Int64
timeMax, Int64
timeAvg :: SlowQueryStats -> Int64
timeAvg :: Int64
timeAvg, Map Text Int
errs :: SlowQueryStats -> Map Text Int
errs :: Map Text Int
errs}) =
      SlowQueryStats -> Maybe SlowQueryStats
forall a. a -> Maybe a
Just (SlowQueryStats -> Maybe SlowQueryStats)
-> SlowQueryStats -> Maybe SlowQueryStats
forall a b. (a -> b) -> a -> b
$
        SlowQueryStats
          { count :: Int64
count = Int64
count Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int64
1,
            timeMax :: Int64
timeMax = Int64 -> Int64 -> Int64
forall a. Ord a => a -> a -> a
max Int64
timeMax Int64
diff,
            timeAvg :: Int64
timeAvg = (Int64
timeAvg Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
* Int64
count Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int64
diff) Int64 -> Int64 -> Int64
forall a. Integral a => a -> a -> a
`div` (Int64
count Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int64
1),
            Map Text Int
errs :: Map Text Int
errs :: Map Text Int
errs
          }

open :: String -> TrackQueries -> IO Connection
open :: String -> TrackQueries -> IO Connection
open String
f TrackQueries
track = do
  Connection
conn <- String -> IO Connection
SQL.open String
f
  TMap Query SlowQueryStats
slow <- IO (TMap Query SlowQueryStats)
forall k a. IO (TMap k a)
TM.emptyIO
  Connection -> IO Connection
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Connection {Connection
conn :: Connection
conn :: Connection
conn, TMap Query SlowQueryStats
slow :: TMap Query SlowQueryStats
slow :: TMap Query SlowQueryStats
slow, TrackQueries
track :: TrackQueries
track :: TrackQueries
track}

close :: Connection -> IO ()
close :: Connection -> IO ()
close = Connection -> IO ()
SQL.close (Connection -> IO ())
-> (Connection -> Connection) -> Connection -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Connection -> Connection
conn

execute :: ToRow q => Connection -> Query -> q -> IO ()
execute :: forall q. ToRow q => Connection -> Query -> q -> IO ()
execute Connection
c Query
sql = Connection -> Query -> IO () -> IO ()
forall a. Connection -> Query -> IO a -> IO a
timeIt Connection
c Query
sql (IO () -> IO ()) -> (q -> IO ()) -> q -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Connection -> Query -> q -> IO ()
forall q. ToRow q => Connection -> Query -> q -> IO ()
SQL.execute (Connection -> Connection
conn Connection
c) Query
sql
{-# INLINE execute #-}

execute_ :: Connection -> Query -> IO ()
execute_ :: Connection -> Query -> IO ()
execute_ Connection
c Query
sql = Connection -> Query -> IO () -> IO ()
forall a. Connection -> Query -> IO a -> IO a
timeIt Connection
c Query
sql (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> Query -> IO ()
SQL.execute_ (Connection -> Connection
conn Connection
c) Query
sql
{-# INLINE execute_ #-}

executeMany :: ToRow q => Connection -> Query -> [q] -> IO ()
executeMany :: forall q. ToRow q => Connection -> Query -> [q] -> IO ()
executeMany Connection
c Query
sql = Connection -> Query -> IO () -> IO ()
forall a. Connection -> Query -> IO a -> IO a
timeIt Connection
c Query
sql (IO () -> IO ()) -> ([q] -> IO ()) -> [q] -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Connection -> Query -> [q] -> IO ()
forall q. ToRow q => Connection -> Query -> [q] -> IO ()
SQL.executeMany (Connection -> Connection
conn Connection
c) Query
sql
{-# INLINE executeMany #-}

query :: (ToRow q, FromRow r) => Connection -> Query -> q -> IO [r]
query :: forall q r.
(ToRow q, FromRow r) =>
Connection -> Query -> q -> IO [r]
query Connection
c Query
sql = Connection -> Query -> IO [r] -> IO [r]
forall a. Connection -> Query -> IO a -> IO a
timeIt Connection
c Query
sql (IO [r] -> IO [r]) -> (q -> IO [r]) -> q -> IO [r]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Connection -> Query -> q -> IO [r]
forall q r.
(ToRow q, FromRow r) =>
Connection -> Query -> q -> IO [r]
SQL.query (Connection -> Connection
conn Connection
c) Query
sql
{-# INLINE query #-}

query_ :: FromRow r => Connection -> Query -> IO [r]
query_ :: forall r. FromRow r => Connection -> Query -> IO [r]
query_ Connection
c Query
sql = Connection -> Query -> IO [r] -> IO [r]
forall a. Connection -> Query -> IO a -> IO a
timeIt Connection
c Query
sql (IO [r] -> IO [r]) -> IO [r] -> IO [r]
forall a b. (a -> b) -> a -> b
$ Connection -> Query -> IO [r]
forall r. FromRow r => Connection -> Query -> IO [r]
SQL.query_ (Connection -> Connection
conn Connection
c) Query
sql
{-# INLINE query_ #-}

blobFieldDecoder :: Typeable k => (ByteString -> Either String k) -> FieldParser k
blobFieldDecoder :: forall k.
Typeable k =>
(ByteString -> Either String k) -> FieldParser k
blobFieldDecoder ByteString -> Either String k
dec = \case
  f :: Field
f@(Field (SQLBlob ByteString
b) Int
_) ->
    case ByteString -> Either String k
dec ByteString
b of
      Right k
k -> k -> Ok k
forall a. a -> Ok a
Ok k
k
      Left String
e -> (String -> String -> String -> ResultError)
-> Field -> String -> Ok k
forall a err.
(Typeable a, Exception err) =>
(String -> String -> String -> err) -> Field -> String -> Ok a
returnError String -> String -> String -> ResultError
ConversionFailed Field
f (String
"couldn't parse field: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
e)
  Field
f -> (String -> String -> String -> ResultError)
-> Field -> String -> Ok k
forall a err.
(Typeable a, Exception err) =>
(String -> String -> String -> err) -> Field -> String -> Ok a
returnError String -> String -> String -> ResultError
ConversionFailed Field
f String
"expecting SQLBlob column type"

fromTextField_ :: Typeable a => (Text -> Maybe a) -> Field -> Ok a
fromTextField_ :: forall a. Typeable a => (Text -> Maybe a) -> Field -> Ok a
fromTextField_ Text -> Maybe a
fromText = \case
  f :: Field
f@(Field (SQLText Text
t) Int
_) ->
    case Text -> Maybe a
fromText Text
t of
      Just a
x -> a -> Ok a
forall a. a -> Ok a
Ok a
x
      Maybe a
_ -> (String -> String -> String -> ResultError)
-> Field -> String -> Ok a
forall a err.
(Typeable a, Exception err) =>
(String -> String -> String -> err) -> Field -> String -> Ok a
returnError String -> String -> String -> ResultError
ConversionFailed Field
f (String
"invalid text: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Text -> String
T.unpack Text
t)
  Field
f -> (String -> String -> String -> ResultError)
-> Field -> String -> Ok a
forall a err.
(Typeable a, Exception err) =>
(String -> String -> String -> err) -> Field -> String -> Ok a
returnError String -> String -> String -> ResultError
ConversionFailed Field
f String
"expecting SQLText column type"

$(J.deriveJSON defaultJSON ''SlowQueryStats)