{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Simplex.Chat.Remote.Transport where

import Control.Monad
import Control.Monad.Except
import Data.ByteString (ByteString)
import Data.ByteString.Builder (Builder, byteString)
import qualified Data.ByteString.Lazy as LB
import Data.Word (Word32)
import Simplex.Chat.Remote.Types
import Simplex.FileTransfer.Description (FileDigest (..))
import Simplex.FileTransfer.Transport (ReceiveFileError (..), receiveSbFile, sendEncFile)
import qualified Simplex.Messaging.Crypto as C
import qualified Simplex.Messaging.Crypto.Lazy as LC
import Simplex.Messaging.Encoding
import Simplex.Messaging.Util (liftError', liftEitherWith)
import Simplex.RemoteControl.Types (RCErrorType (..))
import UnliftIO
import UnliftIO.Directory (getFileSize)

type EncryptedFile = ((Handle, Word32), LC.SbState)

prepareEncryptedFile :: C.SbKeyNonce -> (Handle, Word32) -> ExceptT RemoteProtocolError IO EncryptedFile
prepareEncryptedFile :: SbKeyNonce
-> (Handle, Word32) -> ExceptT RemoteProtocolError IO EncryptedFile
prepareEncryptedFile (SbKey
sk, CbNonce
nonce) (Handle, Word32)
f = do
  SbState
sbState <- (CryptoError -> RemoteProtocolError)
-> Either CryptoError SbState
-> ExceptT RemoteProtocolError IO SbState
forall (m :: * -> *) e e' a.
MonadIO m =>
(e -> e') -> Either e a -> ExceptT e' m a
liftEitherWith (RemoteProtocolError -> CryptoError -> RemoteProtocolError
forall a b. a -> b -> a
const (RemoteProtocolError -> CryptoError -> RemoteProtocolError)
-> RemoteProtocolError -> CryptoError -> RemoteProtocolError
forall a b. (a -> b) -> a -> b
$ RCErrorType -> RemoteProtocolError
PRERemoteControl RCErrorType
RCEEncrypt) (Either CryptoError SbState
 -> ExceptT RemoteProtocolError IO SbState)
-> Either CryptoError SbState
-> ExceptT RemoteProtocolError IO SbState
forall a b. (a -> b) -> a -> b
$ SbKey -> CbNonce -> Either CryptoError SbState
LC.sbInit SbKey
sk CbNonce
nonce
  EncryptedFile -> ExceptT RemoteProtocolError IO EncryptedFile
forall a. a -> ExceptT RemoteProtocolError IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Handle, Word32)
f, SbState
sbState)

sendEncryptedFile :: EncryptedFile -> (Builder -> IO ()) -> IO ()
sendEncryptedFile :: EncryptedFile -> (Builder -> IO ()) -> IO ()
sendEncryptedFile ((Handle
h, Word32
sz), SbState
sbState) Builder -> IO ()
send = do
  Builder -> IO ()
send (Builder -> IO ()) -> Builder -> IO ()
forall a b. (a -> b) -> a -> b
$ ByteString -> Builder
byteString (ByteString -> Builder) -> ByteString -> Builder
forall a b. (a -> b) -> a -> b
$ (Char, Word32) -> ByteString
forall a. Encoding a => a -> ByteString
smpEncode (Char
'\x01', Word32
sz Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
C.authTagSize)
  Handle -> (Builder -> IO ()) -> SbState -> Word32 -> IO ()
sendEncFile Handle
h Builder -> IO ()
send SbState
sbState Word32
sz

receiveEncryptedFile :: C.SbKeyNonce -> (Int -> IO ByteString) -> Word32 -> FileDigest -> FilePath -> ExceptT RemoteProtocolError IO ()
receiveEncryptedFile :: SbKeyNonce
-> (Int -> IO ByteString)
-> Word32
-> FileDigest
-> FilePath
-> ExceptT RemoteProtocolError IO ()
receiveEncryptedFile (SbKey
sk, CbNonce
nonce) Int -> IO ByteString
getChunk Word32
fileSize FileDigest
fileDigest FilePath
toPath = do
  ByteString
c <- IO ByteString -> ExceptT RemoteProtocolError IO ByteString
forall a. IO a -> ExceptT RemoteProtocolError IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> ExceptT RemoteProtocolError IO ByteString)
-> IO ByteString -> ExceptT RemoteProtocolError IO ByteString
forall a b. (a -> b) -> a -> b
$ Int -> IO ByteString
getChunk Int
1
  Bool
-> ExceptT RemoteProtocolError IO ()
-> ExceptT RemoteProtocolError IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString
c ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"\x01") (ExceptT RemoteProtocolError IO ()
 -> ExceptT RemoteProtocolError IO ())
-> ExceptT RemoteProtocolError IO ()
-> ExceptT RemoteProtocolError IO ()
forall a b. (a -> b) -> a -> b
$ RemoteProtocolError -> ExceptT RemoteProtocolError IO ()
forall a. RemoteProtocolError -> ExceptT RemoteProtocolError IO a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError RemoteProtocolError
RPENoFile
  Word32
size <- (FilePath -> RemoteProtocolError)
-> IO (Either FilePath Word32)
-> ExceptT RemoteProtocolError IO Word32
forall (m :: * -> *) e e' a.
MonadIO m =>
(e -> e') -> IO (Either e a) -> ExceptT e' m a
liftError' FilePath -> RemoteProtocolError
RPEInvalidBody (IO (Either FilePath Word32)
 -> ExceptT RemoteProtocolError IO Word32)
-> IO (Either FilePath Word32)
-> ExceptT RemoteProtocolError IO Word32
forall a b. (a -> b) -> a -> b
$ ByteString -> Either FilePath Word32
forall a. Encoding a => ByteString -> Either FilePath a
smpDecode (ByteString -> Either FilePath Word32)
-> IO ByteString -> IO (Either FilePath Word32)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO ByteString
getChunk Int
4
  Bool
-> ExceptT RemoteProtocolError IO ()
-> ExceptT RemoteProtocolError IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Word32
size Word32 -> Word32 -> Bool
forall a. Eq a => a -> a -> Bool
== Word32
fileSize Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
C.authTagSize) (ExceptT RemoteProtocolError IO ()
 -> ExceptT RemoteProtocolError IO ())
-> ExceptT RemoteProtocolError IO ()
-> ExceptT RemoteProtocolError IO ()
forall a b. (a -> b) -> a -> b
$ RemoteProtocolError -> ExceptT RemoteProtocolError IO ()
forall a. RemoteProtocolError -> ExceptT RemoteProtocolError IO a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError RemoteProtocolError
RPEFileSize
  SbState
sbState <- (CryptoError -> RemoteProtocolError)
-> Either CryptoError SbState
-> ExceptT RemoteProtocolError IO SbState
forall (m :: * -> *) e e' a.
MonadIO m =>
(e -> e') -> Either e a -> ExceptT e' m a
liftEitherWith (RemoteProtocolError -> CryptoError -> RemoteProtocolError
forall a b. a -> b -> a
const (RemoteProtocolError -> CryptoError -> RemoteProtocolError)
-> RemoteProtocolError -> CryptoError -> RemoteProtocolError
forall a b. (a -> b) -> a -> b
$ RCErrorType -> RemoteProtocolError
PRERemoteControl RCErrorType
RCEDecrypt) (Either CryptoError SbState
 -> ExceptT RemoteProtocolError IO SbState)
-> Either CryptoError SbState
-> ExceptT RemoteProtocolError IO SbState
forall a b. (a -> b) -> a -> b
$ SbKey -> CbNonce -> Either CryptoError SbState
LC.sbInit SbKey
sk CbNonce
nonce
  (ReceiveFileError -> RemoteProtocolError)
-> IO (Either ReceiveFileError ())
-> ExceptT RemoteProtocolError IO ()
forall (m :: * -> *) e e' a.
MonadIO m =>
(e -> e') -> IO (Either e a) -> ExceptT e' m a
liftError' ReceiveFileError -> RemoteProtocolError
fErr (IO (Either ReceiveFileError ())
 -> ExceptT RemoteProtocolError IO ())
-> IO (Either ReceiveFileError ())
-> ExceptT RemoteProtocolError IO ()
forall a b. (a -> b) -> a -> b
$ FilePath
-> IOMode
-> (Handle -> IO (Either ReceiveFileError ()))
-> IO (Either ReceiveFileError ())
forall (m :: * -> *) a.
MonadUnliftIO m =>
FilePath -> IOMode -> (Handle -> m a) -> m a
withFile FilePath
toPath IOMode
WriteMode ((Handle -> IO (Either ReceiveFileError ()))
 -> IO (Either ReceiveFileError ()))
-> (Handle -> IO (Either ReceiveFileError ()))
-> IO (Either ReceiveFileError ())
forall a b. (a -> b) -> a -> b
$ \Handle
h -> (Int -> IO ByteString)
-> Handle -> SbState -> Word32 -> IO (Either ReceiveFileError ())
receiveSbFile Int -> IO ByteString
getChunk Handle
h SbState
sbState Word32
fileSize
  ByteString
digest <- IO ByteString -> ExceptT RemoteProtocolError IO ByteString
forall a. IO a -> ExceptT RemoteProtocolError IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> ExceptT RemoteProtocolError IO ByteString)
-> IO ByteString -> ExceptT RemoteProtocolError IO ByteString
forall a b. (a -> b) -> a -> b
$ LazyByteString -> ByteString
LC.sha512Hash (LazyByteString -> ByteString)
-> IO LazyByteString -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FilePath -> IO LazyByteString
LB.readFile FilePath
toPath
  Bool
-> ExceptT RemoteProtocolError IO ()
-> ExceptT RemoteProtocolError IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString -> FileDigest
FileDigest ByteString
digest FileDigest -> FileDigest -> Bool
forall a. Eq a => a -> a -> Bool
== FileDigest
fileDigest) (ExceptT RemoteProtocolError IO ()
 -> ExceptT RemoteProtocolError IO ())
-> ExceptT RemoteProtocolError IO ()
-> ExceptT RemoteProtocolError IO ()
forall a b. (a -> b) -> a -> b
$ RemoteProtocolError -> ExceptT RemoteProtocolError IO ()
forall a. RemoteProtocolError -> ExceptT RemoteProtocolError IO a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError RemoteProtocolError
RPEFileDigest
  where
    fErr :: ReceiveFileError -> RemoteProtocolError
fErr ReceiveFileError
RFESize = RemoteProtocolError
RPEFileSize
    fErr ReceiveFileError
RFECrypto = RCErrorType -> RemoteProtocolError
PRERemoteControl RCErrorType
RCEDecrypt

getFileInfo :: FilePath -> ExceptT RemoteProtocolError IO (Word32, FileDigest)
getFileInfo :: FilePath -> ExceptT RemoteProtocolError IO (Word32, FileDigest)
getFileInfo FilePath
filePath = do
  FileDigest
fileDigest <- IO FileDigest -> ExceptT RemoteProtocolError IO FileDigest
forall a. IO a -> ExceptT RemoteProtocolError IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO FileDigest -> ExceptT RemoteProtocolError IO FileDigest)
-> IO FileDigest -> ExceptT RemoteProtocolError IO FileDigest
forall a b. (a -> b) -> a -> b
$ ByteString -> FileDigest
FileDigest (ByteString -> FileDigest)
-> (LazyByteString -> ByteString) -> LazyByteString -> FileDigest
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LazyByteString -> ByteString
LC.sha512Hash (LazyByteString -> FileDigest)
-> IO LazyByteString -> IO FileDigest
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FilePath -> IO LazyByteString
LB.readFile FilePath
filePath
  Integer
fileSize' <- FilePath -> ExceptT RemoteProtocolError IO Integer
forall (m :: * -> *). MonadIO m => FilePath -> m Integer
getFileSize FilePath
filePath
  Bool
-> ExceptT RemoteProtocolError IO ()
-> ExceptT RemoteProtocolError IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Integer
fileSize' Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Word32 -> Integer
forall a. Integral a => a -> Integer
toInteger (Word32
forall a. Bounded a => a
maxBound :: Word32)) (ExceptT RemoteProtocolError IO ()
 -> ExceptT RemoteProtocolError IO ())
-> ExceptT RemoteProtocolError IO ()
-> ExceptT RemoteProtocolError IO ()
forall a b. (a -> b) -> a -> b
$ RemoteProtocolError -> ExceptT RemoteProtocolError IO ()
forall a. RemoteProtocolError -> ExceptT RemoteProtocolError IO a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError RemoteProtocolError
RPEFileSize
  (Word32, FileDigest)
-> ExceptT RemoteProtocolError IO (Word32, FileDigest)
forall a. a -> ExceptT RemoteProtocolError IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Integer -> Word32
forall a. Num a => Integer -> a
fromInteger Integer
fileSize', FileDigest
fileDigest)