Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate auth token code to cats effect #4318

Merged
merged 2 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package ch.epfl.bluebrain.nexus.delta.wiring

import cats.effect.{Clock, IO}
import ch.epfl.bluebrain.nexus.delta.Main.pluginsMaxPriority
import ch.epfl.bluebrain.nexus.delta.config.AppConfig
import ch.epfl.bluebrain.nexus.delta.kernel.cache.CacheConfig
Expand Down Expand Up @@ -34,8 +35,8 @@ object IdentitiesModule extends ModuleDef {
new OpenIdAuthService(httpClient, realms)
}

make[AuthTokenProvider].fromEffect { (authService: OpenIdAuthService) =>
AuthTokenProvider(authService)
make[AuthTokenProvider].fromEffect { (authService: OpenIdAuthService, clock: Clock[IO]) =>
AuthTokenProvider(authService)(clock)
}

many[RemoteContextResolution].addEffect(ContextValue.fromFile("contexts/identities.json").map { ctx =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import akka.http.scaladsl.model.Uri.Query
import akka.http.scaladsl.model.headers.{`Last-Event-ID`, Accept}
import akka.http.scaladsl.model.{HttpRequest, HttpResponse, StatusCodes}
import akka.stream.alpakka.sse.scaladsl.EventSource
import ch.epfl.bluebrain.nexus.delta.kernel.effect.migration.MigrateEffectSyntax
import ch.epfl.bluebrain.nexus.delta.plugins.compositeviews.model.CompositeViewSource.RemoteProjectSource
import ch.epfl.bluebrain.nexus.delta.plugins.compositeviews.stream.CompositeBranch
import ch.epfl.bluebrain.nexus.delta.rdf.IriOrBNode.Iri
Expand Down Expand Up @@ -87,11 +88,12 @@ object DeltaClient {
)(implicit
as: ActorSystem[Nothing],
scheduler: Scheduler
) extends DeltaClient {
) extends DeltaClient
with MigrateEffectSyntax {

override def projectStatistics(source: RemoteProjectSource): HttpResult[ProjectStatistics] = {
for {
authToken <- authTokenProvider(credentials)
authToken <- authTokenProvider(credentials).toBIO
request =
Get(
source.endpoint / "projects" / source.project.organization.value / source.project.project.value / "statistics"
Expand All @@ -104,7 +106,7 @@ object DeltaClient {

override def remaining(source: RemoteProjectSource, offset: Offset): HttpResult[RemainingElems] = {
for {
authToken <- authTokenProvider(credentials)
authToken <- authTokenProvider(credentials).toBIO
request = Get(elemAddress(source) / "remaining")
.addHeader(accept)
.addHeader(`Last-Event-ID`(offset.value.toString))
Expand All @@ -115,7 +117,7 @@ object DeltaClient {

override def checkElems(source: RemoteProjectSource): HttpResult[Unit] = {
for {
authToken <- authTokenProvider(credentials)
authToken <- authTokenProvider(credentials).toBIO
result <- client(Head(elemAddress(source)).withCredentials(authToken)) {
case resp if resp.status.isSuccess() => UIO.delay(resp.discardEntityBytes()) >> IO.unit
}
Expand All @@ -130,7 +132,7 @@ object DeltaClient {

def send(request: HttpRequest): Future[HttpResponse] = {
(for {
authToken <- authTokenProvider(credentials)
authToken <- authTokenProvider(credentials).toBIO
result <- client[HttpResponse](request.withCredentials(authToken))(IO.pure(_))
} yield result).runToFuture
}
Expand Down Expand Up @@ -164,7 +166,7 @@ object DeltaClient {
val resourceUrl =
source.endpoint / "resources" / source.project.organization.value / source.project.project.value / "_" / id.toString
for {
authToken <- authTokenProvider(credentials)
authToken <- authTokenProvider(credentials).toBIO
req = Get(
source.resourceTag.fold(resourceUrl)(t => resourceUrl.withQuery(Query("tag" -> t.value)))
).addHeader(Accept(RdfMediaTypes.`application/n-quads`)).withCredentials(authToken)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import akka.http.scaladsl.model.Multipart.FormData
import akka.http.scaladsl.model.Multipart.FormData.BodyPart
import akka.http.scaladsl.model.StatusCodes._
import akka.http.scaladsl.model.Uri.Path
import ch.epfl.bluebrain.nexus.delta.kernel.effect.migration.MigrateEffectSyntax
import ch.epfl.bluebrain.nexus.delta.plugins.storage.storages.operations.StorageFileRejection.FetchFileRejection.UnexpectedFetchError
import ch.epfl.bluebrain.nexus.delta.plugins.storage.storages.operations.StorageFileRejection.MoveFileRejection.UnexpectedMoveError
import ch.epfl.bluebrain.nexus.delta.plugins.storage.storages.operations.StorageFileRejection.{FetchFileRejection, MoveFileRejection, SaveFileRejection}
Expand Down Expand Up @@ -34,7 +35,7 @@ import scala.concurrent.duration._
*/
final class RemoteDiskStorageClient(client: HttpClient, getAuthToken: AuthTokenProvider, credentials: Credentials)(
implicit as: ActorSystem
) {
) extends MigrateEffectSyntax {
import as.dispatcher

private val serviceName = Name.unsafe("remoteStorage")
Expand All @@ -58,7 +59,7 @@ final class RemoteDiskStorageClient(client: HttpClient, getAuthToken: AuthTokenP
* the storage bucket name
*/
def exists(bucket: Label)(implicit baseUri: BaseUri): IO[HttpClientError, Unit] = {
getAuthToken(credentials).flatMap { authToken =>
getAuthToken(credentials).toBIO.flatMap { authToken =>
val endpoint = baseUri.endpoint / "buckets" / bucket.value
val req = Head(endpoint).withCredentials(authToken)
client(req) {
Expand All @@ -82,7 +83,7 @@ final class RemoteDiskStorageClient(client: HttpClient, getAuthToken: AuthTokenP
relativePath: Path,
entity: BodyPartEntity
)(implicit baseUri: BaseUri): IO[SaveFileRejection, RemoteDiskStorageFileAttributes] = {
getAuthToken(credentials).flatMap { authToken =>
getAuthToken(credentials).toBIO.flatMap { authToken =>
val endpoint = baseUri.endpoint / "buckets" / bucket.value / "files" / relativePath
val filename = relativePath.lastSegment.getOrElse("filename")
val multipartForm = FormData(BodyPart("file", entity, Map("filename" -> filename))).toEntity()
Expand All @@ -106,7 +107,7 @@ final class RemoteDiskStorageClient(client: HttpClient, getAuthToken: AuthTokenP
* the relative path to the file location
*/
def getFile(bucket: Label, relativePath: Path)(implicit baseUri: BaseUri): IO[FetchFileRejection, AkkaSource] = {
getAuthToken(credentials).flatMap { authToken =>
getAuthToken(credentials).toBIO.flatMap { authToken =>
val endpoint = baseUri.endpoint / "buckets" / bucket.value / "files" / relativePath
client.toDataBytes(Get(endpoint).withCredentials(authToken)).mapError {
case error @ HttpClientStatusError(_, `NotFound`, _) if !bucketNotFoundType(error) =>
Expand All @@ -129,7 +130,7 @@ final class RemoteDiskStorageClient(client: HttpClient, getAuthToken: AuthTokenP
bucket: Label,
relativePath: Path
)(implicit baseUri: BaseUri): IO[FetchFileRejection, RemoteDiskStorageFileAttributes] = {
getAuthToken(credentials).flatMap { authToken =>
getAuthToken(credentials).toBIO.flatMap { authToken =>
val endpoint = baseUri.endpoint / "buckets" / bucket.value / "attributes" / relativePath
client.fromJsonTo[RemoteDiskStorageFileAttributes](Get(endpoint).withCredentials(authToken)).mapError {
case error @ HttpClientStatusError(_, `NotFound`, _) if !bucketNotFoundType(error) =>
Expand All @@ -156,7 +157,7 @@ final class RemoteDiskStorageClient(client: HttpClient, getAuthToken: AuthTokenP
sourceRelativePath: Path,
destRelativePath: Path
)(implicit baseUri: BaseUri): IO[MoveFileRejection, RemoteDiskStorageFileAttributes] = {
getAuthToken(credentials).flatMap { authToken =>
getAuthToken(credentials).toBIO.flatMap { authToken =>
val endpoint = baseUri.endpoint / "buckets" / bucket.value / "files" / destRelativePath
val payload = Json.obj("source" -> sourceRelativePath.toString.asJson)
client.fromJsonTo[RemoteDiskStorageFileAttributes](Put(endpoint, payload).withCredentials(authToken)).mapError {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,36 +1,38 @@
package ch.epfl.bluebrain.nexus.delta.sdk.auth

import cats.effect.Clock
import cats.effect.{Clock, IO}
import ch.epfl.bluebrain.nexus.delta.kernel.Logger
import ch.epfl.bluebrain.nexus.delta.kernel.cache.KeyValueStore
import ch.epfl.bluebrain.nexus.delta.kernel.cache.LocalCache
import ch.epfl.bluebrain.nexus.delta.kernel.effect.migration.MigrateEffectSyntax
import ch.epfl.bluebrain.nexus.delta.kernel.utils.IOUtils
import ch.epfl.bluebrain.nexus.delta.kernel.utils.IOInstant
import ch.epfl.bluebrain.nexus.delta.sdk.auth.Credentials.ClientCredentials
import ch.epfl.bluebrain.nexus.delta.sdk.identities.ParsedToken
import ch.epfl.bluebrain.nexus.delta.sdk.identities.model.AuthToken
import monix.bio.UIO
import monix.bio
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get rid of monix here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I should have commented this: distage fromEffect does not seem to work with cats effect right now. I didn't want to investigate, since this was just something I was doing for myself and I didn't want to waste too much time on it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is configured to expected a Monix IO for now and I don't think we can mix both


import java.time.{Duration, Instant}

/**
* Provides an auth token for the service account, for use when comunicating with remote storage
*/
trait AuthTokenProvider {
def apply(credentials: Credentials): UIO[Option[AuthToken]]
def apply(credentials: Credentials): IO[Option[AuthToken]]
}

object AuthTokenProvider {
def apply(authService: OpenIdAuthService): UIO[AuthTokenProvider] = {
KeyValueStore[ClientCredentials, ParsedToken]().map(cache => new CachingOpenIdAuthTokenProvider(authService, cache))
def apply(authService: OpenIdAuthService)(implicit clock: Clock[IO]): bio.UIO[AuthTokenProvider] = {
LocalCache[ClientCredentials, ParsedToken]()
.map(cache => new CachingOpenIdAuthTokenProvider(authService, cache))
.toBIO
}
def anonymousForTest: AuthTokenProvider = new AnonymousAuthTokenProvider
def fixedForTest(token: String): AuthTokenProvider = new AuthTokenProvider {
override def apply(credentials: Credentials): UIO[Option[AuthToken]] = UIO.pure(Some(AuthToken(token)))
override def apply(credentials: Credentials): IO[Option[AuthToken]] = IO.pure(Some(AuthToken(token)))
}
}

private class AnonymousAuthTokenProvider extends AuthTokenProvider {
override def apply(credentials: Credentials): UIO[Option[AuthToken]] = UIO.pure(None)
override def apply(credentials: Credentials): IO[Option[AuthToken]] = IO.pure(None)
}

/**
Expand All @@ -39,42 +41,42 @@ private class AnonymousAuthTokenProvider extends AuthTokenProvider {
*/
private class CachingOpenIdAuthTokenProvider(
service: OpenIdAuthService,
cache: KeyValueStore[ClientCredentials, ParsedToken]
cache: LocalCache[ClientCredentials, ParsedToken]
)(implicit
clock: Clock[UIO]
clock: Clock[IO]
) extends AuthTokenProvider
with MigrateEffectSyntax {

private val logger = Logger.cats[CachingOpenIdAuthTokenProvider]

override def apply(credentials: Credentials): UIO[Option[AuthToken]] = {
override def apply(credentials: Credentials): IO[Option[AuthToken]] = {

credentials match {
case Credentials.Anonymous => UIO.pure(None)
case Credentials.JWTToken(token) => UIO.pure(Some(AuthToken(token)))
case Credentials.Anonymous => IO.pure(None)
case Credentials.JWTToken(token) => IO.pure(Some(AuthToken(token)))
case credentials: ClientCredentials => clientCredentialsFlow(credentials)
}
}

private def clientCredentialsFlow(credentials: ClientCredentials) = {
private def clientCredentialsFlow(credentials: ClientCredentials): IO[Some[AuthToken]] = {
for {
existingValue <- cache.get(credentials)
now <- IOUtils.instant
now <- IOInstant.now
finalValue <- existingValue match {
case None =>
logger.info("Fetching auth token, no initial value.").toUIO >>
logger.info("Fetching auth token, no initial value.") *>
fetchValue(credentials)
case Some(value) if isExpired(value, now) =>
logger.info("Fetching new auth token, current value near expiry.").toUIO >>
logger.info("Fetching new auth token, current value near expiry.") *>
fetchValue(credentials)
case Some(value) => UIO.pure(value)
case Some(value) => IO.pure(value)
}
} yield {
Some(AuthToken(finalValue.rawToken))
}
}

private def fetchValue(credentials: ClientCredentials) = {
private def fetchValue(credentials: ClientCredentials): IO[ParsedToken] = {
cache.getOrElseUpdate(credentials, service.auth(credentials))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import akka.http.javadsl.model.headers.HttpCredentials
import akka.http.scaladsl.model.HttpMethods.POST
import akka.http.scaladsl.model.headers.Authorization
import akka.http.scaladsl.model.{HttpRequest, Uri}
import cats.effect.IO
import ch.epfl.bluebrain.nexus.delta.kernel.Secret
import ch.epfl.bluebrain.nexus.delta.kernel.effect.migration.MigrateEffectSyntax
import ch.epfl.bluebrain.nexus.delta.sdk.auth.Credentials.ClientCredentials
Expand All @@ -15,7 +16,6 @@ import ch.epfl.bluebrain.nexus.delta.sdk.realms.Realms
import ch.epfl.bluebrain.nexus.delta.sdk.realms.model.Realm
import ch.epfl.bluebrain.nexus.delta.sourcing.model.Label
import io.circe.Json
import monix.bio.{IO, UIO}

/**
* Exchanges client credentials for an auth token with a remote OpenId service, as defined in the specified realm
Expand All @@ -25,7 +25,7 @@ class OpenIdAuthService(httpClient: HttpClient, realms: Realms) extends MigrateE
/**
* Exchanges client credentials for an auth token with a remote OpenId service, as defined in the specified realm
*/
def auth(credentials: ClientCredentials): UIO[ParsedToken] = {
def auth(credentials: ClientCredentials): IO[ParsedToken] = {
for {
realm <- findRealm(credentials.realm)
response <- requestToken(realm.tokenEndpoint, credentials.user, credentials.password)
Expand All @@ -35,14 +35,14 @@ class OpenIdAuthService(httpClient: HttpClient, realms: Realms) extends MigrateE
}
}

private def findRealm(id: Label): UIO[Realm] = {
private def findRealm(id: Label): IO[Realm] = {
for {
realm <- realms.fetch(id).toUIO
_ <- UIO.when(realm.deprecated)(UIO.terminate(RealmIsDeprecated(realm.value)))
realm <- realms.fetch(id)
_ <- IO.raiseWhen(realm.deprecated)(RealmIsDeprecated(realm.value))
} yield realm.value
}

private def requestToken(tokenEndpoint: Uri, user: String, password: Secret[String]): UIO[Json] = {
private def requestToken(tokenEndpoint: Uri, user: String, password: Secret[String]): IO[Json] = {
httpClient
.toJson(
HttpRequest(
Expand All @@ -62,13 +62,13 @@ class OpenIdAuthService(httpClient: HttpClient, realms: Realms) extends MigrateE
.hideErrorsWith(AuthTokenHttpError)
}

private def parseResponse(json: Json): UIO[ParsedToken] = {
private def parseResponse(json: Json): IO[ParsedToken] = {
for {
rawToken <- json.hcursor.get[String]("access_token") match {
case Left(failure) => IO.terminate(AuthTokenNotFoundInResponse(failure))
case Right(value) => UIO.pure(value)
case Left(failure) => IO.raiseError(AuthTokenNotFoundInResponse(failure))
case Right(value) => IO.pure(value)
}
parsedToken <- IO.fromEither(ParsedToken.fromToken(AuthToken(rawToken))).hideErrors
parsedToken <- IO.fromEither(ParsedToken.fromToken(AuthToken(rawToken)))
} yield {
parsedToken
}
Expand Down