Skip to content
This repository has been archived by the owner on Jul 16, 2021. It is now read-only.

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
Robert Kenny committed Feb 11, 2020
1 parent 06270dc commit cf16961
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 48 deletions.
63 changes: 34 additions & 29 deletions src/main/scala/uk/ac/wellcome/sierra/SierraPageSource.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ import org.slf4j.{Logger, LoggerFactory}
import scalaj.http.{Http, HttpOptions, HttpResponse}

private[sierra] class SierraPageSource(
apiUrl: String,
oauthKey: String,
oauthSecret: String,
timeoutMs: Int
apiUrl: String,
oauthKey: String,
oauthSecret: String,
timeoutMs: Int
)(
resourceType: String,
params: Map[String, String] = Map()
resourceType: String,
params: Map[String, String] = Map()
) extends GraphStage[SourceShape[List[Json]]] {

val out: Outlet[List[Json]] = Outlet("SierraSource")
Expand All @@ -31,11 +31,9 @@ private[sierra] class SierraPageSource(
var lastId: Option[Int] = None
var jsonList: List[Json] = Nil

setHandler(out,
new OutHandler {
override def onPull(): Unit = makeSierraRequestAndPush()
}
)
setHandler(out, new OutHandler {
override def onPull(): Unit = makeSierraRequestAndPush()
})

private def makeSierraRequestAndPush(): Unit = {
val newParams = lastId match {
Expand All @@ -44,37 +42,43 @@ private[sierra] class SierraPageSource(
case None => params
}

makeRequestWith(newParams, ifUnauthorized = {
token = refreshToken(apiUrl, oauthKey, oauthSecret)
makeRequestWith(newParams, ifUnauthorized = {
fail(out, new RuntimeException("Unauthorized!"))
})
})
makeRequestWith(
newParams,
ifUnauthorized = {
token = refreshToken(apiUrl, oauthKey, oauthSecret)
makeRequestWith(newParams, ifUnauthorized = {
fail(out, new RuntimeException("Unauthorized!"))
})
}
)
}

private def makeRequestWith[T](newParams: Map[String, String], ifUnauthorized: => Unit): Unit = {
private def makeRequestWith[T](newParams: Map[String, String],
ifUnauthorized: => Unit): Unit = {
val newResponse = makeRequest(apiUrl, resourceType, token, newParams)

newResponse.code match {
case 200 => refreshJsonListAndPush(newResponse)
case 404 => complete(out)
case 401 => ifUnauthorized
case code => fail(out, new RuntimeException(
s"Unexpected HTTP status code from Sierra: $code"))
case code =>
fail(out,
new RuntimeException(
s"Unexpected HTTP status code from Sierra: $code"))
}
}

private def refreshJsonListAndPush(response: HttpResponse[String]): Unit = {
val responseJson = parse(response.body).right.getOrElse(
throw new RuntimeException("response was not json"))
private def refreshJsonListAndPush(
response: HttpResponse[String]): Unit = {
val responseJson = parse(response.body).right
.getOrElse(throw new RuntimeException("response was not json"))

jsonList = root.entries.each.json.getAll(responseJson)

lastId = Some(
root.id.string
.getOption(jsonList.last)
.getOrElse(
throw new RuntimeException("id not found in item"))
.getOrElse(throw new RuntimeException("id not found in item"))
.toInt)

push(out, jsonList)
Expand All @@ -87,12 +91,13 @@ private[sierra] class SierraPageSource(
Http(s"$apiUrl/token").postForm.auth(oauthKey, oauthSecret).asString
val json = parse(tokenResponse.body).right
.getOrElse(throw new RuntimeException("response was not json"))
root.access_token.string.getOption(json).getOrElse(
throw new Exception("Failed to refresh token!")
)
root.access_token.string
.getOption(json)
.getOrElse(
throw new Exception("Failed to refresh token!")
)
}


}

private def makeRequest(apiUrl: String,
Expand Down
39 changes: 20 additions & 19 deletions src/main/scala/uk/ac/wellcome/sierra/SierraSource.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,40 +9,41 @@ import scala.concurrent.duration._

case class ThrottleRate(elements: Int, per: FiniteDuration, maximumBurst: Int)
case object ThrottleRate {
def apply(elements: Int, per: FiniteDuration): ThrottleRate = ThrottleRate(elements, per, 0)
def apply(elements: Int, per: FiniteDuration): ThrottleRate =
ThrottleRate(elements, per, 0)
}

object SierraSource {
def apply(
apiUrl: String,
oauthKey: String,
oauthSecret: String,
throttleRate: ThrottleRate = ThrottleRate(elements = 0, per = 0 seconds),
timeoutMs: Int = 10000
)(
resourceType: String,
apiUrl: String,
oauthKey: String,
oauthSecret: String,
throttleRate: ThrottleRate = ThrottleRate(elements = 0, per = 0 seconds),
timeoutMs: Int = 10000
)(resourceType: String,
params: Map[String, String]): Source[Json, NotUsed] = {

val source = Source.fromGraph(
new SierraPageSource(
apiUrl = apiUrl,
oauthKey = oauthKey,
oauthSecret = oauthSecret,
timeoutMs = timeoutMs)(
resourceType = resourceType,
params = params)
timeoutMs = timeoutMs)(resourceType = resourceType, params = params)
)

throttleRate.elements match {
case 0 => source
.mapConcat(identity)
case 0 =>
source
.mapConcat(identity)
case _ =>
source.throttle(
throttleRate.elements,
throttleRate.per,
throttleRate.maximumBurst,
ThrottleMode.shaping
).mapConcat(identity)
source
.throttle(
throttleRate.elements,
throttleRate.per,
throttleRate.maximumBurst,
ThrottleMode.shaping
)
.mapConcat(identity)
}
}
}

0 comments on commit cf16961

Please sign in to comment.