Skip to content

Commit

Permalink
Add getServingInfo unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
caiocamatta-stripe committed Feb 26, 2024
1 parent 0d1d72d commit acc83c2
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 5 deletions.
4 changes: 2 additions & 2 deletions online/src/main/scala/ai/chronon/online/FetcherBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,8 @@ class FetcherBase(kvStore: KVStore,
* @param batchEndTs the new batchEndTs from the latest batch data
* @param groupByServingInfo the current GroupByServingInfo
*/
private def updateServingInfo(batchEndTs: Long,
groupByServingInfo: GroupByServingInfoParsed): GroupByServingInfoParsed = {
private[online] def updateServingInfo(batchEndTs: Long,
groupByServingInfo: GroupByServingInfoParsed): GroupByServingInfoParsed = {
val name = groupByServingInfo.groupBy.metaData.name
if (batchEndTs > groupByServingInfo.batchEndTsMillis) {
logger.info(s"""$name's value's batch timestamp of $batchEndTs is
Expand Down
51 changes: 48 additions & 3 deletions online/src/test/scala/ai/chronon/online/FetcherBaseTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@

package ai.chronon.online

import ai.chronon.aggregator.windowing.FinalBatchIr
import ai.chronon.api.Extensions.GroupByOps
import ai.chronon.api.MetaData
import ai.chronon.online.Fetcher.{ColumnSpec, Request, Response}
import ai.chronon.online.FetcherCache.BatchResponses
import ai.chronon.online.KVStore.TimedValue
import org.junit.{Before, Test}
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito._
Expand All @@ -25,12 +30,14 @@ import org.mockito.stubbing.Answer
import org.mockito.{Answers, ArgumentCaptor}
import org.scalatest.matchers.should.Matchers
import org.scalatestplus.mockito.MockitoSugar
import org.junit.Assert.assertSame

import scala.concurrent.duration.DurationInt
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.util.{Failure, Success}
import scala.util.Try

class FetcherBaseTest extends MockitoSugar with Matchers {
class FetcherBaseTest extends MockitoSugar with Matchers with MockitoHelper {
val GroupBy = "relevance.short_term_user_features"
val Column = "pdp_view_count_14d"
val GuestKey = "guest"
Expand Down Expand Up @@ -118,7 +125,7 @@ class FetcherBaseTest extends MockitoSugar with Matchers {
// Fetch a single query
val keyMap = Map(GuestKey -> GuestId)
val query = ColumnSpec(GroupBy, Column, None, Some(keyMap))

doAnswer(new Answer[Future[Seq[Fetcher.Response]]] {
def answer(invocation: InvocationOnMock): Future[Seq[Response]] = {
Future.successful(Seq())
Expand All @@ -130,7 +137,7 @@ class FetcherBaseTest extends MockitoSugar with Matchers {
queryResults.contains(query) shouldBe true
queryResults.get(query).map(_.values) match {
case Some(Failure(ex: IllegalStateException)) => succeed
case _ => fail()
case _ => fail()
}

// GroupBy request sent to KV store for the query
Expand All @@ -141,4 +148,42 @@ class FetcherBaseTest extends MockitoSugar with Matchers {
actualRequest.get.name shouldBe query.groupByName + "." + query.columnName
actualRequest.get.keys shouldBe query.keyMapping.get
}

@Test
def test_getServingInfo_ShouldCallUpdateServingInfoIfBatchResponseIsFromKvStore(): Unit = {
val baseFetcher = new FetcherBase(mock[KVStore])
val spiedFetcherBase = spy(baseFetcher)
val oldServingInfo = mock[GroupByServingInfoParsed]
val updatedServingInfo = mock[GroupByServingInfoParsed]
val batchTimedValuesSuccess = Success(Seq(TimedValue(Array(1.toByte), 2000L)))
val kvStoreBatchResponses = BatchResponses(batchTimedValuesSuccess)
doReturn(updatedServingInfo).when(spiedFetcherBase).updateServingInfo(any(), any())

// updateServingInfo is called
val result = spiedFetcherBase.getServingInfo(oldServingInfo, kvStoreBatchResponses)
assertSame(result, updatedServingInfo)
verify(spiedFetcherBase).updateServingInfo(any(), any())
}

@Test
def test_getServingInfo_ShouldRefreshServingInfoIfBatchResponseIsCached(): Unit = {
val baseFetcher = new FetcherBase(mock[KVStore])
val spiedFetcherBase = spy(baseFetcher)
val oldServingInfo = mock[GroupByServingInfoParsed]
val metaData = mock[MetaData]
val groupByOpsMock = mock[GroupByOps]
val cachedBatchResponses = BatchResponses(mock[FinalBatchIr])
val ttlCache = mock[TTLCache[String, Try[GroupByServingInfoParsed]]]
doReturn(ttlCache).when(spiedFetcherBase).getGroupByServingInfo
doReturn(Success(oldServingInfo)).when(ttlCache).refresh(any[String])
metaData.name = "test"
groupByOpsMock.metaData = metaData
when(oldServingInfo.groupByOps).thenReturn(groupByOpsMock)

// FetcherBase.updateServingInfo is not called, but getGroupByServingInfo.refresh() is.
val result = spiedFetcherBase.getServingInfo(oldServingInfo, cachedBatchResponses)
assertSame(result, oldServingInfo)
verify(ttlCache).refresh(any())
verify(spiedFetcherBase, never()).updateServingInfo(any(), any())
}
}

0 comments on commit acc83c2

Please sign in to comment.