{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Simplex.Chat.Messages.Batch
  ( MsgBatch (..),
    batchMessages,
    batchDeliveryTasks1,
  )
where

import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Int (Int64)
import Data.List (foldl')
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as L
import Simplex.Chat.Controller (ChatError (..), ChatErrorType (..))
import Simplex.Chat.Delivery
import Simplex.Chat.Messages
import Simplex.Chat.Protocol
import Simplex.Chat.Types (VersionRangeChat)

data MsgBatch = MsgBatch ByteString [SndMessage]

-- | Batches SndMessages in [Either ChatError SndMessage] into batches of ByteStrings in form of JSON arrays.
-- Preserves original errors in the list.
-- Does not check if the resulting batch is a valid JSON.
-- If a single element is passed, it is returned as is (a JSON string).
-- If an element exceeds maxLen, it is returned as ChatError.
batchMessages :: Int -> [Either ChatError SndMessage] -> [Either ChatError MsgBatch]
batchMessages :: Int -> [Either ChatError SndMessage] -> [Either ChatError MsgBatch]
batchMessages Int
maxLen = ([Either ChatError MsgBatch], [SndMessage], Int, Int)
-> [Either ChatError MsgBatch]
addBatch (([Either ChatError MsgBatch], [SndMessage], Int, Int)
 -> [Either ChatError MsgBatch])
-> ([Either ChatError SndMessage]
    -> ([Either ChatError MsgBatch], [SndMessage], Int, Int))
-> [Either ChatError SndMessage]
-> [Either ChatError MsgBatch]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Either ChatError SndMessage
 -> ([Either ChatError MsgBatch], [SndMessage], Int, Int)
 -> ([Either ChatError MsgBatch], [SndMessage], Int, Int))
-> ([Either ChatError MsgBatch], [SndMessage], Int, Int)
-> [Either ChatError SndMessage]
-> ([Either ChatError MsgBatch], [SndMessage], Int, Int)
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Either ChatError SndMessage
-> ([Either ChatError MsgBatch], [SndMessage], Int, Int)
-> ([Either ChatError MsgBatch], [SndMessage], Int, Int)
addToBatch ([], [], Int
0, Int
0)
  where
    msgBatch :: [SndMessage] -> Either a MsgBatch
msgBatch [SndMessage]
batch = MsgBatch -> Either a MsgBatch
forall a b. b -> Either a b
Right (ByteString -> [SndMessage] -> MsgBatch
MsgBatch ([SndMessage] -> ByteString
encodeMessages [SndMessage]
batch) [SndMessage]
batch)
    addToBatch :: Either ChatError SndMessage -> ([Either ChatError MsgBatch], [SndMessage], Int, Int) -> ([Either ChatError MsgBatch], [SndMessage], Int, Int)
    addToBatch :: Either ChatError SndMessage
-> ([Either ChatError MsgBatch], [SndMessage], Int, Int)
-> ([Either ChatError MsgBatch], [SndMessage], Int, Int)
addToBatch (Left ChatError
err) ([Either ChatError MsgBatch], [SndMessage], Int, Int)
acc = (ChatError -> Either ChatError MsgBatch
forall a b. a -> Either a b
Left ChatError
err Either ChatError MsgBatch
-> [Either ChatError MsgBatch] -> [Either ChatError MsgBatch]
forall a. a -> [a] -> [a]
: ([Either ChatError MsgBatch], [SndMessage], Int, Int)
-> [Either ChatError MsgBatch]
addBatch ([Either ChatError MsgBatch], [SndMessage], Int, Int)
acc, [], Int
0, Int
0) -- step over original error
    addToBatch (Right msg :: SndMessage
msg@SndMessage {ByteString
msgBody :: ByteString
msgBody :: SndMessage -> ByteString
msgBody}) acc :: ([Either ChatError MsgBatch], [SndMessage], Int, Int)
acc@([Either ChatError MsgBatch]
batches, [SndMessage]
batch, Int
len, Int
n)
      | Int
batchLen Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
maxLen = ([Either ChatError MsgBatch]
batches, SndMessage
msg SndMessage -> [SndMessage] -> [SndMessage]
forall a. a -> [a] -> [a]
: [SndMessage]
batch, Int
len', Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
      | Int
msgLen Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
maxLen = (([Either ChatError MsgBatch], [SndMessage], Int, Int)
-> [Either ChatError MsgBatch]
addBatch ([Either ChatError MsgBatch], [SndMessage], Int, Int)
acc, [SndMessage
msg], Int
msgLen, Int
1)
      | Bool
otherwise = (SndMessage -> Either ChatError MsgBatch
forall {b}. SndMessage -> Either ChatError b
errLarge SndMessage
msg Either ChatError MsgBatch
-> [Either ChatError MsgBatch] -> [Either ChatError MsgBatch]
forall a. a -> [a] -> [a]
: ([Either ChatError MsgBatch], [SndMessage], Int, Int)
-> [Either ChatError MsgBatch]
addBatch ([Either ChatError MsgBatch], [SndMessage], Int, Int)
acc, [], Int
0, Int
0)
      where
        msgLen :: Int
msgLen = ByteString -> Int
B.length ByteString
msgBody
        len' :: Int
len'
          | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Int
msgLen
          | Bool
otherwise = Int
msgLen Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 -- 1 accounts for comma
        batchLen :: Int
batchLen
          | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Int
len'
          | Bool
otherwise = Int
len' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2 -- 2 accounts for opening and closing brackets
        errLarge :: SndMessage -> Either ChatError b
errLarge SndMessage {MessageId
msgId :: MessageId
msgId :: SndMessage -> MessageId
msgId} = ChatError -> Either ChatError b
forall a b. a -> Either a b
Left (ChatError -> Either ChatError b)
-> ChatError -> Either ChatError b
forall a b. (a -> b) -> a -> b
$ ChatErrorType -> ChatError
ChatError (ChatErrorType -> ChatError) -> ChatErrorType -> ChatError
forall a b. (a -> b) -> a -> b
$ String -> ChatErrorType
CEInternalError (String
"large message " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> MessageId -> String
forall a. Show a => a -> String
show MessageId
msgId)
    addBatch :: ([Either ChatError MsgBatch], [SndMessage], Int, Int) -> [Either ChatError MsgBatch]
    addBatch :: ([Either ChatError MsgBatch], [SndMessage], Int, Int)
-> [Either ChatError MsgBatch]
addBatch ([Either ChatError MsgBatch]
batches, [SndMessage]
batch, Int
_, Int
n) = if Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then [Either ChatError MsgBatch]
batches else [SndMessage] -> Either ChatError MsgBatch
forall {a}. [SndMessage] -> Either a MsgBatch
msgBatch [SndMessage]
batch Either ChatError MsgBatch
-> [Either ChatError MsgBatch] -> [Either ChatError MsgBatch]
forall a. a -> [a] -> [a]
: [Either ChatError MsgBatch]
batches
    encodeMessages :: [SndMessage] -> ByteString
    encodeMessages :: [SndMessage] -> ByteString
encodeMessages = \case
      [] -> ByteString
forall a. Monoid a => a
mempty
      [SndMessage
msg] -> SndMessage -> ByteString
body SndMessage
msg
      [SndMessage]
msgs -> [ByteString] -> ByteString
B.concat [ByteString
"[", ByteString -> [ByteString] -> ByteString
B.intercalate ByteString
"," ((SndMessage -> ByteString) -> [SndMessage] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map SndMessage -> ByteString
body [SndMessage]
msgs), ByteString
"]"]
    body :: SndMessage -> ByteString
body SndMessage {ByteString
msgBody :: SndMessage -> ByteString
msgBody :: ByteString
msgBody} = ByteString
msgBody

-- | Batches delivery tasks into (batch, [taskIds], [largeTaskIds]).
batchDeliveryTasks1 :: VersionRangeChat -> Int -> NonEmpty MessageDeliveryTask -> (ByteString, [Int64], [Int64])
batchDeliveryTasks1 :: VersionRangeChat
-> Int
-> NonEmpty MessageDeliveryTask
-> (ByteString, [MessageId], [MessageId])
batchDeliveryTasks1 VersionRangeChat
vr Int
maxLen = ([ByteString], [MessageId], [MessageId], Int, Int)
-> (ByteString, [MessageId], [MessageId])
toResult (([ByteString], [MessageId], [MessageId], Int, Int)
 -> (ByteString, [MessageId], [MessageId]))
-> (NonEmpty MessageDeliveryTask
    -> ([ByteString], [MessageId], [MessageId], Int, Int))
-> NonEmpty MessageDeliveryTask
-> (ByteString, [MessageId], [MessageId])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([ByteString], [MessageId], [MessageId], Int, Int)
 -> MessageDeliveryTask
 -> ([ByteString], [MessageId], [MessageId], Int, Int))
-> ([ByteString], [MessageId], [MessageId], Int, Int)
-> [MessageDeliveryTask]
-> ([ByteString], [MessageId], [MessageId], Int, Int)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ([ByteString], [MessageId], [MessageId], Int, Int)
-> MessageDeliveryTask
-> ([ByteString], [MessageId], [MessageId], Int, Int)
addToBatch ([], [], [], Int
0, Int
0) ([MessageDeliveryTask]
 -> ([ByteString], [MessageId], [MessageId], Int, Int))
-> (NonEmpty MessageDeliveryTask -> [MessageDeliveryTask])
-> NonEmpty MessageDeliveryTask
-> ([ByteString], [MessageId], [MessageId], Int, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NonEmpty MessageDeliveryTask -> [MessageDeliveryTask]
forall a. NonEmpty a -> [a]
L.toList
  where
    addToBatch :: ([ByteString], [Int64], [Int64], Int, Int) -> MessageDeliveryTask -> ([ByteString], [Int64], [Int64], Int, Int)
    addToBatch :: ([ByteString], [MessageId], [MessageId], Int, Int)
-> MessageDeliveryTask
-> ([ByteString], [MessageId], [MessageId], Int, Int)
addToBatch ([ByteString]
msgBodies, [MessageId]
taskIds, [MessageId]
largeTaskIds, Int
len, Int
n) MessageDeliveryTask
task
      -- too large: skip msgBody, record taskId in largeTaskIds
      | Int
msgLen Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
maxLen = ([ByteString]
msgBodies, [MessageId]
taskIds, MessageId
taskId MessageId -> [MessageId] -> [MessageId]
forall a. a -> [a] -> [a]
: [MessageId]
largeTaskIds, Int
len, Int
n)
      -- fits: include in batch
      | Int
batchLen Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
maxLen = (ByteString
msgBody ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
msgBodies, MessageId
taskId MessageId -> [MessageId] -> [MessageId]
forall a. a -> [a] -> [a]
: [MessageId]
taskIds, [MessageId]
largeTaskIds, Int
len', Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
      -- doesn’t fit: stop adding further messages
      | Bool
otherwise = ([ByteString]
msgBodies, [MessageId]
taskIds, [MessageId]
largeTaskIds, Int
len, Int
n)
      where
        MessageDeliveryTask {MessageId
taskId :: MessageId
taskId :: MessageDeliveryTask -> MessageId
taskId, MemberId
senderMemberId :: MemberId
senderMemberId :: MessageDeliveryTask -> MemberId
senderMemberId, ContactName
senderMemberName :: ContactName
senderMemberName :: MessageDeliveryTask -> ContactName
senderMemberName, UTCTime
brokerTs :: UTCTime
brokerTs :: MessageDeliveryTask -> UTCTime
brokerTs, ChatMessage 'Json
chatMessage :: ChatMessage 'Json
chatMessage :: MessageDeliveryTask -> ChatMessage 'Json
chatMessage, messageFromChannel :: MessageDeliveryTask -> Bool
messageFromChannel = Bool
_messageFromChannel} = MessageDeliveryTask
task
        -- TODO [channels fwd] handle messageFromChannel (null memberId in XGrpMsgForward)
        msgBody :: ByteString
msgBody =
          let fwdEvt :: ChatMsgEvent 'Json
fwdEvt = MemberId
-> Maybe ContactName
-> ChatMessage 'Json
-> UTCTime
-> ChatMsgEvent 'Json
XGrpMsgForward MemberId
senderMemberId (ContactName -> Maybe ContactName
forall a. a -> Maybe a
Just ContactName
senderMemberName) ChatMessage 'Json
chatMessage UTCTime
brokerTs
              cm :: ChatMessage 'Json
cm = ChatMessage {chatVRange :: VersionRangeChat
chatVRange = VersionRangeChat
vr, msgId :: Maybe SharedMsgId
msgId = Maybe SharedMsgId
forall a. Maybe a
Nothing, chatMsgEvent :: ChatMsgEvent 'Json
chatMsgEvent = ChatMsgEvent 'Json
fwdEvt}
            in ChatMessage 'Json -> ByteString
forall (e :: MsgEncoding).
MsgEncodingI e =>
ChatMessage e -> ByteString
chatMsgToBody ChatMessage 'Json
cm
        msgLen :: Int
msgLen = ByteString -> Int
B.length ByteString
msgBody
        len' :: Int
len'
          | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Int
msgLen
          | Bool
otherwise = Int
msgLen Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 -- 1 accounts for comma
        batchLen :: Int
batchLen
          | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Int
len'
          | Bool
otherwise = Int
len' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2 -- 2 accounts for opening and closing brackets
    toResult :: ([ByteString], [Int64], [Int64], Int, Int) -> (ByteString, [Int64], [Int64])
    toResult :: ([ByteString], [MessageId], [MessageId], Int, Int)
-> (ByteString, [MessageId], [MessageId])
toResult ([ByteString]
msgBodies, [MessageId]
taskIds, [MessageId]
largeTaskIds, Int
_, Int
_) =
      ([ByteString] -> ByteString
encodeMessages ([ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse [ByteString]
msgBodies), [MessageId] -> [MessageId]
forall a. [a] -> [a]
reverse [MessageId]
taskIds, [MessageId] -> [MessageId]
forall a. [a] -> [a]
reverse [MessageId]
largeTaskIds)
    encodeMessages :: [ByteString] -> ByteString
    encodeMessages :: [ByteString] -> ByteString
encodeMessages = \case
      [] -> ByteString
forall a. Monoid a => a
mempty
      [ByteString
msg] -> ByteString
msg
      [ByteString]
msgs -> [ByteString] -> ByteString
B.concat [ByteString
"[", ByteString -> [ByteString] -> ByteString
B.intercalate ByteString
"," [ByteString]
msgs, ByteString
"]"]