module Simplex.Messaging.Server.MsgStore.Journal.SharedLock
  ( withLockWaitShared,
    withLockMapWaitShared,
    withSharedWaitLock,
  )
where

import Control.Concurrent.STM
import qualified Control.Exception as E
import Control.Monad
import Data.Text (Text)
import Simplex.Messaging.Agent.Lock
import Simplex.Messaging.Agent.Client (getMapLock)
import Simplex.Messaging.Protocol (RecipientId)
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Util (($>>), ($>>=))

-- wait until shared lock with passed ID is released and take lock
withLockWaitShared :: RecipientId -> Lock -> TMVar RecipientId -> Text -> IO a -> IO a
withLockWaitShared :: forall a.
RecipientId -> Lock -> TMVar RecipientId -> Text -> IO a -> IO a
withLockWaitShared RecipientId
rId Lock
lock TMVar RecipientId
shared Text
name =
  IO () -> IO () -> IO a -> IO a
forall a b c. IO a -> IO b -> IO c -> IO c
E.bracket_
    (STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ RecipientId -> TMVar RecipientId -> STM ()
waitShared RecipientId
rId TMVar RecipientId
shared STM () -> STM () -> STM ()
forall a b. STM a -> STM b -> STM b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Lock -> Text -> STM ()
forall a. TMVar a -> a -> STM ()
putTMVar Lock
lock Text
name)
    (IO Text -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Text -> IO ()) -> IO Text -> IO ()
forall a b. (a -> b) -> a -> b
$ STM Text -> IO Text
forall a. STM a -> IO a
atomically (STM Text -> IO Text) -> STM Text -> IO Text
forall a b. (a -> b) -> a -> b
$ Lock -> STM Text
forall a. TMVar a -> STM a
takeTMVar Lock
lock)

-- wait until shared lock with passed ID is released and take lock from Map for this ID
withLockMapWaitShared :: RecipientId -> TMap RecipientId Lock -> TMVar RecipientId -> Text -> IO a -> IO a
withLockMapWaitShared :: forall a.
RecipientId
-> TMap RecipientId Lock
-> TMVar RecipientId
-> Text
-> IO a
-> IO a
withLockMapWaitShared RecipientId
rId TMap RecipientId Lock
locks TMVar RecipientId
shared Text
name IO a
a =
  IO Lock -> (Lock -> IO Text) -> (Lock -> IO a) -> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracket
    (STM Lock -> IO Lock
forall a. STM a -> IO a
atomically (STM Lock -> IO Lock) -> STM Lock -> IO Lock
forall a b. (a -> b) -> a -> b
$ RecipientId -> TMVar RecipientId -> STM ()
waitShared RecipientId
rId TMVar RecipientId
shared STM () -> STM Lock -> STM Lock
forall a b. STM a -> STM b -> STM b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> (RecipientId -> STM Lock) -> RecipientId -> Text -> STM Lock
forall k. (k -> STM Lock) -> k -> Text -> STM Lock
getPutLock (TMap RecipientId Lock -> RecipientId -> STM Lock
forall k. Ord k => TMap k Lock -> k -> STM Lock
getMapLock TMap RecipientId Lock
locks) RecipientId
rId Text
name)
    (STM Text -> IO Text
forall a. STM a -> IO a
atomically (STM Text -> IO Text) -> (Lock -> STM Text) -> Lock -> IO Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lock -> STM Text
forall a. TMVar a -> STM a
takeTMVar)
    (IO a -> Lock -> IO a
forall a b. a -> b -> a
const IO a
a)

waitShared :: RecipientId -> TMVar RecipientId -> STM ()
waitShared :: RecipientId -> TMVar RecipientId -> STM ()
waitShared RecipientId
rId TMVar RecipientId
shared = TMVar RecipientId -> STM (Maybe RecipientId)
forall a. TMVar a -> STM (Maybe a)
tryReadTMVar TMVar RecipientId
shared STM (Maybe RecipientId) -> (Maybe RecipientId -> STM ()) -> STM ()
forall a b. STM a -> (a -> STM b) -> STM b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (RecipientId -> STM ()) -> Maybe RecipientId -> STM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\RecipientId
rId' -> Bool -> STM () -> STM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (RecipientId
rId RecipientId -> RecipientId -> Bool
forall a. Eq a => a -> a -> Bool
== RecipientId
rId') STM ()
forall a. STM a
retry)

-- wait until lock with passed ID in Map is released and take shared lock for this ID
withSharedWaitLock :: RecipientId -> TMap RecipientId Lock -> TMVar RecipientId -> IO a -> IO a
withSharedWaitLock :: forall a.
RecipientId
-> TMap RecipientId Lock -> TMVar RecipientId -> IO a -> IO a
withSharedWaitLock RecipientId
rId TMap RecipientId Lock
locks TMVar RecipientId
shared =
  IO () -> IO RecipientId -> IO a -> IO a
forall a b c. IO a -> IO b -> IO c -> IO c
E.bracket_
    (STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ STM (Maybe Any)
forall {b}. STM (Maybe b)
waitLock STM (Maybe Any) -> STM () -> STM ()
forall a b. STM a -> STM b -> STM b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> TMVar RecipientId -> RecipientId -> STM ()
forall a. TMVar a -> a -> STM ()
putTMVar TMVar RecipientId
shared RecipientId
rId)
    (STM RecipientId -> IO RecipientId
forall a. STM a -> IO a
atomically (STM RecipientId -> IO RecipientId)
-> STM RecipientId -> IO RecipientId
forall a b. (a -> b) -> a -> b
$ TMVar RecipientId -> STM RecipientId
forall a. TMVar a -> STM a
takeTMVar TMVar RecipientId
shared)
  where
    waitLock :: STM (Maybe b)
waitLock = RecipientId -> TMap RecipientId Lock -> STM (Maybe Lock)
forall k a. Ord k => k -> TMap k a -> STM (Maybe a)
TM.lookup RecipientId
rId TMap RecipientId Lock
locks STM (Maybe Lock) -> (Lock -> STM (Maybe Text)) -> STM (Maybe Text)
forall (m :: * -> *) (f :: * -> *) a b.
(Monad m, Monad f, Traversable f) =>
m (f a) -> (a -> m (f b)) -> m (f b)
$>>= Lock -> STM (Maybe Text)
forall a. TMVar a -> STM (Maybe a)
tryReadTMVar STM (Maybe Text) -> STM (Maybe b) -> STM (Maybe b)
forall (m :: * -> *) (f :: * -> *) a b.
(Monad m, Monad f, Traversable f) =>
m (f a) -> m (f b) -> m (f b)
$>> STM (Maybe b)
forall a. STM a
retry