module Simplex.Messaging.Crypto.SNTRUP761.Bindings.RNG
  ( withDRG,
    rngFuncPtr,
    RNGContext,
    RNGFunc,
  ) where

import Control.Concurrent.STM
import Control.Exception (bracket)
import Crypto.Random (ChaChaDRG)
import Data.ByteArray (ByteArrayAccess (copyByteArrayToPtr))
import Foreign
import Foreign.C
import qualified Simplex.Messaging.Crypto as C

withDRG :: TVar ChaChaDRG -> (Ptr RNGContext -> IO a) -> IO a
withDRG :: forall a. TVar ChaChaDRG -> (Ptr RNGContext -> IO a) -> IO a
withDRG TVar ChaChaDRG
drg = IO (Ptr RNGContext)
-> (Ptr RNGContext -> IO RNGContext)
-> (Ptr RNGContext -> IO a)
-> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (StablePtr (TVar ChaChaDRG) -> Ptr RNGContext
forall a. StablePtr a -> Ptr RNGContext
castStablePtrToPtr (StablePtr (TVar ChaChaDRG) -> Ptr RNGContext)
-> IO (StablePtr (TVar ChaChaDRG)) -> IO (Ptr RNGContext)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TVar ChaChaDRG -> IO (StablePtr (TVar ChaChaDRG))
forall a. a -> IO (StablePtr a)
newStablePtr TVar ChaChaDRG
drg) (StablePtr Any -> IO RNGContext
forall a. StablePtr a -> IO RNGContext
freeStablePtr (StablePtr Any -> IO RNGContext)
-> (Ptr RNGContext -> StablePtr Any)
-> Ptr RNGContext
-> IO RNGContext
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr RNGContext -> StablePtr Any
forall a. Ptr RNGContext -> StablePtr a
castPtrToStablePtr)

rngFunc :: RNGFunc
rngFunc :: RNGFunc
rngFunc Ptr RNGContext
cxt CSize
sz Ptr Word8
buf = do
  TVar ChaChaDRG
drg <- StablePtr (TVar ChaChaDRG) -> IO (TVar ChaChaDRG)
forall a. StablePtr a -> IO a
deRefStablePtr (StablePtr (TVar ChaChaDRG) -> IO (TVar ChaChaDRG))
-> StablePtr (TVar ChaChaDRG) -> IO (TVar ChaChaDRG)
forall a b. (a -> b) -> a -> b
$ Ptr RNGContext -> StablePtr (TVar ChaChaDRG)
forall a. Ptr RNGContext -> StablePtr a
castPtrToStablePtr Ptr RNGContext
cxt
  ByteString
bs <- STM ByteString -> IO ByteString
forall a. STM a -> IO a
atomically (STM ByteString -> IO ByteString)
-> STM ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Int -> TVar ChaChaDRG -> STM ByteString
C.randomBytes (CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
sz) TVar ChaChaDRG
drg
  ByteString -> Ptr Word8 -> IO RNGContext
forall p. ByteString -> Ptr p -> IO RNGContext
forall ba p. ByteArrayAccess ba => ba -> Ptr p -> IO RNGContext
copyByteArrayToPtr ByteString
bs Ptr Word8
buf

type RNGContext = ()

-- typedef void random_func (void *ctx, size_t length, uint8_t *dst);
type RNGFunc = Ptr RNGContext -> CSize -> Ptr Word8 -> IO ()

foreign export ccall "haskell_rng_func" rngFunc :: RNGFunc

foreign import ccall "&haskell_rng_func" rngFuncPtr :: FunPtr RNGFunc