{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Simplex.Chat.Messages.Batch
  ( MsgBatch (..),
    BatchMode (..),
    encodeBatchElement,
    encodeFwdElement,
    encodeBinaryBatch,
    batchMessages,
    batchDeliveryTasks1,
  )
where

import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Int (Int64)
import Data.List (foldl')
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as L
import Simplex.Chat.Controller (ChatError (..), ChatErrorType (..))
import Simplex.Chat.Delivery
import Simplex.Chat.Messages
import Simplex.Chat.Protocol
import Simplex.Chat.Types (VersionRangeChat)
import Simplex.Messaging.Encoding (Large (..), smpEncode, smpEncodeList)

data BatchMode = BMJson | BMBinary
  deriving (BatchMode -> BatchMode -> Bool
(BatchMode -> BatchMode -> Bool)
-> (BatchMode -> BatchMode -> Bool) -> Eq BatchMode
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: BatchMode -> BatchMode -> Bool
== :: BatchMode -> BatchMode -> Bool
$c/= :: BatchMode -> BatchMode -> Bool
/= :: BatchMode -> BatchMode -> Bool
Eq, Int -> BatchMode -> ShowS
[BatchMode] -> ShowS
BatchMode -> String
(Int -> BatchMode -> ShowS)
-> (BatchMode -> String)
-> ([BatchMode] -> ShowS)
-> Show BatchMode
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> BatchMode -> ShowS
showsPrec :: Int -> BatchMode -> ShowS
$cshow :: BatchMode -> String
show :: BatchMode -> String
$cshowList :: [BatchMode] -> ShowS
showList :: [BatchMode] -> ShowS
Show)

-- | Encode a batch element with optional signature prefix.
-- Dual of elementP's '/'/'{'cases.
encodeBatchElement :: Maybe SignedMsg -> ByteString -> ByteString
encodeBatchElement :: Maybe SignedMsg -> ByteString -> ByteString
encodeBatchElement Maybe SignedMsg
Nothing ByteString
body = ByteString
body
encodeBatchElement (Just SignedMsg {ChatBinding
chatBinding :: ChatBinding
$sel:chatBinding:SignedMsg :: SignedMsg -> ChatBinding
chatBinding, NonEmpty MsgSignature
signatures :: NonEmpty MsgSignature
$sel:signatures:SignedMsg :: SignedMsg -> NonEmpty MsgSignature
signatures}) ByteString
body =
  ByteString
"/" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> (ChatBinding, NonEmpty MsgSignature) -> ByteString
forall a. Encoding a => a -> ByteString
smpEncode (ChatBinding
chatBinding, NonEmpty MsgSignature
signatures) ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
body

data MsgBatch = MsgBatch ByteString [SndMessage]

-- | Batches SndMessages in [Either ChatError SndMessage] into batches of ByteStrings.
-- BMJson mode: JSON arrays like [msg1,msg2,...]
-- BMBinary mode: Binary format =<count>(<len:2><body>)*
-- Preserves original errors in the list.
-- If a single element is passed, it is returned as is.
-- If an element exceeds maxLen, it is returned as ChatError.
-- Elements are encoded with signature prefix via encodeBatchElement.
batchMessages :: BatchMode -> Int -> [Either ChatError SndMessage] -> [Either ChatError MsgBatch]
batchMessages :: BatchMode
-> Int
-> [Either ChatError SndMessage]
-> [Either ChatError MsgBatch]
batchMessages BatchMode
mode Int
maxLen = ([Either ChatError MsgBatch], [ByteString], [SndMessage], Int, Int)
-> [Either ChatError MsgBatch]
addBatch (([Either ChatError MsgBatch], [ByteString], [SndMessage], Int,
  Int)
 -> [Either ChatError MsgBatch])
-> ([Either ChatError SndMessage]
    -> ([Either ChatError MsgBatch], [ByteString], [SndMessage], Int,
        Int))
-> [Either ChatError SndMessage]
-> [Either ChatError MsgBatch]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Either ChatError SndMessage
 -> ([Either ChatError MsgBatch], [ByteString], [SndMessage], Int,
     Int)
 -> ([Either ChatError MsgBatch], [ByteString], [SndMessage], Int,
     Int))
-> ([Either ChatError MsgBatch], [ByteString], [SndMessage], Int,
    Int)
-> [Either ChatError SndMessage]
-> ([Either ChatError MsgBatch], [ByteString], [SndMessage], Int,
    Int)
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Either ChatError SndMessage
-> ([Either ChatError MsgBatch], [ByteString], [SndMessage], Int,
    Int)
-> ([Either ChatError MsgBatch], [ByteString], [SndMessage], Int,
    Int)
addToBatch ([], [], [], Int
0, Int
0)
  where
    addToBatch :: Either ChatError SndMessage -> ([Either ChatError MsgBatch], [ByteString], [SndMessage], Int, Int) -> ([Either ChatError MsgBatch], [ByteString], [SndMessage], Int, Int)
    addToBatch :: Either ChatError SndMessage
-> ([Either ChatError MsgBatch], [ByteString], [SndMessage], Int,
    Int)
-> ([Either ChatError MsgBatch], [ByteString], [SndMessage], Int,
    Int)
addToBatch (Left ChatError
err) ([Either ChatError MsgBatch], [ByteString], [SndMessage], Int, Int)
acc = (ChatError -> Either ChatError MsgBatch
forall a b. a -> Either a b
Left ChatError
err Either ChatError MsgBatch
-> [Either ChatError MsgBatch] -> [Either ChatError MsgBatch]
forall a. a -> [a] -> [a]
: ([Either ChatError MsgBatch], [ByteString], [SndMessage], Int, Int)
-> [Either ChatError MsgBatch]
addBatch ([Either ChatError MsgBatch], [ByteString], [SndMessage], Int, Int)
acc, [], [], Int
0, Int
0) -- step over original error
    addToBatch (Right msg :: SndMessage
msg@SndMessage {ByteString
msgBody :: ByteString
$sel:msgBody:SndMessage :: SndMessage -> ByteString
msgBody, Maybe SignedMsg
signedMsg_ :: Maybe SignedMsg
$sel:signedMsg_:SndMessage :: SndMessage -> Maybe SignedMsg
signedMsg_}) acc :: ([Either ChatError MsgBatch], [ByteString], [SndMessage], Int, Int)
acc@([Either ChatError MsgBatch]
batches, [ByteString]
bodies, [SndMessage]
msgs, Int
len, Int
n)
      | BatchMode -> Int -> Int -> Int
batchLen BatchMode
mode Int
len' Int
n' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
maxLen = ([Either ChatError MsgBatch]
batches, ByteString
body ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
bodies, SndMessage
msg SndMessage -> [SndMessage] -> [SndMessage]
forall a. a -> [a] -> [a]
: [SndMessage]
msgs, Int
len', Int
n')
      | Int
msgLen Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
maxLen = (([Either ChatError MsgBatch], [ByteString], [SndMessage], Int, Int)
-> [Either ChatError MsgBatch]
addBatch ([Either ChatError MsgBatch], [ByteString], [SndMessage], Int, Int)
acc, [ByteString
body], [SndMessage
msg], Int
msgLen, Int
1)
      | Bool
otherwise = (SndMessage -> Either ChatError MsgBatch
forall {b}. SndMessage -> Either ChatError b
errLarge SndMessage
msg Either ChatError MsgBatch
-> [Either ChatError MsgBatch] -> [Either ChatError MsgBatch]
forall a. a -> [a] -> [a]
: ([Either ChatError MsgBatch], [ByteString], [SndMessage], Int, Int)
-> [Either ChatError MsgBatch]
addBatch ([Either ChatError MsgBatch], [ByteString], [SndMessage], Int, Int)
acc, [], [], Int
0, Int
0)
      where
        body :: ByteString
body = Maybe SignedMsg -> ByteString -> ByteString
encodeBatchElement Maybe SignedMsg
signedMsg_ ByteString
msgBody
        msgLen :: Int
msgLen = ByteString -> Int
B.length ByteString
body
        len' :: Int
len' = Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
msgLen
        n' :: Int
n' = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
        errLarge :: SndMessage -> Either ChatError b
errLarge SndMessage {MessageId
msgId :: MessageId
$sel:msgId:SndMessage :: SndMessage -> MessageId
msgId} = ChatError -> Either ChatError b
forall a b. a -> Either a b
Left (ChatError -> Either ChatError b)
-> ChatError -> Either ChatError b
forall a b. (a -> b) -> a -> b
$ ChatErrorType -> ChatError
ChatError (ChatErrorType -> ChatError) -> ChatErrorType -> ChatError
forall a b. (a -> b) -> a -> b
$ String -> ChatErrorType
CEInternalError (String
"large message " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> MessageId -> String
forall a. Show a => a -> String
show MessageId
msgId)
    addBatch :: ([Either ChatError MsgBatch], [ByteString], [SndMessage], Int, Int) -> [Either ChatError MsgBatch]
    addBatch :: ([Either ChatError MsgBatch], [ByteString], [SndMessage], Int, Int)
-> [Either ChatError MsgBatch]
addBatch ([Either ChatError MsgBatch]
batches, [ByteString]
bodies, [SndMessage]
msgs, Int
_, Int
n)
      | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = [Either ChatError MsgBatch]
batches
      | Bool
otherwise =
          let encoded :: ByteString
encoded = BatchMode -> [ByteString] -> ByteString
encodeBatch BatchMode
mode [ByteString]
bodies
           in MsgBatch -> Either ChatError MsgBatch
forall a b. b -> Either a b
Right (ByteString -> [SndMessage] -> MsgBatch
MsgBatch ByteString
encoded [SndMessage]
msgs) Either ChatError MsgBatch
-> [Either ChatError MsgBatch] -> [Either ChatError MsgBatch]
forall a. a -> [a] -> [a]
: [Either ChatError MsgBatch]
batches

-- | Batches delivery tasks into (batch, [taskIds], [largeTaskIds]).
-- Always uses binary batch format for relay groups.
batchDeliveryTasks1 :: VersionRangeChat -> Int -> NonEmpty MessageDeliveryTask -> (ByteString, [Int64], [Int64])
batchDeliveryTasks1 :: VersionRangeChat
-> Int
-> NonEmpty MessageDeliveryTask
-> (ByteString, [MessageId], [MessageId])
batchDeliveryTasks1 VersionRangeChat
_vr Int
maxLen = ([ByteString], [MessageId], [MessageId], Int, Int)
-> (ByteString, [MessageId], [MessageId])
toResult (([ByteString], [MessageId], [MessageId], Int, Int)
 -> (ByteString, [MessageId], [MessageId]))
-> (NonEmpty MessageDeliveryTask
    -> ([ByteString], [MessageId], [MessageId], Int, Int))
-> NonEmpty MessageDeliveryTask
-> (ByteString, [MessageId], [MessageId])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([ByteString], [MessageId], [MessageId], Int, Int)
 -> MessageDeliveryTask
 -> ([ByteString], [MessageId], [MessageId], Int, Int))
-> ([ByteString], [MessageId], [MessageId], Int, Int)
-> [MessageDeliveryTask]
-> ([ByteString], [MessageId], [MessageId], Int, Int)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ([ByteString], [MessageId], [MessageId], Int, Int)
-> MessageDeliveryTask
-> ([ByteString], [MessageId], [MessageId], Int, Int)
addToBatch ([], [], [], Int
0, Int
0) ([MessageDeliveryTask]
 -> ([ByteString], [MessageId], [MessageId], Int, Int))
-> (NonEmpty MessageDeliveryTask -> [MessageDeliveryTask])
-> NonEmpty MessageDeliveryTask
-> ([ByteString], [MessageId], [MessageId], Int, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NonEmpty MessageDeliveryTask -> [MessageDeliveryTask]
forall a. NonEmpty a -> [a]
L.toList
  where
    addToBatch :: ([ByteString], [Int64], [Int64], Int, Int) -> MessageDeliveryTask -> ([ByteString], [Int64], [Int64], Int, Int)
    addToBatch :: ([ByteString], [MessageId], [MessageId], Int, Int)
-> MessageDeliveryTask
-> ([ByteString], [MessageId], [MessageId], Int, Int)
addToBatch ([ByteString]
msgBodies, [MessageId]
taskIds, [MessageId]
largeTaskIds, Int
len, Int
n) MessageDeliveryTask
task
      -- too large: skip, record taskId in largeTaskIds
      | Int
msgLen Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
maxLen = ([ByteString]
msgBodies, [MessageId]
taskIds, MessageId
taskId MessageId -> [MessageId] -> [MessageId]
forall a. a -> [a] -> [a]
: [MessageId]
largeTaskIds, Int
len, Int
n)
      -- fits: include in batch
      -- batch overhead: '=' + count (2) + 2-byte length prefix per element
      | Int
len' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
maxLen = (ByteString
msgBody ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
msgBodies, MessageId
taskId MessageId -> [MessageId] -> [MessageId]
forall a. a -> [a] -> [a]
: [MessageId]
taskIds, [MessageId]
largeTaskIds, Int
len', Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
      -- doesn't fit: stop adding further messages
      | Bool
otherwise = ([ByteString]
msgBodies, [MessageId]
taskIds, [MessageId]
largeTaskIds, Int
len, Int
n)
      where
        MessageDeliveryTask {MessageId
taskId :: MessageId
$sel:taskId:MessageDeliveryTask :: MessageDeliveryTask -> MessageId
taskId, FwdSender
fwdSender :: FwdSender
$sel:fwdSender:MessageDeliveryTask :: MessageDeliveryTask -> FwdSender
fwdSender, $sel:brokerTs:MessageDeliveryTask :: MessageDeliveryTask -> UTCTime
brokerTs = UTCTime
fwdBrokerTs, VerifiedMsg 'Json
verifiedMsg :: VerifiedMsg 'Json
$sel:verifiedMsg:MessageDeliveryTask :: MessageDeliveryTask -> VerifiedMsg 'Json
verifiedMsg} = MessageDeliveryTask
task
        msgBody :: ByteString
msgBody = GrpMsgForward -> VerifiedMsg 'Json -> ByteString
encodeFwdElement GrpMsgForward {FwdSender
fwdSender :: FwdSender
$sel:fwdSender:GrpMsgForward :: FwdSender
fwdSender, UTCTime
fwdBrokerTs :: UTCTime
$sel:fwdBrokerTs:GrpMsgForward :: UTCTime
fwdBrokerTs} VerifiedMsg 'Json
verifiedMsg
        msgLen :: Int
msgLen = ByteString -> Int
B.length ByteString
msgBody
        len' :: Int
len' = Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
msgLen
    toResult :: ([ByteString], [Int64], [Int64], Int, Int) -> (ByteString, [Int64], [Int64])
    toResult :: ([ByteString], [MessageId], [MessageId], Int, Int)
-> (ByteString, [MessageId], [MessageId])
toResult ([ByteString]
msgBodies, [MessageId]
taskIds, [MessageId]
largeTaskIds, Int
_, Int
_) =
      let encoded :: ByteString
encoded = [ByteString] -> ByteString
encodeBinaryBatch ([ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse [ByteString]
msgBodies)
       in (ByteString
encoded, [MessageId] -> [MessageId]
forall a. [a] -> [a]
reverse [MessageId]
taskIds, [MessageId] -> [MessageId]
forall a. [a] -> [a]
reverse [MessageId]
largeTaskIds)

-- | Encode a batch element for relay groups: ><GrpMsgForward>[/<sigs>]<body>.
encodeFwdElement :: GrpMsgForward -> VerifiedMsg 'Json -> ByteString
encodeFwdElement :: GrpMsgForward -> VerifiedMsg 'Json -> ByteString
encodeFwdElement GrpMsgForward
fwd VerifiedMsg 'Json
verifiedMsg = ByteString
">" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> GrpMsgForward -> ByteString
forall a. Encoding a => a -> ByteString
smpEncode GrpMsgForward
fwd ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Maybe SignedMsg -> ByteString -> ByteString
encodeBatchElement Maybe SignedMsg
signedMsg_ ByteString
msgBody
  where
    (Maybe MsgSigStatus
_, Maybe SignedMsg
signedMsg_, ByteString
msgBody) = VerifiedMsg 'Json
-> (Maybe MsgSigStatus, Maybe SignedMsg, ByteString)
forall (e :: MsgEncoding).
MsgEncodingI e =>
VerifiedMsg e -> (Maybe MsgSigStatus, Maybe SignedMsg, ByteString)
verifiedMsgParts VerifiedMsg 'Json
verifiedMsg

encodeBatch :: BatchMode -> [ByteString] -> ByteString
encodeBatch :: BatchMode -> [ByteString] -> ByteString
encodeBatch BatchMode
_ [] = ByteString
forall a. Monoid a => a
mempty
encodeBatch BatchMode
_ [ByteString
msg] = ByteString
msg
encodeBatch BatchMode
BMJson [ByteString]
msgs = [ByteString] -> ByteString
B.concat [ByteString
"[", ByteString -> [ByteString] -> ByteString
B.intercalate ByteString
"," [ByteString]
msgs, ByteString
"]"]
encodeBatch BatchMode
BMBinary [ByteString]
msgs = Char -> ByteString -> ByteString
B.cons Char
'=' (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ [Large] -> ByteString
forall a. Encoding a => [a] -> ByteString
smpEncodeList ((ByteString -> Large) -> [ByteString] -> [Large]
forall a b. (a -> b) -> [a] -> [b]
map ByteString -> Large
Large [ByteString]
msgs)

-- Always uses batch format (no single-element shortcut) since elements may have F prefix.
encodeBinaryBatch :: [ByteString] -> ByteString
encodeBinaryBatch :: [ByteString] -> ByteString
encodeBinaryBatch [] = ByteString
forall a. Monoid a => a
mempty
encodeBinaryBatch [ByteString]
msgs = Char -> ByteString -> ByteString
B.cons Char
'=' (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ [Large] -> ByteString
forall a. Encoding a => [a] -> ByteString
smpEncodeList ((ByteString -> Large) -> [ByteString] -> [Large]
forall a b. (a -> b) -> [a] -> [b]
map ByteString -> Large
Large [ByteString]
msgs)

-- Returns length the batch would have if encoded.
-- `len` - the total length of all `n` encoded elements (including signature prefixes)
batchLen :: BatchMode -> Int -> Int -> Int
batchLen :: BatchMode -> Int -> Int -> Int
batchLen BatchMode
_ Int
_ Int
0 = Int
0
batchLen BatchMode
_ Int
len Int
1 = Int
len
batchLen BatchMode
BMJson Int
len Int
n = Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 -- (n - 1) commas + 2 brackets
batchLen BatchMode
BMBinary Int
len Int
n = Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2 -- 2-byte length prefix per element + '=' + count