{-# LANGUAGE GADTs #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} module Simplex.Messaging.Transport.WebSockets (WS (..)) where import qualified Control.Exception as E import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import qualified Data.ByteString.Lazy as LB import qualified Data.X509 as X import qualified Network.TLS as T import Network.WebSockets import Network.WebSockets.Stream (Stream) import qualified Network.WebSockets.Stream as S import Simplex.Messaging.Transport ( ALPN, Transport (..), TransportConfig (..), TransportError (..), TransportPeer (..), STransportPeer (..), TransportPeerI (..), closeTLS, smpBlockSize, withTlsUnique, ) import Simplex.Messaging.Transport.Buffer (trimCR) import System.IO.Error (isEOFError) data WS (p :: TransportPeer) = WS { forall (p :: TransportPeer). WS p -> ByteString tlsUniq :: ByteString, forall (p :: TransportPeer). WS p -> Maybe ByteString wsALPN :: Maybe ALPN, forall (p :: TransportPeer). WS p -> Stream wsStream :: Stream, forall (p :: TransportPeer). WS p -> Connection wsConnection :: Connection, forall (p :: TransportPeer). WS p -> TransportConfig wsTransportConfig :: TransportConfig, forall (p :: TransportPeer). WS p -> Bool wsCertSent :: Bool, forall (p :: TransportPeer). WS p -> CertificateChain wsPeerCert :: X.CertificateChain } websocketsOpts :: ConnectionOptions websocketsOpts :: ConnectionOptions websocketsOpts = ConnectionOptions defaultConnectionOptions { connectionCompressionOptions = NoCompression, connectionFramePayloadSizeLimit = SizeLimit $ fromIntegral smpBlockSize, connectionMessageDataSizeLimit = SizeLimit 65536 } instance Transport WS where transportName :: forall (p :: TransportPeer). TProxy WS p -> String transportName TProxy WS p _ = String "WebSockets" {-# INLINE transportName #-} transportConfig :: forall (p :: TransportPeer). WS p -> TransportConfig transportConfig = WS p -> TransportConfig forall (p :: TransportPeer). WS p -> TransportConfig wsTransportConfig {-# INLINE transportConfig #-} getTransportConnection :: forall (p :: TransportPeer). TransportPeerI p => TransportConfig -> Bool -> CertificateChain -> Context -> IO (WS p) getTransportConnection = TransportConfig -> Bool -> CertificateChain -> Context -> IO (WS p) forall (p :: TransportPeer). TransportPeerI p => TransportConfig -> Bool -> CertificateChain -> Context -> IO (WS p) getWS {-# INLINE getTransportConnection #-} certificateSent :: forall (p :: TransportPeer). WS p -> Bool certificateSent = WS p -> Bool forall (p :: TransportPeer). WS p -> Bool wsCertSent {-# INLINE certificateSent #-} getPeerCertChain :: forall (p :: TransportPeer). WS p -> CertificateChain getPeerCertChain = WS p -> CertificateChain forall (p :: TransportPeer). WS p -> CertificateChain wsPeerCert {-# INLINE getPeerCertChain #-} getSessionALPN :: forall (p :: TransportPeer). WS p -> Maybe ByteString getSessionALPN = WS p -> Maybe ByteString forall (p :: TransportPeer). WS p -> Maybe ByteString wsALPN {-# INLINE getSessionALPN #-} tlsUnique :: forall (p :: TransportPeer). WS p -> ByteString tlsUnique = WS p -> ByteString forall (p :: TransportPeer). WS p -> ByteString tlsUniq {-# INLINE tlsUnique #-} closeConnection :: forall (p :: TransportPeer). WS p -> IO () closeConnection = Stream -> IO () S.close (Stream -> IO ()) -> (WS p -> Stream) -> WS p -> IO () forall b c a. (b -> c) -> (a -> b) -> a -> c . WS p -> Stream forall (p :: TransportPeer). WS p -> Stream wsStream {-# INLINE closeConnection #-} cGet :: WS p -> Int -> IO ByteString cGet :: forall (p :: TransportPeer). WS p -> Int -> IO ByteString cGet WS p c Int n = do ByteString s <- Connection -> IO ByteString forall a. WebSocketsData a => Connection -> IO a receiveData (WS p -> Connection forall (p :: TransportPeer). WS p -> Connection wsConnection WS p c) if ByteString -> Int B.length ByteString s Int -> Int -> Bool forall a. Eq a => a -> a -> Bool == Int n then ByteString -> IO ByteString forall a. a -> IO a forall (f :: * -> *) a. Applicative f => a -> f a pure ByteString s else TransportError -> IO ByteString forall e a. Exception e => e -> IO a E.throwIO TransportError TEBadBlock cPut :: WS p -> ByteString -> IO () cPut :: forall (p :: TransportPeer). WS p -> ByteString -> IO () cPut = Connection -> ByteString -> IO () forall a. WebSocketsData a => Connection -> a -> IO () sendBinaryData (Connection -> ByteString -> IO ()) -> (WS p -> Connection) -> WS p -> ByteString -> IO () forall b c a. (b -> c) -> (a -> b) -> a -> c . WS p -> Connection forall (p :: TransportPeer). WS p -> Connection wsConnection getLn :: WS p -> IO ByteString getLn :: forall (p :: TransportPeer). WS p -> IO ByteString getLn WS p c = do ByteString s <- ByteString -> ByteString trimCR (ByteString -> ByteString) -> IO ByteString -> IO ByteString forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <$> Connection -> IO ByteString forall a. WebSocketsData a => Connection -> IO a receiveData (WS p -> Connection forall (p :: TransportPeer). WS p -> Connection wsConnection WS p c) if ByteString -> Bool B.null ByteString s Bool -> Bool -> Bool || ByteString -> Char B.last ByteString s Char -> Char -> Bool forall a. Eq a => a -> a -> Bool /= Char '\n' then TransportError -> IO ByteString forall e a. Exception e => e -> IO a E.throwIO TransportError TEBadBlock else ByteString -> IO ByteString forall a. a -> IO a forall (f :: * -> *) a. Applicative f => a -> f a pure (ByteString -> IO ByteString) -> ByteString -> IO ByteString forall a b. (a -> b) -> a -> b $ HasCallStack => ByteString -> ByteString ByteString -> ByteString B.init ByteString s getWS :: forall p. TransportPeerI p => TransportConfig -> Bool -> X.CertificateChain -> T.Context -> IO (WS p) getWS :: forall (p :: TransportPeer). TransportPeerI p => TransportConfig -> Bool -> CertificateChain -> Context -> IO (WS p) getWS TransportConfig cfg Bool wsCertSent CertificateChain wsPeerCert Context cxt = forall (c :: TransportPeer -> *) (p :: TransportPeer). TransportPeerI p => Context -> (ByteString -> IO (c p)) -> IO (c p) withTlsUnique @WS @p Context cxt ByteString -> IO (WS p) connectWS where connectWS :: ByteString -> IO (WS p) connectWS ByteString tlsUniq = do Stream s <- Context -> IO Stream makeTLSContextStream Context cxt Connection wsConnection <- Stream -> IO Connection connectPeer Stream s Maybe ByteString wsALPN <- Context -> IO (Maybe ByteString) forall (m :: * -> *). MonadIO m => Context -> m (Maybe ByteString) T.getNegotiatedProtocol Context cxt WS p -> IO (WS p) forall a. a -> IO a forall (f :: * -> *) a. Applicative f => a -> f a pure (WS p -> IO (WS p)) -> WS p -> IO (WS p) forall a b. (a -> b) -> a -> b $ WS {ByteString tlsUniq :: ByteString tlsUniq :: ByteString tlsUniq, Maybe ByteString wsALPN :: Maybe ByteString wsALPN :: Maybe ByteString wsALPN, wsStream :: Stream wsStream = Stream s, Connection wsConnection :: Connection wsConnection :: Connection wsConnection, wsTransportConfig :: TransportConfig wsTransportConfig = TransportConfig cfg, Bool wsCertSent :: Bool wsCertSent :: Bool wsCertSent, CertificateChain wsPeerCert :: CertificateChain wsPeerCert :: CertificateChain wsPeerCert} connectPeer :: Stream -> IO Connection connectPeer :: Stream -> IO Connection connectPeer = case forall (p :: TransportPeer). TransportPeerI p => STransportPeer p sTransportPeer @p of STransportPeer p STServer -> Stream -> IO Connection acceptClientRequest STransportPeer p STClient -> Stream -> IO Connection sendClientRequest acceptClientRequest :: Stream -> IO Connection acceptClientRequest Stream s = Stream -> ConnectionOptions -> IO PendingConnection makePendingConnectionFromStream Stream s ConnectionOptions websocketsOpts IO PendingConnection -> (PendingConnection -> IO Connection) -> IO Connection forall a b. IO a -> (a -> IO b) -> IO b forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b >>= PendingConnection -> IO Connection acceptRequest sendClientRequest :: Stream -> IO Connection sendClientRequest Stream s = Stream -> String -> String -> ConnectionOptions -> Headers -> IO Connection newClientConnection Stream s String "" String "/" ConnectionOptions websocketsOpts [] makeTLSContextStream :: T.Context -> IO Stream makeTLSContextStream :: Context -> IO Stream makeTLSContextStream Context cxt = IO (Maybe ByteString) -> (Maybe ByteString -> IO ()) -> IO Stream S.makeStream IO (Maybe ByteString) readStream Maybe ByteString -> IO () writeStream where readStream :: IO (Maybe ByteString) readStream :: IO (Maybe ByteString) readStream = (ByteString -> Maybe ByteString forall a. a -> Maybe a Just (ByteString -> Maybe ByteString) -> IO ByteString -> IO (Maybe ByteString) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <$> Context -> IO ByteString forall (m :: * -> *). MonadIO m => Context -> m ByteString T.recvData Context cxt) IO (Maybe ByteString) -> [Handler (Maybe ByteString)] -> IO (Maybe ByteString) forall a. IO a -> [Handler a] -> IO a `E.catches` [(TLSException -> IO (Maybe ByteString)) -> Handler (Maybe ByteString) forall a e. Exception e => (e -> IO a) -> Handler a E.Handler TLSException -> IO (Maybe ByteString) forall {a}. TLSException -> IO (Maybe a) handleTlsEOF, (IOError -> IO (Maybe ByteString)) -> Handler (Maybe ByteString) forall a e. Exception e => (e -> IO a) -> Handler a E.Handler IOError -> IO (Maybe ByteString) forall {a}. IOError -> IO (Maybe a) handleEOF] where handleTlsEOF :: TLSException -> IO (Maybe a) handleTlsEOF = \case T.PostHandshake TLSError T.Error_EOF -> Maybe a -> IO (Maybe a) forall a. a -> IO a forall (f :: * -> *) a. Applicative f => a -> f a pure Maybe a forall a. Maybe a Nothing TLSException e -> TLSException -> IO (Maybe a) forall e a. Exception e => e -> IO a E.throwIO TLSException e handleEOF :: IOError -> IO (Maybe a) handleEOF IOError e = if IOError -> Bool isEOFError IOError e then Maybe a -> IO (Maybe a) forall a. a -> IO a forall (f :: * -> *) a. Applicative f => a -> f a pure Maybe a forall a. Maybe a Nothing else IOError -> IO (Maybe a) forall e a. Exception e => e -> IO a E.throwIO IOError e writeStream :: Maybe LB.ByteString -> IO () writeStream :: Maybe ByteString -> IO () writeStream = IO () -> (ByteString -> IO ()) -> Maybe ByteString -> IO () forall b a. b -> (a -> b) -> Maybe a -> b maybe (Context -> IO () closeTLS Context cxt) (Context -> ByteString -> IO () forall (m :: * -> *). MonadIO m => Context -> ByteString -> m () T.sendData Context cxt)