{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Simplex.Messaging.Session
  ( SessionVar (..),
    getSessVar,
    removeSessVar,
    tryReadSessVar,
  ) where

import Control.Concurrent.STM
import Data.Time (UTCTime)
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Util (($>>=))

data SessionVar a = SessionVar
  { forall a. SessionVar a -> TMVar a
sessionVar :: TMVar a,
    forall a. SessionVar a -> Int
sessionVarId :: Int,
    forall a. SessionVar a -> UTCTime
sessionVarTs :: UTCTime
  }

getSessVar :: forall k a. Ord k => TVar Int -> k -> TMap k (SessionVar a) -> UTCTime -> STM (Either (SessionVar a) (SessionVar a))
getSessVar :: forall k a.
Ord k =>
TVar Int
-> k
-> TMap k (SessionVar a)
-> UTCTime
-> STM (Either (SessionVar a) (SessionVar a))
getSessVar TVar Int
sessSeq k
sessKey TMap k (SessionVar a)
vs UTCTime
sessionVarTs = STM (Either (SessionVar a) (SessionVar a))
-> (SessionVar a -> STM (Either (SessionVar a) (SessionVar a)))
-> Maybe (SessionVar a)
-> STM (Either (SessionVar a) (SessionVar a))
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (SessionVar a -> Either (SessionVar a) (SessionVar a)
forall a b. a -> Either a b
Left (SessionVar a -> Either (SessionVar a) (SessionVar a))
-> STM (SessionVar a) -> STM (Either (SessionVar a) (SessionVar a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STM (SessionVar a)
newSessionVar) (Either (SessionVar a) (SessionVar a)
-> STM (Either (SessionVar a) (SessionVar a))
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either (SessionVar a) (SessionVar a)
 -> STM (Either (SessionVar a) (SessionVar a)))
-> (SessionVar a -> Either (SessionVar a) (SessionVar a))
-> SessionVar a
-> STM (Either (SessionVar a) (SessionVar a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SessionVar a -> Either (SessionVar a) (SessionVar a)
forall a b. b -> Either a b
Right) (Maybe (SessionVar a)
 -> STM (Either (SessionVar a) (SessionVar a)))
-> STM (Maybe (SessionVar a))
-> STM (Either (SessionVar a) (SessionVar a))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< k -> TMap k (SessionVar a) -> STM (Maybe (SessionVar a))
forall k a. Ord k => k -> TMap k a -> STM (Maybe a)
TM.lookup k
sessKey TMap k (SessionVar a)
vs
  where
    newSessionVar :: STM (SessionVar a)
    newSessionVar :: STM (SessionVar a)
newSessionVar = do
      TMVar a
sessionVar <- STM (TMVar a)
forall a. STM (TMVar a)
newEmptyTMVar
      Int
sessionVarId <- TVar Int -> (Int -> (Int, Int)) -> STM Int
forall s a. TVar s -> (s -> (a, s)) -> STM a
stateTVar TVar Int
sessSeq ((Int -> (Int, Int)) -> STM Int) -> (Int -> (Int, Int)) -> STM Int
forall a b. (a -> b) -> a -> b
$ \Int
next -> (Int
next, Int
next Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
      let v :: SessionVar a
v = SessionVar {TMVar a
sessionVar :: TMVar a
sessionVar :: TMVar a
sessionVar, Int
sessionVarId :: Int
sessionVarId :: Int
sessionVarId, UTCTime
sessionVarTs :: UTCTime
sessionVarTs :: UTCTime
sessionVarTs}
      k -> SessionVar a -> TMap k (SessionVar a) -> STM ()
forall k a. Ord k => k -> a -> TMap k a -> STM ()
TM.insert k
sessKey SessionVar a
v TMap k (SessionVar a)
vs
      SessionVar a -> STM (SessionVar a)
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SessionVar a
v

removeSessVar :: Ord k => SessionVar a -> k -> TMap k (SessionVar a) -> STM ()
removeSessVar :: forall k a.
Ord k =>
SessionVar a -> k -> TMap k (SessionVar a) -> STM ()
removeSessVar SessionVar a
v k
sessKey TMap k (SessionVar a)
vs =
  k -> TMap k (SessionVar a) -> STM (Maybe (SessionVar a))
forall k a. Ord k => k -> TMap k a -> STM (Maybe a)
TM.lookup k
sessKey TMap k (SessionVar a)
vs STM (Maybe (SessionVar a))
-> (Maybe (SessionVar a) -> STM ()) -> STM ()
forall a b. STM a -> (a -> STM b) -> STM b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just SessionVar a
v' | SessionVar a -> Int
forall a. SessionVar a -> Int
sessionVarId SessionVar a
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== SessionVar a -> Int
forall a. SessionVar a -> Int
sessionVarId SessionVar a
v' -> k -> TMap k (SessionVar a) -> STM ()
forall k a. Ord k => k -> TMap k a -> STM ()
TM.delete k
sessKey TMap k (SessionVar a)
vs
    Maybe (SessionVar a)
_ -> () -> STM ()
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

tryReadSessVar :: Ord k => k -> TMap k (SessionVar a) -> STM (Maybe a)
tryReadSessVar :: forall k a. Ord k => k -> TMap k (SessionVar a) -> STM (Maybe a)
tryReadSessVar k
sessKey TMap k (SessionVar a)
vs = k -> TMap k (SessionVar a) -> STM (Maybe (SessionVar a))
forall k a. Ord k => k -> TMap k a -> STM (Maybe a)
TM.lookup k
sessKey TMap k (SessionVar a)
vs STM (Maybe (SessionVar a))
-> (SessionVar a -> STM (Maybe a)) -> STM (Maybe a)
forall (m :: * -> *) (f :: * -> *) a b.
(Monad m, Monad f, Traversable f) =>
m (f a) -> (a -> m (f b)) -> m (f b)
$>>= (TMVar a -> STM (Maybe a)
forall a. TMVar a -> STM (Maybe a)
tryReadTMVar (TMVar a -> STM (Maybe a))
-> (SessionVar a -> TMVar a) -> SessionVar a -> STM (Maybe a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SessionVar a -> TMVar a
forall a. SessionVar a -> TMVar a
sessionVar)