-----------------------------------------------------------------------------
-- |
-- Module      :  Data.Binary.Bits.Put
-- Copyright   :  (c) Lennart Kolmodin 2010-2011
-- License     :  BSD3-style (see LICENSE)
--
-- Maintainer  :  kolmodin@gmail.com
-- Stability   :  experimental
-- Portability :  portable (should run where the package binary runs)
--
-- Put bits easily.
-----------------------------------------------------------------------------

module Data.Binary.Bits.Put
          ( BitPut
          , runBitPut
          , joinPut

          -- * Data types
          -- ** Bool
          , putBool

          -- ** Words
          , putWord8
          , putWord16be
          , putWord32be
          , putWord64be

          -- ** ByteString
          , putByteString
          )
          where

import qualified Data.Binary.Builder as B
import Data.Binary.Builder ( Builder )
import qualified Data.Binary.Put as Put
import Data.Binary.Put ( Put )

import Data.ByteString

import Control.Applicative
import Data.Bits
import Data.Monoid
import Data.Word

data BitPut a = BitPut { BitPut a -> S -> PairS a
run :: (S -> PairS a) }

data PairS a = PairS a {-# UNPACK #-} !S

data S = S !Builder !Word8 !Int

-- | Put a 1 bit 'Bool'.
putBool :: Bool -> BitPut ()
putBool :: Bool -> BitPut ()
putBool b :: Bool
b = Int -> Word8 -> BitPut ()
putWord8 1 (if Bool
b then 0xff else 0x00)

-- | make_mask 3 = 00000111
make_mask :: (Bits a, Num a) => Int -> a
make_mask :: Int -> a
make_mask n :: Int
n = (1 a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftL` Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n) a -> a -> a
forall a. Num a => a -> a -> a
- 1
{-# SPECIALIZE make_mask :: Int -> Int #-}
{-# SPECIALIZE make_mask :: Int -> Word #-}
{-# SPECIALIZE make_mask :: Int -> Word8 #-}
{-# SPECIALIZE make_mask :: Int -> Word16 #-}
{-# SPECIALIZE make_mask :: Int -> Word32 #-}
{-# SPECIALIZE make_mask :: Int -> Word64 #-}

-- | Put the @n@ lower bits of a 'Word8'.
putWord8 :: Int -> Word8 -> BitPut ()
putWord8 :: Int -> Word8 -> BitPut ()
putWord8 n :: Int
n w :: Word8
w = (S -> PairS ()) -> BitPut ()
forall a. (S -> PairS a) -> BitPut a
BitPut ((S -> PairS ()) -> BitPut ()) -> (S -> PairS ()) -> BitPut ()
forall a b. (a -> b) -> a -> b
$ \s :: S
s -> () -> S -> PairS ()
forall a. a -> S -> PairS a
PairS () (S -> PairS ()) -> S -> PairS ()
forall a b. (a -> b) -> a -> b
$
  let w' :: Word8
w' = Int -> Word8
forall a. (Bits a, Num a) => Int -> a
make_mask Int
n Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
w in
  case S
s of
                -- a whole word8, no offset
    (S b :: Builder
b t :: Word8
t o :: Int
o) | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 8 Bool -> Bool -> Bool
&& Int
o Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 0 -> S -> S
flush (S -> S) -> S -> S
forall a b. (a -> b) -> a -> b
$ Builder -> Word8 -> Int -> S
S Builder
b Word8
w Int
n
                -- less than a word8, will fit in the current word8
              | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 8 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
o       -> S -> S
flush (S -> S) -> S -> S
forall a b. (a -> b) -> a -> b
$ Builder -> Word8 -> Int -> S
S Builder
b (Word8
t Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|. (Word8
w' Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
`shiftL` (8 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
o))) (Int
oInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
n)
                -- will finish this word8, and spill into the next one
              | Bool
otherwise -> S -> S
flush (S -> S) -> S -> S
forall a b. (a -> b) -> a -> b
$
                              let o' :: Int
o' = Int
o Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- 8
                                  b' :: Word8
b' = Word8
t Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|. (Word8
w' Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
`shiftR` Int
o')
                                  t' :: Word8
t' = Word8
w Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
`shiftL` (8 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
o')
                              in Builder -> Word8 -> Int -> S
S (Builder
b Builder -> Builder -> Builder
forall a. Monoid a => a -> a -> a
`mappend` Word8 -> Builder
B.singleton Word8
b') Word8
t' Int
o'

-- | Put the @n@ lower bits of a 'Word16'.
putWord16be :: Int -> Word16 -> BitPut ()
putWord16be :: Int -> Word16 -> BitPut ()
putWord16be n :: Int
n w :: Word16
w
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 8 = Int -> Word8 -> BitPut ()
putWord8 Int
n (Word16 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
w)
  | Bool
otherwise =
      (S -> PairS ()) -> BitPut ()
forall a. (S -> PairS a) -> BitPut a
BitPut ((S -> PairS ()) -> BitPut ()) -> (S -> PairS ()) -> BitPut ()
forall a b. (a -> b) -> a -> b
$ \s :: S
s -> () -> S -> PairS ()
forall a. a -> S -> PairS a
PairS () (S -> PairS ()) -> S -> PairS ()
forall a b. (a -> b) -> a -> b
$
        let w' :: Word16
w' = Int -> Word16
forall a. (Bits a, Num a) => Int -> a
make_mask Int
n Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.&. Word16
w in
        case S
s of
          -- as n>=9, it's too big to fit into one single byte
          -- it'll either use 2 or 3 bytes
                                     -- it'll fit in 2 bytes
          (S b :: Builder
b t :: Word8
t o :: Int
o) | Int
o Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 16 -> S -> S
flush (S -> S) -> S -> S
forall a b. (a -> b) -> a -> b
$
                        let o' :: Int
o' = Int
o Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- 8
                            b' :: Word8
b' = Word8
t Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|. Word16 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16
w' Word16 -> Int -> Word16
forall a. Bits a => a -> Int -> a
`shiftR` Int
o')
                            t' :: Word8
t' = Word16 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16
w Word16 -> Int -> Word16
forall a. Bits a => a -> Int -> a
`shiftL` (8Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
o'))
                        in (Builder -> Word8 -> Int -> S
S (Builder
b Builder -> Builder -> Builder
forall a. Monoid a => a -> a -> a
`mappend` Word8 -> Builder
B.singleton Word8
b') Word8
t' Int
o')
                                   -- 3 bytes required
                    | Bool
otherwise -> S -> S
flush (S -> S) -> S -> S
forall a b. (a -> b) -> a -> b
$
                        let o' :: Int
o'  = Int
o Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- 16
                            b' :: Word8
b'  = Word8
t Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|. Word16 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16
w' Word16 -> Int -> Word16
forall a. Bits a => a -> Int -> a
`shiftR` (Int
o' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 8))
                            b'' :: Word8
b'' = Word16 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word16
w Word16 -> Int -> Word16
forall a. Bits a => a -> Int -> a
`shiftR` Int
o') Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.&. 0xff)
                            t' :: Word8
t'  = Word16 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16
w Word16 -> Int -> Word16
forall a. Bits a => a -> Int -> a
`shiftL` (8Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
o'))
                        in (Builder -> Word8 -> Int -> S
S (Builder
b Builder -> Builder -> Builder
forall a. Monoid a => a -> a -> a
`mappend` Word8 -> Builder
B.singleton Word8
b' Builder -> Builder -> Builder
forall a. Monoid a => a -> a -> a
`mappend` Word8 -> Builder
B.singleton Word8
b'') Word8
t' Int
o')

-- | Put the @n@ lower bits of a 'Word32'.
putWord32be :: Int -> Word32 -> BitPut ()
putWord32be :: Int -> Word32 -> BitPut ()
putWord32be n :: Int
n w :: Word32
w
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 16 = Int -> Word16 -> BitPut ()
putWord16be Int
n (Word32 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
w)
  | Bool
otherwise = do
      Int -> Word32 -> BitPut ()
putWord32be (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-16) (Word32
wWord32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR`16)
      Int -> Word32 -> BitPut ()
putWord32be    16  (Word32
w Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. 0x0000ffff)

-- | Put the @n@ lower bits of a 'Word64'.
putWord64be :: Int -> Word64 -> BitPut ()
putWord64be :: Int -> Word64 -> BitPut ()
putWord64be n :: Int
n w :: Word64
w
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 32 = Int -> Word32 -> BitPut ()
putWord32be Int
n (Word64 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
w)
  | Bool
otherwise = do
      Int -> Word64 -> BitPut ()
putWord64be (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-32) (Word64
wWord64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shiftR`32)
      Int -> Word64 -> BitPut ()
putWord64be    32  (Word64
w Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. 0xffffffff)

-- | Put a 'ByteString'.
putByteString :: ByteString -> BitPut ()
putByteString :: ByteString -> BitPut ()
putByteString bs :: ByteString
bs = do
  Bool
offset <- BitPut Bool
hasOffset
  if Bool
offset
    then (Word8 -> BitPut ()) -> [Word8] -> BitPut ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Int -> Word8 -> BitPut ()
putWord8 8) (ByteString -> [Word8]
unpack ByteString
bs) -- naive
    else Put -> BitPut ()
joinPut (ByteString -> Put
Put.putByteString ByteString
bs)
  where
    hasOffset :: BitPut Bool
hasOffset = (S -> PairS Bool) -> BitPut Bool
forall a. (S -> PairS a) -> BitPut a
BitPut ((S -> PairS Bool) -> BitPut Bool)
-> (S -> PairS Bool) -> BitPut Bool
forall a b. (a -> b) -> a -> b
$ \ s :: S
s@(S _ _ o :: Int
o) -> Bool -> S -> PairS Bool
forall a. a -> S -> PairS a
PairS (Int
o Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= 0) S
s

-- | Run a 'Put' inside 'BitPut'. Any partially written bytes will be flushed
-- before 'Put' executes to ensure byte alignment.
joinPut :: Put -> BitPut ()
joinPut :: Put -> BitPut ()
joinPut m :: Put
m = (S -> PairS ()) -> BitPut ()
forall a. (S -> PairS a) -> BitPut a
BitPut ((S -> PairS ()) -> BitPut ()) -> (S -> PairS ()) -> BitPut ()
forall a b. (a -> b) -> a -> b
$ \s0 :: S
s0 -> () -> S -> PairS ()
forall a. a -> S -> PairS a
PairS () (S -> PairS ()) -> S -> PairS ()
forall a b. (a -> b) -> a -> b
$
  let (S b0 :: Builder
b0 _ _) = S -> S
flushIncomplete S
s0
      b :: Builder
b = Put -> Builder
forall a. PutM a -> Builder
Put.execPut Put
m
  in (Builder -> Word8 -> Int -> S
S (Builder
b0Builder -> Builder -> Builder
forall a. Monoid a => a -> a -> a
`mappend`Builder
b) 0 0)

flush :: S -> S
flush :: S -> S
flush s :: S
s@(S b :: Builder
b w :: Word8
w o :: Int
o)
  | Int
o Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> 8 = [Char] -> S
forall a. HasCallStack => [Char] -> a
error "flush: offset > 8"
  | Int
o Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 8 = Builder -> Word8 -> Int -> S
S (Builder
b Builder -> Builder -> Builder
forall a. Monoid a => a -> a -> a
`mappend` Word8 -> Builder
B.singleton Word8
w) 0 0
  | Bool
otherwise = S
s

flushIncomplete :: S -> S
flushIncomplete :: S -> S
flushIncomplete s :: S
s@(S b :: Builder
b w :: Word8
w o :: Int
o)
  | Int
o Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 0 = S
s
  | Bool
otherwise = (Builder -> Word8 -> Int -> S
S (Builder
b Builder -> Builder -> Builder
forall a. Monoid a => a -> a -> a
`mappend` Word8 -> Builder
B.singleton Word8
w) 0 0)

-- | Run the 'BitPut' monad inside 'Put'.
runBitPut :: BitPut () -> Put.Put
runBitPut :: BitPut () -> Put
runBitPut m :: BitPut ()
m = Builder -> Put
Put.putBuilder Builder
b
  where
  PairS _ s :: S
s = BitPut () -> S -> PairS ()
forall a. BitPut a -> S -> PairS a
run BitPut ()
m (Builder -> Word8 -> Int -> S
S Builder
forall a. Monoid a => a
mempty 0 0)
  (S b :: Builder
b _ _) = S -> S
flushIncomplete S
s

instance Functor BitPut where
  fmap :: (a -> b) -> BitPut a -> BitPut b
fmap f :: a -> b
f (BitPut k :: S -> PairS a
k) = (S -> PairS b) -> BitPut b
forall a. (S -> PairS a) -> BitPut a
BitPut ((S -> PairS b) -> BitPut b) -> (S -> PairS b) -> BitPut b
forall a b. (a -> b) -> a -> b
$ \s :: S
s ->
    let PairS x :: a
x s' :: S
s' = S -> PairS a
k S
s
    in b -> S -> PairS b
forall a. a -> S -> PairS a
PairS (a -> b
f a
x) S
s'

instance Applicative BitPut where
  pure :: a -> BitPut a
pure a :: a
a = (S -> PairS a) -> BitPut a
forall a. (S -> PairS a) -> BitPut a
BitPut (\s :: S
s -> a -> S -> PairS a
forall a. a -> S -> PairS a
PairS a
a S
s)
  (BitPut f :: S -> PairS (a -> b)
f) <*> :: BitPut (a -> b) -> BitPut a -> BitPut b
<*> (BitPut g :: S -> PairS a
g) = (S -> PairS b) -> BitPut b
forall a. (S -> PairS a) -> BitPut a
BitPut ((S -> PairS b) -> BitPut b) -> (S -> PairS b) -> BitPut b
forall a b. (a -> b) -> a -> b
$ \s :: S
s ->
    let PairS a :: a -> b
a s' :: S
s' = S -> PairS (a -> b)
f S
s
        PairS b :: a
b s'' :: S
s'' = S -> PairS a
g S
s'
    in b -> S -> PairS b
forall a. a -> S -> PairS a
PairS (a -> b
a a
b) S
s''

instance Monad BitPut where
  m :: BitPut a
m >>= :: BitPut a -> (a -> BitPut b) -> BitPut b
>>= k :: a -> BitPut b
k = (S -> PairS b) -> BitPut b
forall a. (S -> PairS a) -> BitPut a
BitPut ((S -> PairS b) -> BitPut b) -> (S -> PairS b) -> BitPut b
forall a b. (a -> b) -> a -> b
$ \s :: S
s ->
    let PairS a :: a
a s' :: S
s'  = BitPut a -> S -> PairS a
forall a. BitPut a -> S -> PairS a
run BitPut a
m S
s
        PairS b :: b
b s'' :: S
s'' = BitPut b -> S -> PairS b
forall a. BitPut a -> S -> PairS a
run (a -> BitPut b
k a
a) S
s'
    in b -> S -> PairS b
forall a. a -> S -> PairS a
PairS b
b S
s''
  return :: a -> BitPut a
return x :: a
x = (S -> PairS a) -> BitPut a
forall a. (S -> PairS a) -> BitPut a
BitPut ((S -> PairS a) -> BitPut a) -> (S -> PairS a) -> BitPut a
forall a b. (a -> b) -> a -> b
$ \s :: S
s -> a -> S -> PairS a
forall a. a -> S -> PairS a
PairS a
x S
s