{-# LANGUAGE FlexibleContexts #-}
module Data.Conduit.OpenPGP.Decrypt
( conduitDecrypt
) where
import Control.Monad (when)
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.IO.Unlift (MonadUnliftIO)
import Control.Monad.Trans.Resource (MonadResource, MonadThrow)
import qualified Crypto.Hash as CH
import qualified Crypto.Hash.Algorithms as CHA
import Data.Binary (get)
import qualified Data.ByteArray as BA
import qualified Data.ByteString as B
import qualified Data.ByteString.Base16.Lazy as B16L
import qualified Data.ByteString.Lazy as BL
import Data.Conduit
import qualified Data.Conduit.Binary as CB
import qualified Data.Conduit.Combinators as CC
import qualified Data.Conduit.List as CL
import Data.Conduit.OpenPGP.Compression (conduitDecompress)
import Data.Conduit.Serialization.Binary (conduitGet)
import Data.Maybe (fromJust, isNothing)
import Codec.Encryption.OpenPGP.CFB (decryptOpenPGPCfb, decryptPreservingNonce)
import Codec.Encryption.OpenPGP.S2K (skesk2Key)
import Codec.Encryption.OpenPGP.Types
data RecursorState =
RecursorState
{ _depth :: Int
, _lastPKESK :: Maybe PKESK
, _lastSKESK :: Maybe SKESK
, _lastNonce :: Maybe B.ByteString
, _lastClearText :: Maybe B.ByteString
}
deriving (Eq, Show)
def :: RecursorState
def = RecursorState 0 Nothing Nothing Nothing Nothing
type InputCallback m = String -> m BL.ByteString
conduitDecrypt ::
(MonadUnliftIO m, MonadResource m, MonadThrow m)
=> InputCallback IO
-> ConduitT Pkt Pkt m ()
conduitDecrypt = conduitDecrypt' def
conduitDecrypt' ::
(MonadUnliftIO m, MonadResource m, MonadThrow m)
=> RecursorState
-> InputCallback IO
-> ConduitT Pkt Pkt m ()
conduitDecrypt' rs cb = CC.concatMapAccumM push rs
where
push ::
(MonadUnliftIO m, MonadResource m, MonadThrow m)
=> Pkt
-> RecursorState
-> m (RecursorState, [Pkt])
push i s
| _depth s > 42 = fail "I think we've been quine-attacked"
| otherwise =
case i of
SKESKPkt {} -> return (s {_lastSKESK = Just (fromPkt i)}, [])
(SymEncDataPkt bs) -> do
d <- decryptSEDP s cb (fromJust . _lastSKESK $ s) bs
return (s, d)
(SymEncIntegrityProtectedDataPkt _ bs) -> do
d <- decryptSEIPDP s cb (fromJust . _lastSKESK $ s) bs
return (s, d)
m@(ModificationDetectionCodePkt mdc) -> do
when (isNothing (_lastClearText s)) $ fail "MDC with no referent"
let mcalculated = calculateMDC <$> _lastNonce s <*> _lastClearText s
when (mcalculated /= Just mdc) $
fail $
"MDC indicates tampering: " ++
show (B16L.encode mdc) ++
" versus " ++
maybe "<empty>" (show . B16L.encode) mcalculated ++
" ... " ++
show (_lastNonce s) ++ " / " ++ show (_lastClearText s)
return (s, [m])
p -> return (s, [p])
decryptSEDP ::
(MonadUnliftIO m, MonadIO m, MonadThrow m)
=> RecursorState
-> InputCallback IO
-> SKESK
-> BL.ByteString
-> m [Pkt]
decryptSEDP rs cb skesk bs
= do
passphrase <- liftIO $ cb "Input the passphrase I want"
let key = skesk2Key skesk passphrase
decrypted =
case decryptOpenPGPCfb
(_skeskSymmetricAlgorithm skesk)
(BL.toStrict bs)
key of
Left e -> error e
Right x -> x
runConduitRes $
CB.sourceLbs (BL.fromStrict decrypted) .| conduitGet get .|
conduitDecompress .|
conduitDecrypt' rs {_depth = _depth rs + 1} cb .|
CL.consume
decryptSEIPDP ::
(MonadUnliftIO m, MonadIO m, MonadThrow m)
=> RecursorState
-> InputCallback IO
-> SKESK
-> BL.ByteString
-> m [Pkt]
decryptSEIPDP rs cb skesk bs
= do
passphrase <- liftIO $ cb "Input the passphrase I want"
let key = skesk2Key skesk passphrase
(nonce, decrypted) =
case decryptPreservingNonce
(_skeskSymmetricAlgorithm skesk)
(BL.toStrict bs)
key of
Left e -> error e
Right x -> x
runConduitRes $
CB.sourceLbs (BL.fromStrict decrypted) .| conduitGet get .|
conduitDecompress .|
conduitDecrypt'
rs
{ _depth = _depth rs + 1
, _lastNonce = Just nonce
, _lastClearText = Just decrypted
}
cb .|
CL.consume
calculateMDC :: B.ByteString -> B.ByteString -> BL.ByteString
calculateMDC nonce garbage
| B.length garbage < 23 = mempty
| otherwise =
BL.fromStrict . BA.convert . (CH.hash :: B.ByteString -> CH.Digest CHA.SHA1) $
nonce <> B.take (B.length garbage - 22) garbage <> B.pack [211, 20]