{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NamedFieldPuns #-}

module Simplex.Chat.Mobile.WebRTC
  ( cChatEncryptMedia,
    cChatDecryptMedia,
    chatEncryptMedia,
    chatDecryptMedia,
    reservedSize,
  ) where

import Control.Monad
import Control.Monad.Except
import Control.Monad.IO.Class
import qualified Crypto.Cipher.Types as AES
import Data.Bifunctor (bimap)
import qualified Data.ByteArray as BA
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Base64.URL as U
import Data.Either (fromLeft)
import Data.Word (Word8)
import Foreign.C (CInt, CString, newCAString)
import Foreign.Ptr (Ptr)
import Foreign.StablePtr
import Simplex.Chat.Controller (ChatController (..))
import Simplex.Chat.Mobile.Shared
import qualified Simplex.Messaging.Crypto as C
import UnliftIO (atomically)

cChatEncryptMedia :: StablePtr ChatController -> CString -> Ptr Word8 -> CInt -> IO CString
cChatEncryptMedia :: StablePtr ChatController
-> CString -> Ptr Word8 -> CInt -> IO CString
cChatEncryptMedia = (ByteString -> ByteString -> ExceptT String IO ByteString)
-> CString -> Ptr Word8 -> CInt -> IO CString
cTransformMedia ((ByteString -> ByteString -> ExceptT String IO ByteString)
 -> CString -> Ptr Word8 -> CInt -> IO CString)
-> (StablePtr ChatController
    -> ByteString -> ByteString -> ExceptT String IO ByteString)
-> StablePtr ChatController
-> CString
-> Ptr Word8
-> CInt
-> IO CString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StablePtr ChatController
-> ByteString -> ByteString -> ExceptT String IO ByteString
chatEncryptMedia

cChatDecryptMedia :: CString -> Ptr Word8 -> CInt -> IO CString
cChatDecryptMedia :: CString -> Ptr Word8 -> CInt -> IO CString
cChatDecryptMedia = (ByteString -> ByteString -> ExceptT String IO ByteString)
-> CString -> Ptr Word8 -> CInt -> IO CString
cTransformMedia ByteString -> ByteString -> ExceptT String IO ByteString
chatDecryptMedia

cTransformMedia :: (ByteString -> ByteString -> ExceptT String IO ByteString) -> CString -> Ptr Word8 -> CInt -> IO CString
cTransformMedia :: (ByteString -> ByteString -> ExceptT String IO ByteString)
-> CString -> Ptr Word8 -> CInt -> IO CString
cTransformMedia ByteString -> ByteString -> ExceptT String IO ByteString
f CString
cKey Ptr Word8
cFrame CInt
cFrameLen = do
  ByteString
key <- CString -> IO ByteString
B.packCString CString
cKey
  ByteString
frame <- Ptr Word8 -> CInt -> IO ByteString
getByteString Ptr Word8
cFrame CInt
cFrameLen
  ExceptT String IO () -> IO (Either String ())
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ByteString -> ByteString -> ExceptT String IO ByteString
f ByteString
key ByteString
frame ExceptT String IO ByteString
-> (ByteString -> ExceptT String IO ()) -> ExceptT String IO ()
forall a b.
ExceptT String IO a
-> (a -> ExceptT String IO b) -> ExceptT String IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IO () -> ExceptT String IO ()
forall a. IO a -> ExceptT String IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ExceptT String IO ())
-> (ByteString -> IO ()) -> ByteString -> ExceptT String IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> IO ()
putFrame) IO (Either String ())
-> (Either String () -> IO CString) -> IO CString
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= String -> IO CString
newCAString (String -> IO CString)
-> (Either String () -> String) -> Either String () -> IO CString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Either String () -> String
forall a b. a -> Either a b -> a
fromLeft String
""
  where
    putFrame :: ByteString -> IO ()
putFrame ByteString
s = Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteString -> Int
B.length ByteString
s Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
cFrameLen) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Ptr Word8 -> ByteString -> IO ()
putByteString Ptr Word8
cFrame ByteString
s
{-# INLINE cTransformMedia #-}

chatEncryptMedia :: StablePtr ChatController -> ByteString -> ByteString -> ExceptT String IO ByteString
chatEncryptMedia :: StablePtr ChatController
-> ByteString -> ByteString -> ExceptT String IO ByteString
chatEncryptMedia StablePtr ChatController
cc ByteString
keyStr ByteString
frame = do
  ChatController {TVar ChaChaDRG
random :: TVar ChaChaDRG
random :: ChatController -> TVar ChaChaDRG
random} <- IO ChatController -> ExceptT String IO ChatController
forall a. IO a -> ExceptT String IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ChatController -> ExceptT String IO ChatController)
-> IO ChatController -> ExceptT String IO ChatController
forall a b. (a -> b) -> a -> b
$ StablePtr ChatController -> IO ChatController
forall a. StablePtr a -> IO a
deRefStablePtr StablePtr ChatController
cc
  Int
len <- ByteString -> ExceptT String IO Int
checkFrameLen ByteString
frame
  Key
key <- ByteString -> ExceptT String IO Key
decodeKey ByteString
keyStr
  GCMIV
iv <- STM GCMIV -> ExceptT String IO GCMIV
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM GCMIV -> ExceptT String IO GCMIV)
-> STM GCMIV -> ExceptT String IO GCMIV
forall a b. (a -> b) -> a -> b
$ TVar ChaChaDRG -> STM GCMIV
C.randomGCMIV TVar ChaChaDRG
random
  (AuthTag
tag, ByteString
frame') <- (CryptoError -> String)
-> ExceptT CryptoError IO (AuthTag, ByteString)
-> ExceptT String IO (AuthTag, ByteString)
forall (m :: * -> *) e e' a.
Functor m =>
(e -> e') -> ExceptT e m a -> ExceptT e' m a
withExceptT CryptoError -> String
forall a. Show a => a -> String
show (ExceptT CryptoError IO (AuthTag, ByteString)
 -> ExceptT String IO (AuthTag, ByteString))
-> ExceptT CryptoError IO (AuthTag, ByteString)
-> ExceptT String IO (AuthTag, ByteString)
forall a b. (a -> b) -> a -> b
$ Key
-> GCMIV
-> ByteString
-> ExceptT CryptoError IO (AuthTag, ByteString)
C.encryptAESNoPad Key
key GCMIV
iv (ByteString -> ExceptT CryptoError IO (AuthTag, ByteString))
-> ByteString -> ExceptT CryptoError IO (AuthTag, ByteString)
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
B.take Int
len ByteString
frame
  ByteString -> ExceptT String IO ByteString
forall a. a -> ExceptT String IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> ExceptT String IO ByteString)
-> ByteString -> ExceptT String IO ByteString
forall a b. (a -> b) -> a -> b
$ ByteString
frame' ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> AuthTag -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert (AuthTag -> AuthTag
C.unAuthTag AuthTag
tag) ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> GCMIV -> ByteString
C.unGCMIV GCMIV
iv

chatDecryptMedia :: ByteString -> ByteString -> ExceptT String IO ByteString
chatDecryptMedia :: ByteString -> ByteString -> ExceptT String IO ByteString
chatDecryptMedia ByteString
keyStr ByteString
frame = do
  Int
len <- ByteString -> ExceptT String IO Int
checkFrameLen ByteString
frame
  Key
key <- ByteString -> ExceptT String IO Key
decodeKey ByteString
keyStr
  let (ByteString
frame', ByteString
rest) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
len ByteString
frame
      (ByteString
tag, ByteString
iv) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
C.authTagSize ByteString
rest
      authTag :: AuthTag
authTag = AuthTag -> AuthTag
C.AuthTag (AuthTag -> AuthTag) -> AuthTag -> AuthTag
forall a b. (a -> b) -> a -> b
$ Bytes -> AuthTag
AES.AuthTag (Bytes -> AuthTag) -> Bytes -> AuthTag
forall a b. (a -> b) -> a -> b
$ ByteString -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ByteString
tag
  (CryptoError -> String)
-> ExceptT CryptoError IO ByteString
-> ExceptT String IO ByteString
forall (m :: * -> *) e e' a.
Functor m =>
(e -> e') -> ExceptT e m a -> ExceptT e' m a
withExceptT CryptoError -> String
forall a. Show a => a -> String
show (ExceptT CryptoError IO ByteString -> ExceptT String IO ByteString)
-> ExceptT CryptoError IO ByteString
-> ExceptT String IO ByteString
forall a b. (a -> b) -> a -> b
$ do
    GCMIV
iv' <- Either CryptoError GCMIV -> ExceptT CryptoError IO GCMIV
forall e (m :: * -> *) a. MonadError e m => Either e a -> m a
liftEither (Either CryptoError GCMIV -> ExceptT CryptoError IO GCMIV)
-> Either CryptoError GCMIV -> ExceptT CryptoError IO GCMIV
forall a b. (a -> b) -> a -> b
$ ByteString -> Either CryptoError GCMIV
C.gcmIV ByteString
iv
    ByteString
frame'' <- Key
-> GCMIV
-> ByteString
-> AuthTag
-> ExceptT CryptoError IO ByteString
C.decryptAESNoPad Key
key GCMIV
iv' ByteString
frame' AuthTag
authTag
    ByteString -> ExceptT CryptoError IO ByteString
forall a. a -> ExceptT CryptoError IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> ExceptT CryptoError IO ByteString)
-> ByteString -> ExceptT CryptoError IO ByteString
forall a b. (a -> b) -> a -> b
$ ByteString
frame'' ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
framePad

checkFrameLen :: ByteString -> ExceptT String IO Int
checkFrameLen :: ByteString -> ExceptT String IO Int
checkFrameLen ByteString
frame = do
  let len :: Int
len = ByteString -> Int
B.length ByteString
frame Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
reservedSize
  Bool -> ExceptT String IO () -> ExceptT String IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0) (ExceptT String IO () -> ExceptT String IO ())
-> ExceptT String IO () -> ExceptT String IO ()
forall a b. (a -> b) -> a -> b
$ String -> ExceptT String IO ()
forall a. String -> ExceptT String IO a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError String
"frame has no [reserved space for] IV and/or auth tag"
  Int -> ExceptT String IO Int
forall a. a -> ExceptT String IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
len
{-# INLINE checkFrameLen #-}

decodeKey :: ByteString -> ExceptT String IO C.Key
decodeKey :: ByteString -> ExceptT String IO Key
decodeKey = Either String Key -> ExceptT String IO Key
forall e (m :: * -> *) a. MonadError e m => Either e a -> m a
liftEither (Either String Key -> ExceptT String IO Key)
-> (ByteString -> Either String Key)
-> ByteString
-> ExceptT String IO Key
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String -> String)
-> (ByteString -> Key)
-> Either String ByteString
-> Either String Key
forall a b c d. (a -> b) -> (c -> d) -> Either a c -> Either b d
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap (String
"invalid key: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<>) ByteString -> Key
C.Key (Either String ByteString -> Either String Key)
-> (ByteString -> Either String ByteString)
-> ByteString
-> Either String Key
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Either String ByteString
U.decode
{-# INLINE decodeKey #-}

reservedSize :: Int
reservedSize :: Int
reservedSize = Int
C.authTagSize Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
C.gcmIVSize

framePad :: ByteString
framePad :: ByteString
framePad = Int -> Word8 -> ByteString
B.replicate Int
reservedSize Word8
0