Skip to content

Commit

Permalink
Merge pull request #140 from mdsol/tech/fix_akka_mauth_directive_to_g…
Browse files Browse the repository at this point in the history
…ive_401

Tech: Fix Akka MAuth Directive so that missing or malformed headers give a 401 status code
  • Loading branch information
jatcwang authored Jul 6, 2022
2 parents b4ba1b2 + a4d8d0e commit 8ff5c5d
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 47 deletions.
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
package com.mdsol.mauth.akka.http

import java.util.UUID

import akka.http.javadsl.model.HttpHeader
import akka.http.javadsl.server.AuthorizationFailedRejection
import akka.http.scaladsl.model.{HttpEntity, HttpRequest}
import akka.http.scaladsl.server.Directives.{headerValueByName, headerValueByType}
import akka.http.scaladsl.server.Directives.{headerValueByName, headerValuePF}
import akka.http.scaladsl.server.directives.BasicDirectives._
import akka.http.scaladsl.server.directives.FutureDirectives.onComplete
import akka.http.scaladsl.server.directives.RouteDirectives.reject
import akka.http.scaladsl.server._
import akka.http.scaladsl.server.directives.HeaderMagnet
import com.mdsol.mauth.MAuthRequest
import com.mdsol.mauth.http.{`X-MWS-Authentication`, `X-MWS-Time`, HttpVerbOps}
import com.mdsol.mauth.scaladsl.Authenticator
Expand All @@ -25,6 +25,12 @@ case class AuthHeaderDetail(appId: UUID, hash: String)

case object MdsolAuthFailedRejection extends AuthorizationFailedRejection with Rejection

final case class MdsolAuthMalformedHeaderRejection(headerName: String, errorMsg: String, cause: Option[Throwable] = None)
extends AuthorizationFailedRejection
with RejectionWithOptionalCause

final case class MdsolAuthMissingHeaderRejection(headerName: String) extends AuthorizationFailedRejection with Rejection

trait MAuthDirectives extends StrictLogging {

/** Directive to wrap all routes that require MAuth authentication check.
Expand Down Expand Up @@ -81,20 +87,23 @@ trait MAuthDirectives extends StrictLogging {
@deprecated("This method is for Mauth V1 protocol only", "3.0.0")
val extractMwsAuthenticationHeader: Directive1[String] = headerValueByName(`X-MWS-Authentication`.name)

def headerValueByTypeMdsol[T](magnet: HeaderMagnet[T]): Directive1[T] =
headerValuePF(magnet.extractPF) | reject(MdsolAuthMissingHeaderRejection(magnet.headerName))

/** Extracts the detail information of the HTTP request header X-MWS-Authentication
*
* @return Directive1[AuthHeaderDetail] of Mauth V1 protocol
* If invalidated, the request is rejected with a MalformedHeaderRejection.
*/
@deprecated("This method is for Mauth V1 protocol only", "3.0.0")
val extractMAuthHeader: Directive1[AuthHeaderDetail] =
headerValueByType[`X-MWS-Authentication`]((): Unit).flatMap { hdr =>
headerValueByTypeMdsol[`X-MWS-Authentication`]((): Unit).flatMap { hdr =>
extractAuthHeaderDetail(hdr.value) match {
case Some(ahd: AuthHeaderDetail) => provide(ahd)
case None =>
val msg = s"${`X-MWS-Authentication`.name} header supplied with bad format: [${hdr.value}]"
logger.error(msg)
reject(MalformedHeaderRejection(headerName = `X-MWS-Authentication`.name, errorMsg = msg, None))
reject(MdsolAuthMalformedHeaderRejection(headerName = `X-MWS-Authentication`.name, errorMsg = msg, None))

}
}
Expand All @@ -106,13 +115,13 @@ trait MAuthDirectives extends StrictLogging {
*/
@deprecated("This method is for Mauth V1 protocol only", "3.0.0")
val extractMwsTimeHeader: Directive1[Long] =
headerValueByType[`X-MWS-Time`]((): Unit).flatMap { time =>
headerValueByTypeMdsol[`X-MWS-Time`]((): Unit).flatMap { time =>
Try(time.value.toLong).toOption match {
case Some(t: Long) => provide(t)
case None =>
val msg = s"${`X-MWS-Time`.name} header supplied with bad format: [${time.value}]"
logger.error(msg)
reject(MalformedHeaderRejection(headerName = `X-MWS-Time`.name, errorMsg = msg, None))
reject(MdsolAuthMalformedHeaderRejection(headerName = `X-MWS-Time`.name, errorMsg = msg, None))
}
}

Expand Down Expand Up @@ -156,7 +165,7 @@ trait MAuthDirectives extends StrictLogging {
* Otherwise, extracts the authentication header of X-MWS-Authentication if MCC-Authentication header is not found.
*
* @return Directive1[MauthHeaderValues] of Mauth authentication header values for V1 or V2
* the request is rejected with a MissingHeaderRejection if the expected header is not present
* the request is rejected with a MdsolAuthMissingHeaderRejection if the expected header is not present
*/
def extractLatestAuthenticationHeaders(v2OnlyAuthenticate: Boolean): Directive1[MauthHeaderValues] = {
extractRequest.flatMap { request: HttpRequest =>
Expand All @@ -171,10 +180,10 @@ trait MAuthDirectives extends StrictLogging {
case None =>
val msg = s"${MAuthRequest.MCC_TIME_HEADER_NAME} header supplied with bad format: [$timeHeaderStr]"
logger.error(msg)
reject(MalformedHeaderRejection(headerName = MAuthRequest.MCC_TIME_HEADER_NAME, errorMsg = msg, None))
reject(MdsolAuthMalformedHeaderRejection(headerName = MAuthRequest.MCC_TIME_HEADER_NAME, errorMsg = msg, None))
}
} else {
reject(MissingHeaderRejection(MAuthRequest.MCC_TIME_HEADER_NAME))
reject(MdsolAuthMissingHeaderRejection(MAuthRequest.MCC_TIME_HEADER_NAME))
}
} else {
// If V2 headers not found, fallback to V1 headers if allowed
Expand All @@ -189,16 +198,16 @@ trait MAuthDirectives extends StrictLogging {
case None =>
val msg = s"${MAuthRequest.X_MWS_TIME_HEADER_NAME} header supplied with bad format: [$timeHeaderStr]"
logger.error(msg)
reject(MalformedHeaderRejection(headerName = MAuthRequest.X_MWS_TIME_HEADER_NAME, errorMsg = msg, None))
reject(MdsolAuthMalformedHeaderRejection(headerName = MAuthRequest.X_MWS_TIME_HEADER_NAME, errorMsg = msg, None))
}
} else {
reject(MissingHeaderRejection(MAuthRequest.X_MWS_TIME_HEADER_NAME))
reject(MdsolAuthMissingHeaderRejection(MAuthRequest.X_MWS_TIME_HEADER_NAME))
}
} else {
reject(MissingHeaderRejection(MAuthRequest.X_MWS_AUTHENTICATION_HEADER_NAME))
reject(MdsolAuthMissingHeaderRejection(MAuthRequest.X_MWS_AUTHENTICATION_HEADER_NAME))
}
} else {
reject(MissingHeaderRejection(MAuthRequest.MCC_AUTHENTICATION_HEADER_NAME))
reject(MdsolAuthMissingHeaderRejection(MAuthRequest.MCC_AUTHENTICATION_HEADER_NAME))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class MAuthDirectivesSpec extends AnyWordSpec with Matchers with ScalatestRouteT
(client.getPublicKey _).expects(appUuid).never()

Get().withHeaders(RawHeader(MAuthRequest.X_MWS_TIME_HEADER_NAME, timeHeader.toString)) ~> route ~> check {
inside(rejection) { case MissingHeaderRejection(headerName) =>
inside(rejection) { case MdsolAuthMissingHeaderRejection(headerName) =>
headerName.replaceAll("_", "-").toLowerCase shouldEqual `X-MWS-Authentication`.name
}
}
Expand All @@ -142,7 +142,7 @@ class MAuthDirectivesSpec extends AnyWordSpec with Matchers with ScalatestRouteT
(client.getPublicKey _).expects(appUuid).never()

Get().withHeaders(RawHeader(MAuthRequest.X_MWS_AUTHENTICATION_HEADER_NAME, authHeader)) ~> route ~> check {
inside(rejection) { case MissingHeaderRejection(headerName) =>
inside(rejection) { case MdsolAuthMissingHeaderRejection(headerName) =>
headerName.replaceAll("_", "-").toLowerCase shouldEqual `X-MWS-Time`.name
}
}
Expand All @@ -159,16 +159,18 @@ class MAuthDirectivesSpec extends AnyWordSpec with Matchers with ScalatestRouteT
}
}

"reject with a MalformedHeaderRejection if supplied with bad format" in {
"reject with a MdsolAuthMalformedHeaderRejection if supplied with bad format" in {
Get().withHeaders(RawHeader(MAuthRequest.X_MWS_TIME_HEADER_NAME, "xyz")) ~> route ~> check {
inside(rejection) { case MalformedHeaderRejection("x-mws-time", "x-mws-time header supplied with bad format: [xyz]", None) => }
inside(rejection) { case MdsolAuthMalformedHeaderRejection("x-mws-time", "x-mws-time header supplied with bad format: [xyz]", None) => }
}
}

"reject with a MissingHeaderRejection if header is missing" in {
"reject with a MdsolAuthMissingHeaderRejection if header is missing" in {
Get() ~> route ~> check {
inside(rejection) { case MissingHeaderRejection(headerName) =>
headerName.replaceAll("_", "-").toLowerCase shouldEqual `X-MWS-Time`.name
inside(rejection) {
case MdsolAuthMissingHeaderRejection(headerName) =>
headerName.replaceAll("_", "-").toLowerCase shouldEqual `X-MWS-Time`.name
case t => fail(t.toString)
}
}
}
Expand All @@ -184,39 +186,39 @@ class MAuthDirectivesSpec extends AnyWordSpec with Matchers with ScalatestRouteT
}
}

"reject with a MalformedHeaderRejection if Authentication is missing the Prefix MWS" in {
"reject with a MdsolAuthMalformedHeaderRejection if Authentication is missing the Prefix MWS" in {
val wrongHeader = s" $appUuid:$signature"
Get().withHeaders(RawHeader(MAuthRequest.X_MWS_AUTHENTICATION_HEADER_NAME, wrongHeader)) ~> route ~> check {
inside(rejection) { case MalformedHeaderRejection(actualHeader, actualMsg, _) =>
inside(rejection) { case MdsolAuthMalformedHeaderRejection(actualHeader, actualMsg, _) =>
actualHeader shouldBe "x-mws-authentication"
actualMsg shouldBe s"x-mws-authentication header supplied with bad format: [$wrongHeader]"
}
}
}

"reject with a MalformedHeaderRejection if Authentication is missing the App UUID" in {
"reject with a MdsolAuthMalformedHeaderRejection if Authentication is missing the App UUID" in {
val wrongHeader = s"$authPrefix :$signature"
Get().withHeaders(RawHeader(MAuthRequest.X_MWS_AUTHENTICATION_HEADER_NAME, wrongHeader)) ~> route ~> check {
inside(rejection) { case MalformedHeaderRejection(actualHeader, actualMsg, _) =>
inside(rejection) { case MdsolAuthMalformedHeaderRejection(actualHeader, actualMsg, _) =>
actualHeader shouldBe "x-mws-authentication"
actualMsg shouldBe s"x-mws-authentication header supplied with bad format: [$wrongHeader]"
}
}
}

"reject with a MalformedHeaderRejection if Authentication is missing the signature" in {
"reject with a MdsolAuthMalformedHeaderRejection if Authentication is missing the signature" in {
val wrongHeader = s"$authPrefix $appUuid:"
Get().withHeaders(RawHeader(MAuthRequest.X_MWS_AUTHENTICATION_HEADER_NAME, wrongHeader)) ~> route ~> check {
inside(rejection) { case MalformedHeaderRejection(actualHeader, actualMsg, _) =>
inside(rejection) { case MdsolAuthMalformedHeaderRejection(actualHeader, actualMsg, _) =>
actualHeader shouldBe "x-mws-authentication"
actualMsg shouldBe s"x-mws-authentication header supplied with bad format: [$wrongHeader]"
}
}
}

"reject with a MissingHeaderRejection if header is missing" in {
"reject with a MdsolAuthMissingHeaderRejection if header is missing" in {
Get() ~> route ~> check {
inside(rejection) { case MissingHeaderRejection(headerName) =>
inside(rejection) { case MdsolAuthMissingHeaderRejection(headerName) =>
headerName.replaceAll("_", "-").toLowerCase shouldEqual `X-MWS-Authentication`.name
}
}
Expand Down Expand Up @@ -301,7 +303,7 @@ class MAuthDirectivesSpec extends AnyWordSpec with Matchers with ScalatestRouteT
(client.getPublicKey _).expects(appUuid).never()

Get().withHeaders(RawHeader(MAuthRequest.MCC_TIME_HEADER_NAME, timeHeader.toString)) ~> route ~> check {
inside(rejection) { case MissingHeaderRejection(headerName) =>
inside(rejection) { case MdsolAuthMissingHeaderRejection(headerName) =>
headerName.replaceAll("_", "-").toLowerCase shouldEqual MAuthRequest.MCC_AUTHENTICATION_HEADER_NAME
}
}
Expand All @@ -311,7 +313,7 @@ class MAuthDirectivesSpec extends AnyWordSpec with Matchers with ScalatestRouteT
(client.getPublicKey _).expects(appUuid).never()

Get().withHeaders(RawHeader(MAuthRequest.MCC_AUTHENTICATION_HEADER_NAME, authHeaderV2)) ~> route ~> check {
inside(rejection) { case MissingHeaderRejection(headerName) =>
inside(rejection) { case MdsolAuthMissingHeaderRejection(headerName) =>
headerName.replaceAll("_", "-").toLowerCase shouldEqual MAuthRequest.MCC_TIME_HEADER_NAME
}
}
Expand All @@ -324,7 +326,7 @@ class MAuthDirectivesSpec extends AnyWordSpec with Matchers with ScalatestRouteT
RawHeader(MAuthRequest.X_MWS_TIME_HEADER_NAME, timeHeader.toString),
RawHeader(MAuthRequest.X_MWS_AUTHENTICATION_HEADER_NAME, authHeader)
) ~> route ~> check {
inside(rejection) { case MissingHeaderRejection(headerName) =>
inside(rejection) { case MdsolAuthMissingHeaderRejection(headerName) =>
headerName.replaceAll("_", "-").toLowerCase shouldEqual MAuthRequest.MCC_AUTHENTICATION_HEADER_NAME
}
}
Expand Down Expand Up @@ -356,32 +358,32 @@ class MAuthDirectivesSpec extends AnyWordSpec with Matchers with ScalatestRouteT
}
}

"reject with a MalformedHeaderRejection if Authentication header is missing" in {
"reject with a MdsolAuthMalformedHeaderRejection if Authentication header is missing" in {
Get("/").withHeaders(
RawHeader(MAuthRequest.X_MWS_TIME_HEADER_NAME, timeHeader.toString)
) ~> route ~> check {
inside(rejection) { case MissingHeaderRejection(headerName) =>
inside(rejection) { case MdsolAuthMissingHeaderRejection(headerName) =>
headerName.replaceAll("_", "-").toLowerCase shouldEqual MAuthRequest.X_MWS_AUTHENTICATION_HEADER_NAME
}
}
}

"reject with a MalformedHeaderRejection if Time header is missing" in {
"reject with a MdsolAuthMalformedHeaderRejection if Time header is missing" in {
Get("/").withHeaders(
RawHeader(MAuthRequest.X_MWS_AUTHENTICATION_HEADER_NAME, authHeader)
) ~> route ~> check {
inside(rejection) { case MissingHeaderRejection(headerName) =>
inside(rejection) { case MdsolAuthMissingHeaderRejection(headerName) =>
headerName.replaceAll("_", "-").toLowerCase shouldEqual MAuthRequest.X_MWS_TIME_HEADER_NAME
}
}
}

"reject with a MalformedHeaderRejection if V1 Time header is missing (mixed headers)" in {
"reject with a MdsolAuthMalformedHeaderRejection if V1 Time header is missing (mixed headers)" in {
Get("/").withHeaders(
RawHeader(MAuthRequest.X_MWS_AUTHENTICATION_HEADER_NAME, authHeader),
RawHeader(MAuthRequest.MCC_TIME_HEADER_NAME, timeHeader.toString)
) ~> route ~> check {
inside(rejection) { case MissingHeaderRejection(headerName) =>
inside(rejection) { case MdsolAuthMissingHeaderRejection(headerName) =>
headerName.replaceAll("_", "-").toLowerCase shouldEqual MAuthRequest.X_MWS_TIME_HEADER_NAME
}
}
Expand Down Expand Up @@ -413,54 +415,54 @@ class MAuthDirectivesSpec extends AnyWordSpec with Matchers with ScalatestRouteT
}
}

"reject with a MalformedHeaderRejection with V1 headers only" in {
"reject with a MdsolAuthMalformedHeaderRejection with V1 headers only" in {
Get("/").withHeaders(
RawHeader(MAuthRequest.X_MWS_TIME_HEADER_NAME, timeHeader.toString),
RawHeader(MAuthRequest.X_MWS_AUTHENTICATION_HEADER_NAME, authHeader)
) ~> route ~> check {
inside(rejection) { case MissingHeaderRejection(headerName) =>
inside(rejection) { case MdsolAuthMissingHeaderRejection(headerName) =>
headerName.replaceAll("_", "-").toLowerCase shouldEqual MAuthRequest.MCC_AUTHENTICATION_HEADER_NAME
}
}
}

"reject with a MalformedHeaderRejection if supplied with bad format" in {
"reject with a MdsolAuthMalformedHeaderRejection if supplied with bad format" in {
Get().withHeaders(
RawHeader(MAuthRequest.MCC_AUTHENTICATION_HEADER_NAME, authHeaderV2),
RawHeader(MAuthRequest.MCC_TIME_HEADER_NAME, "xyz")
) ~> route ~> check {
inside(rejection) { case MalformedHeaderRejection("mcc-time", "mcc-time header supplied with bad format: [xyz]", None) =>
inside(rejection) { case MdsolAuthMalformedHeaderRejection("mcc-time", "mcc-time header supplied with bad format: [xyz]", None) =>
}
}
}

"reject with a MalformedHeaderRejection if V2 Authentication header is missing" in {
"reject with a MdsolAuthMalformedHeaderRejection if V2 Authentication header is missing" in {
Get("/").withHeaders(
RawHeader(MAuthRequest.MCC_TIME_HEADER_NAME, timeHeader.toString),
RawHeader(MAuthRequest.X_MWS_AUTHENTICATION_HEADER_NAME, authHeader)
) ~> route ~> check {
inside(rejection) { case MissingHeaderRejection(headerName) =>
inside(rejection) { case MdsolAuthMissingHeaderRejection(headerName) =>
headerName.replaceAll("_", "-").toLowerCase shouldEqual MAuthRequest.MCC_AUTHENTICATION_HEADER_NAME
}
}
}

"reject with a MalformedHeaderRejection if Time header is missing" in {
"reject with a MdsolAuthMalformedHeaderRejection if Time header is missing" in {
Get("/").withHeaders(
RawHeader(MAuthRequest.MCC_AUTHENTICATION_HEADER_NAME, authHeaderV2)
) ~> route ~> check {
inside(rejection) { case MissingHeaderRejection(headerName) =>
inside(rejection) { case MdsolAuthMissingHeaderRejection(headerName) =>
headerName.replaceAll("_", "-").toLowerCase shouldEqual MAuthRequest.MCC_TIME_HEADER_NAME
}
}
}

"reject with a MalformedHeaderRejection if V2 Time header is missing (mixed headers)" in {
"reject with a MdsolAuthMalformedHeaderRejection if V2 Time header is missing (mixed headers)" in {
Get("/").withHeaders(
RawHeader(MAuthRequest.MCC_AUTHENTICATION_HEADER_NAME, authHeaderV2),
RawHeader(MAuthRequest.X_MWS_TIME_HEADER_NAME, timeHeader.toString)
) ~> route ~> check {
inside(rejection) { case MissingHeaderRejection(headerName) =>
inside(rejection) { case MdsolAuthMissingHeaderRejection(headerName) =>
headerName.replaceAll("_", "-").toLowerCase shouldEqual MAuthRequest.MCC_TIME_HEADER_NAME
}
}
Expand Down

0 comments on commit 8ff5c5d

Please sign in to comment.