{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TupleSections #-}

module Network.TLS.Handshake.Server.ClientHello (
    processClientHello,
) where

import qualified Control.Exception as E
import Crypto.HPKE
import qualified Data.ByteString as BS

import Network.TLS.ECH.Config

import Network.TLS.Compression
import Network.TLS.Context.Internal
import Network.TLS.Extension
import Network.TLS.Handshake.Common
import Network.TLS.Handshake.State
import Network.TLS.Imports
import Network.TLS.Measurement
import Network.TLS.Packet
import Network.TLS.Parameters
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Types

processClientHello
    :: ServerParams
    -> Context
    -> ClientHello
    -> IO
        ( Version
        , ClientHello
        , Maybe ClientRandom -- Just for ECH to keep the outer one for key log
        )
processClientHello :: ServerParams
-> Context
-> ClientHello
-> IO (Version, ClientHello, Maybe ClientRandom)
processClientHello ServerParams
sparams Context
ctx ch :: ClientHello
ch@CH{[CompressionID]
[CipherId]
[ExtensionRaw]
Version
ClientRandom
Session
chVersion :: Version
chRandom :: ClientRandom
chSession :: Session
chCiphers :: [CipherId]
chComps :: [CompressionID]
chExtensions :: [ExtensionRaw]
chCiphers :: ClientHello -> [CipherId]
chComps :: ClientHello -> [CompressionID]
chExtensions :: ClientHello -> [ExtensionRaw]
chRandom :: ClientHello -> ClientRandom
chSession :: ClientHello -> Session
chVersion :: ClientHello -> Version
..} = do
    established <- Context -> IO Established
ctxEstablished Context
ctx
    -- renego is not allowed in TLS 1.3
    when (established /= NotEstablished) $ do
        ver <- usingState_ ctx (getVersionWithDefault TLS12)
        when (ver == TLS13) $
            throwCore $
                Error_Protocol "renegotiation is not allowed in TLS 1.3" UnexpectedMessage
    -- rejecting client initiated renegotiation to prevent DOS.
    eof <- ctxEOF ctx
    let renegotiation = Established
established Established -> Established -> Bool
forall a. Eq a => a -> a -> Bool
== Established
Established Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
eof
    when
        (renegotiation && not (supportedClientInitiatedRenegotiation $ ctxSupported ctx))
        $ throwCore
        $ Error_Protocol_Warning "renegotiation is not allowed" NoRenegotiation
    -- check if policy allow this new handshake to happens
    handshakeAuthorized <- withMeasure ctx (onNewHandshake $ serverHooks sparams)
    unless
        handshakeAuthorized
        (throwCore $ Error_HandshakePolicy "server: handshake denied")
    updateMeasure ctx incrementNbHandshakes

    when (chVersion /= TLS12) $
        throwCore $
            Error_Protocol (show chVersion ++ " is not supported") ProtocolVersion

    -- Fallback SCSV: RFC7507
    -- TLS_FALLBACK_SCSV: {0x56, 0x00}
    when
        ( supportedFallbackScsv (ctxSupported ctx)
            && (CipherId 0x5600 `elem` chCiphers)
            && chVersion < TLS12
        )
        $ throwCore
        $ Error_Protocol "fallback is not allowed" InappropriateFallback

    -- choosing TLS version
    let extract (SupportedVersionsClientHello [Version]
vers) = [Version]
vers -- fixme: vers == []
        extract SupportedVersions
_ = []
        clientVersions =
            ExtensionID
-> MessageType
-> [ExtensionRaw]
-> [Version]
-> (SupportedVersions -> [Version])
-> [Version]
forall e a.
Extension e =>
ExtensionID -> MessageType -> [ExtensionRaw] -> a -> (e -> a) -> a
lookupAndDecode ExtensionID
EID_SupportedVersions MessageType
MsgTClientHello [ExtensionRaw]
chExtensions [] SupportedVersions -> [Version]
extract
        clientVersion = Version -> Version -> Version
forall a. Ord a => a -> a -> a
min Version
TLS12 Version
chVersion
        serverVersions
            | Bool
renegotiation = (Version -> Bool) -> [Version] -> [Version]
forall a. (a -> Bool) -> [a] -> [a]
filter (Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
< Version
TLS13) (Supported -> [Version]
supportedVersions (Supported -> [Version]) -> Supported -> [Version]
forall a b. (a -> b) -> a -> b
$ Context -> Supported
ctxSupported Context
ctx)
            | Bool
otherwise = Supported -> [Version]
supportedVersions (Supported -> [Version]) -> Supported -> [Version]
forall a b. (a -> b) -> a -> b
$ Context -> Supported
ctxSupported Context
ctx
        mVersion = DebugParams -> Maybe Version
debugVersionForced (DebugParams -> Maybe Version) -> DebugParams -> Maybe Version
forall a b. (a -> b) -> a -> b
$ ServerParams -> DebugParams
serverDebug ServerParams
sparams
    chosenVersion <- case mVersion of
        Just Version
cver -> Version -> IO Version
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Version
cver
        Maybe Version
Nothing ->
            if (Version
TLS13 Version -> [Version] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Version]
serverVersions) Bool -> Bool -> Bool
&& [Version]
clientVersions [Version] -> [Version] -> Bool
forall a. Eq a => a -> a -> Bool
/= []
                then case [Version] -> [Version] -> Maybe Version
findHighestVersionFrom13 [Version]
clientVersions [Version]
serverVersions of
                    Maybe Version
Nothing ->
                        TLSError -> IO Version
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO Version) -> TLSError -> IO Version
forall a b. (a -> b) -> a -> b
$
                            [Char] -> AlertDescription -> TLSError
Error_Protocol
                                ([Char]
"client versions " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Version] -> [Char]
forall a. Show a => a -> [Char]
show [Version]
clientVersions [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" is not supported")
                                AlertDescription
ProtocolVersion
                    Just Version
v -> Version -> IO Version
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Version
v
                else case Version -> [Version] -> Maybe Version
findHighestVersionFrom Version
clientVersion [Version]
serverVersions of
                    Maybe Version
Nothing ->
                        TLSError -> IO Version
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO Version) -> TLSError -> IO Version
forall a b. (a -> b) -> a -> b
$
                            [Char] -> AlertDescription -> TLSError
Error_Protocol
                                ([Char]
"client version " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Version -> [Char]
forall a. Show a => a -> [Char]
show Version
clientVersion [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" is not supported")
                                AlertDescription
ProtocolVersion
                    Just Version
v -> Version -> IO Version
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Version
v

    -- Checking compression
    let nullComp = Compression -> CompressionID
compressionID Compression
nullCompression
    case chosenVersion of
        Version
TLS13 ->
            Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([CompressionID]
chComps [CompressionID] -> [CompressionID] -> Bool
forall a. Eq a => a -> a -> Bool
/= [CompressionID
nullComp]) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
                TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
                    [Char] -> AlertDescription -> TLSError
Error_Protocol [Char]
"compression is not allowed in TLS 1.3" AlertDescription
IllegalParameter
        Version
_ -> case (CompressionID -> Bool) -> [CompressionID] -> Maybe CompressionID
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (CompressionID -> CompressionID -> Bool
forall a. Eq a => a -> a -> Bool
== CompressionID
nullComp) [CompressionID]
chComps of
            Maybe CompressionID
Nothing ->
                TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
                    [Char] -> AlertDescription -> TLSError
Error_Protocol
                        [Char]
"compressions must include nullCompression in TLS 1.2"
                        AlertDescription
IllegalParameter
            Maybe CompressionID
_ -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

    -- Processing encrypted client hello
    (mClientHello', receivedECH) <-
        if chosenVersion == TLS13 && not (null (serverECHKey sparams))
            then do
                lookupAndDecodeAndDo
                    EID_EncryptedClientHello
                    MsgTClientHello
                    chExtensions
                    (return (Nothing, False))
                    (\EncryptedClientHello
bs -> (,Bool
True) (Maybe ClientHello -> (Maybe ClientHello, Bool))
-> IO (Maybe ClientHello) -> IO (Maybe ClientHello, Bool)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ServerParams
-> Context
-> ClientHello
-> EncryptedClientHello
-> IO (Maybe ClientHello)
decryptECH ServerParams
sparams Context
ctx ClientHello
ch EncryptedClientHello
bs)
            else return (Nothing, False)
    case mClientHello' of
        Just ClientHello
chI -> do
            Context -> ClientHello -> IO ()
setupI Context
ctx ClientHello
chI
            (Version, ClientHello, Maybe ClientRandom)
-> IO (Version, ClientHello, Maybe ClientRandom)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Version
chosenVersion, ClientHello
chI, ClientRandom -> Maybe ClientRandom
forall a. a -> Maybe a
Just ClientRandom
chRandom)
        Maybe ClientHello
_ -> do
            Context -> ClientHello -> IO ()
setupO Context
ctx ClientHello
ch
            Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Version
chosenVersion Version -> Version -> Bool
forall a. Eq a => a -> a -> Bool
== Version
TLS13) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                let hasECHConf :: Bool
hasECHConf = Bool -> Bool
not ([ECHConfig] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (Shared -> [ECHConfig]
sharedECHConfigList (ServerParams -> Shared
serverShared ServerParams
sparams)))
                Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool
hasECHConf Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
receivedECH) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
                    Context -> HandshakeM () -> IO ()
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM () -> IO ()) -> HandshakeM () -> IO ()
forall a b. (a -> b) -> a -> b
$
                        Bool -> HandshakeM ()
setECHEE Bool
True
                Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
receivedECH (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
                    Context -> HandshakeM () -> IO ()
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM () -> IO ()) -> HandshakeM () -> IO ()
forall a b. (a -> b) -> a -> b
$
                        Bool -> HandshakeM ()
setECHEE Bool
True
            (Version, ClientHello, Maybe ClientRandom)
-> IO (Version, ClientHello, Maybe ClientRandom)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Version
chosenVersion, ClientHello
ch, Maybe ClientRandom
forall a. Maybe a
Nothing)

setupI :: Context -> ClientHello -> IO ()
setupI :: Context -> ClientHello -> IO ()
setupI Context
ctx chI :: ClientHello
chI@CH{[CompressionID]
[CipherId]
[ExtensionRaw]
Version
ClientRandom
Session
chCiphers :: ClientHello -> [CipherId]
chComps :: ClientHello -> [CompressionID]
chExtensions :: ClientHello -> [ExtensionRaw]
chRandom :: ClientHello -> ClientRandom
chSession :: ClientHello -> Session
chVersion :: ClientHello -> Version
chVersion :: Version
chRandom :: ClientRandom
chSession :: Session
chCiphers :: [CipherId]
chComps :: [CompressionID]
chExtensions :: [ExtensionRaw]
..} = do
    hrr <- Context -> TLSSt Bool -> IO Bool
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Bool
getTLS13HRR
    unless hrr $ startHandshake ctx TLS13 chRandom
    usingHState ctx $ setClientHello chI
    let serverName = [ExtensionRaw] -> Maybe [Char]
getServerName [ExtensionRaw]
chExtensions
    maybe (return ()) (usingState_ ctx . setClientSNI) serverName

setupO :: Context -> ClientHello -> IO ()
setupO :: Context -> ClientHello -> IO ()
setupO Context
ctx ch :: ClientHello
ch@CH{[CompressionID]
[CipherId]
[ExtensionRaw]
Version
ClientRandom
Session
chCiphers :: ClientHello -> [CipherId]
chComps :: ClientHello -> [CompressionID]
chExtensions :: ClientHello -> [ExtensionRaw]
chRandom :: ClientHello -> ClientRandom
chSession :: ClientHello -> Session
chVersion :: ClientHello -> Version
chVersion :: Version
chRandom :: ClientRandom
chSession :: Session
chCiphers :: [CipherId]
chComps :: [CompressionID]
chExtensions :: [ExtensionRaw]
..} = do
    hrr <- Context -> TLSSt Bool -> IO Bool
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Bool
getTLS13HRR
    unless hrr $ startHandshake ctx chVersion chRandom
    usingHState ctx $ setClientHello ch
    let serverName = [ExtensionRaw] -> Maybe [Char]
getServerName [ExtensionRaw]
chExtensions
    maybe (return ()) (usingState_ ctx . setClientSNI) serverName

-- SNI (Server Name Indication)
getServerName :: [ExtensionRaw] -> Maybe HostName
getServerName :: [ExtensionRaw] -> Maybe [Char]
getServerName [ExtensionRaw]
chExts =
    ExtensionID
-> MessageType
-> [ExtensionRaw]
-> Maybe [Char]
-> (ServerName -> Maybe [Char])
-> Maybe [Char]
forall e a.
Extension e =>
ExtensionID -> MessageType -> [ExtensionRaw] -> a -> (e -> a) -> a
lookupAndDecode
        ExtensionID
EID_ServerName
        MessageType
MsgTClientHello
        [ExtensionRaw]
chExts
        Maybe [Char]
forall a. Maybe a
Nothing
        ServerName -> Maybe [Char]
extractServerName
  where
    extractServerName :: ServerName -> Maybe [Char]
extractServerName (ServerName [ServerNameType]
ns) = [[Char]] -> Maybe [Char]
forall a. [a] -> Maybe a
listToMaybe ((ServerNameType -> Maybe [Char]) -> [ServerNameType] -> [[Char]]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ServerNameType -> Maybe [Char]
toHostName [ServerNameType]
ns)
    toHostName :: ServerNameType -> Maybe [Char]
toHostName (ServerNameHostName [Char]
hostName) = [Char] -> Maybe [Char]
forall a. a -> Maybe a
Just [Char]
hostName
    toHostName (ServerNameOther (CompressionID, ByteString)
_) = Maybe [Char]
forall a. Maybe a
Nothing

findHighestVersionFrom :: Version -> [Version] -> Maybe Version
findHighestVersionFrom :: Version -> [Version] -> Maybe Version
findHighestVersionFrom Version
clientVersion [Version]
allowedVersions =
    case (Version -> Bool) -> [Version] -> [Version]
forall a. (a -> Bool) -> [a] -> [a]
filter (Version
clientVersion Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
>=) ([Version] -> [Version]) -> [Version] -> [Version]
forall a b. (a -> b) -> a -> b
$ (Version -> Down Version) -> [Version] -> [Version]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn Version -> Down Version
forall a. a -> Down a
Down [Version]
allowedVersions of
        [] -> Maybe Version
forall a. Maybe a
Nothing
        Version
v : [Version]
_ -> Version -> Maybe Version
forall a. a -> Maybe a
Just Version
v

findHighestVersionFrom13 :: [Version] -> [Version] -> Maybe Version
findHighestVersionFrom13 :: [Version] -> [Version] -> Maybe Version
findHighestVersionFrom13 [Version]
clientVersions [Version]
serverVersions = case [Version]
svs [Version] -> [Version] -> [Version]
forall a. Eq a => [a] -> [a] -> [a]
`intersect` [Version]
cvs of
    [] -> Maybe Version
forall a. Maybe a
Nothing
    Version
v : [Version]
_ -> Version -> Maybe Version
forall a. a -> Maybe a
Just Version
v
  where
    svs :: [Version]
svs = (Version -> Down Version) -> [Version] -> [Version]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn Version -> Down Version
forall a. a -> Down a
Down [Version]
serverVersions
    cvs :: [Version]
cvs = (Version -> Down Version) -> [Version] -> [Version]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn Version -> Down Version
forall a. a -> Down a
Down ([Version] -> [Version]) -> [Version] -> [Version]
forall a b. (a -> b) -> a -> b
$ (Version -> Bool) -> [Version] -> [Version]
forall a. (a -> Bool) -> [a] -> [a]
filter (Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
>= Version
TLS12) [Version]
clientVersions

decryptECH
    :: ServerParams
    -> Context
    -> ClientHello
    -> EncryptedClientHello
    -> IO (Maybe ClientHello)
decryptECH :: ServerParams
-> Context
-> ClientHello
-> EncryptedClientHello
-> IO (Maybe ClientHello)
decryptECH ServerParams
_ Context
_ ClientHello
_ EncryptedClientHello
ECHClientHelloInner = Maybe ClientHello -> IO (Maybe ClientHello)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ClientHello
forall a. Maybe a
Nothing
decryptECH ServerParams
sparams Context
ctx ClientHello
chO ech :: EncryptedClientHello
ech@ECHClientHelloOuter{CompressionID
(KDF_ID, AEAD_ID)
ByteString
EncodedPublicKey
echCipherSuite :: (KDF_ID, AEAD_ID)
echConfigId :: CompressionID
echEnc :: EncodedPublicKey
echPayload :: ByteString
echCipherSuite :: EncryptedClientHello -> (KDF_ID, AEAD_ID)
echConfigId :: EncryptedClientHello -> CompressionID
echEnc :: EncryptedClientHello -> EncodedPublicKey
echPayload :: EncryptedClientHello -> ByteString
..} = (HPKEError -> IO (Maybe ClientHello))
-> IO (Maybe ClientHello) -> IO (Maybe ClientHello)
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
E.handle HPKEError -> IO (Maybe ClientHello)
hpkeHandler (IO (Maybe ClientHello) -> IO (Maybe ClientHello))
-> IO (Maybe ClientHello) -> IO (Maybe ClientHello)
forall a b. (a -> b) -> a -> b
$ do
    mfunc <- ServerParams
-> Context -> EncryptedClientHello -> IO (Maybe (HPKEF, Int))
getHPKE ServerParams
sparams Context
ctx EncryptedClientHello
ech
    case mfunc of
        Maybe (HPKEF, Int)
Nothing -> Maybe ClientHello -> IO (Maybe ClientHello)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ClientHello
forall a. Maybe a
Nothing
        Just (HPKEF
func, Int
nenc) -> do
            hrr <- Context -> TLSSt Bool -> IO Bool
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Bool
getTLS13HRR
            let nenc' = if Bool
hrr then Int
0 else Int
nenc
            let aad = Handshake -> ByteString
encodeHandshake' (Handshake -> ByteString) -> Handshake -> ByteString
forall a b. (a -> b) -> a -> b
$ ClientHello -> Handshake
ClientHello (ClientHello -> Handshake) -> ClientHello -> Handshake
forall a b. (a -> b) -> a -> b
$ Int -> ClientHello -> ClientHello
fill0ClientHello Int
nenc' ClientHello
chO
            plaintext <- func aad echPayload
            case decodeClientHello' plaintext of
                Right (ClientHello ClientHello
chI) -> do
                    case ClientHello -> ClientHello -> Maybe ClientHello
expandClientHello ClientHello
chI ClientHello
chO of
                        Maybe ClientHello
Nothing -> Maybe ClientHello -> IO (Maybe ClientHello)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ClientHello
forall a. Maybe a
Nothing
                        Just ClientHello
chI' -> Maybe ClientHello -> IO (Maybe ClientHello)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ClientHello -> IO (Maybe ClientHello))
-> Maybe ClientHello -> IO (Maybe ClientHello)
forall a b. (a -> b) -> a -> b
$ ClientHello -> Maybe ClientHello
forall a. a -> Maybe a
Just ClientHello
chI'
                Either TLSError Handshake
_ -> Maybe ClientHello -> IO (Maybe ClientHello)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ClientHello
forall a. Maybe a
Nothing
  where
    hpkeHandler :: HPKEError -> IO (Maybe ClientHello)
    hpkeHandler :: HPKEError -> IO (Maybe ClientHello)
hpkeHandler HPKEError
_ = Maybe ClientHello -> IO (Maybe ClientHello)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ClientHello
forall a. Maybe a
Nothing
decryptECH ServerParams
_ Context
_ ClientHello
_ EncryptedClientHello
_ = Maybe ClientHello -> IO (Maybe ClientHello)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ClientHello
forall a. Maybe a
Nothing

fill0ClientHello :: Int -> ClientHello -> ClientHello
fill0ClientHello :: Int -> ClientHello -> ClientHello
fill0ClientHello Int
nenc ch :: ClientHello
ch@CH{[CompressionID]
[CipherId]
[ExtensionRaw]
Version
ClientRandom
Session
chCiphers :: ClientHello -> [CipherId]
chComps :: ClientHello -> [CompressionID]
chExtensions :: ClientHello -> [ExtensionRaw]
chRandom :: ClientHello -> ClientRandom
chSession :: ClientHello -> Session
chVersion :: ClientHello -> Version
chVersion :: Version
chRandom :: ClientRandom
chSession :: Session
chCiphers :: [CipherId]
chComps :: [CompressionID]
chExtensions :: [ExtensionRaw]
..} =
    ClientHello
ch{chExtensions = fill0Exts nenc chExtensions}

fill0Exts :: Int -> [ExtensionRaw] -> [ExtensionRaw]
fill0Exts :: Int -> [ExtensionRaw] -> [ExtensionRaw]
fill0Exts Int
nenc [ExtensionRaw]
xs0 = [ExtensionRaw] -> [ExtensionRaw]
loop [ExtensionRaw]
xs0
  where
    loop :: [ExtensionRaw] -> [ExtensionRaw]
loop [] = []
    loop (ExtensionRaw ExtensionID
EID_EncryptedClientHello ByteString
bs : [ExtensionRaw]
xs) = ExtensionRaw
x' ExtensionRaw -> [ExtensionRaw] -> [ExtensionRaw]
forall a. a -> [a] -> [a]
: [ExtensionRaw] -> [ExtensionRaw]
loop [ExtensionRaw]
xs
      where
        (ByteString
prefix, ByteString
payload) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt (Int
10 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
nenc) ByteString
bs
        bs' :: ByteString
bs' = ByteString
prefix ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Int -> CompressionID -> ByteString
BS.replicate (ByteString -> Int
BS.length ByteString
payload) CompressionID
0
        x' :: ExtensionRaw
x' = ExtensionID -> ByteString -> ExtensionRaw
ExtensionRaw ExtensionID
EID_EncryptedClientHello ByteString
bs'
    loop (ExtensionRaw
x : [ExtensionRaw]
xs) = ExtensionRaw
x ExtensionRaw -> [ExtensionRaw] -> [ExtensionRaw]
forall a. a -> [a] -> [a]
: [ExtensionRaw] -> [ExtensionRaw]
loop [ExtensionRaw]
xs

expandClientHello :: ClientHello -> ClientHello -> Maybe ClientHello
expandClientHello :: ClientHello -> ClientHello -> Maybe ClientHello
expandClientHello ClientHello
inner ClientHello
outer =
    case [ExtensionRaw] -> [ExtensionRaw] -> Maybe [ExtensionRaw]
expand (ClientHello -> [ExtensionRaw]
chExtensions ClientHello
inner) (ClientHello -> [ExtensionRaw]
chExtensions ClientHello
outer) of
        Maybe [ExtensionRaw]
Nothing -> Maybe ClientHello
forall a. Maybe a
Nothing
        Just [ExtensionRaw]
exts ->
            ClientHello -> Maybe ClientHello
forall a. a -> Maybe a
Just (ClientHello -> Maybe ClientHello)
-> ClientHello -> Maybe ClientHello
forall a b. (a -> b) -> a -> b
$
                ClientHello
inner
                    { chSession = chSession outer
                    , chExtensions = exts
                    }
  where
    expand :: [ExtensionRaw] -> [ExtensionRaw] -> Maybe [ExtensionRaw]
    expand :: [ExtensionRaw] -> [ExtensionRaw] -> Maybe [ExtensionRaw]
expand [] [ExtensionRaw]
_ = [ExtensionRaw] -> Maybe [ExtensionRaw]
forall a. a -> Maybe a
Just []
    expand [ExtensionRaw]
iis [] = [ExtensionRaw] -> Maybe [ExtensionRaw]
chk [ExtensionRaw]
iis
    expand (ExtensionRaw
i : [ExtensionRaw]
is) [ExtensionRaw]
oos = do
        (rs, oos') <- case ExtensionRaw
i of
            ExtensionRaw ExtensionID
EID_EchOuterExtensions ByteString
bs ->
                case MessageType -> ByteString -> Maybe EchOuterExtensions
forall a. Extension a => MessageType -> ByteString -> Maybe a
extensionDecode MessageType
MsgTClientHello ByteString
bs of
                    Maybe EchOuterExtensions
Nothing -> Maybe ([ExtensionRaw], [ExtensionRaw])
forall a. Maybe a
Nothing
                    Just (EchOuterExtensions [ExtensionID]
eids) -> [ExtensionID]
-> [ExtensionRaw] -> Maybe ([ExtensionRaw], [ExtensionRaw])
expd [ExtensionID]
eids [ExtensionRaw]
oos
            ExtensionRaw
_ -> ([ExtensionRaw], [ExtensionRaw])
-> Maybe ([ExtensionRaw], [ExtensionRaw])
forall a. a -> Maybe a
Just ([ExtensionRaw
i], [ExtensionRaw]
oos)
        (rs ++) <$> expand is oos'
    expd
        :: [ExtensionID] -> [ExtensionRaw] -> Maybe ([ExtensionRaw], [ExtensionRaw])
    expd :: [ExtensionID]
-> [ExtensionRaw] -> Maybe ([ExtensionRaw], [ExtensionRaw])
expd [] [ExtensionRaw]
oos = ([ExtensionRaw], [ExtensionRaw])
-> Maybe ([ExtensionRaw], [ExtensionRaw])
forall a. a -> Maybe a
Just ([], [ExtensionRaw]
oos)
    expd [ExtensionID]
_ [] = Maybe ([ExtensionRaw], [ExtensionRaw])
forall a. Maybe a
Nothing
    expd (ExtensionID
i : [ExtensionID]
is) [ExtensionRaw]
oos = case ExtensionID
-> [ExtensionRaw] -> Maybe (ExtensionRaw, [ExtensionRaw])
fnd ExtensionID
i [ExtensionRaw]
oos of
        Maybe (ExtensionRaw, [ExtensionRaw])
Nothing -> Maybe ([ExtensionRaw], [ExtensionRaw])
forall a. Maybe a
Nothing
        Just (ExtensionRaw
ext, [ExtensionRaw]
oos') -> do
            (exts, oos'') <- [ExtensionID]
-> [ExtensionRaw] -> Maybe ([ExtensionRaw], [ExtensionRaw])
expd [ExtensionID]
is [ExtensionRaw]
oos'
            Just (ext : exts, oos'')
    fnd :: ExtensionID -> [ExtensionRaw] -> Maybe (ExtensionRaw, [ExtensionRaw])
    fnd :: ExtensionID
-> [ExtensionRaw] -> Maybe (ExtensionRaw, [ExtensionRaw])
fnd ExtensionID
_ [] = Maybe (ExtensionRaw, [ExtensionRaw])
forall a. Maybe a
Nothing
    fnd ExtensionID
EID_EncryptedClientHello [ExtensionRaw]
_ = Maybe (ExtensionRaw, [ExtensionRaw])
forall a. Maybe a
Nothing
    fnd ExtensionID
i (o :: ExtensionRaw
o@(ExtensionRaw ExtensionID
eid ByteString
_) : [ExtensionRaw]
os)
        | ExtensionID
i ExtensionID -> ExtensionID -> Bool
forall a. Eq a => a -> a -> Bool
== ExtensionID
eid = (ExtensionRaw, [ExtensionRaw])
-> Maybe (ExtensionRaw, [ExtensionRaw])
forall a. a -> Maybe a
Just (ExtensionRaw
o, [ExtensionRaw]
os)
        | Bool
otherwise = ExtensionID
-> [ExtensionRaw] -> Maybe (ExtensionRaw, [ExtensionRaw])
fnd ExtensionID
i [ExtensionRaw]
os
    chk :: [ExtensionRaw] -> Maybe [ExtensionRaw]
    chk :: [ExtensionRaw] -> Maybe [ExtensionRaw]
chk [] = [ExtensionRaw] -> Maybe [ExtensionRaw]
forall a. a -> Maybe a
Just []
    chk (ExtensionRaw ExtensionID
EID_EchOuterExtensions ByteString
_ : [ExtensionRaw]
_) = Maybe [ExtensionRaw]
forall a. Maybe a
Nothing
    chk (ExtensionRaw
i : [ExtensionRaw]
is) = (ExtensionRaw
i ExtensionRaw -> [ExtensionRaw] -> [ExtensionRaw]
forall a. a -> [a] -> [a]
:) ([ExtensionRaw] -> [ExtensionRaw])
-> Maybe [ExtensionRaw] -> Maybe [ExtensionRaw]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [ExtensionRaw] -> Maybe [ExtensionRaw]
chk [ExtensionRaw]
is

getHPKE
    :: ServerParams
    -> Context
    -> EncryptedClientHello
    -> IO (Maybe (HPKEF, Int))
getHPKE :: ServerParams
-> Context -> EncryptedClientHello -> IO (Maybe (HPKEF, Int))
getHPKE ServerParams{Bool
Int
[(CompressionID, ByteString)]
[SignedCertificate]
Maybe DHParams
DebugParams
ServerHooks
Shared
Supported
serverHooks :: ServerParams -> ServerHooks
serverDebug :: ServerParams -> DebugParams
serverECHKey :: ServerParams -> [(CompressionID, ByteString)]
serverShared :: ServerParams -> Shared
serverWantClientCert :: Bool
serverCACertificates :: [SignedCertificate]
serverDHEParams :: Maybe DHParams
serverHooks :: ServerHooks
serverShared :: Shared
serverSupported :: Supported
serverDebug :: DebugParams
serverEarlyDataSize :: Int
serverTicketLifetime :: Int
serverECHKey :: [(CompressionID, ByteString)]
serverCACertificates :: ServerParams -> [SignedCertificate]
serverDHEParams :: ServerParams -> Maybe DHParams
serverEarlyDataSize :: ServerParams -> Int
serverSupported :: ServerParams -> Supported
serverTicketLifetime :: ServerParams -> Int
serverWantClientCert :: ServerParams -> Bool
..} Context
ctx ECHClientHelloOuter{CompressionID
(KDF_ID, AEAD_ID)
ByteString
EncodedPublicKey
echCipherSuite :: EncryptedClientHello -> (KDF_ID, AEAD_ID)
echConfigId :: EncryptedClientHello -> CompressionID
echEnc :: EncryptedClientHello -> EncodedPublicKey
echPayload :: EncryptedClientHello -> ByteString
echCipherSuite :: (KDF_ID, AEAD_ID)
echConfigId :: CompressionID
echEnc :: EncodedPublicKey
echPayload :: ByteString
..} = do
    mfunc <- Context -> IO (Maybe (HPKEF, Int))
getTLS13HPKE Context
ctx
    case mfunc of
        Maybe (HPKEF, Int)
Nothing -> do
            let mconfig :: Maybe ECHConfig
mconfig = CompressionID -> [ECHConfig] -> Maybe ECHConfig
findECHConfigById CompressionID
echConfigId ([ECHConfig] -> Maybe ECHConfig) -> [ECHConfig] -> Maybe ECHConfig
forall a b. (a -> b) -> a -> b
$ Shared -> [ECHConfig]
sharedECHConfigList Shared
serverShared
                mskR :: Maybe ByteString
mskR = CompressionID -> [(CompressionID, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup CompressionID
echConfigId [(CompressionID, ByteString)]
serverECHKey
            case (Maybe ECHConfig
mconfig, Maybe ByteString
mskR) of
                (Just ECHConfig
config, Just ByteString
skR') -> do
                    let kemid :: KEM_ID
kemid = Word16 -> KEM_ID
KEM_ID (Word16 -> KEM_ID) -> Word16 -> KEM_ID
forall a b. (a -> b) -> a -> b
$ HpkeKeyConfig -> Word16
kem_id (HpkeKeyConfig -> Word16) -> HpkeKeyConfig -> Word16
forall a b. (a -> b) -> a -> b
$ ECHConfigContents -> HpkeKeyConfig
key_config (ECHConfigContents -> HpkeKeyConfig)
-> ECHConfigContents -> HpkeKeyConfig
forall a b. (a -> b) -> a -> b
$ ECHConfig -> ECHConfigContents
contents ECHConfig
config
                        skR :: EncodedSecretKey
skR = ByteString -> EncodedSecretKey
EncodedSecretKey ByteString
skR'
                        encodedConfig :: ByteString
encodedConfig = ECHConfig -> ByteString
encodeECHConfig ECHConfig
config
                    let info :: ByteString
info = ByteString
"tls ech\x00" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
encodedConfig
                        (KDF_ID
kdfid, AEAD_ID
aeadid) = (KDF_ID, AEAD_ID)
echCipherSuite
                    ctxR <- KEM_ID
-> KDF_ID
-> AEAD_ID
-> EncodedSecretKey
-> Maybe EncodedSecretKey
-> EncodedPublicKey
-> ByteString
-> IO ContextR
setupBaseR KEM_ID
kemid KDF_ID
kdfid AEAD_ID
aeadid EncodedSecretKey
skR Maybe EncodedSecretKey
forall a. Maybe a
Nothing EncodedPublicKey
echEnc ByteString
info
                    let nenc = KEM_ID -> Int
nEnc KEM_ID
kemid
                        func = ContextR -> HPKEF
open ContextR
ctxR
                    setTLS13HPKE ctx func nenc
                    return $ Just (func, nenc)
                (Maybe ECHConfig, Maybe ByteString)
_ -> Maybe (HPKEF, Int) -> IO (Maybe (HPKEF, Int))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (HPKEF, Int)
forall a. Maybe a
Nothing
        Maybe (HPKEF, Int)
_ -> Maybe (HPKEF, Int) -> IO (Maybe (HPKEF, Int))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (HPKEF, Int)
mfunc
getHPKE ServerParams
_ Context
_ EncryptedClientHello
_ = Maybe (HPKEF, Int) -> IO (Maybe (HPKEF, Int))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (HPKEF, Int)
forall a. Maybe a
Nothing

findECHConfigById :: ConfigId -> ECHConfigList -> Maybe ECHConfig
findECHConfigById :: CompressionID -> [ECHConfig] -> Maybe ECHConfig
findECHConfigById CompressionID
cnfId [ECHConfig]
echConfigList = (ECHConfig -> Bool) -> [ECHConfig] -> Maybe ECHConfig
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ECHConfig -> Bool
eqCfgId [ECHConfig]
echConfigList
  where
    eqCfgId :: ECHConfig -> Bool
eqCfgId ECHConfig
cnf = HpkeKeyConfig -> CompressionID
config_id (ECHConfigContents -> HpkeKeyConfig
key_config (ECHConfig -> ECHConfigContents
contents ECHConfig
cnf)) CompressionID -> CompressionID -> Bool
forall a. Eq a => a -> a -> Bool
== CompressionID
cnfId