{-# LANGUAGE GADTs #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

module Simplex.Messaging.Transport.WebSockets (WS (..)) where

import qualified Control.Exception as E
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy as LB
import qualified Data.X509 as X
import qualified Network.TLS as T
import Network.WebSockets
import Network.WebSockets.Stream (Stream)
import qualified Network.WebSockets.Stream as S
import Simplex.Messaging.Transport
  ( ALPN,
    Transport (..),
    TransportConfig (..),
    TransportError (..),
    TransportPeer (..),
    STransportPeer (..),
    TransportPeerI (..),
    closeTLS,
    smpBlockSize,
    withTlsUnique,
  )
import Simplex.Messaging.Transport.Buffer (trimCR)
import System.IO.Error (isEOFError)

data WS (p :: TransportPeer) = WS
  { forall (p :: TransportPeer). WS p -> ByteString
tlsUniq :: ByteString,
    forall (p :: TransportPeer). WS p -> Maybe ByteString
wsALPN :: Maybe ALPN,
    forall (p :: TransportPeer). WS p -> Stream
wsStream :: Stream,
    forall (p :: TransportPeer). WS p -> Connection
wsConnection :: Connection,
    forall (p :: TransportPeer). WS p -> TransportConfig
wsTransportConfig :: TransportConfig,
    forall (p :: TransportPeer). WS p -> Bool
wsCertSent :: Bool,
    forall (p :: TransportPeer). WS p -> CertificateChain
wsPeerCert :: X.CertificateChain
  }

websocketsOpts :: ConnectionOptions
websocketsOpts :: ConnectionOptions
websocketsOpts =
  ConnectionOptions
defaultConnectionOptions
    { connectionCompressionOptions = NoCompression,
      connectionFramePayloadSizeLimit = SizeLimit $ fromIntegral smpBlockSize,
      connectionMessageDataSizeLimit = SizeLimit 65536
    }

instance Transport WS where
  transportName :: forall (p :: TransportPeer). TProxy WS p -> String
transportName TProxy WS p
_ = String
"WebSockets"
  {-# INLINE transportName #-}
  transportConfig :: forall (p :: TransportPeer). WS p -> TransportConfig
transportConfig = WS p -> TransportConfig
forall (p :: TransportPeer). WS p -> TransportConfig
wsTransportConfig
  {-# INLINE transportConfig #-}
  getTransportConnection :: forall (p :: TransportPeer).
TransportPeerI p =>
TransportConfig -> Bool -> CertificateChain -> Context -> IO (WS p)
getTransportConnection = TransportConfig -> Bool -> CertificateChain -> Context -> IO (WS p)
forall (p :: TransportPeer).
TransportPeerI p =>
TransportConfig -> Bool -> CertificateChain -> Context -> IO (WS p)
getWS
  {-# INLINE getTransportConnection #-}
  certificateSent :: forall (p :: TransportPeer). WS p -> Bool
certificateSent = WS p -> Bool
forall (p :: TransportPeer). WS p -> Bool
wsCertSent
  {-# INLINE certificateSent #-}
  getPeerCertChain :: forall (p :: TransportPeer). WS p -> CertificateChain
getPeerCertChain = WS p -> CertificateChain
forall (p :: TransportPeer). WS p -> CertificateChain
wsPeerCert
  {-# INLINE getPeerCertChain #-}
  getSessionALPN :: forall (p :: TransportPeer). WS p -> Maybe ByteString
getSessionALPN = WS p -> Maybe ByteString
forall (p :: TransportPeer). WS p -> Maybe ByteString
wsALPN
  {-# INLINE getSessionALPN #-}
  tlsUnique :: forall (p :: TransportPeer). WS p -> ByteString
tlsUnique = WS p -> ByteString
forall (p :: TransportPeer). WS p -> ByteString
tlsUniq
  {-# INLINE tlsUnique #-}
  closeConnection :: forall (p :: TransportPeer). WS p -> IO ()
closeConnection = Stream -> IO ()
S.close (Stream -> IO ()) -> (WS p -> Stream) -> WS p -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WS p -> Stream
forall (p :: TransportPeer). WS p -> Stream
wsStream
  {-# INLINE closeConnection #-}

  cGet :: WS p -> Int -> IO ByteString
  cGet :: forall (p :: TransportPeer). WS p -> Int -> IO ByteString
cGet WS p
c Int
n = do
    ByteString
s <- Connection -> IO ByteString
forall a. WebSocketsData a => Connection -> IO a
receiveData (WS p -> Connection
forall (p :: TransportPeer). WS p -> Connection
wsConnection WS p
c)
    if ByteString -> Int
B.length ByteString
s Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n
      then ByteString -> IO ByteString
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
s
      else TransportError -> IO ByteString
forall e a. Exception e => e -> IO a
E.throwIO TransportError
TEBadBlock

  cPut :: WS p -> ByteString -> IO ()
  cPut :: forall (p :: TransportPeer). WS p -> ByteString -> IO ()
cPut = Connection -> ByteString -> IO ()
forall a. WebSocketsData a => Connection -> a -> IO ()
sendBinaryData (Connection -> ByteString -> IO ())
-> (WS p -> Connection) -> WS p -> ByteString -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WS p -> Connection
forall (p :: TransportPeer). WS p -> Connection
wsConnection

  getLn :: WS p -> IO ByteString
  getLn :: forall (p :: TransportPeer). WS p -> IO ByteString
getLn WS p
c = do
    ByteString
s <- ByteString -> ByteString
trimCR (ByteString -> ByteString) -> IO ByteString -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection -> IO ByteString
forall a. WebSocketsData a => Connection -> IO a
receiveData (WS p -> Connection
forall (p :: TransportPeer). WS p -> Connection
wsConnection WS p
c)
    if ByteString -> Bool
B.null ByteString
s Bool -> Bool -> Bool
|| ByteString -> Char
B.last ByteString
s Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
/= Char
'\n'
      then TransportError -> IO ByteString
forall e a. Exception e => e -> IO a
E.throwIO TransportError
TEBadBlock
      else ByteString -> IO ByteString
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ HasCallStack => ByteString -> ByteString
ByteString -> ByteString
B.init ByteString
s

getWS :: forall p. TransportPeerI p => TransportConfig -> Bool -> X.CertificateChain -> T.Context -> IO (WS p)
getWS :: forall (p :: TransportPeer).
TransportPeerI p =>
TransportConfig -> Bool -> CertificateChain -> Context -> IO (WS p)
getWS TransportConfig
cfg Bool
wsCertSent CertificateChain
wsPeerCert Context
cxt = forall (c :: TransportPeer -> *) (p :: TransportPeer).
TransportPeerI p =>
Context -> (ByteString -> IO (c p)) -> IO (c p)
withTlsUnique @WS @p Context
cxt ByteString -> IO (WS p)
connectWS
  where
    connectWS :: ByteString -> IO (WS p)
connectWS ByteString
tlsUniq = do
      Stream
s <- Context -> IO Stream
makeTLSContextStream Context
cxt
      Connection
wsConnection <- Stream -> IO Connection
connectPeer Stream
s
      Maybe ByteString
wsALPN <- Context -> IO (Maybe ByteString)
forall (m :: * -> *). MonadIO m => Context -> m (Maybe ByteString)
T.getNegotiatedProtocol Context
cxt
      WS p -> IO (WS p)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (WS p -> IO (WS p)) -> WS p -> IO (WS p)
forall a b. (a -> b) -> a -> b
$ WS {ByteString
tlsUniq :: ByteString
tlsUniq :: ByteString
tlsUniq, Maybe ByteString
wsALPN :: Maybe ByteString
wsALPN :: Maybe ByteString
wsALPN, wsStream :: Stream
wsStream = Stream
s, Connection
wsConnection :: Connection
wsConnection :: Connection
wsConnection, wsTransportConfig :: TransportConfig
wsTransportConfig = TransportConfig
cfg, Bool
wsCertSent :: Bool
wsCertSent :: Bool
wsCertSent, CertificateChain
wsPeerCert :: CertificateChain
wsPeerCert :: CertificateChain
wsPeerCert}
    connectPeer :: Stream -> IO Connection
    connectPeer :: Stream -> IO Connection
connectPeer = case forall (p :: TransportPeer). TransportPeerI p => STransportPeer p
sTransportPeer @p of
      STransportPeer p
STServer -> Stream -> IO Connection
acceptClientRequest
      STransportPeer p
STClient -> Stream -> IO Connection
sendClientRequest
    acceptClientRequest :: Stream -> IO Connection
acceptClientRequest Stream
s = Stream -> ConnectionOptions -> IO PendingConnection
makePendingConnectionFromStream Stream
s ConnectionOptions
websocketsOpts IO PendingConnection
-> (PendingConnection -> IO Connection) -> IO Connection
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= PendingConnection -> IO Connection
acceptRequest
    sendClientRequest :: Stream -> IO Connection
sendClientRequest Stream
s = Stream
-> String
-> String
-> ConnectionOptions
-> Headers
-> IO Connection
newClientConnection Stream
s String
"" String
"/" ConnectionOptions
websocketsOpts []

makeTLSContextStream :: T.Context -> IO Stream
makeTLSContextStream :: Context -> IO Stream
makeTLSContextStream Context
cxt =
  IO (Maybe ByteString) -> (Maybe ByteString -> IO ()) -> IO Stream
S.makeStream IO (Maybe ByteString)
readStream Maybe ByteString -> IO ()
writeStream
  where
    readStream :: IO (Maybe ByteString)
    readStream :: IO (Maybe ByteString)
readStream = (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> Maybe ByteString)
-> IO ByteString -> IO (Maybe ByteString)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> IO ByteString
forall (m :: * -> *). MonadIO m => Context -> m ByteString
T.recvData Context
cxt) IO (Maybe ByteString)
-> [Handler (Maybe ByteString)] -> IO (Maybe ByteString)
forall a. IO a -> [Handler a] -> IO a
`E.catches` [(TLSException -> IO (Maybe ByteString))
-> Handler (Maybe ByteString)
forall a e. Exception e => (e -> IO a) -> Handler a
E.Handler TLSException -> IO (Maybe ByteString)
forall {a}. TLSException -> IO (Maybe a)
handleTlsEOF, (IOError -> IO (Maybe ByteString)) -> Handler (Maybe ByteString)
forall a e. Exception e => (e -> IO a) -> Handler a
E.Handler IOError -> IO (Maybe ByteString)
forall {a}. IOError -> IO (Maybe a)
handleEOF]
      where
        handleTlsEOF :: TLSException -> IO (Maybe a)
handleTlsEOF = \case
          T.PostHandshake TLSError
T.Error_EOF -> Maybe a -> IO (Maybe a)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe a
forall a. Maybe a
Nothing
          TLSException
e -> TLSException -> IO (Maybe a)
forall e a. Exception e => e -> IO a
E.throwIO TLSException
e
        handleEOF :: IOError -> IO (Maybe a)
handleEOF IOError
e = if IOError -> Bool
isEOFError IOError
e then Maybe a -> IO (Maybe a)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe a
forall a. Maybe a
Nothing else IOError -> IO (Maybe a)
forall e a. Exception e => e -> IO a
E.throwIO IOError
e
    writeStream :: Maybe LB.ByteString -> IO ()
    writeStream :: Maybe ByteString -> IO ()
writeStream = IO () -> (ByteString -> IO ()) -> Maybe ByteString -> IO ()
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Context -> IO ()
closeTLS Context
cxt) (Context -> ByteString -> IO ()
forall (m :: * -> *). MonadIO m => Context -> ByteString -> m ()
T.sendData Context
cxt)