{-# LANGUAGE CApiFFI #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE TemplateHaskell #-}

module Simplex.Messaging.Transport.KeepAlive
  ( KeepAliveOpts (..),
    defaultKeepAliveOpts,
    setSocketKeepAlive,
  ) where

import qualified Data.Aeson.TH as J
import Foreign.C (CInt (..))
import Network.Socket
import Simplex.Messaging.Parsers (defaultJSON)

data KeepAliveOpts = KeepAliveOpts
  { KeepAliveOpts -> Int
keepIdle :: Int,
    KeepAliveOpts -> Int
keepIntvl :: Int,
    KeepAliveOpts -> Int
keepCnt :: Int
  }
  deriving (KeepAliveOpts -> KeepAliveOpts -> Bool
(KeepAliveOpts -> KeepAliveOpts -> Bool)
-> (KeepAliveOpts -> KeepAliveOpts -> Bool) -> Eq KeepAliveOpts
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: KeepAliveOpts -> KeepAliveOpts -> Bool
== :: KeepAliveOpts -> KeepAliveOpts -> Bool
$c/= :: KeepAliveOpts -> KeepAliveOpts -> Bool
/= :: KeepAliveOpts -> KeepAliveOpts -> Bool
Eq, Int -> KeepAliveOpts -> ShowS
[KeepAliveOpts] -> ShowS
KeepAliveOpts -> String
(Int -> KeepAliveOpts -> ShowS)
-> (KeepAliveOpts -> String)
-> ([KeepAliveOpts] -> ShowS)
-> Show KeepAliveOpts
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> KeepAliveOpts -> ShowS
showsPrec :: Int -> KeepAliveOpts -> ShowS
$cshow :: KeepAliveOpts -> String
show :: KeepAliveOpts -> String
$cshowList :: [KeepAliveOpts] -> ShowS
showList :: [KeepAliveOpts] -> ShowS
Show)

defaultKeepAliveOpts :: KeepAliveOpts
defaultKeepAliveOpts :: KeepAliveOpts
defaultKeepAliveOpts =
  KeepAliveOpts
    { keepIdle :: Int
keepIdle = Int
30,
      keepIntvl :: Int
keepIntvl = Int
15,
      keepCnt :: Int
keepCnt = Int
4
    }

_SOL_TCP :: CInt
_SOL_TCP :: CInt
_SOL_TCP = CInt
6

#if defined(mingw32_HOST_OS)
-- Windows

-- The values are copied from windows::Win32::Networking::WinSock
-- https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/index.html

_TCP_KEEPIDLE :: CInt
_TCP_KEEPIDLE = 3

_TCP_KEEPINTVL :: CInt
_TCP_KEEPINTVL = 17

_TCP_KEEPCNT :: CInt
_TCP_KEEPCNT = 16

#else
-- Mac/Linux

#if defined(darwin_HOST_OS)
foreign import capi "netinet/tcp.h value TCP_KEEPALIVE" _TCP_KEEPIDLE :: CInt
#else
foreign import capi "netinet/tcp.h value TCP_KEEPIDLE" _TCP_KEEPIDLE :: CInt
#endif

foreign import capi "netinet/tcp.h value TCP_KEEPINTVL" _TCP_KEEPINTVL :: CInt

foreign import capi "netinet/tcp.h value TCP_KEEPCNT" _TCP_KEEPCNT :: CInt

#endif

setSocketKeepAlive :: Socket -> KeepAliveOpts -> IO ()
setSocketKeepAlive :: Socket -> KeepAliveOpts -> IO ()
setSocketKeepAlive Socket
sock KeepAliveOpts {Int
keepCnt :: KeepAliveOpts -> Int
keepCnt :: Int
keepCnt, Int
keepIdle :: KeepAliveOpts -> Int
keepIdle :: Int
keepIdle, Int
keepIntvl :: KeepAliveOpts -> Int
keepIntvl :: Int
keepIntvl} = do
  Socket -> SocketOption -> Int -> IO ()
setSocketOption Socket
sock SocketOption
KeepAlive Int
1
  Socket -> SocketOption -> Int -> IO ()
setSocketOption Socket
sock (CInt -> CInt -> SocketOption
SockOpt CInt
_SOL_TCP CInt
_TCP_KEEPIDLE) Int
keepIdle
  Socket -> SocketOption -> Int -> IO ()
setSocketOption Socket
sock (CInt -> CInt -> SocketOption
SockOpt CInt
_SOL_TCP CInt
_TCP_KEEPINTVL) Int
keepIntvl
  Socket -> SocketOption -> Int -> IO ()
setSocketOption Socket
sock (CInt -> CInt -> SocketOption
SockOpt CInt
_SOL_TCP CInt
_TCP_KEEPCNT) Int
keepCnt

$(J.deriveJSON defaultJSON ''KeepAliveOpts)