Last active February 7, 2016 15:45
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Main where
import Criterion.Main
import Control.Monad.State
import Crypto.Error
import Crypto.Cipher.AES
import Crypto.Cipher.Types
import Crypto.Random (getRandomBytes)
import Data.Bits (xor)
import qualified Data.ByteArray as BA
import Data.ByteArray (ByteArrayAccess, ByteArray, ScrubbedBytes)
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Data.Memory.Endian (BE(..), toBE)
import Data.Memory.PtrMethods (memCopy)
import Data.List (foldl')
import Data.Word (Word64, Word8)
import Foreign.Ptr (Ptr, plusPtr)
import Foreign.Storable (peek, peekElemOff, poke, pokeElemOff)
import System.IO.Unsafe (unsafePerformIO)
main :: IO ()
main = do
kek <- getRandomBytes 16 :: IO ByteString
cek <- getRandomBytes 16 :: IO ByteString
let CryptoPassed cipher = cipherInit kek :: CryptoFailable AES256
[ bgroup "keywrap"
[ bench "ptr-based" $ nf (aesKeyWrap cipher) cek
, bench "pure" $ nf (keyWrap cipher) cek
iv' :: ByteString
iv' = BA.replicate 8 166
keyWrap :: BlockCipher128 cipher
=> cipher
-> ByteString
-> ByteString
keyWrap cipher cek =
let p = toBlocks cek
(r0, r) = foldl' (doRound (ecbEncrypt cipher) 1) (iv', p) [0..5]
in BA.concat (r0 : r)
l = B.length cek
n = l `div` 8
doRound enc i (a, r:rs) j =
let b = enc $ B.concat [a, r]
t = fromIntegral ((n*j) + i) :: Word8
a' = txor t (B.take 8 b)
r' = B.drop 8 b
next = doRound enc (i+1) (a', rs) j
in (fst next, r' : snd next)
doRound _ _ (a, []) _ = (a, [])
txor t b = B.snoc (B.init b) (B.last b `xor` t)
toBlocks :: ByteString -> [ByteString]
toBlocks bytes
| bytes == B.empty = []
| otherwise = let (b, bs') = B.splitAt 8 bytes
in b : toBlocks bs'
iv :: Word64
iv = 0xA6A6A6A6A6A6A6A6
:: BlockCipher128 cipher
=> cipher
-> Ptr Word64 -- ^ register
-> (Int, Int) -- ^ step (t) and offset (i)
-> StateT Word64 IO ()
aesKeyWrapStep cipher p (t, i) = do
a <- get
r_i <- lift $ peekElemOff p i
m :: ScrubbedBytes <-
lift $ BA.alloc 16 $ \p' -> poke p' a >> pokeElemOff p' 1 r_i
let b = ecbEncrypt cipher m
b_hi <- lift $ BA.withByteArray b peek
b_lo <- lift $ BA.withByteArray b (`peekElemOff` 1)
put (b_hi `xor` unBE (toBE (fromIntegral t)))
lift $ pokeElemOff p i b_lo
-- | Wrap a secret.
-- Input size must be a multiple of 8 bytes, and at least 16 bytes.
-- Output size is input size plus 8 bytes.
:: (BlockCipher128 cipher)
=> cipher
-> ByteString
-> ByteString
aesKeyWrap cipher m = unsafePerformIO $ do
let n = BA.length m
c <- BA.withByteArray m $ \p ->
BA.alloc (n + 8) $ \p' ->
memCopy (p' `plusPtr` 8) p n
BA.withByteArray c $ \p -> do
let coords = zip [1..] (join (replicate 6 [1 .. n `div` 8]))
a <- execStateT (mapM_ (aesKeyWrapStep cipher p) coords) iv
poke p a
return c
