{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | Encrypted tokens/tickets to keep state in the client side.
module Crypto.Token (
    -- * Configuration
    Config,
    defaultConfig,
    interval,
    tokenLifetime,
    threadName,

    -- * Token manager
    TokenManager,
    spawnTokenManager,
    killTokenManager,

    -- * Encryption and decryption
    encryptToken,
    decryptToken,
) where

import Control.Concurrent
import Crypto.Cipher.AES (AES256)
import Crypto.Cipher.Types (AEADMode (..), AuthTag (..))
import qualified Crypto.Cipher.Types as C
import Crypto.Error (maybeCryptoError, throwCryptoError)
import Crypto.Random (getRandomBytes)
import Data.Array.IO
import Data.Bits (xor)
import qualified Data.ByteArray as BA
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BS
import qualified Data.IORef as I
import Data.Word
import Foreign.Ptr
import Foreign.Storable
import GHC.Conc.Sync (labelThread)
import Network.ByteOrder

----------------------------------------------------------------

type Index = Word16
type Counter = Word64

-- | Configuration for token manager.
data Config = Config
    { Config -> BufferSize
interval :: Int
    -- ^ The interval to generate a new secret and remove the oldest one in seconds.
    , Config -> BufferSize
tokenLifetime :: Int
    -- ^ The token lifetime, that is, tokens can be decrypted in this period.
    , Config -> String
threadName :: String
    }
    deriving (Config -> Config -> Bool
(Config -> Config -> Bool)
-> (Config -> Config -> Bool) -> Eq Config
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Config -> Config -> Bool
== :: Config -> Config -> Bool
$c/= :: Config -> Config -> Bool
/= :: Config -> Config -> Bool
Eq, BufferSize -> Config -> ShowS
[Config] -> ShowS
Config -> String
(BufferSize -> Config -> ShowS)
-> (Config -> String) -> ([Config] -> ShowS) -> Show Config
forall a.
(BufferSize -> a -> ShowS)
-> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: BufferSize -> Config -> ShowS
showsPrec :: BufferSize -> Config -> ShowS
$cshow :: Config -> String
show :: Config -> String
$cshowList :: [Config] -> ShowS
showList :: [Config] -> ShowS
Show)

-- | Default configuration to update secrets in 30 minutes (1,800 seconds) and token lifetime is 2 hours (7,200 seconds)
--
-- >>> defaultConfig
-- Config {interval = 1800, tokenLifetime = 7200}
defaultConfig :: Config
defaultConfig :: Config
defaultConfig =
    Config
        { interval :: BufferSize
interval = BufferSize
1800
        , tokenLifetime :: BufferSize
tokenLifetime = BufferSize
7200
        , threadName :: String
threadName = String
"Crypto token manager"
        }

----------------------------------------------------------------

-- fixme: mask

-- | The abstract data type for token manager.
data TokenManager = TokenManager
    { TokenManager -> Header
headerMask :: Header
    , TokenManager -> IO (Secret, Index)
getEncryptSecret :: IO (Secret, Index)
    , TokenManager -> Index -> IO Secret
getDecryptSecret :: Index -> IO Secret
    , TokenManager -> ThreadId
threadId :: ThreadId
    }

-- | Spawning a token manager.
spawnTokenManager :: Config -> IO TokenManager
spawnTokenManager :: Config -> IO TokenManager
spawnTokenManager Config{BufferSize
String
interval :: Config -> BufferSize
tokenLifetime :: Config -> BufferSize
threadName :: Config -> String
interval :: BufferSize
tokenLifetime :: BufferSize
threadName :: String
..} = do
    emp <- IO Secret
emptySecret
    let lim = BufferSize -> Index
forall a b. (Integral a, Num b) => a -> b
fromIntegral (BufferSize
tokenLifetime BufferSize -> BufferSize -> BufferSize
forall a. Integral a => a -> a -> a
`div` BufferSize
interval)
    arr <- newArray (0, lim - 1) emp
    ent <- generateSecret
    writeArray arr 0 ent
    ref <- I.newIORef 0
    tid <- forkIO $ loop arr ref
    labelThread tid threadName
    msk <- newHeaderMask
    return $ TokenManager msk (readCurrentSecret arr ref) (readSecret arr) tid
  where
    loop :: IOArray Index Secret -> IORef Index -> IO b
loop IOArray Index Secret
arr IORef Index
ref = do
        BufferSize -> IO ()
threadDelay (BufferSize
interval BufferSize -> BufferSize -> BufferSize
forall a. Num a => a -> a -> a
* BufferSize
1000000)
        IOArray Index Secret -> IORef Index -> IO ()
update IOArray Index Secret
arr IORef Index
ref
        IOArray Index Secret -> IORef Index -> IO b
loop IOArray Index Secret
arr IORef Index
ref
    update :: IOArray Index Secret -> I.IORef Index -> IO ()
    update :: IOArray Index Secret -> IORef Index -> IO ()
update IOArray Index Secret
arr IORef Index
ref = do
        idx0 <- IORef Index -> IO Index
forall a. IORef a -> IO a
I.readIORef IORef Index
ref
        (_, n) <- getBounds arr
        let idx = (Index
idx0 Index -> Index -> Index
forall a. Num a => a -> a -> a
+ Index
1) Index -> Index -> Index
forall a. Integral a => a -> a -> a
`mod` (Index
n Index -> Index -> Index
forall a. Num a => a -> a -> a
+ Index
1)
        sec <- generateSecret
        writeArray arr idx sec
        I.writeIORef ref idx

-- | Killing a token manager.
killTokenManager :: TokenManager -> IO ()
killTokenManager :: TokenManager -> IO ()
killTokenManager TokenManager{IO (Secret, Index)
ThreadId
Header
Index -> IO Secret
headerMask :: TokenManager -> Header
getEncryptSecret :: TokenManager -> IO (Secret, Index)
getDecryptSecret :: TokenManager -> Index -> IO Secret
threadId :: TokenManager -> ThreadId
headerMask :: Header
getEncryptSecret :: IO (Secret, Index)
getDecryptSecret :: Index -> IO Secret
threadId :: ThreadId
..} = ThreadId -> IO ()
killThread ThreadId
threadId

----------------------------------------------------------------

readSecret :: IOArray Index Secret -> Index -> IO Secret
readSecret :: IOArray Index Secret -> Index -> IO Secret
readSecret IOArray Index Secret
secrets Index
idx0 = do
    (_, n) <- IOArray Index Secret -> IO (Index, Index)
forall i. Ix i => IOArray i Secret -> IO (i, i)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> m (i, i)
getBounds IOArray Index Secret
secrets
    let idx = Index
idx0 Index -> Index -> Index
forall a. Integral a => a -> a -> a
`mod` (Index
n Index -> Index -> Index
forall a. Num a => a -> a -> a
+ Index
1)
    readArray secrets idx

readCurrentSecret :: IOArray Index Secret -> I.IORef Index -> IO (Secret, Index)
readCurrentSecret :: IOArray Index Secret -> IORef Index -> IO (Secret, Index)
readCurrentSecret IOArray Index Secret
arr IORef Index
ref = do
    idx <- IORef Index -> IO Index
forall a. IORef a -> IO a
I.readIORef IORef Index
ref
    sec <- readSecret arr idx
    return (sec, idx)

----------------------------------------------------------------

data Secret = Secret
    { Secret -> ByteString
secretIV :: ByteString
    , Secret -> ByteString
secretKey :: ByteString
    , Secret -> IORef Counter
secretCounter :: I.IORef Counter
    }

emptySecret :: IO Secret
emptySecret :: IO Secret
emptySecret = ByteString -> ByteString -> IORef Counter -> Secret
Secret ByteString
BS.empty ByteString
BS.empty (IORef Counter -> Secret) -> IO (IORef Counter) -> IO Secret
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Counter -> IO (IORef Counter)
forall a. a -> IO (IORef a)
I.newIORef Counter
0

generateSecret :: IO Secret
generateSecret :: IO Secret
generateSecret =
    ByteString -> ByteString -> IORef Counter -> Secret
Secret
        (ByteString -> ByteString -> IORef Counter -> Secret)
-> IO ByteString -> IO (ByteString -> IORef Counter -> Secret)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO ByteString
genIV
        IO (ByteString -> IORef Counter -> Secret)
-> IO ByteString -> IO (IORef Counter -> Secret)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO ByteString
genKey
        IO (IORef Counter -> Secret) -> IO (IORef Counter) -> IO Secret
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Counter -> IO (IORef Counter)
forall a. a -> IO (IORef a)
I.newIORef Counter
0

genKey :: IO ByteString
genKey :: IO ByteString
genKey = BufferSize -> IO ByteString
forall byteArray. ByteArray byteArray => BufferSize -> IO byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
BufferSize -> m byteArray
getRandomBytes BufferSize
keyLength

genIV :: IO ByteString
genIV :: IO ByteString
genIV = BufferSize -> IO ByteString
forall byteArray. ByteArray byteArray => BufferSize -> IO byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
BufferSize -> m byteArray
getRandomBytes BufferSize
ivLength

----------------------------------------------------------------

ivLength :: Int
ivLength :: BufferSize
ivLength = BufferSize
8

keyLength :: Int
keyLength :: BufferSize
keyLength = BufferSize
32

indexLength :: Int
indexLength :: BufferSize
indexLength = BufferSize
2

counterLength :: Int
counterLength :: BufferSize
counterLength = BufferSize
8

tagLength :: Int
tagLength :: BufferSize
tagLength = BufferSize
16

----------------------------------------------------------------

data Header = Header
    { Header -> Index
headerIndex :: Index
    , Header -> Counter
headerCounter :: Counter
    }

encodeHeader :: Header -> IO ByteString
encodeHeader :: Header -> IO ByteString
encodeHeader Header{Index
Counter
headerIndex :: Header -> Index
headerCounter :: Header -> Counter
headerIndex :: Index
headerCounter :: Counter
..} = BufferSize -> (WriteBuffer -> IO ()) -> IO ByteString
withWriteBuffer (BufferSize
indexLength BufferSize -> BufferSize -> BufferSize
forall a. Num a => a -> a -> a
+ BufferSize
counterLength) ((WriteBuffer -> IO ()) -> IO ByteString)
-> (WriteBuffer -> IO ()) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \WriteBuffer
wbuf -> do
    WriteBuffer -> Index -> IO ()
write16 WriteBuffer
wbuf Index
headerIndex
    WriteBuffer -> Counter -> IO ()
write64 WriteBuffer
wbuf Counter
headerCounter

decodeHeader :: ByteString -> IO Header
decodeHeader :: ByteString -> IO Header
decodeHeader ByteString
bs = ByteString -> (ReadBuffer -> IO Header) -> IO Header
forall a. ByteString -> (ReadBuffer -> IO a) -> IO a
withReadBuffer ByteString
bs ((ReadBuffer -> IO Header) -> IO Header)
-> (ReadBuffer -> IO Header) -> IO Header
forall a b. (a -> b) -> a -> b
$ \ReadBuffer
rbuf ->
    Index -> Counter -> Header
Header (Index -> Counter -> Header) -> IO Index -> IO (Counter -> Header)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReadBuffer -> IO Index
forall a. Readable a => a -> IO Index
read16 ReadBuffer
rbuf IO (Counter -> Header) -> IO Counter -> IO Header
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReadBuffer -> IO Counter
forall a. Readable a => a -> IO Counter
read64 ReadBuffer
rbuf

newHeaderMask :: IO Header
newHeaderMask :: IO Header
newHeaderMask = do
    bin <- BufferSize -> IO ByteString
forall byteArray. ByteArray byteArray => BufferSize -> IO byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
BufferSize -> m byteArray
getRandomBytes (BufferSize
indexLength BufferSize -> BufferSize -> BufferSize
forall a. Num a => a -> a -> a
+ BufferSize
counterLength) :: IO ByteString
    decodeHeader bin

----------------------------------------------------------------

xorHeader :: Header -> Header -> Header
xorHeader :: Header -> Header -> Header
xorHeader Header
x Header
y =
    Header
        { headerIndex :: Index
headerIndex = Header -> Index
headerIndex Header
x Index -> Index -> Index
forall a. Bits a => a -> a -> a
`xor` Header -> Index
headerIndex Header
y
        , headerCounter :: Counter
headerCounter = Header -> Counter
headerCounter Header
x Counter -> Counter -> Counter
forall a. Bits a => a -> a -> a
`xor` Header -> Counter
headerCounter Header
y
        }

addHeader :: TokenManager -> Index -> Counter -> ByteString -> IO ByteString
addHeader :: TokenManager -> Index -> Counter -> ByteString -> IO ByteString
addHeader TokenManager{IO (Secret, Index)
ThreadId
Header
Index -> IO Secret
headerMask :: TokenManager -> Header
getEncryptSecret :: TokenManager -> IO (Secret, Index)
getDecryptSecret :: TokenManager -> Index -> IO Secret
threadId :: TokenManager -> ThreadId
headerMask :: Header
getEncryptSecret :: IO (Secret, Index)
getDecryptSecret :: Index -> IO Secret
threadId :: ThreadId
..} Index
idx Counter
counter ByteString
cipher = do
    let hdr :: Header
hdr = Index -> Counter -> Header
Header Index
idx Counter
counter
        mskhdr :: Header
mskhdr = Header
headerMask Header -> Header -> Header
`xorHeader` Header
hdr
    hdrbin <- Header -> IO ByteString
encodeHeader Header
mskhdr
    return (hdrbin `BS.append` cipher)

delHeader
    :: TokenManager -> ByteString -> IO (Maybe (Index, Counter, ByteString))
delHeader :: TokenManager
-> ByteString -> IO (Maybe (Index, Counter, ByteString))
delHeader TokenManager{IO (Secret, Index)
ThreadId
Header
Index -> IO Secret
headerMask :: TokenManager -> Header
getEncryptSecret :: TokenManager -> IO (Secret, Index)
getDecryptSecret :: TokenManager -> Index -> IO Secret
threadId :: TokenManager -> ThreadId
headerMask :: Header
getEncryptSecret :: IO (Secret, Index)
getDecryptSecret :: Index -> IO Secret
threadId :: ThreadId
..} ByteString
token
    | ByteString -> BufferSize
BS.length ByteString
token BufferSize -> BufferSize -> Bool
forall a. Ord a => a -> a -> Bool
< BufferSize
minlen = Maybe (Index, Counter, ByteString)
-> IO (Maybe (Index, Counter, ByteString))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Index, Counter, ByteString)
forall a. Maybe a
Nothing
    | Bool
otherwise = do
        let (ByteString
hdrbin, ByteString
cipher) = BufferSize -> ByteString -> (ByteString, ByteString)
BS.splitAt BufferSize
minlen ByteString
token
        mskhdr <- ByteString -> IO Header
decodeHeader ByteString
hdrbin
        let hdr = Header
headerMask Header -> Header -> Header
`xorHeader` Header
mskhdr
            idx = Header -> Index
headerIndex Header
hdr
            counter = Header -> Counter
headerCounter Header
hdr
        return $ Just (idx, counter, cipher)
  where
    minlen :: BufferSize
minlen = BufferSize
indexLength BufferSize -> BufferSize -> BufferSize
forall a. Num a => a -> a -> a
+ BufferSize
counterLength

-- | Encrypting a target value to get a token.
encryptToken
    :: TokenManager
    -> ByteString
    -> IO ByteString
encryptToken :: TokenManager -> ByteString -> IO ByteString
encryptToken TokenManager
mgr ByteString
x = do
    (secret, idx) <- TokenManager -> IO (Secret, Index)
getEncryptSecret TokenManager
mgr
    (counter, cipher) <- encrypt secret x
    addHeader mgr idx counter cipher

encrypt
    :: Secret
    -> ByteString
    -> IO (Counter, ByteString)
encrypt :: Secret -> ByteString -> IO (Counter, ByteString)
encrypt Secret
secret ByteString
plain = do
    counter <- IORef Counter -> (Counter -> (Counter, Counter)) -> IO Counter
forall a b. IORef a -> (a -> (a, b)) -> IO b
I.atomicModifyIORef' (Secret -> IORef Counter
secretCounter Secret
secret) (\Counter
i -> (Counter
i Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
+ Counter
1, Counter
i))
    nonce <- makeNonce counter $ secretIV secret
    let cipher = ByteString -> ByteString -> ByteString -> ByteString
aes256gcmEncrypt ByteString
plain (Secret -> ByteString
secretKey Secret
secret) ByteString
nonce
    return (counter, cipher)

-- | Decrypting a token to get a target value.
decryptToken
    :: TokenManager
    -> ByteString
    -> IO (Maybe ByteString)
decryptToken :: TokenManager -> ByteString -> IO (Maybe ByteString)
decryptToken TokenManager
mgr ByteString
token = do
    mx <- TokenManager
-> ByteString -> IO (Maybe (Index, Counter, ByteString))
delHeader TokenManager
mgr ByteString
token
    case mx of
        Maybe (Index, Counter, ByteString)
Nothing -> Maybe ByteString -> IO (Maybe ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ByteString
forall a. Maybe a
Nothing
        Just (Index
idx, Counter
counter, ByteString
cipher) -> do
            secret <- TokenManager -> Index -> IO Secret
getDecryptSecret TokenManager
mgr Index
idx
            decrypt secret counter cipher

decrypt
    :: Secret
    -> Counter
    -> ByteString
    -> IO (Maybe ByteString)
decrypt :: Secret -> Counter -> ByteString -> IO (Maybe ByteString)
decrypt Secret
secret Counter
counter ByteString
cipher = do
    nonce <- Counter -> ByteString -> IO ByteString
makeNonce Counter
counter (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Secret -> ByteString
secretIV Secret
secret
    return $ aes256gcmDecrypt cipher (secretKey secret) nonce

makeNonce :: Counter -> ByteString -> IO ByteString
makeNonce :: Counter -> ByteString -> IO ByteString
makeNonce Counter
counter ByteString
iv = do
    cv <- BufferSize -> (Ptr Word8 -> IO ()) -> IO ByteString
BS.create BufferSize
ivLength ((Ptr Word8 -> IO ()) -> IO ByteString)
-> (Ptr Word8 -> IO ()) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> Ptr Counter -> Counter -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word8 -> Ptr Counter
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
ptr) Counter
counter
    return $ iv `BA.xor` cv

----------------------------------------------------------------

constantAdditionalData :: ByteString
constantAdditionalData :: ByteString
constantAdditionalData = ByteString
BS.empty

aes256gcmEncrypt
    :: ByteString
    -> ByteString
    -> ByteString
    -> ByteString
aes256gcmEncrypt :: ByteString -> ByteString -> ByteString -> ByteString
aes256gcmEncrypt ByteString
plain ByteString
key ByteString
nonce = ByteString
cipher ByteString -> ByteString -> ByteString
`BS.append` (Bytes -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert Bytes
tag)
  where
    conn :: AES256
conn = CryptoFailable AES256 -> AES256
forall a. CryptoFailable a -> a
throwCryptoError (ByteString -> CryptoFailable AES256
forall key. ByteArray key => key -> CryptoFailable AES256
forall cipher key.
(Cipher cipher, ByteArray key) =>
key -> CryptoFailable cipher
C.cipherInit ByteString
key) :: AES256
    aeadIni :: AEAD AES256
aeadIni = CryptoFailable (AEAD AES256) -> AEAD AES256
forall a. CryptoFailable a -> a
throwCryptoError (CryptoFailable (AEAD AES256) -> AEAD AES256)
-> CryptoFailable (AEAD AES256) -> AEAD AES256
forall a b. (a -> b) -> a -> b
$ AEADMode -> AES256 -> ByteString -> CryptoFailable (AEAD AES256)
forall cipher iv.
(BlockCipher cipher, ByteArrayAccess iv) =>
AEADMode -> cipher -> iv -> CryptoFailable (AEAD cipher)
forall iv.
ByteArrayAccess iv =>
AEADMode -> AES256 -> iv -> CryptoFailable (AEAD AES256)
C.aeadInit AEADMode
AEAD_GCM AES256
conn ByteString
nonce
    (AuthTag Bytes
tag, ByteString
cipher) = AEAD AES256
-> ByteString -> ByteString -> BufferSize -> (AuthTag, ByteString)
forall aad ba a.
(ByteArrayAccess aad, ByteArray ba) =>
AEAD a -> aad -> ba -> BufferSize -> (AuthTag, ba)
C.aeadSimpleEncrypt AEAD AES256
aeadIni ByteString
constantAdditionalData ByteString
plain BufferSize
tagLength

aes256gcmDecrypt
    :: ByteString
    -> ByteString
    -> ByteString
    -> Maybe ByteString
aes256gcmDecrypt :: ByteString -> ByteString -> ByteString -> Maybe ByteString
aes256gcmDecrypt ByteString
ctexttag ByteString
key ByteString
nonce = do
    aes <- CryptoFailable AES256 -> Maybe AES256
forall a. CryptoFailable a -> Maybe a
maybeCryptoError (CryptoFailable AES256 -> Maybe AES256)
-> CryptoFailable AES256 -> Maybe AES256
forall a b. (a -> b) -> a -> b
$ ByteString -> CryptoFailable AES256
forall key. ByteArray key => key -> CryptoFailable AES256
forall cipher key.
(Cipher cipher, ByteArray key) =>
key -> CryptoFailable cipher
C.cipherInit ByteString
key :: Maybe AES256
    aead <- maybeCryptoError $ C.aeadInit AEAD_GCM aes nonce
    let (ctext, tag) = BS.splitAt (BS.length ctexttag - tagLength) ctexttag
        authtag = Bytes -> AuthTag
AuthTag (Bytes -> AuthTag) -> Bytes -> AuthTag
forall a b. (a -> b) -> a -> b
$ ByteString -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ByteString
tag
    C.aeadSimpleDecrypt aead constantAdditionalData ctext authtag