{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}

module Simplex.Messaging.Compression
  ( Compressed,
    maxLengthPassthrough,
    compressionLevel,
    compress1,
    decompress1,
    limitDecompress1,
    decompressedSize,
  ) where

import qualified Codec.Compression.Zstd as Z1
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Simplex.Messaging.Encoding

data Compressed
  = -- | Short messages are left intact to skip copying and FFI festivities.
    Passthrough ByteString
  | -- | Generic compression using no extra context.
    Compressed Large

-- | Messages below this length are not encoded to avoid compression overhead.
maxLengthPassthrough :: Int
maxLengthPassthrough :: Int
maxLengthPassthrough = Int
180 -- Sampled from real client data. Messages with length > 180 rapidly gain compression ratio.

compressionLevel :: Num a => a
compressionLevel :: forall a. Num a => a
compressionLevel = a
3

instance Encoding Compressed where
  smpEncode :: Compressed -> ByteString
smpEncode = \case
    Passthrough ByteString
bytes -> ByteString
"0" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
forall a. Encoding a => a -> ByteString
smpEncode ByteString
bytes
    Compressed Large
bytes -> ByteString
"1" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Large -> ByteString
forall a. Encoding a => a -> ByteString
smpEncode Large
bytes
  smpP :: Parser Compressed
smpP =
    Parser Char
forall a. Encoding a => Parser a
smpP Parser Char -> (Char -> Parser Compressed) -> Parser Compressed
forall a b.
Parser ByteString a
-> (a -> Parser ByteString b) -> Parser ByteString b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Char
'0' -> ByteString -> Compressed
Passthrough (ByteString -> Compressed)
-> Parser ByteString ByteString -> Parser Compressed
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser ByteString ByteString
forall a. Encoding a => Parser a
smpP
      Char
'1' -> Large -> Compressed
Compressed (Large -> Compressed)
-> Parser ByteString Large -> Parser Compressed
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser ByteString Large
forall a. Encoding a => Parser a
smpP
      Char
x -> String -> Parser Compressed
forall a. String -> Parser ByteString a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Parser Compressed) -> String -> Parser Compressed
forall a b. (a -> b) -> a -> b
$ String
"unknown Compressed tag: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Char -> String
forall a. Show a => a -> String
show Char
x

compress1 :: ByteString -> Compressed
compress1 :: ByteString -> Compressed
compress1 ByteString
bs
  | ByteString -> Int
B.length ByteString
bs Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
maxLengthPassthrough = ByteString -> Compressed
Passthrough ByteString
bs
  | Bool
otherwise = Large -> Compressed
Compressed (Large -> Compressed)
-> (ByteString -> Large) -> ByteString -> Compressed
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Large
Large (ByteString -> Compressed) -> ByteString -> Compressed
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
Z1.compress Int
forall a. Num a => a
compressionLevel ByteString
bs

decompressedSize :: Compressed -> Maybe Int
decompressedSize :: Compressed -> Maybe Int
decompressedSize = \case
  Passthrough ByteString
bs -> Int -> Maybe Int
forall a. a -> Maybe a
Just (Int -> Maybe Int) -> Int -> Maybe Int
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
B.length ByteString
bs
  Compressed (Large ByteString
bs) -> ByteString -> Maybe Int
Z1.decompressedSize ByteString
bs

decompress1 :: Compressed -> Either String ByteString
decompress1 :: Compressed -> Either String ByteString
decompress1 = \case
  Passthrough ByteString
bs -> ByteString -> Either String ByteString
forall a b. b -> Either a b
Right ByteString
bs
  Compressed (Large ByteString
bs) -> ByteString -> Either String ByteString
decompress_ ByteString
bs

limitDecompress1 :: Int -> Compressed -> Either String ByteString
limitDecompress1 :: Int -> Compressed -> Either String ByteString
limitDecompress1 Int
limit = \case
  Passthrough ByteString
bs -> ByteString -> Either String ByteString
forall a b. b -> Either a b
Right ByteString
bs
  Compressed (Large ByteString
bs) -> case ByteString -> Maybe Int
Z1.decompressedSize ByteString
bs of
    Just Int
sz | Int
sz Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
limit -> ByteString -> Either String ByteString
decompress_ ByteString
bs
    Maybe Int
_ -> String -> Either String ByteString
forall a b. a -> Either a b
Left (String -> Either String ByteString)
-> String -> Either String ByteString
forall a b. (a -> b) -> a -> b
$ String
"compressed size not specified or exceeds " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
limit

decompress_ :: ByteString -> Either String ByteString
decompress_ :: ByteString -> Either String ByteString
decompress_ ByteString
bs = case ByteString -> Decompress
Z1.decompress ByteString
bs of
  Z1.Error String
e -> String -> Either String ByteString
forall a b. a -> Either a b
Left String
e
  Decompress
Z1.Skip -> ByteString -> Either String ByteString
forall a b. b -> Either a b
Right ByteString
forall a. Monoid a => a
mempty
  Z1.Decompress ByteString
bs' -> ByteString -> Either String ByteString
forall a b. b -> Either a b
Right ByteString
bs'