diff --git a/src/Servant/Auth/Hmac/Client.hs b/src/Servant/Auth/Hmac/Client.hs index fe6ab53..6c3a92f 100644 --- a/src/Servant/Auth/Hmac/Client.hs +++ b/src/Servant/Auth/Hmac/Client.hs @@ -41,14 +41,14 @@ import Servant.Auth.Hmac.Crypto ( RequestPayload (..), SecretKey, Signature (..), - authHeaderName, - keepWhitelistedHeaders, requestSignature, signSHA256, + defaultAuthHeaderName, keepWhitelistedHeaders' ) import qualified Network.HTTP.Client as Client import qualified Servant.Client.Core as Servant +import Network.HTTP.Types -- | Environment for 'HmacClientM'. Contains all required settings for hmac client. data HmacSettings = HmacSettings @@ -59,6 +59,8 @@ data HmacSettings = HmacSettings , hmacRequestHook :: Maybe (Servant.Request -> ClientM ()) -- ^ Function to call for every request after this request is signed. -- Useful for debugging. + , hmacAuthHeaderName :: HeaderName + -- ^ Header name to use to get request signature } {- | Default 'HmacSettings' with the following configuration: @@ -66,6 +68,7 @@ data HmacSettings = HmacSettings 1. Signing function is 'signSHA256'. 2. Secret key is provided. 3. 'hmacRequestHook' is 'Nothing'. +4. 'hmacAuthHeaderName' is 'Authentication'. -} defaultHmacSettings :: SecretKey -> HmacSettings defaultHmacSettings sk = @@ -73,6 +76,7 @@ defaultHmacSettings sk = { hmacSigner = signSHA256 , hmacSecretKey = sk , hmacRequestHook = Nothing + , hmacAuthHeaderName = defaultAuthHeaderName } {- | @newtype@ wrapper over 'ClientM' that signs all outgoing requests @@ -90,7 +94,7 @@ hmacClientSign :: Servant.Request -> HmacClientM Servant.Request hmacClientSign req = HmacClientM $ do HmacSettings{..} <- ask url <- lift $ asks baseUrl - let signedRequest = signRequestHmac hmacSigner hmacSecretKey url req + let signedRequest = signRequestHmac hmacAuthHeaderName hmacSigner hmacSecretKey url req case hmacRequestHook of Nothing -> pure () Just hook -> lift $ hook signedRequest @@ -118,13 +122,13 @@ hmacClient = Proxy @api `clientIn` Proxy @HmacClientM -- Internals ---------------------------------------------------------------------------- -servantRequestToPayload :: BaseUrl -> Servant.Request -> RequestPayload -servantRequestToPayload url sreq = +servantRequestToPayload :: HeaderName -> BaseUrl -> Servant.Request -> RequestPayload +servantRequestToPayload authHeaderName url sreq = RequestPayload { rpMethod = Client.method req , rpContent = "" -- toBsBody $ Client.requestBody req , rpHeaders = - keepWhitelistedHeaders $ + keepWhitelistedHeaders' authHeaderName $ ("Host", hostAndPort) : ("Accept-Encoding", "gzip") : Client.requestHeaders req @@ -159,9 +163,12 @@ servantRequestToPayload url sreq = @ Authentication: HMAC -@ + -} + signRequestHmac :: + -- | Authentication header name + HeaderName -> -- | Signing function (SecretKey -> ByteString -> Signature) -> -- | Secret key that was used for signing 'Request' @@ -172,8 +179,8 @@ signRequestHmac :: Servant.Request -> -- | Signed request Servant.Request -signRequestHmac signer sk url req = do - let payload = servantRequestToPayload url req +signRequestHmac authHeaderName signer sk url req = do + let payload = servantRequestToPayload authHeaderName url req let signature = requestSignature signer sk payload let authHead = (authHeaderName, "HMAC " <> unSignature signature) req{Servant.requestHeaders = authHead <| Servant.requestHeaders req} diff --git a/src/Servant/Auth/Hmac/Crypto.hs b/src/Servant/Auth/Hmac/Crypto.hs index 801f5c9..626df4b 100644 --- a/src/Servant/Auth/Hmac/Crypto.hs +++ b/src/Servant/Auth/Hmac/Crypto.hs @@ -12,11 +12,15 @@ module Servant.Auth.Hmac.Crypto ( RequestPayload (..), requestSignature, verifySignatureHmac, + verifySignatureHmac', whitelistHeaders, + whitelistHeaders', keepWhitelistedHeaders, + keepWhitelistedHeaders', - -- * Internals - authHeaderName, + -- * Internal + defaultAuthHeaderName, + unsignedPayload ) where import Crypto.Hash (hash) @@ -24,7 +28,7 @@ import Crypto.Hash.Algorithms (MD5, SHA256) import Crypto.Hash.IO (HashAlgorithm) import Crypto.MAC.HMAC (HMAC (hmacGetDigest), hmac) import Data.ByteString (ByteString) -import Data.CaseInsensitive (foldedCase) +import Data.CaseInsensitive (foldedCase, CI (original)) import Data.List (sort, uncons) import Network.HTTP.Types (Header, HeaderName, Method, RequestHeaders) @@ -143,12 +147,15 @@ requestSignature signer sk = signer sk . createStringToSign {- | White-listed headers. Only these headers will be taken into consideration: -1. @Authentication@ +1. An authentication header of your choosing 2. @Host@ 3. @Accept-Encoding@ -} whitelistHeaders :: [HeaderName] -whitelistHeaders = +whitelistHeaders = whitelistHeaders' defaultAuthHeaderName + +whitelistHeaders' :: HeaderName -> [HeaderName] +whitelistHeaders' authHeaderName = [ authHeaderName , "Host" , "Accept-Encoding" @@ -156,7 +163,10 @@ whitelistHeaders = -- | Keeps only headers from 'whitelistHeaders'. keepWhitelistedHeaders :: [Header] -> [Header] -keepWhitelistedHeaders = filter (\(name, _) -> name `elem` whitelistHeaders) +keepWhitelistedHeaders = keepWhitelistedHeaders' defaultAuthHeaderName + +keepWhitelistedHeaders' :: HeaderName -> [Header] -> [Header] +keepWhitelistedHeaders' authHeaderName = filter (\(name, _) -> name `elem` whitelistHeaders' authHeaderName) {- | This function takes signing function @signer@ and secret key and expects that given 'Request' has header: @@ -168,41 +178,39 @@ Authentication: HMAC It checks whether @@ is true request signature. Function returns 'Nothing' if it is true, and 'Just' error message otherwise. -} -verifySignatureHmac :: + +verifySignatureHmac :: (SecretKey -> ByteString -> Signature) -> SecretKey -> RequestPayload -> Maybe LBS.ByteString +verifySignatureHmac = verifySignatureHmac' requestSignature unsignedPayload defaultAuthHeaderName + +verifySignatureHmac' :: + -- | Function to generate signature from request: takes signing function, secret key, and request + ((SecretKey -> ByteString -> Signature) -> SecretKey -> RequestPayload -> Signature) -> + -- | Function to extract signature from request + (RequestPayload -> HeaderName -> Either LBS.ByteString (RequestPayload, Signature)) -> + -- | Auth header name + HeaderName -> -- | Signing function (SecretKey -> ByteString -> Signature) -> -- | Secret key that was used for signing 'Request' SecretKey -> RequestPayload -> Maybe LBS.ByteString -verifySignatureHmac signer sk signedPayload = case unsignedPayload of +verifySignatureHmac' mkRequestSignature extractSignature authHeaderName signer sk signedPayload = case extractSignature signedPayload authHeaderName of Left err -> Just err Right (pay, sig) -> - if sig == requestSignature signer sk pay + if sig == mkRequestSignature signer sk pay then Nothing else Just "Signatures don't match" - where - -- Extracts HMAC signature from request and returns request with @authHeaderName@ header - unsignedPayload :: Either LBS.ByteString (RequestPayload, Signature) - unsignedPayload = case extractOn isAuthHeader $ rpHeaders signedPayload of - (Nothing, _) -> Left "No 'Authentication' header" - (Just (_, val), headers) -> case BS.stripPrefix "HMAC " val of - Just sig -> - Right - ( signedPayload{rpHeaders = headers} - , Signature sig - ) - Nothing -> Left "Can not strip 'HMAC' prefix in header" ---------------------------------------------------------------------------- -- Internals ---------------------------------------------------------------------------- -authHeaderName :: HeaderName -authHeaderName = "Authentication" +defaultAuthHeaderName :: HeaderName +defaultAuthHeaderName = "Authentication" -isAuthHeader :: Header -> Bool -isAuthHeader = (== authHeaderName) . fst +isAuthHeader :: HeaderName -> Header -> Bool +isAuthHeader name = (== name) . fst hashMD5 :: ByteString -> ByteString hashMD5 = BA.convert . hash @_ @MD5 @@ -220,3 +228,14 @@ extractOn p l = in case uncons after of Nothing -> (Nothing, l) Just (x, xs) -> (Just x, before ++ xs) + +unsignedPayload :: RequestPayload -> HeaderName -> Either LBS.ByteString (RequestPayload, Signature) +unsignedPayload signedPayload authHeaderName = case extractOn (isAuthHeader authHeaderName) $ rpHeaders signedPayload of + (Nothing, _) -> Left $ "No '" <> LBS.fromStrict (original authHeaderName) <> "' header" + (Just (_, val), headers) -> case BS.stripPrefix "HMAC " val of + Just sig -> + Right + ( signedPayload{rpHeaders = headers} + , Signature sig + ) + Nothing -> Left "Can not strip 'HMAC' prefix in header" \ No newline at end of file diff --git a/src/Servant/Auth/Hmac/Server.hs b/src/Servant/Auth/Hmac/Server.hs index 6bb081a..bd18885 100644 --- a/src/Servant/Auth/Hmac/Server.hs +++ b/src/Servant/Auth/Hmac/Server.hs @@ -8,8 +8,11 @@ module Servant.Auth.Hmac.Server ( HmacAuthContext, HmacAuthHandler, hmacAuthServerContext, + hmacAuthServerContext', hmacAuthHandler, + hmacAuthHandler', hmacAuthHandlerMap, + hmacAuthHandlerMap', ) where import Control.Monad.Except (throwError) @@ -24,12 +27,12 @@ import Servant.Server.Experimental.Auth (AuthHandler, AuthServerData, mkAuthHand import Servant.Auth.Hmac.Crypto ( RequestPayload (..), SecretKey, - Signature, - keepWhitelistedHeaders, - verifySignatureHmac, + Signature, verifySignatureHmac', keepWhitelistedHeaders', defaultAuthHeaderName, requestSignature, unsignedPayload ) import qualified Network.Wai as Wai (Request) +import Network.HTTP.Types +import qualified Data.ByteString.Lazy as LBS type HmacAuth = AuthProtect "hmac-auth" @@ -39,29 +42,57 @@ type HmacAuthHandler = AuthHandler Wai.Request () type HmacAuthContextHandlers = '[HmacAuthHandler] type HmacAuthContext = Context HmacAuthContextHandlers -hmacAuthServerContext :: +hmacAuthServerContext :: (SecretKey -> ByteString -> Signature) -> SecretKey -> HmacAuthContext +hmacAuthServerContext = hmacAuthServerContext' requestSignature unsignedPayload defaultAuthHeaderName + +hmacAuthServerContext' :: + -- | Function to generate signature from request: takes signing function, secret key, and request + ((SecretKey -> ByteString -> Signature) -> SecretKey -> RequestPayload -> Signature) -> + -- | Function to extract signature from request + (RequestPayload -> HeaderName -> Either LBS.ByteString (RequestPayload, Signature)) -> + -- | Auth header name + HeaderName -> -- | Signing function (SecretKey -> ByteString -> Signature) -> -- | Secret key that was used for signing 'Request' SecretKey -> HmacAuthContext -hmacAuthServerContext signer sk = hmacAuthHandler signer sk :. EmptyContext +hmacAuthServerContext' mkRequestSignature extractSignature authHeaderName signer sk = hmacAuthHandler' mkRequestSignature extractSignature authHeaderName signer sk :. EmptyContext + +hmacAuthHandler :: (SecretKey -> ByteString -> Signature) -> SecretKey -> HmacAuthHandler +hmacAuthHandler = hmacAuthHandler' requestSignature unsignedPayload defaultAuthHeaderName -- | Create 'HmacAuthHandler' from signing function and secret key. -hmacAuthHandler :: +hmacAuthHandler' :: + -- | Function to generate signature from request: takes signing function, secret key, and request + ((SecretKey -> ByteString -> Signature) -> SecretKey -> RequestPayload -> Signature) -> + -- | Function to extract signature from request + (RequestPayload -> HeaderName -> Either LBS.ByteString (RequestPayload, Signature)) -> + -- | Auth header name + HeaderName -> -- | Signing function (SecretKey -> ByteString -> Signature) -> -- | Secret key that was used for signing 'Request' SecretKey -> HmacAuthHandler -hmacAuthHandler = hmacAuthHandlerMap pure +hmacAuthHandler' mkRequestSignature extractSignature authHeaderName = hmacAuthHandlerMap' mkRequestSignature extractSignature authHeaderName pure {- | Like 'hmacAuthHandler' but allows to specify additional mapping function for 'Wai.Request'. This can be useful if you want to print incoming request (for logging purposes) or filter some headers (to match signature). Given function is applied before signature verification. -} -hmacAuthHandlerMap :: + +hmacAuthHandlerMap :: (Wai.Request -> Handler Wai.Request) -> (SecretKey -> ByteString -> Signature) -> SecretKey -> HmacAuthHandler +hmacAuthHandlerMap = hmacAuthHandlerMap' requestSignature unsignedPayload defaultAuthHeaderName + +hmacAuthHandlerMap' :: + -- | Function to generate signature from request: takes signing function, secret key, and request + ((SecretKey -> ByteString -> Signature) -> SecretKey -> RequestPayload -> Signature) -> + -- | Function to extract signature from request + (RequestPayload -> HeaderName -> Either LBS.ByteString (RequestPayload, Signature)) -> + -- | Auth header name + HeaderName -> -- | Request mapper (Wai.Request -> Handler Wai.Request) -> -- | Signing function @@ -69,13 +100,13 @@ hmacAuthHandlerMap :: -- | Secret key that was used for signing 'Request' SecretKey -> HmacAuthHandler -hmacAuthHandlerMap mapper signer sk = mkAuthHandler handler +hmacAuthHandlerMap' mkRequestSignature extractSignature authHeaderName mapper signer sk = mkAuthHandler handler where handler :: Wai.Request -> Handler () handler req = do newReq <- mapper req - let payload = waiRequestToPayload newReq - let verification = verifySignatureHmac signer sk payload + let payload = waiRequestToPayload authHeaderName newReq + let verification = verifySignatureHmac' mkRequestSignature extractSignature authHeaderName signer sk payload case verification of Nothing -> pure () Just bs -> throwError $ err401{errBody = bs} @@ -93,12 +124,12 @@ hmacAuthHandlerMap mapper signer sk = mkAuthHandler handler -- then pure [] -- else (chunk:) <$> getChunks -waiRequestToPayload :: Wai.Request -> RequestPayload +waiRequestToPayload :: HeaderName -> Wai.Request -> RequestPayload -- waiRequestToPayload req = getWaiRequestBody req >>= \body -> pure RequestPayload -waiRequestToPayload req = +waiRequestToPayload authHeaderName req = RequestPayload { rpMethod = requestMethod req , rpContent = "" - , rpHeaders = keepWhitelistedHeaders $ requestHeaders req + , rpHeaders = keepWhitelistedHeaders' authHeaderName $ requestHeaders req , rpRawUrl = fromMaybe mempty (requestHeaderHost req) <> rawPathInfo req <> rawQueryString req } diff --git a/test/Servant/Auth/HmacSpec.hs b/test/Servant/Auth/HmacSpec.hs index 5d97df4..5f9e657 100644 --- a/test/Servant/Auth/HmacSpec.hs +++ b/test/Servant/Auth/HmacSpec.hs @@ -23,10 +23,9 @@ import Servant.Auth.Hmac ( HmacAuth, SecretKey (SecretKey), defaultHmacSettings, - hmacAuthServerContext, hmacClient, runHmacClient, - signSHA256, + signSHA256, hmacAuthServerContext ) import Servant.Client ( BaseUrl (baseUrlPort),