-- Decrypt.hs: OpenPGP (RFC4880) recursive packet decryption
-- Copyright © 2013-2019  Clint Adams
-- This software is released under the terms of the Expat license.
-- (See the LICENSE file).
{-# LANGUAGE FlexibleContexts #-}

module Data.Conduit.OpenPGP.Decrypt
  ( conduitDecrypt
  ) where

import Control.Monad (when)
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.IO.Unlift (MonadUnliftIO)
import Control.Monad.Trans.Resource (MonadResource, MonadThrow)
import qualified Crypto.Hash as CH
import qualified Crypto.Hash.Algorithms as CHA
import Data.Binary (get)
import qualified Data.ByteArray as BA
import qualified Data.ByteString as B
import qualified Data.ByteString.Base16.Lazy as B16L
import qualified Data.ByteString.Lazy as BL
import Data.Conduit
import qualified Data.Conduit.Binary as CB
import qualified Data.Conduit.Combinators as CC
import qualified Data.Conduit.List as CL
import Data.Conduit.OpenPGP.Compression (conduitDecompress)
import Data.Conduit.Serialization.Binary (conduitGet)
import Data.Maybe (fromJust, isNothing)

import Codec.Encryption.OpenPGP.CFB (decryptOpenPGPCfb, decryptPreservingNonce)
import Codec.Encryption.OpenPGP.S2K (skesk2Key)
import Codec.Encryption.OpenPGP.Types

data RecursorState =
  RecursorState
    { _depth :: Int
    , _lastPKESK :: Maybe PKESK
    , _lastSKESK :: Maybe SKESK
    , _lastNonce :: Maybe B.ByteString
    , _lastClearText :: Maybe B.ByteString
    }
  deriving (Eq, Show)

def :: RecursorState
def = RecursorState 0 Nothing Nothing Nothing Nothing

type InputCallback m = String -> m BL.ByteString

conduitDecrypt ::
     (MonadUnliftIO m, MonadResource m, MonadThrow m)
  => InputCallback IO
  -> ConduitT Pkt Pkt m ()
conduitDecrypt = conduitDecrypt' def

conduitDecrypt' ::
     (MonadUnliftIO m, MonadResource m, MonadThrow m)
  => RecursorState
  -> InputCallback IO
  -> ConduitT Pkt Pkt m ()
conduitDecrypt' rs cb = CC.concatMapAccumM push rs
  where
    push ::
         (MonadUnliftIO m, MonadResource m, MonadThrow m)
      => Pkt
      -> RecursorState
      -> m (RecursorState, [Pkt])
    push i s
      | _depth s > 42 = fail "I think we've been quine-attacked"
      | otherwise =
        case i of
          SKESKPkt {} -> return (s {_lastSKESK = Just (fromPkt i)}, [])
          (SymEncDataPkt bs) -> do
            d <- decryptSEDP s cb (fromJust . _lastSKESK $ s) bs
            return (s, d)
          (SymEncIntegrityProtectedDataPkt _ bs) -> do
            d <- decryptSEIPDP s cb (fromJust . _lastSKESK $ s) bs
            return (s, d)
          m@(ModificationDetectionCodePkt mdc) -> do
            when (isNothing (_lastClearText s)) $ fail "MDC with no referent"
            let mcalculated = calculateMDC <$> _lastNonce s <*> _lastClearText s
            when (mcalculated /= Just mdc) $
              fail $
              "MDC indicates tampering: " ++
              show (B16L.encode mdc) ++
              " versus " ++
              maybe "<empty>" (show . B16L.encode) mcalculated ++
              "  ... " ++
              show (_lastNonce s) ++ " / " ++ show (_lastClearText s)
            return (s, [m])
          p -> return (s, [p])

decryptSEDP ::
     (MonadUnliftIO m, MonadIO m, MonadThrow m)
  => RecursorState
  -> InputCallback IO
  -> SKESK
  -> BL.ByteString
  -> m [Pkt]
decryptSEDP rs cb skesk bs -- FIXME: this shouldn't pass the whole SKESK
 = do
  passphrase <- liftIO $ cb "Input the passphrase I want"
  let key = skesk2Key skesk passphrase
      decrypted =
        case decryptOpenPGPCfb
               (_skeskSymmetricAlgorithm skesk)
               (BL.toStrict bs)
               key of
          Left e -> error e
          Right x -> x
  runConduitRes $
    CB.sourceLbs (BL.fromStrict decrypted) .| conduitGet get .|
    conduitDecompress .|
    conduitDecrypt' rs {_depth = _depth rs + 1} cb .|
    CL.consume

decryptSEIPDP ::
     (MonadUnliftIO m, MonadIO m, MonadThrow m)
  => RecursorState
  -> InputCallback IO
  -> SKESK
  -> BL.ByteString
  -> m [Pkt]
decryptSEIPDP rs cb skesk bs -- FIXME: this shouldn't pass the whole SKESK
 = do
  passphrase <- liftIO $ cb "Input the passphrase I want"
  let key = skesk2Key skesk passphrase
      (nonce, decrypted) =
        case decryptPreservingNonce
               (_skeskSymmetricAlgorithm skesk)
               (BL.toStrict bs)
               key of
          Left e -> error e
          Right x -> x
  runConduitRes $
    CB.sourceLbs (BL.fromStrict decrypted) .| conduitGet get .|
    conduitDecompress .|
    conduitDecrypt'
      rs
        { _depth = _depth rs + 1
        , _lastNonce = Just nonce
        , _lastClearText = Just decrypted
        }
      cb .|
    CL.consume

calculateMDC :: B.ByteString -> B.ByteString -> BL.ByteString
calculateMDC nonce garbage
  | B.length garbage < 23 = mempty -- FIXME: this is horrible
  | otherwise =
    BL.fromStrict . BA.convert . (CH.hash :: B.ByteString -> CH.Digest CHA.SHA1) $
    nonce <> B.take (B.length garbage - 22) garbage <> B.pack [211, 20]