{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}

module Simplex.Messaging.Transport.Server
  ( TransportServerConfig (..),
    ServerCredentials (..),
    TLSServerCredential (..),
    SNICredentialUsed,
    AddHTTP,
    mkTransportServerConfig,
    runTransportServerState,
    runTransportServerState_,
    SocketState,
    SocketStats (..),
    newSocketState,
    getSocketStats,
    runTransportServer,
    runTransportServerSocket,
    runLocalTCPServer,
    startTCPServer,
    loadServerCredential,
    loadFingerprint,
    loadFileFingerprint,
    smpServerHandshake,
  )
where

import Control.Applicative ((<|>))
import Control.Logger.Simple
import Control.Monad
import qualified Crypto.Store.X509 as SX
import qualified Data.ByteString as B
import Data.Default (def)
import Data.IntMap.Strict (IntMap)
import qualified Data.IntMap.Strict as IM
import Data.List (find)
import Data.Maybe (fromJust, fromMaybe, maybeToList)
import qualified Data.X509 as X
import Data.X509.Validation (Fingerprint (..))
import qualified Data.X509.Validation as XV
import Foreign.C.Error
import GHC.IO.Exception (ioe_errno)
import Network.Socket
import qualified Network.TLS as T
import Simplex.Messaging.Transport
import Simplex.Messaging.Transport.Shared
import Simplex.Messaging.Util (catchAll_, labelMyThread, tshow, unlessM)
import System.Exit (exitFailure)
import System.IO.Error (tryIOError)
import System.Mem.Weak (Weak, deRefWeak)
import UnliftIO (timeout)
import UnliftIO.Concurrent
import qualified UnliftIO.Exception as E
import UnliftIO.STM

data TransportServerConfig = TransportServerConfig
  { TransportServerConfig -> Bool
logTLSErrors :: Bool,
    TransportServerConfig -> Maybe [ByteString]
serverALPN :: Maybe [ALPN],
    TransportServerConfig -> Bool
askClientCert :: Bool,
    TransportServerConfig -> Bool
addCORSHeaders :: Bool,
    TransportServerConfig -> Int
tlsSetupTimeout :: Int,
    TransportServerConfig -> Int
transportTimeout :: Int
  }
  deriving (TransportServerConfig -> TransportServerConfig -> Bool
(TransportServerConfig -> TransportServerConfig -> Bool)
-> (TransportServerConfig -> TransportServerConfig -> Bool)
-> Eq TransportServerConfig
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: TransportServerConfig -> TransportServerConfig -> Bool
== :: TransportServerConfig -> TransportServerConfig -> Bool
$c/= :: TransportServerConfig -> TransportServerConfig -> Bool
/= :: TransportServerConfig -> TransportServerConfig -> Bool
Eq, Int -> TransportServerConfig -> ShowS
[TransportServerConfig] -> ShowS
TransportServerConfig -> HostName
(Int -> TransportServerConfig -> ShowS)
-> (TransportServerConfig -> HostName)
-> ([TransportServerConfig] -> ShowS)
-> Show TransportServerConfig
forall a.
(Int -> a -> ShowS) -> (a -> HostName) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> TransportServerConfig -> ShowS
showsPrec :: Int -> TransportServerConfig -> ShowS
$cshow :: TransportServerConfig -> HostName
show :: TransportServerConfig -> HostName
$cshowList :: [TransportServerConfig] -> ShowS
showList :: [TransportServerConfig] -> ShowS
Show)

data ServerCredentials = ServerCredentials
  { ServerCredentials -> Maybe HostName
caCertificateFile :: Maybe FilePath, -- CA certificate private key is not needed for initialization
    ServerCredentials -> HostName
privateKeyFile :: FilePath,
    ServerCredentials -> HostName
certificateFile :: FilePath
  }
  deriving (Int -> ServerCredentials -> ShowS
[ServerCredentials] -> ShowS
ServerCredentials -> HostName
(Int -> ServerCredentials -> ShowS)
-> (ServerCredentials -> HostName)
-> ([ServerCredentials] -> ShowS)
-> Show ServerCredentials
forall a.
(Int -> a -> ShowS) -> (a -> HostName) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ServerCredentials -> ShowS
showsPrec :: Int -> ServerCredentials -> ShowS
$cshow :: ServerCredentials -> HostName
show :: ServerCredentials -> HostName
$cshowList :: [ServerCredentials] -> ShowS
showList :: [ServerCredentials] -> ShowS
Show)

type AddHTTP = Bool

data TLSServerCredential = TLSServerCredential
  { TLSServerCredential -> Credential
credential :: T.Credential,
    -- `sniCredential` is used when SNI is sent by the client.
    --  It is needed to provide different credential when the server is accessed from the browser.
    TLSServerCredential -> Maybe Credential
sniCredential :: Maybe T.Credential
  }

type SNICredentialUsed = Bool

mkTransportServerConfig :: Bool -> Maybe [ALPN] -> Bool -> TransportServerConfig
mkTransportServerConfig :: Bool -> Maybe [ByteString] -> Bool -> TransportServerConfig
mkTransportServerConfig Bool
logTLSErrors Maybe [ByteString]
serverALPN Bool
askClientCert =
  TransportServerConfig
    { Bool
$sel:logTLSErrors:TransportServerConfig :: Bool
logTLSErrors :: Bool
logTLSErrors,
      Maybe [ByteString]
$sel:serverALPN:TransportServerConfig :: Maybe [ByteString]
serverALPN :: Maybe [ByteString]
serverALPN,
      Bool
$sel:askClientCert:TransportServerConfig :: Bool
askClientCert :: Bool
askClientCert,
      $sel:addCORSHeaders:TransportServerConfig :: Bool
addCORSHeaders = Bool
False,
      $sel:tlsSetupTimeout:TransportServerConfig :: Int
tlsSetupTimeout = Int
60000000,
      $sel:transportTimeout:TransportServerConfig :: Int
transportTimeout = Int
40000000
    }

serverTransportConfig :: TransportServerConfig -> TransportConfig
serverTransportConfig :: TransportServerConfig -> TransportConfig
serverTransportConfig TransportServerConfig {Bool
$sel:logTLSErrors:TransportServerConfig :: TransportServerConfig -> Bool
logTLSErrors :: Bool
logTLSErrors} =
  -- TransportConfig {logTLSErrors, transportTimeout = Just transportTimeout}
  TransportConfig {Bool
logTLSErrors :: Bool
$sel:logTLSErrors:TransportConfig :: Bool
logTLSErrors, $sel:transportTimeout:TransportConfig :: Maybe Int
transportTimeout = Maybe Int
forall a. Maybe a
Nothing}

-- | Run transport server (plain TCP or WebSockets) on passed TCP port and signal when server started and stopped via passed TMVar.
--
-- All accepted connections are passed to the passed function.
runTransportServer :: Transport c => TMVar Bool -> ServiceName -> T.Supported -> T.Credential -> TransportServerConfig -> (c 'TServer -> IO ()) -> IO ()
runTransportServer :: forall (c :: TransportPeer -> *).
Transport c =>
TMVar Bool
-> HostName
-> Supported
-> Credential
-> TransportServerConfig
-> (c 'TServer -> IO ())
-> IO ()
runTransportServer TMVar Bool
started HostName
port Supported
srvSupported Credential
srvCreds TransportServerConfig
cfg c 'TServer -> IO ()
server = do
  SocketState
ss <- IO SocketState
newSocketState
  SocketState
-> TMVar Bool
-> HostName
-> Supported
-> Credential
-> TransportServerConfig
-> (c 'TServer -> IO ())
-> IO ()
forall (c :: TransportPeer -> *).
Transport c =>
SocketState
-> TMVar Bool
-> HostName
-> Supported
-> Credential
-> TransportServerConfig
-> (c 'TServer -> IO ())
-> IO ()
runTransportServerState SocketState
ss TMVar Bool
started HostName
port Supported
srvSupported Credential
srvCreds TransportServerConfig
cfg c 'TServer -> IO ()
server

runTransportServerState :: Transport c => SocketState -> TMVar Bool -> ServiceName -> T.Supported -> T.Credential -> TransportServerConfig -> (c 'TServer -> IO ()) -> IO ()
runTransportServerState :: forall (c :: TransportPeer -> *).
Transport c =>
SocketState
-> TMVar Bool
-> HostName
-> Supported
-> Credential
-> TransportServerConfig
-> (c 'TServer -> IO ())
-> IO ()
runTransportServerState SocketState
ss TMVar Bool
started HostName
port Supported
srvSupported Credential
credential TransportServerConfig
cfg c 'TServer -> IO ()
server = SocketState
-> TMVar Bool
-> HostName
-> Supported
-> TLSServerCredential
-> TransportServerConfig
-> (Socket -> (Bool, c 'TServer) -> IO ())
-> IO ()
forall (c :: TransportPeer -> *).
Transport c =>
SocketState
-> TMVar Bool
-> HostName
-> Supported
-> TLSServerCredential
-> TransportServerConfig
-> (Socket -> (Bool, c 'TServer) -> IO ())
-> IO ()
runTransportServerState_ SocketState
ss TMVar Bool
started HostName
port Supported
srvSupported TLSServerCredential
srvCreds TransportServerConfig
cfg (\Socket
_ -> c 'TServer -> IO ()
server (c 'TServer -> IO ())
-> ((Bool, c 'TServer) -> c 'TServer)
-> (Bool, c 'TServer)
-> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Bool, c 'TServer) -> c 'TServer
forall a b. (a, b) -> b
snd)
  where
    srvCreds :: TLSServerCredential
srvCreds = TLSServerCredential {Credential
$sel:credential:TLSServerCredential :: Credential
credential :: Credential
credential, $sel:sniCredential:TLSServerCredential :: Maybe Credential
sniCredential = Maybe Credential
forall a. Maybe a
Nothing}

runTransportServerState_ :: forall c. Transport c => SocketState -> TMVar Bool -> ServiceName -> T.Supported -> TLSServerCredential -> TransportServerConfig -> (Socket -> (SNICredentialUsed, c 'TServer) -> IO ()) -> IO ()
runTransportServerState_ :: forall (c :: TransportPeer -> *).
Transport c =>
SocketState
-> TMVar Bool
-> HostName
-> Supported
-> TLSServerCredential
-> TransportServerConfig
-> (Socket -> (Bool, c 'TServer) -> IO ())
-> IO ()
runTransportServerState_ SocketState
ss TMVar Bool
started HostName
port = SocketState
-> TMVar Bool
-> IO Socket
-> HostName
-> Supported
-> TLSServerCredential
-> TransportServerConfig
-> (Socket -> (Bool, c 'TServer) -> IO ())
-> IO ()
forall (c :: TransportPeer -> *).
Transport c =>
SocketState
-> TMVar Bool
-> IO Socket
-> HostName
-> Supported
-> TLSServerCredential
-> TransportServerConfig
-> (Socket -> (Bool, c 'TServer) -> IO ())
-> IO ()
runTransportServerSocketState SocketState
ss TMVar Bool
started (TMVar Bool -> Maybe HostName -> HostName -> IO Socket
startTCPServer TMVar Bool
started Maybe HostName
forall a. Maybe a
Nothing HostName
port) (TProxy c 'TServer -> HostName
forall (p :: TransportPeer). TProxy c p -> HostName
forall (c :: TransportPeer -> *) (p :: TransportPeer).
Transport c =>
TProxy c p -> HostName
transportName (TProxy c 'TServer
forall (c :: TransportPeer -> *) (p :: TransportPeer). TProxy c p
TProxy :: TProxy c 'TServer))

-- | Run a transport server with provided connection setup and handler.
runTransportServerSocket :: Transport c => TMVar Bool -> IO Socket -> String -> T.ServerParams -> TransportServerConfig -> (c 'TServer -> IO ()) -> IO ()
runTransportServerSocket :: forall (c :: TransportPeer -> *).
Transport c =>
TMVar Bool
-> IO Socket
-> HostName
-> ServerParams
-> TransportServerConfig
-> (c 'TServer -> IO ())
-> IO ()
runTransportServerSocket TMVar Bool
started IO Socket
getSocket HostName
threadLabel ServerParams
srvParams TransportServerConfig
cfg c 'TServer -> IO ()
server = do
  SocketState
ss <- IO SocketState
newSocketState
  SocketState
-> TMVar Bool
-> IO Socket
-> HostName
-> Int
-> (Socket -> IO (Bool, c 'TServer))
-> (Socket -> (Bool, c 'TServer) -> IO ())
-> IO ()
forall (c :: TransportPeer -> *).
Transport c =>
SocketState
-> TMVar Bool
-> IO Socket
-> HostName
-> Int
-> (Socket -> IO (Bool, c 'TServer))
-> (Socket -> (Bool, c 'TServer) -> IO ())
-> IO ()
runTransportServerSocketState_ SocketState
ss TMVar Bool
started IO Socket
getSocket HostName
threadLabel (TransportServerConfig -> Int
tlsSetupTimeout TransportServerConfig
cfg) Socket -> IO (Bool, c 'TServer)
forall {c :: TransportPeer -> *} {p :: TransportPeer}.
(Transport c, TransportPeerI p) =>
Socket -> IO (Bool, c p)
setupTLS (\Socket
_ -> c 'TServer -> IO ()
server (c 'TServer -> IO ())
-> ((Bool, c 'TServer) -> c 'TServer)
-> (Bool, c 'TServer)
-> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Bool, c 'TServer) -> c 'TServer
forall a b. (a, b) -> b
snd)
  where
    tCfg :: TransportConfig
tCfg = TransportServerConfig -> TransportConfig
serverTransportConfig TransportServerConfig
cfg
    setupTLS :: Socket -> IO (Bool, c p)
setupTLS Socket
conn = do
      Context
tls <- Maybe HostName
-> TransportConfig -> ServerParams -> Socket -> IO Context
forall p.
TLSParams p =>
Maybe HostName -> TransportConfig -> p -> Socket -> IO Context
connectTLS Maybe HostName
forall a. Maybe a
Nothing TransportConfig
tCfg ServerParams
srvParams Socket
conn
      (Bool
False,) (c p -> (Bool, c p)) -> IO (c p) -> IO (Bool, c p)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TransportConfig -> Bool -> CertificateChain -> Context -> IO (c p)
forall (p :: TransportPeer).
TransportPeerI p =>
TransportConfig -> Bool -> CertificateChain -> Context -> IO (c p)
forall (c :: TransportPeer -> *) (p :: TransportPeer).
(Transport c, TransportPeerI p) =>
TransportConfig -> Bool -> CertificateChain -> Context -> IO (c p)
getTransportConnection TransportConfig
tCfg Bool
True ([SignedExact Certificate] -> CertificateChain
X.CertificateChain []) Context
tls

runTransportServerSocketState :: Transport c => SocketState -> TMVar Bool -> IO Socket -> String -> T.Supported -> TLSServerCredential -> TransportServerConfig -> (Socket -> (SNICredentialUsed, c 'TServer) -> IO ()) -> IO ()
runTransportServerSocketState :: forall (c :: TransportPeer -> *).
Transport c =>
SocketState
-> TMVar Bool
-> IO Socket
-> HostName
-> Supported
-> TLSServerCredential
-> TransportServerConfig
-> (Socket -> (Bool, c 'TServer) -> IO ())
-> IO ()
runTransportServerSocketState SocketState
ss TMVar Bool
started IO Socket
getSocket HostName
threadLabel Supported
srvSupported TLSServerCredential
srvCreds TransportServerConfig
cfg Socket -> (Bool, c 'TServer) -> IO ()
server =
  SocketState
-> TMVar Bool
-> IO Socket
-> HostName
-> Int
-> (Socket -> IO (Bool, c 'TServer))
-> (Socket -> (Bool, c 'TServer) -> IO ())
-> IO ()
forall (c :: TransportPeer -> *).
Transport c =>
SocketState
-> TMVar Bool
-> IO Socket
-> HostName
-> Int
-> (Socket -> IO (Bool, c 'TServer))
-> (Socket -> (Bool, c 'TServer) -> IO ())
-> IO ()
runTransportServerSocketState_ SocketState
ss TMVar Bool
started IO Socket
getSocket HostName
threadLabel (TransportServerConfig -> Int
tlsSetupTimeout TransportServerConfig
cfg) Socket -> IO (Bool, c 'TServer)
forall {c :: TransportPeer -> *} {p :: TransportPeer}.
(Transport c, TransportPeerI p) =>
Socket -> IO (Bool, c p)
setupTLS Socket -> (Bool, c 'TServer) -> IO ()
server
  where
    tCfg :: TransportConfig
tCfg = TransportServerConfig -> TransportConfig
serverTransportConfig TransportServerConfig
cfg
    setupTLS :: Socket -> IO (Bool, c p)
setupTLS Socket
conn = do
      TVar Bool
sniUsed <- Bool -> IO (TVar Bool)
forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO Bool
False
      let srvParams :: ServerParams
srvParams = Supported
-> TLSServerCredential
-> TVar Bool
-> Maybe [ByteString]
-> ServerParams
supportedTLSServerParams Supported
srvSupported TLSServerCredential
srvCreds TVar Bool
sniUsed (Maybe [ByteString] -> ServerParams)
-> Maybe [ByteString] -> ServerParams
forall a b. (a -> b) -> a -> b
$ TransportServerConfig -> Maybe [ByteString]
serverALPN TransportServerConfig
cfg
      c p
h <- ServerParams -> IO (c p)
forall {c :: TransportPeer -> *} {p :: TransportPeer}.
(Transport c, TransportPeerI p) =>
ServerParams -> IO (c p)
setupTLS_ ServerParams
srvParams
      Bool
sni <- TVar Bool -> IO Bool
forall (m :: * -> *) a. MonadIO m => TVar a -> m a
readTVarIO TVar Bool
sniUsed
      (Bool, c p) -> IO (Bool, c p)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
sni, c p
h)
      where
        setupTLS_ :: ServerParams -> IO (c p)
setupTLS_ ServerParams
srvParams
          | TransportServerConfig -> Bool
askClientCert TransportServerConfig
cfg = do
              TMVar (Maybe CertificateChain)
clientCert <- IO (TMVar (Maybe CertificateChain))
forall (m :: * -> *) a. MonadIO m => m (TMVar a)
newEmptyTMVarIO
              Context
tls <- Maybe HostName
-> TransportConfig -> ServerParams -> Socket -> IO Context
forall p.
TLSParams p =>
Maybe HostName -> TransportConfig -> p -> Socket -> IO Context
connectTLS Maybe HostName
forall a. Maybe a
Nothing TransportConfig
tCfg (TMVar (Maybe CertificateChain) -> ServerParams -> ServerParams
paramsAskClientCert TMVar (Maybe CertificateChain)
clientCert ServerParams
srvParams) Socket
conn
              CertificateChain
chain <- TMVar (Maybe CertificateChain) -> IO CertificateChain
takePeerCertChain TMVar (Maybe CertificateChain)
clientCert IO CertificateChain -> IO () -> IO CertificateChain
forall (m :: * -> *) a b. MonadUnliftIO m => m a -> m b -> m a
`E.onException` Context -> IO ()
closeTLS Context
tls
              TransportConfig -> Bool -> CertificateChain -> Context -> IO (c p)
forall (p :: TransportPeer).
TransportPeerI p =>
TransportConfig -> Bool -> CertificateChain -> Context -> IO (c p)
forall (c :: TransportPeer -> *) (p :: TransportPeer).
(Transport c, TransportPeerI p) =>
TransportConfig -> Bool -> CertificateChain -> Context -> IO (c p)
getTransportConnection TransportConfig
tCfg Bool
True CertificateChain
chain Context
tls
          | Bool
otherwise = do
              Context
tls <- Maybe HostName
-> TransportConfig -> ServerParams -> Socket -> IO Context
forall p.
TLSParams p =>
Maybe HostName -> TransportConfig -> p -> Socket -> IO Context
connectTLS Maybe HostName
forall a. Maybe a
Nothing TransportConfig
tCfg ServerParams
srvParams Socket
conn
              TransportConfig -> Bool -> CertificateChain -> Context -> IO (c p)
forall (p :: TransportPeer).
TransportPeerI p =>
TransportConfig -> Bool -> CertificateChain -> Context -> IO (c p)
forall (c :: TransportPeer -> *) (p :: TransportPeer).
(Transport c, TransportPeerI p) =>
TransportConfig -> Bool -> CertificateChain -> Context -> IO (c p)
getTransportConnection TransportConfig
tCfg Bool
True ([SignedExact Certificate] -> CertificateChain
X.CertificateChain []) Context
tls

-- | Run a transport server with provided connection setup and handler.
runTransportServerSocketState_ :: Transport c => SocketState -> TMVar Bool -> IO Socket -> String -> Int -> (Socket -> IO (SNICredentialUsed, c 'TServer)) -> (Socket -> (SNICredentialUsed, c 'TServer) -> IO ()) -> IO ()
runTransportServerSocketState_ :: forall (c :: TransportPeer -> *).
Transport c =>
SocketState
-> TMVar Bool
-> IO Socket
-> HostName
-> Int
-> (Socket -> IO (Bool, c 'TServer))
-> (Socket -> (Bool, c 'TServer) -> IO ())
-> IO ()
runTransportServerSocketState_ SocketState
ss TMVar Bool
started IO Socket
getSocket HostName
threadLabel Int
tlsSetupTimeout Socket -> IO (Bool, c 'TServer)
setupTLS Socket -> (Bool, c 'TServer) -> IO ()
server = do
  HostName -> IO ()
forall (m :: * -> *). MonadIO m => HostName -> m ()
labelMyThread (HostName -> IO ()) -> HostName -> IO ()
forall a b. (a -> b) -> a -> b
$ HostName
"transport server for " HostName -> ShowS
forall a. Semigroup a => a -> a -> a
<> HostName
threadLabel
  SocketState
-> TMVar Bool -> IO Socket -> (Socket -> IO ()) -> IO ()
runTCPServerSocket SocketState
ss TMVar Bool
started IO Socket
getSocket ((Socket -> IO ()) -> IO ()) -> (Socket -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Socket
conn -> do
    HostName -> IO ()
forall (m :: * -> *). MonadIO m => HostName -> m ()
labelMyThread (HostName -> IO ()) -> HostName -> IO ()
forall a b. (a -> b) -> a -> b
$ HostName
threadLabel HostName -> ShowS
forall a. Semigroup a => a -> a -> a
<> HostName
"/setup"
    IO (Bool, c 'TServer)
-> ((Bool, c 'TServer) -> IO ())
-> ((Bool, c 'TServer) -> IO ())
-> IO ()
forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
E.bracket
      (Int -> IO (Bool, c 'TServer) -> IO (Maybe (Bool, c 'TServer))
forall (m :: * -> *) a.
MonadUnliftIO m =>
Int -> m a -> m (Maybe a)
timeout Int
tlsSetupTimeout (Socket -> IO (Bool, c 'TServer)
setupTLS Socket
conn) IO (Maybe (Bool, c 'TServer))
-> (Maybe (Bool, c 'TServer) -> IO (Bool, c 'TServer))
-> IO (Bool, c 'TServer)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IO (Bool, c 'TServer)
-> ((Bool, c 'TServer) -> IO (Bool, c 'TServer))
-> Maybe (Bool, c 'TServer)
-> IO (Bool, c 'TServer)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (HostName -> IO (Bool, c 'TServer)
forall a. HostName -> IO a
forall (m :: * -> *) a. MonadFail m => HostName -> m a
fail HostName
"tls setup timeout") (Bool, c 'TServer) -> IO (Bool, c 'TServer)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure)
      (c 'TServer -> IO ()
forall (p :: TransportPeer). c p -> IO ()
forall (c :: TransportPeer -> *) (p :: TransportPeer).
Transport c =>
c p -> IO ()
closeConnection (c 'TServer -> IO ())
-> ((Bool, c 'TServer) -> c 'TServer)
-> (Bool, c 'TServer)
-> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Bool, c 'TServer) -> c 'TServer
forall a b. (a, b) -> b
snd)
      (Socket -> (Bool, c 'TServer) -> IO ()
server Socket
conn)

-- | Run TCP server without TLS
runLocalTCPServer :: TMVar Bool -> ServiceName -> (Socket -> IO ()) -> IO ()
runLocalTCPServer :: TMVar Bool -> HostName -> (Socket -> IO ()) -> IO ()
runLocalTCPServer TMVar Bool
started HostName
port Socket -> IO ()
server = do
  SocketState
ss <- IO SocketState
newSocketState
  SocketState
-> TMVar Bool -> IO Socket -> (Socket -> IO ()) -> IO ()
runTCPServerSocket SocketState
ss TMVar Bool
started (TMVar Bool -> Maybe HostName -> HostName -> IO Socket
startTCPServer TMVar Bool
started (HostName -> Maybe HostName
forall a. a -> Maybe a
Just HostName
"127.0.0.1") HostName
port) Socket -> IO ()
server

-- | Wrap socket provider in a TCP server bracket.
runTCPServerSocket :: SocketState -> TMVar Bool -> IO Socket -> (Socket -> IO ()) -> IO ()
runTCPServerSocket :: SocketState
-> TMVar Bool -> IO Socket -> (Socket -> IO ()) -> IO ()
runTCPServerSocket (TVar Int
accepted, TVar Int
gracefullyClosed, TVar (IntMap (Weak ThreadId))
clients) TMVar Bool
started IO Socket
getSocket Socket -> IO ()
server =
  IO Socket -> (Socket -> IO ()) -> (Socket -> IO ()) -> IO ()
forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
E.bracket IO Socket
getSocket (TMVar Bool -> TVar (IntMap (Weak ThreadId)) -> Socket -> IO ()
closeServer TMVar Bool
started TVar (IntMap (Weak ThreadId))
clients) ((Socket -> IO ()) -> IO ()) -> (Socket -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Socket
sock ->
    IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> IO ())
-> (((Socket, SockAddr) -> IO ()) -> IO ())
-> ((Socket, SockAddr) -> IO ())
-> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO (Socket, SockAddr)
-> ((Socket, SockAddr) -> IO ())
-> ((Socket, SockAddr) -> IO ())
-> IO ()
forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
E.bracketOnError (Socket -> IO (Socket, SockAddr)
safeAccept Socket
sock) (Socket -> IO ()
close (Socket -> IO ())
-> ((Socket, SockAddr) -> Socket) -> (Socket, SockAddr) -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Socket, SockAddr) -> Socket
forall a b. (a, b) -> a
fst) (((Socket, SockAddr) -> IO ()) -> IO ())
-> ((Socket, SockAddr) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(Socket
conn, SockAddr
_peer) -> do
      Int
cId <- STM Int -> IO Int
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM Int -> IO Int) -> STM Int -> IO Int
forall a b. (a -> b) -> a -> b
$ TVar Int -> (Int -> (Int, Int)) -> STM Int
forall s a. TVar s -> (s -> (a, s)) -> STM a
stateTVar TVar Int
accepted ((Int -> (Int, Int)) -> STM Int) -> (Int -> (Int, Int)) -> STM Int
forall a b. (a -> b) -> a -> b
$ \Int
cId -> let cId' :: Int
cId' = Int
cId Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 in Int
cId' Int -> (Int, Int) -> (Int, Int)
forall a b. a -> b -> b
`seq` (Int
cId', Int
cId')
      TVar Bool
closed <- Bool -> IO (TVar Bool)
forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO Bool
False
      let closeConn :: p -> IO ()
closeConn p
_ = do
            STM () -> IO ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar Bool -> Bool -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar Bool
closed Bool
True STM () -> STM () -> STM ()
forall a b. STM a -> STM b -> STM b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> TVar (IntMap (Weak ThreadId))
-> (IntMap (Weak ThreadId) -> IntMap (Weak ThreadId)) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar (IntMap (Weak ThreadId))
clients (Int -> IntMap (Weak ThreadId) -> IntMap (Weak ThreadId)
forall a. Int -> IntMap a -> IntMap a
IM.delete Int
cId)
            Socket -> Int -> IO ()
gracefulClose Socket
conn Int
5000 IO () -> IO () -> IO ()
forall a. IO a -> IO a -> IO a
`catchAll_` () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure () -- catchAll_ is needed here in case the connection was closed earlier
            STM () -> IO ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar Int -> (Int -> Int) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar Int
gracefullyClosed (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
      Weak ThreadId
tId <- ThreadId -> IO (Weak ThreadId)
forall (m :: * -> *). MonadIO m => ThreadId -> m (Weak ThreadId)
mkWeakThreadId (ThreadId -> IO (Weak ThreadId))
-> IO ThreadId -> IO (Weak ThreadId)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Socket -> IO ()
server Socket
conn IO () -> (Either SomeException () -> IO ()) -> IO ThreadId
forall (m :: * -> *) a.
MonadUnliftIO m =>
m a -> (Either SomeException a -> m ()) -> m ThreadId
`forkFinally` Either SomeException () -> IO ()
forall {p}. p -> IO ()
closeConn
      STM () -> IO ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ STM Bool -> STM () -> STM ()
forall (m :: * -> *). Monad m => m Bool -> m () -> m ()
unlessM (TVar Bool -> STM Bool
forall a. TVar a -> STM a
readTVar TVar Bool
closed) (STM () -> STM ()) -> STM () -> STM ()
forall a b. (a -> b) -> a -> b
$ TVar (IntMap (Weak ThreadId))
-> (IntMap (Weak ThreadId) -> IntMap (Weak ThreadId)) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar (IntMap (Weak ThreadId))
clients ((IntMap (Weak ThreadId) -> IntMap (Weak ThreadId)) -> STM ())
-> (IntMap (Weak ThreadId) -> IntMap (Weak ThreadId)) -> STM ()
forall a b. (a -> b) -> a -> b
$ Int
-> Weak ThreadId
-> IntMap (Weak ThreadId)
-> IntMap (Weak ThreadId)
forall a. Int -> a -> IntMap a -> IntMap a
IM.insert Int
cId Weak ThreadId
tId

-- | Recover from errors in `accept` whenever it is safe.
-- Some errors are safe to ignore, while blindly restaring `accept` may trigger a busy loop.
--
-- man accept says:
-- @
-- For  reliable  operation the application should detect the network errors defined for the protocol after accept() and treat them like EAGAIN by retrying.
-- In  the  case  of  TCP/IP, these are ENETDOWN, EPROTO, ENOPROTOOPT, EHOSTDOWN, ENONET, EHOSTUNREACH, EOPNOTSUPP, and ENETUNREACH.
-- @
safeAccept :: Socket -> IO (Socket, SockAddr)
safeAccept :: Socket -> IO (Socket, SockAddr)
safeAccept Socket
sock =
  IO (Socket, SockAddr) -> IO (Either IOError (Socket, SockAddr))
forall a. IO a -> IO (Either IOError a)
tryIOError (Socket -> IO (Socket, SockAddr)
accept Socket
sock) IO (Either IOError (Socket, SockAddr))
-> (Either IOError (Socket, SockAddr) -> IO (Socket, SockAddr))
-> IO (Socket, SockAddr)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Right (Socket, SockAddr)
r -> (Socket, SockAddr) -> IO (Socket, SockAddr)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Socket, SockAddr)
r
    Left IOError
e
      | Bool
retryAccept -> Text -> IO ()
forall (m :: * -> *). (HasCallStack, MonadIO m) => Text -> m ()
logWarn Text
err IO () -> IO (Socket, SockAddr) -> IO (Socket, SockAddr)
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Socket -> IO (Socket, SockAddr)
safeAccept Socket
sock
      | Bool
otherwise -> Text -> IO ()
forall (m :: * -> *). (HasCallStack, MonadIO m) => Text -> m ()
logError Text
err IO () -> IO (Socket, SockAddr) -> IO (Socket, SockAddr)
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IOError -> IO (Socket, SockAddr)
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
E.throwIO IOError
e
      where
        retryAccept :: Bool
retryAccept = Bool -> (CInt -> Bool) -> Maybe CInt -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False ((Errno -> [Errno] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Errno]
again) (Errno -> Bool) -> (CInt -> Errno) -> CInt -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CInt -> Errno
Errno) Maybe CInt
errno
        again :: [Errno]
again = [Errno
eCONNABORTED, Errno
eAGAIN, Errno
eNETDOWN, Errno
ePROTO, Errno
eNOPROTOOPT, Errno
eHOSTDOWN, Errno
eNONET, Errno
eHOSTUNREACH, Errno
eOPNOTSUPP, Errno
eNETUNREACH]
        err :: Text
err = Text
"socket accept error: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> IOError -> Text
forall a. Show a => a -> Text
tshow IOError
e Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text -> (CInt -> Text) -> Maybe CInt -> Text
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Text
"" ((Text
", errno=" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<>) (Text -> Text) -> (CInt -> Text) -> CInt -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CInt -> Text
forall a. Show a => a -> Text
tshow) Maybe CInt
errno
        errno :: Maybe CInt
errno = IOError -> Maybe CInt
ioe_errno IOError
e

type SocketState = (TVar Int, TVar Int, TVar (IntMap (Weak ThreadId)))

data SocketStats = SocketStats
  { SocketStats -> Int
socketsAccepted :: Int,
    SocketStats -> Int
socketsClosed :: Int,
    SocketStats -> Int
socketsActive :: Int,
    SocketStats -> Int
socketsLeaked :: Int
  }

newSocketState :: IO SocketState
newSocketState :: IO SocketState
newSocketState = (,,) (TVar Int
 -> TVar Int -> TVar (IntMap (Weak ThreadId)) -> SocketState)
-> IO (TVar Int)
-> IO (TVar Int -> TVar (IntMap (Weak ThreadId)) -> SocketState)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO (TVar Int)
forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO Int
0 IO (TVar Int -> TVar (IntMap (Weak ThreadId)) -> SocketState)
-> IO (TVar Int)
-> IO (TVar (IntMap (Weak ThreadId)) -> SocketState)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> IO (TVar Int)
forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO Int
0 IO (TVar (IntMap (Weak ThreadId)) -> SocketState)
-> IO (TVar (IntMap (Weak ThreadId))) -> IO SocketState
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IntMap (Weak ThreadId) -> IO (TVar (IntMap (Weak ThreadId)))
forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO IntMap (Weak ThreadId)
forall a. Monoid a => a
mempty

getSocketStats :: SocketState -> IO SocketStats
getSocketStats :: SocketState -> IO SocketStats
getSocketStats (TVar Int
accepted, TVar Int
closed, TVar (IntMap (Weak ThreadId))
active) = do
  Int
socketsAccepted <- TVar Int -> IO Int
forall (m :: * -> *) a. MonadIO m => TVar a -> m a
readTVarIO TVar Int
accepted
  Int
socketsClosed <- TVar Int -> IO Int
forall (m :: * -> *) a. MonadIO m => TVar a -> m a
readTVarIO TVar Int
closed
  Int
socketsActive <- IntMap (Weak ThreadId) -> Int
forall a. IntMap a -> Int
IM.size (IntMap (Weak ThreadId) -> Int)
-> IO (IntMap (Weak ThreadId)) -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TVar (IntMap (Weak ThreadId)) -> IO (IntMap (Weak ThreadId))
forall (m :: * -> *) a. MonadIO m => TVar a -> m a
readTVarIO TVar (IntMap (Weak ThreadId))
active
  let socketsLeaked :: Int
socketsLeaked = Int
socketsAccepted Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
socketsClosed Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
socketsActive
  SocketStats -> IO SocketStats
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SocketStats {Int
$sel:socketsAccepted:SocketStats :: Int
socketsAccepted :: Int
socketsAccepted, Int
$sel:socketsClosed:SocketStats :: Int
socketsClosed :: Int
socketsClosed, Int
$sel:socketsActive:SocketStats :: Int
socketsActive :: Int
socketsActive, Int
$sel:socketsLeaked:SocketStats :: Int
socketsLeaked :: Int
socketsLeaked}

closeServer :: TMVar Bool -> TVar (IntMap (Weak ThreadId)) -> Socket -> IO ()
closeServer :: TMVar Bool -> TVar (IntMap (Weak ThreadId)) -> Socket -> IO ()
closeServer TMVar Bool
started TVar (IntMap (Weak ThreadId))
clients Socket
sock = do
  Socket -> IO ()
close Socket
sock
  TVar (IntMap (Weak ThreadId)) -> IO (IntMap (Weak ThreadId))
forall (m :: * -> *) a. MonadIO m => TVar a -> m a
readTVarIO TVar (IntMap (Weak ThreadId))
clients IO (IntMap (Weak ThreadId))
-> (IntMap (Weak ThreadId) -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Weak ThreadId -> IO ()) -> IntMap (Weak ThreadId) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Weak ThreadId -> IO (Maybe ThreadId)
forall v. Weak v -> IO (Maybe v)
deRefWeak (Weak ThreadId -> IO (Maybe ThreadId))
-> (Maybe ThreadId -> IO ()) -> Weak ThreadId -> IO ()
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> (ThreadId -> IO ()) -> Maybe ThreadId -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ThreadId -> IO ()
forall (m :: * -> *). MonadIO m => ThreadId -> m ()
killThread)
  IO Bool -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Bool -> IO ()) -> (STM Bool -> IO Bool) -> STM Bool -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. STM Bool -> IO Bool
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM Bool -> IO ()) -> STM Bool -> IO ()
forall a b. (a -> b) -> a -> b
$ TMVar Bool -> Bool -> STM Bool
forall a. TMVar a -> a -> STM Bool
tryPutTMVar TMVar Bool
started Bool
False

startTCPServer :: TMVar Bool -> Maybe HostName -> ServiceName -> IO Socket
startTCPServer :: TMVar Bool -> Maybe HostName -> HostName -> IO Socket
startTCPServer TMVar Bool
started Maybe HostName
host HostName
port = IO Socket -> IO Socket
forall a. IO a -> IO a
withSocketsDo (IO Socket -> IO Socket) -> IO Socket -> IO Socket
forall a b. (a -> b) -> a -> b
$ IO AddrInfo
resolve IO AddrInfo -> (AddrInfo -> IO Socket) -> IO Socket
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= AddrInfo -> IO Socket
open IO Socket -> (Socket -> IO Socket) -> IO Socket
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Socket -> IO Socket
forall {m :: * -> *} {b}. MonadIO m => b -> m b
setStarted
  where
    resolve :: IO AddrInfo
resolve =
      let hints :: AddrInfo
hints = AddrInfo
defaultHints {addrFlags = [AI_PASSIVE], addrSocketType = Stream}
       in [AddrInfo] -> AddrInfo
forall {t :: * -> *}. Foldable t => t AddrInfo -> AddrInfo
select ([AddrInfo] -> AddrInfo) -> IO [AddrInfo] -> IO AddrInfo
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe AddrInfo -> Maybe HostName -> Maybe HostName -> IO [AddrInfo]
getAddrInfo (AddrInfo -> Maybe AddrInfo
forall a. a -> Maybe a
Just AddrInfo
hints) Maybe HostName
host (HostName -> Maybe HostName
forall a. a -> Maybe a
Just HostName
port)
    select :: t AddrInfo -> AddrInfo
select t AddrInfo
as = Maybe AddrInfo -> AddrInfo
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe AddrInfo -> AddrInfo) -> Maybe AddrInfo -> AddrInfo
forall a b. (a -> b) -> a -> b
$ Family -> Maybe AddrInfo
family Family
AF_INET6 Maybe AddrInfo -> Maybe AddrInfo -> Maybe AddrInfo
forall a. Maybe a -> Maybe a -> Maybe a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Family -> Maybe AddrInfo
family Family
AF_INET
      where
        family :: Family -> Maybe AddrInfo
family Family
f = (AddrInfo -> Bool) -> t AddrInfo -> Maybe AddrInfo
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((Family -> Family -> Bool
forall a. Eq a => a -> a -> Bool
== Family
f) (Family -> Bool) -> (AddrInfo -> Family) -> AddrInfo -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AddrInfo -> Family
addrFamily) t AddrInfo
as
    open :: AddrInfo -> IO Socket
open AddrInfo
addr = do
      Socket
sock <- Family -> SocketType -> CInt -> IO Socket
socket (AddrInfo -> Family
addrFamily AddrInfo
addr) (AddrInfo -> SocketType
addrSocketType AddrInfo
addr) (AddrInfo -> CInt
addrProtocol AddrInfo
addr)
      Socket -> SocketOption -> Int -> IO ()
setSocketOption Socket
sock SocketOption
ReuseAddr Int
1
      Socket -> (CInt -> IO ()) -> IO ()
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
sock CInt -> IO ()
setCloseOnExecIfNeeded
      Text -> IO ()
forall (m :: * -> *). (HasCallStack, MonadIO m) => Text -> m ()
logNote (Text -> IO ()) -> Text -> IO ()
forall a b. (a -> b) -> a -> b
$ Text
"binding to " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> SockAddr -> Text
forall a. Show a => a -> Text
tshow (AddrInfo -> SockAddr
addrAddress AddrInfo
addr)
      Socket -> SockAddr -> IO ()
bind Socket
sock (SockAddr -> IO ()) -> SockAddr -> IO ()
forall a b. (a -> b) -> a -> b
$ AddrInfo -> SockAddr
addrAddress AddrInfo
addr
      Socket -> Int -> IO ()
listen Socket
sock Int
1024
      Socket -> IO Socket
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Socket
sock
    setStarted :: b -> m b
setStarted b
sock = STM Bool -> m Bool
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (TMVar Bool -> Bool -> STM Bool
forall a. TMVar a -> a -> STM Bool
tryPutTMVar TMVar Bool
started Bool
True) m Bool -> m b -> m b
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> b -> m b
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure b
sock

loadServerCredential :: ServerCredentials -> IO T.Credential
loadServerCredential :: ServerCredentials -> IO Credential
loadServerCredential ServerCredentials {Maybe HostName
$sel:caCertificateFile:ServerCredentials :: ServerCredentials -> Maybe HostName
caCertificateFile :: Maybe HostName
caCertificateFile, HostName
$sel:certificateFile:ServerCredentials :: ServerCredentials -> HostName
certificateFile :: HostName
certificateFile, HostName
$sel:privateKeyFile:ServerCredentials :: ServerCredentials -> HostName
privateKeyFile :: HostName
privateKeyFile} =
  HostName
-> [HostName] -> HostName -> IO (Either HostName Credential)
T.credentialLoadX509Chain HostName
certificateFile (Maybe HostName -> [HostName]
forall a. Maybe a -> [a]
maybeToList Maybe HostName
caCertificateFile) HostName
privateKeyFile IO (Either HostName Credential)
-> (Either HostName Credential -> IO Credential) -> IO Credential
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Right Credential
credential -> Credential -> IO Credential
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Credential
credential
    Left HostName
_ -> HostName -> IO ()
putStrLn HostName
"invalid credential" IO () -> IO Credential -> IO Credential
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO Credential
forall a. IO a
exitFailure

supportedTLSServerParams :: T.Supported -> TLSServerCredential -> TVar SNICredentialUsed -> Maybe [ALPN] -> T.ServerParams
supportedTLSServerParams :: Supported
-> TLSServerCredential
-> TVar Bool
-> Maybe [ByteString]
-> ServerParams
supportedTLSServerParams Supported
serverSupported TLSServerCredential {Credential
$sel:credential:TLSServerCredential :: TLSServerCredential -> Credential
credential :: Credential
credential, Maybe Credential
$sel:sniCredential:TLSServerCredential :: TLSServerCredential -> Maybe Credential
sniCredential :: Maybe Credential
sniCredential} TVar Bool
sniCredUsed Maybe [ByteString]
alpn_ =
  ServerParams
forall a. Default a => a
def
    { T.serverWantClientCert = False,
      T.serverHooks =
        def
          { T.onServerNameIndication = case sniCredential of
              Maybe Credential
Nothing -> \Maybe HostName
_ -> Credentials -> IO Credentials
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Credentials -> IO Credentials) -> Credentials -> IO Credentials
forall a b. (a -> b) -> a -> b
$ [Credential] -> Credentials
T.Credentials [Credential
credential]
              Just Credential
sniCred -> \case
                Maybe HostName
Nothing -> Credentials -> IO Credentials
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Credentials -> IO Credentials) -> Credentials -> IO Credentials
forall a b. (a -> b) -> a -> b
$ [Credential] -> Credentials
T.Credentials [Credential
credential]
                Just HostName
_host -> [Credential] -> Credentials
T.Credentials [Credential
sniCred] Credentials -> IO () -> IO Credentials
forall a b. a -> IO b -> IO a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ STM () -> IO ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (TVar Bool -> Bool -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar Bool
sniCredUsed Bool
True),
            T.onALPNClientSuggest = (\[ByteString]
alpn -> ByteString -> IO ByteString
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> IO ByteString)
-> ([ByteString] -> ByteString) -> [ByteString] -> IO ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Maybe ByteString -> ByteString
forall a. a -> Maybe a -> a
fromMaybe ByteString
"" (Maybe ByteString -> ByteString)
-> ([ByteString] -> Maybe ByteString) -> [ByteString] -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> Bool) -> [ByteString] -> Maybe ByteString
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (ByteString -> [ByteString] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ByteString]
alpn)) <$> alpn_
          },
      T.serverSupported = serverSupported
    }

paramsAskClientCert :: TMVar (Maybe X.CertificateChain) -> T.ServerParams -> T.ServerParams
paramsAskClientCert :: TMVar (Maybe CertificateChain) -> ServerParams -> ServerParams
paramsAskClientCert TMVar (Maybe CertificateChain)
clientCert ServerParams
params =
  ServerParams
params
    { T.serverWantClientCert = True,
      T.serverHooks =
        (T.serverHooks params)
          { T.onClientCertificate = \CertificateChain
cc ->
              CertificateChain -> IO (Maybe CertificateRejectReason)
validateClientCertificate CertificateChain
cc IO (Maybe CertificateRejectReason)
-> (Maybe CertificateRejectReason -> IO CertificateUsage)
-> IO CertificateUsage
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
                Just CertificateRejectReason
reason -> CertificateRejectReason -> CertificateUsage
T.CertificateUsageReject CertificateRejectReason
reason CertificateUsage -> IO Bool -> IO CertificateUsage
forall a b. a -> IO b -> IO a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ STM Bool -> IO Bool
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (TMVar (Maybe CertificateChain)
-> Maybe CertificateChain -> STM Bool
forall a. TMVar a -> a -> STM Bool
tryPutTMVar TMVar (Maybe CertificateChain)
clientCert Maybe CertificateChain
forall a. Maybe a
Nothing)
                Maybe CertificateRejectReason
Nothing -> CertificateUsage
T.CertificateUsageAccept CertificateUsage -> IO Bool -> IO CertificateUsage
forall a b. a -> IO b -> IO a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ STM Bool -> IO Bool
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (TMVar (Maybe CertificateChain)
-> Maybe CertificateChain -> STM Bool
forall a. TMVar a -> a -> STM Bool
tryPutTMVar TMVar (Maybe CertificateChain)
clientCert (Maybe CertificateChain -> STM Bool)
-> Maybe CertificateChain -> STM Bool
forall a b. (a -> b) -> a -> b
$ CertificateChain -> Maybe CertificateChain
forall a. a -> Maybe a
Just CertificateChain
cc)
          }
    }

validateClientCertificate :: X.CertificateChain -> IO (Maybe T.CertificateRejectReason)
validateClientCertificate :: CertificateChain -> IO (Maybe CertificateRejectReason)
validateClientCertificate CertificateChain
cc = case CertificateChain -> ChainCertificates
chainIdCaCerts CertificateChain
cc of
  ChainCertificates
CCEmpty -> Maybe CertificateRejectReason -> IO (Maybe CertificateRejectReason)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe CertificateRejectReason
forall a. Maybe a
Nothing -- client certificates are only used for services
  CCSelf SignedExact Certificate
cert -> SignedExact Certificate -> IO (Maybe CertificateRejectReason)
validate SignedExact Certificate
cert
  CCValid {SignedExact Certificate
caCert :: SignedExact Certificate
caCert :: ChainCertificates -> SignedExact Certificate
caCert} -> SignedExact Certificate -> IO (Maybe CertificateRejectReason)
validate SignedExact Certificate
caCert
  ChainCertificates
CCLong -> Maybe CertificateRejectReason -> IO (Maybe CertificateRejectReason)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe CertificateRejectReason
 -> IO (Maybe CertificateRejectReason))
-> Maybe CertificateRejectReason
-> IO (Maybe CertificateRejectReason)
forall a b. (a -> b) -> a -> b
$ CertificateRejectReason -> Maybe CertificateRejectReason
forall a. a -> Maybe a
Just (CertificateRejectReason -> Maybe CertificateRejectReason)
-> CertificateRejectReason -> Maybe CertificateRejectReason
forall a b. (a -> b) -> a -> b
$ HostName -> CertificateRejectReason
T.CertificateRejectOther HostName
"chain too long"
  where
    validate :: SignedExact Certificate -> IO (Maybe CertificateRejectReason)
validate SignedExact Certificate
caCert = [FailedReason] -> Maybe CertificateRejectReason
usage ([FailedReason] -> Maybe CertificateRejectReason)
-> IO [FailedReason] -> IO (Maybe CertificateRejectReason)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SignedExact Certificate
-> (HostName, ByteString) -> CertificateChain -> IO [FailedReason]
x509validate SignedExact Certificate
caCert (HostName
"", ByteString
B.empty) CertificateChain
cc
    usage :: [FailedReason] -> Maybe CertificateRejectReason
usage [] = Maybe CertificateRejectReason
forall a. Maybe a
Nothing
    usage [FailedReason]
r =
      CertificateRejectReason -> Maybe CertificateRejectReason
forall a. a -> Maybe a
Just (CertificateRejectReason -> Maybe CertificateRejectReason)
-> CertificateRejectReason -> Maybe CertificateRejectReason
forall a b. (a -> b) -> a -> b
$
        if
          | FailedReason
XV.Expired FailedReason -> [FailedReason] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [FailedReason]
r Bool -> Bool -> Bool
|| FailedReason
XV.InFuture FailedReason -> [FailedReason] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [FailedReason]
r -> CertificateRejectReason
T.CertificateRejectExpired
          | FailedReason
XV.UnknownCA FailedReason -> [FailedReason] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [FailedReason]
r -> CertificateRejectReason
T.CertificateRejectUnknownCA
          | Bool
otherwise -> HostName -> CertificateRejectReason
T.CertificateRejectOther ([FailedReason] -> HostName
forall a. Show a => a -> HostName
show [FailedReason]
r)

loadFingerprint :: ServerCredentials -> IO Fingerprint
loadFingerprint :: ServerCredentials -> IO Fingerprint
loadFingerprint ServerCredentials {Maybe HostName
$sel:caCertificateFile:ServerCredentials :: ServerCredentials -> Maybe HostName
caCertificateFile :: Maybe HostName
caCertificateFile} = case Maybe HostName
caCertificateFile of
  Just HostName
certificateFile -> HostName -> IO Fingerprint
loadFileFingerprint HostName
certificateFile
  Maybe HostName
Nothing -> HostName -> IO Fingerprint
forall a. HasCallStack => HostName -> a
error HostName
"CA file must be used in protocol credentials"

loadFileFingerprint :: FilePath -> IO Fingerprint
loadFileFingerprint :: HostName -> IO Fingerprint
loadFileFingerprint HostName
certificateFile = do
  (SignedExact Certificate
cert : [SignedExact Certificate]
_) <- HostName -> IO [SignedExact Certificate]
forall a. SignedObject a => HostName -> IO [SignedExact a]
SX.readSignedObject HostName
certificateFile
  Fingerprint -> IO Fingerprint
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Fingerprint -> IO Fingerprint) -> Fingerprint -> IO Fingerprint
forall a b. (a -> b) -> a -> b
$ SignedExact Certificate -> HashALG -> Fingerprint
forall a.
(Show a, Eq a, ASN1Object a) =>
SignedExact a -> HashALG -> Fingerprint
XV.getFingerprint (SignedExact Certificate
cert :: X.SignedExact X.Certificate) HashALG
X.HashSHA256