{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
module Simplex.Messaging.Transport.Server
( TransportServerConfig (..),
ServerCredentials (..),
TLSServerCredential (..),
SNICredentialUsed,
AddHTTP,
mkTransportServerConfig,
runTransportServerState,
runTransportServerState_,
SocketState,
SocketStats (..),
newSocketState,
getSocketStats,
runTransportServer,
runTransportServerSocket,
runLocalTCPServer,
startTCPServer,
loadServerCredential,
loadFingerprint,
loadFileFingerprint,
smpServerHandshake,
)
where
import Control.Applicative ((<|>))
import Control.Logger.Simple
import Control.Monad
import qualified Crypto.Store.X509 as SX
import qualified Data.ByteString as B
import Data.Default (def)
import Data.IntMap.Strict (IntMap)
import qualified Data.IntMap.Strict as IM
import Data.List (find)
import Data.Maybe (fromJust, fromMaybe, maybeToList)
import qualified Data.X509 as X
import Data.X509.Validation (Fingerprint (..))
import qualified Data.X509.Validation as XV
import Foreign.C.Error
import GHC.IO.Exception (ioe_errno)
import Network.Socket
import qualified Network.TLS as T
import Simplex.Messaging.Transport
import Simplex.Messaging.Transport.Shared
import Simplex.Messaging.Util (catchAll_, labelMyThread, tshow, unlessM)
import System.Exit (exitFailure)
import System.IO.Error (tryIOError)
import System.Mem.Weak (Weak, deRefWeak)
import UnliftIO (timeout)
import UnliftIO.Concurrent
import qualified UnliftIO.Exception as E
import UnliftIO.STM
data TransportServerConfig = TransportServerConfig
{ TransportServerConfig -> Bool
logTLSErrors :: Bool,
TransportServerConfig -> Maybe [ByteString]
serverALPN :: Maybe [ALPN],
TransportServerConfig -> Bool
askClientCert :: Bool,
:: Bool,
TransportServerConfig -> Int
tlsSetupTimeout :: Int,
TransportServerConfig -> Int
transportTimeout :: Int
}
deriving (TransportServerConfig -> TransportServerConfig -> Bool
(TransportServerConfig -> TransportServerConfig -> Bool)
-> (TransportServerConfig -> TransportServerConfig -> Bool)
-> Eq TransportServerConfig
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: TransportServerConfig -> TransportServerConfig -> Bool
== :: TransportServerConfig -> TransportServerConfig -> Bool
$c/= :: TransportServerConfig -> TransportServerConfig -> Bool
/= :: TransportServerConfig -> TransportServerConfig -> Bool
Eq, Int -> TransportServerConfig -> ShowS
[TransportServerConfig] -> ShowS
TransportServerConfig -> HostName
(Int -> TransportServerConfig -> ShowS)
-> (TransportServerConfig -> HostName)
-> ([TransportServerConfig] -> ShowS)
-> Show TransportServerConfig
forall a.
(Int -> a -> ShowS) -> (a -> HostName) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> TransportServerConfig -> ShowS
showsPrec :: Int -> TransportServerConfig -> ShowS
$cshow :: TransportServerConfig -> HostName
show :: TransportServerConfig -> HostName
$cshowList :: [TransportServerConfig] -> ShowS
showList :: [TransportServerConfig] -> ShowS
Show)
data ServerCredentials = ServerCredentials
{ ServerCredentials -> Maybe HostName
caCertificateFile :: Maybe FilePath,
ServerCredentials -> HostName
privateKeyFile :: FilePath,
ServerCredentials -> HostName
certificateFile :: FilePath
}
deriving (Int -> ServerCredentials -> ShowS
[ServerCredentials] -> ShowS
ServerCredentials -> HostName
(Int -> ServerCredentials -> ShowS)
-> (ServerCredentials -> HostName)
-> ([ServerCredentials] -> ShowS)
-> Show ServerCredentials
forall a.
(Int -> a -> ShowS) -> (a -> HostName) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ServerCredentials -> ShowS
showsPrec :: Int -> ServerCredentials -> ShowS
$cshow :: ServerCredentials -> HostName
show :: ServerCredentials -> HostName
$cshowList :: [ServerCredentials] -> ShowS
showList :: [ServerCredentials] -> ShowS
Show)
type AddHTTP = Bool
data TLSServerCredential = TLSServerCredential
{ TLSServerCredential -> Credential
credential :: T.Credential,
TLSServerCredential -> Maybe Credential
sniCredential :: Maybe T.Credential
}
type SNICredentialUsed = Bool
mkTransportServerConfig :: Bool -> Maybe [ALPN] -> Bool -> TransportServerConfig
mkTransportServerConfig :: Bool -> Maybe [ByteString] -> Bool -> TransportServerConfig
mkTransportServerConfig Bool
logTLSErrors Maybe [ByteString]
serverALPN Bool
askClientCert =
TransportServerConfig
{ Bool
$sel:logTLSErrors:TransportServerConfig :: Bool
logTLSErrors :: Bool
logTLSErrors,
Maybe [ByteString]
$sel:serverALPN:TransportServerConfig :: Maybe [ByteString]
serverALPN :: Maybe [ByteString]
serverALPN,
Bool
$sel:askClientCert:TransportServerConfig :: Bool
askClientCert :: Bool
askClientCert,
$sel:addCORSHeaders:TransportServerConfig :: Bool
addCORSHeaders = Bool
False,
$sel:tlsSetupTimeout:TransportServerConfig :: Int
tlsSetupTimeout = Int
60000000,
$sel:transportTimeout:TransportServerConfig :: Int
transportTimeout = Int
40000000
}
serverTransportConfig :: TransportServerConfig -> TransportConfig
serverTransportConfig :: TransportServerConfig -> TransportConfig
serverTransportConfig TransportServerConfig {Bool
$sel:logTLSErrors:TransportServerConfig :: TransportServerConfig -> Bool
logTLSErrors :: Bool
logTLSErrors} =
TransportConfig {Bool
logTLSErrors :: Bool
$sel:logTLSErrors:TransportConfig :: Bool
logTLSErrors, $sel:transportTimeout:TransportConfig :: Maybe Int
transportTimeout = Maybe Int
forall a. Maybe a
Nothing}
runTransportServer :: Transport c => TMVar Bool -> ServiceName -> T.Supported -> T.Credential -> TransportServerConfig -> (c 'TServer -> IO ()) -> IO ()
runTransportServer :: forall (c :: TransportPeer -> *).
Transport c =>
TMVar Bool
-> HostName
-> Supported
-> Credential
-> TransportServerConfig
-> (c 'TServer -> IO ())
-> IO ()
runTransportServer TMVar Bool
started HostName
port Supported
srvSupported Credential
srvCreds TransportServerConfig
cfg c 'TServer -> IO ()
server = do
SocketState
ss <- IO SocketState
newSocketState
SocketState
-> TMVar Bool
-> HostName
-> Supported
-> Credential
-> TransportServerConfig
-> (c 'TServer -> IO ())
-> IO ()
forall (c :: TransportPeer -> *).
Transport c =>
SocketState
-> TMVar Bool
-> HostName
-> Supported
-> Credential
-> TransportServerConfig
-> (c 'TServer -> IO ())
-> IO ()
runTransportServerState SocketState
ss TMVar Bool
started HostName
port Supported
srvSupported Credential
srvCreds TransportServerConfig
cfg c 'TServer -> IO ()
server
runTransportServerState :: Transport c => SocketState -> TMVar Bool -> ServiceName -> T.Supported -> T.Credential -> TransportServerConfig -> (c 'TServer -> IO ()) -> IO ()
runTransportServerState :: forall (c :: TransportPeer -> *).
Transport c =>
SocketState
-> TMVar Bool
-> HostName
-> Supported
-> Credential
-> TransportServerConfig
-> (c 'TServer -> IO ())
-> IO ()
runTransportServerState SocketState
ss TMVar Bool
started HostName
port Supported
srvSupported Credential
credential TransportServerConfig
cfg c 'TServer -> IO ()
server = SocketState
-> TMVar Bool
-> HostName
-> Supported
-> TLSServerCredential
-> TransportServerConfig
-> (Socket -> (Bool, c 'TServer) -> IO ())
-> IO ()
forall (c :: TransportPeer -> *).
Transport c =>
SocketState
-> TMVar Bool
-> HostName
-> Supported
-> TLSServerCredential
-> TransportServerConfig
-> (Socket -> (Bool, c 'TServer) -> IO ())
-> IO ()
runTransportServerState_ SocketState
ss TMVar Bool
started HostName
port Supported
srvSupported TLSServerCredential
srvCreds TransportServerConfig
cfg (\Socket
_ -> c 'TServer -> IO ()
server (c 'TServer -> IO ())
-> ((Bool, c 'TServer) -> c 'TServer)
-> (Bool, c 'TServer)
-> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Bool, c 'TServer) -> c 'TServer
forall a b. (a, b) -> b
snd)
where
srvCreds :: TLSServerCredential
srvCreds = TLSServerCredential {Credential
$sel:credential:TLSServerCredential :: Credential
credential :: Credential
credential, $sel:sniCredential:TLSServerCredential :: Maybe Credential
sniCredential = Maybe Credential
forall a. Maybe a
Nothing}
runTransportServerState_ :: forall c. Transport c => SocketState -> TMVar Bool -> ServiceName -> T.Supported -> TLSServerCredential -> TransportServerConfig -> (Socket -> (SNICredentialUsed, c 'TServer) -> IO ()) -> IO ()
runTransportServerState_ :: forall (c :: TransportPeer -> *).
Transport c =>
SocketState
-> TMVar Bool
-> HostName
-> Supported
-> TLSServerCredential
-> TransportServerConfig
-> (Socket -> (Bool, c 'TServer) -> IO ())
-> IO ()
runTransportServerState_ SocketState
ss TMVar Bool
started HostName
port = SocketState
-> TMVar Bool
-> IO Socket
-> HostName
-> Supported
-> TLSServerCredential
-> TransportServerConfig
-> (Socket -> (Bool, c 'TServer) -> IO ())
-> IO ()
forall (c :: TransportPeer -> *).
Transport c =>
SocketState
-> TMVar Bool
-> IO Socket
-> HostName
-> Supported
-> TLSServerCredential
-> TransportServerConfig
-> (Socket -> (Bool, c 'TServer) -> IO ())
-> IO ()
runTransportServerSocketState SocketState
ss TMVar Bool
started (TMVar Bool -> Maybe HostName -> HostName -> IO Socket
startTCPServer TMVar Bool
started Maybe HostName
forall a. Maybe a
Nothing HostName
port) (TProxy c 'TServer -> HostName
forall (p :: TransportPeer). TProxy c p -> HostName
forall (c :: TransportPeer -> *) (p :: TransportPeer).
Transport c =>
TProxy c p -> HostName
transportName (TProxy c 'TServer
forall (c :: TransportPeer -> *) (p :: TransportPeer). TProxy c p
TProxy :: TProxy c 'TServer))
runTransportServerSocket :: Transport c => TMVar Bool -> IO Socket -> String -> T.ServerParams -> TransportServerConfig -> (c 'TServer -> IO ()) -> IO ()
runTransportServerSocket :: forall (c :: TransportPeer -> *).
Transport c =>
TMVar Bool
-> IO Socket
-> HostName
-> ServerParams
-> TransportServerConfig
-> (c 'TServer -> IO ())
-> IO ()
runTransportServerSocket TMVar Bool
started IO Socket
getSocket HostName
threadLabel ServerParams
srvParams TransportServerConfig
cfg c 'TServer -> IO ()
server = do
SocketState
ss <- IO SocketState
newSocketState
SocketState
-> TMVar Bool
-> IO Socket
-> HostName
-> Int
-> (Socket -> IO (Bool, c 'TServer))
-> (Socket -> (Bool, c 'TServer) -> IO ())
-> IO ()
forall (c :: TransportPeer -> *).
Transport c =>
SocketState
-> TMVar Bool
-> IO Socket
-> HostName
-> Int
-> (Socket -> IO (Bool, c 'TServer))
-> (Socket -> (Bool, c 'TServer) -> IO ())
-> IO ()
runTransportServerSocketState_ SocketState
ss TMVar Bool
started IO Socket
getSocket HostName
threadLabel (TransportServerConfig -> Int
tlsSetupTimeout TransportServerConfig
cfg) Socket -> IO (Bool, c 'TServer)
forall {c :: TransportPeer -> *} {p :: TransportPeer}.
(Transport c, TransportPeerI p) =>
Socket -> IO (Bool, c p)
setupTLS (\Socket
_ -> c 'TServer -> IO ()
server (c 'TServer -> IO ())
-> ((Bool, c 'TServer) -> c 'TServer)
-> (Bool, c 'TServer)
-> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Bool, c 'TServer) -> c 'TServer
forall a b. (a, b) -> b
snd)
where
tCfg :: TransportConfig
tCfg = TransportServerConfig -> TransportConfig
serverTransportConfig TransportServerConfig
cfg
setupTLS :: Socket -> IO (Bool, c p)
setupTLS Socket
conn = do
Context
tls <- Maybe HostName
-> TransportConfig -> ServerParams -> Socket -> IO Context
forall p.
TLSParams p =>
Maybe HostName -> TransportConfig -> p -> Socket -> IO Context
connectTLS Maybe HostName
forall a. Maybe a
Nothing TransportConfig
tCfg ServerParams
srvParams Socket
conn
(Bool
False,) (c p -> (Bool, c p)) -> IO (c p) -> IO (Bool, c p)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TransportConfig -> Bool -> CertificateChain -> Context -> IO (c p)
forall (p :: TransportPeer).
TransportPeerI p =>
TransportConfig -> Bool -> CertificateChain -> Context -> IO (c p)
forall (c :: TransportPeer -> *) (p :: TransportPeer).
(Transport c, TransportPeerI p) =>
TransportConfig -> Bool -> CertificateChain -> Context -> IO (c p)
getTransportConnection TransportConfig
tCfg Bool
True ([SignedExact Certificate] -> CertificateChain
X.CertificateChain []) Context
tls
runTransportServerSocketState :: Transport c => SocketState -> TMVar Bool -> IO Socket -> String -> T.Supported -> TLSServerCredential -> TransportServerConfig -> (Socket -> (SNICredentialUsed, c 'TServer) -> IO ()) -> IO ()
runTransportServerSocketState :: forall (c :: TransportPeer -> *).
Transport c =>
SocketState
-> TMVar Bool
-> IO Socket
-> HostName
-> Supported
-> TLSServerCredential
-> TransportServerConfig
-> (Socket -> (Bool, c 'TServer) -> IO ())
-> IO ()
runTransportServerSocketState SocketState
ss TMVar Bool
started IO Socket
getSocket HostName
threadLabel Supported
srvSupported TLSServerCredential
srvCreds TransportServerConfig
cfg Socket -> (Bool, c 'TServer) -> IO ()
server =
SocketState
-> TMVar Bool
-> IO Socket
-> HostName
-> Int
-> (Socket -> IO (Bool, c 'TServer))
-> (Socket -> (Bool, c 'TServer) -> IO ())
-> IO ()
forall (c :: TransportPeer -> *).
Transport c =>
SocketState
-> TMVar Bool
-> IO Socket
-> HostName
-> Int
-> (Socket -> IO (Bool, c 'TServer))
-> (Socket -> (Bool, c 'TServer) -> IO ())
-> IO ()
runTransportServerSocketState_ SocketState
ss TMVar Bool
started IO Socket
getSocket HostName
threadLabel (TransportServerConfig -> Int
tlsSetupTimeout TransportServerConfig
cfg) Socket -> IO (Bool, c 'TServer)
forall {c :: TransportPeer -> *} {p :: TransportPeer}.
(Transport c, TransportPeerI p) =>
Socket -> IO (Bool, c p)
setupTLS Socket -> (Bool, c 'TServer) -> IO ()
server
where
tCfg :: TransportConfig
tCfg = TransportServerConfig -> TransportConfig
serverTransportConfig TransportServerConfig
cfg
setupTLS :: Socket -> IO (Bool, c p)
setupTLS Socket
conn = do
TVar Bool
sniUsed <- Bool -> IO (TVar Bool)
forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO Bool
False
let srvParams :: ServerParams
srvParams = Supported
-> TLSServerCredential
-> TVar Bool
-> Maybe [ByteString]
-> ServerParams
supportedTLSServerParams Supported
srvSupported TLSServerCredential
srvCreds TVar Bool
sniUsed (Maybe [ByteString] -> ServerParams)
-> Maybe [ByteString] -> ServerParams
forall a b. (a -> b) -> a -> b
$ TransportServerConfig -> Maybe [ByteString]
serverALPN TransportServerConfig
cfg
c p
h <- ServerParams -> IO (c p)
forall {c :: TransportPeer -> *} {p :: TransportPeer}.
(Transport c, TransportPeerI p) =>
ServerParams -> IO (c p)
setupTLS_ ServerParams
srvParams
Bool
sni <- TVar Bool -> IO Bool
forall (m :: * -> *) a. MonadIO m => TVar a -> m a
readTVarIO TVar Bool
sniUsed
(Bool, c p) -> IO (Bool, c p)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
sni, c p
h)
where
setupTLS_ :: ServerParams -> IO (c p)
setupTLS_ ServerParams
srvParams
| TransportServerConfig -> Bool
askClientCert TransportServerConfig
cfg = do
TMVar (Maybe CertificateChain)
clientCert <- IO (TMVar (Maybe CertificateChain))
forall (m :: * -> *) a. MonadIO m => m (TMVar a)
newEmptyTMVarIO
Context
tls <- Maybe HostName
-> TransportConfig -> ServerParams -> Socket -> IO Context
forall p.
TLSParams p =>
Maybe HostName -> TransportConfig -> p -> Socket -> IO Context
connectTLS Maybe HostName
forall a. Maybe a
Nothing TransportConfig
tCfg (TMVar (Maybe CertificateChain) -> ServerParams -> ServerParams
paramsAskClientCert TMVar (Maybe CertificateChain)
clientCert ServerParams
srvParams) Socket
conn
CertificateChain
chain <- TMVar (Maybe CertificateChain) -> IO CertificateChain
takePeerCertChain TMVar (Maybe CertificateChain)
clientCert IO CertificateChain -> IO () -> IO CertificateChain
forall (m :: * -> *) a b. MonadUnliftIO m => m a -> m b -> m a
`E.onException` Context -> IO ()
closeTLS Context
tls
TransportConfig -> Bool -> CertificateChain -> Context -> IO (c p)
forall (p :: TransportPeer).
TransportPeerI p =>
TransportConfig -> Bool -> CertificateChain -> Context -> IO (c p)
forall (c :: TransportPeer -> *) (p :: TransportPeer).
(Transport c, TransportPeerI p) =>
TransportConfig -> Bool -> CertificateChain -> Context -> IO (c p)
getTransportConnection TransportConfig
tCfg Bool
True CertificateChain
chain Context
tls
| Bool
otherwise = do
Context
tls <- Maybe HostName
-> TransportConfig -> ServerParams -> Socket -> IO Context
forall p.
TLSParams p =>
Maybe HostName -> TransportConfig -> p -> Socket -> IO Context
connectTLS Maybe HostName
forall a. Maybe a
Nothing TransportConfig
tCfg ServerParams
srvParams Socket
conn
TransportConfig -> Bool -> CertificateChain -> Context -> IO (c p)
forall (p :: TransportPeer).
TransportPeerI p =>
TransportConfig -> Bool -> CertificateChain -> Context -> IO (c p)
forall (c :: TransportPeer -> *) (p :: TransportPeer).
(Transport c, TransportPeerI p) =>
TransportConfig -> Bool -> CertificateChain -> Context -> IO (c p)
getTransportConnection TransportConfig
tCfg Bool
True ([SignedExact Certificate] -> CertificateChain
X.CertificateChain []) Context
tls
runTransportServerSocketState_ :: Transport c => SocketState -> TMVar Bool -> IO Socket -> String -> Int -> (Socket -> IO (SNICredentialUsed, c 'TServer)) -> (Socket -> (SNICredentialUsed, c 'TServer) -> IO ()) -> IO ()
runTransportServerSocketState_ :: forall (c :: TransportPeer -> *).
Transport c =>
SocketState
-> TMVar Bool
-> IO Socket
-> HostName
-> Int
-> (Socket -> IO (Bool, c 'TServer))
-> (Socket -> (Bool, c 'TServer) -> IO ())
-> IO ()
runTransportServerSocketState_ SocketState
ss TMVar Bool
started IO Socket
getSocket HostName
threadLabel Int
tlsSetupTimeout Socket -> IO (Bool, c 'TServer)
setupTLS Socket -> (Bool, c 'TServer) -> IO ()
server = do
HostName -> IO ()
forall (m :: * -> *). MonadIO m => HostName -> m ()
labelMyThread (HostName -> IO ()) -> HostName -> IO ()
forall a b. (a -> b) -> a -> b
$ HostName
"transport server for " HostName -> ShowS
forall a. Semigroup a => a -> a -> a
<> HostName
threadLabel
SocketState
-> TMVar Bool -> IO Socket -> (Socket -> IO ()) -> IO ()
runTCPServerSocket SocketState
ss TMVar Bool
started IO Socket
getSocket ((Socket -> IO ()) -> IO ()) -> (Socket -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Socket
conn -> do
HostName -> IO ()
forall (m :: * -> *). MonadIO m => HostName -> m ()
labelMyThread (HostName -> IO ()) -> HostName -> IO ()
forall a b. (a -> b) -> a -> b
$ HostName
threadLabel HostName -> ShowS
forall a. Semigroup a => a -> a -> a
<> HostName
"/setup"
IO (Bool, c 'TServer)
-> ((Bool, c 'TServer) -> IO ())
-> ((Bool, c 'TServer) -> IO ())
-> IO ()
forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
E.bracket
(Int -> IO (Bool, c 'TServer) -> IO (Maybe (Bool, c 'TServer))
forall (m :: * -> *) a.
MonadUnliftIO m =>
Int -> m a -> m (Maybe a)
timeout Int
tlsSetupTimeout (Socket -> IO (Bool, c 'TServer)
setupTLS Socket
conn) IO (Maybe (Bool, c 'TServer))
-> (Maybe (Bool, c 'TServer) -> IO (Bool, c 'TServer))
-> IO (Bool, c 'TServer)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IO (Bool, c 'TServer)
-> ((Bool, c 'TServer) -> IO (Bool, c 'TServer))
-> Maybe (Bool, c 'TServer)
-> IO (Bool, c 'TServer)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (HostName -> IO (Bool, c 'TServer)
forall a. HostName -> IO a
forall (m :: * -> *) a. MonadFail m => HostName -> m a
fail HostName
"tls setup timeout") (Bool, c 'TServer) -> IO (Bool, c 'TServer)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure)
(c 'TServer -> IO ()
forall (p :: TransportPeer). c p -> IO ()
forall (c :: TransportPeer -> *) (p :: TransportPeer).
Transport c =>
c p -> IO ()
closeConnection (c 'TServer -> IO ())
-> ((Bool, c 'TServer) -> c 'TServer)
-> (Bool, c 'TServer)
-> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Bool, c 'TServer) -> c 'TServer
forall a b. (a, b) -> b
snd)
(Socket -> (Bool, c 'TServer) -> IO ()
server Socket
conn)
runLocalTCPServer :: TMVar Bool -> ServiceName -> (Socket -> IO ()) -> IO ()
runLocalTCPServer :: TMVar Bool -> HostName -> (Socket -> IO ()) -> IO ()
runLocalTCPServer TMVar Bool
started HostName
port Socket -> IO ()
server = do
SocketState
ss <- IO SocketState
newSocketState
SocketState
-> TMVar Bool -> IO Socket -> (Socket -> IO ()) -> IO ()
runTCPServerSocket SocketState
ss TMVar Bool
started (TMVar Bool -> Maybe HostName -> HostName -> IO Socket
startTCPServer TMVar Bool
started (HostName -> Maybe HostName
forall a. a -> Maybe a
Just HostName
"127.0.0.1") HostName
port) Socket -> IO ()
server
runTCPServerSocket :: SocketState -> TMVar Bool -> IO Socket -> (Socket -> IO ()) -> IO ()
runTCPServerSocket :: SocketState
-> TMVar Bool -> IO Socket -> (Socket -> IO ()) -> IO ()
runTCPServerSocket (TVar Int
accepted, TVar Int
gracefullyClosed, TVar (IntMap (Weak ThreadId))
clients) TMVar Bool
started IO Socket
getSocket Socket -> IO ()
server =
IO Socket -> (Socket -> IO ()) -> (Socket -> IO ()) -> IO ()
forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
E.bracket IO Socket
getSocket (TMVar Bool -> TVar (IntMap (Weak ThreadId)) -> Socket -> IO ()
closeServer TMVar Bool
started TVar (IntMap (Weak ThreadId))
clients) ((Socket -> IO ()) -> IO ()) -> (Socket -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Socket
sock ->
IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> IO ())
-> (((Socket, SockAddr) -> IO ()) -> IO ())
-> ((Socket, SockAddr) -> IO ())
-> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO (Socket, SockAddr)
-> ((Socket, SockAddr) -> IO ())
-> ((Socket, SockAddr) -> IO ())
-> IO ()
forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
E.bracketOnError (Socket -> IO (Socket, SockAddr)
safeAccept Socket
sock) (Socket -> IO ()
close (Socket -> IO ())
-> ((Socket, SockAddr) -> Socket) -> (Socket, SockAddr) -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Socket, SockAddr) -> Socket
forall a b. (a, b) -> a
fst) (((Socket, SockAddr) -> IO ()) -> IO ())
-> ((Socket, SockAddr) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(Socket
conn, SockAddr
_peer) -> do
Int
cId <- STM Int -> IO Int
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM Int -> IO Int) -> STM Int -> IO Int
forall a b. (a -> b) -> a -> b
$ TVar Int -> (Int -> (Int, Int)) -> STM Int
forall s a. TVar s -> (s -> (a, s)) -> STM a
stateTVar TVar Int
accepted ((Int -> (Int, Int)) -> STM Int) -> (Int -> (Int, Int)) -> STM Int
forall a b. (a -> b) -> a -> b
$ \Int
cId -> let cId' :: Int
cId' = Int
cId Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 in Int
cId' Int -> (Int, Int) -> (Int, Int)
forall a b. a -> b -> b
`seq` (Int
cId', Int
cId')
TVar Bool
closed <- Bool -> IO (TVar Bool)
forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO Bool
False
let closeConn :: p -> IO ()
closeConn p
_ = do
STM () -> IO ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar Bool -> Bool -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar Bool
closed Bool
True STM () -> STM () -> STM ()
forall a b. STM a -> STM b -> STM b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> TVar (IntMap (Weak ThreadId))
-> (IntMap (Weak ThreadId) -> IntMap (Weak ThreadId)) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar (IntMap (Weak ThreadId))
clients (Int -> IntMap (Weak ThreadId) -> IntMap (Weak ThreadId)
forall a. Int -> IntMap a -> IntMap a
IM.delete Int
cId)
Socket -> Int -> IO ()
gracefulClose Socket
conn Int
5000 IO () -> IO () -> IO ()
forall a. IO a -> IO a -> IO a
`catchAll_` () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
STM () -> IO ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar Int -> (Int -> Int) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar Int
gracefullyClosed (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
Weak ThreadId
tId <- ThreadId -> IO (Weak ThreadId)
forall (m :: * -> *). MonadIO m => ThreadId -> m (Weak ThreadId)
mkWeakThreadId (ThreadId -> IO (Weak ThreadId))
-> IO ThreadId -> IO (Weak ThreadId)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Socket -> IO ()
server Socket
conn IO () -> (Either SomeException () -> IO ()) -> IO ThreadId
forall (m :: * -> *) a.
MonadUnliftIO m =>
m a -> (Either SomeException a -> m ()) -> m ThreadId
`forkFinally` Either SomeException () -> IO ()
forall {p}. p -> IO ()
closeConn
STM () -> IO ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ STM Bool -> STM () -> STM ()
forall (m :: * -> *). Monad m => m Bool -> m () -> m ()
unlessM (TVar Bool -> STM Bool
forall a. TVar a -> STM a
readTVar TVar Bool
closed) (STM () -> STM ()) -> STM () -> STM ()
forall a b. (a -> b) -> a -> b
$ TVar (IntMap (Weak ThreadId))
-> (IntMap (Weak ThreadId) -> IntMap (Weak ThreadId)) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar (IntMap (Weak ThreadId))
clients ((IntMap (Weak ThreadId) -> IntMap (Weak ThreadId)) -> STM ())
-> (IntMap (Weak ThreadId) -> IntMap (Weak ThreadId)) -> STM ()
forall a b. (a -> b) -> a -> b
$ Int
-> Weak ThreadId
-> IntMap (Weak ThreadId)
-> IntMap (Weak ThreadId)
forall a. Int -> a -> IntMap a -> IntMap a
IM.insert Int
cId Weak ThreadId
tId
safeAccept :: Socket -> IO (Socket, SockAddr)
safeAccept :: Socket -> IO (Socket, SockAddr)
safeAccept Socket
sock =
IO (Socket, SockAddr) -> IO (Either IOError (Socket, SockAddr))
forall a. IO a -> IO (Either IOError a)
tryIOError (Socket -> IO (Socket, SockAddr)
accept Socket
sock) IO (Either IOError (Socket, SockAddr))
-> (Either IOError (Socket, SockAddr) -> IO (Socket, SockAddr))
-> IO (Socket, SockAddr)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Right (Socket, SockAddr)
r -> (Socket, SockAddr) -> IO (Socket, SockAddr)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Socket, SockAddr)
r
Left IOError
e
| Bool
retryAccept -> Text -> IO ()
forall (m :: * -> *). (HasCallStack, MonadIO m) => Text -> m ()
logWarn Text
err IO () -> IO (Socket, SockAddr) -> IO (Socket, SockAddr)
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Socket -> IO (Socket, SockAddr)
safeAccept Socket
sock
| Bool
otherwise -> Text -> IO ()
forall (m :: * -> *). (HasCallStack, MonadIO m) => Text -> m ()
logError Text
err IO () -> IO (Socket, SockAddr) -> IO (Socket, SockAddr)
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IOError -> IO (Socket, SockAddr)
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
E.throwIO IOError
e
where
retryAccept :: Bool
retryAccept = Bool -> (CInt -> Bool) -> Maybe CInt -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False ((Errno -> [Errno] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Errno]
again) (Errno -> Bool) -> (CInt -> Errno) -> CInt -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CInt -> Errno
Errno) Maybe CInt
errno
again :: [Errno]
again = [Errno
eCONNABORTED, Errno
eAGAIN, Errno
eNETDOWN, Errno
ePROTO, Errno
eNOPROTOOPT, Errno
eHOSTDOWN, Errno
eNONET, Errno
eHOSTUNREACH, Errno
eOPNOTSUPP, Errno
eNETUNREACH]
err :: Text
err = Text
"socket accept error: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> IOError -> Text
forall a. Show a => a -> Text
tshow IOError
e Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text -> (CInt -> Text) -> Maybe CInt -> Text
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Text
"" ((Text
", errno=" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<>) (Text -> Text) -> (CInt -> Text) -> CInt -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CInt -> Text
forall a. Show a => a -> Text
tshow) Maybe CInt
errno
errno :: Maybe CInt
errno = IOError -> Maybe CInt
ioe_errno IOError
e
type SocketState = (TVar Int, TVar Int, TVar (IntMap (Weak ThreadId)))
data SocketStats = SocketStats
{ SocketStats -> Int
socketsAccepted :: Int,
SocketStats -> Int
socketsClosed :: Int,
SocketStats -> Int
socketsActive :: Int,
SocketStats -> Int
socketsLeaked :: Int
}
newSocketState :: IO SocketState
newSocketState :: IO SocketState
newSocketState = (,,) (TVar Int
-> TVar Int -> TVar (IntMap (Weak ThreadId)) -> SocketState)
-> IO (TVar Int)
-> IO (TVar Int -> TVar (IntMap (Weak ThreadId)) -> SocketState)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO (TVar Int)
forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO Int
0 IO (TVar Int -> TVar (IntMap (Weak ThreadId)) -> SocketState)
-> IO (TVar Int)
-> IO (TVar (IntMap (Weak ThreadId)) -> SocketState)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> IO (TVar Int)
forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO Int
0 IO (TVar (IntMap (Weak ThreadId)) -> SocketState)
-> IO (TVar (IntMap (Weak ThreadId))) -> IO SocketState
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IntMap (Weak ThreadId) -> IO (TVar (IntMap (Weak ThreadId)))
forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO IntMap (Weak ThreadId)
forall a. Monoid a => a
mempty
getSocketStats :: SocketState -> IO SocketStats
getSocketStats :: SocketState -> IO SocketStats
getSocketStats (TVar Int
accepted, TVar Int
closed, TVar (IntMap (Weak ThreadId))
active) = do
Int
socketsAccepted <- TVar Int -> IO Int
forall (m :: * -> *) a. MonadIO m => TVar a -> m a
readTVarIO TVar Int
accepted
Int
socketsClosed <- TVar Int -> IO Int
forall (m :: * -> *) a. MonadIO m => TVar a -> m a
readTVarIO TVar Int
closed
Int
socketsActive <- IntMap (Weak ThreadId) -> Int
forall a. IntMap a -> Int
IM.size (IntMap (Weak ThreadId) -> Int)
-> IO (IntMap (Weak ThreadId)) -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TVar (IntMap (Weak ThreadId)) -> IO (IntMap (Weak ThreadId))
forall (m :: * -> *) a. MonadIO m => TVar a -> m a
readTVarIO TVar (IntMap (Weak ThreadId))
active
let socketsLeaked :: Int
socketsLeaked = Int
socketsAccepted Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
socketsClosed Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
socketsActive
SocketStats -> IO SocketStats
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SocketStats {Int
$sel:socketsAccepted:SocketStats :: Int
socketsAccepted :: Int
socketsAccepted, Int
$sel:socketsClosed:SocketStats :: Int
socketsClosed :: Int
socketsClosed, Int
$sel:socketsActive:SocketStats :: Int
socketsActive :: Int
socketsActive, Int
$sel:socketsLeaked:SocketStats :: Int
socketsLeaked :: Int
socketsLeaked}
closeServer :: TMVar Bool -> TVar (IntMap (Weak ThreadId)) -> Socket -> IO ()
closeServer :: TMVar Bool -> TVar (IntMap (Weak ThreadId)) -> Socket -> IO ()
closeServer TMVar Bool
started TVar (IntMap (Weak ThreadId))
clients Socket
sock = do
Socket -> IO ()
close Socket
sock
TVar (IntMap (Weak ThreadId)) -> IO (IntMap (Weak ThreadId))
forall (m :: * -> *) a. MonadIO m => TVar a -> m a
readTVarIO TVar (IntMap (Weak ThreadId))
clients IO (IntMap (Weak ThreadId))
-> (IntMap (Weak ThreadId) -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Weak ThreadId -> IO ()) -> IntMap (Weak ThreadId) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Weak ThreadId -> IO (Maybe ThreadId)
forall v. Weak v -> IO (Maybe v)
deRefWeak (Weak ThreadId -> IO (Maybe ThreadId))
-> (Maybe ThreadId -> IO ()) -> Weak ThreadId -> IO ()
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> (ThreadId -> IO ()) -> Maybe ThreadId -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ThreadId -> IO ()
forall (m :: * -> *). MonadIO m => ThreadId -> m ()
killThread)
IO Bool -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Bool -> IO ()) -> (STM Bool -> IO Bool) -> STM Bool -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. STM Bool -> IO Bool
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM Bool -> IO ()) -> STM Bool -> IO ()
forall a b. (a -> b) -> a -> b
$ TMVar Bool -> Bool -> STM Bool
forall a. TMVar a -> a -> STM Bool
tryPutTMVar TMVar Bool
started Bool
False
startTCPServer :: TMVar Bool -> Maybe HostName -> ServiceName -> IO Socket
startTCPServer :: TMVar Bool -> Maybe HostName -> HostName -> IO Socket
startTCPServer TMVar Bool
started Maybe HostName
host HostName
port = IO Socket -> IO Socket
forall a. IO a -> IO a
withSocketsDo (IO Socket -> IO Socket) -> IO Socket -> IO Socket
forall a b. (a -> b) -> a -> b
$ IO AddrInfo
resolve IO AddrInfo -> (AddrInfo -> IO Socket) -> IO Socket
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= AddrInfo -> IO Socket
open IO Socket -> (Socket -> IO Socket) -> IO Socket
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Socket -> IO Socket
forall {m :: * -> *} {b}. MonadIO m => b -> m b
setStarted
where
resolve :: IO AddrInfo
resolve =
let hints :: AddrInfo
hints = AddrInfo
defaultHints {addrFlags = [AI_PASSIVE], addrSocketType = Stream}
in [AddrInfo] -> AddrInfo
forall {t :: * -> *}. Foldable t => t AddrInfo -> AddrInfo
select ([AddrInfo] -> AddrInfo) -> IO [AddrInfo] -> IO AddrInfo
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe AddrInfo -> Maybe HostName -> Maybe HostName -> IO [AddrInfo]
getAddrInfo (AddrInfo -> Maybe AddrInfo
forall a. a -> Maybe a
Just AddrInfo
hints) Maybe HostName
host (HostName -> Maybe HostName
forall a. a -> Maybe a
Just HostName
port)
select :: t AddrInfo -> AddrInfo
select t AddrInfo
as = Maybe AddrInfo -> AddrInfo
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe AddrInfo -> AddrInfo) -> Maybe AddrInfo -> AddrInfo
forall a b. (a -> b) -> a -> b
$ Family -> Maybe AddrInfo
family Family
AF_INET6 Maybe AddrInfo -> Maybe AddrInfo -> Maybe AddrInfo
forall a. Maybe a -> Maybe a -> Maybe a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Family -> Maybe AddrInfo
family Family
AF_INET
where
family :: Family -> Maybe AddrInfo
family Family
f = (AddrInfo -> Bool) -> t AddrInfo -> Maybe AddrInfo
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((Family -> Family -> Bool
forall a. Eq a => a -> a -> Bool
== Family
f) (Family -> Bool) -> (AddrInfo -> Family) -> AddrInfo -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AddrInfo -> Family
addrFamily) t AddrInfo
as
open :: AddrInfo -> IO Socket
open AddrInfo
addr = do
Socket
sock <- Family -> SocketType -> CInt -> IO Socket
socket (AddrInfo -> Family
addrFamily AddrInfo
addr) (AddrInfo -> SocketType
addrSocketType AddrInfo
addr) (AddrInfo -> CInt
addrProtocol AddrInfo
addr)
Socket -> SocketOption -> Int -> IO ()
setSocketOption Socket
sock SocketOption
ReuseAddr Int
1
Socket -> (CInt -> IO ()) -> IO ()
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
sock CInt -> IO ()
setCloseOnExecIfNeeded
Text -> IO ()
forall (m :: * -> *). (HasCallStack, MonadIO m) => Text -> m ()
logNote (Text -> IO ()) -> Text -> IO ()
forall a b. (a -> b) -> a -> b
$ Text
"binding to " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> SockAddr -> Text
forall a. Show a => a -> Text
tshow (AddrInfo -> SockAddr
addrAddress AddrInfo
addr)
Socket -> SockAddr -> IO ()
bind Socket
sock (SockAddr -> IO ()) -> SockAddr -> IO ()
forall a b. (a -> b) -> a -> b
$ AddrInfo -> SockAddr
addrAddress AddrInfo
addr
Socket -> Int -> IO ()
listen Socket
sock Int
1024
Socket -> IO Socket
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Socket
sock
setStarted :: b -> m b
setStarted b
sock = STM Bool -> m Bool
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (TMVar Bool -> Bool -> STM Bool
forall a. TMVar a -> a -> STM Bool
tryPutTMVar TMVar Bool
started Bool
True) m Bool -> m b -> m b
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> b -> m b
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure b
sock
loadServerCredential :: ServerCredentials -> IO T.Credential
loadServerCredential :: ServerCredentials -> IO Credential
loadServerCredential ServerCredentials {Maybe HostName
$sel:caCertificateFile:ServerCredentials :: ServerCredentials -> Maybe HostName
caCertificateFile :: Maybe HostName
caCertificateFile, HostName
$sel:certificateFile:ServerCredentials :: ServerCredentials -> HostName
certificateFile :: HostName
certificateFile, HostName
$sel:privateKeyFile:ServerCredentials :: ServerCredentials -> HostName
privateKeyFile :: HostName
privateKeyFile} =
HostName
-> [HostName] -> HostName -> IO (Either HostName Credential)
T.credentialLoadX509Chain HostName
certificateFile (Maybe HostName -> [HostName]
forall a. Maybe a -> [a]
maybeToList Maybe HostName
caCertificateFile) HostName
privateKeyFile IO (Either HostName Credential)
-> (Either HostName Credential -> IO Credential) -> IO Credential
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Right Credential
credential -> Credential -> IO Credential
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Credential
credential
Left HostName
_ -> HostName -> IO ()
putStrLn HostName
"invalid credential" IO () -> IO Credential -> IO Credential
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO Credential
forall a. IO a
exitFailure
supportedTLSServerParams :: T.Supported -> TLSServerCredential -> TVar SNICredentialUsed -> Maybe [ALPN] -> T.ServerParams
supportedTLSServerParams :: Supported
-> TLSServerCredential
-> TVar Bool
-> Maybe [ByteString]
-> ServerParams
supportedTLSServerParams Supported
serverSupported TLSServerCredential {Credential
$sel:credential:TLSServerCredential :: TLSServerCredential -> Credential
credential :: Credential
credential, Maybe Credential
$sel:sniCredential:TLSServerCredential :: TLSServerCredential -> Maybe Credential
sniCredential :: Maybe Credential
sniCredential} TVar Bool
sniCredUsed Maybe [ByteString]
alpn_ =
ServerParams
forall a. Default a => a
def
{ T.serverWantClientCert = False,
T.serverHooks =
def
{ T.onServerNameIndication = case sniCredential of
Maybe Credential
Nothing -> \Maybe HostName
_ -> Credentials -> IO Credentials
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Credentials -> IO Credentials) -> Credentials -> IO Credentials
forall a b. (a -> b) -> a -> b
$ [Credential] -> Credentials
T.Credentials [Credential
credential]
Just Credential
sniCred -> \case
Maybe HostName
Nothing -> Credentials -> IO Credentials
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Credentials -> IO Credentials) -> Credentials -> IO Credentials
forall a b. (a -> b) -> a -> b
$ [Credential] -> Credentials
T.Credentials [Credential
credential]
Just HostName
_host -> [Credential] -> Credentials
T.Credentials [Credential
sniCred] Credentials -> IO () -> IO Credentials
forall a b. a -> IO b -> IO a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ STM () -> IO ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (TVar Bool -> Bool -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar Bool
sniCredUsed Bool
True),
T.onALPNClientSuggest = (\[ByteString]
alpn -> ByteString -> IO ByteString
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> IO ByteString)
-> ([ByteString] -> ByteString) -> [ByteString] -> IO ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Maybe ByteString -> ByteString
forall a. a -> Maybe a -> a
fromMaybe ByteString
"" (Maybe ByteString -> ByteString)
-> ([ByteString] -> Maybe ByteString) -> [ByteString] -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> Bool) -> [ByteString] -> Maybe ByteString
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (ByteString -> [ByteString] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ByteString]
alpn)) <$> alpn_
},
T.serverSupported = serverSupported
}
paramsAskClientCert :: TMVar (Maybe X.CertificateChain) -> T.ServerParams -> T.ServerParams
paramsAskClientCert :: TMVar (Maybe CertificateChain) -> ServerParams -> ServerParams
paramsAskClientCert TMVar (Maybe CertificateChain)
clientCert ServerParams
params =
ServerParams
params
{ T.serverWantClientCert = True,
T.serverHooks =
(T.serverHooks params)
{ T.onClientCertificate = \CertificateChain
cc ->
CertificateChain -> IO (Maybe CertificateRejectReason)
validateClientCertificate CertificateChain
cc IO (Maybe CertificateRejectReason)
-> (Maybe CertificateRejectReason -> IO CertificateUsage)
-> IO CertificateUsage
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just CertificateRejectReason
reason -> CertificateRejectReason -> CertificateUsage
T.CertificateUsageReject CertificateRejectReason
reason CertificateUsage -> IO Bool -> IO CertificateUsage
forall a b. a -> IO b -> IO a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ STM Bool -> IO Bool
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (TMVar (Maybe CertificateChain)
-> Maybe CertificateChain -> STM Bool
forall a. TMVar a -> a -> STM Bool
tryPutTMVar TMVar (Maybe CertificateChain)
clientCert Maybe CertificateChain
forall a. Maybe a
Nothing)
Maybe CertificateRejectReason
Nothing -> CertificateUsage
T.CertificateUsageAccept CertificateUsage -> IO Bool -> IO CertificateUsage
forall a b. a -> IO b -> IO a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ STM Bool -> IO Bool
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (TMVar (Maybe CertificateChain)
-> Maybe CertificateChain -> STM Bool
forall a. TMVar a -> a -> STM Bool
tryPutTMVar TMVar (Maybe CertificateChain)
clientCert (Maybe CertificateChain -> STM Bool)
-> Maybe CertificateChain -> STM Bool
forall a b. (a -> b) -> a -> b
$ CertificateChain -> Maybe CertificateChain
forall a. a -> Maybe a
Just CertificateChain
cc)
}
}
validateClientCertificate :: X.CertificateChain -> IO (Maybe T.CertificateRejectReason)
validateClientCertificate :: CertificateChain -> IO (Maybe CertificateRejectReason)
validateClientCertificate CertificateChain
cc = case CertificateChain -> ChainCertificates
chainIdCaCerts CertificateChain
cc of
ChainCertificates
CCEmpty -> Maybe CertificateRejectReason -> IO (Maybe CertificateRejectReason)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe CertificateRejectReason
forall a. Maybe a
Nothing
CCSelf SignedExact Certificate
cert -> SignedExact Certificate -> IO (Maybe CertificateRejectReason)
validate SignedExact Certificate
cert
CCValid {SignedExact Certificate
caCert :: SignedExact Certificate
caCert :: ChainCertificates -> SignedExact Certificate
caCert} -> SignedExact Certificate -> IO (Maybe CertificateRejectReason)
validate SignedExact Certificate
caCert
ChainCertificates
CCLong -> Maybe CertificateRejectReason -> IO (Maybe CertificateRejectReason)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe CertificateRejectReason
-> IO (Maybe CertificateRejectReason))
-> Maybe CertificateRejectReason
-> IO (Maybe CertificateRejectReason)
forall a b. (a -> b) -> a -> b
$ CertificateRejectReason -> Maybe CertificateRejectReason
forall a. a -> Maybe a
Just (CertificateRejectReason -> Maybe CertificateRejectReason)
-> CertificateRejectReason -> Maybe CertificateRejectReason
forall a b. (a -> b) -> a -> b
$ HostName -> CertificateRejectReason
T.CertificateRejectOther HostName
"chain too long"
where
validate :: SignedExact Certificate -> IO (Maybe CertificateRejectReason)
validate SignedExact Certificate
caCert = [FailedReason] -> Maybe CertificateRejectReason
usage ([FailedReason] -> Maybe CertificateRejectReason)
-> IO [FailedReason] -> IO (Maybe CertificateRejectReason)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SignedExact Certificate
-> (HostName, ByteString) -> CertificateChain -> IO [FailedReason]
x509validate SignedExact Certificate
caCert (HostName
"", ByteString
B.empty) CertificateChain
cc
usage :: [FailedReason] -> Maybe CertificateRejectReason
usage [] = Maybe CertificateRejectReason
forall a. Maybe a
Nothing
usage [FailedReason]
r =
CertificateRejectReason -> Maybe CertificateRejectReason
forall a. a -> Maybe a
Just (CertificateRejectReason -> Maybe CertificateRejectReason)
-> CertificateRejectReason -> Maybe CertificateRejectReason
forall a b. (a -> b) -> a -> b
$
if
| FailedReason
XV.Expired FailedReason -> [FailedReason] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [FailedReason]
r Bool -> Bool -> Bool
|| FailedReason
XV.InFuture FailedReason -> [FailedReason] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [FailedReason]
r -> CertificateRejectReason
T.CertificateRejectExpired
| FailedReason
XV.UnknownCA FailedReason -> [FailedReason] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [FailedReason]
r -> CertificateRejectReason
T.CertificateRejectUnknownCA
| Bool
otherwise -> HostName -> CertificateRejectReason
T.CertificateRejectOther ([FailedReason] -> HostName
forall a. Show a => a -> HostName
show [FailedReason]
r)
loadFingerprint :: ServerCredentials -> IO Fingerprint
loadFingerprint :: ServerCredentials -> IO Fingerprint
loadFingerprint ServerCredentials {Maybe HostName
$sel:caCertificateFile:ServerCredentials :: ServerCredentials -> Maybe HostName
caCertificateFile :: Maybe HostName
caCertificateFile} = case Maybe HostName
caCertificateFile of
Just HostName
certificateFile -> HostName -> IO Fingerprint
loadFileFingerprint HostName
certificateFile
Maybe HostName
Nothing -> HostName -> IO Fingerprint
forall a. HasCallStack => HostName -> a
error HostName
"CA file must be used in protocol credentials"
loadFileFingerprint :: FilePath -> IO Fingerprint
loadFileFingerprint :: HostName -> IO Fingerprint
loadFileFingerprint HostName
certificateFile = do
(SignedExact Certificate
cert : [SignedExact Certificate]
_) <- HostName -> IO [SignedExact Certificate]
forall a. SignedObject a => HostName -> IO [SignedExact a]
SX.readSignedObject HostName
certificateFile
Fingerprint -> IO Fingerprint
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Fingerprint -> IO Fingerprint) -> Fingerprint -> IO Fingerprint
forall a b. (a -> b) -> a -> b
$ SignedExact Certificate -> HashALG -> Fingerprint
forall a.
(Show a, Eq a, ASN1Object a) =>
SignedExact a -> HashALG -> Fingerprint
XV.getFingerprint (SignedExact Certificate
cert :: X.SignedExact X.Certificate) HashALG
X.HashSHA256