{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Crypto.Token (
Config,
defaultConfig,
interval,
tokenLifetime,
threadName,
TokenManager,
spawnTokenManager,
killTokenManager,
encryptToken,
decryptToken,
) where
import Control.Concurrent
import Crypto.Cipher.AES (AES256)
import Crypto.Cipher.Types (AEADMode (..), AuthTag (..))
import qualified Crypto.Cipher.Types as C
import Crypto.Error (maybeCryptoError, throwCryptoError)
import Crypto.Random (getRandomBytes)
import Data.Array.IO
import Data.Bits (xor)
import qualified Data.ByteArray as BA
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BS
import qualified Data.IORef as I
import Data.Word
import Foreign.Ptr
import Foreign.Storable
import GHC.Conc.Sync (labelThread)
import Network.ByteOrder
type Index = Word16
type Counter = Word64
data Config = Config
{ Config -> BufferSize
interval :: Int
, Config -> BufferSize
tokenLifetime :: Int
, Config -> String
threadName :: String
}
deriving (Config -> Config -> Bool
(Config -> Config -> Bool)
-> (Config -> Config -> Bool) -> Eq Config
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Config -> Config -> Bool
== :: Config -> Config -> Bool
$c/= :: Config -> Config -> Bool
/= :: Config -> Config -> Bool
Eq, BufferSize -> Config -> ShowS
[Config] -> ShowS
Config -> String
(BufferSize -> Config -> ShowS)
-> (Config -> String) -> ([Config] -> ShowS) -> Show Config
forall a.
(BufferSize -> a -> ShowS)
-> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: BufferSize -> Config -> ShowS
showsPrec :: BufferSize -> Config -> ShowS
$cshow :: Config -> String
show :: Config -> String
$cshowList :: [Config] -> ShowS
showList :: [Config] -> ShowS
Show)
defaultConfig :: Config
defaultConfig :: Config
defaultConfig =
Config
{ interval :: BufferSize
interval = BufferSize
1800
, tokenLifetime :: BufferSize
tokenLifetime = BufferSize
7200
, threadName :: String
threadName = String
"Crypto token manager"
}
data TokenManager = TokenManager
{ :: Header
, TokenManager -> IO (Secret, Index)
getEncryptSecret :: IO (Secret, Index)
, TokenManager -> Index -> IO Secret
getDecryptSecret :: Index -> IO Secret
, TokenManager -> ThreadId
threadId :: ThreadId
}
spawnTokenManager :: Config -> IO TokenManager
spawnTokenManager :: Config -> IO TokenManager
spawnTokenManager Config{BufferSize
String
interval :: Config -> BufferSize
tokenLifetime :: Config -> BufferSize
threadName :: Config -> String
interval :: BufferSize
tokenLifetime :: BufferSize
threadName :: String
..} = do
emp <- IO Secret
emptySecret
let lim = BufferSize -> Index
forall a b. (Integral a, Num b) => a -> b
fromIntegral (BufferSize
tokenLifetime BufferSize -> BufferSize -> BufferSize
forall a. Integral a => a -> a -> a
`div` BufferSize
interval)
arr <- newArray (0, lim - 1) emp
ent <- generateSecret
writeArray arr 0 ent
ref <- I.newIORef 0
tid <- forkIO $ loop arr ref
labelThread tid threadName
msk <- newHeaderMask
return $ TokenManager msk (readCurrentSecret arr ref) (readSecret arr) tid
where
loop :: IOArray Index Secret -> IORef Index -> IO b
loop IOArray Index Secret
arr IORef Index
ref = do
BufferSize -> IO ()
threadDelay (BufferSize
interval BufferSize -> BufferSize -> BufferSize
forall a. Num a => a -> a -> a
* BufferSize
1000000)
IOArray Index Secret -> IORef Index -> IO ()
update IOArray Index Secret
arr IORef Index
ref
IOArray Index Secret -> IORef Index -> IO b
loop IOArray Index Secret
arr IORef Index
ref
update :: IOArray Index Secret -> I.IORef Index -> IO ()
update :: IOArray Index Secret -> IORef Index -> IO ()
update IOArray Index Secret
arr IORef Index
ref = do
idx0 <- IORef Index -> IO Index
forall a. IORef a -> IO a
I.readIORef IORef Index
ref
(_, n) <- getBounds arr
let idx = (Index
idx0 Index -> Index -> Index
forall a. Num a => a -> a -> a
+ Index
1) Index -> Index -> Index
forall a. Integral a => a -> a -> a
`mod` (Index
n Index -> Index -> Index
forall a. Num a => a -> a -> a
+ Index
1)
sec <- generateSecret
writeArray arr idx sec
I.writeIORef ref idx
killTokenManager :: TokenManager -> IO ()
killTokenManager :: TokenManager -> IO ()
killTokenManager TokenManager{IO (Secret, Index)
ThreadId
Header
Index -> IO Secret
headerMask :: TokenManager -> Header
getEncryptSecret :: TokenManager -> IO (Secret, Index)
getDecryptSecret :: TokenManager -> Index -> IO Secret
threadId :: TokenManager -> ThreadId
headerMask :: Header
getEncryptSecret :: IO (Secret, Index)
getDecryptSecret :: Index -> IO Secret
threadId :: ThreadId
..} = ThreadId -> IO ()
killThread ThreadId
threadId
readSecret :: IOArray Index Secret -> Index -> IO Secret
readSecret :: IOArray Index Secret -> Index -> IO Secret
readSecret IOArray Index Secret
secrets Index
idx0 = do
(_, n) <- IOArray Index Secret -> IO (Index, Index)
forall i. Ix i => IOArray i Secret -> IO (i, i)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> m (i, i)
getBounds IOArray Index Secret
secrets
let idx = Index
idx0 Index -> Index -> Index
forall a. Integral a => a -> a -> a
`mod` (Index
n Index -> Index -> Index
forall a. Num a => a -> a -> a
+ Index
1)
readArray secrets idx
readCurrentSecret :: IOArray Index Secret -> I.IORef Index -> IO (Secret, Index)
readCurrentSecret :: IOArray Index Secret -> IORef Index -> IO (Secret, Index)
readCurrentSecret IOArray Index Secret
arr IORef Index
ref = do
idx <- IORef Index -> IO Index
forall a. IORef a -> IO a
I.readIORef IORef Index
ref
sec <- readSecret arr idx
return (sec, idx)
data Secret = Secret
{ Secret -> ByteString
secretIV :: ByteString
, Secret -> ByteString
secretKey :: ByteString
, Secret -> IORef Counter
secretCounter :: I.IORef Counter
}
emptySecret :: IO Secret
emptySecret :: IO Secret
emptySecret = ByteString -> ByteString -> IORef Counter -> Secret
Secret ByteString
BS.empty ByteString
BS.empty (IORef Counter -> Secret) -> IO (IORef Counter) -> IO Secret
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Counter -> IO (IORef Counter)
forall a. a -> IO (IORef a)
I.newIORef Counter
0
generateSecret :: IO Secret
generateSecret :: IO Secret
generateSecret =
ByteString -> ByteString -> IORef Counter -> Secret
Secret
(ByteString -> ByteString -> IORef Counter -> Secret)
-> IO ByteString -> IO (ByteString -> IORef Counter -> Secret)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO ByteString
genIV
IO (ByteString -> IORef Counter -> Secret)
-> IO ByteString -> IO (IORef Counter -> Secret)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO ByteString
genKey
IO (IORef Counter -> Secret) -> IO (IORef Counter) -> IO Secret
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Counter -> IO (IORef Counter)
forall a. a -> IO (IORef a)
I.newIORef Counter
0
genKey :: IO ByteString
genKey :: IO ByteString
genKey = BufferSize -> IO ByteString
forall byteArray. ByteArray byteArray => BufferSize -> IO byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
BufferSize -> m byteArray
getRandomBytes BufferSize
keyLength
genIV :: IO ByteString
genIV :: IO ByteString
genIV = BufferSize -> IO ByteString
forall byteArray. ByteArray byteArray => BufferSize -> IO byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
BufferSize -> m byteArray
getRandomBytes BufferSize
ivLength
ivLength :: Int
ivLength :: BufferSize
ivLength = BufferSize
8
keyLength :: Int
keyLength :: BufferSize
keyLength = BufferSize
32
indexLength :: Int
indexLength :: BufferSize
indexLength = BufferSize
2
counterLength :: Int
counterLength :: BufferSize
counterLength = BufferSize
8
tagLength :: Int
tagLength :: BufferSize
tagLength = BufferSize
16
data =
{ :: Index
, :: Counter
}
encodeHeader :: Header -> IO ByteString
Header{Index
Counter
headerIndex :: Header -> Index
headerCounter :: Header -> Counter
headerIndex :: Index
headerCounter :: Counter
..} = BufferSize -> (WriteBuffer -> IO ()) -> IO ByteString
withWriteBuffer (BufferSize
indexLength BufferSize -> BufferSize -> BufferSize
forall a. Num a => a -> a -> a
+ BufferSize
counterLength) ((WriteBuffer -> IO ()) -> IO ByteString)
-> (WriteBuffer -> IO ()) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \WriteBuffer
wbuf -> do
WriteBuffer -> Index -> IO ()
write16 WriteBuffer
wbuf Index
headerIndex
WriteBuffer -> Counter -> IO ()
write64 WriteBuffer
wbuf Counter
headerCounter
decodeHeader :: ByteString -> IO Header
ByteString
bs = ByteString -> (ReadBuffer -> IO Header) -> IO Header
forall a. ByteString -> (ReadBuffer -> IO a) -> IO a
withReadBuffer ByteString
bs ((ReadBuffer -> IO Header) -> IO Header)
-> (ReadBuffer -> IO Header) -> IO Header
forall a b. (a -> b) -> a -> b
$ \ReadBuffer
rbuf ->
Index -> Counter -> Header
Header (Index -> Counter -> Header) -> IO Index -> IO (Counter -> Header)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReadBuffer -> IO Index
forall a. Readable a => a -> IO Index
read16 ReadBuffer
rbuf IO (Counter -> Header) -> IO Counter -> IO Header
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReadBuffer -> IO Counter
forall a. Readable a => a -> IO Counter
read64 ReadBuffer
rbuf
newHeaderMask :: IO Header
= do
bin <- BufferSize -> IO ByteString
forall byteArray. ByteArray byteArray => BufferSize -> IO byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
BufferSize -> m byteArray
getRandomBytes (BufferSize
indexLength BufferSize -> BufferSize -> BufferSize
forall a. Num a => a -> a -> a
+ BufferSize
counterLength) :: IO ByteString
decodeHeader bin
xorHeader :: Header -> Header -> Header
Header
x Header
y =
Header
{ headerIndex :: Index
headerIndex = Header -> Index
headerIndex Header
x Index -> Index -> Index
forall a. Bits a => a -> a -> a
`xor` Header -> Index
headerIndex Header
y
, headerCounter :: Counter
headerCounter = Header -> Counter
headerCounter Header
x Counter -> Counter -> Counter
forall a. Bits a => a -> a -> a
`xor` Header -> Counter
headerCounter Header
y
}
addHeader :: TokenManager -> Index -> Counter -> ByteString -> IO ByteString
TokenManager{IO (Secret, Index)
ThreadId
Header
Index -> IO Secret
headerMask :: TokenManager -> Header
getEncryptSecret :: TokenManager -> IO (Secret, Index)
getDecryptSecret :: TokenManager -> Index -> IO Secret
threadId :: TokenManager -> ThreadId
headerMask :: Header
getEncryptSecret :: IO (Secret, Index)
getDecryptSecret :: Index -> IO Secret
threadId :: ThreadId
..} Index
idx Counter
counter ByteString
cipher = do
let hdr :: Header
hdr = Index -> Counter -> Header
Header Index
idx Counter
counter
mskhdr :: Header
mskhdr = Header
headerMask Header -> Header -> Header
`xorHeader` Header
hdr
hdrbin <- Header -> IO ByteString
encodeHeader Header
mskhdr
return (hdrbin `BS.append` cipher)
delHeader
:: TokenManager -> ByteString -> IO (Maybe (Index, Counter, ByteString))
TokenManager{IO (Secret, Index)
ThreadId
Header
Index -> IO Secret
headerMask :: TokenManager -> Header
getEncryptSecret :: TokenManager -> IO (Secret, Index)
getDecryptSecret :: TokenManager -> Index -> IO Secret
threadId :: TokenManager -> ThreadId
headerMask :: Header
getEncryptSecret :: IO (Secret, Index)
getDecryptSecret :: Index -> IO Secret
threadId :: ThreadId
..} ByteString
token
| ByteString -> BufferSize
BS.length ByteString
token BufferSize -> BufferSize -> Bool
forall a. Ord a => a -> a -> Bool
< BufferSize
minlen = Maybe (Index, Counter, ByteString)
-> IO (Maybe (Index, Counter, ByteString))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Index, Counter, ByteString)
forall a. Maybe a
Nothing
| Bool
otherwise = do
let (ByteString
hdrbin, ByteString
cipher) = BufferSize -> ByteString -> (ByteString, ByteString)
BS.splitAt BufferSize
minlen ByteString
token
mskhdr <- ByteString -> IO Header
decodeHeader ByteString
hdrbin
let hdr = Header
headerMask Header -> Header -> Header
`xorHeader` Header
mskhdr
idx = Header -> Index
headerIndex Header
hdr
counter = Header -> Counter
headerCounter Header
hdr
return $ Just (idx, counter, cipher)
where
minlen :: BufferSize
minlen = BufferSize
indexLength BufferSize -> BufferSize -> BufferSize
forall a. Num a => a -> a -> a
+ BufferSize
counterLength
encryptToken
:: TokenManager
-> ByteString
-> IO ByteString
encryptToken :: TokenManager -> ByteString -> IO ByteString
encryptToken TokenManager
mgr ByteString
x = do
(secret, idx) <- TokenManager -> IO (Secret, Index)
getEncryptSecret TokenManager
mgr
(counter, cipher) <- encrypt secret x
addHeader mgr idx counter cipher
encrypt
:: Secret
-> ByteString
-> IO (Counter, ByteString)
encrypt :: Secret -> ByteString -> IO (Counter, ByteString)
encrypt Secret
secret ByteString
plain = do
counter <- IORef Counter -> (Counter -> (Counter, Counter)) -> IO Counter
forall a b. IORef a -> (a -> (a, b)) -> IO b
I.atomicModifyIORef' (Secret -> IORef Counter
secretCounter Secret
secret) (\Counter
i -> (Counter
i Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
+ Counter
1, Counter
i))
nonce <- makeNonce counter $ secretIV secret
let cipher = ByteString -> ByteString -> ByteString -> ByteString
aes256gcmEncrypt ByteString
plain (Secret -> ByteString
secretKey Secret
secret) ByteString
nonce
return (counter, cipher)
decryptToken
:: TokenManager
-> ByteString
-> IO (Maybe ByteString)
decryptToken :: TokenManager -> ByteString -> IO (Maybe ByteString)
decryptToken TokenManager
mgr ByteString
token = do
mx <- TokenManager
-> ByteString -> IO (Maybe (Index, Counter, ByteString))
delHeader TokenManager
mgr ByteString
token
case mx of
Maybe (Index, Counter, ByteString)
Nothing -> Maybe ByteString -> IO (Maybe ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ByteString
forall a. Maybe a
Nothing
Just (Index
idx, Counter
counter, ByteString
cipher) -> do
secret <- TokenManager -> Index -> IO Secret
getDecryptSecret TokenManager
mgr Index
idx
decrypt secret counter cipher
decrypt
:: Secret
-> Counter
-> ByteString
-> IO (Maybe ByteString)
decrypt :: Secret -> Counter -> ByteString -> IO (Maybe ByteString)
decrypt Secret
secret Counter
counter ByteString
cipher = do
nonce <- Counter -> ByteString -> IO ByteString
makeNonce Counter
counter (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Secret -> ByteString
secretIV Secret
secret
return $ aes256gcmDecrypt cipher (secretKey secret) nonce
makeNonce :: Counter -> ByteString -> IO ByteString
makeNonce :: Counter -> ByteString -> IO ByteString
makeNonce Counter
counter ByteString
iv = do
cv <- BufferSize -> (Ptr Word8 -> IO ()) -> IO ByteString
BS.create BufferSize
ivLength ((Ptr Word8 -> IO ()) -> IO ByteString)
-> (Ptr Word8 -> IO ()) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> Ptr Counter -> Counter -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word8 -> Ptr Counter
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
ptr) Counter
counter
return $ iv `BA.xor` cv
constantAdditionalData :: ByteString
constantAdditionalData :: ByteString
constantAdditionalData = ByteString
BS.empty
aes256gcmEncrypt
:: ByteString
-> ByteString
-> ByteString
-> ByteString
aes256gcmEncrypt :: ByteString -> ByteString -> ByteString -> ByteString
aes256gcmEncrypt ByteString
plain ByteString
key ByteString
nonce = ByteString
cipher ByteString -> ByteString -> ByteString
`BS.append` (Bytes -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert Bytes
tag)
where
conn :: AES256
conn = CryptoFailable AES256 -> AES256
forall a. CryptoFailable a -> a
throwCryptoError (ByteString -> CryptoFailable AES256
forall key. ByteArray key => key -> CryptoFailable AES256
forall cipher key.
(Cipher cipher, ByteArray key) =>
key -> CryptoFailable cipher
C.cipherInit ByteString
key) :: AES256
aeadIni :: AEAD AES256
aeadIni = CryptoFailable (AEAD AES256) -> AEAD AES256
forall a. CryptoFailable a -> a
throwCryptoError (CryptoFailable (AEAD AES256) -> AEAD AES256)
-> CryptoFailable (AEAD AES256) -> AEAD AES256
forall a b. (a -> b) -> a -> b
$ AEADMode -> AES256 -> ByteString -> CryptoFailable (AEAD AES256)
forall cipher iv.
(BlockCipher cipher, ByteArrayAccess iv) =>
AEADMode -> cipher -> iv -> CryptoFailable (AEAD cipher)
forall iv.
ByteArrayAccess iv =>
AEADMode -> AES256 -> iv -> CryptoFailable (AEAD AES256)
C.aeadInit AEADMode
AEAD_GCM AES256
conn ByteString
nonce
(AuthTag Bytes
tag, ByteString
cipher) = AEAD AES256
-> ByteString -> ByteString -> BufferSize -> (AuthTag, ByteString)
forall aad ba a.
(ByteArrayAccess aad, ByteArray ba) =>
AEAD a -> aad -> ba -> BufferSize -> (AuthTag, ba)
C.aeadSimpleEncrypt AEAD AES256
aeadIni ByteString
constantAdditionalData ByteString
plain BufferSize
tagLength
aes256gcmDecrypt
:: ByteString
-> ByteString
-> ByteString
-> Maybe ByteString
aes256gcmDecrypt :: ByteString -> ByteString -> ByteString -> Maybe ByteString
aes256gcmDecrypt ByteString
ctexttag ByteString
key ByteString
nonce = do
aes <- CryptoFailable AES256 -> Maybe AES256
forall a. CryptoFailable a -> Maybe a
maybeCryptoError (CryptoFailable AES256 -> Maybe AES256)
-> CryptoFailable AES256 -> Maybe AES256
forall a b. (a -> b) -> a -> b
$ ByteString -> CryptoFailable AES256
forall key. ByteArray key => key -> CryptoFailable AES256
forall cipher key.
(Cipher cipher, ByteArray key) =>
key -> CryptoFailable cipher
C.cipherInit ByteString
key :: Maybe AES256
aead <- maybeCryptoError $ C.aeadInit AEAD_GCM aes nonce
let (ctext, tag) = BS.splitAt (BS.length ctexttag - tagLength) ctexttag
authtag = Bytes -> AuthTag
AuthTag (Bytes -> AuthTag) -> Bytes -> AuthTag
forall a b. (a -> b) -> a -> b
$ ByteString -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ByteString
tag
C.aeadSimpleDecrypt aead constantAdditionalData ctext authtag