{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Simplex.Messaging.Agent.RetryInterval
  ( RetryInterval (..),
    RetryInterval2 (..),
    RetryIntervalMode (..),
    RI2State (..),
    withRetryInterval,
    withRetryIntervalCount,
    withRetryForeground,
    withRetryLock2,
    updateRetryInterval2,
    nextRetryDelay,
  )
where

import Control.Concurrent (forkIO)
import Control.Concurrent.STM (retry)
import Control.Monad (void)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.Int (Int64)
import Simplex.Messaging.Util (threadDelay', unlessM, whenM)
import UnliftIO.STM

data RetryInterval = RetryInterval
  { RetryInterval -> Int64
initialInterval :: Int64,
    RetryInterval -> Int64
increaseAfter :: Int64,
    RetryInterval -> Int64
maxInterval :: Int64
  }

data RetryInterval2 = RetryInterval2
  { RetryInterval2 -> RetryInterval
riSlow :: RetryInterval,
    RetryInterval2 -> RetryInterval
riFast :: RetryInterval
  }

data RI2State = RI2State
  { RI2State -> Int64
slowInterval :: Int64,
    RI2State -> Int64
fastInterval :: Int64
  }
  deriving (Int -> RI2State -> ShowS
[RI2State] -> ShowS
RI2State -> String
(Int -> RI2State -> ShowS)
-> (RI2State -> String) -> ([RI2State] -> ShowS) -> Show RI2State
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> RI2State -> ShowS
showsPrec :: Int -> RI2State -> ShowS
$cshow :: RI2State -> String
show :: RI2State -> String
$cshowList :: [RI2State] -> ShowS
showList :: [RI2State] -> ShowS
Show)

updateRetryInterval2 :: RI2State -> RetryInterval2 -> RetryInterval2
updateRetryInterval2 :: RI2State -> RetryInterval2 -> RetryInterval2
updateRetryInterval2 RI2State {Int64
slowInterval :: RI2State -> Int64
slowInterval :: Int64
slowInterval, Int64
fastInterval :: RI2State -> Int64
fastInterval :: Int64
fastInterval} RetryInterval2 {RetryInterval
riSlow :: RetryInterval2 -> RetryInterval
riSlow :: RetryInterval
riSlow, RetryInterval
riFast :: RetryInterval2 -> RetryInterval
riFast :: RetryInterval
riFast} =
  RetryInterval2
    { riSlow :: RetryInterval
riSlow = RetryInterval
riSlow {initialInterval = slowInterval, increaseAfter = 0},
      riFast :: RetryInterval
riFast = RetryInterval
riFast {initialInterval = fastInterval, increaseAfter = 0}
    }

data RetryIntervalMode = RISlow | RIFast
  deriving (RetryIntervalMode -> RetryIntervalMode -> Bool
(RetryIntervalMode -> RetryIntervalMode -> Bool)
-> (RetryIntervalMode -> RetryIntervalMode -> Bool)
-> Eq RetryIntervalMode
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: RetryIntervalMode -> RetryIntervalMode -> Bool
== :: RetryIntervalMode -> RetryIntervalMode -> Bool
$c/= :: RetryIntervalMode -> RetryIntervalMode -> Bool
/= :: RetryIntervalMode -> RetryIntervalMode -> Bool
Eq, Int -> RetryIntervalMode -> ShowS
[RetryIntervalMode] -> ShowS
RetryIntervalMode -> String
(Int -> RetryIntervalMode -> ShowS)
-> (RetryIntervalMode -> String)
-> ([RetryIntervalMode] -> ShowS)
-> Show RetryIntervalMode
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> RetryIntervalMode -> ShowS
showsPrec :: Int -> RetryIntervalMode -> ShowS
$cshow :: RetryIntervalMode -> String
show :: RetryIntervalMode -> String
$cshowList :: [RetryIntervalMode] -> ShowS
showList :: [RetryIntervalMode] -> ShowS
Show)

withRetryInterval :: forall m a. MonadIO m => RetryInterval -> (Int64 -> m a -> m a) -> m a
withRetryInterval :: forall (m :: * -> *) a.
MonadIO m =>
RetryInterval -> (Int64 -> m a -> m a) -> m a
withRetryInterval RetryInterval
ri = RetryInterval -> (Int -> Int64 -> m a -> m a) -> m a
forall (m :: * -> *) a.
MonadIO m =>
RetryInterval -> (Int -> Int64 -> m a -> m a) -> m a
withRetryIntervalCount RetryInterval
ri ((Int -> Int64 -> m a -> m a) -> m a)
-> ((Int64 -> m a -> m a) -> Int -> Int64 -> m a -> m a)
-> (Int64 -> m a -> m a)
-> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int64 -> m a -> m a) -> Int -> Int64 -> m a -> m a
forall a b. a -> b -> a
const

withRetryIntervalCount :: forall m a. MonadIO m => RetryInterval -> (Int -> Int64 -> m a -> m a) -> m a
withRetryIntervalCount :: forall (m :: * -> *) a.
MonadIO m =>
RetryInterval -> (Int -> Int64 -> m a -> m a) -> m a
withRetryIntervalCount RetryInterval
ri Int -> Int64 -> m a -> m a
action = Int -> Int64 -> Int64 -> m a
callAction Int
0 Int64
0 (Int64 -> m a) -> Int64 -> m a
forall a b. (a -> b) -> a -> b
$ RetryInterval -> Int64
initialInterval RetryInterval
ri
  where
    callAction :: Int -> Int64 -> Int64 -> m a
    callAction :: Int -> Int64 -> Int64 -> m a
callAction Int
n Int64
elapsed Int64
delay = Int -> Int64 -> m a -> m a
action Int
n Int64
delay m a
loop
      where
        loop :: m a
loop = do
          IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Int64 -> IO ()
threadDelay' Int64
delay
          let elapsed' :: Int64
elapsed' = Int64
elapsed Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int64
delay
          Int -> Int64 -> Int64 -> m a
callAction (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int64
elapsed' (Int64 -> m a) -> Int64 -> m a
forall a b. (a -> b) -> a -> b
$ Int64 -> Int64 -> RetryInterval -> Int64
nextRetryDelay Int64
elapsed' Int64
delay RetryInterval
ri

withRetryForeground :: forall m a. MonadIO m => RetryInterval -> STM Bool -> STM Bool -> (Int64 -> m a -> m a) -> m a
withRetryForeground :: forall (m :: * -> *) a.
MonadIO m =>
RetryInterval
-> STM Bool -> STM Bool -> (Int64 -> m a -> m a) -> m a
withRetryForeground RetryInterval
ri STM Bool
isForeground STM Bool
isOnline Int64 -> m a -> m a
action = Int64 -> Int64 -> m a
callAction Int64
0 (Int64 -> m a) -> Int64 -> m a
forall a b. (a -> b) -> a -> b
$ RetryInterval -> Int64
initialInterval RetryInterval
ri
  where
    callAction :: Int64 -> Int64 -> m a
    callAction :: Int64 -> Int64 -> m a
callAction Int64
elapsed Int64
delay = Int64 -> m a -> m a
action Int64
delay m a
loop
      where
        loop :: m a
loop = do
          -- limit delay to max Int value (~36 minutes on for 32 bit architectures)
          TVar Bool
d <- Int -> m (TVar Bool)
forall (m :: * -> *). MonadIO m => Int -> m (TVar Bool)
registerDelay (Int -> m (TVar Bool)) -> Int -> m (TVar Bool)
forall a b. (a -> b) -> a -> b
$ Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Int) -> Int64 -> Int
forall a b. (a -> b) -> a -> b
$ Int64 -> Int64 -> Int64
forall a. Ord a => a -> a -> a
min Int64
delay (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
forall a. Bounded a => a
maxBound :: Int))
          (Bool
wasForeground, Bool
wasOnline) <- STM (Bool, Bool) -> m (Bool, Bool)
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM (Bool, Bool) -> m (Bool, Bool))
-> STM (Bool, Bool) -> m (Bool, Bool)
forall a b. (a -> b) -> a -> b
$ (,) (Bool -> Bool -> (Bool, Bool))
-> STM Bool -> STM (Bool -> (Bool, Bool))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STM Bool
isForeground STM (Bool -> (Bool, Bool)) -> STM Bool -> STM (Bool, Bool)
forall a b. STM (a -> b) -> STM a -> STM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> STM Bool
isOnline
          Bool
reset <- STM Bool -> m Bool
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM Bool -> m Bool) -> STM Bool -> m Bool
forall a b. (a -> b) -> a -> b
$ do
            Bool
foreground <- STM Bool
isForeground
            Bool
online <- STM Bool
isOnline
            let reset :: Bool
reset = (Bool -> Bool
not Bool
wasForeground Bool -> Bool -> Bool
&& Bool
foreground) Bool -> Bool -> Bool
|| (Bool -> Bool
not Bool
wasOnline Bool -> Bool -> Bool
&& Bool
online)
            STM Bool -> STM () -> STM ()
forall (m :: * -> *). Monad m => m Bool -> m () -> m ()
unlessM ((Bool
reset Bool -> Bool -> Bool
||) (Bool -> Bool) -> STM Bool -> STM Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TVar Bool -> STM Bool
forall a. TVar a -> STM a
readTVar TVar Bool
d) STM ()
forall a. STM a
retry
            Bool -> STM Bool
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
reset
          let (Int64
elapsed', Int64
delay')
                | Bool
reset = (Int64
0, RetryInterval -> Int64
initialInterval RetryInterval
ri)
                | Bool
otherwise = (Int64
elapsed Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int64
delay, Int64 -> Int64 -> RetryInterval -> Int64
nextRetryDelay Int64
elapsed' Int64
delay RetryInterval
ri)
          Int64 -> Int64 -> m a
callAction Int64
elapsed' Int64
delay'

-- This function allows action to toggle between slow and fast retry intervals.
withRetryLock2 :: forall m. MonadIO m => RetryInterval2 -> TMVar () -> (RI2State -> (RetryIntervalMode -> m ()) -> m ()) -> m ()
withRetryLock2 :: forall (m :: * -> *).
MonadIO m =>
RetryInterval2
-> TMVar ()
-> (RI2State -> (RetryIntervalMode -> m ()) -> m ())
-> m ()
withRetryLock2 RetryInterval2 {RetryInterval
riSlow :: RetryInterval2 -> RetryInterval
riSlow :: RetryInterval
riSlow, RetryInterval
riFast :: RetryInterval2 -> RetryInterval
riFast :: RetryInterval
riFast} TMVar ()
lock RI2State -> (RetryIntervalMode -> m ()) -> m ()
action =
  (Int64, Int64) -> (Int64, Int64) -> m ()
callAction (Int64
0, RetryInterval -> Int64
initialInterval RetryInterval
riSlow) (Int64
0, RetryInterval -> Int64
initialInterval RetryInterval
riFast)
  where
    callAction :: (Int64, Int64) -> (Int64, Int64) -> m ()
    callAction :: (Int64, Int64) -> (Int64, Int64) -> m ()
callAction (Int64, Int64)
slow (Int64, Int64)
fast = RI2State -> (RetryIntervalMode -> m ()) -> m ()
action (Int64 -> Int64 -> RI2State
RI2State ((Int64, Int64) -> Int64
forall a b. (a, b) -> b
snd (Int64, Int64)
slow) ((Int64, Int64) -> Int64
forall a b. (a, b) -> b
snd (Int64, Int64)
fast)) RetryIntervalMode -> m ()
loop
      where
        loop :: RetryIntervalMode -> m ()
loop = \case
          RetryIntervalMode
RISlow -> (Int64, Int64) -> RetryInterval -> ((Int64, Int64) -> m ()) -> m ()
forall {m :: * -> *} {b}.
MonadIO m =>
(Int64, Int64) -> RetryInterval -> ((Int64, Int64) -> m b) -> m b
run (Int64, Int64)
slow RetryInterval
riSlow ((Int64, Int64) -> (Int64, Int64) -> m ()
`callAction` (Int64, Int64)
fast)
          RetryIntervalMode
RIFast -> (Int64, Int64) -> RetryInterval -> ((Int64, Int64) -> m ()) -> m ()
forall {m :: * -> *} {b}.
MonadIO m =>
(Int64, Int64) -> RetryInterval -> ((Int64, Int64) -> m b) -> m b
run (Int64, Int64)
fast RetryInterval
riFast ((Int64, Int64) -> (Int64, Int64) -> m ()
callAction (Int64, Int64)
slow)
        run :: (Int64, Int64) -> RetryInterval -> ((Int64, Int64) -> m b) -> m b
run (Int64
elapsed, Int64
delay) RetryInterval
ri (Int64, Int64) -> m b
call = do
          Int64 -> m ()
forall {m :: * -> *}. MonadIO m => Int64 -> m ()
wait Int64
delay
          let elapsed' :: Int64
elapsed' = Int64
elapsed Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int64
delay
              delay' :: Int64
delay' = Int64 -> Int64 -> RetryInterval -> Int64
nextRetryDelay Int64
elapsed' Int64
delay RetryInterval
ri
          (Int64, Int64) -> m b
call (Int64
elapsed', Int64
delay')
        wait :: Int64 -> m ()
wait Int64
delay = do
          TVar Bool
waiting <- Bool -> m (TVar Bool)
forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO Bool
True
          ThreadId
_ <- IO ThreadId -> m ThreadId
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ThreadId -> m ThreadId)
-> (IO () -> IO ThreadId) -> IO () -> m ThreadId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO () -> IO ThreadId
forkIO (IO () -> m ThreadId) -> IO () -> m ThreadId
forall a b. (a -> b) -> a -> b
$ do
            Int64 -> IO ()
threadDelay' Int64
delay
            STM () -> IO ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ STM Bool -> STM () -> STM ()
forall (m :: * -> *). Monad m => m Bool -> m () -> m ()
whenM (TVar Bool -> STM Bool
forall a. TVar a -> STM a
readTVar TVar Bool
waiting) (STM () -> STM ()) -> STM () -> STM ()
forall a b. (a -> b) -> a -> b
$ STM Bool -> STM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (STM Bool -> STM ()) -> STM Bool -> STM ()
forall a b. (a -> b) -> a -> b
$ TMVar () -> () -> STM Bool
forall a. TMVar a -> a -> STM Bool
tryPutTMVar TMVar ()
lock ()
          STM () -> m ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> m ()) -> STM () -> m ()
forall a b. (a -> b) -> a -> b
$ do
            TMVar () -> STM ()
forall a. TMVar a -> STM a
takeTMVar TMVar ()
lock
            TVar Bool -> Bool -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar Bool
waiting Bool
False

nextRetryDelay :: Int64 -> Int64 -> RetryInterval -> Int64
nextRetryDelay :: Int64 -> Int64 -> RetryInterval -> Int64
nextRetryDelay Int64
elapsed Int64
delay RetryInterval {Int64
increaseAfter :: RetryInterval -> Int64
increaseAfter :: Int64
increaseAfter, Int64
maxInterval :: RetryInterval -> Int64
maxInterval :: Int64
maxInterval} =
  if Int64
elapsed Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
< Int64
increaseAfter Bool -> Bool -> Bool
|| Int64
delay Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
== Int64
maxInterval
    then Int64
delay
    else Int64 -> Int64 -> Int64
forall a. Ord a => a -> a -> a
min (Int64
delay Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
* Int64
3 Int64 -> Int64 -> Int64
forall a. Integral a => a -> a -> a
`div` Int64
2) Int64
maxInterval