{-# LANGUAGE LambdaCase #-}

module Simplex.Chat.Types.MemberRelations
  ( IntroductionDirection (..),
    MemberRelation (..),
    toIntroDirInt,
    fromIntroDirInt,
    toRelationInt,
    fromRelationInt,
    getRelation,
    getRelation',
    getRelationsIndexes,
    setRelation,
    setRelations,
    setRelationConnected,
    setNewRelation,
    setNewRelations,
  )
where

import Control.Monad
import Data.Bits (shiftL, shiftR, (.&.), (.|.), complement)
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Data.ByteString.Internal (toForeignPtr, unsafeCreate)
import Data.Int (Int64)
import Data.Word (Word8)
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Marshal.Utils (copyBytes, fillBytes)
import Foreign.Ptr (plusPtr)
import Foreign.Storable (peekByteOff, pokeByteOff)

data IntroductionDirection
  = IDSubjectIntroduced -- Member described by vector (subject member, vector "owner") is introduced to member referenced in vector
  | IDReferencedIntroduced -- Member referenced in vector is introduced to subject member
  deriving (IntroductionDirection -> IntroductionDirection -> Bool
(IntroductionDirection -> IntroductionDirection -> Bool)
-> (IntroductionDirection -> IntroductionDirection -> Bool)
-> Eq IntroductionDirection
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: IntroductionDirection -> IntroductionDirection -> Bool
== :: IntroductionDirection -> IntroductionDirection -> Bool
$c/= :: IntroductionDirection -> IntroductionDirection -> Bool
/= :: IntroductionDirection -> IntroductionDirection -> Bool
Eq, Int -> IntroductionDirection -> ShowS
[IntroductionDirection] -> ShowS
IntroductionDirection -> String
(Int -> IntroductionDirection -> ShowS)
-> (IntroductionDirection -> String)
-> ([IntroductionDirection] -> ShowS)
-> Show IntroductionDirection
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> IntroductionDirection -> ShowS
showsPrec :: Int -> IntroductionDirection -> ShowS
$cshow :: IntroductionDirection -> String
show :: IntroductionDirection -> String
$cshowList :: [IntroductionDirection] -> ShowS
showList :: [IntroductionDirection] -> ShowS
Show)

toIntroDirInt :: IntroductionDirection -> Word8
toIntroDirInt :: IntroductionDirection -> Word8
toIntroDirInt = \case
  IntroductionDirection
IDSubjectIntroduced -> Word8
0
  IntroductionDirection
IDReferencedIntroduced -> Word8
1

fromIntroDirInt :: Word8 -> IntroductionDirection
fromIntroDirInt :: Word8 -> IntroductionDirection
fromIntroDirInt = \case
  Word8
0 -> IntroductionDirection
IDSubjectIntroduced
  Word8
1 -> IntroductionDirection
IDReferencedIntroduced
  Word8
_ -> IntroductionDirection
IDSubjectIntroduced

data MemberRelation
  = MRNew
  | MRIntroduced
  | MRSubjectConnected -- Subject member notified about connection to referenced member
  | MRReferencedConnected -- Referenced member notified about connection to subject member
  | MRConnected -- Both members notified about connection
  deriving (MemberRelation -> MemberRelation -> Bool
(MemberRelation -> MemberRelation -> Bool)
-> (MemberRelation -> MemberRelation -> Bool) -> Eq MemberRelation
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: MemberRelation -> MemberRelation -> Bool
== :: MemberRelation -> MemberRelation -> Bool
$c/= :: MemberRelation -> MemberRelation -> Bool
/= :: MemberRelation -> MemberRelation -> Bool
Eq, Eq MemberRelation
Eq MemberRelation =>
(MemberRelation -> MemberRelation -> Ordering)
-> (MemberRelation -> MemberRelation -> Bool)
-> (MemberRelation -> MemberRelation -> Bool)
-> (MemberRelation -> MemberRelation -> Bool)
-> (MemberRelation -> MemberRelation -> Bool)
-> (MemberRelation -> MemberRelation -> MemberRelation)
-> (MemberRelation -> MemberRelation -> MemberRelation)
-> Ord MemberRelation
MemberRelation -> MemberRelation -> Bool
MemberRelation -> MemberRelation -> Ordering
MemberRelation -> MemberRelation -> MemberRelation
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: MemberRelation -> MemberRelation -> Ordering
compare :: MemberRelation -> MemberRelation -> Ordering
$c< :: MemberRelation -> MemberRelation -> Bool
< :: MemberRelation -> MemberRelation -> Bool
$c<= :: MemberRelation -> MemberRelation -> Bool
<= :: MemberRelation -> MemberRelation -> Bool
$c> :: MemberRelation -> MemberRelation -> Bool
> :: MemberRelation -> MemberRelation -> Bool
$c>= :: MemberRelation -> MemberRelation -> Bool
>= :: MemberRelation -> MemberRelation -> Bool
$cmax :: MemberRelation -> MemberRelation -> MemberRelation
max :: MemberRelation -> MemberRelation -> MemberRelation
$cmin :: MemberRelation -> MemberRelation -> MemberRelation
min :: MemberRelation -> MemberRelation -> MemberRelation
Ord, Int -> MemberRelation -> ShowS
[MemberRelation] -> ShowS
MemberRelation -> String
(Int -> MemberRelation -> ShowS)
-> (MemberRelation -> String)
-> ([MemberRelation] -> ShowS)
-> Show MemberRelation
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> MemberRelation -> ShowS
showsPrec :: Int -> MemberRelation -> ShowS
$cshow :: MemberRelation -> String
show :: MemberRelation -> String
$cshowList :: [MemberRelation] -> ShowS
showList :: [MemberRelation] -> ShowS
Show)

toRelationInt :: MemberRelation -> Word8
toRelationInt :: MemberRelation -> Word8
toRelationInt = \case
  MemberRelation
MRNew -> Word8
0
  MemberRelation
MRIntroduced -> Word8
1
  MemberRelation
MRSubjectConnected -> Word8
2
  MemberRelation
MRReferencedConnected -> Word8
3
  MemberRelation
MRConnected -> Word8
4

fromRelationInt :: Word8 -> MemberRelation
fromRelationInt :: Word8 -> MemberRelation
fromRelationInt = \case
  Word8
0 -> MemberRelation
MRNew
  Word8
1 -> MemberRelation
MRIntroduced
  Word8
2 -> MemberRelation
MRSubjectConnected
  Word8
3 -> MemberRelation
MRReferencedConnected
  Word8
4 -> MemberRelation
MRConnected
  Word8
_ -> MemberRelation
MRNew

-- Bit layout: 4 reserved | 1 direction | 3 status

-- | Get the relation status of a member at a given index from the relations vector.
-- Returns 'MRNew' if the vector is not long enough (lazy initialization).
getRelation :: Int64 -> ByteString -> MemberRelation
getRelation :: Int64 -> ByteString -> MemberRelation
getRelation Int64
i ByteString
v = (IntroductionDirection, MemberRelation) -> MemberRelation
forall a b. (a, b) -> b
snd ((IntroductionDirection, MemberRelation) -> MemberRelation)
-> (IntroductionDirection, MemberRelation) -> MemberRelation
forall a b. (a -> b) -> a -> b
$ Int64 -> ByteString -> (IntroductionDirection, MemberRelation)
getRelation' Int64
i ByteString
v

-- | Get both direction and status of a member at a given index from the relations vector.
-- Returns (IDSubjectIntroduced, MRNew) if the vector is not long enough (lazy initialization).
getRelation' :: Int64 -> ByteString -> (IntroductionDirection, MemberRelation)
getRelation' :: Int64 -> ByteString -> (IntroductionDirection, MemberRelation)
getRelation' Int64
i ByteString
v
  | Int64
i Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
< Int64
0 Bool -> Bool -> Bool
|| Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= ByteString -> Int
B.length ByteString
v = (IntroductionDirection
IDSubjectIntroduced, MemberRelation
MRNew)
  | Bool
otherwise =
      let b :: Word8
b = ByteString
v HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
`B.index` Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
i
       in (Word8 -> IntroductionDirection
fromIntroDirInt (Word8 -> IntroductionDirection) -> Word8 -> IntroductionDirection
forall a b. (a -> b) -> a -> b
$ (Word8
b Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
directionMask) Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
`shiftR` Int
3, Word8 -> MemberRelation
fromRelationInt (Word8 -> MemberRelation) -> Word8 -> MemberRelation
forall a b. (a -> b) -> a -> b
$ Word8
b Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
statusMask)

-- | Get the indexes of members with the given relation status from the relations vector.
getRelationsIndexes :: MemberRelation -> ByteString -> [Int64]
getRelationsIndexes :: MemberRelation -> ByteString -> [Int64]
getRelationsIndexes MemberRelation
r ByteString
v = [Int64
i | Int64
i <- [Int64
0 .. Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
B.length ByteString
v) Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
- Int64
1], Int64 -> ByteString -> MemberRelation
getRelation Int64
i ByteString
v MemberRelation -> MemberRelation -> Bool
forall a. Eq a => a -> a -> Bool
== MemberRelation
r]

-- | Set the relation status of a member at a given index in the relations vector.
-- Preserves the introduction direction. Expands the vector lazily if needed.
setRelation :: Int64 -> MemberRelation -> ByteString -> ByteString
setRelation :: Int64 -> MemberRelation -> ByteString -> ByteString
setRelation Int64
i MemberRelation
r ByteString
v
  | Int64
i Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
>= Int64
0 = [(Int64, MemberRelation)] -> ByteString -> ByteString
setRelations [(Int64
i, MemberRelation
r)] ByteString
v
  | Bool
otherwise = ByteString
v

-- | Set multiple relation statuses at once.
-- Preserves the introduction direction. Expands the vector lazily if needed.
setRelations :: [(Int64, MemberRelation)] -> ByteString -> ByteString
setRelations :: [(Int64, MemberRelation)] -> ByteString -> ByteString
setRelations = (MemberRelation -> Word8 -> Word8)
-> [(Int64, MemberRelation)] -> ByteString -> ByteString
forall r.
(r -> Word8 -> Word8) -> [(Int64, r)] -> ByteString -> ByteString
setRelations_ ((MemberRelation -> Word8 -> Word8)
 -> [(Int64, MemberRelation)] -> ByteString -> ByteString)
-> (MemberRelation -> Word8 -> Word8)
-> [(Int64, MemberRelation)]
-> ByteString
-> ByteString
forall a b. (a -> b) -> a -> b
$ \MemberRelation
r Word8
b -> (Word8
b Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8 -> Word8
forall a. Bits a => a -> a
complement Word8
statusMask) Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|. MemberRelation -> Word8
toRelationInt MemberRelation
r

-- | Set relation to connected state based on passed status and current status.
-- newStatus should be MRSubjectConnected or MRReferencedConnected, otherwise returns vector unchanged.
-- Logic:
-- - if newStatus is complementary to oldStatus -> set MRConnected
-- - if newStatus > oldStatus (by enum order) -> set newStatus
-- - otherwise don't update
setRelationConnected :: Int64 -> MemberRelation -> ByteString -> ByteString
setRelationConnected :: Int64 -> MemberRelation -> ByteString -> ByteString
setRelationConnected Int64
i MemberRelation
newStatus ByteString
v
  | MemberRelation
newStatus MemberRelation -> MemberRelation -> Bool
forall a. Eq a => a -> a -> Bool
/= MemberRelation
MRSubjectConnected Bool -> Bool -> Bool
&& MemberRelation
newStatus MemberRelation -> MemberRelation -> Bool
forall a. Eq a => a -> a -> Bool
/= MemberRelation
MRReferencedConnected = ByteString
v
  | Bool
otherwise = case Maybe MemberRelation
status' of
      Maybe MemberRelation
Nothing -> ByteString
v
      Just MemberRelation
s -> Int64 -> MemberRelation -> ByteString -> ByteString
setRelation Int64
i MemberRelation
s ByteString
v
  where
    oldStatus :: MemberRelation
oldStatus = Int64 -> ByteString -> MemberRelation
getRelation Int64
i ByteString
v
    status' :: Maybe MemberRelation
status' = case (MemberRelation
oldStatus, MemberRelation
newStatus) of
      -- complementary statuses -> MRConnected
      (MemberRelation
MRSubjectConnected, MemberRelation
MRReferencedConnected) -> MemberRelation -> Maybe MemberRelation
forall a. a -> Maybe a
Just MemberRelation
MRConnected
      (MemberRelation
MRReferencedConnected, MemberRelation
MRSubjectConnected) -> MemberRelation -> Maybe MemberRelation
forall a. a -> Maybe a
Just MemberRelation
MRConnected
      -- newStatus > oldStatus -> set newStatus
      (MemberRelation, MemberRelation)
_ | MemberRelation
newStatus MemberRelation -> MemberRelation -> Bool
forall a. Ord a => a -> a -> Bool
> MemberRelation
oldStatus -> MemberRelation -> Maybe MemberRelation
forall a. a -> Maybe a
Just MemberRelation
newStatus
        | Bool
otherwise -> Maybe MemberRelation
forall a. Maybe a
Nothing

-- | Set a new relation with both direction and status at a given index.
-- Expands the vector lazily if needed.
setNewRelation :: Int64 -> IntroductionDirection -> MemberRelation -> ByteString -> ByteString
setNewRelation :: Int64
-> IntroductionDirection
-> MemberRelation
-> ByteString
-> ByteString
setNewRelation Int64
i IntroductionDirection
dir MemberRelation
r ByteString
v
  | Int64
i Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
>= Int64
0 = [(Int64, (IntroductionDirection, MemberRelation))]
-> ByteString -> ByteString
setNewRelations [(Int64
i, (IntroductionDirection
dir, MemberRelation
r))] ByteString
v
  | Bool
otherwise = ByteString
v

-- | Set multiple new relations with both direction and status at once.
-- Expands the vector lazily if needed.
setNewRelations :: [(Int64, (IntroductionDirection, MemberRelation))] -> ByteString -> ByteString
setNewRelations :: [(Int64, (IntroductionDirection, MemberRelation))]
-> ByteString -> ByteString
setNewRelations = ((IntroductionDirection, MemberRelation) -> Word8 -> Word8)
-> [(Int64, (IntroductionDirection, MemberRelation))]
-> ByteString
-> ByteString
forall r.
(r -> Word8 -> Word8) -> [(Int64, r)] -> ByteString -> ByteString
setRelations_ (((IntroductionDirection, MemberRelation) -> Word8 -> Word8)
 -> [(Int64, (IntroductionDirection, MemberRelation))]
 -> ByteString
 -> ByteString)
-> ((IntroductionDirection, MemberRelation) -> Word8 -> Word8)
-> [(Int64, (IntroductionDirection, MemberRelation))]
-> ByteString
-> ByteString
forall a b. (a -> b) -> a -> b
$ \(IntroductionDirection
dir, MemberRelation
r) Word8
b -> (Word8
b Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
relationMask) Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|. (IntroductionDirection -> Word8
toIntroDirInt IntroductionDirection
dir Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
`shiftL` Int
3) Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|. MemberRelation -> Word8
toRelationInt MemberRelation
r
  where
    relationMask :: Word8
relationMask = Word8 -> Word8
forall a. Bits a => a -> a
complement (Word8
statusMask Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|. Word8
directionMask)

setRelations_ :: (r -> Word8 -> Word8) -> [(Int64, r)] -> ByteString -> ByteString
setRelations_ :: forall r.
(r -> Word8 -> Word8) -> [(Int64, r)] -> ByteString -> ByteString
setRelations_ r -> Word8 -> Word8
_ [] ByteString
v = ByteString
v
setRelations_ r -> Word8 -> Word8
updateByte [(Int64, r)]
relations ByteString
v =
  let (ForeignPtr Word8
fp, Int
off, Int
len) = ByteString -> (ForeignPtr Word8, Int, Int)
toForeignPtr ByteString
v
      newLen :: Int
newLen = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
len (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Int) -> Int64 -> Int
forall a b. (a -> b) -> a -> b
$ [Int64] -> Int64
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum (((Int64, r) -> Int64) -> [(Int64, r)] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map (Int64, r) -> Int64
forall a b. (a, b) -> a
fst [(Int64, r)]
relations) Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int64
1
   in Int -> (Ptr Word8 -> IO ()) -> ByteString
unsafeCreate Int
newLen ((Ptr Word8 -> IO ()) -> ByteString)
-> (Ptr Word8 -> IO ()) -> ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> do
        ForeignPtr Word8 -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fp ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
vPtr -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
forall a. Ptr a -> Ptr a -> Int -> IO ()
copyBytes Ptr Word8
ptr (Ptr Word8
vPtr Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
off) Int
len
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
newLen Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
len) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Ptr Any -> Word8 -> Int -> IO ()
forall a. Ptr a -> Word8 -> Int -> IO ()
fillBytes (Ptr Word8
ptr Ptr Word8 -> Int -> Ptr Any
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
len) Word8
0 (Int
newLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len)
        [(Int64, r)] -> ((Int64, r) -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Int64, r)]
relations (((Int64, r) -> IO ()) -> IO ()) -> ((Int64, r) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(Int64
ix, r
r) -> Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int64
ix Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
>= Int64
0) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
          let i :: Int
i = Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
ix
           in Ptr Word8 -> Int -> Word8 -> IO ()
forall b. Ptr b -> Int -> Word8 -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr Word8
ptr Int
i (Word8 -> IO ()) -> (Word8 -> Word8) -> Word8 -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. r -> Word8 -> Word8
updateByte r
r (Word8 -> IO ()) -> IO Word8 -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr Word8 -> Int -> IO Word8
forall b. Ptr b -> Int -> IO Word8
forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr Word8
ptr Int
i

statusMask :: Word8
statusMask :: Word8
statusMask = Word8
0x07 -- bits 0-2

directionMask :: Word8
directionMask :: Word8
directionMask = Word8
0x08 -- bit 3