{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}

module Simplex.Messaging.Transport.Shared
  ( ChainCertificates (..),
    chainIdCaCerts,
    x509validate,
    takePeerCertChain,
  ) where

import Control.Concurrent.STM
import qualified Control.Exception as E
import Control.Logger.Simple (logError)
import Data.ByteString (ByteString)
import qualified Data.X509 as X
import qualified Data.X509.CertificateStore as XS
import qualified Data.X509.Validation as XV
import Network.Socket (HostName)

data ChainCertificates
  = CCEmpty
  | CCSelf X.SignedCertificate
  | CCValid {ChainCertificates -> SignedCertificate
leafCert :: X.SignedCertificate, ChainCertificates -> SignedCertificate
idCert :: X.SignedCertificate, ChainCertificates -> SignedCertificate
caCert :: X.SignedCertificate}
  | CCLong

chainIdCaCerts :: X.CertificateChain -> ChainCertificates
chainIdCaCerts :: CertificateChain -> ChainCertificates
chainIdCaCerts (X.CertificateChain [SignedCertificate]
chain) = case [SignedCertificate]
chain of
  [] -> ChainCertificates
CCEmpty
  [SignedCertificate
cert] -> SignedCertificate -> ChainCertificates
CCSelf SignedCertificate
cert
  [SignedCertificate
leafCert, SignedCertificate
cert] -> CCValid {SignedCertificate
leafCert :: SignedCertificate
leafCert :: SignedCertificate
leafCert, idCert :: SignedCertificate
idCert = SignedCertificate
cert, caCert :: SignedCertificate
caCert = SignedCertificate
cert} -- current long-term online/offline certificates chain
  [SignedCertificate
leafCert, SignedCertificate
idCert, SignedCertificate
caCert] -> CCValid {SignedCertificate
leafCert :: SignedCertificate
leafCert :: SignedCertificate
leafCert, SignedCertificate
idCert :: SignedCertificate
idCert :: SignedCertificate
idCert, SignedCertificate
caCert :: SignedCertificate
caCert :: SignedCertificate
caCert} -- with additional operator certificate (preset in the client)
  [SignedCertificate
leafCert, SignedCertificate
idCert, SignedCertificate
_, SignedCertificate
caCert] -> CCValid {SignedCertificate
leafCert :: SignedCertificate
leafCert :: SignedCertificate
leafCert, SignedCertificate
idCert :: SignedCertificate
idCert :: SignedCertificate
idCert, SignedCertificate
caCert :: SignedCertificate
caCert :: SignedCertificate
caCert} -- with network certificate
  [SignedCertificate]
_ -> ChainCertificates
CCLong

x509validate :: X.SignedCertificate -> (HostName, ByteString) -> X.CertificateChain -> IO [XV.FailedReason]
x509validate :: SignedCertificate
-> (HostName, ByteString) -> CertificateChain -> IO [FailedReason]
x509validate SignedCertificate
caCert (HostName, ByteString)
serviceID = HashALG
-> ValidationHooks
-> ValidationChecks
-> CertificateStore
-> ValidationCache
-> (HostName, ByteString)
-> CertificateChain
-> IO [FailedReason]
XV.validate HashALG
X.HashSHA256 ValidationHooks
XV.defaultHooks ValidationChecks
checks CertificateStore
certStore ValidationCache
noCache (HostName, ByteString)
serviceID
  where
    checks :: ValidationChecks
checks = ValidationChecks
XV.defaultChecks {XV.checkFQHN = False}
    certStore :: CertificateStore
certStore = [SignedCertificate] -> CertificateStore
XS.makeCertificateStore [SignedCertificate
caCert]
    noCache :: ValidationCache
noCache = ValidationCacheQueryCallback
-> ValidationCacheAddCallback -> ValidationCache
XV.ValidationCache (\(HostName, ByteString)
_ Fingerprint
_ Certificate
_ -> ValidationCacheResult -> IO ValidationCacheResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ValidationCacheResult
XV.ValidationCacheUnknown) (\(HostName, ByteString)
_ Fingerprint
_ Certificate
_ -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())

takePeerCertChain :: TMVar (Maybe X.CertificateChain) -> IO (X.CertificateChain)
takePeerCertChain :: TMVar (Maybe CertificateChain) -> IO CertificateChain
takePeerCertChain TMVar (Maybe CertificateChain)
peerCert =
  STM (Maybe (Maybe CertificateChain))
-> IO (Maybe (Maybe CertificateChain))
forall a. STM a -> IO a
atomically (TMVar (Maybe CertificateChain)
-> STM (Maybe (Maybe CertificateChain))
forall a. TMVar a -> STM (Maybe a)
tryTakeTMVar TMVar (Maybe CertificateChain)
peerCert) IO (Maybe (Maybe CertificateChain))
-> (Maybe (Maybe CertificateChain) -> IO CertificateChain)
-> IO CertificateChain
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (Just CertificateChain
cc) -> CertificateChain -> IO CertificateChain
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure CertificateChain
cc
    Just Maybe CertificateChain
Nothing -> Text -> IO ()
forall (m :: * -> *).
(?callStack::CallStack, MonadIO m) =>
Text -> m ()
logError Text
"peer certificate invalid" IO () -> IO CertificateChain -> IO CertificateChain
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IOError -> IO CertificateChain
forall e a. Exception e => e -> IO a
E.throwIO (HostName -> IOError
userError HostName
"peer certificate invalid")
    Maybe (Maybe CertificateChain)
Nothing -> Text -> IO ()
forall (m :: * -> *).
(?callStack::CallStack, MonadIO m) =>
Text -> m ()
logError Text
"certificate hook not called" IO () -> IO CertificateChain -> IO CertificateChain
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IOError -> IO CertificateChain
forall e a. Exception e => e -> IO a
E.throwIO (HostName -> IOError
userError HostName
"certificate hook not called") -- onServerCertificate / onClientCertificate